Skip to content

Commit

Permalink
Mitigate MoE expert polarization issue
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Dec 14, 2023
1 parent dcace35 commit b65850f
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
if (isTraining)
{
(var c, _) = Decoder.GPTDecode(tgtTokensList, computeGraph, decoder as GPTDecoder, decoderFFLayer, tgtEmbedding, m_modelMetaData.TgtVocab, m_paddingType,
m_options.DropoutRatio, decodingOptions, isTraining, lossType: m_options.LossType, focalLossGamma: m_options.FocalLossGamma,
m_options.DropoutRatio, decodingOptions, isTraining, lossType: m_options.LossType, focalLossGamma: m_options.FocalLossGamma, lossSmooth: m_options.LossSmooth,
segmentEmbeddings: segmentEmbedding, amp: m_options.AMP, posEmbeddings: posEmbeddings);
nr.Cost = c;
nr.Output = null;
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Layers/FeedForwardLayer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public void ClearStatus()
public IWeightTensor Process(IWeightTensor inputT, int batchSize, IComputeGraph g, Dictionary<string, IWeightTensor> cachedTensors = null)
{
IWeightTensor res = g.Affine(inputT, m_Whd, m_Bd, 1.0f);
return g.Dropout(res, batchSize, m_dropoutRatio, inPlace: true);
return g.Dropout(res, m_dropoutRatio, inPlace: true);
}

public virtual List<IWeightTensor> GetParams()
Expand Down
45 changes: 45 additions & 0 deletions Seq2SeqSharp/Layers/MoEFeedForward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Xml.Linq;
using TensorSharp;

namespace Seq2SeqSharp.Layers
Expand Down Expand Up @@ -70,12 +71,20 @@ public void ClearStatus()

}

//DateTime lastCheckDT = DateTime.Now;

public IWeightTensor Process(IWeightTensor input, int batchSize, IComputeGraph graph, Dictionary<string, IWeightTensor> cachedTensors = null)
{
//Computing routing result
using var g = graph.CreateSubGraph($"{m_name}_MoEFeedForward");
var inputNorm = layerNorm.Norm(input, g);
var inputRouterDense = g.Affine(inputNorm, m_Router, m_RouterBias); // [batchSize * seqLen, expertNum]

var maskTensor = g.CreateTensorWeights(new long[] { 1, m_expertNum }, -65000.0f);
maskTensor = g.Dropout(maskTensor, 0.1f);
maskTensor = g.Expand(maskTensor, inputRouterDense.Sizes); // [batchSize * seqLen, expertNum]
inputRouterDense = g.Add(inputRouterDense, maskTensor);

var inputRouter = g.Softmax(inputRouterDense); // [batchSize * seqLen, expertNum]

(var topValue, var topIndex) = g.TopK(inputRouter, m_expertsPerTokenFactor); // [batchSize * seqLen, m_expertsPerTokenFactor]
Expand All @@ -95,6 +104,42 @@ public IWeightTensor Process(IWeightTensor input, int batchSize, IComputeGraph g
}
}

////////////////////////////////////////////
//if (DateTime.Now - lastCheckDT >= TimeSpan.FromMinutes(5.0f))
//{
// lastCheckDT = DateTime.Now;

// Logger.WriteLine(Logger.Level.debug, $"Weight '{m_name}'");
// for (int i = 0; i < indexs.Length; i++)
// {
// Logger.WriteLine(Logger.Level.debug, $"Expert '{i}' is selected by '{indexs[i].Count}' tokens.");
// }

// var weights = inputRouter.ToWeightArray();

// StringBuilder sb = new StringBuilder();

// int colSize = (int)inputRouter.Sizes[^1];
// int idx = 0;
// foreach (var weight in weights)
// {
// sb.Append($"{weight:F4}, ");
// idx++;

// if (idx % colSize == 0)
// {
// sb.AppendLine();
// }
// }

// sb.Append("]");

