diff --git a/AdvUtils/AdvUtils.csproj b/AdvUtils/AdvUtils.csproj
index e72bd11..02957dd 100644
--- a/AdvUtils/AdvUtils.csproj
+++ b/AdvUtils/AdvUtils.csproj
@@ -14,7 +14,7 @@
enable
enable
True
- 2.8.16
+ 2.7.0
Zhongkai Fu
A utility for common alogrithms
diff --git a/Seq2SeqSharp/Applications/Decoder.cs b/Seq2SeqSharp/Applications/Decoder.cs
index 01cd25f..e00d74c 100644
--- a/Seq2SeqSharp/Applications/Decoder.cs
+++ b/Seq2SeqSharp/Applications/Decoder.cs
@@ -269,7 +269,7 @@ public static List> CombineBeamSearchResults(List>) DecodeTransformer(List> tgtSeqs, IComputeGraph g, IWeightTensor encOutputs, TransformerDecoder decoder, IFeedForwardLayer decoderFFLayer,
IWeightTensor tgtEmbedding, float[] srcOriginalLenghts, Vocab tgtVocab, PaddingEnums paddingType, float dropoutRatio, DecodingOptions decodingOptions, bool isTraining = true,
- bool outputSentScore = true, List previousBeamSearchResults = null, IFeedForwardLayer pointerGenerator = null, List> srcSeqs = null, Dictionary cachedTensors = null,
+ bool outputSentScore = true, List previousBeamSearchResults = null, IFeedForwardLayer pointerGenerator = null, List> srcSeqs = null, Dictionary contextTensors = null,
List> alignmentsToSrc = null, List> alignmentScoresToSrc = null, bool teacherForcedAlignment = false, LossEnums lossType = LossEnums.CrossEntropy, float labelSmooth = 0.0f, float lossSmooth = 1e-9f,
List blockedTokens = null, IWeightTensor segmentEmbeddings = null, bool amp = false, IWeightTensor posEmbeddings = null, float lossScaling = 1.0f, int paddingAlignmentFactor = 0)
{
@@ -281,33 +281,59 @@ public static (float, List>) DecodeTransformer(List 0)
{
- if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1)
+ // Only add last token in the target sequence into token list and cross-mask tensor
+ if (srcTgtMask != null)
{
- tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
- tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen });
+ srcTgtMask = g.Peek(srcTgtMask, 2, tgtSeqLen - 1, 1);
}
- else
+
+ List> t = new List>();
+ for (int i = 0; i < tgtSeqs.Count; i++)
{
- tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
- tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen });
+ t.Add(new List());
+ t[i].Add(tgtSeqs[i][tgtSeqs[i].Count - 1]);
}
- }
+ tgtSeqLen = t[0].Count;
- IWeightTensor inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
- if (posEmbeddings != null)
+ inputEmbs = TensorUtils.CreateTokensEmbeddings(t, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
+ if (posEmbeddings != null)
+ {
+ inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio); //Output Shape: [batchSize * seqLen, hidden_dim]
+ }
+ }
+ else
{
- inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio);
+ if (decoder.AttentionType == AttentionTypeEnums.Classic)
+ {
+ if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1)
+ {
+ tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
+ tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen });
+ }
+ else
+ {
+ tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32);
+ tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen });
+ }
+ }
+
+ inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
+ if (posEmbeddings != null)
+ {
+ inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio);
+ }
}
IWeightTensor decOutput;
IWeightTensor decEncAttnProbs;
- (decOutput, decEncAttnProbs) = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g, outputAttnWeights: pointerGenerator != null, cachedTensors: cachedTensors);
+ (decOutput, decEncAttnProbs) = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g, outputAttnWeights: pointerGenerator != null, cachedTensors: contextTensors);
if (isTraining == false && teacherForcedAlignment == false)
{
diff --git a/Seq2SeqSharp/Applications/Seq2Seq.cs b/Seq2SeqSharp/Applications/Seq2Seq.cs
index 31d36d0..16d3f83 100644
--- a/Seq2SeqSharp/Applications/Seq2Seq.cs
+++ b/Seq2SeqSharp/Applications/Seq2Seq.cs
@@ -12,7 +12,7 @@
using System.Collections.Generic;
using System.IO;
using AdvUtils;
-using Microsoft.Extensions.Caching.Memory;
+using System.Runtime.Caching;
using Seq2SeqSharp.Enums;
using Seq2SeqSharp.Corpus;
using Seq2SeqSharp.Layers;
@@ -41,8 +41,6 @@ public class Seq2Seq : BaseSeq2SeqFramework
private readonly PaddingEnums m_paddingType = PaddingEnums.AllowPadding;
readonly Seq2SeqOptions m_options = null;
- private MemoryCache m_memoryCache;
-
public Seq2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab = null)
: base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio,
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
@@ -57,11 +55,6 @@ public Seq2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab = n
// Check if options are valided.
m_options.ValidateOptions();
- m_memoryCache = new MemoryCache(new MemoryCacheOptions
- {
- SizeLimit = 1024
- });
-
if (File.Exists(m_options.ModelFilePath))
{
if (srcVocab != null || tgtVocab != null)
@@ -193,16 +186,15 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu
//}
IWeightTensor encOutput;
- if (!isTraining && (m_options.ProcessorType == ProcessorTypeEnums.CPU))
+ if (m_options.UseKVCache)
{
// Try to get src tensor from cache
string cacheKey = GenerateCacheKey(srcSnts);
- if (!m_memoryCache.TryGetValue(cacheKey, out encOutput))
- {
+ encOutput = MemoryCache.Default[cacheKey] as IWeightTensor;
+ if (encOutput == null)
+ {
encOutput = Encoder.Run(computeGraph, encoder, m_modelMetaData, m_paddingType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths); // Shape: [batchsize * seqLen, embedding_dim]
-
- var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1);
- m_memoryCache.Set(cacheKey, encOutput.CopyWeightsRef($"cache_{encOutput.Name}", false, graphToBind: null), cacheEntryOptions);
+ MemoryCache.Default.Set(cacheKey, encOutput.CopyWeightsRef($"cache_{encOutput.Name}", false, graphToBind: null), DateTimeOffset.Now + TimeSpan.FromMinutes(10));
}
}
else
@@ -265,7 +257,18 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu
}
else
{ // Test mode or running validation in Training mode
- Dictionary cachedTensors = new Dictionary();
+ string cacheKey = GenerateCacheKey(tgtSnts);
+ Dictionary cachedTensors = null;
+ if (m_options.UseKVCache)
+ {
+ cachedTensors = MemoryCache.Default[cacheKey] as Dictionary;
+ if (cachedTensors == null && decodingOptions.BeamSearchSize == 1)
+ {
+ cachedTensors = new Dictionary();
+ }
+ MemoryCache.Default.Remove(cacheKey);
+ }
+
List> beam2batchStatus = Decoder.InitBeamSearchStatusListList(batchSize, tgtTokensList);
for (int i = tgtTokensList[0].Count; i < decodingOptions.MaxTgtSentLength; i++)
{
@@ -292,7 +295,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu
originalSrcLengths, m_modelMetaData.TgtVocab, m_paddingType, 0.0f, decodingOptions, isTraining,
outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus,
pointerGenerator: pointerGenerator, srcSeqs: srcTokensList,
- cachedTensors: cachedTensors, alignmentsToSrc: alignmentsToSrc, alignmentScoresToSrc: alignmentScores,
+ contextTensors: cachedTensors, alignmentsToSrc: alignmentsToSrc, alignmentScoresToSrc: alignmentScores,
blockedTokens: decodingOptions.BlockedTokens, segmentEmbeddings: segmentEmbedding, amp: m_options.AMP,
posEmbeddings: posEmbeddings, lossScaling: LossScaling, paddingAlignmentFactor: m_options.PaddingAlignmentFactor);
@@ -334,10 +337,13 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu
if (cachedTensors != null)
{
+ cacheKey = GenerateCacheKey(nr.Output[0]);
+ Dictionary newCachedTensors = new Dictionary();
foreach (var pair in cachedTensors)
{
- pair.Value.Dispose();
+ newCachedTensors.Add(pair.Key, pair.Value.CopyWeightsRef(pair.Value.Name, false, graphToBind: null));
}
+ MemoryCache.Default.Set(cacheKey, newCachedTensors, DateTimeOffset.Now + TimeSpan.FromMinutes(10));
}
}
}
diff --git a/Seq2SeqSharp/Layers/MultiHeadAttention.cs b/Seq2SeqSharp/Layers/MultiHeadAttention.cs
index 708143b..950554c 100644
--- a/Seq2SeqSharp/Layers/MultiHeadAttention.cs
+++ b/Seq2SeqSharp/Layers/MultiHeadAttention.cs
@@ -437,27 +437,6 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int
using IComputeGraph g = graph.CreateSubGraph(keyName);
int seqLenQ = inputQ.Rows / batchSize;
- int newTokensIdx = seqLenQ;
- IWeightTensor m_cacheQs = null;
- string QKeyName = keyName + "_" + nameof(inputQ);
- if (cachedTensors != null)
- {
- if (cachedTensors.ContainsKey(QKeyName) == true)
- {
- m_cacheQs = cachedTensors[QKeyName];
- newTokensIdx = seqLenQ - (int)m_cacheQs.Sizes[0];
- }
- else
- {
- cachedTensors.Add(QKeyName, null);
- }
-
- // Optimize runtime for test that only processing new tokens
- inputQ = g.View(inputQ, dims: new long[] { batchSize, seqLenQ, -1 });
- inputQ = g.AsContiguous(g.Peek(inputQ, 1, seqLenQ - newTokensIdx, newTokensIdx)); // Shape: [batchSize, newTokensSize, input_dim]
- inputQ = g.View(inputQ, dims: new long[] { batchSize * newTokensIdx, -1 }); // Shape: [batchSize * newTokensSize, input_dim]
- }
-
// SeqLenK must be euqal to SeqLenV
int seqLenK = inputK.Rows / batchSize;
int seqLenV = inputV.Rows / batchSize;
@@ -465,10 +444,10 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int
IWeightTensor inputQNorm = layerNormQ.Norm(inputQ, g);
//Input projections
- IWeightTensor allQ = g.View(g.Affine(inputQNorm, Q, Qb), dims: new long[] { batchSize, newTokensIdx, m_multiHeadNum, m_d });
+ IWeightTensor allQ = g.View(g.Affine(inputQNorm, Q, Qb), dims: new long[] { batchSize, seqLenQ, m_multiHeadNum, m_d });
//Multi-head attentions
- IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, newTokensIdx, m_d });
+ IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenQ, m_d });
IWeightTensor Ks = null;
IWeightTensor Vs = null;
@@ -512,15 +491,10 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int
// Scaled softmax
float scale = 1.0f / (float)(Math.Sqrt(m_d));
var attn = g.MulBatch(Qs, Ks, scale); // Shape: [batchSize * m_multiHeadNum, newTokensIdx, seqLenK]
- attn = g.View(attn, dims: new long[] { batchSize, m_multiHeadNum, newTokensIdx, seqLenK });
+ attn = g.View(attn, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenK });
if (keyMask != null)
{
- if (cachedTensors != null)
- {
- keyMask = g.Peek(keyMask, 2, seqLenQ - newTokensIdx, newTokensIdx); // Shape: [batchSize, m_multiHeadNum, newTokensIdx, seqLenK]
- }
-
attn = g.Add(attn, keyMask, inPlace: true); // Shape: [batchSize, m_multiHeadNum, newTokensIdx, seqLenK]
}
@@ -537,40 +511,18 @@ private IWeightTensor PerformFlashAttentionWithCausal(IWeightTensor inputQ, int
}
sumAttnWeights = graph.Div(sumAttnWeights, (float)m_multiHeadNum, inPlace: true);
- sumAttnWeights = graph.View(sumAttnWeights, new long[] { batchSize * newTokensIdx, seqLenK });
+ sumAttnWeights = graph.View(sumAttnWeights, new long[] { batchSize * seqLenQ, seqLenK });
}
- attnProbs = g.View(attnProbs, dims: new long[] { batchSize * m_multiHeadNum, newTokensIdx, seqLenK });
+ attnProbs = g.View(attnProbs, dims: new long[] { batchSize * m_multiHeadNum, seqLenQ, seqLenK });
- IWeightTensor o = g.View(g.MulBatch(attnProbs, Vs), dims: new long[] { batchSize, m_multiHeadNum, newTokensIdx, m_d });
- IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize * newTokensIdx, m_multiHeadNum * m_d });
+ IWeightTensor o = g.View(g.MulBatch(attnProbs, Vs), dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, m_d });
+ IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize * seqLenQ, m_multiHeadNum * m_d });
// Output projection
IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), m_dropoutRatio, inPlace: true);
IWeightTensor result = graph.Add(finalAttResults, inputQ, inPlace: true); // Shape: [batchSize * newTokensSize, input_dim]
-
- if (cachedTensors != null)
- {
- result = g.View(result, dims: new long[] { batchSize, newTokensIdx, m_multiHeadNum * m_d });
- result = g.AsContiguous(g.Transpose(result, 0, 1)); // Shape: [newTokensIdx, batchSize, m_multiHeadNum * m_d]
-
- if (m_cacheQs == null)
- {
- m_cacheQs = result;// Shape: [newTokensIdx, batchSize, m_multiHeadNum * m_d]
- }
- else
- {
- m_cacheQs = g.Concate(0, m_cacheQs, result);
- }
- m_cacheQs.UnbindFromComputeGraph();
-
- cachedTensors[QKeyName] = m_cacheQs;
-
- result = g.AsContiguous(g.Transpose(m_cacheQs, 0, 1)); // Shape: [batchSize, seqLenQ, m_multiHeadNum * m_d]
- result = graph.View(result, dims: new long[] { batchSize * seqLenQ, m_multiHeadNum * m_d });
- }
-
return (result, sumAttnWeights);
}
diff --git a/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs b/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs
index 061d182..92e07b9 100644
--- a/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs
+++ b/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs
@@ -9,7 +9,6 @@
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details.
using AdvUtils;
-using Microsoft.Extensions.Logging.Abstractions;
using Seq2SeqSharp.Tools;
using Seq2SeqSharp.Utils;
using System;
diff --git a/Seq2SeqSharp/Networks/TransformerDecoder.cs b/Seq2SeqSharp/Networks/TransformerDecoder.cs
index 258d46c..43cf5d0 100644
--- a/Seq2SeqSharp/Networks/TransformerDecoder.cs
+++ b/Seq2SeqSharp/Networks/TransformerDecoder.cs
@@ -135,21 +135,16 @@ public void Reset(IWeightFactory weightFactory, int batchSize)
IWeightTensor attnProbs = null;
using (IComputeGraph subg = g.CreateSubGraph($"{m_name}_Decoder"))
{
- int seqLenQ = tgtInputs.Rows / batchSize;
-
- // SeqLenK must be euqal to SeqLenV
- int seqLenK = encOutputBatchFirst.Rows / batchSize;
-
IWeightTensor selfMaskTensor = null;
if (tgtSelfMask != null)
{
- selfMaskTensor = subg.Expand(tgtSelfMask, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenQ });
+ selfMaskTensor = subg.Expand(tgtSelfMask, dims: new long[] { batchSize, m_multiHeadNum, tgtSelfMask.Sizes[2], tgtSelfMask.Sizes[3] });
}
IWeightTensor crossMaskTensor = null;
if (srcTgtMask != null)
{
- crossMaskTensor = subg.Expand(srcTgtMask, dims: new long[] { batchSize, m_multiHeadNum, seqLenQ, seqLenK });
+ crossMaskTensor = subg.Expand(srcTgtMask, dims: new long[] { batchSize, m_multiHeadNum, srcTgtMask.Sizes[2], srcTgtMask.Sizes[3] });
}
for (int k = 0; k < m_selfAttns.Count; k++)
diff --git a/Seq2SeqSharp/Seq2SeqSharp.csproj b/Seq2SeqSharp/Seq2SeqSharp.csproj
index 71d1d40..08bfb04 100644
--- a/Seq2SeqSharp/Seq2SeqSharp.csproj
+++ b/Seq2SeqSharp/Seq2SeqSharp.csproj
@@ -15,7 +15,7 @@
AnyCPU
false
bin\
- 2.8.16
+ 2.8.17
Seq2SeqSharp is a tensor based fast & flexible encoder-decoder deep neural network framework written by .NET (C#). It can be used for sequence-to-sequence task, sequence-labeling task and sequence-classification task and other NLP tasks. Seq2SeqSharp supports both CPUs (x86, x64 and ARM64) and GPUs. It's powered by .NET core, so Seq2SeqSharp can run on both Windows and Linux without any modification and recompilation.
README.md
Seq2SeqSharp
@@ -47,7 +47,6 @@
-