From 6199125e883dd3b4f90f900776dbca66fe0737e7 Mon Sep 17 00:00:00 2001 From: eljandoubi Date: Sun, 25 Aug 2024 02:45:28 +0200 Subject: [PATCH] clean gemma attention class --- src/models/gemma.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/models/gemma.py b/src/models/gemma.py index 5fb8f86..bd4d1da 100644 --- a/src/models/gemma.py +++ b/src/models/gemma.py @@ -84,35 +84,31 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): self.layer_idx = layer_idx self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.is_causal = True - - assert self.hidden_size % self.num_heads == 0 + + assert config.hidden_size % self.num_heads == 0 - self.q_proj = nn.Linear(self.hidden_size, + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, + self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, + self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, - self.hidden_size, + config.hidden_size, bias=config.attention_bias) self.rotary_emb = GemmaRotaryEmbedding( dim=self.head_dim, - base=self.rope_theta, + base=config.rope_theta, ) def forward(self,