Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not a public change. #473

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fiddle/_src/absl_flags/sample_test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def base_experiment() -> fdl.Config:
return fake_encoder_decoder.fixture.as_buildable()


def base_experiment_with_bias() -> fdl.Config:
return fake_encoder_decoder.fixture_with_bias.as_buildable()


def set_dtypes(config, dtype: str):
def traverse(value, state):
if state.current_path and state.current_path[-1] == daglish.Attr("dtype"):
Expand Down
44 changes: 41 additions & 3 deletions fiddle/_src/experimental/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,28 @@ def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs):
return partial_cls(buildable_or_callable, *args, **kwargs)


def _override_values_recursively(
base: config.Buildable, overrides: config.Buildable
) -> None:
"""Recursively replaces fields in base with values from overrides."""
for field_name in dir(overrides):
field = getattr(overrides, field_name, config.NO_VALUE)
if field != config.NO_VALUE:
if isinstance(field, config.Buildable):
_override_values_recursively(getattr(base, field_name), field)
else:
setattr(base, field_name, field)


def override_values(
base: config.Buildable, overrides: config.Buildable
) -> config.Buildable:
"""Returns a copy of base with any values present in overrides overridden."""
base = base.__deepcopy__(memo={})
_override_values_recursively(base, overrides)
return base


def exempt(fn_or_cls: Callable[..., Any]) -> Callable[..., Any]:
"""Wrap a callable so that it's exempted from auto_config.

Expand Down Expand Up @@ -599,6 +621,7 @@ def auto_config(
experimental_exemption_policy: Optional[auto_config_policy.Policy] = None,
experimental_config_types: ConfigTypes = ConfigTypes(),
experimental_result_must_contain_buildable: bool = True,
base_config: Optional[AutoConfig] = None,
) -> Any: # TODO(b/272377821): More precise return type.
"""Rewrites the given function to make it generate a ``Config``.

Expand Down Expand Up @@ -693,6 +716,9 @@ def build_model():
experimental_result_must_contain_buildable: If true, then raise an error if
`fn.as_buildable` returns a result that does not contain any `Buildable`
values -- e.g., if it returns an empty dict.
base_config: Ff given, would be used as a default values. This allows to
have common settings defined once while defing multiple slightly different
configurations.

Returns:
A wrapped version of ``fn``, but with an additional ``as_buildable``
Expand Down Expand Up @@ -857,13 +883,25 @@ def make_auto_config(fn):
auto_config_fn.__defaults__ = fn.__defaults__
auto_config_fn.__kwdefaults__ = fn.__kwdefaults__

if base_config is not None:

@functools.wraps(auto_config_fn)
def auto_config_fn_with_base(*args, **kwargs):
return override_values(
base_config.as_buildable(*args, **kwargs),
auto_config_fn(*args, **kwargs), # pylint: disable=not-callable
)

else:
auto_config_fn_with_base = auto_config_fn

# Finally we wrap the rewritten function to perform additional error
# checking and enforce that the output contains a `fdl.Buildable`.
if experimental_result_must_contain_buildable:

@functools.wraps(auto_config_fn)
@functools.wraps(auto_config_fn_with_base)
def as_buildable(*args, **kwargs):
output = auto_config_fn(*args, **kwargs) # pylint: disable=not-callable
output = auto_config_fn_with_base(*args, **kwargs) # pylint: disable=not-callable
if not _contains_buildable(output):
raise TypeError(
f'The `auto_config` rewritten version of `{fn.__qualname__}` '
Expand All @@ -875,7 +913,7 @@ def as_buildable(*args, **kwargs):
return output

else:
as_buildable = auto_config_fn
as_buildable = auto_config_fn_with_base

if method_type:
fn = method_type(fn)
Expand Down
13 changes: 13 additions & 0 deletions fiddle/_src/testing/example/fake_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,16 @@ def fixture():
bias_init),
mlp=Mlp(dtype, False, ["num_heads", "head_dim", "embed"]),
))


@auto_config.auto_config(base_config=fixture)
def fixture_with_bias():
# pylint: disable=no-value-for-parameter
# pytype: disable=missing-parameter
return FakeEncoderDecoder(
encoder=FakeEncoder(
mlp=Mlp(use_bias=True),
),
)
# pytype: enable=missing-parameter
# pylint: enable=no-value-for-parameter