From 34abe16c2c2774229e671b8bed0c9484b065a12b Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 17 Oct 2023 16:17:59 -0500 Subject: [PATCH] reformat and change Tensor_Parallel_Linear to Linear --- axonn/intra_layer/__init__.py | 2 +- axonn/intra_layer/fully_connected.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 57efc32..cce4cb5 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -1,4 +1,4 @@ -from .fully_connected import Linear as Tensor_Parallel_Linear # noqa: F401 +from .fully_connected import Linear # noqa: F401 from .communication import Drop, Gather from axonn import axonn as ax diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index d924824..61ac6ef 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -108,7 +108,9 @@ def forward(self, x, scatter_input=True, gather_output=True): bias = self.bias if gather_output: - bias = Gather.apply(self.bias, self.outer_group if not self.transpose else self.inner_group) + bias = Gather.apply( + self.bias, self.outer_group if not self.transpose else self.inner_group + ) if self.skip_bias_add: return x, bias