Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilize config_normalizer APIs to make a Pax-specific graphviz function. #460

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions fiddle/_src/codegen/auto_config/add_type_signatures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# coding=utf-8
# Copyright 2022 The Fiddle-Config Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Adds type signatures to modules.

For now, we only populate return types.
"""

import inspect

from fiddle._src import config as config_lib
from fiddle._src import signatures
from fiddle._src.codegen import import_manager as import_manager_lib
from fiddle._src.codegen.auto_config import code_ir
from fiddle._src.codegen.auto_config import import_manager_wrapper


_BUILTIN_TYPE_MAP = {
type(None): "None",
str: "str",
int: "int",
float: "float",
bool: "bool",
}


def _get_annotation_from_type(typ) -> code_ir.CodegenNode:
if typ in _BUILTIN_TYPE_MAP:
return code_ir.BuiltinReference(code_ir.Name(_BUILTIN_TYPE_MAP[typ]))
else:
# TODO(b/293352960): import typing.Any correctly.
# TODO(b/293509806): Handle more types, especially from function return
# signatures.
return code_ir.BuiltinReference(code_ir.Name("Any"))


def get_type_annotation(
value, import_manager: import_manager_lib.ImportManager
) -> code_ir.CodegenNode:
"""Gets the type annotation for a given value."""
if isinstance(value, config_lib.Buildable):
buildable_type = import_manager_wrapper.add(type(value), import_manager)
fn_or_cls = config_lib.get_callable(value)
if isinstance(fn_or_cls, type):
sub_type = import_manager_wrapper.add(fn_or_cls, import_manager)
else:
signature = signatures.get_signature(fn_or_cls)
if isinstance(signature.return_annotation, type) and (
signature.return_annotation is not inspect.Signature.empty
):
sub_type = _get_annotation_from_type(signature.return_annotation)
else:
return buildable_type
return code_ir.ParameterizedTypeExpression(buildable_type, [sub_type])
elif isinstance(value, (list, tuple)):
base_expression = code_ir.BuiltinReference(
code_ir.Name("list" if isinstance(value, list) else "tuple")
)
sub_value_annotations = [
get_type_annotation(item, import_manager) for item in value
]
if sub_value_annotations and all(
annotation == sub_value_annotations[0]
for annotation in sub_value_annotations
):
return code_ir.ParameterizedTypeExpression(
base_expression, [sub_value_annotations[0]]
)
else:
return base_expression
elif isinstance(value, dict):
base_expression = code_ir.BuiltinReference(code_ir.Name("dict"))
key_annotations = [
get_type_annotation(item, import_manager) for item in value.keys()
]
value_annotations = [
get_type_annotation(item, import_manager) for item in value.values()
]
if key_annotations and all(
annotation == key_annotations[0] for annotation in key_annotations
):
key_annotation = key_annotations[0]
else:
# TODO(b/293352960): import typing.Any correctly.
key_annotation = code_ir.BuiltinReference(code_ir.Name("Any"))
if value_annotations and all(
annotation == value_annotations[0] for annotation in value_annotations
):
value_annotation = value_annotations[0]
else:
value_annotation = code_ir.BuiltinReference(code_ir.Name("Any"))
return code_ir.ParameterizedTypeExpression(
base_expression, [key_annotation, value_annotation]
)
else:
return _get_annotation_from_type(type(value))


def add_return_types(task: code_ir.CodegenTask) -> None:
"""Adds return type signatures.

This is normally based on config types, so for `auto_config`, it would reflect
the as_buildable() path. Hence, we don't add it by default yet.

Args:
task: Codegen task.
"""
for fn in task.top_level_call.all_fixture_functions():
fn.return_type_annotation = get_type_annotation(
fn.output_value, task.import_manager
)
123 changes: 123 additions & 0 deletions fiddle/_src/codegen/auto_config/add_type_signatures_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# coding=utf-8
# Copyright 2022 The Fiddle-Config Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for add_type_signatures."""

from typing import List

from absl.testing import absltest
from absl.testing import parameterized
import fiddle as fdl
from fiddle._src.codegen import import_manager as import_manager_lib
from fiddle._src.codegen import namespace as namespace_lib
from fiddle._src.codegen.auto_config import add_type_signatures
from fiddle._src.codegen.auto_config import ir_printer
from fiddle._src.codegen.auto_config import test_fixtures
from fiddle._src.testing.example import fake_encoder_decoder


def foo(x):
return x


def bar(x: int) -> int:
return x


def baz() -> List[int]:
return [1]


def qux() -> list: # pylint: disable=g-bare-generic
return [1]


class AddTypeSignaturesTest(parameterized.TestCase):

@parameterized.parameters(
{
"value": True,
"expected": "bool",
},
{
"value": [1, 2, 3],
"expected": "list[int]",
},
{
"value": [1, 2, "a"],
"expected": "list",
},
{
"value": {"hi": 3, "bye": 4},
"expected": "dict[str, int]",
},
{
"value": {},
"expected": "dict[Any, Any]",
},
{
# Custom types are replaced with Any.
# (Rationale: Don't put custom objects in Fiddle configs.)
"value": namespace_lib.Namespace(set()),
"expected": "Any",
},
{
"value": fdl.Config(foo, x=1),
"expected": "fdl.Config",
},
{
"value": fdl.Config(bar, x=1),
"expected": "fdl.Config[int]",
},
{
# TODO(b/293509806): Handle more types, especially from function
# return signatures.
"value": fdl.Config(baz),
"expected": "fdl.Config",
},
{
# TODO(b/293509806): Handle more types, especially from function
# return signatures.
"value": fdl.Config(qux),
"expected": "fdl.Config[Any]",
},
{
"value": fdl.Config(fake_encoder_decoder.FakeEncoderDecoder),
"expected": "fdl.Config[fake_encoder_decoder.FakeEncoderDecoder]",
},
{
"value": fdl.Partial(foo, x=1),
"expected": "fdl.Partial",
},
{
"value": fdl.Partial(bar, x=1),
"expected": "fdl.Partial[int]",
},
)
def test_get_type_annotation(self, value, expected):
import_manager = import_manager_lib.ImportManager(namespace_lib.Namespace())
expression = add_type_signatures.get_type_annotation(
value=value, import_manager=import_manager
)
formatted = ir_printer.format_expr(expression)
self.assertEqual(formatted, expected)

