Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScatterMoE backward fix. #4

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"hard_sigmoid": nn.modules.Hardsigmoid,
"hard_swish": nn.modules.Hardswish,
"hard_tanh": nn.modules.Hardtanh,
"identity": nn.modules.Identity,
"laplace": ACT2CLS["laplace"],
"leaky_reLU": nn.modules.LeakyReLU,
"log_sigmoid": nn.modules.LogSigmoid,
Expand Down
20 changes: 10 additions & 10 deletions dolomite_engine/hf_models/models/moe_dolomite/moe/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ def forward(
grouped_out=False,
):
results = scattered_experts(
inputs,
self.weight.permute(0, 2, 1),
k,
sorted_expert_idxs,
sorted_scattered_idxs,
padded_block_idxs,
expert_offsets,
gates,
grouped_in,
grouped_out,
inputs=inputs,
expert_weights=self.weight.permute(0, 2, 1),
k=k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
padded_block_idxs=padded_block_idxs,
expert_offsets=expert_offsets,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)

return results
Expand Down
74 changes: 62 additions & 12 deletions dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed._tensor.placement_types import _Partial as Partial

from .....utils import ProcessGroupManager, is_scattermoe_available
from ....enums import InitMethod
from ....modeling_utils import get_activation_function, is_glu
from ....modeling_utils import ParameterizedLinear, get_activation_function, is_glu
from ....modeling_utils_TP import (
DTensorModule,
ReplicatedLinear,
Expand All @@ -28,7 +29,28 @@
from scattermoe.parallel_experts import parallel_linear as scattered_experts


class ColumnParallelScatteredExperts(ParameterizedScatteredExperts, DTensorModule):
class ReplicatedRouter(ParameterizedLinear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
std: float | None = None,
use_padding_free_transformer: bool = False,
sequence_parallel: bool = False,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype, std)

self.weight = nn.Parameter(
DTensor.from_local(
self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()]
)
)


class ColumnParallelScatteredExperts(ParameterizedScatteredExperts):
def __init__(
self,
num_experts: int,
Expand Down Expand Up @@ -67,8 +89,8 @@ def __init__(
run_check=False,
)
)

self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel)
# Put in MLP
# self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel)

def forward(
self,
Expand All @@ -84,8 +106,6 @@ def forward(
):
# F.linear manually triggers an all gather for sequence parallel but custom kernels are not aware of the placements
# so we manually call an all gather here
inputs = tensor_to_dtensor(inputs, current_placement=self.input_placement)
inputs = dtensor_to_tensor(inputs, desired_placement=Replicate(), grad_placement=Partial())

weight = self.weight.to_local()

Expand Down Expand Up @@ -145,7 +165,7 @@ def __init__(
)
)

self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel)
# self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel)

def forward(
self,
Expand Down Expand Up @@ -174,9 +194,6 @@ def forward(
grouped_out,
)

inputs = tensor_to_dtensor(inputs, current_placement=Partial())
inputs = dtensor_to_tensor(inputs, desired_placement=self.output_placement)

return inputs


Expand All @@ -185,6 +202,7 @@ def __init__(
self,
config: MoEDolomiteConfig,
use_padding_free_transformer: bool,
sequence_parallel: bool = False,
layer_idx: int | None = None,
) -> None:
nn.Module.__init__(self)
Expand All @@ -205,13 +223,13 @@ def __init__(
init_method = InitMethod(config.init_method)
residual_dropout = config.resid_pdrop

self.gate = ReplicatedLinear(
self.gate = ReplicatedRouter(
in_features=self.hidden_size,
out_features=config.num_experts,
bias=False,
std=config.initializer_range,
use_padding_free_transformer=use_padding_free_transformer,
sequence_parallel=False,
sequence_parallel=False, # replicate even if SP
)

std = initializer_range
Expand Down Expand Up @@ -240,3 +258,35 @@ def __init__(
)

self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout)
self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel)
self.output_placement = self.input_placement

def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]:
# hidden_states -> (total_q, hidden_size)
router_logits = self.gate(hidden_states)
router_logits = dtensor_to_tensor(router_logits, desired_placement=Replicate(), grad_placement=Partial())
# router_logits -> (total_q, num_experts)

