Skip to content

Commit

Permalink
Enable KVCache for Encoder-Decoder network.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Oct 22, 2024
1 parent 50a0e4a commit 113dd85
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 97 deletions.
2 changes: 1 addition & 1 deletion AdvUtils/AdvUtils.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
<Version>2.8.16</Version>
<Version>2.7.0</Version>
<Authors>Zhongkai Fu</Authors>
<Company />
<Description>A utility for common alogrithms</Description>
Expand Down
54 changes: 40 additions & 14 deletions Seq2SeqSharp/Applications/Decoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public static List<List<BeamSearchStatus>> CombineBeamSearchResults(List<List<Be

public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<int>> tgtSeqs, IComputeGraph g, IWeightTensor encOutputs, TransformerDecoder decoder, IFeedForwardLayer decoderFFLayer,
IWeightTensor tgtEmbedding, float[] srcOriginalLenghts, Vocab tgtVocab, PaddingEnums paddingType, float dropoutRatio, DecodingOptions decodingOptions, bool isTraining = true,
bool outputSentScore = true, List<BeamSearchStatus> previousBeamSearchResults = null, IFeedForwardLayer pointerGenerator = null, List<List<int>> srcSeqs = null, Dictionary<string, IWeightTensor> cachedTensors = null,
bool outputSentScore = true, List<BeamSearchStatus> previousBeamSearchResults = null, IFeedForwardLayer pointerGenerator = null, List<List<int>> srcSeqs = null, Dictionary<string, IWeightTensor> contextTensors = null,
List<List<int>> alignmentsToSrc = null, List<List<float>> alignmentScoresToSrc = null, bool teacherForcedAlignment = false, LossEnums lossType = LossEnums.CrossEntropy, float labelSmooth = 0.0f, float lossSmooth = 1e-9f,
List<int> blockedTokens = null, IWeightTensor segmentEmbeddings = null, bool amp = false, IWeightTensor posEmbeddings = null, float lossScaling = 1.0f, int paddingAlignmentFactor = 0)
{
Expand All @@ -281,33 +281,59 @@ public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<
IWeightTensor srcTgtMask = (paddingType == PaddingEnums.NoPadding || batchSize == 1) ? null : g.BuildSrcTgtMask(srcSeqLen, tgtSeqLen, tgtOriginalLengths, srcOriginalLenghts, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
if (srcTgtMask != null)
{
srcTgtMask = g.View(srcTgtMask, new long[] { srcTgtMask.Sizes[0], 1, srcTgtMask.Sizes[1], srcTgtMask.Sizes[2] });
srcTgtMask = g.View(srcTgtMask, new long[] { srcTgtMask.Sizes[0], 1, srcTgtMask.Sizes[1], srcTgtMask.Sizes[2] }); // Shape: [batch_size, 1, tgtSeqLen, srcSeqLen]
}

IWeightTensor tgtSelfTriMask = null;
if (decoder.AttentionType == AttentionTypeEnums.Classic)
IWeightTensor inputEmbs = null;
if (contextTensors != null && contextTensors.Count > 0)
{
if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1)
// Only add last token in the target sequence into token list and cross-mask tensor
if (srcTgtMask != null)
{
tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen });
srcTgtMask = g.Peek(srcTgtMask, 2, tgtSeqLen - 1, 1);
}
else

List<List<int>> t = new List<List<int>>();
for (int i = 0; i < tgtSeqs.Count; i++)
{
tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen });
t.Add(new List<int>());
t[i].Add(tgtSeqs[i][tgtSeqs[i].Count - 1]);
}
}
tgtSeqLen = t[0].Count;

IWeightTensor inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
if (posEmbeddings != null)
inputEmbs = TensorUtils.CreateTokensEmbeddings(t, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
if (posEmbeddings != null)
{
inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio); //Output Shape: [batchSize * seqLen, hidden_dim]
}
}
else
{
inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio);
if (decoder.AttentionType == AttentionTypeEnums.Classic)
{
if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1)
{
tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen });
}
else
{
tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen });
}
}

inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
if (posEmbeddings != null)
{
inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio);
}
}

IWeightTensor decOutput;
IWeightTensor decEncAttnProbs;
(decOutput, decEncAttnProbs) = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g, outputAttnWeights: pointerGenerator != null, cachedTensors: cachedTensors);
(decOutput, decEncAttnProbs) = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g, outputAttnWeights: pointerGenerator != null, cachedTensors: contextTensors);

if (isTraining == false && teacherForcedAlignment == false)
{
Expand Down
40 changes: 23 additions & 17 deletions Seq2SeqSharp/Applications/Seq2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
using System.Collections.Generic;
using System.IO;
using AdvUtils;
using Microsoft.Extensions.Caching.Memory;
using System.Runtime.Caching;
using Seq2SeqSharp.Enums;
using Seq2SeqSharp.Corpus;
using Seq2SeqSharp.Layers;
Expand Down Expand Up @@ -41,8 +41,6 @@ public class Seq2Seq : BaseSeq2SeqFramework<Seq2SeqModel>
private readonly PaddingEnums m_paddingType = PaddingEnums.AllowPadding;
readonly Seq2SeqOptions m_options = null;

private MemoryCache m_memoryCache;

public Seq2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab = null)
: base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio,
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
Expand All @@ -57,11 +55,6 @@ public Seq2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab = n
// Check if options are valided.
m_options.ValidateOptions();

m_memoryCache = new MemoryCache(new MemoryCacheOptions
{
SizeLimit = 1024
});

if (File.Exists(m_options.ModelFilePath))
{
if (srcVocab != null || tgtVocab != null)
Expand Down Expand Up @@ -193,16 +186,15 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
//}

IWeightTensor encOutput;
if (!isTraining && (m_options.ProcessorType == ProcessorTypeEnums.CPU))
if (m_options.UseKVCache)
{
// Try to get src tensor from cache
string cacheKey = GenerateCacheKey(srcSnts);
if (!m_memoryCache.TryGetValue(cacheKey, out encOutput))
{
encOutput = MemoryCache.Default[cacheKey] as IWeightTensor;
if (encOutput == null)
{
encOutput = Encoder.Run(computeGraph, encoder, m_modelMetaData, m_paddingType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths); // Shape: [batchsize * seqLen, embedding_dim]

var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1);
m_memoryCache.Set(cacheKey, encOutput.CopyWeightsRef($"cache_{encOutput.Name}", false, graphToBind: null), cacheEntryOptions);
MemoryCache.Default.Set(cacheKey, encOutput.CopyWeightsRef($"cache_{encOutput.Name}", false, graphToBind: null), DateTimeOffset.Now + TimeSpan.FromMinutes(10));
}
}
else
Expand Down Expand Up @@ -265,7 +257,18 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
}
else
{ // Test mode or running validation in Training mode
Dictionary<string, IWeightTensor> cachedTensors = new Dictionary<string, IWeightTensor>();
string cacheKey = GenerateCacheKey(tgtSnts);
Dictionary<string, IWeightTensor> cachedTensors = null;
if (m_options.UseKVCache)
{
cachedTensors = MemoryCache.Default[cacheKey] as Dictionary<string, IWeightTensor>;
if (cachedTensors == null && decodingOptions.BeamSearchSize == 1)
{
cachedTensors = new Dictionary<string, IWeightTensor>();
}
MemoryCache.Default.Remove(cacheKey);
}

List<List<BeamSearchStatus>> beam2batchStatus = Decoder.InitBeamSearchStatusListList(batchSize, tgtTokensList);
for (int i = tgtTokensList[0].Count; i < decodingOptions.MaxTgtSentLength; i++)
{
Expand All @@ -292,7 +295,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
originalSrcLengths, m_modelMetaData.TgtVocab, m_paddingType, 0.0f, decodingOptions, isTraining,
outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus,
pointerGenerator: pointerGenerator, srcSeqs: srcTokensList,
cachedTensors: cachedTensors, alignmentsToSrc: alignmentsToSrc, alignmentScoresToSrc: alignmentScores,
contextTensors: cachedTensors, alignmentsToSrc: alignmentsToSrc, alignmentScoresToSrc: alignmentScores,
blockedTokens: decodingOptions.BlockedTokens, segmentEmbeddings: segmentEmbedding, amp: m_options.AMP,
posEmbeddings: posEmbeddings, lossScaling: LossScaling, paddingAlignmentFactor: m_options.PaddingAlignmentFactor);

Expand Down Expand Up @@ -334,10 +337,13 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu

if (cachedTensors != null)
{
cacheKey = GenerateCacheKey(nr.Output[0]);
Dictionary<string, IWeightTensor> newCachedTensors = new Dictionary<string, IWeightTensor>();
foreach (var pair in cachedTensors)
{
pair.Value.Dispose();
newCachedTensors.Add(pair.Key, pair.Value.CopyWeightsRef(pair.Value.Name, false, graphToBind: null));
}
MemoryCache.Default.Set(cacheKey, newCachedTensors, DateTimeOffset.Now + TimeSpan.FromMinutes(10));
}
}
}
Expand Down
62 changes: 7 additions & 55 deletions Seq2SeqSharp/Layers/MultiHeadAttention.cs
Original file line number Diff line number Diff line change
Expand Up @@ -437,38 +437,17 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int
using IComputeGraph g = graph.CreateSubGraph(keyName);
int seqLenQ = inputQ.Rows / batchSize;

