Skip to content

Commit

Permalink
Update RMSNorm operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Aug 21, 2024
1 parent 9629e64 commit c4087ad
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 327 deletions.
7 changes: 1 addition & 6 deletions Seq2SeqSharp/Layers/RMSNormalization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@ namespace Seq2SeqSharp
internal class RMSNormalization : INormalization
{
private readonly IWeightTensor m_alpha;
private readonly IWeightTensor m_beta;
private readonly float m_epsilon;

public RMSNormalization(string name, int dim, int deviceId, bool isTrainable, float learningRateFactor = 1.0f, float epsilon = 1e-06f, DType elementType = DType.Float32)
{
m_alpha = new WeightTensor(new long[2] { 1, dim }, 1.0f, deviceId, name: $"{name}.{nameof(m_alpha)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType);
m_beta = new WeightTensor(new long[2] { 1, dim }, 0, deviceId, name: $"{name}.{nameof(m_beta)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType);
m_epsilon = epsilon;
}

public IWeightTensor Norm(IWeightTensor input, IComputeGraph g)
{
var result = g.RMSNorm(input, m_alpha, m_beta, m_epsilon);
var result = g.RMSNorm(input, m_alpha, m_epsilon);
return result;
}

Expand All @@ -41,7 +39,6 @@ public virtual List<IWeightTensor> GetParams()
List<IWeightTensor> response = new List<IWeightTensor>
{
m_alpha,
m_beta
};

return response;
Expand All @@ -50,14 +47,12 @@ public virtual List<IWeightTensor> GetParams()
public void Save(IModel stream)
{
m_alpha.Save(stream);
m_beta.Save(stream);
}


public void Load(IModel stream)
{
m_alpha.Load(stream);
m_beta.Load(stream);
}

public INeuralUnit CloneToDeviceAt(int deviceId)
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Seq2SeqSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<PlatformTarget>AnyCPU</PlatformTarget>
<AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath>
<OutputPath>bin\</OutputPath>
<Version>2.8.9</Version>
<Version>2.8.10</Version>
<Description>Seq2SeqSharp is a tensor based fast &amp; flexible encoder-decoder deep neural network framework written by .NET (C#). It can be used for sequence-to-sequence task, sequence-labeling task and sequence-classification task and other NLP tasks. Seq2SeqSharp supports both CPUs (x86, x64 and ARM64) and GPUs. It's powered by .NET core, so Seq2SeqSharp can run on both Windows and Linux without any modification and recompilation.</Description>
<PackageReadmeFile>README.md</PackageReadmeFile>
<Title>Seq2SeqSharp</Title>
Expand Down
13 changes: 5 additions & 8 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3416,16 +3416,15 @@ void backward()
}


public IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, IWeightTensor beta, float eps = 1e-9f)
public IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, float eps = 1e-9f)
{
WeightTensor srcT = src as WeightTensor;
WeightTensor alphaT = alpha as WeightTensor;
WeightTensor betaT = beta as WeightTensor;

WeightTensor res = m_weightTensorFactory.CreateWeightTensor(srcT.Sizes, m_deviceId, name: $"{GetHashString(src.Name, alpha.Name, beta.Name)}.RMSNorm", graphToBind: this, needGradient: srcT.NeedGradient, dtype: src.ElementType);
VisualizeNodes(new IWeightTensor[] { src, alpha, beta }, res);
WeightTensor res = m_weightTensorFactory.CreateWeightTensor(srcT.Sizes, m_deviceId, name: $"{GetHashString(src.Name, alpha.Name )}.RMSNorm", graphToBind: this, needGradient: srcT.NeedGradient, dtype: src.ElementType);
VisualizeNodes(new IWeightTensor[] { src, alpha }, res);

Ops.RMSNorm(res.TWeight, srcT.TWeight, alphaT.TWeight, betaT.TWeight, eps);
Ops.RMSNorm(res.TWeight, srcT.TWeight, alphaT.TWeight, eps);
if (m_autoCheckCorruption)
{
if (res.IsWeightsCorrupted())
Expand All @@ -3439,17 +3438,15 @@ public IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, IWeightTens
var srcTWeight = srcT.TWeight.CopyRef();
var resTWeight = res.TWeight.CopyRef();
var alphaTWeight = alphaT.TWeight.CopyRef();
var betaTWeight = betaT.TWeight.CopyRef();
void backward()
{
if (srcT.NeedGradient)
{
Ops.RMSNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaTWeight, betaTWeight, eps);
Ops.RMSNormGrad(srcT.TGradient, alphaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaTWeight, eps);
}
srcTWeight.Dispose();
resTWeight.Dispose();
alphaTWeight.Dispose();
betaTWeight.Dispose();

res.Dispose();
}
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Tools/IComputeGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public interface IComputeGraph : IDisposable
IWeightTensor Transpose(IWeightTensor w);
IWeightTensor Mul(IWeightTensor w, float v, bool inPlace = false);
IWeightTensor LayerNorm(IWeightTensor src, IWeightTensor alpha, IWeightTensor beta, float eps = 1e-9f);
IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, IWeightTensor beta, float eps = 1e-9f);
IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, float eps = 1e-9f);

IWeightTensor Select(IWeightTensor src, int dim, int index);
void Backward();
Expand Down
4 changes: 2 additions & 2 deletions TensorSharp.CUDA/CudaBasicOps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,9 @@ public Tensor BuildTriMask(Tensor result, float value, float maskedValue)


[RegisterOpStorageType("rmsnorm", typeof(CudaStorage))]
public Tensor RMSNorm(Tensor result, Tensor src, Tensor alpha, Tensor beta, float eps = 1e-09f) { return advFuncKernels.RMSNorm(result, src, alpha, beta, eps); }
public Tensor RMSNorm(Tensor result, Tensor src, Tensor alpha, float eps = 1e-09f) { return advFuncKernels.RMSNorm(result, src, alpha, eps); }
[RegisterOpStorageType("rmsnormgrad", typeof(CudaStorage))]
public Tensor RMSNormGrad(Tensor outGrad, Tensor alphaGrad, Tensor betaGrad, Tensor inGrad, Tensor y, Tensor x, Tensor alpha, Tensor beta, float eps = 1e-09f) { return advFuncKernels.RMSNormGrad(outGrad, alphaGrad, betaGrad, inGrad, y, x, alpha, beta, eps); }
public Tensor RMSNormGrad(Tensor outGrad, Tensor alphaGrad, Tensor inGrad, Tensor y, Tensor x, Tensor alpha, float eps = 1e-09f) { return advFuncKernels.RMSNormGrad(outGrad, alphaGrad, inGrad, y, x, alpha, eps); }

[RegisterOpStorageType("flashattention", typeof(CudaStorage))]
public Tensor FlashAttention(Tensor O, Tensor L, Tensor Q, Tensor K, Tensor V, int q_start_offset = 0) { return advFuncKernels.FlashAttention(O, L, Q, K, V, q_start_offset); }
Expand Down
Loading

0 comments on commit c4087ad

Please sign in to comment.