From 3d132ab6275a6503399901c1ae8a7ff3e9fb5c0c Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Fri, 31 Jan 2025 21:25:41 +0100 Subject: [PATCH 1/3] Add `chrono` and `strftime_now` function callable --- Cargo.lock | 55 ++++++++++++++++++++- router/Cargo.toml | 1 + router/src/infer/chat_template.rs | 80 ++++++++++++++++++++++++++++++- 3 files changed, 134 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e63d15407a3..057c2d49c01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -52,6 +52,21 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.18" @@ -651,6 +666,20 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chrono" +version = "0.4.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets 0.52.6", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -1802,6 +1831,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "1.5.0" @@ -4512,6 +4564,7 @@ dependencies = [ "axum 0.7.9", "axum-tracing-opentelemetry", "base64 0.22.1", + "chrono", "clap 4.5.21", "csv", "futures", diff --git a/router/Cargo.toml b/router/Cargo.toml index 2e621dfc669..e4d0179a426 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -64,6 +64,7 @@ uuid = { version = "1.9.1", default-features = false, features = [ csv = "1.3.0" ureq = "=2.9" pyo3 = { workspace = true } +chrono = "0.4.39" [build-dependencies] diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 2bda71933ab..61c8eed020c 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,5 +1,6 @@ use crate::infer::InferError; use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; +use chrono::Local; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -8,6 +9,11 @@ pub(crate) fn raise_exception(err_text: String) -> Result Result { + Ok(Local::now().format(&format_str).to_string()) +} + #[derive(Clone)] pub(crate) struct ChatTemplate { template: Template<'static, 'static>, @@ -27,6 +33,7 @@ impl ChatTemplate { env.set_unknown_method_callback(pycompat::unknown_method_callback); let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); tracing::debug!("Loading template: {}", template_str); // leaking env and template_str as read-only, static resources for performance. @@ -109,11 +116,12 @@ impl ChatTemplate { // tests #[cfg(test)] mod tests { - use crate::infer::chat_template::raise_exception; + use crate::infer::chat_template::{raise_exception, strftime_now}; use crate::infer::ChatTemplate; use crate::{ ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool, }; + use chrono::Local; use minijinja::Environment; #[test] @@ -182,6 +190,7 @@ mod tests { fn test_chat_template_invalid_with_raise() { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); let source = r#" {{ bos_token }} @@ -253,6 +262,7 @@ mod tests { fn test_chat_template_valid_with_raise() { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); let source = r#" {{ bos_token }} @@ -307,10 +317,76 @@ mod tests { assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); } + #[test] + fn test_chat_template_valid_with_strftime_now() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); + + let source = r#" + {% set today = strftime_now("%Y-%m-%d") %} + {% set default_system_message = "The current date is " + today + ".\n" %} + {{ bos_token }} + {% if messages[0]['role'] == 'system' %} + {% set loop_messages = messages[1:] %} + {% else %} + {% set loop_messages = messages %} + {% endif %} + {% for message in loop_messages %} + {% if message['role'] == 'user' %} + {{ '[INST]' + message['content'] + '[/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ message['content'] + eos_token }} + {% else %} + {{ raise_exception('Only user and assistant roles are supported!') }} + {% endif %} + {% endfor %} + "#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + TextMessage { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + TextMessage { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let current_date = Local::now().format("%Y-%m-%d").to_string(); + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, format!("The current date is {}. [BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]", current_date)); + } + #[test] fn test_chat_template_valid_with_add_generation_prompt() { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); let source = r#" {% for message in messages %} @@ -502,6 +578,7 @@ mod tests { { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); let tmpl = env.template_from_str(chat_template); let result = tmpl.unwrap().render(input).unwrap(); assert_eq!(result, target); @@ -776,6 +853,7 @@ mod tests { { let mut env = Environment::new(); env.add_function("raise_exception", raise_exception); + env.add_function("strftime_now", strftime_now); // trim all the whitespace let chat_template = chat_template .lines() From 1c17d8a76825915791cfc3ddfa75ba895fe67479 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Fri, 31 Jan 2025 21:41:15 +0100 Subject: [PATCH 2/3] Fix `test_chat_template_valid_with_strftime_now` --- router/src/infer/chat_template.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 61c8eed020c..00938aba226 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -325,7 +325,8 @@ mod tests { let source = r#" {% set today = strftime_now("%Y-%m-%d") %} - {% set default_system_message = "The current date is " + today + ".\n" %} + {% set default_system_message = "The current date is " + today + "." %} + {{ default_system_message }} {{ bos_token }} {% if messages[0]['role'] == 'system' %} {% set loop_messages = messages[1:] %} @@ -379,7 +380,7 @@ mod tests { let current_date = Local::now().format("%Y-%m-%d").to_string(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); - assert_eq!(result, format!("The current date is {}. [BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]", current_date)); + assert_eq!(result, format!("The current date is {}.[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]", current_date)); } #[test] From 77940ac73f4efd0d8735466504ce12a66a1eb061 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Fri, 31 Jan 2025 21:52:55 +0100 Subject: [PATCH 3/3] Fix `test_chat_template_valid_with_strftime_now` --- router/src/infer/chat_template.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 00938aba226..8303ee76829 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -326,13 +326,15 @@ mod tests { let source = r#" {% set today = strftime_now("%Y-%m-%d") %} {% set default_system_message = "The current date is " + today + "." %} - {{ default_system_message }} {{ bos_token }} {% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} + { set system_message = messages[0]['content'] %} + {%- set loop_messages = messages[1:] %} {% else %} - {% set loop_messages = messages %} + {%- set system_message = default_system_message %} + {%- set loop_messages = messages %} {% endif %} + {{ '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }} {% for message in loop_messages %} {% if message['role'] == 'user' %} {{ '[INST]' + message['content'] + '[/INST]' }} @@ -380,7 +382,7 @@ mod tests { let current_date = Local::now().format("%Y-%m-%d").to_string(); let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); - assert_eq!(result, format!("The current date is {}.[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]", current_date)); + assert_eq!(result, format!("[BOS][SYSTEM_PROMPT]The current date is {}.[/SYSTEM_PROMPT][INST]Hi![/INST]Hello how can I help?[EOS][INST]What is Deep Learning?[/INST]magic![EOS]", current_date)); } #[test]