From 577f8079e1832dc79b64797411c8db5c3a7633e4 Mon Sep 17 00:00:00 2001 From: duanhx1037 Date: Tue, 4 Mar 2025 21:18:10 +0000 Subject: [PATCH] update seq split --- deepspeed/runtime/domino/transformer.py | 121 ++++++++++++++++++++---- 1 file changed, 100 insertions(+), 21 deletions(-) diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 88c5494c8147..e1334c7abe1c 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -118,7 +118,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attn_mask=None, + attn_mask=attention_mask, dropout_p=self.att_dropout_p, is_causal=True, scale=None) @@ -260,6 +260,7 @@ def __init__(self, self.bias_dropout_add_fused_inference = bias_dropout_add_fused_inference self.mpu = mpu self.output_bias = output_bias + self.input_split_dim = config.input_split_dim # Layernorm on the input data. self.input_layernorm = fused_layer_norm(config.hidden_size, @@ -267,14 +268,39 @@ def __init__(self, no_persist_layer_norm=config.no_persist_layer_norm) # Self attention. - self.self_attention = ShardedAttention(config, - layer_number, - mpu, - ColumnParallelLinear, - RowParallelLinearNoComm, - apply_rotary_pos_emb, - attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type) + if self.input_split_dim == "batch": + self.self_attention = ShardedAttention(config, + layer_number, + mpu, + ColumnParallelLinear, + RowParallelLinearNoComm, + apply_rotary_pos_emb, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) + elif self.input_split_dim == "seq": + query_projection_size = config.kv_channels * config.num_attention_heads + kv_projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = query_projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads // world_size + self.query_key_value = ColumnParallelLinear(config.hidden_size, + query_projection_size + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + self.self_attention_sp = CoreAttention(config, self.layer_number, mpu, self_attn_mask_type) + self.dense = RowParallelLinearNoComm(query_projection_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True) + else: + raise NotImplementedError self.hidden_dropout = config.hidden_dropout @@ -356,20 +382,73 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): if not self.llama_model: rotary_pos_emb = None - attention_output0, attention_bias0 = \ - self.self_attention( - layernorm_output0, + if self.input_split_dim == "seq": + # Compute full Q, K, V + layernorm_output = torch.concat([layernorm_output0, layernorm_output1], dim=0) + mixed_x_layer, _ = self.query_key_value(layernorm_output) + + # [s, b, np * 3 * hn] --> [s, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [s, b, np, 3 * hn] -> [b, np, s, 3*hn] + mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous() + # [s, b, np, 3 * hn] --> [s, b, np, hn], [s, b, np, hn], [s, b, np, hn] + (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [ + self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], dim=3) + + # [s, b, np, np * hn] -> [s, b, np, hn] + query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, + self.hidden_size_per_attention_head) + + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = ((rotary_pos_emb, ) * 2) + q_pos_emb, k_pos_emb = rotary_pos_emb + query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb) + key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb) + + batchsize, num_heads, seq_len, hidden_per_head = query_layer.shape[0], query_layer.shape[1], query_layer.shape[2], query_layer.shape[3] + + # seq 0: core attention + context_layer0 = self.self_attention_sp(query_layer[:, :, :seq_len//2, :], key_layer, value_layer, attention_mask[:, :, :seq_len//2, :]) + # Output. [s, b, h] + attention_output0, attention_bias0 = self.dense(context_layer0) + + handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + + # seq 1: core attention + context_layer1 = self.self_attention_sp(query_layer[:, :, seq_len//2:, :], key_layer, value_layer, attention_mask[:, :, seq_len//2:, :]) + # Output. [s, b, h] + attention_output1, attention_bias1 = self.dense(context_layer1) + + handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + handle0.wait() + + elif self.input_split_dim == "batch": + attention_output0, attention_bias0 = \ + self.self_attention( + layernorm_output0, + attention_mask, + rotary_pos_emb=rotary_pos_emb) + handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + + attention_output1, attention_bias1 = \ + self.self_attention( + layernorm_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) - - attention_output1, attention_bias1 = \ - self.self_attention( - layernorm_output1, - attention_mask, - rotary_pos_emb=rotary_pos_emb) - handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) - handle0.wait() + handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + handle0.wait() + else: + raise NotImplementedError # Residual0 connection. if self.apply_residual_connection_post_layernorm: