Skip to content

Commit

Permalink
allow loading of nested states when Strategy.load_checkpoint is used
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus28 committed Aug 18, 2024
1 parent 025c30e commit a618341
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
17 changes: 12 additions & 5 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def _do_nothing(*_: Any) -> None:
pass


def _recursively_update_state(old_state: Dict[str, Any], new_unwrapped_state: Dict[str, Any]) -> None:
for k in list(new_unwrapped_state.keys()):
obj, _ = _unwrap_compiled(old_state[k])
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
pass
elif isinstance(obj, dict):
_recursively_update_state(old_state[k], new_unwrapped_state[k])
else:
old_state[k] = new_unwrapped_state[k]


class Fabric:
r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.
Expand Down Expand Up @@ -775,11 +786,7 @@ def load(
if state is not None:
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
# (for user metadata) wouldn't show up in the original dict, so we need to copy the data back.
for k in list(unwrapped_state.keys()):
obj, _ = _unwrap_compiled(state[k])
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
continue
state[k] = unwrapped_state[k]
_recursively_update_state(state, unwrapped_state)
return remainder

def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], strict: bool = True) -> None:
Expand Down
29 changes: 18 additions & 11 deletions src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,21 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
# for optimizers that are not sharded, we return the state dict on all ranks
return optimizer.state_dict()

def _recursively_load_state(self, state: Dict[str, Any], checkpoint: Dict[str, Any], strict: bool = True) -> None:
_validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict)
for name, obj in state.copy().items():
if name not in checkpoint:
continue
if isinstance(obj, _Stateful):
if isinstance(obj, Module):
self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=strict)
else:
obj.load_state_dict(checkpoint.pop(name))
elif isinstance(obj, dict):
self._recursively_load_state(state=state[name], checkpoint=checkpoint.pop(name), strict=strict)
else:
state[name] = checkpoint.pop(name)

def load_checkpoint(
self,
path: _PATH,
Expand Down Expand Up @@ -338,17 +353,7 @@ def load_checkpoint(
state.load_state_dict(checkpoint)
return {}

_validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict)
for name, obj in state.copy().items():
if name not in checkpoint:
continue
if isinstance(obj, _Stateful):
if isinstance(obj, Module):
self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=strict)
else:
obj.load_state_dict(checkpoint.pop(name))
else:
state[name] = checkpoint.pop(name)
self._recursively_load_state(state, checkpoint, strict=strict)
return checkpoint

def teardown(self) -> None:
Expand Down Expand Up @@ -405,6 +410,8 @@ def _convert_stateful_objects_in_state(
converted = self.get_optimizer_state(optimizer=obj)
elif isinstance(obj, _Stateful):
converted = obj.state_dict()
elif isinstance(obj, dict):
converted = self._convert_stateful_objects_in_state(obj, filter)
else:
converted = obj
_apply_filter(key, filter, converted, converted_state)
Expand Down

0 comments on commit a618341

Please sign in to comment.