Skip to content

Commit

Permalink
OPT with quantizable MatMuls (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
natuan authored Jul 27, 2023
1 parent 2aca427 commit 38ae788
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,46 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int
return super().forward(positions + self.offset)


class BMMLeftInput_QK(nn.Identity):
...


class BMMRightInput_QK(nn.Identity):
...


class BMMOutput_QK(nn.Identity):
...


class BMMLeftInput_PV(nn.Identity):
...


class BMMRightInput_PV(nn.Identity):
...


class BMMOutput_PV(nn.Identity):
...


class QuantizableBatchMatMul(nn.Module):
"""
Wrapper around torch.bmm with distinct inputs/output class
instances that could be quantized through SparseML recipe
"""

def __init__(self, left_input_cls, right_input_cls, output_cls):
super().__init__()
self.left_input = left_input_cls()
self.right_input = right_input_cls()
self.output = output_cls()

def forward(self, a: torch.Tensor, b: torch.Tensor):
return self.output(torch.bmm(self.left_input(a), self.right_input(b)))


class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand Down Expand Up @@ -150,6 +190,9 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

self.attn_weights_bmm = QuantizableBatchMatMul(BMMLeftInput_QK, BMMRightInput_QK, BMMOutput_QK)
self.attn_output_bmm = QuantizableBatchMatMul(BMMLeftInput_PV, BMMRightInput_PV, BMMOutput_PV)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

Expand Down Expand Up @@ -208,7 +251,7 @@ def forward(
value_states = value_states.view(*proj_shape)

src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
attn_weights = self.attn_weights_bmm(query_states, key_states.transpose(1, 2))

if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
Expand Down Expand Up @@ -254,7 +297,7 @@ def forward(

attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

attn_output = torch.bmm(attn_probs, value_states)
attn_output = self.attn_output_bmm(attn_probs, value_states)

if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
Expand Down

0 comments on commit 38ae788

Please sign in to comment.