Skip to content

Commit

Permalink
[Traceable FSDP2] Dynamo support FSDP2 use_training_state context man…
Browse files Browse the repository at this point in the history
…ager (pytorch#127854)

Improve Dynamo to support the FSDP2 `use_training_state()` context manager.

Test command:
`
pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_dynamo_trace_use_training_state
`

Pull Request resolved: pytorch#127854
Approved by: https://github.com/yanboliang
  • Loading branch information
yf225 authored and pytorchmergebot committed Jun 16, 2024
1 parent e4d8aa4 commit 979edbb
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 2 deletions.
41 changes: 41 additions & 0 deletions test/distributed/_composable/fsdp/test_fully_shard_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import unittest

import torch
import torch._dynamo.testing
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable.fsdp._fsdp_common import TrainingState
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_utils import run_tests
Expand Down Expand Up @@ -60,5 +63,43 @@ def patched_trace_rules_check(*args, **kwargs):
self.assertTrue(trace_rules_check_count > 0)


class TestFullyShardCompile(FSDPTest):
def test_dynamo_trace_use_training_state(self):
torch._dynamo.reset()
# Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
param_group = FSDPParamGroup(
[], # params: List[nn.Parameter],
torch.nn.Linear(1, 1), # module: nn.Module,
None, # mesh_info: FSDPMeshInfo,
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
None, # device: torch.device,
None, # mp_policy: MixedPrecisionPolicy,
None, # offload_policy: OffloadPolicy,
)

def f(x):
param_group._training_state = TrainingState.IDLE
with param_group.use_training_state(TrainingState.FORWARD):
if param_group._training_state == TrainingState.FORWARD:
return x + 1
else:
return x

inp = torch.zeros(1)
self.assertEqual(param_group._training_state, TrainingState.IDLE)

eager_out = f(inp)
self.assertEqual(param_group._training_state, TrainingState.IDLE)
self.assertEqual(eager_out, inp + 1)

cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_out = torch.compile(f, backend=cnt, fullgraph=True)(inp)
self.assertEqual(param_group._training_state, TrainingState.IDLE)
self.assertEqual(eager_out, compiled_out)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
self.assertEqual(len(cnt.graphs), 1)


if __name__ == "__main__":
run_tests()
3 changes: 3 additions & 0 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,9 @@ def DETERMINISTIC_ALGORITHMS(self, guard: Guard):
def TORCH_FUNCTION_STATE(self, guard: Guard):
pass # we always guard on this via GlobalStateGuard()

def FSDP_TRAINING_STATE(self, guard: Guard):
pass # we always guard on this via GlobalStateGuard()

def DEFAULT_DEVICE(self, guard: Guard):
"""Guard on CURRENT_DEVICE per torch.utils._device"""
assert guard.source is GuardSource.GLOBAL
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DeterministicAlgorithmsVariable,
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
FSDPParamGroupUseTrainingStateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
Expand Down
56 changes: 56 additions & 0 deletions torch/_dynamo/variables/ctx_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,62 @@ def reconstruct(self, codegen):
)


class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable):
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE)

@staticmethod
def create(tx, param_group_var, target_value, **kwargs):
var = FSDPParamGroupUseTrainingStateVariable(
param_group_var=param_group_var,
target_values=[target_value],
initial_values=[param_group_var.value._training_state],
**kwargs,
)
return var

def __init__(self, param_group_var, target_values, initial_values=None, **kwargs):
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.param_group_var = param_group_var
install_guard(self._guards_singleton)

def enter(self, tx):
self._call_func(tx, self.target_values)
return variables.ConstantVariable.create(None)

def exit(self, tx, *args):
self._call_func(tx, self.initial_values)
return variables.ConstantVariable.create(None)

def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
):
self._call_func(tx, self.initial_values) # undo eager initialization
return super().call_function(tx, args, kwargs)

def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
if self.param_group_var.value._training_state != value:
self.param_group_var.call_method(
tx,
"__setattr__",
(
variables.ConstantVariable.create("_training_state"),
variables.EnumVariable(value),
),
{},
)
self.param_group_var.value._training_state = value

def module_name(self):
return "torch.distributed._composable.fsdp._fsdp_param_group.FSDPParamGroup"

def fn_name(self):
return "use_training_state"


class StreamVariable(VariableTracker):
def __init__(self, proxy, value, device, **kwargs):
if proxy is not None and "example_value" in proxy.node.meta:
Expand Down
12 changes: 12 additions & 0 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
if TYPE_CHECKING:
from torch._guards import Source

try:
from torch.distributed._composable.fsdp import _fsdp_param_group
except ModuleNotFoundError:
_fsdp_param_group = None


def wrap_bound_arg(tx, val, source=None):
# Source propagation is best effort since not every object we encounter has a source to begin with.
Expand Down Expand Up @@ -338,6 +343,13 @@ def call_function(
return self.obj.call_method(
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
)
elif (
_fsdp_param_group is not None
and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state
):
return variables.TorchCtxManagerClassVariable(self.fn).call_function(
tx, (self.obj, *args), kwargs
)
if self.is_constant:
fn = getattr(self.obj.value, self.fn.__name__)
return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
Expand Down
14 changes: 14 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
except ModuleNotFoundError:
np = None # type: ignore[assignment]

try:
from torch.distributed._composable.fsdp import _fsdp_param_group
except ModuleNotFoundError:
_fsdp_param_group = None # type: ignore[assignment]

log = logging.getLogger(__name__)

supported_ctx_manager_classes = dict.fromkeys(
Expand Down Expand Up @@ -203,6 +208,7 @@ def call_function(
from . import (
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
FSDPParamGroupUseTrainingStateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
Expand Down Expand Up @@ -300,6 +306,14 @@ def call_function(
return DisabledSavedTensorsHooksVariable.create(
tx, args[0].as_python_constant()
)
elif (
_fsdp_param_group is not None
and self.value is _fsdp_param_group.FSDPParamGroup.use_training_state
):
assert len(args) == 2
return FSDPParamGroupUseTrainingStateVariable.create(
tx, args[0], args[1].as_python_constant()
)

return super().call_function(tx, args, kwargs)

Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_composable/fsdp/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# The decorator adds a state object to `module` that can be accessed via
# `fully_shard.state(module)`. The state object and module are 1:1.
@contract(state_cls=FSDPState)
@contract(state_cls=FSDPState) # type: ignore[operator]
def fully_shard(
module: nn.Module,
*,
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _to_kwargs(
def _verify_param_shape_across_processes(
process_group: dist.ProcessGroup,
tensors: List[torch.Tensor],
logger: Optional[dist.Logger] = None,
logger: Optional["dist.Logger"] = None,
):
return dist._verify_params_across_processes(process_group, tensors, logger)

Expand Down

0 comments on commit 979edbb

Please sign in to comment.