diff --git a/fiddle/_src/absl_flags/sample_test_binary.py b/fiddle/_src/absl_flags/sample_test_binary.py index dfda4f53..23d4f252 100644 --- a/fiddle/_src/absl_flags/sample_test_binary.py +++ b/fiddle/_src/absl_flags/sample_test_binary.py @@ -35,6 +35,10 @@ def base_experiment() -> fdl.Config: return fake_encoder_decoder.fixture.as_buildable() +def base_experiment_with_bias() -> fdl.Config: + return fake_encoder_decoder.fixture_with_bias.as_buildable() + + def set_dtypes(config, dtype: str): def traverse(value, state): if state.current_path and state.current_path[-1] == daglish.Attr("dtype"): diff --git a/fiddle/_src/experimental/auto_config.py b/fiddle/_src/experimental/auto_config.py index a998547d..de592cf4 100644 --- a/fiddle/_src/experimental/auto_config.py +++ b/fiddle/_src/experimental/auto_config.py @@ -552,6 +552,28 @@ def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs): return partial_cls(buildable_or_callable, *args, **kwargs) +def _override_values_recursively( + base: config.Buildable, overrides: config.Buildable +) -> None: + """Recursively replaces fields in base with values from overrides.""" + for field_name in dir(overrides): + field = getattr(overrides, field_name, config.NO_VALUE) + if field != config.NO_VALUE: + if isinstance(field, config.Buildable): + _override_values_recursively(getattr(base, field_name), field) + else: + setattr(base, field_name, field) + + +def override_values( + base: config.Buildable, overrides: config.Buildable +) -> config.Buildable: + """Returns a copy of base with any values present in overrides overridden.""" + base = base.__deepcopy__(memo={}) + _override_values_recursively(base, overrides) + return base + + def exempt(fn_or_cls: Callable[..., Any]) -> Callable[..., Any]: """Wrap a callable so that it's exempted from auto_config. @@ -599,6 +621,7 @@ def auto_config( experimental_exemption_policy: Optional[auto_config_policy.Policy] = None, experimental_config_types: ConfigTypes = ConfigTypes(), experimental_result_must_contain_buildable: bool = True, + base_config: Optional[AutoConfig] = None, ) -> Any: # TODO(b/272377821): More precise return type. """Rewrites the given function to make it generate a ``Config``. @@ -693,6 +716,9 @@ def build_model(): experimental_result_must_contain_buildable: If true, then raise an error if `fn.as_buildable` returns a result that does not contain any `Buildable` values -- e.g., if it returns an empty dict. + base_config: Ff given, would be used as a default values. This allows to + have common settings defined once while defing multiple slightly different + configurations. Returns: A wrapped version of ``fn``, but with an additional ``as_buildable`` @@ -857,13 +883,25 @@ def make_auto_config(fn): auto_config_fn.__defaults__ = fn.__defaults__ auto_config_fn.__kwdefaults__ = fn.__kwdefaults__ + if base_config is not None: + + @functools.wraps(auto_config_fn) + def auto_config_fn_with_base(*args, **kwargs): + return override_values( + base_config.as_buildable(*args, **kwargs), + auto_config_fn(*args, **kwargs), # pylint: disable=not-callable + ) + + else: + auto_config_fn_with_base = auto_config_fn + # Finally we wrap the rewritten function to perform additional error # checking and enforce that the output contains a `fdl.Buildable`. if experimental_result_must_contain_buildable: - @functools.wraps(auto_config_fn) + @functools.wraps(auto_config_fn_with_base) def as_buildable(*args, **kwargs): - output = auto_config_fn(*args, **kwargs) # pylint: disable=not-callable + output = auto_config_fn_with_base(*args, **kwargs) # pylint: disable=not-callable if not _contains_buildable(output): raise TypeError( f'The `auto_config` rewritten version of `{fn.__qualname__}` ' @@ -875,7 +913,7 @@ def as_buildable(*args, **kwargs): return output else: - as_buildable = auto_config_fn + as_buildable = auto_config_fn_with_base if method_type: fn = method_type(fn) diff --git a/fiddle/_src/testing/example/fake_encoder_decoder.py b/fiddle/_src/testing/example/fake_encoder_decoder.py index 015915bb..3076a229 100644 --- a/fiddle/_src/testing/example/fake_encoder_decoder.py +++ b/fiddle/_src/testing/example/fake_encoder_decoder.py @@ -99,3 +99,16 @@ def fixture(): bias_init), mlp=Mlp(dtype, False, ["num_heads", "head_dim", "embed"]), )) + + +@auto_config.auto_config(base_config=fixture) +def fixture_with_bias(): + # pylint: disable=no-value-for-parameter + # pytype: disable=missing-parameter + return FakeEncoderDecoder( + encoder=FakeEncoder( + mlp=Mlp(use_bias=True), + ), + ) + # pytype: enable=missing-parameter + # pylint: enable=no-value-for-parameter