Skip to content

Commit

Permalink
make backward pass of fc layer asynchronous
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Oct 23, 2023
1 parent e447012 commit 7b2df38
Showing 1 changed file with 35 additions and 22 deletions.
57 changes: 35 additions & 22 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import torch.distributed as dist
import torch
from .communication import ForwardAllReduce, BackwardAllReduce, Drop

import torch.distributed as dist
from torch.autograd import Function

def divide(a, b):
assert a % b == 0
Expand All @@ -21,6 +22,36 @@ def initialize_params(
return params


class AsyncLinear(Function):
@staticmethod
def forward(ctx, input_, weight,
forward_all_reduce_group,
backward_all_reduce_group,
backward_comm_async):
ctx.save_for_backward(input_, weight)
ctx.backward_all_reduce_group = backward_all_reduce_group
ctx.backward_comm_async = backward_comm_async
output = input_.matmul(weight.t())
dist.all_reduce(output,
group=forward_all_reduce_group,
async_op=False)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
handle=None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(grad_input,
group=ctx.backward_all_reduce_group,
async_op=ctx.backward_comm_async)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().mm(input_.view(-1, input_.shape[-1]))
if handle and ctx.backward_comm_async:
handle.wait()
return grad_input, grad_weight, None, None, None

class Linear(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -66,16 +97,7 @@ def __init__(
init_method,
)

self.linear = torch.nn.Linear(
in_features=self.local_in_features,
out_features=self.local_out_features,
*args,
**kwargs,
bias=False
)

if init_method:
self.linear.weight.data.copy_(initial_params)
self.weight = torch.nn.Parameter(initial_params, requires_grad=True)

self.bias = torch.nn.Parameter(
torch.zeros(
Expand All @@ -90,18 +112,9 @@ def get_output_feature_size(self):

def forward(self, x):
if not self.transpose:
if x.size(-1) == self.local_in_features * self.inner_group_size:
x = Drop.apply(x, self.inner_group)
x = BackwardAllReduce.apply(x, self.outer_group)
x = self.linear(x)
x = ForwardAllReduce.apply(x, self.inner_group)
x = AsyncLinear.apply(x, self.weight, self.inner_group, self.outer_group, True)
else:
if x.size(-1) == self.local_in_features * self.outer_group_size:
x = Drop.apply(x, self.outer_group)
x = BackwardAllReduce.apply(x, self.inner_group)
x = self.linear(x)
x = ForwardAllReduce.apply(x, self.outer_group)

x = AsyncLinear.apply(x, self.weight, self.outer_group, self.inner_group, True)
if self.skip_bias_add:
return x, self.bias
else:
Expand Down

0 comments on commit 7b2df38

Please sign in to comment.