From ec0ab07b81a2b12dab1a88d9ece34db6a8a64e79 Mon Sep 17 00:00:00 2001 From: Fiddle-Config Team Date: Thu, 27 Jul 2023 13:55:59 -0700 Subject: [PATCH] Utilize config_normalizer APIs to make a Pax-specific graphviz function. PiperOrigin-RevId: 551632323 --- .../auto_config/add_type_signatures.py | 123 ++++++++++++++++++ .../auto_config/add_type_signatures_test.py | 123 ++++++++++++++++++ fiddle/_src/codegen/auto_config/code_ir.py | 1 + .../_src/codegen/auto_config/code_ir_test.py | 1 + fiddle/_src/codegen/auto_config/ir_printer.py | 8 ++ fiddle/_src/codegen/auto_config/ir_to_cst.py | 15 +++ fiddle/_src/codegen/new_codegen.py | 13 ++ fiddle/_src/codegen/new_codegen_test.py | 2 +- 8 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 fiddle/_src/codegen/auto_config/add_type_signatures.py create mode 100644 fiddle/_src/codegen/auto_config/add_type_signatures_test.py diff --git a/fiddle/_src/codegen/auto_config/add_type_signatures.py b/fiddle/_src/codegen/auto_config/add_type_signatures.py new file mode 100644 index 00000000..300262fb --- /dev/null +++ b/fiddle/_src/codegen/auto_config/add_type_signatures.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Copyright 2022 The Fiddle-Config Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adds type signatures to modules. + +For now, we only populate return types. +""" + +import inspect + +from fiddle._src import config as config_lib +from fiddle._src import signatures +from fiddle._src.codegen import import_manager as import_manager_lib +from fiddle._src.codegen.auto_config import code_ir +from fiddle._src.codegen.auto_config import import_manager_wrapper + + +_BUILTIN_TYPE_MAP = { + type(None): "None", + str: "str", + int: "int", + float: "float", + bool: "bool", +} + + +def _get_annotation_from_type(typ) -> code_ir.CodegenNode: + if typ in _BUILTIN_TYPE_MAP: + return code_ir.BuiltinReference(code_ir.Name(_BUILTIN_TYPE_MAP[typ])) + else: + # TODO(b/293352960): import typing.Any correctly. + # TODO(b/293509806): Handle more types, especially from function return + # signatures. + return code_ir.BuiltinReference(code_ir.Name("Any")) + + +def get_type_annotation( + value, import_manager: import_manager_lib.ImportManager +) -> code_ir.CodegenNode: + """Gets the type annotation for a given value.""" + if isinstance(value, config_lib.Buildable): + buildable_type = import_manager_wrapper.add(type(value), import_manager) + fn_or_cls = config_lib.get_callable(value) + if isinstance(fn_or_cls, type): + sub_type = import_manager_wrapper.add(fn_or_cls, import_manager) + else: + signature = signatures.get_signature(fn_or_cls) + if isinstance(signature.return_annotation, type) and ( + signature.return_annotation is not inspect.Signature.empty + ): + sub_type = _get_annotation_from_type(signature.return_annotation) + else: + return buildable_type + return code_ir.ParameterizedTypeExpression(buildable_type, [sub_type]) + elif isinstance(value, (list, tuple)): + base_expression = code_ir.BuiltinReference( + code_ir.Name("list" if isinstance(value, list) else "tuple") + ) + sub_value_annotations = [ + get_type_annotation(item, import_manager) for item in value + ] + if sub_value_annotations and all( + annotation == sub_value_annotations[0] + for annotation in sub_value_annotations + ): + return code_ir.ParameterizedTypeExpression( + base_expression, [sub_value_annotations[0]] + ) + else: + return base_expression + elif isinstance(value, dict): + base_expression = code_ir.BuiltinReference(code_ir.Name("dict")) + key_annotations = [ + get_type_annotation(item, import_manager) for item in value.keys() + ] + value_annotations = [ + get_type_annotation(item, import_manager) for item in value.values() + ] + if key_annotations and all( + annotation == key_annotations[0] for annotation in key_annotations + ): + key_annotation = key_annotations[0] + else: + # TODO(b/293352960): import typing.Any correctly. + key_annotation = code_ir.BuiltinReference(code_ir.Name("Any")) + if value_annotations and all( + annotation == value_annotations[0] for annotation in value_annotations + ): + value_annotation = value_annotations[0] + else: + value_annotation = code_ir.BuiltinReference(code_ir.Name("Any")) + return code_ir.ParameterizedTypeExpression( + base_expression, [key_annotation, value_annotation] + ) + else: + return _get_annotation_from_type(type(value)) + + +def add_return_types(task: code_ir.CodegenTask) -> None: + """Adds return type signatures. + + This is normally based on config types, so for `auto_config`, it would reflect + the as_buildable() path. Hence, we don't add it by default yet. + + Args: + task: Codegen task. + """ + for fn in task.top_level_call.all_fixture_functions(): + fn.return_type_annotation = get_type_annotation( + fn.output_value, task.import_manager + ) diff --git a/fiddle/_src/codegen/auto_config/add_type_signatures_test.py b/fiddle/_src/codegen/auto_config/add_type_signatures_test.py new file mode 100644 index 00000000..4ff74910 --- /dev/null +++ b/fiddle/_src/codegen/auto_config/add_type_signatures_test.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Copyright 2022 The Fiddle-Config Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for add_type_signatures.""" + +from typing import List + +from absl.testing import absltest +from absl.testing import parameterized +import fiddle as fdl +from fiddle._src.codegen import import_manager as import_manager_lib +from fiddle._src.codegen import namespace as namespace_lib +from fiddle._src.codegen.auto_config import add_type_signatures +from fiddle._src.codegen.auto_config import ir_printer +from fiddle._src.codegen.auto_config import test_fixtures +from fiddle._src.testing.example import fake_encoder_decoder + + +def foo(x): + return x + + +def bar(x: int) -> int: + return x + + +def baz() -> List[int]: + return [1] + + +def qux() -> list: # pylint: disable=g-bare-generic + return [1] + + +class AddTypeSignaturesTest(parameterized.TestCase): + + @parameterized.parameters( + { + "value": True, + "expected": "bool", + }, + { + "value": [1, 2, 3], + "expected": "list[int]", + }, + { + "value": [1, 2, "a"], + "expected": "list", + }, + { + "value": {"hi": 3, "bye": 4}, + "expected": "dict[str, int]", + }, + { + "value": {}, + "expected": "dict[Any, Any]", + }, + { + # Custom types are replaced with Any. + # (Rationale: Don't put custom objects in Fiddle configs.) + "value": namespace_lib.Namespace(set()), + "expected": "Any", + }, + { + "value": fdl.Config(foo, x=1), + "expected": "fdl.Config", + }, + { + "value": fdl.Config(bar, x=1), + "expected": "fdl.Config[int]", + }, + { + # TODO(b/293509806): Handle more types, especially from function + # return signatures. + "value": fdl.Config(baz), + "expected": "fdl.Config", + }, + { + # TODO(b/293509806): Handle more types, especially from function + # return signatures. + "value": fdl.Config(qux), + "expected": "fdl.Config[Any]", + }, + { + "value": fdl.Config(fake_encoder_decoder.FakeEncoderDecoder), + "expected": "fdl.Config[fake_encoder_decoder.FakeEncoderDecoder]", + }, + { + "value": fdl.Partial(foo, x=1), + "expected": "fdl.Partial", + }, + { + "value": fdl.Partial(bar, x=1), + "expected": "fdl.Partial[int]", + }, + ) + def test_get_type_annotation(self, value, expected): + import_manager = import_manager_lib.ImportManager(namespace_lib.Namespace()) + expression = add_type_signatures.get_type_annotation( + value=value, import_manager=import_manager + ) + formatted = ir_printer.format_expr(expression) + self.assertEqual(formatted, expected) + + @parameterized.named_parameters(*test_fixtures.parameters_for_testcases()) + def test_smoke_add_return_types(self, task): + add_type_signatures.add_return_types(task) + + +if __name__ == "__main__": + absltest.main() diff --git a/fiddle/_src/codegen/auto_config/code_ir.py b/fiddle/_src/codegen/auto_config/code_ir.py index 489f11b4..e7581b7f 100644 --- a/fiddle/_src/codegen/auto_config/code_ir.py +++ b/fiddle/_src/codegen/auto_config/code_ir.py @@ -225,6 +225,7 @@ class FixtureFunction(CodegenNode): parameters: List[Parameter] variables: List[VariableDeclaration] output_value: Any # Value that can involve VariableReference's + return_type_annotation: Optional[Any] = None def __hash__(self): return id(self) diff --git a/fiddle/_src/codegen/auto_config/code_ir_test.py b/fiddle/_src/codegen/auto_config/code_ir_test.py index 05a71664..f01b507f 100644 --- a/fiddle/_src/codegen/auto_config/code_ir_test.py +++ b/fiddle/_src/codegen/auto_config/code_ir_test.py @@ -68,6 +68,7 @@ def test_daglish_iteration(self): (".variables", []), (".output_value", fn.output_value), (".output_value.x", 2), + (".return_type_annotation", None), ], ) diff --git a/fiddle/_src/codegen/auto_config/ir_printer.py b/fiddle/_src/codegen/auto_config/ir_printer.py index a3ca425a..82d16acf 100644 --- a/fiddle/_src/codegen/auto_config/ir_printer.py +++ b/fiddle/_src/codegen/auto_config/ir_printer.py @@ -108,6 +108,14 @@ def traverse(value, state: daglish.State) -> str: f"(*[{positional_arg_expressions}]," f" **{arg_expressions})>" ) + elif isinstance(value, code_ir.ParameterizedTypeExpression): + base_expression = state.call( + value.base_expression, daglish.Attr("base_expression") + ) + param_expressions = state.call( + value.param_expressions, daglish.Attr("param_expressions") + ) + return f"{base_expression}{param_expressions}" elif isinstance(value, code_ir.Name): return value.value elif isinstance(value, type): diff --git a/fiddle/_src/codegen/auto_config/ir_to_cst.py b/fiddle/_src/codegen/auto_config/ir_to_cst.py index f4d076dc..2652984c 100644 --- a/fiddle/_src/codegen/auto_config/ir_to_cst.py +++ b/fiddle/_src/codegen/auto_config/ir_to_cst.py @@ -101,6 +101,14 @@ def _prepare_args_helper( elif isinstance(value, code_ir.AttributeExpression): base = state.call(value.base, daglish.Attr("base")) return cst.Attribute(value=base, attr=cst.Name(value.attribute)) + elif isinstance(value, code_ir.ParameterizedTypeExpression): + return cst.Subscript( + value=code_for_expr(value.base_expression), + slice=[ + cst.SubscriptElement(cst.Index(code_for_expr(param))) + for param in value.param_expressions + ], + ) elif isinstance(value, code_ir.SymbolOrFixtureCall): attr = daglish.Attr("arg_expressions") args = [] @@ -199,6 +207,12 @@ def code_for_fn( ), ] ) + if fn.return_type_annotation: + returns = cst.Annotation( + annotation=code_for_expr(fn.return_type_annotation) + ) + else: + returns = None if fn.parameters and len(fn.parameters) > 1: whitespace_before_params = cst.ParenthesizedWhitespace( cst.TrailingWhitespace(), @@ -211,6 +225,7 @@ def code_for_fn( cst.Name(fn.name.value), params, body, + returns=returns, decorators=[cst.Decorator(auto_config_expr)] if auto_config_expr else [], whitespace_before_params=whitespace_before_params, leading_lines=[cst.EmptyLine(), cst.EmptyLine()], diff --git a/fiddle/_src/codegen/new_codegen.py b/fiddle/_src/codegen/new_codegen.py index 101f18df..108ef4a4 100644 --- a/fiddle/_src/codegen/new_codegen.py +++ b/fiddle/_src/codegen/new_codegen.py @@ -23,6 +23,7 @@ import fiddle as fdl from fiddle._src.codegen import newcg_symbolic_references +from fiddle._src.codegen.auto_config import add_type_signatures from fiddle._src.codegen.auto_config import code_ir from fiddle._src.codegen.auto_config import experimental_top_level_api from fiddle._src.codegen.auto_config import make_symbolic_references as old_symbolic_references @@ -61,6 +62,13 @@ class MakeSymbolicReferences(experimental_top_level_api.MutationCodegenPass): ) +@dataclasses.dataclass(frozen=True) +class AddTypeSignatures(experimental_top_level_api.MutationCodegenPass): + """Adds return type signatures to fixtures.""" + + fn: Callable[..., Any] = add_type_signatures.add_return_types + + def _get_pass_idx( codegen_config: fdl.Config[experimental_top_level_api.Codegen], cls: Type[experimental_top_level_api.CodegenPass], @@ -100,6 +108,11 @@ def code_generator( # Replace MakeSymbolicReferences idx = _get_pass_idx(config, experimental_top_level_api.MakeSymbolicReferences) fdl.update_callable(config.passes[idx], MakeSymbolicReferences) + + # Insert type annotations before MakeSymbolicReferences. These type + # annotations currently make more sense for non-auto_config cases. + config.passes.insert(idx, fdl.Config(AddTypeSignatures)) + return config diff --git a/fiddle/_src/codegen/new_codegen_test.py b/fiddle/_src/codegen/new_codegen_test.py index 7858fcb0..4ce1ed4b 100644 --- a/fiddle/_src/codegen/new_codegen_test.py +++ b/fiddle/_src/codegen/new_codegen_test.py @@ -91,7 +91,7 @@ def test_code_output(self): from fiddle._src.testing.example import fake_encoder_decoder - def config_fixture(): + def config_fixture() -> fdl.Config[fake_encoder_decoder.FakeEncoder]: mlp = fdl.Config(fake_encoder_decoder.Mlp, dtype='float32', use_bias=False, sharding_axes=['embed', 'num_heads', 'head_dim']) return fdl.Config(fake_encoder_decoder.FakeEncoder, embedders={'tokens':