Skip to content

Commit

Permalink
More communication optimizations (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored and Sathwik Yanamaddi committed Jan 25, 2024
1 parent 357aab4 commit d70c5ff
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 35 deletions.
5 changes: 5 additions & 0 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def init(
config.inter_layer_parallel_rank = comm_handle.inter_layer_parallel_rank
config.data_parallel_rank = comm_handle.data_parallel_rank
config.intra_layer_parallel_rank = comm_handle.intra_layer_parallel_rank
config.intra_layer_depth_parallel_rank = comm_handle.intra_layer_depth_parallel_rank
config.intra_layer_row_parallel_rank = comm_handle.intra_layer_row_parallel_rank
config.intra_layer_column_parallel_rank = (
comm_handle.intra_layer_column_parallel_rank
)
is_initialized = True
if mixed_precision:
computation_dtype = torch.float16
Expand Down
14 changes: 14 additions & 0 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ def __init__(
if self.world_rank in group_members:
self.depth_intra_layer_parallel_group = group

# combined inner+outer
for i in range(G_intra_d):
group_members = list(
ranks_in_ith_jth_intra_layer_group[i, :, :].flatten()
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.outer_inner_intra_layer_parallel_group = group
self.outer_inner_intra_layer_parallel_group_root = (
group_members[0]
)

def _torch_to_mpi(self, tensor: torch.Tensor):
"""Converts a PyTorch tensor into an mpi4py compatible array using its
unified virtual address
Expand Down
109 changes: 89 additions & 20 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,38 @@
import torch.distributed as dist


def drop(x, transpose=False, dim=-1, batch_dim=0):
def drop(
x, transpose=False, dim=-1, batch_dim=0, skip_channels=False, skip_batch=False
):
if not transpose:
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group

x = Drop.apply(x, group, dim)
x = Drop.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim)
if not skip_channels:
x = Drop.apply(x, group, dim)
if not skip_batch:
x = Drop.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim)
return x


def gather(x, transpose=False, dim=-1, batch_dim=0):
def gather(
x, transpose=False, dim=-1, batch_dim=0, skip_channels=False, skip_batch=False
):
if not transpose:
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group

x = Gather.apply(x, group, dim)
x = Gather.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim)
if not skip_channels:
x = Gather.apply(x, group, dim)
if not skip_batch:
x = Gather.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim)
return x


OVERLAP_REDUCE_SCATTER = False
OVERLAP_ALL_REDUCE = False
CACHE_WEIGHTS = False
ALL_GATHER_ITERATOR = None
handles = []
pending_grad_accumulations = []
Expand Down Expand Up @@ -110,43 +117,64 @@ def enqueue_next_all_gather():
pass


def retrieve_all_gathered_weight(weight):
global CACHE_WEIGHTS, ALL_GATHER_ITERATOR
def retrieve_all_gathered_weight(weight, delete):
global ALL_GATHER_ITERATOR
assert weight in weights_cache
all_gathered_weight, handle = weights_cache[weight]
if ALL_GATHER_ITERATOR is not None:
enqueue_next_all_gather()
if delete:
del weights_cache[weight]
return all_gathered_weight, handle


@contextmanager
def overlap_all_gathers_for_checkpointed_forward(
model_object_for_overlapping_allgathers,
):
global ALL_GATHER_ITERATOR
if ALL_GATHER_ITERATOR is None: # this is a false call
try:
yield None
finally:
pass
else:
old_iterator = ALL_GATHER_ITERATOR
ALL_GATHER_ITERATOR = trigger_async_all_gathers(
model_object_for_overlapping_allgathers
)
enqueue_next_all_gather()
try:
yield None
finally:
ALL_GATHER_ITERATOR = old_iterator


@contextmanager
def optimize_communication(
overlap_all_reduce=True,
overlap_reduce_scatter=False,
cache_weights=False,
overlap_all_gather=False,
model=None,
model_object_for_overlapping_allgathers=None,
*args,
**kwargs
):
global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER, CACHE_WEIGHTS
global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER
global ALL_GATHER_ITERATOR
OVERLAP_ALL_REDUCE = overlap_all_reduce
OVERLAP_REDUCE_SCATTER = overlap_reduce_scatter

CACHE_WEIGHTS = cache_weights

if overlap_all_gather:
if model is None:
if model_object_for_overlapping_allgathers is None:
raise ValueError(
"You need to pass your model as an argument - "
"optimize_communication(...,model=model, ...)"
"optimize_communication(...,model_object_"
"for_overlapping_allgathers=model, ...)"
"if overlap_all_gather is True"
)
assert (
cache_weights
), "all gathers can only be overlapped if cache_weights is True"
ALL_GATHER_ITERATOR = trigger_async_all_gathers(model)
ALL_GATHER_ITERATOR = trigger_async_all_gathers(
model_object_for_overlapping_allgathers
)
enqueue_next_all_gather()

try:
Expand All @@ -157,3 +185,44 @@ def optimize_communication(
OVERLAP_ALL_REDUCE = False
OVERLAP_REDUCE_SCATTER = False
ALL_GATHER_ITERATOR = None


@torch.no_grad()
def sync_gradients(model, gradient_attr_name="grad", mean=False, vectorize=False):
grads_to_sync = []
for param in model.parameters():
if param.requires_grad:
grad = getattr(param, gradient_attr_name)
if grad is not None:
if hasattr(param, "is_tensor_parallel") and param.is_tensor_parallel:
if (
hasattr(param, "needs_gradient_sync")
and param.needs_gradient_sync
):
grads_to_sync.append(grad)
else:
grads_to_sync.append(grad)

if not grads_to_sync:
return

world_size = dist.get_world_size(ax.comm_handle.depth_intra_layer_parallel_group)
if vectorize:
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

global_grad = _flatten_dense_tensors(grads_to_sync)
dist.all_reduce(
global_grad, group=ax.comm_handle.depth_intra_layer_parallel_group
)
if mean:
global_grad.div_(world_size)

for old_tensor, new_tensor in zip(
grads_to_sync, _unflatten_dense_tensors(global_grad, grads_to_sync)
):
old_tensor.data = new_tensor
else:
for grad in grads_to_sync:
dist.all_reduce(grad, group=ax.comm_handle.depth_intra_layer_parallel_group)
if mean:
grad.div_(world_size)
7 changes: 5 additions & 2 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ def _gather(input_, dim, process_group=None, cache=False):
return input_

if input_ in axonn.intra_layer.weights_cache:
output, handle = axonn.intra_layer.retrieve_all_gathered_weight(input_)
output, handle = axonn.intra_layer.retrieve_all_gathered_weight(
input_, delete=not cache
)
if handle is not None:
handle.wait()
axonn.intra_layer.weights_cache[input_][1] = None
if cache:
axonn.intra_layer.weights_cache[input_][1] = None
else:
input_ = input_.contiguous()
# Size and dimension.
Expand Down
23 changes: 11 additions & 12 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@

from axonn import axonn as ax
import axonn
from .communication import (
Drop,
Gather,
ForwardGather_BackwardReduceScatter,
BackwardAllReduce,
)
from .communication import Drop, Gather, ForwardGather_BackwardReduceScatter


def divide(a, b):
Expand Down Expand Up @@ -168,6 +163,7 @@ def __init__(
self.weight = torch.nn.Parameter(initial_params, requires_grad=True)

setattr(self.weight, "is_tensor_parallel", True)
setattr(self.weight, "needs_gradient_sync", False)
setattr(
self.weight,
"process_group_for_norm_reduction",
Expand All @@ -181,6 +177,7 @@ def __init__(
)
)
setattr(self.bias, "is_tensor_parallel", True)
setattr(self.bias, "needs_gradient_sync", True)
if not transpose:
setattr(
self.bias,
Expand All @@ -204,15 +201,21 @@ def __init__(
def get_output_feature_size(self):
return self.local_out_features

def forward(self, x, scatter_input=True, gather_output=True):
def forward(
self,
x,
scatter_input=True,
gather_output=True,
cache_weights_in_all_gather=False,
):
# gather weights from depth parallel group
# reduce scatter in the backward pass
weight = ForwardGather_BackwardReduceScatter.apply(
self.weight,
self.depth_group,
0,
axonn.intra_layer.OVERLAP_REDUCE_SCATTER,
axonn.intra_layer.CACHE_WEIGHTS,
cache_weights_in_all_gather,
).reshape(self.local_out_features, self.local_in_features)

if not self.transpose:
Expand Down Expand Up @@ -256,10 +259,6 @@ def forward(self, x, scatter_input=True, gather_output=True):
bias,
self.outer_group if not self.transpose else self.inner_group,
)
else:
bias = BackwardAllReduce.apply(
bias, self.depth_group, axonn.intra_layer.OVERLAP_REDUCE_SCATTER
)
if self.skip_bias_add:
return x, bias
else:
Expand Down
5 changes: 4 additions & 1 deletion axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
clip_grad_norm_,
optimize_communication,
clear_weights_cache,
sync_gradients,
)


Expand Down Expand Up @@ -138,11 +139,13 @@ def test_bw_pass(
overlap_reduce_scatter=comm_opt_level >= 2,
cache_weights=comm_opt_level >= 3,
overlap_all_gather=comm_opt_level == 4,
model=layer,
model_object_for_overlapping_allgathers=layer,
):
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Y_local.backward(Y_local_grad)

if not easy_tp:
sync_gradients(layer)
if comm_opt_level >= 3:
clear_weights_cache()
# sequential backward pass
Expand Down

0 comments on commit d70c5ff

Please sign in to comment.