int newTokensIdx = seqLenQ;
IWeightTensor m_cacheQs = null;
string QKeyName = keyName + "_" + nameof(inputQ);
if (cachedTensors != null)
{
if (cachedTensors.ContainsKey(QKeyName) == true)
{
m_cacheQs = cachedTensors[QKeyName];
newTokensIdx = seqLenQ - (int)m_cacheQs.Sizes[0];
}
else
{
cachedTensors.Add(QKeyName, null);
}

// Optimize runtime for test that only processing new tokens
inputQ = g.View(inputQ, dims: new long[] { batchSize, seqLenQ, -1 });
inputQ = g.AsContiguous(g.Peek(inputQ, 1, seqLenQ - newTokensIdx, newTokensIdx)); // Shape: [batchSize, newTokensSize, input_dim]
inputQ = g.View(inputQ, dims: new long[] { batchSize * newTokensIdx, -1 }); // Shape: [batchSize * newTokensSize, input_dim]
}

// SeqLenK must be euqal to SeqLenV
int seqLenK = inputK.Rows / batchSize;
int seqLenV = inputV.Rows / batchSize;

IWeightTensor inputQNorm = layerNormQ.Norm(inputQ, g);

//Input projections
IWeightTensor allQ = g.View(g.Affine(inputQNorm, Q, Qb), dims: new long[] { batchSize, newTokensIdx, m_multiHeadNum, m_d });
IWeightTensor allQ = g.View(g.Affine(inputQNorm, Q, Qb), dims: new long[] { batchSize, seqLenQ, m_multiHeadNum, m_d });

//Multi-head attentions
IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, newTokensIdx, m_d });
IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenQ, m_d });
IWeightTensor Ks = null;
IWeightTensor Vs = null;

Expand Down Expand Up @@ -512,15 +491,10 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int
// Scaled softmax
float scale = 1.0f / (float)(Math.Sqrt(m_d));
var attn = g.MulBatch(Qs, Ks, scale); // Shape: [batchSize * m_multiHeadNum, newTokensIdx, seqLenK]
attn = g.View(attn, dims: new long[] { batchSize, m_multiHeadNum, newTokensIdx, seqLenK });
attn = g.View(attn, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenK });

if (keyMask != null)
{
if (cachedTensors != null)
{
keyMask = g.Peek(keyMask, 2, seqLenQ - newTokensIdx, newTokensIdx); // Shape: [batchSize, m_multiHeadNum, newTokensIdx, seqLenK]
}

attn = g.Add(attn, keyMask, inPlace: true); // Shape: [batchSize, m_multiHeadNum, newTokensIdx, seqLenK]
}