@parameterized.named_parameters(*test_fixtures.parameters_for_testcases())
def test_smoke_add_return_types(self, task):
add_type_signatures.add_return_types(task)


if __name__ == "__main__":
absltest.main()
1 change: 1 addition & 0 deletions fiddle/_src/codegen/auto_config/code_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class FixtureFunction(CodegenNode):
parameters: List[Parameter]
variables: List[VariableDeclaration]
output_value: Any # Value that can involve VariableReference's
return_type_annotation: Optional[Any] = None

def __hash__(self):
return id(self)
Expand Down
1 change: 1 addition & 0 deletions fiddle/_src/codegen/auto_config/code_ir_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_daglish_iteration(self):
(".variables", []),
(".output_value", fn.output_value),
(".output_value.x", 2),
(".return_type_annotation", None),
],
)

Expand Down
8 changes: 8 additions & 0 deletions fiddle/_src/codegen/auto_config/ir_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def traverse(value, state: daglish.State) -> str:
f"(*[{positional_arg_expressions}],"
f" **{arg_expressions})>"
)
elif isinstance(value, code_ir.ParameterizedTypeExpression):
base_expression = state.call(
value.base_expression, daglish.Attr("base_expression")
)
param_expressions = state.call(
value.param_expressions, daglish.Attr("param_expressions")
)
return f"{base_expression}{param_expressions}"
elif isinstance(value, code_ir.Name):
return value.value
elif isinstance(value, type):
Expand Down
15 changes: 15 additions & 0 deletions fiddle/_src/codegen/auto_config/ir_to_cst.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def _prepare_args_helper(
elif isinstance(value, code_ir.AttributeExpression):
base = state.call(value.base, daglish.Attr("base"))
return cst.Attribute(value=base, attr=cst.Name(value.attribute))
elif isinstance(value, code_ir.ParameterizedTypeExpression):
return cst.Subscript(
value=code_for_expr(value.base_expression),
slice=[
cst.SubscriptElement(cst.Index(code_for_expr(param)))
for param in value.param_expressions
],
)
elif isinstance(value, code_ir.SymbolOrFixtureCall):
attr = daglish.Attr("arg_expressions")
args = []
Expand Down Expand Up @@ -199,6 +207,12 @@ def code_for_fn(
),
]
)
if fn.return_type_annotation:
returns = cst.Annotation(
annotation=code_for_expr(fn.return_type_annotation)
)
else:
returns = None
if fn.parameters and len(fn.parameters) > 1:
whitespace_before_params = cst.ParenthesizedWhitespace(
cst.TrailingWhitespace(),
Expand All @@ -211,6 +225,7 @@ def code_for_fn(
cst.Name(fn.name.value),
params,
body,
returns=returns,
decorators=[cst.Decorator(auto_config_expr)] if auto_config_expr else [],
whitespace_before_params=whitespace_before_params,
leading_lines=[cst.EmptyLine(), cst.EmptyLine()],
Expand Down
13 changes: 13 additions & 0 deletions fiddle/_src/codegen/new_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import fiddle as fdl
from fiddle._src.codegen import newcg_symbolic_references
from fiddle._src.codegen.auto_config import add_type_signatures
from fiddle._src.codegen.auto_config import code_ir
from fiddle._src.codegen.auto_config import experimental_top_level_api
from fiddle._src.codegen.auto_config import make_symbolic_references as old_symbolic_references
Expand Down Expand Up @@ -61,6 +62,13 @@ class MakeSymbolicReferences(experimental_top_level_api.MutationCodegenPass):
)


@dataclasses.dataclass(frozen=True)
class AddTypeSignatures(experimental_top_level_api.MutationCodegenPass):
"""Adds return type signatures to fixtures."""

fn: Callable[..., Any] = add_type_signatures.add_return_types


def _get_pass_idx(
codegen_config: fdl.Config[experimental_top_level_api.Codegen],
cls: Type[experimental_top_level_api.CodegenPass],
Expand Down Expand Up @@ -100,6 +108,11 @@ def code_generator(
# Replace MakeSymbolicReferences
idx = _get_pass_idx(config, experimental_top_level_api.MakeSymbolicReferences)
fdl.update_callable(config.passes[idx], MakeSymbolicReferences)

# Insert type annotations before MakeSymbolicReferences. These type
# annotations currently make more sense for non-auto_config cases.
config.passes.insert(idx, fdl.Config(AddTypeSignatures))

return config


Expand Down
2 changes: 1 addition & 1 deletion fiddle/_src/codegen/new_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_code_output(self):
from fiddle._src.testing.example import fake_encoder_decoder


def config_fixture():
def config_fixture() -> fdl.Config[fake_encoder_decoder.FakeEncoder]:
mlp = fdl.Config(fake_encoder_decoder.Mlp, dtype='float32',
use_bias=False, sharding_axes=['embed', 'num_heads', 'head_dim'])
return fdl.Config(fake_encoder_decoder.FakeEncoder, embedders={'tokens':
Expand Down