Skip to content

Commit

Permalink
Not a public change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555220375
  • Loading branch information
Fiddle-Config Team authored and copybara-github committed Aug 17, 2023
1 parent b5abdf1 commit 6864e25
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 3 deletions.
4 changes: 4 additions & 0 deletions fiddle/_src/absl_flags/sample_test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def base_experiment() -> fdl.Config:
return fake_encoder_decoder.fixture.as_buildable()


def base_experiment_with_bias() -> fdl.Config:
return fake_encoder_decoder.fixture_with_bias.as_buildable()


def set_dtypes(config, dtype: str):
def traverse(value, state):
if state.current_path and state.current_path[-1] == daglish.Attr("dtype"):
Expand Down
44 changes: 41 additions & 3 deletions fiddle/_src/experimental/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,28 @@ def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs):
return partial_cls(buildable_or_callable, *args, **kwargs)


def _override_values_recursively(
base: config.Buildable, overrides: config.Buildable
) -> None:
"""Recursively replaces fields in base with values from overrides."""
for field_name in dir(overrides):
field = getattr(overrides, field_name, config.NO_VALUE)
if field != config.NO_VALUE:
if isinstance(field, config.Buildable):
_override_values_recursively(getattr(base, field_name), field)
else:
setattr(base, field_name, field)


def override_values(
base: config.Buildable, overrides: config.Buildable
) -> config.Buildable:
"""Returns a copy of base with any values present in overrides overridden."""
base = base.__deepcopy__(memo={})
_override_values_recursively(base, overrides)
return base


def exempt(fn_or_cls: Callable[..., Any]) -> Callable[..., Any]:
"""Wrap a callable so that it's exempted from auto_config.
Expand Down Expand Up @@ -599,6 +621,7 @@ def auto_config(
experimental_exemption_policy: Optional[auto_config_policy.Policy] = None,
experimental_config_types: ConfigTypes = ConfigTypes(),
experimental_result_must_contain_buildable: bool = True,
base_config: Optional[AutoConfig] = None,
) -> Any: # TODO(b/272377821): More precise return type.
"""Rewrites the given function to make it generate a ``Config``.
Expand Down Expand Up @@ -693,6 +716,9 @@ def build_model():
experimental_result_must_contain_buildable: If true, then raise an error if
`fn.as_buildable` returns a result that does not contain any `Buildable`
values -- e.g., if it returns an empty dict.
base_config: Ff given, would be used as a default values. This allows to
have common settings defined once while defing multiple slightly different
configurations.
Returns:
A wrapped version of ``fn``, but with an additional ``as_buildable``
Expand Down Expand Up @@ -857,13 +883,25 @@ def make_auto_config(fn):
auto_config_fn.__defaults__ = fn.__defaults__
auto_config_fn.__kwdefaults__ = fn.__kwdefaults__

if base_config is not None:

@functools.wraps(auto_config_fn)
def auto_config_fn_with_base(*args, **kwargs):
return override_values(
base_config.as_buildable(*args, **kwargs),
auto_config_fn(*args, **kwargs), # pylint: disable=not-callable
)

else:
auto_config_fn_with_base = auto_config_fn

# Finally we wrap the rewritten function to perform additional error
# checking and enforce that the output contains a `fdl.Buildable`.
if experimental_result_must_contain_buildable:

@functools.wraps(auto_config_fn)
@functools.wraps(auto_config_fn_with_base)
def as_buildable(*args, **kwargs):
output = auto_config_fn(*args, **kwargs) # pylint: disable=not-callable
output = auto_config_fn_with_base(*args, **kwargs) # pylint: disable=not-callable
if not _contains_buildable(output):
raise TypeError(
f'The `auto_config` rewritten version of `{fn.__qualname__}` '
Expand All @@ -875,7 +913,7 @@ def as_buildable(*args, **kwargs):
return output

else:
as_buildable = auto_config_fn
as_buildable = auto_config_fn_with_base

if method_type:
fn = method_type(fn)
Expand Down
13 changes: 13 additions & 0 deletions fiddle/_src/testing/example/fake_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,16 @@ def fixture():
bias_init),
mlp=Mlp(dtype, False, ["num_heads", "head_dim", "embed"]),
))


@auto_config.auto_config(base_config=fixture)
def fixture_with_bias():
# pylint: disable=no-value-for-parameter
# pytype: disable=missing-parameter
return FakeEncoderDecoder(
encoder=FakeEncoder(
mlp=Mlp(use_bias=True),
),
)
# pytype: enable=missing-parameter
# pylint: enable=no-value-for-parameter

0 comments on commit 6864e25

Please sign in to comment.