Skip to content

Commit

Permalink
Optimize GetNextLength
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Dec 1, 2023
1 parent 6b3861c commit 5bd2a37
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 38 deletions.
16 changes: 8 additions & 8 deletions Seq2SeqSharp/Applications/Image2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class Image2Seq : BaseSeq2SeqFramework<Seq2SeqModel>
private MultiProcessorNetworkWrapper<IWeightTensor> m_tgtEmbedding; //The embeddings over devices for source
private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_srcEmbedding; //The embeddings over devices for source

private MultiProcessorNetworkWrapper<IWeightTensor> m_pixelEmbeddings;
// private MultiProcessorNetworkWrapper<IWeightTensor> m_pixelEmbeddings;

private MultiProcessorNetworkWrapper<IEncoder> m_encoder; //The encoders over devices.
private MultiProcessorNetworkWrapper<IDecoder> m_decoder; //The decoders over devices
Expand Down Expand Up @@ -130,8 +130,8 @@ private bool CreateTrainableParameters(IModel model)
m_srcEmbedding = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("SrcEmbedding_Decoder_0", 768, model.HiddenDim, dropoutRatio: 0.2f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: true, learningRateFactor: m_options.EncoderStartLearningRateFactor, elementType: elementType), DeviceIds);

m_pixelEmbeddings = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 768, 1 }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "PIXEL", learningRateFactor: m_options.EncoderStartLearningRateFactor,
isTrainable: true, dtype: elementType), raDeviceIds.ToArray());
//m_pixelEmbeddings = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 768, 1 }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "PIXEL", learningRateFactor: m_options.EncoderStartLearningRateFactor,
// isTrainable: true, dtype: elementType), raDeviceIds.ToArray());


if (model.PointerGenerator)
Expand Down Expand Up @@ -163,17 +163,17 @@ public void VQModel()
/// <summary>
/// Get networks on specific devices
/// </summary>
private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IFeedForwardLayer, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IFeedForwardLayer, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
{
var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId);
return (m_encoder.GetNetworkOnDevice(deviceIdIdx),
m_decoder.GetNetworkOnDevice(deviceIdIdx),
m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx),
m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx),
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx));
// m_cls.GetNetworkOnDevice(deviceIdIdx), m_layerNorm.GetNetworkOnDevice(deviceIdIdx),
m_pixelEmbeddings.GetNetworkOnDevice(deviceIdIdx));
// m_pixelEmbeddings.GetNetworkOnDevice(deviceIdIdx));
}

private string GenerateCacheKey(List<List<string>> strs)
Expand All @@ -199,10 +199,10 @@ private string GenerateCacheKey(List<List<string>> strs)
/// <returns>The cost of forward part</returns>
public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, IPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining)
{
(var encoder, var decoder, var decoderFFLayer, var srcEmbeddings, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings, var pixelEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId);
(var encoder, var decoder, var decoderFFLayer, var srcEmbeddings, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId);

var srcSnts = sntPairBatch.GetSrcTokens();
IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbeddings, posEmbeddings, null, m_modelMetaData.HiddenDim, null, pixelEmbeddings);
IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbeddings, posEmbeddings, null, m_modelMetaData.HiddenDim, null);

List<NetworkResult> nrs = new List<NetworkResult>();

Expand Down
13 changes: 6 additions & 7 deletions Seq2SeqSharp/Applications/ImgEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static class ImgEncoder
static int TOKEN_W = 16;
static int TOKEN_H = 16;

static private IWeightTensor LoadImageToTokens(IComputeGraph g, string filePath, IWeightTensor pixelEmbeddings)
static private IWeightTensor LoadImageToTokens(IComputeGraph g, string filePath)
{

List<float[]> tokens = new List<float[]>();
Expand Down Expand Up @@ -75,8 +75,7 @@ static private IWeightTensor LoadImageToTokens(IComputeGraph g, string filePath,
}
});

