Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add support for seq split in Domino #7111

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 100 additions & 21 deletions deepspeed/runtime/domino/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask):
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
attn_mask=None,
attn_mask=attention_mask,
dropout_p=self.att_dropout_p,
is_causal=True,
scale=None)
Expand Down Expand Up @@ -260,21 +260,47 @@ def __init__(self,
self.bias_dropout_add_fused_inference = bias_dropout_add_fused_inference
self.mpu = mpu
self.output_bias = output_bias
self.input_split_dim = config.input_split_dim

# Layernorm on the input data.
self.input_layernorm = fused_layer_norm(config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=config.no_persist_layer_norm)

# Self attention.
self.self_attention = ShardedAttention(config,
layer_number,
mpu,
ColumnParallelLinear,
RowParallelLinearNoComm,
apply_rotary_pos_emb,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
if self.input_split_dim == "batch":
self.self_attention = ShardedAttention(config,
layer_number,
mpu,
ColumnParallelLinear,
RowParallelLinearNoComm,
apply_rotary_pos_emb,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
elif self.input_split_dim == "seq":
query_projection_size = config.kv_channels * config.num_attention_heads
kv_projection_size = config.kv_channels * config.num_attention_heads

# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = query_projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads // world_size
self.query_key_value = ColumnParallelLinear(config.hidden_size,
query_projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
self.self_attention_sp = CoreAttention(config, self.layer_number, mpu, self_attn_mask_type)
self.dense = RowParallelLinearNoComm(query_projection_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True)
Comment on lines +280 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is duplication of shardedAttention module's detail implementation.

else:
raise NotImplementedError

self.hidden_dropout = config.hidden_dropout

Expand Down Expand Up @@ -356,20 +382,73 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None):
if not self.llama_model:
rotary_pos_emb = None

attention_output0, attention_bias0 = \
self.self_attention(
layernorm_output0,
if self.input_split_dim == "seq":
# Compute full Q, K, V
layernorm_output = torch.concat([layernorm_output0, layernorm_output1], dim=0)
mixed_x_layer, _ = self.query_key_value(layernorm_output)

# [s, b, np * 3 * hn] --> [s, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [s, b, np, 3 * hn] -> [b, np, s, 3*hn]
mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous()
# [s, b, np, 3 * hn] --> [s, b, np, hn], [s, b, np, hn], [s, b, np, hn]
(query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
], dim=3)

# [s, b, np, np * hn] -> [s, b, np, hn]
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1,
self.hidden_size_per_attention_head)

if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb, ) * 2)
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb)

batchsize, num_heads, seq_len, hidden_per_head = query_layer.shape[0], query_layer.shape[1], query_layer.shape[2], query_layer.shape[3]

# seq 0: core attention
context_layer0 = self.self_attention_sp(query_layer[:, :, :seq_len//2, :], key_layer, value_layer, attention_mask[:, :, :seq_len//2, :])
# Output. [s, b, h]
attention_output0, attention_bias0 = self.dense(context_layer0)

handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

# seq 1: core attention
context_layer1 = self.self_attention_sp(query_layer[:, :, seq_len//2:, :], key_layer, value_layer, attention_mask[:, :, seq_len//2:, :])
# Output. [s, b, h]
attention_output1, attention_bias1 = self.dense(context_layer1)

handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()
Comment on lines +387 to +433
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are exactly the same as sharedAttention forward, I don't see why we need these duplication here.

Also please follow our current code hierarchy, not pull up lower layer module implementation code to upper layer module.

e.g., if any real change need to make on sharedAttention, create a similar module say XYZAttention, then here in DominoTransformerLayer forward function, we can simply call XYZAttention module without duplicating XYZAttention module's every line of code of its forward.


elif self.input_split_dim == "batch":
attention_output0, attention_bias0 = \
self.self_attention(
layernorm_output0,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

attention_output1, attention_bias1 = \
self.self_attention(
layernorm_output1,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

attention_output1, attention_bias1 = \
self.self_attention(
layernorm_output1,
attention_mask,
rotary_pos_emb=rotary_pos_emb)
handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()
handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()
else:
raise NotImplementedError

# Residual0 connection.
if self.apply_residual_connection_post_layernorm:
Expand Down