Skip to content

Commit

Permalink
Optimize fp16 model loading performance
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Nov 14, 2024
1 parent e8a9884 commit 9528778
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
45 changes: 28 additions & 17 deletions Seq2SeqSharp/Models/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@
using Seq2SeqSharp.Utils;
using Seq2SeqSharp.Enums;
using TensorSharp;
using System.Runtime.InteropServices;

namespace Seq2SeqSharp.Models
{
[StructLayout(LayoutKind.Explicit)]
public class Name2WeightsHalf
{
[FieldOffset(0)] public Dictionary<string, ushort[]> usDict;
[FieldOffset(0)] public Dictionary<string, half[]> halfDict;
}

[Serializable]
public abstract class Model : IModel
{
Expand Down Expand Up @@ -56,7 +64,7 @@ public abstract class Model : IModel

public Dictionary<string, float[]> Name2Weights { get; set; }

public Dictionary<string, ushort[]> Name2WeightsHalf { get; set; }
public Name2WeightsHalf Name2WeightsHalf { get; set; }

public VQTypeEnums VQType { get; set; }
public Dictionary<string, byte[]> Name2WeightsVQ { get; set; }
Expand Down Expand Up @@ -92,7 +100,7 @@ public Model(Options opts,Vocab srcVocab, Vocab tgtVocab)
KVGroupNum = opts.KVGroupNum;

Name2Weights = new Dictionary<string, float[]>();
Name2WeightsHalf= new Dictionary<string, ushort[]>();
Name2WeightsHalf = new Name2WeightsHalf();
Name2WeightsVQ = new Dictionary<string, byte[]>();
Name2CodeBook = new Dictionary<string, double[]>();
}
Expand All @@ -117,7 +125,8 @@ public Model(Model_4_ProtoBufSerializer m)
VQType = m.VQType;

Name2Weights = m.Name2Weights;
Name2WeightsHalf = m.Name2WeightsHalf;
Name2WeightsHalf = new Name2WeightsHalf();
Name2WeightsHalf.usDict = m.Name2WeightsHalf;
Name2WeightsVQ = m.Name2WeightsVQ;
Name2CodeBook = m.Name2CodeBook;
PEType = m.PEType;
Expand All @@ -132,9 +141,9 @@ public Model(Model_4_ProtoBufSerializer m)
Name2Weights = new Dictionary<string, float[]>();
}

if (Name2WeightsHalf == null)
if (Name2WeightsHalf.usDict == null)
{
Name2WeightsHalf = new Dictionary<string, ushort[]>();
Name2WeightsHalf.usDict = new Dictionary<string, ushort[]>();
}

if (Name2WeightsVQ == null)
Expand All @@ -159,7 +168,7 @@ public void AddWeights(string name, float[] weights)
{
weightsHalf[i] = (new half(weights[i])).x;
}
Name2WeightsHalf.Add(name, weightsHalf);
Name2WeightsHalf.usDict.Add(name, weightsHalf);
}
else if (VQType == VQTypeEnums.INT8)
{
Expand Down Expand Up @@ -237,7 +246,7 @@ public float[] GetWeights(string name)
{
weight = Name2Weights[name];
}
else if (Name2WeightsHalf.ContainsKey(name))
else if (Name2WeightsHalf.halfDict.ContainsKey(name))
{
throw new InvalidCastException($"The model is saved as Float16 type, so please enable AMP for model loading.");
}
Expand Down Expand Up @@ -298,14 +307,16 @@ public half[] GetWeightsHalfType(string name)
weights[i] = new half(values[i]);
}
}
else if (Name2WeightsHalf.ContainsKey(name))
else if (Name2WeightsHalf.halfDict.ContainsKey(name))
{
var values = Name2WeightsHalf[name];
weights = new half[values.Length];
for (int i = 0; i < values.Length; i++)
{
weights[i] = new half(values[i]);
}
weights = Name2WeightsHalf.halfDict[name];

//var values = Name2WeightsHalf[name];
//weights = new half[values.Length];
//for (int i = 0; i < values.Length; i++)
//{
// weights[i] = new half(values[i]);
//}
}
else if (VQType == VQTypeEnums.INT8)
{
Expand Down Expand Up @@ -368,9 +379,9 @@ public void DeleteWeights(string name)
Name2Weights.Remove(name);
}

if (Name2WeightsHalf != null && Name2WeightsHalf.ContainsKey(name))
if (Name2WeightsHalf != null && Name2WeightsHalf.halfDict.ContainsKey(name))
{
Name2WeightsHalf.Remove(name);
Name2WeightsHalf.halfDict.Remove(name);
}
}

Expand All @@ -379,7 +390,7 @@ public void ClearWeights()
Name2WeightsVQ.Clear();
Name2CodeBook.Clear();
Name2Weights.Clear();
Name2WeightsHalf.Clear();
Name2WeightsHalf.halfDict.Clear();
}

public void ShowModelInfo()
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ public Model_4_ProtoBufSerializer(Model m)
Name2Weights = new Dictionary<string, float[]>();
}

Name2WeightsHalf = m.Name2WeightsHalf;
Name2WeightsHalf = m.Name2WeightsHalf.usDict;
if (Name2WeightsHalf == null)
{
Name2WeightsHalf = new Dictionary<string, ushort[]>();
Expand Down

0 comments on commit 9528778

Please sign in to comment.