Skip to content

Commit

Permalink
clean gemma attention class
Browse files Browse the repository at this point in the history
  • Loading branch information
eljandoubi committed Aug 25, 2024
1 parent 4f9f00f commit 6199125
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions src/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,31 @@ def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
self.layer_idx = layer_idx

self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.rope_theta = config.rope_theta
self.is_causal = True

assert self.hidden_size % self.num_heads == 0

assert config.hidden_size % self.num_heads == 0

self.q_proj = nn.Linear(self.hidden_size,
self.q_proj = nn.Linear(config.hidden_size,
self.num_heads * self.head_dim,
bias=config.attention_bias)

self.k_proj = nn.Linear(self.hidden_size,
self.k_proj = nn.Linear(config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias)

self.v_proj = nn.Linear(self.hidden_size,
self.v_proj = nn.Linear(config.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias)

self.o_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
config.hidden_size,
bias=config.attention_bias)

self.rotary_emb = GemmaRotaryEmbedding(
dim=self.head_dim,
base=self.rope_theta,
base=config.rope_theta,
)

def forward(self,
Expand Down

0 comments on commit 6199125

Please sign in to comment.