Skip to content

Commit

Permalink
Add strftime_now callable function for minijinja chat templates (#…
Browse files Browse the repository at this point in the history
…2983)

* Add `chrono` and `strftime_now` function callable

* Fix `test_chat_template_valid_with_strftime_now`

* Fix `test_chat_template_valid_with_strftime_now`
  • Loading branch information
alvarobartt authored Feb 3, 2025
1 parent e3f2018 commit 88fd56f
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
53 changes: 53 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
csv = "1.3.0"
ureq = "=2.9"
pyo3 = { workspace = true }
chrono = "0.4.39"


[build-dependencies]
Expand Down
83 changes: 82 additions & 1 deletion router/src/infer/chat_template.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::infer::InferError;
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
use chrono::Local;
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;

Expand All @@ -8,6 +9,11 @@ pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Err
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}

/// Get the current date in a specific format (custom function), similar to `datetime.now().strftime()` in Python
pub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Error> {
Ok(Local::now().format(&format_str).to_string())
}

#[derive(Clone)]
pub(crate) struct ChatTemplate {
template: Template<'static, 'static>,
Expand All @@ -27,6 +33,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
tracing::debug!("Loading template: {}", template_str);

// leaking env and template_str as read-only, static resources for performance.
Expand Down Expand Up @@ -109,11 +116,12 @@ impl ChatTemplate {
// tests
#[cfg(test)]
mod tests {
use crate::infer::chat_template::raise_exception;
use crate::infer::chat_template::{raise_exception, strftime_now};
use crate::infer::ChatTemplate;
use crate::{
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
};
use chrono::Local;
use minijinja::Environment;

#[test]
Expand Down Expand Up @@ -182,6 +190,7 @@ mod tests {
fn test_chat_template_invalid_with_raise() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);

let source = r#"
{{ bos_token }}
Expand Down Expand Up @@ -253,6 +262,7 @@ mod tests {
fn test_chat_template_valid_with_raise() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);

let source = r#"
{{ bos_token }}
Expand Down Expand Up @@ -307,10 +317,79 @@ mod tests {
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
}

#[test]
fn test_chat_template_valid_with_strftime_now() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);

let source = r#"
{% set today = strftime_now("%Y-%m-%d") %}
{% set default_system_message = "The current date is " + today + "." %}
{{ bos_token }}
{% if messages[0]['role'] == 'system' %}
{ set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{% else %}
{%- set system_message = default_system_message %}
{%- set loop_messages = messages %}
{% endif %}
{{ '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
{% for message in loop_messages %}
{% if message['role'] == 'user' %}
{{ '[INST]' + message['content'] + '[/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ message['content'] + eos_token }}
{% else %}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif %}
{% endfor %}
"#;

// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");

let tmpl = env.template_from_str(&source);

let chat_template_inputs = ChatTemplateInputs {
messages: vec![
TextMessage {
role: "user".to_string(),
content: "Hi!".to_string(),
},
TextMessage {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
TextMessage {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
TextMessage {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
..Default::default()
};

let current_date = Local::now().format("%Y-%m-%d").to_string();
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, format!("[BOS][SYSTEM_PROMPT]The current date is {}.[/SYSTEM_PROMPT][INST]Hi![/INST]Hello how can I help?[EOS][INST]What is Deep Learning?[/INST]magic![EOS]", current_date));
}

#[test]
fn test_chat_template_valid_with_add_generation_prompt() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);

let source = r#"
{% for message in messages %}
Expand Down Expand Up @@ -502,6 +581,7 @@ mod tests {
{
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
let tmpl = env.template_from_str(chat_template);
let result = tmpl.unwrap().render(input).unwrap();
assert_eq!(result, target);
Expand Down Expand Up @@ -776,6 +856,7 @@ mod tests {
{
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
env.add_function("strftime_now", strftime_now);
// trim all the whitespace
let chat_template = chat_template
.lines()
Expand Down

0 comments on commit 88fd56f

Please sign in to comment.