Skip to content

Commit

Permalink
Research/llama/bmm quantization (#94)
Browse files Browse the repository at this point in the history
* Quantize attention matmuls

* Quantize attention matmuls
  • Loading branch information
anmarques authored Oct 20, 2023
1 parent e25912d commit 27494ce
Showing 1 changed file with 52 additions and 2 deletions.
54 changes: 52 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,51 @@
_CONFIG_FOR_DOC = "LlamaConfig"


class QuantizableIdentity(nn.Module):
def forward(self, x):
return x


class MatMulLeftInput_QK(QuantizableIdentity):
...


class MatMulRightInput_QK(QuantizableIdentity):
...


class MatMulOutput_QK(QuantizableIdentity):
...


class MatMulLeftInput_PV(QuantizableIdentity):
...


class MatMulRightInput_PV(QuantizableIdentity):
...


class MatMulOutput_PV(QuantizableIdentity):
...


class QuantizableMatMul(nn.Module):
"""
Wrapper around torch.matmul with distinct inputs/output class
instances that could be quantized through SparseML recipe
"""

def __init__(self, left_input_cls, right_input_cls, output_cls):
super().__init__()
self.left_input = left_input_cls()
self.right_input = right_input_cls()
self.output = output_cls()

def forward(self, a: torch.Tensor, b: torch.Tensor):
return self.output(torch.matmul(self.left_input(a), self.right_input(b)))


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
Expand Down Expand Up @@ -253,8 +298,13 @@ def __init__(self, config: LlamaConfig):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

self.attn_weights_matmul = QuantizableMatMul(MatMulLeftInput_QK, MatMulRightInput_QK, MatMulOutput_QK)
self.attn_output_matmul = QuantizableMatMul(MatMulLeftInput_PV, MatMulRightInput_PV, MatMulOutput_PV)

self._init_rope()


def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
Expand Down Expand Up @@ -327,7 +377,7 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = self.attn_weights_matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
Expand All @@ -344,7 +394,7 @@ def forward(

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = self.attn_output_matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down

0 comments on commit 27494ce

Please sign in to comment.