router_weights, selected_experts = self._get_topk(router_logits)
router_weights = F.softmax(router_weights.float(), dim=-1)

# we cast back to the input dtype
router_weights = router_weights.type_as(hidden_states)

return router_logits, router_weights, selected_experts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = tensor_to_dtensor(hidden_states, current_placement=self.input_placement)

router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states)

hidden_states = dtensor_to_tensor(hidden_states, desired_placement=Replicate(), grad_placement=Partial())

hidden_states = self._compute_experts(hidden_states, router_weights, selected_experts)

hidden_states = tensor_to_dtensor(hidden_states, current_placement=Partial())
hidden_states = dtensor_to_tensor(
hidden_states, desired_placement=self.output_placement, grad_placement=self.output_placement
)
hidden_states = self.dropout(hidden_states)
return hidden_states, router_logits
184 changes: 184 additions & 0 deletions test_scattermoe_sp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import os

import scattermoe
import torch
import torch.distributed
from torch import nn
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Replicate, Shard
from transformers import set_seed

from dolomite_engine.hf_models.modeling_utils_TP.TP import (
dtensor_to_tensor,
get_module_placements,
modify_state_dict_to_dtensor_dict,
tensor_parallel_split_safetensor_slice,
tensor_to_dtensor,
)
from dolomite_engine.hf_models.models.moe_dolomite.config import MoEDolomiteConfig
from dolomite_engine.hf_models.models.moe_dolomite.moe.scatter import ParameterizedScatteredExperts, ScatterMoE
from dolomite_engine.hf_models.models.moe_dolomite_TP.moe_TP.scatter import ScatterMoE_TP
from dolomite_engine.utils import ProcessGroupManager


def load_dparams(module, name, tensor):
device_mesh = getattr(module, name).device_mesh
placements = getattr(module, name).placements
setattr(module, name, nn.Parameter(DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements)))


set_seed(42)
tp_size = int(os.getenv("WORLD_SIZE"))
ProcessGroupManager(tensor_parallel_size=tp_size)
rank = torch.distributed.get_rank()
torch_dtype = torch.float32

config = MoEDolomiteConfig(
n_embd=2048,
n_inner=2048,
num_experts=16,
num_experts_per_tok=2,
activation_function="relu",
add_bias=False,
embd_pdrop=0.0,
resid_pdrop=0.0,
)

if rank == 0:
print(config)

batch_size = 1024
# ones = torch.ones(config.num_experts, device=torch.cuda.current_device(), dtype=torch_dtype)
# eye = torch.eye(config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype)
# expert_idxs = 1 + torch.arange(config.num_experts, device=torch.cuda.current_device(), dtype=torch_dtype)
# batch_idxs = 1 + torch.arange(batch_size, device=torch.cuda.current_device(), dtype=torch_dtype)
# dim_idxs = 1 + torch.arange(config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype)

local_moe = ScatterMoE(config, use_padding_free_transformer=True, layer_idx=0)
local_moe = local_moe.to(device=torch.cuda.current_device(), dtype=torch_dtype)
shard_moe = ScatterMoE_TP(config, use_padding_free_transformer=True, sequence_parallel=True, layer_idx=0).to(
device=torch.cuda.current_device(), dtype=torch_dtype
)
input_tensor = 0.02 * torch.randn(
batch_size, config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype, requires_grad=True
)
gate_weight = 0.02 * torch.randn_like(local_moe.gate.weight, requires_grad=True)
c_fc_weight = 0.02 * torch.randn_like(local_moe.c_fc.weight)
c_proj_weight = 0.02 * torch.randn_like(local_moe.c_proj.weight)
grad_tensor = 0.02 * torch.randn(batch_size, config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype)

torch.distributed.broadcast(input_tensor, 0)
torch.distributed.broadcast(gate_weight, 0)
torch.distributed.broadcast(c_fc_weight, 0)
torch.distributed.broadcast(c_proj_weight, 0)
torch.distributed.broadcast(grad_tensor, 0)