// Logger.WriteLine(Logger.Level.debug, "*************************");
// Logger.WriteLine(Logger.Level.debug, sb.ToString());
// Logger.WriteLine(Logger.Level.debug, "*************************");
//}
////////////////////////////////////////////

List<IWeightTensor> tokenEmbsList = new List<IWeightTensor>();
List<IWeightTensor> tokenIdxList = new List<IWeightTensor>();

Expand Down
4 changes: 2 additions & 2 deletions Seq2SeqSharp/Layers/MultiHeadAttention.cs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int ba
IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize * newTokensIdx, m_multiHeadNum * m_d });

// Output projection
IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true); // Shape: [batchSize * relPosSize, m_multiHeadNum * m_d]
IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), m_dropoutRatio, inPlace: true); // Shape: [batchSize * relPosSize, m_multiHeadNum * m_d]

if (cachedTensors != null)
{
Expand Down Expand Up @@ -337,7 +337,7 @@ public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int ba
IWeightTensor W = g.View(g.AsContiguous(g.Transpose(o, 1, 2)), dims: new long[] { batchSize * newTokensIdx, m_multiHeadNum * m_d });

// Output projection
IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), batchSize, m_dropoutRatio, inPlace: true);
IWeightTensor finalAttResults = g.Dropout(g.Affine(W, W0, b0), m_dropoutRatio, inPlace: true);
IWeightTensor result = graph.Add(finalAttResults, inputQ, inPlace: true); // Shape: [batchSize * newTokensSize, input_dim]


Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Networks/AttentionDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public IWeightTensor Decode(IWeightTensor input, AttentionPreProcessResult atten
V = e;
}

IWeightTensor eOutput = g.Dropout(V, batchSize, m_dropoutRatio, false);
IWeightTensor eOutput = g.Dropout(V, m_dropoutRatio, false);

