Skip to content

Commit

Permalink
Truncate tokens at seqLength
Browse files Browse the repository at this point in the history
From @gevorgter's ticket NMZivkovic#18.
  • Loading branch information
ctwardy authored May 17, 2023
1 parent 150e40a commit 3b62257
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/Base/TokenizerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@ public TokenizerBase(string vocabularyFilePath)
}


// Includes @gevorgeter patch to take only sequenceLength tokens.
public List<(long InputIds, long TokenTypeIds, long AttentionMask)> Encode(int sequenceLength, params string[] texts)
{
var tokens = Tokenize(texts);

var padding = Enumerable.Repeat(0L, sequenceLength - tokens.Count).ToList();

var tokenIndexes = tokens.Select(token => (long)token.VocabularyIndex).Concat(padding).ToArray();
var segmentIndexes = tokens.Select(token => token.SegmentIndex).Concat(padding).ToArray();
var inputMask = tokens.Select(o => 1L).Concat(padding).ToArray();
List<long> padding;

if (sequenceLength > tokens.Count)
padding = Enumerable.Repeat(0L, sequenceLength - tokens.Count).ToList();
else
padding = new List<long>();

var tokenIndexes = tokens.Select(token => (long)token.VocabularyIndex).Concat(padding).Take(sequenceLength).ToArray();
var segmentIndexes = tokens.Select(token => token.SegmentIndex).Concat(padding).Take(sequenceLength).ToArray();
var inputMask = tokens.Select(o => 1L).Concat(padding).Take(sequenceLength).ToArray();

var output = tokenIndexes.Zip(segmentIndexes, Tuple.Create)
.Zip(inputMask, (t, z) => Tuple.Create(t.Item1, t.Item2, z));
Expand Down

0 comments on commit 3b62257

Please sign in to comment.