From 97ba7f4b26773598affb4dd8ac119e9e1d1444e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:49:26 +0000 Subject: [PATCH] fix: hugr-py not adding extension-reqs on custom ops (#1759) Fixes #1758 Without these changes, the validation tests fail if we remove the workaround in `custom.rs` --- hugr-core/src/ops/custom.rs | 9 +------- hugr-py/src/hugr/_serialization/extension.py | 12 +++++++---- hugr-py/src/hugr/ext.py | 14 ++++++++++++- hugr-py/src/hugr/ops.py | 22 ++++++++++++++++---- hugr-py/src/hugr/std/int.py | 2 +- hugr-py/src/hugr/tys.py | 17 +++++++++++++++ hugr-py/tests/test_custom.py | 2 +- 7 files changed, 59 insertions(+), 19 deletions(-) diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index eed0a761f..91b93e332 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -257,14 +257,7 @@ impl DataflowOpTrait for OpaqueOp { } fn signature(&self) -> Cow<'_, Signature> { - // TODO: Return a borrowed cow once - // https://github.com/CQCL/hugr/issues/1758 - // gets fixed - Cow::Owned( - self.signature - .clone() - .with_extension_delta(self.extension.clone()), - ) + Cow::Borrowed(&self.signature) } } diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 6420bffff..3017bd6f8 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -101,15 +101,19 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): lower_funcs: list[FixedHugr] = pd.Field(default_factory=list) def deserialize(self, extension: ext.Extension) -> ext.OpDef: + signature = ext.OpDefSig( + self.signature.deserialize().with_extension_reqs([extension.name]) + if self.signature + else None, + self.binary, + ) + return extension.add_op_def( ext.OpDef( name=self.name, description=self.description, misc=self.misc or {}, - signature=ext.OpDefSig( - self.signature.deserialize() if self.signature else None, - self.binary, - ), + signature=signature, lower_funcs=[f.deserialize() for f in self.lower_funcs], ) ) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 533e55cd7..165ba89fe 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -160,7 +160,13 @@ def _to_serial(self) -> ext_s.FixedHugr: @dataclass class OpDefSig: - """Type signature of an :class:`OpDef`.""" + """Type signature of an :class:`OpDef`. + + Args: + poly_func: The polymorphic function type of the operation. + binary: If no static type scheme known, flag indicates a computation of the + signature + """ #: The polymorphic function type of the operation (type scheme). poly_func: tys.PolyFuncType | None @@ -311,6 +317,12 @@ def add_op_def(self, op_def: OpDef) -> OpDef: Returns: The added operation definition, now associated with the extension. """ + if op_def.signature.poly_func is not None: + # Ensure the op def signature has the extension as a requirement + op_def.signature.poly_func = op_def.signature.poly_func.with_extension_reqs( + [self.name] + ) + op_def._extension = self self.operations[op_def.name] = op_def return self.operations[op_def.name] diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 8c5b845a8..4e9b2d961 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -453,7 +453,11 @@ def op_def(self) -> ext.OpDef: return std.PRELUDE.get_op("MakeTuple") def cached_signature(self) -> tys.FunctionType | None: - return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)]) + return tys.FunctionType( + input=self.types, + output=[tys.Tuple(*self.types)], + extension_reqs=["prelude"], + ) def type_args(self) -> list[tys.TypeArg]: return [tys.SequenceArg([t.type_arg() for t in self.types])] @@ -492,7 +496,11 @@ def op_def(self) -> ext.OpDef: return std.PRELUDE.get_op("UnpackTuple") def cached_signature(self) -> tys.FunctionType | None: - return tys.FunctionType(input=[tys.Tuple(*self.types)], output=self.types) + return tys.FunctionType( + input=[tys.Tuple(*self.types)], + output=self.types, + extension_reqs=["prelude"], + ) def type_args(self) -> list[tys.TypeArg]: return [tys.SequenceArg([t.type_arg() for t in self.types])] @@ -1266,10 +1274,16 @@ def op_def(self) -> ext.OpDef: return std.PRELUDE.get_op("Noop") def cached_signature(self) -> tys.FunctionType | None: - return tys.FunctionType.endo([self.type_]) + return tys.FunctionType.endo( + [self.type_], + extension_reqs=["prelude"], + ) def outer_signature(self) -> tys.FunctionType: - return tys.FunctionType.endo([self.type_]) + return tys.FunctionType.endo( + [self.type_], + extension_reqs=["prelude"], + ) def _set_in_types(self, types: tys.TypeRow) -> None: (t,) = types diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 3f437df92..35985305f 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -87,7 +87,7 @@ def type_args(self) -> list[tys.TypeArg]: def cached_signature(self) -> tys.FunctionType | None: row: list[tys.Type] = [int_t(self.width)] * 2 - return tys.FunctionType.endo(row) + return tys.FunctionType.endo(row, extension_reqs=[INT_OPS_EXTENSION.name]) @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index 7c13e50b9..29f0f1482 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -514,6 +514,14 @@ def resolve(self, registry: ext.ExtensionRegistry) -> FunctionType: extension_reqs=self.extension_reqs, ) + def with_extension_reqs(self, extension_reqs: ExtensionSet) -> FunctionType: + """Adds a list of extension requirements to the function type, and + returns the new signature. + """ + exts = set(self.extension_reqs) + exts = exts.union(extension_reqs) + return FunctionType(self.input, self.output, [*exts]) + def __str__(self) -> str: return f"{comma_sep_str(self.input)} -> {comma_sep_str(self.output)}" @@ -543,6 +551,15 @@ def resolve(self, registry: ext.ExtensionRegistry) -> PolyFuncType: body=self.body.resolve(registry), ) + def with_extension_reqs(self, extension_reqs: ExtensionSet) -> PolyFuncType: + """Adds a list of extension requirements to the function type, and + returns the new signature. + """ + return PolyFuncType( + params=self.params, + body=self.body.with_extension_reqs(extension_reqs), + ) + def __str__(self) -> str: return f"∀ {comma_sep_str(self.params)}. {self.body!s}" diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 48f57de7a..b6865b6ab 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -37,7 +37,7 @@ def type_args(self) -> list[tys.TypeArg]: return [tys.StringArg(self.tag)] def cached_signature(self) -> tys.FunctionType | None: - return tys.FunctionType.endo([]) + return tys.FunctionType.endo([], extension_reqs=[STRINGLY_EXT.name]) @classmethod def from_ext(cls, custom: ops.ExtOp) -> "StringlyOp":