diff --git a/dolomite_engine/hf_models/modeling_utils/activations/base.py b/dolomite_engine/hf_models/modeling_utils/activations/base.py index 32170d6e..3a8d1550 100644 --- a/dolomite_engine/hf_models/modeling_utils/activations/base.py +++ b/dolomite_engine/hf_models/modeling_utils/activations/base.py @@ -12,6 +12,7 @@ "hard_sigmoid": nn.modules.Hardsigmoid, "hard_swish": nn.modules.Hardswish, "hard_tanh": nn.modules.Hardtanh, + "identity": nn.modules.Identity, "laplace": ACT2CLS["laplace"], "leaky_reLU": nn.modules.LeakyReLU, "log_sigmoid": nn.modules.LogSigmoid, diff --git a/dolomite_engine/hf_models/models/moe_dolomite/moe/scatter.py b/dolomite_engine/hf_models/models/moe_dolomite/moe/scatter.py index bf6dcd17..6d47a8c5 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite/moe/scatter.py +++ b/dolomite_engine/hf_models/models/moe_dolomite/moe/scatter.py @@ -45,16 +45,16 @@ def forward( grouped_out=False, ): results = scattered_experts( - inputs, - self.weight.permute(0, 2, 1), - k, - sorted_expert_idxs, - sorted_scattered_idxs, - padded_block_idxs, - expert_offsets, - gates, - grouped_in, - grouped_out, + inputs=inputs, + expert_weights=self.weight.permute(0, 2, 1), + k=k, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + padded_block_idxs=padded_block_idxs, + expert_offsets=expert_offsets, + gates=gates, + grouped_in=grouped_in, + grouped_out=grouped_out, ) return results diff --git a/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py b/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py index 824abea0..d2cce55d 100644 --- a/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py +++ b/dolomite_engine/hf_models/models/moe_dolomite_TP/moe_TP/scatter.py @@ -3,13 +3,14 @@ import torch import torch.distributed import torch.nn as nn +import torch.nn.functional as F from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed._tensor.placement_types import _Partial as Partial from .....utils import ProcessGroupManager, is_scattermoe_available from ....enums import InitMethod -from ....modeling_utils import get_activation_function, is_glu +from ....modeling_utils import ParameterizedLinear, get_activation_function, is_glu from ....modeling_utils_TP import ( DTensorModule, ReplicatedLinear, @@ -28,7 +29,28 @@ from scattermoe.parallel_experts import parallel_linear as scattered_experts -class ColumnParallelScatteredExperts(ParameterizedScatteredExperts, DTensorModule): +class ReplicatedRouter(ParameterizedLinear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + std: float | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> None: + super().__init__(in_features, out_features, bias, device, dtype, std) + + self.weight = nn.Parameter( + DTensor.from_local( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), placements=[Replicate()] + ) + ) + + +class ColumnParallelScatteredExperts(ParameterizedScatteredExperts): def __init__( self, num_experts: int, @@ -67,8 +89,8 @@ def __init__( run_check=False, ) ) - - self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + # Put in MLP + # self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) def forward( self, @@ -84,8 +106,6 @@ def forward( ): # F.linear manually triggers an all gather for sequence parallel but custom kernels are not aware of the placements # so we manually call an all gather here - inputs = tensor_to_dtensor(inputs, current_placement=self.input_placement) - inputs = dtensor_to_tensor(inputs, desired_placement=Replicate(), grad_placement=Partial()) weight = self.weight.to_local() @@ -145,7 +165,7 @@ def __init__( ) ) - self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + # self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) def forward( self, @@ -174,9 +194,6 @@ def forward( grouped_out, ) - inputs = tensor_to_dtensor(inputs, current_placement=Partial()) - inputs = dtensor_to_tensor(inputs, desired_placement=self.output_placement) - return inputs @@ -185,6 +202,7 @@ def __init__( self, config: MoEDolomiteConfig, use_padding_free_transformer: bool, + sequence_parallel: bool = False, layer_idx: int | None = None, ) -> None: nn.Module.__init__(self) @@ -205,13 +223,13 @@ def __init__( init_method = InitMethod(config.init_method) residual_dropout = config.resid_pdrop - self.gate = ReplicatedLinear( + self.gate = ReplicatedRouter( in_features=self.hidden_size, out_features=config.num_experts, bias=False, std=config.initializer_range, use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=False, + sequence_parallel=False, # replicate even if SP ) std = initializer_range @@ -240,3 +258,35 @@ def __init__( ) self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout) + self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + self.output_placement = self.input_placement + + def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]: + # hidden_states -> (total_q, hidden_size) + router_logits = self.gate(hidden_states) + router_logits = dtensor_to_tensor(router_logits, desired_placement=Replicate(), grad_placement=Partial()) + # router_logits -> (total_q, num_experts) + + router_weights, selected_experts = self._get_topk(router_logits) + router_weights = F.softmax(router_weights.float(), dim=-1) + + # we cast back to the input dtype + router_weights = router_weights.type_as(hidden_states) + + return router_logits, router_weights, selected_experts + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = tensor_to_dtensor(hidden_states, current_placement=self.input_placement) + + router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states) + + hidden_states = dtensor_to_tensor(hidden_states, desired_placement=Replicate(), grad_placement=Partial()) + + hidden_states = self._compute_experts(hidden_states, router_weights, selected_experts) + + hidden_states = tensor_to_dtensor(hidden_states, current_placement=Partial()) + hidden_states = dtensor_to_tensor( + hidden_states, desired_placement=self.output_placement, grad_placement=self.output_placement + ) + hidden_states = self.dropout(hidden_states) + return hidden_states, router_logits diff --git a/test_scattermoe_sp.py b/test_scattermoe_sp.py new file mode 100644 index 00000000..b8afdfd4 --- /dev/null +++ b/test_scattermoe_sp.py @@ -0,0 +1,184 @@ +import os + +import scattermoe +import torch +import torch.distributed +from torch import nn +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Replicate, Shard +from transformers import set_seed + +from dolomite_engine.hf_models.modeling_utils_TP.TP import ( + dtensor_to_tensor, + get_module_placements, + modify_state_dict_to_dtensor_dict, + tensor_parallel_split_safetensor_slice, + tensor_to_dtensor, +) +from dolomite_engine.hf_models.models.moe_dolomite.config import MoEDolomiteConfig +from dolomite_engine.hf_models.models.moe_dolomite.moe.scatter import ParameterizedScatteredExperts, ScatterMoE +from dolomite_engine.hf_models.models.moe_dolomite_TP.moe_TP.scatter import ScatterMoE_TP +from dolomite_engine.utils import ProcessGroupManager + + +def load_dparams(module, name, tensor): + device_mesh = getattr(module, name).device_mesh + placements = getattr(module, name).placements + setattr(module, name, nn.Parameter(DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements))) + + +set_seed(42) +tp_size = int(os.getenv("WORLD_SIZE")) +ProcessGroupManager(tensor_parallel_size=tp_size) +rank = torch.distributed.get_rank() +torch_dtype = torch.float32 + +config = MoEDolomiteConfig( + n_embd=2048, + n_inner=2048, + num_experts=16, + num_experts_per_tok=2, + activation_function="relu", + add_bias=False, + embd_pdrop=0.0, + resid_pdrop=0.0, +) + +if rank == 0: + print(config) + +batch_size = 1024 +# ones = torch.ones(config.num_experts, device=torch.cuda.current_device(), dtype=torch_dtype) +# eye = torch.eye(config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype) +# expert_idxs = 1 + torch.arange(config.num_experts, device=torch.cuda.current_device(), dtype=torch_dtype) +# batch_idxs = 1 + torch.arange(batch_size, device=torch.cuda.current_device(), dtype=torch_dtype) +# dim_idxs = 1 + torch.arange(config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype) + +local_moe = ScatterMoE(config, use_padding_free_transformer=True, layer_idx=0) +local_moe = local_moe.to(device=torch.cuda.current_device(), dtype=torch_dtype) +shard_moe = ScatterMoE_TP(config, use_padding_free_transformer=True, sequence_parallel=True, layer_idx=0).to( + device=torch.cuda.current_device(), dtype=torch_dtype +) +input_tensor = 0.02 * torch.randn( + batch_size, config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype, requires_grad=True +) +gate_weight = 0.02 * torch.randn_like(local_moe.gate.weight, requires_grad=True) +c_fc_weight = 0.02 * torch.randn_like(local_moe.c_fc.weight) +c_proj_weight = 0.02 * torch.randn_like(local_moe.c_proj.weight) +grad_tensor = 0.02 * torch.randn(batch_size, config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype) + +torch.distributed.broadcast(input_tensor, 0) +torch.distributed.broadcast(gate_weight, 0) +torch.distributed.broadcast(c_fc_weight, 0) +torch.distributed.broadcast(c_proj_weight, 0) +torch.distributed.broadcast(grad_tensor, 0) + + +if rank == 0: + print("Rank", rank) + print(local_moe) + print([(n, p.size()) for n, p in local_moe.named_parameters()]) + print(shard_moe) + print([(n, p.size()) for n, p in local_moe.named_parameters()]) + +if rank == 0: + print("Distributing local_moe params...") + +params_dict = {"gate.weight": gate_weight, "c_fc.weight": c_fc_weight, "c_proj.weight": c_proj_weight} +local_moe.load_state_dict(params_dict) +torch.distributed.barrier() + +if rank == 0: + print("Distributing shard_moe params...") + +# shard_moe.gate.load_state_dict({"weight": gate_weight}) +load_dparams(shard_moe.gate, "weight", gate_weight) + +# sharded_inter_dim = shard_moe.c_proj.in_features_per_device +# c_fc_1_weight, c_fc_2_weight = c_fc_weight.chunk(2, dim=1) +# shard_moe.c_fc.load_state_dict( +# { +# "weight": torch.cat( +# ( +# c_fc_1_weight[:, sharded_inter_dim * rank : (rank + 1) * sharded_inter_dim, :], +# c_fc_2_weight[:, sharded_inter_dim * rank : (rank + 1) * sharded_inter_dim, :], +# ), +# dim=1, +# ) +# } +# ) +# shard_moe.c_fc.load_state_dict({"weight": c_fc_weight.view(c_fc_weight.size(0), tp_size, -1, c_fc_weight.size(2))[:, rank]}) + +load_dparams( + shard_moe.c_fc, "weight", c_fc_weight.view(c_fc_weight.size(0), tp_size, -1, c_fc_weight.size(2))[:, rank] +) + +# shard_moe.c_proj.load_state_dict({"weight": c_proj_weight.view(c_proj_weight.size(0), c_proj_weight.size(1), tp_size, -1)[:, :, rank]}) +load_dparams( + shard_moe.c_proj, + "weight", + c_proj_weight.view(c_proj_weight.size(0), c_proj_weight.size(1), tp_size, -1)[:, :, rank], +) + +torch.distributed.barrier() +local_input_tensor = input_tensor +shard_input_tensor = input_tensor.clone().chunk(tp_size, dim=0)[rank] + + +local_out, local_logits, _ = local_moe(local_input_tensor) +shard_out, shard_logits = shard_moe(shard_input_tensor) + +local_input_tensor_grad, local_gate_weight_grad = torch.autograd.grad( + outputs=(local_out), + inputs=(local_input_tensor, local_moe.gate.weight), + grad_outputs=(grad_tensor,), +) + +shard_input_tensor_grad, shard_gate_weight_grad = torch.autograd.grad( + outputs=(shard_out), + inputs=(shard_input_tensor, shard_moe.gate.weight), + grad_outputs=(grad_tensor.chunk(tp_size, dim=0)[rank],), +) + +shard_gate_weight_grad = dtensor_to_tensor(shard_gate_weight_grad, desired_placement=Replicate()) +torch.distributed.barrier() +# print(list(shard_moe.parameters())) +# print(list(local_moe.parameters())) +if rank == 0: + print("Error:") + print() + print("logits:") +for r in range(tp_size): + if rank == r: + print("Rank %d:" % r, (local_logits - shard_logits).abs().max()) + torch.distributed.barrier() + + +if rank == 0: + print() + print("out:") + +for r in range(tp_size): + if rank == r: + print("Rank %d:" % r, (local_out.chunk(tp_size, dim=0)[rank] - shard_out).abs().max()) + torch.distributed.barrier() + +if rank == 0: + print() + print("input grad:") +for r in range(tp_size): + if rank == r: + diff = (local_input_tensor_grad.chunk(tp_size, dim=0)[rank] - shard_input_tensor_grad).abs() + print("Rank %d:" % r, diff.max()) + torch.distributed.barrier() + +if rank == 0: + print() + print("gate grad:") + +for r in range(tp_size): + if rank == r: + print("Rank %d:" % r, (local_gate_weight_grad - shard_gate_weight_grad).abs().max()) + torch.distributed.barrier() + +ProcessGroupManager.destroy_process_groups() diff --git a/test_scattermoe_tp.py b/test_scattermoe_tp.py new file mode 100644 index 00000000..5d919295 --- /dev/null +++ b/test_scattermoe_tp.py @@ -0,0 +1,184 @@ +import os + +import scattermoe +import torch +import torch.distributed +from torch import nn +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Replicate, Shard +from transformers import set_seed + +from dolomite_engine.hf_models.modeling_utils_TP.TP import ( + dtensor_to_tensor, + get_module_placements, + modify_state_dict_to_dtensor_dict, + tensor_parallel_split_safetensor_slice, + tensor_to_dtensor, +) +from dolomite_engine.hf_models.models.moe_dolomite.config import MoEDolomiteConfig +from dolomite_engine.hf_models.models.moe_dolomite.moe.scatter import ParameterizedScatteredExperts, ScatterMoE +from dolomite_engine.hf_models.models.moe_dolomite_TP.moe_TP.scatter import ScatterMoE_TP +from dolomite_engine.utils import ProcessGroupManager + + +def load_dparams(module, name, tensor): + device_mesh = getattr(module, name).device_mesh + placements = getattr(module, name).placements + setattr(module, name, nn.Parameter(DTensor.from_local(tensor, device_mesh=device_mesh, placements=placements))) + + +set_seed(42) +tp_size = int(os.getenv("WORLD_SIZE")) +ProcessGroupManager(tensor_parallel_size=tp_size) +rank = torch.distributed.get_rank() +torch_dtype = torch.float32 + +config = MoEDolomiteConfig( + n_embd=2048, + n_inner=2048, + num_experts=16, + num_experts_per_tok=2, + activation_function="relu", + add_bias=False, + embd_pdrop=0.0, + resid_pdrop=0.0, +) + +if rank == 0: + print(config) + +batch_size = 128 +# ones = torch.ones(config.num_experts, device=torch.cuda.current_device(), dtype=torch_dtype) +# eye = torch.eye(config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype) +# expert_idxs = 1 + torch.arange(config.num_experts, device=torch.cuda.current_device(), dtype=torch_dtype) +# batch_idxs = 1 + torch.arange(batch_size, device=torch.cuda.current_device(), dtype=torch_dtype) +# dim_idxs = 1 + torch.arange(config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype) + +local_moe = ScatterMoE(config, use_padding_free_transformer=True, layer_idx=0) +local_moe = local_moe.to(device=torch.cuda.current_device(), dtype=torch_dtype) +shard_moe = ScatterMoE_TP(config, use_padding_free_transformer=True, layer_idx=0).to( + device=torch.cuda.current_device(), dtype=torch_dtype +) +input_tensor = 0.02 * torch.randn( + batch_size, config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype, requires_grad=True +) +gate_weight = 0.02 * torch.randn_like(local_moe.gate.weight, requires_grad=True) +c_fc_weight = 0.02 * torch.randn_like(local_moe.c_fc.weight) +c_proj_weight = 0.02 * torch.randn_like(local_moe.c_proj.weight) +grad_tensor = 0.02 * torch.randn(batch_size, config.n_embd, device=torch.cuda.current_device(), dtype=torch_dtype) + +torch.distributed.broadcast(input_tensor, 0) +torch.distributed.broadcast(gate_weight, 0) +torch.distributed.broadcast(c_fc_weight, 0) +torch.distributed.broadcast(c_proj_weight, 0) +torch.distributed.broadcast(grad_tensor, 0) + + +if rank == 0: + print("Rank", rank) + print(local_moe) + print([(n, p.size()) for n, p in local_moe.named_parameters()]) + print(shard_moe) + print([(n, p.size()) for n, p in local_moe.named_parameters()]) + +if rank == 0: + print("Distributing local_moe params...") + +params_dict = {"gate.weight": gate_weight, "c_fc.weight": c_fc_weight, "c_proj.weight": c_proj_weight} +local_moe.load_state_dict(params_dict) +torch.distributed.barrier() + +if rank == 0: + print("Distributing shard_moe params...") + +# shard_moe.gate.load_state_dict({"weight": gate_weight}) +load_dparams(shard_moe.gate, "weight", gate_weight) +if False: + sharded_inter_dim = shard_moe.c_proj.in_features_per_device + c_fc_1_weight, c_fc_2_weight = c_fc_weight.chunk(2, dim=1) + shard_moe.c_fc.load_state_dict( + { + "weight": torch.cat( + ( + c_fc_1_weight[:, sharded_inter_dim * rank : (rank + 1) * sharded_inter_dim, :], + c_fc_2_weight[:, sharded_inter_dim * rank : (rank + 1) * sharded_inter_dim, :], + ), + dim=1, + ) + } + ) +else: + # shard_moe.c_fc.load_state_dict({"weight": c_fc_weight.view(c_fc_weight.size(0), tp_size, -1, c_fc_weight.size(2))[:, rank]}) + load_dparams( + shard_moe.c_fc, "weight", c_fc_weight.view(c_fc_weight.size(0), tp_size, -1, c_fc_weight.size(2))[:, rank] + ) + +# shard_moe.c_proj.load_state_dict({"weight": c_proj_weight.view(c_proj_weight.size(0), c_proj_weight.size(1), tp_size, -1)[:, :, rank]}) +load_dparams( + shard_moe.c_proj, + "weight", + c_proj_weight.view(c_proj_weight.size(0), c_proj_weight.size(1), tp_size, -1)[:, :, rank], +) + +torch.distributed.barrier() +local_input_tensor = input_tensor +shard_input_tensor = input_tensor.clone() + +local_out, local_logits, _ = local_moe(local_input_tensor) +shard_out, shard_logits = shard_moe(shard_input_tensor) + +local_input_tensor_grad, local_gate_weight_grad = torch.autograd.grad( + outputs=(local_out), + inputs=(local_input_tensor, local_moe.gate.weight), + grad_outputs=(grad_tensor,), +) + +shard_input_tensor_grad, shard_gate_weight_grad = torch.autograd.grad( + outputs=(shard_out), + inputs=(shard_input_tensor, shard_moe.gate.weight), + grad_outputs=(grad_tensor,), +) + +shard_gate_weight_grad = dtensor_to_tensor(shard_gate_weight_grad, desired_placement=Replicate()) + +torch.distributed.barrier() +# print(list(shard_moe.parameters())) +# print(list(local_moe.parameters())) +if rank == 0: + print("Error:") + print() + print("logits:") +for r in range(tp_size): + if rank == r: + print("Rank %d:" % r, (local_logits - shard_logits).abs().max()) + torch.distributed.barrier() + + +if rank == 0: + print() + print("out:") + +for r in range(tp_size): + if rank == r: + print("Rank %d:" % r, (local_out - shard_out).abs().max()) + torch.distributed.barrier() + +if rank == 0: + print() + print("input grad:") +for r in range(tp_size): + if rank == r: + diff = (local_input_tensor_grad - shard_input_tensor_grad).abs() + print("Rank %d:" % r, diff.max()) + torch.distributed.barrier() + +if rank == 0: + print() + print("gate grad:") + +for r in range(tp_size): + if rank == r: + print("Rank %d:" % r, (local_gate_weight_grad - shard_gate_weight_grad).abs().max()) + torch.distributed.barrier() + +ProcessGroupManager.destroy_process_groups()