Skip to content

Commit

Permalink
Remove RoPE from cross attention
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Oct 7, 2024
1 parent 2113ab7 commit 7c93f7a
Showing 1 changed file with 3 additions and 27 deletions.
30 changes: 3 additions & 27 deletions Seq2SeqSharp/Layers/MultiHeadAttention.cs
Original file line number Diff line number Diff line change
Expand Up @@ -388,30 +388,14 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int

//Multi-head attentions
IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, newTokensIdx, m_d });
if (m_PEType == PositionEmbeddingEnums.RoPE)
{
Qs = g.RoPE(Qs, seqLenQ);
}

IWeightTensor Ks = null;
IWeightTensor Vs = null;

if (cachedTensors == null) // We don't use any cached tensors
{
IWeightTensor allK = g.View(g.Affine(inputK, K, Kb), dims: new long[] { batchSize, seqLenK, m_multiHeadNum, m_d });
IWeightTensor allV = g.View(g.Affine(inputV, V, Vb), dims: new long[] { batchSize, seqLenV, m_multiHeadNum, m_d });

if (m_PEType == PositionEmbeddingEnums.RoPE)
{
Ks = g.View(g.AsContiguous(g.Transpose(allK, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenK, m_d });
Ks = g.RoPE(Ks, seqLenK);
Ks = g.View(g.AsContiguous(g.Transpose(Ks, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK });
}
else
{
Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK });
}

Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK });
Vs = g.View(g.AsContiguous(g.Transpose(allV, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenV, m_d });
}
else
Expand All @@ -422,16 +406,8 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int
if (cachedTensors.ContainsKey(KsCacheName) == false)
{
IWeightTensor allK = g.View(g.Affine(inputK, K, Kb), dims: new long[] { batchSize, seqLenK, m_multiHeadNum, m_d });
if (m_PEType == PositionEmbeddingEnums.RoPE)
{
Ks = g.View(g.AsContiguous(g.Transpose(allK, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenK, m_d });
Ks = g.RoPE(Ks, seqLenK);
Ks = g.View(g.AsContiguous(g.Transpose(Ks, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK });
}
else
{
Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK });
}
Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK });

cachedTensors.Add(KsCacheName, Ks.CopyWeightsRef(KsCacheName, Ks.NeedGradient, graphToBind: null));
}
else
Expand Down

0 comments on commit 7c93f7a

Please sign in to comment.