diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 15fc3b033a228e..0fda78b66a677f 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -120,6 +120,46 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int return super().forward(positions + self.offset) +class BMMLeftInput_QK(nn.Identity): + ... + + +class BMMRightInput_QK(nn.Identity): + ... + + +class BMMOutput_QK(nn.Identity): + ... + + +class BMMLeftInput_PV(nn.Identity): + ... + + +class BMMRightInput_PV(nn.Identity): + ... + + +class BMMOutput_PV(nn.Identity): + ... + + +class QuantizableBatchMatMul(nn.Module): + """ + Wrapper around torch.bmm with distinct inputs/output class + instances that could be quantized through SparseML recipe + """ + + def __init__(self, left_input_cls, right_input_cls, output_cls): + super().__init__() + self.left_input = left_input_cls() + self.right_input = right_input_cls() + self.output = output_cls() + + def forward(self, a: torch.Tensor, b: torch.Tensor): + return self.output(torch.bmm(self.left_input(a), self.right_input(b))) + + class OPTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -150,6 +190,9 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.attn_weights_bmm = QuantizableBatchMatMul(BMMLeftInput_QK, BMMRightInput_QK, BMMOutput_QK) + self.attn_output_bmm = QuantizableBatchMatMul(BMMLeftInput_PV, BMMRightInput_PV, BMMOutput_PV) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -208,7 +251,7 @@ def forward( value_states = value_states.view(*proj_shape) src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + attn_weights = self.attn_weights_bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( @@ -254,7 +297,7 @@ def forward( attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.bmm(attn_probs, value_states) + attn_output = self.attn_output_bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError(