diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index c42b9d67c4..c5adb7e110 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -163,7 +163,12 @@ def _deep_copy_full_config(subconfig: Any) -> Any: return copy.deepcopy(subconfig) -def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: +def instantiate( + config: Any, + *args: Any, + _skip_instantiate_full_deepcopy_: bool = False, + **kwargs: Any, +) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: @@ -186,6 +191,10 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: are converted to dicts / lists too. _partial_: If True, return functools.partial wrapped method or object False by default. Configure per target. + :param _skip_instantiate_full_deepcopy_: If True, deep copy just the input config instead + of full config before resolving omegaconf interpolations, which may + potentially modify the config's parent/sibling configs in place. + False by default. :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present @@ -225,8 +234,12 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: if OmegaConf.is_dict(config): # Finalize config (convert targets to strings, merge with kwargs) - # Create full copy to avoid mutating original - config_copy = _deep_copy_full_config(config) + # Create copy to avoid mutating original + if _skip_instantiate_full_deepcopy_: + config_copy = copy.deepcopy(config) + config_copy._set_parent(config._get_parent()) + else: + config_copy = _deep_copy_full_config(config) config_copy._set_flag( flags=["allow_objects", "struct", "readonly"], values=[True, False, False] ) @@ -246,8 +259,12 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: ) elif OmegaConf.is_list(config): # Finalize config (convert targets to strings, merge with kwargs) - # Create full copy to avoid mutating original - config_copy = _deep_copy_full_config(config) + # Create copy to avoid mutating original + if _skip_instantiate_full_deepcopy_: + config_copy = copy.deepcopy(config) + config_copy._set_parent(config._get_parent()) + else: + config_copy = _deep_copy_full_config(config) config_copy._set_flag( flags=["allow_objects", "struct", "readonly"], values=[True, False, False] ) diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index c022ccb72b..10a67b1c14 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -485,6 +485,7 @@ def test_none_cases( assert str(cfg) == original_config_str +@mark.parametrize("skip_deepcopy", [True, False]) @mark.parametrize("convert_to_list", [True, False]) @mark.parametrize( "input_conf, passthrough, expected", @@ -578,6 +579,7 @@ def test_interpolation_accessing_parent( passthrough: Dict[str, Any], expected: Any, convert_to_list: bool, + skip_deepcopy: bool, ) -> Any: if convert_to_list: input_conf = copy.deepcopy(input_conf) @@ -586,15 +588,24 @@ def test_interpolation_accessing_parent( input_conf = OmegaConf.create(input_conf) original_config_str = str(input_conf) if convert_to_list: - obj = instantiate_func(input_conf.node[0], **passthrough) + obj = instantiate_func( + input_conf.node[0], + _skip_instantiate_full_deepcopy_=skip_deepcopy, + **passthrough, + ) else: - obj = instantiate_func(input_conf.node, **passthrough) + obj = instantiate_func( + input_conf.node, + _skip_instantiate_full_deepcopy_=skip_deepcopy, + **passthrough, + ) if isinstance(expected, partial): assert partial_equal(obj, expected) else: assert obj == expected assert input_conf == cfg_copy - assert str(input_conf) == original_config_str + if not skip_deepcopy: + assert str(input_conf) == original_config_str @mark.parametrize(