From d70c5ffbe653741277751294bf145d178c35b331 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 1 Jan 2024 17:07:43 +0530 Subject: [PATCH] More communication optimizations (#57) --- axonn/axonn.py | 5 ++ axonn/communication.py | 14 ++++ axonn/intra_layer/__init__.py | 109 ++++++++++++++++++++++----- axonn/intra_layer/communication.py | 7 +- axonn/intra_layer/fully_connected.py | 23 +++--- axonn/tests/test_intra_layer_fc.py | 5 +- 6 files changed, 128 insertions(+), 35 deletions(-) diff --git a/axonn/axonn.py b/axonn/axonn.py index c651963..52357f1 100644 --- a/axonn/axonn.py +++ b/axonn/axonn.py @@ -137,6 +137,11 @@ def init( config.inter_layer_parallel_rank = comm_handle.inter_layer_parallel_rank config.data_parallel_rank = comm_handle.data_parallel_rank config.intra_layer_parallel_rank = comm_handle.intra_layer_parallel_rank + config.intra_layer_depth_parallel_rank = comm_handle.intra_layer_depth_parallel_rank + config.intra_layer_row_parallel_rank = comm_handle.intra_layer_row_parallel_rank + config.intra_layer_column_parallel_rank = ( + comm_handle.intra_layer_column_parallel_rank + ) is_initialized = True if mixed_precision: computation_dtype = torch.float16 diff --git a/axonn/communication.py b/axonn/communication.py index cb718a6..949ab90 100644 --- a/axonn/communication.py +++ b/axonn/communication.py @@ -184,6 +184,20 @@ def __init__( if self.world_rank in group_members: self.depth_intra_layer_parallel_group = group + # combined inner+outer + for i in range(G_intra_d): + group_members = list( + ranks_in_ith_jth_intra_layer_group[i, :, :].flatten() + ) + group = torch.distributed.new_group( + ranks=group_members, backend="nccl" + ) + if self.world_rank in group_members: + self.outer_inner_intra_layer_parallel_group = group + self.outer_inner_intra_layer_parallel_group_root = ( + group_members[0] + ) + def _torch_to_mpi(self, tensor: torch.Tensor): """Converts a PyTorch tensor into an mpi4py compatible array using its unified virtual address diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index f4c9edf..a0e9bc2 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -10,31 +10,38 @@ import torch.distributed as dist -def drop(x, transpose=False, dim=-1, batch_dim=0): +def drop( + x, transpose=False, dim=-1, batch_dim=0, skip_channels=False, skip_batch=False +): if not transpose: group = ax.comm_handle.inner_intra_layer_parallel_group else: group = ax.comm_handle.outer_intra_layer_parallel_group - x = Drop.apply(x, group, dim) - x = Drop.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim) + if not skip_channels: + x = Drop.apply(x, group, dim) + if not skip_batch: + x = Drop.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim) return x -def gather(x, transpose=False, dim=-1, batch_dim=0): +def gather( + x, transpose=False, dim=-1, batch_dim=0, skip_channels=False, skip_batch=False +): if not transpose: group = ax.comm_handle.inner_intra_layer_parallel_group else: group = ax.comm_handle.outer_intra_layer_parallel_group - x = Gather.apply(x, group, dim) - x = Gather.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim) + if not skip_channels: + x = Gather.apply(x, group, dim) + if not skip_batch: + x = Gather.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim) return x OVERLAP_REDUCE_SCATTER = False OVERLAP_ALL_REDUCE = False -CACHE_WEIGHTS = False ALL_GATHER_ITERATOR = None handles = [] pending_grad_accumulations = [] @@ -110,43 +117,64 @@ def enqueue_next_all_gather(): pass -def retrieve_all_gathered_weight(weight): - global CACHE_WEIGHTS, ALL_GATHER_ITERATOR +def retrieve_all_gathered_weight(weight, delete): + global ALL_GATHER_ITERATOR assert weight in weights_cache all_gathered_weight, handle = weights_cache[weight] if ALL_GATHER_ITERATOR is not None: enqueue_next_all_gather() + if delete: + del weights_cache[weight] return all_gathered_weight, handle +@contextmanager +def overlap_all_gathers_for_checkpointed_forward( + model_object_for_overlapping_allgathers, +): + global ALL_GATHER_ITERATOR + if ALL_GATHER_ITERATOR is None: # this is a false call + try: + yield None + finally: + pass + else: + old_iterator = ALL_GATHER_ITERATOR + ALL_GATHER_ITERATOR = trigger_async_all_gathers( + model_object_for_overlapping_allgathers + ) + enqueue_next_all_gather() + try: + yield None + finally: + ALL_GATHER_ITERATOR = old_iterator + + @contextmanager def optimize_communication( overlap_all_reduce=True, overlap_reduce_scatter=False, - cache_weights=False, overlap_all_gather=False, - model=None, + model_object_for_overlapping_allgathers=None, *args, **kwargs ): - global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER, CACHE_WEIGHTS + global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER global ALL_GATHER_ITERATOR OVERLAP_ALL_REDUCE = overlap_all_reduce OVERLAP_REDUCE_SCATTER = overlap_reduce_scatter - CACHE_WEIGHTS = cache_weights - if overlap_all_gather: - if model is None: + if model_object_for_overlapping_allgathers is None: raise ValueError( "You need to pass your model as an argument - " - "optimize_communication(...,model=model, ...)" + "optimize_communication(...,model_object_" + "for_overlapping_allgathers=model, ...)" "if overlap_all_gather is True" ) - assert ( - cache_weights - ), "all gathers can only be overlapped if cache_weights is True" - ALL_GATHER_ITERATOR = trigger_async_all_gathers(model) + ALL_GATHER_ITERATOR = trigger_async_all_gathers( + model_object_for_overlapping_allgathers + ) enqueue_next_all_gather() try: @@ -157,3 +185,44 @@ def optimize_communication( OVERLAP_ALL_REDUCE = False OVERLAP_REDUCE_SCATTER = False ALL_GATHER_ITERATOR = None + + +@torch.no_grad() +def sync_gradients(model, gradient_attr_name="grad", mean=False, vectorize=False): + grads_to_sync = [] + for param in model.parameters(): + if param.requires_grad: + grad = getattr(param, gradient_attr_name) + if grad is not None: + if hasattr(param, "is_tensor_parallel") and param.is_tensor_parallel: + if ( + hasattr(param, "needs_gradient_sync") + and param.needs_gradient_sync + ): + grads_to_sync.append(grad) + else: + grads_to_sync.append(grad) + + if not grads_to_sync: + return + + world_size = dist.get_world_size(ax.comm_handle.depth_intra_layer_parallel_group) + if vectorize: + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + global_grad = _flatten_dense_tensors(grads_to_sync) + dist.all_reduce( + global_grad, group=ax.comm_handle.depth_intra_layer_parallel_group + ) + if mean: + global_grad.div_(world_size) + + for old_tensor, new_tensor in zip( + grads_to_sync, _unflatten_dense_tensors(global_grad, grads_to_sync) + ): + old_tensor.data = new_tensor + else: + for grad in grads_to_sync: + dist.all_reduce(grad, group=ax.comm_handle.depth_intra_layer_parallel_group) + if mean: + grad.div_(world_size) diff --git a/axonn/intra_layer/communication.py b/axonn/intra_layer/communication.py index d4f3334..3926ba8 100644 --- a/axonn/intra_layer/communication.py +++ b/axonn/intra_layer/communication.py @@ -32,10 +32,13 @@ def _gather(input_, dim, process_group=None, cache=False): return input_ if input_ in axonn.intra_layer.weights_cache: - output, handle = axonn.intra_layer.retrieve_all_gathered_weight(input_) + output, handle = axonn.intra_layer.retrieve_all_gathered_weight( + input_, delete=not cache + ) if handle is not None: handle.wait() - axonn.intra_layer.weights_cache[input_][1] = None + if cache: + axonn.intra_layer.weights_cache[input_][1] = None else: input_ = input_.contiguous() # Size and dimension. diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index d677648..e86c2a4 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -6,12 +6,7 @@ from axonn import axonn as ax import axonn -from .communication import ( - Drop, - Gather, - ForwardGather_BackwardReduceScatter, - BackwardAllReduce, -) +from .communication import Drop, Gather, ForwardGather_BackwardReduceScatter def divide(a, b): @@ -168,6 +163,7 @@ def __init__( self.weight = torch.nn.Parameter(initial_params, requires_grad=True) setattr(self.weight, "is_tensor_parallel", True) + setattr(self.weight, "needs_gradient_sync", False) setattr( self.weight, "process_group_for_norm_reduction", @@ -181,6 +177,7 @@ def __init__( ) ) setattr(self.bias, "is_tensor_parallel", True) + setattr(self.bias, "needs_gradient_sync", True) if not transpose: setattr( self.bias, @@ -204,7 +201,13 @@ def __init__( def get_output_feature_size(self): return self.local_out_features - def forward(self, x, scatter_input=True, gather_output=True): + def forward( + self, + x, + scatter_input=True, + gather_output=True, + cache_weights_in_all_gather=False, + ): # gather weights from depth parallel group # reduce scatter in the backward pass weight = ForwardGather_BackwardReduceScatter.apply( @@ -212,7 +215,7 @@ def forward(self, x, scatter_input=True, gather_output=True): self.depth_group, 0, axonn.intra_layer.OVERLAP_REDUCE_SCATTER, - axonn.intra_layer.CACHE_WEIGHTS, + cache_weights_in_all_gather, ).reshape(self.local_out_features, self.local_in_features) if not self.transpose: @@ -256,10 +259,6 @@ def forward(self, x, scatter_input=True, gather_output=True): bias, self.outer_group if not self.transpose else self.inner_group, ) - else: - bias = BackwardAllReduce.apply( - bias, self.depth_group, axonn.intra_layer.OVERLAP_REDUCE_SCATTER - ) if self.skip_bias_add: return x, bias else: diff --git a/axonn/tests/test_intra_layer_fc.py b/axonn/tests/test_intra_layer_fc.py index 29409d2..d8d5821 100644 --- a/axonn/tests/test_intra_layer_fc.py +++ b/axonn/tests/test_intra_layer_fc.py @@ -7,6 +7,7 @@ clip_grad_norm_, optimize_communication, clear_weights_cache, + sync_gradients, ) @@ -138,11 +139,13 @@ def test_bw_pass( overlap_reduce_scatter=comm_opt_level >= 2, cache_weights=comm_opt_level >= 3, overlap_all_gather=comm_opt_level == 4, - model=layer, + model_object_for_overlapping_allgathers=layer, ): Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp) Y_local.backward(Y_local_grad) + if not easy_tp: + sync_gradients(layer) if comm_opt_level >= 3: clear_weights_cache() # sequential backward pass