Skip to content

Commit

Permalink
add grad normalization feature
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Oct 18, 2023
1 parent 34abe16 commit 9392fa7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
3 changes: 3 additions & 0 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .fully_connected import Linear # noqa: F401
from .communication import Drop, Gather
from .gradient_normalization import clip_grad_norm_

from axonn import axonn as ax


Expand All @@ -17,4 +19,5 @@ def gather(x, transpose=False):
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group

return Gather.apply(x, group)
42 changes: 23 additions & 19 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from axonn import axonn as ax
import torch.distributed as dist
import torch
from .communication import ForwardAllReduce, BackwardAllReduce, Drop, Gather
import math

from .communication import ForwardAllReduce, BackwardAllReduce, Drop, Gather

def divide(a, b):
assert a % b == 0
Expand All @@ -20,7 +21,6 @@ def initialize_params(
params = Drop.apply(params, in_features_group)
return params


class Linear(torch.nn.Module):
def __init__(
self,
Expand All @@ -38,33 +38,35 @@ def __init__(

self.inner_group_size = dist.get_world_size(self.inner_group)
self.outer_group_size = dist.get_world_size(self.outer_group)


if init_method is None:
## this is the same as pytorch 2.1
init_method = lambda weight : torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))

if not transpose:
assert in_features % self.inner_group_size == 0
assert out_features % self.outer_group_size == 0
self.local_in_features = divide(in_features, self.inner_group_size)
self.local_out_features = divide(out_features, self.outer_group_size)
if init_method:
initial_params = initialize_params(
out_features,
in_features,
self.outer_group,
self.inner_group,
init_method,
)
initial_params = initialize_params(
out_features,
in_features,
self.outer_group,
self.inner_group,
init_method,
)
else:
assert out_features % self.inner_group_size == 0
assert in_features % self.outer_group_size == 0
self.local_in_features = divide(in_features, self.outer_group_size)
self.local_out_features = divide(out_features, self.inner_group_size)
if init_method:
initial_params = initialize_params(
out_features,
in_features,
self.inner_group,
self.outer_group,
init_method,
)
initial_params = initialize_params(
out_features,
in_features,
self.inner_group,
self.outer_group,
init_method,
)

self.linear = torch.nn.Linear(
in_features=self.local_in_features,
Expand All @@ -77,6 +79,8 @@ def __init__(
if init_method:
self.linear.weight.data.copy_(initial_params)

setattr(self.linear.weight, "is_tensor_parallel", True)

self.bias = torch.nn.Parameter(
torch.zeros(
self.local_out_features,
Expand Down
4 changes: 2 additions & 2 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H):
X_local = _drop(
X, 1, inner_group
) # divide colunns of X along the inner tensor group
layer = Tensor_Parallel_Linear(
in_features=H, out_features=H, skip_bias_add=True
layer = Linear(
in_features=H, out_features=H, skip_bias_add=True,
).cuda()

with torch.no_grad():
Expand Down

0 comments on commit 9392fa7

Please sign in to comment.