var indice = g.CreateTensorWeights(new long[] { IMAGE_W * IMAGE_H *3, 1 }, processedImage);
IWeightTensor res = g.IndexSelect(pixelEmbeddings, indice);
IWeightTensor res = g.CreateTensorWeights(new long[] { IMAGE_W, IMAGE_H, 3 }, processedImage);
res = g.View(res, dims: new long[] {IMAGE_W / TOKEN_W, TOKEN_W, IMAGE_H / TOKEN_H, TOKEN_H, 3 });
res = g.AsContiguous(g.Transpose(res, 1, 2)); // shape: [IMAGE_W / TOKEN_W, IMAGE_H / TOKEN_H, TOKEN_W, TOKEN_H, 3]
res = g.View(res, dims: new long[] { -1, 768 });
Expand All @@ -91,24 +90,24 @@ static private IWeightTensor LoadImageToTokens(IComputeGraph g, string filePath,
//Size(token) = TOTAL_TOKEN_NUM_PER_IMG
//Size(embedding_dim) = 768
//Shape: [batchsize, TOTAL_TOKEN_NUM_PER_IMG, 768]
static private IWeightTensor InnerEncode(IComputeGraph g, List<string> imgPaths, IWeightTensor pixelEmbeddings)
static private IWeightTensor InnerEncode(IComputeGraph g, List<string> imgPaths)
{
int batchSize = imgPaths.Count;
List<IWeightTensor> batchTokens = new List<IWeightTensor>();

foreach (var picPath in imgPaths)
{
batchTokens.Add(LoadImageToTokens(g, picPath, pixelEmbeddings)); //shape: [TOTAL_TOKEN_NUM_PER_IMG, 768]
batchTokens.Add(LoadImageToTokens(g, picPath)); //shape: [TOTAL_TOKEN_NUM_PER_IMG, 768]
}

var res = g.Concate(batchTokens, 0);
return res;
}

static public IWeightTensor Run(IComputeGraph g, List<string> imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls, int dim, INormalization layernorm, IWeightTensor pixelEmbeddings)
static public IWeightTensor Run(IComputeGraph g, List<string> imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls, int dim, INormalization layernorm)
{
int batchSize = imgPaths.Count;
var inputEmbs = InnerEncode(g, imgPaths, pixelEmbeddings);
var inputEmbs = InnerEncode(g, imgPaths);

// inputEmbs = layernorm.Norm(inputEmbs, g);
inputEmbs = srcEmbeddings.Process(inputEmbs, batchSize, g);
Expand Down
22 changes: 11 additions & 11 deletions Seq2SeqSharp/Corpus/MonoCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,13 @@ public List<Dictionary<string, long>> CountTokenFreqs()
}


public long GetNextLength(Dictionary<long, LinkedList<long>> len2offsets, Dictionary<long, long> len2counts)
public long GetNextLength(Dictionary<long, long> len2counts, long totalRecordsNum)
{
long rndItems = rnd.NextInt64(totalRecordsNum);
long totalItems = 0;
foreach (var pair in len2offsets)
foreach (var pair in len2counts)
{
totalItems += len2counts[pair.Key];
}

int rndItems = rnd.Next((int)totalItems);
totalItems = 0;
foreach (var pair in len2offsets)
{
long length = len2counts[pair.Key];
long length = pair.Value;
if (totalItems <= rndItems && totalItems + length >= rndItems)
{
return pair.Key;
Expand All @@ -253,6 +247,11 @@ public void PrepareDataSet()
{
m_batchNumInTotal = 0;
(var length2offsets, var length2counts, string tmpDataSetFilePath) = BuildIndex();
long totalRecordsNum = 0;
foreach (var pair in length2offsets)
{
totalRecordsNum += length2counts[pair.Key];
}

Logger.WriteLine(Logger.Level.debug, $"Start to sort and shuffle data set by length.");

Expand All @@ -265,7 +264,7 @@ public void PrepareDataSet()
{
while (length2offsets.Count > 0)
{
long length = GetNextLength(length2offsets, length2counts);
long length = GetNextLength(length2counts, totalRecordsNum);
LinkedList<long> offsets = length2offsets[length];

int totalTgtTokenSize = 0;
Expand All @@ -276,6 +275,7 @@ public void PrepareDataSet()
long offset = offsets.First.Value;
offsets.RemoveFirst();
length2counts[length]--;
totalRecordsNum--;

br.BaseStream.Seek(offset, SeekOrigin.Begin);
string tgtLine = br.ReadString();
Expand Down
23 changes: 12 additions & 11 deletions Seq2SeqSharp/Corpus/ParallelCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,13 @@ public interface ICorpus<out T> : IEnumerable<T>
}


public long GetNextLength(Dictionary<long, LinkedList<long>> len2offsets, Dictionary<long, long> len2counts)
public long GetNextLength(Dictionary<long, long> len2counts, long totalRecordsNum)
{
long rndItems = rnd.NextInt64(totalRecordsNum);
long totalItems = 0;
foreach (var pair in len2offsets)
foreach (var pair in len2counts)
{
totalItems += len2counts[pair.Key];
}

int rndItems = rnd.Next((int)totalItems);
totalItems = 0;
foreach (var pair in len2offsets)
{
long length = len2counts[pair.Key];
long length = pair.Value;
if (totalItems <= rndItems && totalItems + length >= rndItems)
{
return pair.Key;
Expand All @@ -343,6 +337,12 @@ public void PrepareDataSet()
m_batchNumInTotal = 0;
(var length2offsets, var length2counts, string tmpDataSetFilePath) = BuildIndex();

long totalRecordsNum = 0;
foreach (var pair in length2offsets)
{
totalRecordsNum += length2counts[pair.Key];
}

Logger.WriteLine(Logger.Level.debug, $"Start to sort and shuffle data set by length.");

m_sortedIndexedDataSetFilePath = tmpDataSetFilePath + ".sorted";
Expand All @@ -359,7 +359,7 @@ public void PrepareDataSet()
{
while (length2offsets.Count > 0)
{
long length = GetNextLength(length2offsets, length2counts);
long length = GetNextLength(length2counts, totalRecordsNum);
LinkedList<long> offsets = length2offsets[length];

int totalSrcTokenSize = 0;
Expand All @@ -372,6 +372,7 @@ public void PrepareDataSet()
long offset = offsets.First.Value;
offsets.RemoveFirst();
length2counts[length]--;
totalRecordsNum--;

br.BaseStream.Seek(offset, SeekOrigin.Begin);

Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Corpus/VisionTextCorpusBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class VisionTextCorpusBatch : IVisionSntPairBatch

public int BatchSize => SrcBatchPaths.Count;

public int SrcTokenCount { get; set; } = 1;
public int SrcTokenCount { get; set; } = 768;
public int TgtTokenCount { get; set; }

public IPairBatch CloneSrcTokens()
Expand Down

0 comments on commit 5bd2a37

Please sign in to comment.