From bc6d2abe24183cee07899a1306f0029906bf8e10 Mon Sep 17 00:00:00 2001 From: Patrick Hovsepian Date: Fri, 7 Jun 2024 19:42:23 -0700 Subject: [PATCH 1/4] Wip2 (#1) - more generic default template history transformer - EOT and EOS token to string - minor refactor of ModelToken to readonly struct - template convenience methods - exposed metadata get by key native handle - update example to use update default history transform and convenience anti prompt value from model --- LLama.Examples/Examples/LLama3ChatSession.cs | 103 +++++------------- .../Native/SafeLlamaModelHandleTests.cs | 39 +++++++ LLama.Unittest/TemplateTests.cs | 67 +++++++++++- .../PromptTemplateTransformerTests.cs | 36 ++++++ LLama/ChatSession.cs | 2 +- LLama/Common/InferenceParams.cs | 3 +- LLama/LLamaExecutorBase.cs | 4 +- LLama/LLamaInteractExecutor.cs | 2 +- LLama/LLamaTemplate.cs | 44 +++++++- LLama/LLamaWeights.cs | 1 - LLama/Native/LLamaToken.cs | 3 + LLama/Native/SafeLLamaContextHandle.cs | 4 + LLama/Native/SafeLlamaModelHandle.cs | 92 ++++++++++++++-- .../Transformers/PromptTemplateTransformer.cs | 45 ++++++++ 14 files changed, 343 insertions(+), 102 deletions(-) create mode 100644 LLama.Unittest/Native/SafeLlamaModelHandleTests.cs create mode 100644 LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs create mode 100644 LLama/Transformers/PromptTemplateTransformer.cs diff --git a/LLama.Examples/Examples/LLama3ChatSession.cs b/LLama.Examples/Examples/LLama3ChatSession.cs index c9a32e0ce..1b5b4442c 100644 --- a/LLama.Examples/Examples/LLama3ChatSession.cs +++ b/LLama.Examples/Examples/LLama3ChatSession.cs @@ -1,38 +1,47 @@ -using LLama.Abstractions; -using LLama.Common; +using LLama.Common; +using LLama.Transformers; namespace LLama.Examples.Examples; -// When using chatsession, it's a common case that you want to strip the role names -// rather than display them. This example shows how to use transforms to strip them. +/// +/// This sample shows a simple chatbot +/// It's configured to use the default prompt template as provided by llama.cpp and supports +/// models such as llama3, llama2, phi3, qwen1.5, etc. +/// public class LLama3ChatSession { public static async Task Run() { - string modelPath = UserSettings.GetModelPath(); - + var modelPath = UserSettings.GetModelPath(); var parameters = new ModelParams(modelPath) { Seed = 1337, GpuLayerCount = 10 }; + using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); - session.WithHistoryTransform(new LLama3HistoryTransform()); + + // add the default templator. If llama.cpp doesn't support the template by default, + // you'll need to write your own transformer to format the prompt correctly + session.WithHistoryTransform(new PromptTemplateTransformer(model, withAssistant: true)); + + // Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed somtimes session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( - new string[] { "User:", "Assistant:", "�" }, + [model.Tokens.EndOfTurnToken!, "�"], redundancyLength: 5)); - InferenceParams inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams() { + MaxTokens = -1, // keep generating tokens until the anti prompt is encountered Temperature = 0.6f, - AntiPrompts = new List { "User:" } + AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string }; Console.ForegroundColor = ConsoleColor.Yellow; @@ -40,10 +49,15 @@ public static async Task Run() // show the prompt Console.ForegroundColor = ConsoleColor.Green; - string userInput = Console.ReadLine() ?? ""; + Console.Write("User> "); + var userInput = Console.ReadLine() ?? ""; while (userInput != "exit") { + Console.ForegroundColor = ConsoleColor.White; + Console.Write("Assistant> "); + + // as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc await foreach ( var text in session.ChatAsync( @@ -56,71 +70,8 @@ in session.ChatAsync( Console.WriteLine(); Console.ForegroundColor = ConsoleColor.Green; + Console.Write("User> "); userInput = Console.ReadLine() ?? ""; - - Console.ForegroundColor = ConsoleColor.White; - } - } - - class LLama3HistoryTransform : IHistoryTransform - { - /// - /// Convert a ChatHistory instance to plain text. - /// - /// The ChatHistory instance - /// - public string HistoryToText(ChatHistory history) - { - string res = Bos; - foreach (var message in history.Messages) - { - res += EncodeMessage(message); - } - res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, "")); - return res; - } - - private string EncodeHeader(ChatHistory.Message message) - { - string res = StartHeaderId; - res += message.AuthorRole.ToString(); - res += EndHeaderId; - res += "\n\n"; - return res; - } - - private string EncodeMessage(ChatHistory.Message message) - { - string res = EncodeHeader(message); - res += message.Content; - res += EndofTurn; - return res; } - - /// - /// Converts plain text to a ChatHistory instance. - /// - /// The role for the author. - /// The chat history as plain text. - /// The updated history. - public ChatHistory TextToHistory(AuthorRole role, string text) - { - return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) }); - } - - /// - /// Copy the transform. - /// - /// - public IHistoryTransform Clone() - { - return new LLama3HistoryTransform(); - } - - private const string StartHeaderId = "<|start_header_id|>"; - private const string EndHeaderId = "<|end_header_id|>"; - private const string Bos = "<|begin_of_text|>"; - private const string Eos = "<|end_of_text|>"; - private const string EndofTurn = "<|eot_id|>"; } } diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs new file mode 100644 index 000000000..5211d4f6a --- /dev/null +++ b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs @@ -0,0 +1,39 @@ +using System.Text; +using LLama.Common; +using LLama.Native; +using LLama.Extensions; + +namespace LLama.Unittest.Native; + +public class SafeLlamaModelHandleTests +{ + private readonly LLamaWeights _model; + private readonly SafeLlamaModelHandle TestableHandle; + + public SafeLlamaModelHandleTests() + { + var @params = new ModelParams(Constants.GenerativeModelPath) + { + ContextSize = 1, + GpuLayerCount = Constants.CIGpuLayerCount + }; + _model = LLamaWeights.LoadFromFile(@params); + + TestableHandle = _model.NativeHandle; + } + + [Fact] + public void MetadataValByKey_ReturnsCorrectly() + { + const string key = "general.name"; + var template = _model.NativeHandle.MetadataValueByKey(key); + var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span); + + const string expected = "LLaMA v2"; + Assert.Equal(expected, name); + + var metadataLookup = _model.Metadata[key]; + Assert.Equal(expected, metadataLookup); + Assert.Equal(name, metadataLookup); + } +} diff --git a/LLama.Unittest/TemplateTests.cs b/LLama.Unittest/TemplateTests.cs index 3a5bb0cea..b262b154a 100644 --- a/LLama.Unittest/TemplateTests.cs +++ b/LLama.Unittest/TemplateTests.cs @@ -1,6 +1,6 @@ using System.Text; using LLama.Common; -using LLama.Native; +using LLama.Extensions; namespace LLama.Unittest; @@ -8,7 +8,7 @@ public sealed class TemplateTests : IDisposable { private readonly LLamaWeights _model; - + public TemplateTests() { var @params = new ModelParams(Constants.GenerativeModelPath) @@ -18,12 +18,12 @@ public TemplateTests() }; _model = LLamaWeights.LoadFromFile(@params); } - + public void Dispose() { _model.Dispose(); } - + [Fact] public void BasicTemplate() { @@ -173,6 +173,53 @@ public void BasicTemplateWithAddAssistant() Assert.Equal(expected, templateResult); } + [Fact] + public void ToModelPrompt_FormatsCorrectly() + { + var templater = new LLamaTemplate(_model) + { + AddAssistant = true, + }; + + Assert.Equal(0, templater.Count); + templater.Add("assistant", "hello"); + Assert.Equal(1, templater.Count); + templater.Add("user", "world"); + Assert.Equal(2, templater.Count); + templater.Add("assistant", "111"); + Assert.Equal(3, templater.Count); + templater.Add("user", "aaa"); + Assert.Equal(4, templater.Count); + templater.Add("assistant", "222"); + Assert.Equal(5, templater.Count); + templater.Add("user", "bbb"); + Assert.Equal(6, templater.Count); + templater.Add("assistant", "333"); + Assert.Equal(7, templater.Count); + templater.Add("user", "ccc"); + Assert.Equal(8, templater.Count); + + // Call once with empty array to discover length + var templateResult = templater.ToModelPrompt(); + const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + + "<|im_start|>user\nworld<|im_end|>\n" + + "<|im_start|>assistant\n" + + "111<|im_end|>" + + "\n<|im_start|>user\n" + + "aaa<|im_end|>\n" + + "<|im_start|>assistant\n" + + "222<|im_end|>\n" + + "<|im_start|>user\n" + + "bbb<|im_end|>\n" + + "<|im_start|>assistant\n" + + "333<|im_end|>\n" + + "<|im_start|>user\n" + + "ccc<|im_end|>\n" + + "<|im_start|>assistant\n"; + + Assert.Equal(expected, templateResult); + } + [Fact] public void GetOutOfRangeThrows() { @@ -249,4 +296,16 @@ public void RemoveOutOfRange() Assert.Throws(() => templater.RemoveAt(-1)); Assert.Throws(() => templater.RemoveAt(2)); } + + [Fact] + public void EndOTurnToken_ReturnsExpected() + { + Assert.Null(_model.Tokens.EndOfTurnToken); + } + + [Fact] + public void EndOSpeechToken_ReturnsExpected() + { + Assert.Equal("", _model.Tokens.EndOfSpeechToken); + } } \ No newline at end of file diff --git a/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs new file mode 100644 index 000000000..0713e1236 --- /dev/null +++ b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs @@ -0,0 +1,36 @@ +using LLama.Common; +using LLama.Transformers; + +namespace LLama.Unittest.Transformers; + +public class PromptTemplateTransformerTests +{ + private readonly LLamaWeights _model; + private readonly PromptTemplateTransformer TestableTransformer; + + public PromptTemplateTransformerTests() + { + var @params = new ModelParams(Constants.GenerativeModelPath) + { + ContextSize = 1, + GpuLayerCount = Constants.CIGpuLayerCount + }; + _model = LLamaWeights.LoadFromFile(@params); + + TestableTransformer = new PromptTemplateTransformer(_model, true); + } + + [Fact] + public void HistoryToText_EncodesCorrectly() + { + const string userData = nameof(userData); + var template = TestableTransformer.HistoryToText(new ChatHistory(){ + Messages = [new ChatHistory.Message(AuthorRole.User, userData)] + }); + + const string expected = "<|im_start|>user\n" + + $"{userData}<|im_end|>\n" + + "<|im_start|>assistant\n"; + Assert.Equal(expected, template); + } +} diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 3d5b5b616..2f667be0b 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -62,7 +62,7 @@ public class ChatSession /// /// The input transform pipeline used in this session. /// - public List InputTransformPipeline { get; set; } = new(); + public List InputTransformPipeline { get; set; } = []; /// /// The output transform used in this session. diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 44818a1ff..b2e429f83 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -1,5 +1,4 @@ using LLama.Abstractions; -using System; using System.Collections.Generic; using LLama.Native; using LLama.Sampling; @@ -31,7 +30,7 @@ public record InferenceParams /// /// Sequences where the model will stop generating further tokens. /// - public IReadOnlyList AntiPrompts { get; set; } = Array.Empty(); + public IReadOnlyList AntiPrompts { get; set; } = []; /// public int TopK { get; set; } = 40; diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 263ab2716..e01a40ccc 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -307,7 +307,7 @@ public virtual async IAsyncEnumerable InferAsync(string text, IInference var args = new InferStateArgs { - Antiprompts = inferenceParams.AntiPrompts.ToList(), + Antiprompts = [.. inferenceParams.AntiPrompts], RemainedTokens = inferenceParams.MaxTokens, ReturnValue = false, WaitForInput = false, @@ -359,7 +359,7 @@ public virtual async Task PrefillPromptAsync(string prompt) }; var args = new InferStateArgs { - Antiprompts = new List(), + Antiprompts = [], RemainedTokens = 0, ReturnValue = false, WaitForInput = true, diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 226b18ef9..869a0bb44 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -123,7 +123,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) } else { - PreprocessLlava(text, args, true ); + PreprocessLlava(text, args, true); } } else diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs index 0677ddb43..7e2b51ddc 100644 --- a/LLama/LLamaTemplate.cs +++ b/LLama/LLamaTemplate.cs @@ -28,12 +28,12 @@ public sealed class LLamaTemplate /// /// Keep a cache of roles converted into bytes. Roles are very frequently re-used, so this saves converting them many times. /// - private readonly Dictionary> _roleCache = new(); + private readonly Dictionary> _roleCache = []; /// /// Array of messages. The property indicates how many messages there are /// - private TextMessage?[] _messages = new TextMessage[4]; + private TextMessage[] _messages = new TextMessage[4]; /// /// Backing field for @@ -53,7 +53,7 @@ public sealed class LLamaTemplate /// /// Result bytes of last call to /// - private byte[] _result = Array.Empty(); + private byte[] _result = []; /// /// Indicates if this template has been modified and needs regenerating @@ -189,6 +189,21 @@ public LLamaTemplate RemoveAt(int index) return this; } + + /// + /// Remove all messags from the template and resets internal state to accept/generate new messages + /// + public void RemoveAllMessages() + { + _messages = new TextMessage[4]; + Count = 0; + + _resultLength = 0; + _result = []; + _nativeChatMessages = new LLamaChatMessage[4]; + + _dirty = true; + } #endregion /// @@ -213,7 +228,6 @@ public int Apply(Memory dest) for (var i = 0; i < Count; i++) { ref var m = ref _messages[i]!; - Debug.Assert(m != null); totalInputBytes += m.RoleBytes.Length + m.ContentBytes.Length; // Pin byte arrays in place @@ -233,7 +247,6 @@ public int Apply(Memory dest) var output = ArrayPool.Shared.Rent(Math.Max(32, totalInputBytes * 2)); try { - // Run templater and discover true length var outputLength = ApplyInternal(_nativeChatMessages.AsSpan(0, Count), output); @@ -278,10 +291,29 @@ unsafe int ApplyInternal(Span messages, byte[] output) } } + /// + /// Apply the template to the messages and return the resulting prompt as a string + /// + /// + /// The formatted template string as defined by the model + public string ToModelPrompt() + { + // Apply the template to update state and get data length + var dataLength = Apply(Array.Empty()); + + // convert the resulting buffer to a string +#if NET6_0_OR_GREATER + return Encoding.GetString(_result.AsSpan(0, dataLength)); +#endif + + // need the ToArray call for netstandard -- avoided in newer runtimes + return Encoding.GetString(_result.AsSpan(0, dataLength).ToArray()); + } + /// /// A message that has been added to a template /// - public sealed class TextMessage + public readonly struct TextMessage { /// /// The "role" string for this message diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index ce712b724..8646e4d93 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -5,7 +5,6 @@ using System.Threading.Tasks; using LLama.Abstractions; using LLama.Exceptions; -using LLama.Extensions; using LLama.Native; using Microsoft.Extensions.Logging; diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index 64d263a7a..e77193e09 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -10,6 +10,9 @@ namespace LLama.Native; [DebuggerDisplay("{Value}")] public readonly record struct LLamaToken { + /// Token Value used when token is inherently null + public static readonly LLamaToken INVALID_TOKEN = -1; + /// /// The raw value /// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index f54a8680b..3812a3517 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -19,6 +19,10 @@ public sealed class SafeLLamaContextHandle /// public int VocabCount => ThrowIfDisposed().VocabCount; + /// + /// The underlying vocabulary for the model + /// + /// public LLamaVocabType LLamaVocabType => ThrowIfDisposed().VocabType; /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index f24cfe5fd..f3177193b 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -6,7 +6,6 @@ using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; -using LLama.Extensions; namespace LLama.Native { @@ -221,11 +220,30 @@ private static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, /// /// /// - /// - /// + /// /// The length of the string on success, or -1 on failure - [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); + private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string key, Span dest) + { + var bytesCount = Encoding.UTF8.GetByteCount(key); + var bytes = ArrayPool.Shared.Rent(bytesCount); + + unsafe + { + fixed (char* keyPtr = key) + fixed (byte* bytesPtr = bytes) + fixed (byte* destPtr = dest) + { + // Convert text into bytes + Encoding.UTF8.GetBytes(keyPtr, key.Length, bytesPtr, bytes.Length); + + return llama_model_meta_val_str_native(model, bytesPtr, destPtr, dest.Length); + } + } + + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str")] + static extern unsafe int llama_model_meta_val_str_native(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); + } + /// /// Get the number of tokens in the model vocabulary @@ -461,8 +479,8 @@ internal Span TokensToSpan(IReadOnlyList tokens, Span de public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { // Early exit if there's no work to do - if (text == "" && !add_bos) - return Array.Empty(); + if (text == string.Empty && !add_bos) + return []; // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); @@ -484,7 +502,7 @@ public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding e var tokens = new LLamaToken[count]; fixed (LLamaToken* tokensPtr = tokens) { - NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); + _ = NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); return tokens; } } @@ -510,6 +528,26 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) #endregion #region metadata + /// + /// Get the metadata value for the given key + /// + /// The key to fetch + /// The value, null if there is no such key + public Memory? MetadataValueByKey(string key) + { + // Check if the key exists, without getting any bytes of data + var keyLength = llama_model_meta_val_str(this, key, []); + if (keyLength < 0) + return null; + + // get a buffer large enough to hold it + var buffer = new byte[keyLength + 1]; + keyLength = llama_model_meta_val_str(this, key, buffer); + Debug.Assert(keyLength >= 0); + + return buffer.AsMemory().Slice(0,keyLength); + } + /// /// Get the metadata key for the given index /// @@ -576,13 +614,39 @@ internal IReadOnlyDictionary ReadMetadata() /// /// Get tokens for a model /// - public class ModelTokens + public readonly struct ModelTokens { private readonly SafeLlamaModelHandle _model; + private readonly string? _eot; + private readonly string? _eos; internal ModelTokens(SafeLlamaModelHandle model) { _model = model; + _eot = LLamaTokenToString(EOT, true); + _eos = LLamaTokenToString(EOS, true); + } + + private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken) + { + const int buffSize = 32; + Span buff = stackalloc byte[buffSize]; + var tokenLength = _model.TokenToSpan(token ?? LLamaToken.INVALID_TOKEN, buff, special: isSpecialToken); + + if (tokenLength <= 0) + { + return null; + } + + // if the original buffer wasn't large enough, create a new one + if (tokenLength > buffSize) + { + buff = stackalloc byte[(int)tokenLength]; + _ = _model.TokenToSpan(token ?? LLamaToken.INVALID_TOKEN, buff, special: isSpecialToken); + } + + var slice = buff.Slice(0, (int)tokenLength); + return Encoding.UTF8.GetStringFromSpan(slice); } private static LLamaToken? Normalize(LLamaToken token) @@ -599,6 +663,11 @@ internal ModelTokens(SafeLlamaModelHandle model) /// Get the End of Sentence token for this model /// public LLamaToken? EOS => Normalize(llama_token_eos(_model)); + + /// + /// The textual representation of the end of speech special token for this model + /// + public string? EndOfSpeechToken => _eos; /// /// Get the newline token for this model @@ -635,6 +704,11 @@ internal ModelTokens(SafeLlamaModelHandle model) /// public LLamaToken? EOT => Normalize(llama_token_eot(_model)); + /// + /// Returns the string representation of this model's end_of_text token + /// + public string? EndOfTurnToken => _eot; + /// /// Check if the given token should end generation /// diff --git a/LLama/Transformers/PromptTemplateTransformer.cs b/LLama/Transformers/PromptTemplateTransformer.cs new file mode 100644 index 000000000..19bacae93 --- /dev/null +++ b/LLama/Transformers/PromptTemplateTransformer.cs @@ -0,0 +1,45 @@ +using System.Text; +using LLama.Abstractions; +using LLama.Common; + +namespace LLama.Transformers; + +/// +/// A prompt formatter that will use llama.cpp's template formatter +/// If your model is not supported, you will need to define your own formatter according the cchat prompt specification for your model +/// +public class PromptTemplateTransformer(LLamaWeights model, + bool withAssistant = true) : IHistoryTransform +{ + private readonly LLamaWeights _model = model; + private readonly bool _withAssistant = withAssistant; + + /// + public string HistoryToText(ChatHistory history) + { + var template = new LLamaTemplate(_model.NativeHandle) + { + AddAssistant = _withAssistant, + }; + + // encode each message and return the final prompt + foreach (var message in history.Messages) + { + template.Add(message.AuthorRole.ToString().ToLowerInvariant(), message.Content); + } + return template.ToModelPrompt(); + } + + /// + public ChatHistory TextToHistory(AuthorRole role, string text) + { + return new ChatHistory([new ChatHistory.Message(role, text)]); + } + + /// + public IHistoryTransform Clone() + { + // need to preserve history? + return new PromptTemplateTransformer(_model); + } +} From 6808139ba1595017343c9dae313aa1dce4e61153 Mon Sep 17 00:00:00 2001 From: pat_hov Date: Sat, 8 Jun 2024 17:17:50 -0700 Subject: [PATCH 2/4] PR feedback and test tweaks --- LLama.Examples/Examples/LLama3ChatSession.cs | 2 +- LLama.Unittest/TemplateTests.cs | 78 ++++++++----------- .../PromptTemplateTransformerTests.cs | 47 +++++++++++ LLama/LLamaTemplate.cs | 35 +++------ LLama/Native/LLamaToken.cs | 4 +- LLama/Native/SafeLlamaModelHandle.cs | 6 +- .../Transformers/PromptTemplateTransformer.cs | 26 ++++++- 7 files changed, 121 insertions(+), 77 deletions(-) diff --git a/LLama.Examples/Examples/LLama3ChatSession.cs b/LLama.Examples/Examples/LLama3ChatSession.cs index 1b5b4442c..01aa33cd6 100644 --- a/LLama.Examples/Examples/LLama3ChatSession.cs +++ b/LLama.Examples/Examples/LLama3ChatSession.cs @@ -32,7 +32,7 @@ public static async Task Run() // you'll need to write your own transformer to format the prompt correctly session.WithHistoryTransform(new PromptTemplateTransformer(model, withAssistant: true)); - // Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed somtimes + // Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( [model.Tokens.EndOfTurnToken!, "�"], redundancyLength: 5)); diff --git a/LLama.Unittest/TemplateTests.cs b/LLama.Unittest/TemplateTests.cs index b262b154a..a435409de 100644 --- a/LLama.Unittest/TemplateTests.cs +++ b/LLama.Unittest/TemplateTests.cs @@ -173,53 +173,6 @@ public void BasicTemplateWithAddAssistant() Assert.Equal(expected, templateResult); } - [Fact] - public void ToModelPrompt_FormatsCorrectly() - { - var templater = new LLamaTemplate(_model) - { - AddAssistant = true, - }; - - Assert.Equal(0, templater.Count); - templater.Add("assistant", "hello"); - Assert.Equal(1, templater.Count); - templater.Add("user", "world"); - Assert.Equal(2, templater.Count); - templater.Add("assistant", "111"); - Assert.Equal(3, templater.Count); - templater.Add("user", "aaa"); - Assert.Equal(4, templater.Count); - templater.Add("assistant", "222"); - Assert.Equal(5, templater.Count); - templater.Add("user", "bbb"); - Assert.Equal(6, templater.Count); - templater.Add("assistant", "333"); - Assert.Equal(7, templater.Count); - templater.Add("user", "ccc"); - Assert.Equal(8, templater.Count); - - // Call once with empty array to discover length - var templateResult = templater.ToModelPrompt(); - const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + - "<|im_start|>user\nworld<|im_end|>\n" + - "<|im_start|>assistant\n" + - "111<|im_end|>" + - "\n<|im_start|>user\n" + - "aaa<|im_end|>\n" + - "<|im_start|>assistant\n" + - "222<|im_end|>\n" + - "<|im_start|>user\n" + - "bbb<|im_end|>\n" + - "<|im_start|>assistant\n" + - "333<|im_end|>\n" + - "<|im_start|>user\n" + - "ccc<|im_end|>\n" + - "<|im_start|>assistant\n"; - - Assert.Equal(expected, templateResult); - } - [Fact] public void GetOutOfRangeThrows() { @@ -297,6 +250,37 @@ public void RemoveOutOfRange() Assert.Throws(() => templater.RemoveAt(2)); } + [Fact] + public void Clear_ResetsTemplateState() + { + var templater = new LLamaTemplate(_model); + templater.Add("assistant", "1") + .Add("user", "2"); + + Assert.Equal(2, templater.Count); + + templater.Clear(); + + Assert.Equal(0, templater.Count); + + const string userData = nameof(userData); + templater.Add("user", userData); + + // Generte the template string + // Call once with empty array to discover length + var length = templater.Apply(Array.Empty()); + var dest = new byte[length]; + + Assert.Equal(1, templater.Count); + + // Call again to get contents + _ = templater.Apply(dest); + var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + + const string expectedTemplate = $"<|im_start|>user\n{userData}<|im_end|>\n"; + Assert.Equal(expectedTemplate, templateResult); + } + [Fact] public void EndOTurnToken_ReturnsExpected() { diff --git a/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs index 0713e1236..9b1255f9b 100644 --- a/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs +++ b/LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs @@ -33,4 +33,51 @@ public void HistoryToText_EncodesCorrectly() "<|im_start|>assistant\n"; Assert.Equal(expected, template); } + + [Fact] + public void ToModelPrompt_FormatsCorrectly() + { + var templater = new LLamaTemplate(_model) + { + AddAssistant = true, + }; + + Assert.Equal(0, templater.Count); + templater.Add("assistant", "hello"); + Assert.Equal(1, templater.Count); + templater.Add("user", "world"); + Assert.Equal(2, templater.Count); + templater.Add("assistant", "111"); + Assert.Equal(3, templater.Count); + templater.Add("user", "aaa"); + Assert.Equal(4, templater.Count); + templater.Add("assistant", "222"); + Assert.Equal(5, templater.Count); + templater.Add("user", "bbb"); + Assert.Equal(6, templater.Count); + templater.Add("assistant", "333"); + Assert.Equal(7, templater.Count); + templater.Add("user", "ccc"); + Assert.Equal(8, templater.Count); + + // Call once with empty array to discover length + var templateResult = PromptTemplateTransformer.ToModelPrompt(templater); + const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + + "<|im_start|>user\nworld<|im_end|>\n" + + "<|im_start|>assistant\n" + + "111<|im_end|>" + + "\n<|im_start|>user\n" + + "aaa<|im_end|>\n" + + "<|im_start|>assistant\n" + + "222<|im_end|>\n" + + "<|im_start|>user\n" + + "bbb<|im_end|>\n" + + "<|im_start|>assistant\n" + + "333<|im_end|>\n" + + "<|im_start|>user\n" + + "ccc<|im_end|>\n" + + "<|im_start|>assistant\n"; + + Assert.Equal(expected, templateResult); + } } diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs index 7e2b51ddc..414fd5c67 100644 --- a/LLama/LLamaTemplate.cs +++ b/LLama/LLamaTemplate.cs @@ -13,8 +13,6 @@ namespace LLama; public sealed class LLamaTemplate { #region private state - private static readonly Encoding Encoding = Encoding.UTF8; - /// /// The model this template is for. May be null if a custom template was supplied to the constructor. /// @@ -62,6 +60,11 @@ public sealed class LLamaTemplate #endregion #region properties + /// + /// The encoding algorithm to use + /// + public static readonly Encoding Encoding = Encoding.UTF8; + /// /// Number of messages added to this template /// @@ -101,6 +104,11 @@ public bool AddAssistant } } } + + /// + /// Get a span to the underlying bytes as specified by Encoding + /// + public ReadOnlySpan TemplateDataBuffer => _result; #endregion #region construction @@ -191,9 +199,9 @@ public LLamaTemplate RemoveAt(int index) } /// - /// Remove all messags from the template and resets internal state to accept/generate new messages + /// Remove all messages from the template and resets internal state to accept/generate new messages /// - public void RemoveAllMessages() + public void Clear() { _messages = new TextMessage[4]; Count = 0; @@ -291,25 +299,6 @@ unsafe int ApplyInternal(Span messages, byte[] output) } } - /// - /// Apply the template to the messages and return the resulting prompt as a string - /// - /// - /// The formatted template string as defined by the model - public string ToModelPrompt() - { - // Apply the template to update state and get data length - var dataLength = Apply(Array.Empty()); - - // convert the resulting buffer to a string -#if NET6_0_OR_GREATER - return Encoding.GetString(_result.AsSpan(0, dataLength)); -#endif - - // need the ToArray call for netstandard -- avoided in newer runtimes - return Encoding.GetString(_result.AsSpan(0, dataLength).ToArray()); - } - /// /// A message that has been added to a template /// diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index e77193e09..dd8bca1e2 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -10,8 +10,10 @@ namespace LLama.Native; [DebuggerDisplay("{Value}")] public readonly record struct LLamaToken { + /// /// Token Value used when token is inherently null - public static readonly LLamaToken INVALID_TOKEN = -1; + /// + public static readonly LLamaToken InvalidToken = -1; /// /// The raw value diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index f3177193b..fa9495aac 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -234,7 +234,7 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k fixed (byte* destPtr = dest) { // Convert text into bytes - Encoding.UTF8.GetBytes(keyPtr, key.Length, bytesPtr, bytes.Length); + Encoding.UTF8.GetBytes(keyPtr, key.Length, bytesPtr, bytesCount); return llama_model_meta_val_str_native(model, bytesPtr, destPtr, dest.Length); } @@ -631,7 +631,7 @@ internal ModelTokens(SafeLlamaModelHandle model) { const int buffSize = 32; Span buff = stackalloc byte[buffSize]; - var tokenLength = _model.TokenToSpan(token ?? LLamaToken.INVALID_TOKEN, buff, special: isSpecialToken); + var tokenLength = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken); if (tokenLength <= 0) { @@ -642,7 +642,7 @@ internal ModelTokens(SafeLlamaModelHandle model) if (tokenLength > buffSize) { buff = stackalloc byte[(int)tokenLength]; - _ = _model.TokenToSpan(token ?? LLamaToken.INVALID_TOKEN, buff, special: isSpecialToken); + _ = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken); } var slice = buff.Slice(0, (int)tokenLength); diff --git a/LLama/Transformers/PromptTemplateTransformer.cs b/LLama/Transformers/PromptTemplateTransformer.cs index 19bacae93..3b78acb9f 100644 --- a/LLama/Transformers/PromptTemplateTransformer.cs +++ b/LLama/Transformers/PromptTemplateTransformer.cs @@ -1,4 +1,5 @@ -using System.Text; +using System; +using System.Text; using LLama.Abstractions; using LLama.Common; @@ -27,7 +28,7 @@ public string HistoryToText(ChatHistory history) { template.Add(message.AuthorRole.ToString().ToLowerInvariant(), message.Content); } - return template.ToModelPrompt(); + return ToModelPrompt(template); } /// @@ -42,4 +43,25 @@ public IHistoryTransform Clone() // need to preserve history? return new PromptTemplateTransformer(_model); } + + #region utils + /// + /// Apply the template to the messages and return the resulting prompt as a string + /// + /// + /// The formatted template string as defined by the model + public static string ToModelPrompt(LLamaTemplate template) + { + // Apply the template to update state and get data length + var dataLength = template.Apply(Array.Empty()); + + // convert the resulting buffer to a string +#if NET6_0_OR_GREATER + return LLamaTemplate.Encoding.GetString(template.TemplateDataBuffer[..dataLength]); +#endif + + // need the ToArray call for netstandard -- avoided in newer runtimes + return LLamaTemplate.Encoding.GetString(template.TemplateDataBuffer.Slice(0, dataLength).ToArray()); + } + #endregion utils } From 3fa9addeacdef7c5309c9bdae7561d7ab6a92b37 Mon Sep 17 00:00:00 2001 From: Pat Hov Date: Sun, 9 Jun 2024 11:09:23 -0700 Subject: [PATCH 3/4] sealed class --- LLama/Native/SafeLlamaModelHandle.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index fa9495aac..1597908e3 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -614,7 +614,7 @@ internal IReadOnlyDictionary ReadMetadata() /// /// Get tokens for a model /// - public readonly struct ModelTokens + public sealed class ModelTokens { private readonly SafeLlamaModelHandle _model; private readonly string? _eot; From 8c9bbb637cdaee287ec6e26e37077497803efc6d Mon Sep 17 00:00:00 2001 From: Pat Hov Date: Sun, 9 Jun 2024 15:37:24 -0700 Subject: [PATCH 4/4] return span on apply --- LLama.Unittest/TemplateTests.cs | 45 ++++--------------- LLama/LLamaTemplate.cs | 13 ++---- .../Transformers/PromptTemplateTransformer.cs | 6 +-- 3 files changed, 14 insertions(+), 50 deletions(-) diff --git a/LLama.Unittest/TemplateTests.cs b/LLama.Unittest/TemplateTests.cs index a435409de..9520905b6 100644 --- a/LLama.Unittest/TemplateTests.cs +++ b/LLama.Unittest/TemplateTests.cs @@ -47,18 +47,10 @@ public void BasicTemplate() templater.Add("user", "ccc"); Assert.Equal(8, templater.Count); - // Call once with empty array to discover length - var length = templater.Apply(Array.Empty()); - var dest = new byte[length]; - - Assert.Equal(8, templater.Count); - - // Call again to get contents - length = templater.Apply(dest); - + var dest = templater.Apply(); Assert.Equal(8, templater.Count); - var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + var templateResult = Encoding.UTF8.GetString(dest); const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + "<|im_start|>user\nworld<|im_end|>\n" + "<|im_start|>assistant\n" + @@ -93,17 +85,10 @@ public void CustomTemplate() Assert.Equal(4, templater.Count); // Call once with empty array to discover length - var length = templater.Apply(Array.Empty()); - var dest = new byte[length]; - - Assert.Equal(4, templater.Count); - - // Call again to get contents - length = templater.Apply(dest); - + var dest = templater.Apply(); Assert.Equal(4, templater.Count); - var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + var templateResult = Encoding.UTF8.GetString(dest); const string expected = "model\n" + "hello\n" + "user\n" + @@ -143,17 +128,10 @@ public void BasicTemplateWithAddAssistant() Assert.Equal(8, templater.Count); // Call once with empty array to discover length - var length = templater.Apply(Array.Empty()); - var dest = new byte[length]; - - Assert.Equal(8, templater.Count); - - // Call again to get contents - length = templater.Apply(dest); - + var dest = templater.Apply(); Assert.Equal(8, templater.Count); - var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + var templateResult = Encoding.UTF8.GetString(dest); const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" + "<|im_start|>user\nworld<|im_end|>\n" + "<|im_start|>assistant\n" + @@ -267,15 +245,8 @@ public void Clear_ResetsTemplateState() templater.Add("user", userData); // Generte the template string - // Call once with empty array to discover length - var length = templater.Apply(Array.Empty()); - var dest = new byte[length]; - - Assert.Equal(1, templater.Count); - - // Call again to get contents - _ = templater.Apply(dest); - var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length)); + var dest = templater.Apply(); + var templateResult = Encoding.UTF8.GetString(dest); const string expectedTemplate = $"<|im_start|>user\n{userData}<|im_end|>\n"; Assert.Equal(expectedTemplate, templateResult); diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs index 414fd5c67..fb2268ac2 100644 --- a/LLama/LLamaTemplate.cs +++ b/LLama/LLamaTemplate.cs @@ -104,11 +104,6 @@ public bool AddAssistant } } } - - /// - /// Get a span to the underlying bytes as specified by Encoding - /// - public ReadOnlySpan TemplateDataBuffer => _result; #endregion #region construction @@ -217,9 +212,8 @@ public void Clear() /// /// Apply the template to the messages and write it into the output buffer /// - /// Destination to write template bytes into - /// The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer - public int Apply(Memory dest) + /// A span over the buffer that holds the applied template + public ReadOnlySpan Apply() { // Recalculate template if necessary if (_dirty) @@ -285,8 +279,7 @@ public int Apply(Memory dest) } // Now that the template has been applied and is in the result buffer, copy it to the dest - _result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span); - return _resultLength; + return _result.AsSpan(0, _resultLength); unsafe int ApplyInternal(Span messages, byte[] output) { diff --git a/LLama/Transformers/PromptTemplateTransformer.cs b/LLama/Transformers/PromptTemplateTransformer.cs index 3b78acb9f..3543f9a1a 100644 --- a/LLama/Transformers/PromptTemplateTransformer.cs +++ b/LLama/Transformers/PromptTemplateTransformer.cs @@ -53,15 +53,15 @@ public IHistoryTransform Clone() public static string ToModelPrompt(LLamaTemplate template) { // Apply the template to update state and get data length - var dataLength = template.Apply(Array.Empty()); + var templateBuffer = template.Apply(); // convert the resulting buffer to a string #if NET6_0_OR_GREATER - return LLamaTemplate.Encoding.GetString(template.TemplateDataBuffer[..dataLength]); + return LLamaTemplate.Encoding.GetString(templateBuffer); #endif // need the ToArray call for netstandard -- avoided in newer runtimes - return LLamaTemplate.Encoding.GetString(template.TemplateDataBuffer.Slice(0, dataLength).ToArray()); + return LLamaTemplate.Encoding.GetString(templateBuffer.ToArray()); } #endregion utils }