Skip to content

Commit

Permalink
Add pipe.get_stage_module
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Mar 4, 2024
1 parent 4de7fc2 commit 83b80d3
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
5 changes: 5 additions & 0 deletions pippy/IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,11 @@ def forward(self, *args, **kwargs):

return res

def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
if stage_idx < 0 or stage_idx >= self.num_stages:
raise ValueError(f"Invalid stage index {stage_idx}!")
return getattr(self.split_gm, f"submod_{stage_idx}")

@staticmethod
def _number_and_count_forward_stages(gm: fx.GraphModule):
num_stages = 0
Expand Down
9 changes: 4 additions & 5 deletions test/test_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import pippy
import torch
from pippy import annotate_split_points, Pipe, SplitPoint

Expand Down Expand Up @@ -60,7 +59,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
(x,),
)
assert pipe.num_stages == 2
gm = pipe.split_gm


def get_layers(module):
Expand All @@ -70,9 +68,10 @@ def get_layers(module):

# Collect all layers in pipe
layers = []
for name, submod in gm.named_children():
print(f"\nStage {name}: \n", submod)
layers += get_layers(submod)
for stage_idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(stage_idx)
print(f"\nStage {stage_idx}: \n", stage_mod)
layers += get_layers(stage_mod)

# Check layer completeness
orig_layers = get_layers(transformer)
Expand Down
9 changes: 4 additions & 5 deletions test/test_unflatten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import pippy
import torch
from pippy import Pipe, pipe_split

Expand Down Expand Up @@ -52,14 +51,14 @@ def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
)

assert pipe.num_stages == 4
gm = pipe.split_gm
orig_state_dict = mod.state_dict()

# Check qualnames
print("\nParameters of each stage:")
for name, submod in gm.named_children():
print(f"\nStage {name}:")
for param_name, param in submod.named_parameters():
for stage_idx in range(pipe.num_stages):
print(f"\nStage {stage_idx}:")
stage_mod = pipe.get_stage_module(stage_idx)
for param_name, param in stage_mod.named_parameters():
assert (
param_name in orig_state_dict
), f"{param_name} not in original state dict"
Expand Down

0 comments on commit 83b80d3

Please sign in to comment.