if rank == 0:
print("Rank", rank)
print(local_moe)
print([(n, p.size()) for n, p in local_moe.named_parameters()])
print(shard_moe)
print([(n, p.size()) for n, p in local_moe.named_parameters()])

if rank == 0:
print("Distributing local_moe params...")

params_dict = {"gate.weight": gate_weight, "c_fc.weight": c_fc_weight, "c_proj.weight": c_proj_weight}
local_moe.load_state_dict(params_dict)
torch.distributed.barrier()

if rank == 0:
print("Distributing shard_moe params...")

# shard_moe.gate.load_state_dict({"weight": gate_weight})
load_dparams(shard_moe.gate, "weight", gate_weight)

# sharded_inter_dim = shard_moe.c_proj.in_features_per_device
# c_fc_1_weight, c_fc_2_weight = c_fc_weight.chunk(2, dim=1)
# shard_moe.c_fc.load_state_dict(
# {
# "weight": torch.cat(
# (
# c_fc_1_weight[:, sharded_inter_dim * rank : (rank + 1) * sharded_inter_dim, :],
# c_fc_2_weight[:, sharded_inter_dim * rank : (rank + 1) * sharded_inter_dim, :],
# ),
# dim=1,
# )
# }
# )
# shard_moe.c_fc.load_state_dict({"weight": c_fc_weight.view(c_fc_weight.size(0), tp_size, -1, c_fc_weight.size(2))[:, rank]})

load_dparams(
shard_moe.c_fc, "weight", c_fc_weight.view(c_fc_weight.size(0), tp_size, -1, c_fc_weight.size(2))[:, rank]
)

# shard_moe.c_proj.load_state_dict({"weight": c_proj_weight.view(c_proj_weight.size(0), c_proj_weight.size(1), tp_size, -1)[:, :, rank]})
load_dparams(
shard_moe.c_proj,
"weight",
c_proj_weight.view(c_proj_weight.size(0), c_proj_weight.size(1), tp_size, -1)[:, :, rank],
)

torch.distributed.barrier()
local_input_tensor = input_tensor
shard_input_tensor = input_tensor.clone().chunk(tp_size, dim=0)[rank]


local_out, local_logits, _ = local_moe(local_input_tensor)
shard_out, shard_logits = shard_moe(shard_input_tensor)

local_input_tensor_grad, local_gate_weight_grad = torch.autograd.grad(
outputs=(local_out),
inputs=(local_input_tensor, local_moe.gate.weight),
grad_outputs=(grad_tensor,),
)

shard_input_tensor_grad, shard_gate_weight_grad = torch.autograd.grad(
outputs=(shard_out),
inputs=(shard_input_tensor, shard_moe.gate.weight),
grad_outputs=(grad_tensor.chunk(tp_size, dim=0)[rank],),
)

shard_gate_weight_grad = dtensor_to_tensor(shard_gate_weight_grad, desired_placement=Replicate())
torch.distributed.barrier()
# print(list(shard_moe.parameters()))
# print(list(local_moe.parameters()))
if rank == 0:
print("Error:")
print()
print("logits:")
for r in range(tp_size):
if rank == r:
print("Rank %d:" % r, (local_logits - shard_logits).abs().max())
torch.distributed.barrier()


if rank == 0:
print()
print("out:")

for r in range(tp_size):
if rank == r:
print("Rank %d:" % r, (local_out.chunk(tp_size, dim=0)[rank] - shard_out).abs().max())
torch.distributed.barrier()

if rank == 0:
print()
print("input grad:")
for r in range(tp_size):
if rank == r:
diff = (local_input_tensor_grad.chunk(tp_size, dim=0)[rank] - shard_input_tensor_grad).abs()
print("Rank %d:" % r, diff.max())
torch.distributed.barrier()

if rank == 0:
print()
print("gate grad:")

for r in range(tp_size):
if rank == r:
print("Rank %d:" % r, (local_gate_weight_grad - shard_gate_weight_grad).abs().max())
torch.distributed.barrier()

ProcessGroupManager.destroy_process_groups()
Loading