Skip to content

Commit

Permalink
Fixes #7271 AOT for ML.Tokenizers (#7272)
Browse files Browse the repository at this point in the history
* AOT for ML.Tokenizers

* Forgot to add ModelSourceGenerationContext

* Update src/Microsoft.ML.Tokenizers/Model/ModelSourceGenerationContext.cs

Co-authored-by: Eirik Tsarpalis <[email protected]>

---------

Co-authored-by: Eirik Tsarpalis <[email protected]>
  • Loading branch information
euju-ms and eiriktsarpalis authored Oct 17, 2024
1 parent 823fc17 commit f385b06
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 16 deletions.
7 changes: 3 additions & 4 deletions src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -757,11 +757,10 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
/// Read the given files to extract the vocab and merges
internal static async ValueTask<(Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>)> ReadModelDataAsync(Stream vocab, Stream? merges, bool useAsync, CancellationToken cancellationToken = default)
{
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
Dictionary<StringSpanOrdinalKey, int>? dic = useAsync
? await JsonSerializer.DeserializeAsync(vocab, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32, cancellationToken).ConfigureAwait(false)
: JsonSerializer.Deserialize(vocab, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32);

Dictionary<StringSpanOrdinalKey, int>? dic = useAsync ?
await JsonSerializer.DeserializeAsync<Dictionary<StringSpanOrdinalKey, int>>(vocab, options, cancellationToken).ConfigureAwait(false) as Dictionary<StringSpanOrdinalKey, int> :
JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocab, options) as Dictionary<StringSpanOrdinalKey, int>;
var m = useAsync ?
await ConvertMergesToHashmapAsync(merges, useAsync, cancellationToken).ConfigureAwait(false) :
ConvertMergesToHashmapAsync(merges, useAsync).GetAwaiter().GetResult();
Expand Down
5 changes: 2 additions & 3 deletions src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1764,11 +1764,10 @@ void TryMerge(int left, int right, ReadOnlySpan<char> textSpan)

private static Dictionary<StringSpanOrdinalKey, (int, string)> GetVocabulary(Stream vocabularyStream)
{
Dictionary<StringSpanOrdinalKey, (int, string)>? vocab;
Vocabulary? vocab;
try
{
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyCustomConverter.Instance } };
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, (int, string)>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, (int, string)>;
vocab = JsonSerializer.Deserialize(vocabularyStream, ModelSourceGenerationContext.Default.Vocabulary);
}
catch (Exception e)
{
Expand Down
3 changes: 1 addition & 2 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ private static Dictionary<StringSpanOrdinalKey, int> GetVocabulary(Stream vocabu
Dictionary<StringSpanOrdinalKey, int>? vocab;
try
{
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
vocab = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocabularyStream, options) as Dictionary<StringSpanOrdinalKey, int>;
vocab = JsonSerializer.Deserialize(vocabularyStream, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32);
}
catch (Exception e)
{
Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/ModelSourceGenerationContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace Microsoft.ML.Tokenizers;

[JsonSerializable(typeof(Dictionary<StringSpanOrdinalKey, int>))]
[JsonSerializable(typeof(Vocabulary))]
internal partial class ModelSourceGenerationContext : JsonSerializerContext;
15 changes: 8 additions & 7 deletions src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace Microsoft.ML.Tokenizers
/// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
/// always be used with a string.
/// </remarks>
[JsonConverter(typeof(StringSpanOrdinalKeyConverter))]
internal readonly unsafe struct StringSpanOrdinalKey : IEquatable<StringSpanOrdinalKey>
{
public readonly char* Ptr;
Expand Down Expand Up @@ -124,12 +125,14 @@ internal void Set(string k, TValue v)
}
}

[JsonConverter(typeof(VocabularyConverter))]
internal sealed class Vocabulary : Dictionary<StringSpanOrdinalKey, (int, string)>;

/// <summary>
/// Custom JSON converter for <see cref="StringSpanOrdinalKey"/>.
/// </summary>
internal sealed class StringSpanOrdinalKeyConverter : JsonConverter<StringSpanOrdinalKey>
{
public static StringSpanOrdinalKeyConverter Instance { get; } = new StringSpanOrdinalKeyConverter();
public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) =>
new StringSpanOrdinalKey(reader.GetString()!);

Expand All @@ -140,13 +143,11 @@ public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdina
public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
}

internal class StringSpanOrdinalKeyCustomConverter : JsonConverter<Dictionary<StringSpanOrdinalKey, (int, string)>>
internal class VocabularyConverter : JsonConverter<Vocabulary>
{
public static StringSpanOrdinalKeyCustomConverter Instance { get; } = new StringSpanOrdinalKeyCustomConverter();

public override Dictionary<StringSpanOrdinalKey, (int, string)> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
public override Vocabulary Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var dictionary = new Dictionary<StringSpanOrdinalKey, (int, string)>();
var dictionary = new Vocabulary();
while (reader.Read())
{
if (reader.TokenType == JsonTokenType.EndObject)
Expand All @@ -165,7 +166,7 @@ internal class StringSpanOrdinalKeyCustomConverter : JsonConverter<Dictionary<St
throw new JsonException("Invalid JSON.");
}

public override void Write(Utf8JsonWriter writer, Dictionary<StringSpanOrdinalKey, (int, string)> value, JsonSerializerOptions options) => throw new NotImplementedException();
public override void Write(Utf8JsonWriter writer, Vocabulary value, JsonSerializerOptions options) => throw new NotImplementedException();
}

/// <summary>
Expand Down

0 comments on commit f385b06

Please sign in to comment.