Skip to content

Commit

Permalink
return span on apply
Browse files Browse the repository at this point in the history
  • Loading branch information
Pat Hov committed Jun 9, 2024
1 parent 3fa9add commit 8c9bbb6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 50 deletions.
45 changes: 8 additions & 37 deletions LLama.Unittest/TemplateTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,10 @@ public void BasicTemplate()
templater.Add("user", "ccc");
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

var dest = templater.Apply();
Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
Expand Down Expand Up @@ -93,17 +85,10 @@ public void CustomTemplate()
Assert.Equal(4, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(4, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

var dest = templater.Apply();
Assert.Equal(4, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<start_of_turn>model\n" +
"hello<end_of_turn>\n" +
"<start_of_turn>user\n" +
Expand Down Expand Up @@ -143,17 +128,10 @@ public void BasicTemplateWithAddAssistant()
Assert.Equal(8, templater.Count);

// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(8, templater.Count);

// Call again to get contents
length = templater.Apply(dest);

var dest = templater.Apply();
Assert.Equal(8, templater.Count);

var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var templateResult = Encoding.UTF8.GetString(dest);
const string expected = "<|im_start|>assistant\nhello<|im_end|>\n" +
"<|im_start|>user\nworld<|im_end|>\n" +
"<|im_start|>assistant\n" +
Expand Down Expand Up @@ -267,15 +245,8 @@ public void Clear_ResetsTemplateState()
templater.Add("user", userData);

// Generte the template string
// Call once with empty array to discover length
var length = templater.Apply(Array.Empty<byte>());
var dest = new byte[length];

Assert.Equal(1, templater.Count);

// Call again to get contents
_ = templater.Apply(dest);
var templateResult = Encoding.UTF8.GetString(dest.AsSpan(0, length));
var dest = templater.Apply();
var templateResult = Encoding.UTF8.GetString(dest);

const string expectedTemplate = $"<|im_start|>user\n{userData}<|im_end|>\n";
Assert.Equal(expectedTemplate, templateResult);
Expand Down
13 changes: 3 additions & 10 deletions LLama/LLamaTemplate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ public bool AddAssistant
}
}
}

/// <summary>
/// Get a span to the underlying bytes as specified by Encoding
/// </summary>
public ReadOnlySpan<byte> TemplateDataBuffer => _result;
#endregion

#region construction
Expand Down Expand Up @@ -217,9 +212,8 @@ public void Clear()
/// <summary>
/// Apply the template to the messages and write it into the output buffer
/// </summary>
/// <param name="dest">Destination to write template bytes into</param>
/// <returns>The length of the template. If this is longer than dest.Length this method should be called again with a larger dest buffer</returns>
public int Apply(Memory<byte> dest)
/// <returns>A span over the buffer that holds the applied template</returns>
public ReadOnlySpan<byte> Apply()
{
// Recalculate template if necessary
if (_dirty)
Expand Down Expand Up @@ -285,8 +279,7 @@ public int Apply(Memory<byte> dest)
}

// Now that the template has been applied and is in the result buffer, copy it to the dest
_result.AsSpan(0, Math.Min(dest.Length, _resultLength)).CopyTo(dest.Span);
return _resultLength;
return _result.AsSpan(0, _resultLength);

unsafe int ApplyInternal(Span<LLamaChatMessage> messages, byte[] output)
{
Expand Down
6 changes: 3 additions & 3 deletions LLama/Transformers/PromptTemplateTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ public IHistoryTransform Clone()
public static string ToModelPrompt(LLamaTemplate template)
{
// Apply the template to update state and get data length
var dataLength = template.Apply(Array.Empty<byte>());
var templateBuffer = template.Apply();

// convert the resulting buffer to a string
#if NET6_0_OR_GREATER
return LLamaTemplate.Encoding.GetString(template.TemplateDataBuffer[..dataLength]);
return LLamaTemplate.Encoding.GetString(templateBuffer);
#endif

// need the ToArray call for netstandard -- avoided in newer runtimes
return LLamaTemplate.Encoding.GetString(template.TemplateDataBuffer.Slice(0, dataLength).ToArray());
return LLamaTemplate.Encoding.GetString(templateBuffer.ToArray());
}
#endregion utils
}

0 comments on commit 8c9bbb6

Please sign in to comment.