Skip to content

Commit

Permalink
[Primitive] Add .replace_all() (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Mar 23, 2023
1 parent 3e46518 commit a4e00da
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
28 changes: 28 additions & 0 deletions slapo/primitives/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,31 @@ def apply(sch, new_mod_or_func, target_ops=None, name=None, concrete_args=None):
sch,
sch.group,
)


@register_primitive()
class ReplaceAllPrimitive(Primitive):
"""Replace all the specified submodules with the new module.
Parameters
----------
sch : Schedule
The schedule with the module/function to be replaced.
target_mod_type : Type
A target nn.Module type to be replaced.
make_mod_fn : FunctionType
A function that takes the original module and generate a new module.
"""

@staticmethod
def name():
return "replace_all"

@staticmethod
def apply(sch, target_mod_type, make_mod_fn):
module_names = dict(sch.mod.named_modules()).keys()
for name in module_names:
subsch = sch[name]
if isinstance(subsch.mod, target_mod_type):
new_mod = make_mod_fn(name, subsch.mod)
subsch.replace(new_mod)
56 changes: 56 additions & 0 deletions tests/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,62 @@ def forward(self, x):
assert isinstance(sch["activation"].mod, nn.GELU)


def test_replace_all_module():
class SubMod(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1024, 1024)
self.activation = nn.ReLU()

def forward(self, x):
x = self.linear(x)
x = self.activation(x)
return x

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1024, 1024)
self.act1 = nn.ReLU()
self.fc2 = nn.Linear(1024, 1024)
self.act2 = nn.ReLU()
self.submod = SubMod()

def forward(self, x):
x = self.fc1(x)
x = self.act1(x)
x = self.fc2(x)
x = self.act2(x)
x = self.submod(x)
return x

model = Model()
sch = slapo.create_schedule(model)

def make_gelu(name, mod):
return nn.GELU()

sch.replace_all(nn.ReLU, make_gelu)
assert isinstance(sch["act1"].mod, nn.GELU)
assert isinstance(sch["act2"].mod, nn.GELU)
assert isinstance(sch["submod.activation"].mod, nn.GELU)

# test giving different shape of parameters
def make_linear(name, mod):
if name == "fc1":
in_feat, out_feat = 1024, 1025
elif name == "fc2":
in_feat, out_feat = 1025, 1026
else:
in_feat, out_feat = 1026, 1027
return nn.Linear(in_feat, out_feat)

sch.replace_all(nn.Linear, make_linear)
assert sch["fc1"].mod.out_features == 1025
assert sch["fc2"].mod.out_features == 1026
assert sch["submod.linear"].mod.out_features == 1027


def test_vertical_replacement():
class Model(nn.Module):
def __init__(self):
Expand Down

0 comments on commit a4e00da

Please sign in to comment.