diff --git a/llama/model.py b/llama/model.py index 562fcad1b..5be939d2c 100755 --- a/llama/model.py +++ b/llama/model.py @@ -8,6 +8,7 @@ import fairscale.nn.model_parallel.initialize as fs_init import torch import torch.nn.functional as F +from einops import rearrange, repeat from fairscale.nn.model_parallel.layers import ( ColumnParallelLinear, ParallelEmbedding, @@ -125,8 +126,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + return rearrange(freqs_cis, 's d -> 1 s 1 d') def apply_rotary_emb( @@ -153,11 +153,16 @@ def apply_rotary_emb( """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + # xq_.shape + # (B, Seq_Len, H, Head_dim) -> (B, Seq_Len, H, Head_dim // 2, 2) + # -> (B, Seq_Len, H, Head_dim // 2) complex + xq_ = torch.view_as_complex(rearrange(xq.float(), '... (c d) -> ... c d', d=2)) + xk_ = torch.view_as_complex(rearrange(xk.float(), '... (c d) -> ... c d', d=2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + # (B, Seq_Len, H, Head_dim/2) -> (B, Seq_Len, H, Head_dim/2, 2) + # -> (B, Seq_Len, H, Head_dim) + xq_out = rearrange(torch.view_as_real(xq_ * freqs_cis), '... c d -> ... (c d)') + xk_out = rearrange(torch.view_as_real(xk_ * freqs_cis), '... c d -> ... (c d)') return xq_out.type_as(xq), xk_out.type_as(xk) @@ -166,11 +171,8 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) + + return rearrange(repeat(x, 'b s h d -> b s h r d', r=n_rep), 'b s h r d -> b s (h r) d') class Attention(nn.Module): @@ -273,9 +275,9 @@ def forward( bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = rearrange(xq, 'b s (h d) -> b s h d', h=self.n_local_heads) + xk = rearrange(xk, 'b s (h d) -> b s h d', h=self.n_local_kv_heads) + xv = rearrange(xv, 'b s (h d) -> b s h d', h=self.n_local_kv_heads) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) @@ -292,15 +294,15 @@ def forward( keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) - xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) - values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + xq = rearrange(xq, 'b s h d -> b h s d') # (bs, n_local_heads, seqlen, head_dim) + keys = rearrange(keys, 'b s h d -> b h s d') # (bs, n_local_heads, cache_len + seqlen, head_dim) + values = rearrange(values, 'b s h d -> b h s d') # (bs, n_local_heads, cache_len + seqlen, head_dim) + scores = torch.matmul(xq, rearrange(keys, 'b h s d -> b h d s')) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + output = rearrange(output, 'b h s d -> b s (h d)').contiguous() return self.wo(output) @@ -491,5 +493,5 @@ def forward(self, tokens: torch.Tensor, start_pos: int): for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) h = self.norm(h) - output = self.output(h).float() + output = self.output(h).float() # explicitly convert to full precision during inference return output diff --git a/requirements.txt b/requirements.txt index 66f8a64f5..27edd31c7 100755 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ torch fairscale fire sentencepiece +einops