diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 05bb23e8ddd9..9ff3638eb0b8 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2678,12 +2678,10 @@ def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, f if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None: saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS] - for param in self.module.parameters(): + for name, param in self.module.named_parameters(): if param.requires_grad: continue - if param not in self.param_names: - raise ValueError(f"failed to find frozen {param} in named params") - name = self.param_names[param] + if hasattr(param, 'ds_id'): param.ds_tensor.data.copy_(saved_frozen_params[name].data) else: @@ -3414,12 +3412,9 @@ def _get_param_fragment_func(self, param): def _get_zero_frozen_param_attributes(self, attr_func): frozen_param_fragments = OrderedDict() - for param in self.module.parameters(): + for name, param in self.module.named_parameters(): if param.requires_grad: continue - if param not in self.param_names: - raise ValueError(f"failed to find frozen {param} in named params") - name = self.param_names[param] frozen_param_fragments[name] = attr_func(param) return frozen_param_fragments