Expand All @@ -537,40 +511,18 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int
}

sumAttnWeights = graph.Div(sumAttnWeights, (float)m_multiHeadNum, inPlace: true);
sumAttnWeights = graph.View(sumAttnWeights, new long[] { batchSize * newTokensIdx, seqLenK });
sumAttnWeights = graph.View(sumAttnWeights, new long[] { batchSize * seqLenQ, seqLenK });
}

attnProbs = g.View(attnProbs, dims: new long[] { batchSize * m_multiHeadNum, newTokensIdx, seqLenK });
attnProbs = g.View(attnProbs, dims: new long[] { batchSize * m_multiHeadNum, seqLenQ, seqLenK });

IWeightTensor o = g.View(g.MulBatch(attnProbs, Vs), dims: new long[] { batchSize, m_multiHeadNum, newTokensIdx, m_d });
IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize * newTokensIdx, m_multiHeadNum * m_d });
IWeightTensor o = g.View(g.MulBatch(attnProbs, Vs), dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, m_d });
IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize * seqLenQ, m_multiHeadNum * m_d });

// Output projection
IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), m_dropoutRatio, inPlace: true);
IWeightTensor result = graph.Add(finalAttResults, inputQ, inPlace: true); // Shape: [batchSize * newTokensSize, input_dim]


if (cachedTensors != null)
{
result = g.View(result, dims: new long[] { batchSize, newTokensIdx, m_multiHeadNum * m_d });
result = g.AsContiguous(g.Transpose(result, 0, 1)); // Shape: [newTokensIdx, batchSize, m_multiHeadNum * m_d]

if (m_cacheQs == null)
{
m_cacheQs = result;// Shape: [newTokensIdx, batchSize, m_multiHeadNum * m_d]
}
else
{
m_cacheQs = g.Concate(0, m_cacheQs, result);
}
m_cacheQs.UnbindFromComputeGraph();

cachedTensors[QKeyName] = m_cacheQs;

result = g.AsContiguous(g.Transpose(m_cacheQs, 0, 1)); // Shape: [batchSize, seqLenQ, m_multiHeadNum * m_d]
result = graph.View(result, dims: new long[] { batchSize * seqLenQ, m_multiHeadNum * m_d });
}

return (result, sumAttnWeights);
}

Expand Down
1 change: 0 additions & 1 deletion Seq2SeqSharp/MultiProcessorNetworkWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details.

using AdvUtils;
using Microsoft.Extensions.Logging.Abstractions;
using Seq2SeqSharp.Tools;
using Seq2SeqSharp.Utils;
using System;
Expand Down
9 changes: 2 additions & 7 deletions Seq2SeqSharp/Networks/TransformerDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,16 @@ public void Reset(IWeightFactory weightFactory, int batchSize)
IWeightTensor attnProbs = null;
using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Decoder"))
{
int seqLenQ = tgtInputs.Rows / batchSize;

// SeqLenK must be euqal to SeqLenV
int seqLenK = encOutputBatchFirst.Rows / batchSize;

IWeightTensor selfMaskTensor = null;
if (tgtSelfMask != null)
{
selfMaskTensor = subg.Expand(tgtSelfMask, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenQ });
selfMaskTensor = subg.Expand(tgtSelfMask, dims: new long[] { batchSize, m_multiHeadNum, tgtSelfMask.Sizes[2], tgtSelfMask.Sizes[3] });
}

IWeightTensor crossMaskTensor = null;
if (srcTgtMask != null)
{
crossMaskTensor = subg.Expand(srcTgtMask, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenK });
crossMaskTensor = subg.Expand(srcTgtMask, dims: new long[] { batchSize, m_multiHeadNum, srcTgtMask.Sizes[2], srcTgtMask.Sizes[3] });
}

for (int k = 0; k < m_selfAttns.Count; k++)
Expand Down
Loading

0 comments on commit 113dd85

Please sign in to comment.