return eOutput;
}
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Networks/GPTDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public GPTDecoder(string name, int multiHeadNum, int hiddenDim, int intermediate

for (int i = 0; i < depth; i++)
{
if (m_expertNum > 1 && i % 2 == 0)
if (m_expertNum > 1 && i > 1)
{
m_feedForwards.Add(new MoEFeedForward($"{name}.MoEFFN_{i}", m_expertNum, hiddenDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, expertsPerTokenFactor: expertsPerTokenFactor));
}
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Networks/TransformerDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public TransformerDecoder(string name, int multiHeadNum, int hiddenDim, int inte

for (int i = 0; i < depth; i++)
{
if (m_expertNum > 1 && i % 2 == 0)
if (m_expertNum > 1 && i > 1)
{
m_feedForwards.Add(new MoEFeedForward($"{name}.MoEFFN_{i}", m_expertNum, hiddenDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, expertsPerTokenFactor: expertsPerTokenFactor));
}
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Networks/TransformerEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public TransformerEncoder(string name, int multiHeadNum, int hiddenDim, int inte

for (int i = 0; i < depth; i++)
{
if (m_expertNum > 1 && i % 2 == 0)
if (m_expertNum > 1 && i > 1)
{
m_feedForwards.Add(new MoEFeedForward($"{name}.MoEFFN_{i}", m_expertNum, hiddenDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, expertsPerTokenFactor: expertsPerTokenFactor));
}
Expand Down
112 changes: 92 additions & 20 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
using System.Linq;
using TensorSharp;
using Seq2SeqSharp.Utils;
using ManagedCuda.BasicTypes;
using ManagedCuda.VectorTypes;

/// <summary>
/// Tensor based computing graph written by Zhongkai Fu.
Expand Down Expand Up @@ -1328,6 +1330,32 @@ void backward()
return res;
}

public IWeightTensor TruncateGradient(IWeightTensor w, IWeightTensor mask)
{
WeightTensor m = w as WeightTensor;
WeightTensor mt = mask as WeightTensor;
WeightTensor res = m.CopyWeightsRef($"{GetHashString(w.Name)}.TruncateGradient", needGradient: m.NeedGradient, graphToBind: this);
int colSize = (int)m.Sizes[1];

if (m_needsBackprop)
{
void backward()
{
if (m.NeedGradient)
{
res.ReleaseWeight();
m.AddMulGradient(res.TGradient, mt.TWeight);
}
res.Dispose();

}
m_backprop.Add(backward);
}

return res;
}


public IWeightTensor Argmax(IWeightTensor w, int dim = -1)
{
WeightTensor m = w as WeightTensor;
Expand Down Expand Up @@ -1405,13 +1433,6 @@ public IWeightTensor GreaterThan(IWeightTensor w, float val)
return res;
}

private static double PowerA(double a, double b)
{
int tmp = (int)(BitConverter.DoubleToInt64Bits(a) >> 32);
int tmp2 = (int)(b * (tmp - 1072632447) + 1072632447);
return BitConverter.Int64BitsToDouble(((long)tmp2) << 32);
}

/// <summary>
/// Top-P sampling for each row in given tensor
/// </summary>
Expand Down Expand Up @@ -1730,6 +1751,14 @@ public IWeightTensor CreateTensorWeights(long[] sizes, float[] values)
return res;
}

public IWeightTensor CreateTensorWeights(long[] sizes, float value)
{
WeightTensor res = m_weightTensorFactory.CreateWeightTensor(sizes, m_deviceId, name: $"Tensor_CopyFrom_Array", needGradient: false);
Ops.Fill(res.TWeight, value);

return res;
}

public IWeightTensor CreateUniformRandomTensor(long[] sizes, float minVal, float maxVal)
{
WeightTensor res = m_weightTensorFactory.CreateWeightTensor(sizes, m_deviceId, name: $"New_UniformRandom_Tensor", needGradient: false);
Expand Down Expand Up @@ -2333,7 +2362,7 @@ void backward()
//}


public IWeightTensor Dropout(IWeightTensor V, int batchSize, float drop_prob, bool inPlace = false)
public IWeightTensor Dropout(IWeightTensor V, float drop_prob, bool inPlace = false)
{
if (drop_prob == 0 || !m_needsBackprop)
{
Expand All @@ -2359,26 +2388,30 @@ public IWeightTensor Dropout(IWeightTensor V, int batchSize, float drop_prob, bo

Ops.Mul(res.TWeight, w.TWeight, noiseExp);

void backward()
if (m_needsBackprop)
{
if (w.NeedGradient)
Tensor tn = noiseExp.CopyRef();
void backward()
{
res.ReleaseWeight();

if (inPlace && w.IsGradientNull() && res.TGradient.IsOwnerExclusive())
if (w.NeedGradient)
{
w.TGradient = res.TGradient.CopyRef();
res.ReleaseWeight();

if (inPlace && w.IsGradientNull() && res.TGradient.IsOwnerExclusive())
{
w.TGradient = res.TGradient.CopyRef();
}
w.AddMulGradient(tn, res.TGradient, inPlace);
}

w.AddMulGradient(noiseExp, res.TGradient, inPlace);
res.Dispose();
tn.Dispose();
}

res.Dispose();
noise.Dispose();
noiseExp.Dispose();
m_backprop.Add(backward);
}
m_backprop.Add(backward);

noiseExp.Dispose();
noise.Dispose();

return res;
}
Expand Down Expand Up @@ -2660,6 +2693,45 @@ void backward()
return res;
}

public IWeightTensor BuildMaskUntil(List<List<int>> paddedTokensList, int maskEndId, DType elementType = DType.Float32)
{
int batchSize = paddedTokensList.Count;
int seqLength = paddedTokensList[0].Count;

float[] buf = new float[batchSize * seqLength];
Array.Fill(buf, 1.0f);

for (int batchIdx = 0; batchIdx < batchSize; batchIdx++)
{
for (int tokenIdx = 0; tokenIdx < seqLength; tokenIdx++)
{
int token = paddedTokensList[batchIdx][tokenIdx];
if (token == maskEndId)
{
Array.Fill(buf, 0.0f, batchIdx * seqLength, tokenIdx);
}
}
}

WeightTensor res = m_weightTensorFactory.CreateWeightTensor(new long[] { batchSize, seqLength }, m_deviceId, name: $"MaskUntil_{m_deviceId}", graphToBind: this, needGradient: false, dtype: elementType);

res.SetWeightArray(buf);

if (m_needsBackprop)
{
void backward()
{
res.Dispose();
}
m_backprop.Add(backward);
}


return res;
}



private (float, IWeightTensor) CalculateEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float smooth, float gamma)
{
var scatterIdxTensor = View(truthTgtSeqs, new long[] { -1, 1 });
Expand Down
3 changes: 2 additions & 1 deletion Seq2SeqSharp/Tools/IComputeGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public interface IComputeGraph : IDisposable
IWeightTensor AddTanh(IWeightTensor w1, IWeightTensor w2);
IWeightTensor AddTanh(IWeightTensor w1, IWeightTensor w2, IWeightTensor w3);
IWeightTensor Peek(IWeightTensor w, int dim, int ix, int num = 1);
IWeightTensor Dropout(IWeightTensor V, int batchSize, float drop_prob, bool inPlace = false);
IWeightTensor Dropout(IWeightTensor V, float drop_prob, bool inPlace = false);
IWeightTensor Softmax(IWeightTensor w, bool runGradients = true, bool inPlace = false);
List<IWeightTensor> SplitColumns2(IWeightTensor w, params int[] sizes);
(IWeightTensor r1, IWeightTensor r2) SplitColumns(IWeightTensor w, int size1, int size2);
Expand Down Expand Up @@ -70,6 +70,7 @@ public interface IComputeGraph : IDisposable

IWeightTensor Zero(long[] sizes);
IWeightTensor CreateTensorWeights(long[] sizes, float[] values);
IWeightTensor CreateTensorWeights(long[] sizes, float value);
IWeightTensor IndexSelect(IWeightTensor s, IWeightTensor indice, bool clearWeights = false, bool isAdd = false);
IWeightTensor IndexUpdate(long[] sizes, IWeightTensor s, IWeightTensor indice, bool clearWeights = false);

Expand Down
22 changes: 16 additions & 6 deletions Seq2SeqSharp/Tools/WeightTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -339,26 +339,36 @@ public float[] ToWeightArray()



static DateTime lastCheckDT = DateTime.Now;
// static DateTime lastCheckDT = DateTime.Now;
public void PrintWeights()
{
// if (DateTime.Now - lastCheckDT >= TimeSpan.FromMinutes(5.0f))
// {
lastCheckDT = DateTime.Now;
// if (DateTime.Now - lastCheckDT >= TimeSpan.FromMinutes(5.0f))
// {
// lastCheckDT = DateTime.Now;
var weights = ToWeightArray();

StringBuilder sb = new StringBuilder();
sb.Append($"Weights for '{Name}': [");

int colSize = (int)Sizes[^1];
int idx = 0;
foreach (var weight in weights)
{
sb.Append($"{weight:F4}, ");
idx++;

if (idx % colSize == 0)
{
sb.AppendLine();
}
}

sb.Append("]");

Logger.WriteLine(sb.ToString());
// }
Logger.WriteLine(Logger.Level.debug, "*************************");
Logger.WriteLine(Logger.Level.debug, sb.ToString());
Logger.WriteLine(Logger.Level.debug, "*************************");
// }
}

public void AddSoftmaxGradient(Tensor srcTWeight, Tensor srcTGradient, bool inPlace = false)
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Utils/PositionEmbedding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static IWeightTensor AddPositionEmbedding(IComputeGraph g, IWeightTensor

posEmbeddingPeek.Dispose();

inputEmbs = g.Dropout(inputEmbs, batchSize, dropoutRatio, inPlace: true);
inputEmbs = g.Dropout(inputEmbs, dropoutRatio, inPlace: true);

return inputEmbs;
}
Expand Down

0 comments on commit b65850f

Please sign in to comment.