Skip to content

Commit

Permalink
Add support of Falcon new decoder arch (40b and 180b models)
Browse files Browse the repository at this point in the history
  • Loading branch information
arashb committed Dec 8, 2023
1 parent b04811f commit 8766da4
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ void launch_kv_rotary_kernel(T* kv_cache,
DISPATCH_KV_ROTARY_IMPL(5, 128)
DISPATCH_KV_ROTARY_IMPL(8, 64)
DISPATCH_KV_ROTARY_IMPL(8, 128)
DISPATCH_KV_ROTARY_IMPL(16, 64)
DISPATCH_KV_ROTARY_IMPL(16, 128)
DISPATCH_KV_ROTARY_IMPL(35, 64)
DISPATCH_KV_ROTARY_IMPL(35, 128)
DISPATCH_KV_ROTARY_IMPL(36, 64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BlockedRotaryEmbeddings(DSKernelBase):

supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 128]
supported_q_ratios = [1, 2, 4, 5, 8, 35, 36, 71]
supported_q_ratios = [1, 2, 4, 5, 8, 16, 35, 36, 71]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...model_implementations.common_parameters import *
from ...model_implementations.layer_container_base import LayerContainer
'''
# HF Falcon model looks like this:
# HF Falcon 7b model looks like this:
FalconForCausalLM(
(transformer): FalconModel(
Expand Down Expand Up @@ -44,16 +44,16 @@ class FalconTransformerContainer(LayerContainer):
attn_out_w: AttentionOutputParameter
mlp_1_w: MLP1Parameter
mlp_2_w: MLP2Parameter
input_layernorm_gamma: NormParameter
input_layernorm_beta: NormParameter
ln_attn_gamma: NormParameter
ln_attn_beta: NormParameter

PARAM_MAPPING = {
"self_attention.query_key_value.weight": "qkv_w.params",
"self_attention.dense.weight": "attn_out_w.params",
"mlp.dense_h_to_4h.weight": "mlp_1_w.params",
"mlp.dense_4h_to_h.weight": "mlp_2_w.params",
"input_layernorm.weight": "input_layernorm_gamma.params",
"input_layernorm.bias": "input_layernorm_beta.params",
"input_layernorm.weight": "ln_attn_gamma.params",
"input_layernorm.bias": "ln_attn_beta.params",
}


Expand All @@ -72,3 +72,58 @@ class FalconNonTransformerContainer(LayerContainer):
"transformer.ln_f.bias": "final_norm_beta.params",
"lm_head.weight": "word_unembed.params",
}


'''
# HF Falcon 40b model looks like this:
FalconForCausalLM(
(transformer): FalconModel(
(word_embeddings): Embedding(65024, 8192)
(h): ModuleList(
(0-59): 60 x FalconDecoderLayer(
(self_attention): FalconAttention(
(maybe_rotary): FalconRotaryEmbedding()
(query_key_value): FalconLinear(in_features=8192, out_features=9216, bias=False)
(dense): FalconLinear(in_features=8192, out_features=8192, bias=False)
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(mlp): FalconMLP(
(dense_h_to_4h): FalconLinear(in_features=8192, out_features=32768, bias=False)
(act): GELU(approximate='none')
(dense_4h_to_h): FalconLinear(in_features=32768, out_features=8192, bias=False)
)
(ln_attn): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
(ln_mlp): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
)
)
(ln_f): LayerNorm((8192,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=8192, out_features=65024, bias=False)
)
'''


class FalconNewArchTransformerContainer(LayerContainer):
"""
Transformer layer container for the Falcon model.
"""
qkv_w: GQAMegatronQKVParameter
attn_out_w: AttentionOutputParameter
mlp_1_w: MLP1Parameter
mlp_2_w: MLP2Parameter
ln_attn_gamma: NormParameter
ln_attn_beta: NormParameter
ln_mlp_gamma: NormParameter
ln_mlp_beta: NormParameter

PARAM_MAPPING = {
"self_attention.query_key_value.weight": "qkv_w.params",
"self_attention.dense.weight": "attn_out_w.params",
"mlp.dense_h_to_4h.weight": "mlp_1_w.params",
"mlp.dense_4h_to_h.weight": "mlp_2_w.params",
"ln_attn.weight": "ln_attn_gamma.params",
"ln_attn.bias": "ln_attn_beta.params",
"ln_mlp.weight": "ln_mlp_gamma.params",
"ln_mlp.bias": "ln_mlp_beta.params",
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,25 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid
hidden states after pre normalization.
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
"""
assert not self.config.new_decoder_architecture, "Falcon new decoder architecture is supported in separate model implementation!"
assert self.config.parallel_attn, "Only parallel attention implementation is supported"

cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)

attention_layernorm_out = hidden_states
attn_hidden_state = self.qkv(attention_layernorm_out, cur_params.qkv_w, b=None)
attn_ln_out = hidden_states
attn_hidden_state = self.qkv(attn_ln_out, cur_params.qkv_w, b=None)
attn_hidden_state = self.attn(attn_hidden_state, kv_cache, ragged_batch_info)
attention_output = self.attn_out(attn_hidden_state, cur_params.attn_out_w, b=None)

mlp_layernorm_out = hidden_states
mlp_hidden_state = self.mlp_1(mlp_layernorm_out, cur_params.mlp_1_w, b=None)
if self.config.new_decoder_architecture:
residual, mlp_ln_out = self.norm(residual,
None,
gamma=cur_params.ln_mlp_gamma,
beta=cur_params.ln_mlp_beta)
else:
mlp_ln_out = hidden_states

mlp_hidden_state = self.mlp_1(mlp_ln_out, cur_params.mlp_1_w, b=None)
mlp_output = self.mlp_2(mlp_hidden_state, cur_params.mlp_2_w, b=None)

mlp_output.add_(attention_output)
Expand All @@ -153,8 +159,8 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid
next_params = self._transformer[layer_idx + 1]
residual, mlp_output = self.norm(residual,
mlp_output,
next_params.input_layernorm_gamma,
beta=next_params.input_layernorm_beta)
next_params.ln_attn_gamma,
beta=next_params.ln_attn_beta)
else:
# On last layer, we just need to perform the residual add. Adding into the residual
# here is safe.
Expand Down Expand Up @@ -190,8 +196,8 @@ def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:

residual, hidden_states = self.norm(residual,
None,
gamma=self._transformer[0].input_layernorm_gamma,
beta=self._transformer[0].input_layernorm_beta)
gamma=self._transformer[0].ln_attn_gamma,
beta=self._transformer[0].ln_attn_beta)

for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...config_v2 import RaggedInferenceEngineConfig
from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy
from ...model_implementations.falcon.falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer
from ...model_implementations.falcon.falcon_containers import FalconNewArchTransformerContainer
from ...model_implementations.falcon.falcon_model import FalconInferenceModel


Expand All @@ -19,7 +20,8 @@ def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group
def build_container_map(self) -> ContainerMap:
map = ContainerMap()

transformer_containers = [FalconTransformerContainer(self.model) for _ in range(self.model.num_layers)]
trans_container_cls = FalconNewArchTransformerContainer if self._model_config.new_decoder_architecture else FalconTransformerContainer
transformer_containers = [trans_container_cls(self.model) for _ in range(self.model.num_layers)]

map.set_transformer_params(['transformer.h'], transformer_containers)

Expand Down

0 comments on commit 8766da4

Please sign in to comment.