From 9392fa705282b37c78f6559f344808bc5558490a Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 18 Oct 2023 11:34:26 -0500 Subject: [PATCH] add grad normalization feature --- axonn/intra_layer/__init__.py | 3 ++ axonn/intra_layer/fully_connected.py | 42 +++++++++++++++------------- axonn/tests/test_intra_layer_fc.py | 4 +-- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index cce4cb5..fbdf102 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -1,5 +1,7 @@ from .fully_connected import Linear # noqa: F401 from .communication import Drop, Gather +from .gradient_normalization import clip_grad_norm_ + from axonn import axonn as ax @@ -17,4 +19,5 @@ def gather(x, transpose=False): group = ax.comm_handle.inner_intra_layer_parallel_group else: group = ax.comm_handle.outer_intra_layer_parallel_group + return Gather.apply(x, group) diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index 61ac6ef..af7de7c 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -1,8 +1,9 @@ from axonn import axonn as ax import torch.distributed as dist import torch -from .communication import ForwardAllReduce, BackwardAllReduce, Drop, Gather +import math +from .communication import ForwardAllReduce, BackwardAllReduce, Drop, Gather def divide(a, b): assert a % b == 0 @@ -20,7 +21,6 @@ def initialize_params( params = Drop.apply(params, in_features_group) return params - class Linear(torch.nn.Module): def __init__( self, @@ -38,33 +38,35 @@ def __init__( self.inner_group_size = dist.get_world_size(self.inner_group) self.outer_group_size = dist.get_world_size(self.outer_group) - + + if init_method is None: + ## this is the same as pytorch 2.1 + init_method = lambda weight : torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + if not transpose: assert in_features % self.inner_group_size == 0 assert out_features % self.outer_group_size == 0 self.local_in_features = divide(in_features, self.inner_group_size) self.local_out_features = divide(out_features, self.outer_group_size) - if init_method: - initial_params = initialize_params( - out_features, - in_features, - self.outer_group, - self.inner_group, - init_method, - ) + initial_params = initialize_params( + out_features, + in_features, + self.outer_group, + self.inner_group, + init_method, + ) else: assert out_features % self.inner_group_size == 0 assert in_features % self.outer_group_size == 0 self.local_in_features = divide(in_features, self.outer_group_size) self.local_out_features = divide(out_features, self.inner_group_size) - if init_method: - initial_params = initialize_params( - out_features, - in_features, - self.inner_group, - self.outer_group, - init_method, - ) + initial_params = initialize_params( + out_features, + in_features, + self.inner_group, + self.outer_group, + init_method, + ) self.linear = torch.nn.Linear( in_features=self.local_in_features, @@ -77,6 +79,8 @@ def __init__( if init_method: self.linear.weight.data.copy_(initial_params) + setattr(self.linear.weight, "is_tensor_parallel", True) + self.bias = torch.nn.Parameter( torch.zeros( self.local_out_features, diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index 7216103..4f600f8 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -26,8 +26,8 @@ def test_fw_pass(G_intra_r, G_intra_c, B, H): X_local = _drop( X, 1, inner_group ) # divide colunns of X along the inner tensor group - layer = Tensor_Parallel_Linear( - in_features=H, out_features=H, skip_bias_add=True + layer = Linear( + in_features=H, out_features=H, skip_bias_add=True, ).cuda() with torch.no_grad():