Skip to content

Commit

Permalink
[hotfix] moe hybrid parallelism benchmark & follow-up fix (hpcaitech#…
Browse files Browse the repository at this point in the history
…6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark
  • Loading branch information
botbw authored Sep 10, 2024
1 parent 8fd25d6 commit c54c4fc
Show file tree
Hide file tree
Showing 21 changed files with 907 additions and 99 deletions.
35 changes: 21 additions & 14 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ def __init__(
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
if dp_process_group is moe_dp_group:
pg_param_list = {
dp_process_group: list(model.parameters()),
}
else:
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}

if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
if len(pg_param_list[moe_dp_group]) == 0:
raise ValueError("No parameters found in moe_dp_group, please consider using HybridParallelPlugin instead")

super().__init__(
model=model,
Expand Down Expand Up @@ -407,24 +412,25 @@ def configure(
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)

# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group

if use_ddp:
self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
ranks=[0],
)
self.ddp_config["find_unused_parameters"] = True

if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError(
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
)

# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group

model = HybridParallelModule(
module=model,
precision=self.precision,
Expand Down Expand Up @@ -466,6 +472,7 @@ def configure(
tp_process_group=self.tp_group,
)
else:
is_zero = True
if self.dp_size <= 1:
self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
Expand Down
7 changes: 2 additions & 5 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad * ctx.ep_size
grad.mul_(ctx.ep_size)
return grad, None


Expand All @@ -328,7 +328,7 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
grad.div_(ctx.ep_size)
return grad, None


Expand Down Expand Up @@ -449,7 +449,4 @@ def all_to_all_uneven(
overlap: bool = False,
fp8_communication: bool = False,
):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
81 changes: 78 additions & 3 deletions colossalai/shardformer/modeling/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.functional as F
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
Expand All @@ -28,11 +28,13 @@
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
linear_with_async_comm,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group


Expand All @@ -58,7 +60,7 @@ def backward(ctx, grad_output):
return grad_output, grad_loss


class EPDeepseekMoE(nn.Module):
class EPDeepseekMoE(ParallelModule):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")

Expand Down Expand Up @@ -214,6 +216,79 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return output_hidden_states


class DeepseekMoEGate_Col(ParallelModule):
def parallel_linear(self, hidden_states):
assert (
hidden_states.shape[-1] == self.weight.shape[-1]
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
hidden_states.shape, self.weight.shape, self.weight.shape[-1]
)

output = linear_with_async_comm(
hidden_states, self.weight, None, self.process_group, True, fp8_communication=self.fp8_communication
)

# All-gather across the partitions.
output = gather_forward_split_backward(
output, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
)
return output

def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = self.parallel_linear(hidden_states)
if self.scoring_func == "softmax":
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")

### select top-k experts
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator

### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(
1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None

return topk_idx, topk_weight, aux_loss

@staticmethod
def from_native_module(
module, process_group: ProcessGroup, config, gather_output, fp8_communication
) -> "DeepseekMoEGate_Col":
LazyInitContext.materialize(module)
module.process_group = process_group
module.fp8_communication = fp8_communication
sharded_weight = shard_rowwise(module.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, module.weight)
module.__class__ = DeepseekMoEGate_Col
return module


class DeepseekPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/modeling/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
Expand All @@ -49,7 +49,7 @@
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)


class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
class EPMixtralSparseMoeBlock(ParallelModule):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")

Expand Down
36 changes: 31 additions & 5 deletions colossalai/shardformer/policies/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
from colossalai.shardformer.layer.linear import Linear1D_Row
from colossalai.shardformer.modeling.deepseek import (
DeepseekMoEGate_Col,
DeepseekPipelineForwards,
EPDeepseekMoE,
get_deepseek_flash_attention_forward,
Expand Down Expand Up @@ -56,16 +57,24 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size

# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
if sp_mode == "all_to_all":
num_q_heads //= sp_size
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
"num_heads": num_q_heads,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads

policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)

if self.shard_config.enable_sequence_parallelism:
if self.pipeline_stage_manager is not None:
# NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
Expand Down Expand Up @@ -97,6 +106,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding

if self.shard_config.enable_tensor_parallelism:
# tensor parallelism for non-moe params
assert (
Expand All @@ -107,10 +117,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
num_q_heads //= tp_size
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": num_q_heads,
}
if num_kv_heads:
num_kv_heads //= tp_size
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads

policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
Expand All @@ -135,8 +150,19 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription(
suffix="mlp.gate",
target_module=DeepseekMoEGate_Col,
kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"config": self.model.config,
},
ignore_if_not_exist=True,
),
],
)

if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
Expand Down
22 changes: 16 additions & 6 deletions colossalai/shardformer/policies/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size

# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)

if sp_mode == "all_to_all":
num_q_heads //= sp_size
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
"num_heads": num_q_heads,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
num_kv_heads //= sp_size
decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads

policy[attn_cls] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
Expand Down Expand Up @@ -101,12 +109,14 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
num_q_heads //= tp_size
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
"self_attn.num_heads": num_q_heads,
}
if num_kv_heads:
num_kv_heads //= tp_size
decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads

policy[MixtralDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
Expand All @@ -131,7 +141,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
target_module=Linear1D_Row,
kwargs={"fp8_communication": self.shard_config.fp8_communication},
),
SubModuleReplacementDescription( # or replicate?
SubModuleReplacementDescription(
suffix="block_sparse_moe.gate",
target_module=Linear1D_Col,
kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
Expand Down
Loading

0 comments on commit c54c4fc

Please sign in to comment.