diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 15e443b..9be0b10 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -26,9 +26,51 @@ jobs: - name: Build run: cargo build --all-features --verbose - - name: Test + - name: Check formatting + run: cargo fmt --all -- --check + + - name: Clippy + run: cargo clippy --all-features + + - name: Test Without Default Features + run: cargo test --all-features --verbose --no-default-features + + - name: Test With all Features run: cargo test --all-features --verbose + - name: Test With Image Feature + run: cargo test --features image --verbose + + - name: Test with JPEG Feature + run: cargo test --features jpeg --verbose + + - name: Test with PNG Feature + run: cargo test --features png --verbose + + - name: Test with GIF Feature + run: cargo test --features gif --verbose + + - name: Test with WEBP Feature + run: cargo test --features webp --verbose + + - name: Test with Prompt Caching Feature + run: cargo test --features prompt-caching --verbose + + - name: Test with Log Feature + run: cargo test --features log --verbose + + - name: Test with Markdown Feature + run: cargo test --features markdown --verbose + + - name: Test with PartialEq Feature + run: cargo test --features partial-eq --verbose + + - name: Test with Langsan feature + run: cargo test --features langsan --verbose + + - name: Test with Memsecurity feature + run: cargo test --features memsecurity --verbose + # This should only happen on push to main. PRs should not upload coverage. - name: Install llvm-cov uses: taiki-e/install-action@cargo-llvm-cov diff --git a/Cargo.toml b/Cargo.toml index 5174f10..1ab31d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "misanthropic" -version = "0.4.2" +version = "0.5.0" edition = "2021" authors = ["Michael de Gans "] description = "An async, ergonomic, client for Anthropic's Messages API" @@ -17,9 +17,18 @@ categories = [ ] license = "MIT" +[package.metadata.docs.rs] +cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] + [lints.rust] unsafe_code = "forbid" +[profile.release] +lto = true +strip = true +# There are (hopefully) no panics in this library. If there are, they are bugs. +panic = "abort" + [dependencies] base64 = "0.22" derive_more = { version = "1", features = ["from", "is_variant", "display"] } @@ -27,14 +36,15 @@ eventsource-stream = "0.2" futures = "0.3" image = { version = "0.25", optional = true } log = { version = "0.4", optional = true } -memsecurity = "3.5" +memsecurity = { version = "3.5", optional = true } +zeroize = { version = "1", features = ["derive"] } # rustls because I am sick of getting Dependabot alerts for OpenSSL. reqwest = { version = "0.12", features = ["json", "stream"] } serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = "1" # markdown support -pulldown-cmark = { version = "0.12", optional = true } +pulldown-cmark = { version = "0.12", optional = true, features = ["serde"] } pulldown-cmark-to-cmark = { version = "17", optional = true } static_assertions = "1" langsan = { version = "0", features = [ @@ -46,6 +56,8 @@ langsan = { version = "0", features = [ "emoji", "verbose", ], optional = true } +# For HTML escaping +xml-rs = { version = "0.8", optional = true } [dev-dependencies] # for all examples @@ -84,16 +96,32 @@ log = ["dep:log"] rustls-tls = ["reqwest/rustls-tls"] # Use `pulldown-cmark` for markdown parsing and `pulldown-cmark-to-cmark` for # converting to CommonMark. -markdown = ["dep:pulldown-cmark", "dep:pulldown-cmark-to-cmark"] +markdown = ["pulldown-cmark/serde", "dep:pulldown-cmark-to-cmark"] +# Utilities for converting prompts and messages to HTML. Enables `markdown`. +html = ["markdown", "xml-rs"] # Derive PartialEq for all structs and enums. -partial_eq = [] +partial-eq = [] # Input and output sanitization langsan = ["dep:langsan"] +# Encrypted key in memory. Without this the key is still zeroed on drop, but is +# not encrypted. This is a more secure option for the paranoid. Does not build +# on wasm32. +memsecurity = ["dep:memsecurity"] [[example]] name = "strawberry" required-features = ["markdown"] +doc-scrape-examples = true [[example]] name = "python" required-features = ["markdown", "prompt-caching"] +doc-scrape-examples = true + +[[example]] +name = "website_wizard" +doc-scrape-examples = true + +[[example]] +name = "neologism" +doc-scrape-examples = true diff --git a/README.md b/README.md index 0bfdc7b..4e0d1cc 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ println!("{}", message); - [x] Message responses - [x] Image support with or without the `image` crate - [x] Markdown formatting of messages, including images +- [x] HTML formatting of messages\*. - [x] Prompt caching support - [x] Custom request and endpoint support - [x] Zero-copy where possible @@ -75,6 +76,8 @@ println!("{}", message); - [ ] Amazon Bedrock support - [ ] Vertex AI support +\* _Base64 encoded images are currently not implemented for HTML but this is a planned feature._ + [reqwest]: https://docs.rs/reqwest ## FAQ diff --git a/src/client.rs b/src/client.rs index 7eee03d..2b967c0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -120,6 +120,8 @@ impl Client { log::debug!("{} request to {}", method, url.as_str()); } + #[allow(clippy::useless_asref)] + // because with memsecurity feature it's not useless let mut val = reqwest::header::HeaderValue::from_bytes(self.key.read().as_ref()) .unwrap(); diff --git a/src/html.rs b/src/html.rs new file mode 100644 index 0000000..5d6feb6 --- /dev/null +++ b/src/html.rs @@ -0,0 +1,415 @@ +use std::ops::Deref; + +use pulldown_cmark::html::push_html; + +use crate::markdown::ToMarkdown; + +pub use crate::markdown::{Options, DEFAULT_OPTIONS, VERBOSE_OPTIONS}; + +/// Immutable wrapper around a [`String`]. Guaranteed to be valid HTML. +#[derive(derive_more::Display)] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] +#[display("{inner}")] +pub struct Html { + inner: String, +} + +impl Html { + /// Create a new `Html` from a stream of markdown events. + pub fn from_events<'a>( + events: impl Iterator>, + ) -> Self { + events.collect::() + } + + /// Extend the HTML with a stream of markdown events. + pub fn extend<'a, It>( + &mut self, + events: impl IntoIterator, IntoIter = It>, + ) where + It: Iterator>, + { + use pulldown_cmark::{CowStr, Event, Tag, TagEnd}; + use std::borrow::Cow; + use xml::escape::escape_str_pcdata; + + let escape_pcdata = |cow_str: CowStr<'a>| -> CowStr<'a> { + // This is necessary because `escape_str_pcdata` does not have + // lifetime annotations, although it could since it doesn't copy the + // string and this is documented. + match escape_str_pcdata(cow_str.as_ref()) { + Cow::Borrowed(_) => cow_str, + Cow::Owned(s) => s.into(), + } + }; + + let raw: It = events.into_iter(); + let escaped = raw.map(|e| { + match e { + // We must escape the HTML to prevent XSS attacks. A frontend should + // take other measures as well, but we can at least provide some + // protection. + Event::Text(cow_str) => Event::Text(escape_pcdata(cow_str)), + // Without this the escaping test fails because the paragraph + // tags are missing because of how the markdown is parsed. We + // always want message content to be in paragraphs. + Event::Start(Tag::HtmlBlock) => Event::Start(Tag::CodeBlock( + pulldown_cmark::CodeBlockKind::Fenced("html".into()), + )), + Event::End(TagEnd::HtmlBlock) => Event::End(TagEnd::CodeBlock), + Event::Code(cow_str) => Event::Code(escape_pcdata(cow_str)), + Event::InlineMath(cow_str) => { + Event::InlineMath(escape_pcdata(cow_str)) + } + Event::DisplayMath(cow_str) => { + Event::DisplayMath(escape_pcdata(cow_str)) + } + Event::Html(cow_str) => Event::Html(escape_pcdata(cow_str)), + Event::InlineHtml(cow_str) => { + Event::InlineHtml(escape_pcdata(cow_str)) + } + Event::FootnoteReference(cow_str) => { + Event::FootnoteReference(escape_pcdata(cow_str)) + } + // No other events have text or attributes that need to be + // escaped. Heading attributes are already escaped by + // pulldown-cmark when rendering to HTML, so we don't need to + // escape them here or we double escape them. + e => e, + } + }); + push_html(&mut self.inner, escaped); + } +} + +impl From for String { + fn from(html: Html) -> Self { + html.inner + } +} + +impl AsRef for Html { + fn as_ref(&self) -> &str { + self.deref() + } +} + +impl std::borrow::Borrow for Html { + fn borrow(&self) -> &str { + self.as_ref() + } +} + +impl std::ops::Deref for Html { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl<'a> FromIterator> for Html { + fn from_iter>>( + iter: T, + ) -> Self { + let mut html = Html { + inner: String::new(), + }; + html.extend(iter); + html + } +} + +/// A trait for types that can be converted to HTML. This generally does not +/// need to be implemented directly, as it is already implemented for types +/// that implement [`ToMarkdown`]. +/// +/// # Note +/// - `attrs` are always enabled for HTML rendering so this does not have to be +/// set on the [`MarkdownOptions`]. +/// +/// [`MarkdownOptions`]: struct.MarkdownOptions.html +pub trait ToHtml: ToMarkdown { + /// Render the type to an HTML string. + fn html(&self) -> Html { + let mut opts = DEFAULT_OPTIONS; + opts.attrs = true; + self.html_custom(DEFAULT_OPTIONS) + } + + /// Render the type to an HTML string with maximum verbosity. + fn html_verbose(&self) -> Html { + self.html_custom(VERBOSE_OPTIONS) + } + + /// Render the type to an HTML string with custom [`Options`]. + fn html_custom(&self, options: Options) -> Html { + self.markdown_events_custom(options).collect() + } +} + +impl ToHtml for T where T: ToMarkdown {} + +#[cfg(test)] +mod tests { + use std::borrow::Borrow; + + use serde_json::json; + + use crate::{ + prompt::{message::Role, Message}, + tool, Tool, + }; + + use super::*; + + #[test] + fn test_message_html() { + let message = Message { + role: Role::User, + content: "Hello, **world**!".into(), + }; + + assert_eq!( + message.html().as_ref(), + "

User

\n

Hello, world!

\n", + ); + + let opts = Options { + attrs: true, + ..Default::default() + }; + + assert_eq!( + message.html_custom(opts).as_ref(), + "

User

\n

Hello, world!

\n", + ); + } + + #[test] + fn test_prompt_html() { + let prompt = crate::prompt::Prompt { + system: Some("Do stuff the user says.".into()), + tools: Some(vec![Tool { + name: "python".into(), + description: "Run a Python script.".into(), + input_schema: json!({ + "type": "object", + "properties": { + "script": { + "type": "string", + "description": "Python script to run.", + }, + }, + "required": ["script"], + }), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }]), + messages: vec![ + Message { + role: Role::User, + content: "Run a hello world python program.".into(), + }, + tool::Use { + id: "id".into(), + name: "python".into(), + input: json!({ + "script": "print('Hello, world!')", + }), + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(), + tool::Result { + tool_use_id: "id".into(), + content: json!({ + "stdout": "Hello, world!\n", + }) + .to_string() + .into(), + is_error: false, + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(), + Message { + role: Role::Assistant, + content: "It is done!".into(), + }, + ], + ..Default::default() + }; + + assert_eq!( + prompt.html().as_ref(), + "

User

\n

Run a hello world python program.

\n

Assistant

\n

It is done!

\n", + ); + + let opts = Options { + attrs: true, + ..Default::default() + }; + + assert_eq!( + prompt.html_custom(opts).as_ref(), + "

User

\n

Run a hello world python program.

\n

Assistant

\n

It is done!

\n", + ); + + assert_eq!( + prompt.html_verbose().as_ref(), + "

System

\n

Do stuff the user says.

\n

User

\n

Run a hello world python program.

\n

Assistant

\n
{\"type\":\"tool_use\",\"id\":\"id\",\"name\":\"python\",\"input\":{\"script\":\"print('Hello, world!')\"}}
\n

Tool

\n
{\"type\":\"tool_result\",\"tool_use_id\":\"id\",\"content\":[{\"type\":\"text\",\"text\":\"{\\\"stdout\\\":\\\"Hello, world!\\\\n\\\"}\"}],\"is_error\":false}
\n

Assistant

\n

It is done!

\n", + ) + } + + #[test] + fn test_html_from_events() { + let events = vec![ + pulldown_cmark::Event::Start(pulldown_cmark::Tag::Paragraph), + pulldown_cmark::Event::Text("Hello, world!".into()), + pulldown_cmark::Event::End(pulldown_cmark::TagEnd::Paragraph), + ]; + + let html = Html::from_events(events.into_iter()); + assert_eq!(html.as_ref(), "

Hello, world!

\n"); + } + + #[test] + fn test_html_extend() { + let mut html = Html { + inner: String::new(), + }; + + let events = vec![ + pulldown_cmark::Event::Start(pulldown_cmark::Tag::Paragraph), + pulldown_cmark::Event::Text("Hello, world!".into()), + pulldown_cmark::Event::End(pulldown_cmark::TagEnd::Paragraph), + ]; + + html.extend(events.into_iter()); + assert_eq!(html.as_ref(), "

Hello, world!

\n"); + } + + #[test] + fn test_html_from_iter() { + let events = vec![ + pulldown_cmark::Event::Start(pulldown_cmark::Tag::Paragraph), + pulldown_cmark::Event::Text("Hello, world!".into()), + pulldown_cmark::Event::End(pulldown_cmark::TagEnd::Paragraph), + ]; + + let html: Html = events.into_iter().collect(); + assert_eq!(html.as_ref(), "

Hello, world!

\n"); + } + + #[test] + fn test_to_html() { + let message = Message { + role: Role::User, + content: "Hello, **world**!".into(), + }; + + assert_eq!( + message.html().as_ref(), + "

User

\n

Hello, world!

\n", + ); + + assert_eq!( + message.html_verbose().as_ref(), + "

User

\n

Hello, world!

\n", + ); + + assert_eq!( + message + .html_custom(Options { + attrs: true, + ..Default::default() + }) + .as_ref(), + // `attrs` are always enabled for HTML rendering + "

User

\n

Hello, world!

\n", + ); + } + + #[test] + fn test_borrow() { + let message = Message { + role: Role::User, + content: "Hello, **world**!".into(), + }; + + let html: Html = message.html(); + let borrowed: &str = html.borrow(); + assert_eq!(borrowed, html.as_ref()); + } + + #[test] + fn test_into_string() { + let message = Message { + role: Role::User, + content: "Hello, **world**!".into(), + }; + + let html: Html = message.html(); + let string: String = html.into(); + assert_eq!( + string, + "

User

\n

Hello, world!

\n" + ); + } + + #[test] + fn test_escaping() { + use pulldown_cmark::{Event, HeadingLevel::H3, Tag, TagEnd}; + + let message = Message { + role: Role::Assistant, + content: "bla blabla bla".into(), + }; + + assert_eq!( + message.html().as_ref(), + "

Assistant

\n

bla bla<script>alert('XSS')</script>bla bla

\n", + ); + + let message = Message { + role: Role::Assistant, + content: "".into(), + }; + + assert_eq!( + message.html_verbose().as_ref(), + // In the case where a content block is entirely code, it is + // rendered as a code block. This is mostly done because of how + // markdown is parsed and we're lazy, but also it's nice behavior. + "

Assistant

\n
<script>alert('XSS')</script>
\n", + ); + + // Test escaping of attributes + let bad_attrs = vec![ + Event::Start(Tag::Heading { + level: H3, + id: None, + classes: vec![], + attrs: vec![( + r#"

badkey

"#.into(), + Some(r#""sneaky">"#.into()), + )], + }), + Event::Text("Hello, world!".into()), + Event::End(TagEnd::Heading(H3)), + ]; + + let html = Html::from_events(bad_attrs.into_iter()); + // FIXME: This is not the correct behavior. pulldown_cmark is escaping + // the attributes, but not forward slashes in keys leading to a broken + // key. This is a bug in pulldown_cmark. Fixing this is a low priority + // since it only applies to cases where a third party is providing the + // trait and doing very silly things with attributes. + assert_eq!( + html.as_ref(), + r#"

Hello, world!

+"# + ); + } +} diff --git a/src/key.rs b/src/key.rs index 5d19d95..7afa81e 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,135 +1,10 @@ -//! [`Key`] management for Anthropic API keys. - -// TODO: Remove this dependency for wasm32 and find an alternative. It's not a -// super idea to use this in a web app but wasm32 also has server use cases. -// This is the only thing that prevents this from building on wasm32. -use memsecurity::zeroize::Zeroize; - -/// The length of an Anthropic API key in bytes. -pub const LEN: usize = 108; - -/// Type alias for an Anthropic API key. -pub type Arr = [u8; LEN]; - -/// Error for when a key is not the correct [`key::LEN`]. -/// -/// [`key::LEN`]: LEN -#[derive(Debug, thiserror::Error)] -#[error("Invalid key length: {actual} (expected {LEN})")] -pub struct InvalidKeyLength { - /// The incorrect actual length of the key. - pub actual: usize, -} - -/// Stores an Anthropic API key securely. The API key is encrypted in memory. -/// The object features a [`Display`] implementation that can be used to write -/// out the key. **Be sure to zeroize whatever you write it to**. Prefer -/// [`Key::read`] if you want a return value that will automatically zeroize -/// the key on drop. -/// -/// [`Display`]: std::fmt::Display -#[derive(Debug)] -pub struct Key { - // FIXME: `memsecurity` does not build on wasm32. Find a solution for web. - // The `keyring` crate may work, but I'm likewise not sure if it builds on - // wasm32. It's not listed in the platforms so likely not. - mem: memsecurity::EncryptedMem, -} - -impl Key { - /// Read the key. The key is zeroized on drop. - pub fn read(&self) -> memsecurity::ZeroizeBytes { - self.mem.decrypt().unwrap() - } -} - -impl TryFrom for Key { - type Error = InvalidKeyLength; - - /// Create a new key from a string securely. The string is zeroized after - /// conversion. - fn try_from(s: String) -> Result { - // This just unwraps the internal Vec so the data can still be - // zeroized. - Self::try_from(s.into_bytes()) - } -} - -impl TryFrom> for Key { - type Error = InvalidKeyLength; - - /// Create a new key from a byte vector securely. The vector is zeroized - /// after conversion. - fn try_from(mut v: Vec) -> Result { - let mut arr: Arr = [0; LEN]; - if v.len() != LEN { - let actual = v.len(); - v.zeroize(); - return Err(InvalidKeyLength { actual }); - } - - arr.copy_from_slice(&v); - let ret = Ok(Self::from(arr)); - - v.zeroize(); - - ret - } -} - -impl From for Key { - /// Create a new key from a [`key::Arr`] byte array securely. The array is - /// zeroized. - /// - /// [`key::Arr`]: Arr - fn from(mut arr: Arr) -> Self { - let mut mem = memsecurity::EncryptedMem::new(); - - // Unwrap is desirable here because this can only fail if encryption - // is broken, which is a catastrophic failure. - mem.encrypt(&arr).unwrap(); - arr.zeroize(); - - Self { mem } - } -} - -impl std::fmt::Display for Key { - /// Write out the key. Make sure to zeroize whatever you write it to if at - /// all possible. - /// - /// Prefer [`Self::read`] if you want a return value that will automatically - /// zeroize the key on drop. - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // Zeroized on drop - let key = self.read(); - // Unwrap can never panic because a Key can only be created from a str - // or String which are guaranteed to be valid UTF-8. - let key_str = std::str::from_utf8(key.as_ref()).unwrap(); - write!(f, "{}", key_str) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - // Note: This is a real key but it's been disabled. As is warned in the - // docs above, do not use a string literal for a real key. There is no - // TryFrom<&'static str> for Key for this reason. - const API_KEY: &str = "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"; - - #[test] - fn test_key() { - let key = Key::try_from(API_KEY.to_string()).unwrap(); - let key_str = key.to_string(); - assert_eq!(key_str, API_KEY); - } - - #[test] - fn test_invalid_key_length() { - let key = "test_key".to_string(); - let err = Key::try_from(key).unwrap_err(); - assert_eq!(err.to_string(), "Invalid key length: 8 (expected 108)"); - } -} +//! [`Key`] is a wrapper around an Anthropic API key. + +#[cfg(feature = "memsecurity")] +mod encrypted; +#[cfg(feature = "memsecurity")] +pub use encrypted::{InvalidKeyLength, Key}; +#[cfg(not(feature = "memsecurity"))] +mod unencrypted; +#[cfg(not(feature = "memsecurity"))] +pub use unencrypted::{InvalidKeyLength, Key}; diff --git a/src/key/encrypted.rs b/src/key/encrypted.rs new file mode 100644 index 0000000..c576dd4 --- /dev/null +++ b/src/key/encrypted.rs @@ -0,0 +1,103 @@ +//! Encrypted [`Key`] management for Anthropic API keys. + +// TODO: Remove this dependency for wasm32 and find an alternative. It's not a +// super idea to use this in a web app but wasm32 also has server use cases. +// This is the only thing that prevents this from building on wasm32. +use memsecurity::zeroize::Zeroizing; + +/// The length of an Anthropic API key in bytes. +pub const LEN: usize = 108; + +/// Error for when a key is not 108 bytes. +#[derive(Debug, thiserror::Error)] +#[error("Invalid key length: {actual} (expected {LEN})")] +pub struct InvalidKeyLength { + /// The incorrect actual length of the key. + pub actual: usize, +} + +/// Stores an Anthropic API key securely. The API key is encrypted in memory. +/// The object features a [`Display`] implementation that can be used to write +/// out the key. **Be sure to zeroize whatever you write it to**. Prefer +/// [`Key::read`] if you want a return value that will automatically zeroize +/// the key on drop. +/// +/// [`Display`]: std::fmt::Display +#[derive(Debug)] +pub struct Key { + // FIXME: `memsecurity` does not build on wasm32. Find a solution for web. + // The `keyring` crate may work, but I'm likewise not sure if it builds on + // wasm32. It's not listed in the platforms so likely not. + mem: memsecurity::EncryptedMem, +} + +impl Key { + /// Read the key. The key is zeroized on drop. + pub fn read(&self) -> memsecurity::ZeroizeBytes { + self.mem.decrypt().unwrap() + } +} + +impl TryFrom for Key { + type Error = InvalidKeyLength; + + /// Create a new key from a string securely. The string is zeroized after + /// conversion. + fn try_from(s: String) -> Result { + // This just unwraps the internal Vec so the data can still be + // zeroized. + let v = Zeroizing::new(s.into_bytes()); + if v.len() != LEN { + let actual = v.len(); + return Err(InvalidKeyLength { actual }); + } + + let mut mem = memsecurity::EncryptedMem::new(); + + // Unwrap is desirable here because this can only fail if encryption + // is broken, which is a catastrophic failure. + mem.encrypt(&v).unwrap(); + + Ok(Self { mem }) + } +} + +impl std::fmt::Display for Key { + /// Write out the key. Make sure to zeroize whatever you write it to if at + /// all possible. + /// + /// Prefer [`Self::read`] if you want a return value that will automatically + /// zeroize the key on drop. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Zeroized on drop + let key = self.read(); + // Unwrap can never panic because a Key can only be created from a + // String which is guaranteed to be valid UTF-8. + let key_str = std::str::from_utf8(key.as_ref()).unwrap(); + write!(f, "{}", key_str) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Note: This is a real key but it's been disabled. As is warned in the + // docs above, do not use a string literal for a real key. There is no + // TryFrom<&'static str> for Key for this reason. + const API_KEY: &str = "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"; + + #[test] + fn test_key() { + let key = Key::try_from(API_KEY.to_string()).unwrap(); + let key_str = key.to_string(); + assert_eq!(key_str, API_KEY); + } + + #[test] + fn test_invalid_key_length() { + let key = "test_key".to_string(); + let err = Key::try_from(key).unwrap_err(); + assert_eq!(err.to_string(), "Invalid key length: 8 (expected 108)"); + } +} diff --git a/src/key/unencrypted.rs b/src/key/unencrypted.rs new file mode 100644 index 0000000..0041c72 --- /dev/null +++ b/src/key/unencrypted.rs @@ -0,0 +1,93 @@ +//! Unencrypted [`Key`] management for Anthropic API keys. +use zeroize::{ZeroizeOnDrop, Zeroizing}; + +/// The length of an Anthropic API key in bytes. +pub const LEN: usize = 108; + +/// Type alias for an Anthropic API key. +type Arr = [u8; LEN]; + +/// Error for when a key is not 108 bytes. +#[derive(Debug, thiserror::Error)] +#[error("Invalid key length: {actual} (expected {LEN})")] +pub struct InvalidKeyLength { + /// The incorrect actual length of the key. + pub actual: usize, +} + +/// Stores an Anthropic API key securely. The object features a [`Display`] +/// implementation that can be used to write out the key. **Be sure to zeroize +/// whatever you write it to**. The key is zeroized on drop. +/// +/// [`Display`]: std::fmt::Display +#[derive(Debug, ZeroizeOnDrop)] +pub struct Key { + mem: Arr, +} + +impl Key { + /// Read the key. The key is zeroized on drop. + // We can't return a &str becuase the other implementation of Key::read + // returns a memsecurity::ZeroizeBytes, and we can't return a reference to + // that because it's a temporary value, so this returns a slice instead, + // which has more or less the same public API. + pub fn read(&self) -> &[u8] { + &self.mem + } +} + +impl TryFrom for Key { + type Error = InvalidKeyLength; + + /// Create a new key from a string securely. The string is zeroized after + /// conversion. + fn try_from(s: String) -> Result { + let v = Zeroizing::new(s.into_bytes()); + let mut arr: Arr = [0; LEN]; + if v.len() != LEN { + let actual = v.len(); + return Err(InvalidKeyLength { actual }); + } + + arr.copy_from_slice(&v); + Ok(Key { mem: arr }) + } +} + +impl std::fmt::Display for Key { + /// Write out the key. Make sure to zeroize whatever you write it to if at + /// all possible. + /// + /// Prefer [`Self::read`] if you want a return value that will automatically + /// zeroize the key on drop. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Unwrap can never panic because a Key can only be created from String + // whic is guaranteed to be valid UTF-8. + let key_str = std::str::from_utf8(self.read()).unwrap(); + write!(f, "{}", key_str) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Note: This is a real key but it's been disabled. As is warned in the + // docs above, do not use a string literal for a real key. There is no + // TryFrom<&'static str> for Key for this reason. + const API_KEY: &str = "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"; + + #[test] + fn test_key() { + let key = Key::try_from(API_KEY.to_string()).unwrap(); + let key_str = key.to_string(); + assert_eq!(key_str, API_KEY); + } + + #[test] + fn test_invalid_key_length() { + let key = "test_key".to_string(); + let err = Key::try_from(key).unwrap_err(); + assert_eq!(err.to_string(), "Invalid key length: 8 (expected 108)"); + } +} diff --git a/src/lib.rs b/src/lib.rs index 5b6a927..94d2854 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,10 @@ pub use response::Response; /// Markdown utilities for parsing and rendering. pub mod markdown; +#[cfg(feature = "html")] +/// Converts prompts and messages to HTML. +pub mod html; + #[cfg(not(feature = "langsan"))] pub(crate) type CowStr<'a> = std::borrow::Cow<'a, str>; #[cfg(feature = "langsan")] @@ -54,6 +58,7 @@ pub mod exports { pub use langsan; #[cfg(feature = "log")] pub use log; + #[cfg(feature = "memsecurity")] pub use memsecurity; #[cfg(feature = "markdown")] pub use pulldown_cmark; diff --git a/src/markdown.rs b/src/markdown.rs index bf503cf..dd10173 100644 --- a/src/markdown.rs +++ b/src/markdown.rs @@ -1,5 +1,6 @@ use std::ops::Deref; +use pulldown_cmark::HeadingLevel; use serde::{Deserialize, Serialize}; /// Default [`Options`] @@ -8,6 +9,8 @@ pub const DEFAULT_OPTIONS: Options = Options { tool_use: false, tool_results: false, system: false, + attrs: false, + heading_level: None, }; /// Verbose [`Options`] @@ -16,14 +19,10 @@ pub const VERBOSE_OPTIONS: Options = Options { tool_use: true, tool_results: true, system: true, + attrs: true, + heading_level: None, }; -/// A static reference to the default [`Options`]. -pub static DEFAULT_OPTIONS_REF: &'static Options = &DEFAULT_OPTIONS; - -/// A static reference to the verbose [`Options`]. -pub static VERBOSE_OPTIONS_REF: &'static Options = &VERBOSE_OPTIONS; - mod serde_inner { use super::*; @@ -49,21 +48,35 @@ mod serde_inner { } /// Options for parsing, generating, and rendering [`Markdown`]. -#[derive(Serialize, Deserialize)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[derive(Clone, Copy, Serialize, Deserialize)] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] +#[serde(default)] pub struct Options { /// Inner [`pulldown_cmark::Options`]. #[serde(with = "serde_inner")] pub inner: pulldown_cmark::Options, /// Whether to include the system prompt - #[serde(default)] pub system: bool, /// Whether to include tool uses. - #[serde(default)] pub tool_use: bool, /// Whether to include tool results. - #[serde(default)] pub tool_results: bool, + /// Whether to include attributes. Useful when converting to HTML. + /// + /// This adds: + /// - `role` attribute to the [`Prompt`] and [`Message`]s. Possible values + /// are: + /// - `system` - for the system prompt + /// - `assistant` - for generated messages + /// - `tool` - for tool results + /// - `user` - for user messages + /// - `error` - for errors + /// + /// [`Prompt`]: crate::prompt::Prompt + /// [`Message`]: crate::prompt::Message + pub attrs: bool, + /// Heading level to begin at (optional) + pub heading_level: Option, } impl Options { @@ -112,21 +125,21 @@ impl From for Options { /// /// [`Display`]: std::fmt::Display #[derive(derive_more::Display)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[display("{text}")] pub struct Markdown { text: String, } -impl Into for Markdown { - fn into(self) -> String { - self.text +impl From for String { + fn from(markdown: Markdown) -> Self { + markdown.text } } impl AsRef for Markdown { fn as_ref(&self) -> &str { - self.deref().as_ref() + self.deref() } } @@ -159,14 +172,21 @@ where } } -#[cfg(any(test, feature = "partial_eq"))] +#[cfg(any(test, feature = "partial-eq"))] impl PartialEq for Markdown { fn eq(&self, other: &str) -> bool { self.text == other } } -/// A trait for types that can be converted to [`Markdown`]. +/// A trait for types that can be converted to [`Markdown`] +/// +/// # Note +/// +/// - Any of these methods returning an iterator of [`pulldown_cmark::Event`]s +/// can be used to render to html using [`pulldown_cmark::html::push_html`] +/// and other similar functions. +/// - Implementers should guarantee tags are properly closed and nested. pub trait ToMarkdown { /// Render the type to a [`Markdown`] string with [`DEFAULT_OPTIONS`]. fn markdown(&self) -> Markdown { @@ -174,13 +194,13 @@ pub trait ToMarkdown { } /// Render the type to a [`Markdown`] string with custom [`Options`]. - fn markdown_custom(&self, options: &Options) -> Markdown { + fn markdown_custom(&self, options: Options) -> Markdown { self.markdown_events_custom(options).into() } /// Render the type to a [`Markdown`] string with maximum verbosity. fn markdown_verbose(&self) -> Markdown { - self.markdown_custom(VERBOSE_OPTIONS_REF) + self.markdown_custom(VERBOSE_OPTIONS) } /// Render the markdown to a type implementing [`std::fmt::Write`] with @@ -189,7 +209,7 @@ pub trait ToMarkdown { &self, writer: &mut dyn std::fmt::Write, ) -> std::fmt::Result { - self.write_markdown_custom(writer, DEFAULT_OPTIONS_REF) + self.write_markdown_custom(writer, DEFAULT_OPTIONS) } /// Render the markdown to a type implementing [`std::fmt::Write`] with @@ -197,7 +217,7 @@ pub trait ToMarkdown { fn write_markdown_custom( &self, writer: &mut dyn std::fmt::Write, - options: &Options, + options: Options, ) -> std::fmt::Result { use pulldown_cmark_to_cmark::cmark; @@ -211,14 +231,14 @@ pub trait ToMarkdown { fn markdown_events<'a>( &'a self, ) -> Box> + 'a> { - self.markdown_events_custom(DEFAULT_OPTIONS_REF) + self.markdown_events_custom(DEFAULT_OPTIONS) } /// Return an iterator of [`pulldown_cmark::Event`]s with custom /// [`Options`]. fn markdown_events_custom<'a>( &'a self, - options: &'a Options, + options: Options, ) -> Box> + 'a>; } @@ -287,4 +307,10 @@ mod tests { "### User\n\nHello, **world**!" ); } + + #[test] + fn test_options_with_system() { + let options = Options::default().with_system(); + assert!(options.system); + } } diff --git a/src/model.rs b/src/model.rs index 8d8965a..e43cc7b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -32,16 +32,22 @@ pub enum Model { #[serde(rename = "claude-3-opus-20240229")] Opus30_20240229, /// Sonnet 3.0 - #[cfg(not(feature = "prompt-caching"))] #[serde(rename = "claude-3-sonnet-20240229")] Sonnet30, + /// Haiku 3.5 (latest) + #[serde(rename = "claude-3-5-haiku-latest")] + Haiku35, + /// Haiku 3.5 2024-10-22 + #[serde(rename = "claude-3-5-haiku-20241022")] + Haiku35_20241022, /// Haiku 3.0 (latest) This is the default model. + // Note: The `latest` tag is not yet supported by the API for Haiku 3.0, so + // in the future this might point to a separate model. We can't use the same + // serde tag for both, so there's only one option here for now. #[default] - // The `latest` alias is not enabled yet, but (very likely) will be in the - // future. If not we will manually update this. - #[serde(rename = "claude-3-haiku-20240307")] + #[serde( + rename = "claude-3-haiku-20240307", + alias = "claude-3-haiku-latest" + )] Haiku30, - /// Haiku 3.0 2024-03-07 - #[serde(rename = "claude-3-haiku-20240307")] - Haiku30_20240307, } diff --git a/src/prompt.rs b/src/prompt.rs index 0a49a14..4e62afe 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -16,7 +16,7 @@ pub use message::Message; /// /// [Anthropic Messages API]: #[derive(Serialize, Deserialize)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[serde(default)] pub struct Prompt<'a> { /// [`Model`] to use for inference. @@ -505,7 +505,7 @@ impl<'a> Prompt<'a> { // If there are no messages or system prompt, add a cache breakpoint to // the tools if they exist. if let Some(tool) = - self.tools.as_mut().map(|tools| tools.last_mut()).flatten() + self.tools.as_mut().and_then(|tools| tools.last_mut()) { tool.cache(); return self; @@ -526,9 +526,9 @@ impl crate::markdown::ToMarkdown for Prompt<'_> { /// [`Role`]: message::Role fn markdown_events_custom<'a>( &'a self, - options: &'a crate::markdown::Options, + options: crate::markdown::Options, ) -> Box> + 'a> { - use pulldown_cmark::{Event, HeadingLevel, Tag, TagEnd}; + use pulldown_cmark::{Event, HeadingLevel::H3, Tag, TagEnd}; // TODO: Add the title if there is metadata for it. Also add a metadata // option to Options to include arbitrary metadata. In my use case I am @@ -541,15 +541,21 @@ impl crate::markdown::ToMarkdown for Prompt<'_> { .map(|s| s.markdown_events_custom(options)) { if options.system { + let heading_level = options.heading_level.unwrap_or(H3); + let header = [ Event::Start(Tag::Heading { - level: HeadingLevel::H3, + level: heading_level, id: None, classes: vec![], - attrs: vec![], + attrs: if options.attrs { + vec![("role".into(), Some("system".into()))] + } else { + vec![] + }, }), Event::Text("System".into()), - Event::End(TagEnd::Heading(HeadingLevel::H3)), + Event::End(TagEnd::Heading(heading_level)), ]; Box::new(header.into_iter().chain(system)) @@ -962,6 +968,7 @@ mod tests { input: json!({ "host": "example.com" }), + #[cfg(feature = "prompt-caching")] cache_control: None, } .into(), @@ -979,15 +986,13 @@ mod tests { }, ]); - let opts = crate::markdown::Options::verbose(); - - let markdown: Markdown = request.markdown_custom(&opts); + let markdown: Markdown = request.markdown_verbose(); // OpenAI format. Anthropic doesn't have a "system" or "tool" role but // we generate markdown like this because it's easier to read. The user // does not submit a tool result, so it's confusing if the header is // "User". - let expected = "### System\n\nYou are a very succinct assistant.\n\n### User\n\nHello\n\n### Assistant\n\nHi\n\n### User\n\nCall a tool.\n\n### Assistant\n\n````json\n{\"type\":\"tool_use\",\"id\":\"abc123\",\"name\":\"ping\",\"input\":{\"host\":\"example.com\"}}\n````\n\n### Tool\n\n````json\n{\"type\":\"tool_result\",\"tool_use_id\":\"abc123\",\"content\":[{\"type\":\"text\",\"text\":\"Pinging example.com.\"}],\"is_error\":false}\n````\n\n### Assistant\n\nDone."; + let expected = "### System { role=system }\n\nYou are a very succinct assistant.\n\n### User { role=user }\n\nHello\n\n### Assistant { role=assistant }\n\nHi\n\n### User { role=user }\n\nCall a tool.\n\n### Assistant { role=assistant }\n\n````json\n{\"type\":\"tool_use\",\"id\":\"abc123\",\"name\":\"ping\",\"input\":{\"host\":\"example.com\"}}\n````\n\n### Tool { role=tool }\n\n````json\n{\"type\":\"tool_result\",\"tool_use_id\":\"abc123\",\"content\":[{\"type\":\"text\",\"text\":\"Pinging example.com.\"}],\"is_error\":false}\n````\n\n### Assistant { role=assistant }\n\nDone."; assert_eq!(markdown.as_ref(), expected); } diff --git a/src/prompt/message.rs b/src/prompt/message.rs index 6a34bcd..638b171 100644 --- a/src/prompt/message.rs +++ b/src/prompt/message.rs @@ -16,7 +16,7 @@ use crate::{ /// Role of the [`Message`] author. #[derive(Clone, Copy, Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] pub enum Role { /// From the user. User, @@ -43,14 +43,13 @@ impl std::fmt::Display for Role { /// A message in a [`Request`]. See [`response::Message`] for the version with /// additional metadata. /// -/// A message is [`Display`]ed as markdown with a [heading] indicating the +/// A message is [`Display`]ed as markdown with a heading indicating the /// [`Role`] of the author. [`Image`]s are supported and will be rendered as /// markdown images with embedded base64 data. /// /// [`Display`]: std::fmt::Display /// [`Request`]: crate::prompt /// [`response::Message`]: crate::response::Message -/// [heading]: Message::HEADING #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] #[cfg_attr( @@ -58,7 +57,7 @@ impl std::fmt::Display for Role { derive(derive_more::Display), display("{}{}{}{}", Self::HEADING, role, Content::SEP, content) )] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] pub struct Message<'a> { /// Who is the message from. pub role: Role, @@ -75,24 +74,17 @@ impl Message<'_> { /// [`Display`]: std::fmt::Display #[cfg(not(feature = "markdown"))] pub const HEADING: &'static str = "### "; - /// Heading for the message when rendered as markdown using markdown methods - /// as well as [`Display`]. - /// - /// [`Display`]: std::fmt::Display - #[cfg(feature = "markdown")] - pub const HEADING: pulldown_cmark::Tag<'static> = - pulldown_cmark::Tag::Heading { - level: pulldown_cmark::HeadingLevel::H3, - id: None, - classes: vec![], - attrs: vec![], - }; /// Returns the number of [`Content`] [`Block`]s in the message. pub fn len(&self) -> usize { self.content.len() } + /// Returns true if self has no parts. + pub fn is_empty(&self) -> bool { + self.content.is_empty() + } + /// Returns Some([`tool::Use`]) if the final [`Content`] [`Block`] is a /// [`Block::ToolUse`]. pub fn tool_use(&self) -> Option<&crate::tool::Use> { @@ -101,6 +93,8 @@ impl Message<'_> { /// Convert to a `'static` lifetime by taking ownership of the [`Cow`] /// fields. + /// + /// [`Cow`]: std::borrow::Cow pub fn into_static(self) -> Message<'static> { Message { role: self.role, @@ -154,9 +148,9 @@ impl crate::markdown::ToMarkdown for Message<'_> { /// [`Options`]: crate::markdown::Options fn markdown_events_custom<'a>( &'a self, - options: &'a crate::markdown::Options, + options: crate::markdown::Options, ) -> Box> + 'a> { - use pulldown_cmark::Event; + use pulldown_cmark::{Event, HeadingLevel::H3, Tag}; let content = self.content.markdown_events_custom(options); let role = match self.content.last() { @@ -182,10 +176,21 @@ impl crate::markdown::ToMarkdown for Message<'_> { } _ => self.role.as_str(), }; + let heading_tag = Tag::Heading { + level: options.heading_level.unwrap_or(H3), + id: None, + classes: vec![], + attrs: if options.attrs { + vec![("role".into(), Some(role.to_lowercase().into()))] + } else { + vec![] + }, + }; + let heading_end = heading_tag.to_end(); let heading = [ - Event::Start(Self::HEADING), + Event::Start(heading_tag), Event::Text(role.into()), - Event::End(Self::HEADING.to_end()), + Event::End(heading_end), ]; Box::new(heading.into_iter().chain(content)) @@ -205,7 +210,7 @@ impl std::fmt::Display for Message<'_> { #[derive(Clone, Debug, Serialize, Deserialize, derive_more::IsVariant)] #[serde(rename_all = "snake_case")] #[serde(untagged)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] pub enum Content<'a> { /// Single part text-only content. SinglePart(crate::CowStr<'a>), @@ -229,15 +234,16 @@ impl<'a> Content<'a> { Self::SinglePart(text.into()) } - /// Returns the number of [`Block`]s in `self`. + /// Returns the number of bytes in self. Does not include tool use or other + /// metadata. Does include the base64 encoded image data length. pub fn len(&self) -> usize { match self { - Self::SinglePart(_) => 1, - Self::MultiPart(parts) => parts.len(), + Self::SinglePart(s) => s.as_bytes().len(), + Self::MultiPart(parts) => parts.iter().map(Block::len).sum(), } } - /// Returns true if the content is empty. + /// Returns true if `self` is empty. pub fn is_empty(&self) -> bool { self.len() == 0 } @@ -320,6 +326,8 @@ impl<'a> Content<'a> { /// Convert to a `'static` lifetime by taking ownership of the [`Cow`] /// fields. + /// + /// [`Cow`]: std::borrow::Cow pub fn into_static(self) -> Content<'static> { match self { Self::SinglePart(text) => { @@ -343,8 +351,6 @@ impl<'a> Content<'a> { /// Push a [`Delta`] into the [`Content`]. The types must be compatible or /// this will return a [`ContentMismatch`] error. pub fn push_delta(&mut self, delta: Delta<'a>) -> Result<(), DeltaError> { - let delta = delta.into(); - match self { Self::SinglePart(_) => { let mut old = Content::MultiPart(vec![]); @@ -373,7 +379,7 @@ impl crate::markdown::ToMarkdown for Content<'_> { #[cfg(feature = "markdown")] fn markdown_events_custom<'a>( &'a self, - options: &'a crate::markdown::Options, + options: crate::markdown::Options, ) -> Box> + 'a> { use pulldown_cmark::Event; @@ -473,7 +479,7 @@ where #[cfg_attr(not(feature = "markdown"), derive(derive_more::Display))] #[serde(rename_all = "snake_case")] #[serde(tag = "type")] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] pub enum Block<'a> { /// Text content. #[serde(alias = "text_delta")] @@ -487,6 +493,7 @@ pub enum Block<'a> { cache_control: Option, }, /// Image content. + #[cfg_attr(not(feature = "markdown"), display("{}", image))] Image { #[serde(rename = "source")] /// An base64 encoded image. @@ -592,8 +599,21 @@ impl<'a> Block<'a> { }, Delta::Json { partial_json }, ) => { - *input = serde_json::from_str(&partial_json) - .map_err(|e| e.to_string())?; + use serde_json::Value::Object; + // Parse the partial json as an object and merge it into the + // input. + let partial_json: serde_json::Value = + serde_json::from_str(&partial_json).map_err(|e| { + DeltaError::Parse { + error: format!( + "Could not merge partial json `{}` into `{}` because {}", + partial_json, input, e + ), + } + })?; + if let (Object(new), Object(old)) = (partial_json, input) { + old.extend(new); + } } (this, acc) => { let variant_name = match this { @@ -664,6 +684,8 @@ impl<'a> Block<'a> { /// Convert to a `'static` lifetime by taking ownership of the [`Cow`] /// fields. + /// + /// [`Cow`]: std::borrow::Cow pub fn into_static(self) -> Block<'static> { match self { Self::Text { @@ -695,6 +717,18 @@ impl<'a> Block<'a> { }, } } + + /// Returns the number of bytes in the block. Does not include tool use or + /// other metadata. Does include the base64 encoded image data length. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + match self { + Self::Text { text, .. } => text.as_bytes().len(), + Self::Image { image, .. } => image.len(), + Self::ToolUse { .. } => 0, + Self::ToolResult { .. } => 0, + } + } } #[cfg(feature = "markdown")] @@ -706,7 +740,7 @@ impl crate::markdown::ToMarkdown for Block<'_> { #[cfg(feature = "markdown")] fn markdown_events_custom<'a>( &'a self, - options: &crate::markdown::Options, + options: crate::markdown::Options, ) -> Box> + 'a> { use pulldown_cmark::{CodeBlockKind, Event, Tag, TagEnd}; @@ -816,6 +850,7 @@ impl<'a> From> for Block<'a> { #[cfg(feature = "png")] impl From for Block<'_> { fn from(image: image::RgbaImage) -> Self { + #[allow(unused_variables)] // for the `e` variable Image::encode(MediaType::Png, image) // Unwrap can never panic unless the PNG encoding fails, which // should really never happen, but no matter what we don't panic. @@ -838,7 +873,7 @@ impl From for Block<'_> { /// Cache control for prompt caching. #[cfg(feature = "prompt-caching")] #[derive(Clone, Default, Debug, Serialize, Deserialize)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] pub enum CacheControl { @@ -851,7 +886,7 @@ pub enum CacheControl { /// /// [`MultiPart`]: Content::MultiPart #[derive(Clone, Debug, Serialize, Deserialize, derive_more::Display)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[serde(rename_all = "snake_case")] #[serde(tag = "type")] pub enum Image<'a> { @@ -923,6 +958,8 @@ impl Image<'_> { /// Convert to a `'static` lifetime by taking ownership of the [`Cow`] /// fields. + /// + /// [`Cow`]: std::borrow::Cow pub fn into_static(self) -> Image<'static> { match self { Self::Base64 { media_type, data } => Image::Base64 { @@ -934,6 +971,15 @@ impl Image<'_> { }, } } + + /// Returns the number of bytes in the image data (base64 encoded). Call + /// [`decode`] to get the actual image size. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + match self { + Self::Base64 { data, .. } => data.as_bytes().len(), + } + } } /// Errors that can occur when decoding an [`Image`]. @@ -961,20 +1007,16 @@ impl TryInto for Image<'_> { /// Encoding format for [`Image`]s. #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[serde(rename_all = "snake_case")] #[allow(missing_docs)] pub enum MediaType { - #[cfg(feature = "jpeg")] #[serde(rename = "image/jpeg")] Jpeg, - #[cfg(feature = "png")] #[serde(rename = "image/png")] Png, - #[cfg(feature = "gif")] #[serde(rename = "image/gif")] Gif, - #[cfg(feature = "webp")] #[serde(rename = "image/webp")] Webp, } @@ -995,13 +1037,9 @@ impl From for image::ImageFormat { /// A [`MediaType`] can always be converted into an [`image::ImageFormat`]. fn from(value: MediaType) -> image::ImageFormat { match value { - #[cfg(feature = "jpeg")] MediaType::Jpeg => image::ImageFormat::Jpeg, - #[cfg(feature = "png")] MediaType::Png => image::ImageFormat::Png, - #[cfg(feature = "gif")] MediaType::Gif => image::ImageFormat::Gif, - #[cfg(feature = "webp")] MediaType::Webp => image::ImageFormat::WebP, } } @@ -1024,13 +1062,9 @@ impl TryFrom for MediaType { /// [`UnsupportedImageFormat`] error. fn try_from(value: image::ImageFormat) -> Result { match value { - #[cfg(feature = "jpeg")] image::ImageFormat::Jpeg => Ok(Self::Jpeg), - #[cfg(feature = "png")] image::ImageFormat::Png => Ok(Self::Png), - #[cfg(feature = "gif")] image::ImageFormat::Gif => Ok(Self::Gif), - #[cfg(feature = "webp")] image::ImageFormat::WebP => Ok(Self::Webp), _ => Err(UnsupportedImageFormat(value)), } @@ -1041,6 +1075,9 @@ impl TryFrom for MediaType { mod tests { use std::vec; + #[cfg(feature = "markdown")] + use crate::markdown::ToMarkdown; + use super::*; pub const CONTENT_SINGLE: &str = "\"Hello, world!\""; @@ -1093,6 +1130,88 @@ mod tests { ); } + #[test] + fn test_message_is_empty() { + let message: Message = (Role::User, "Hello, world!").into(); + assert!(!message.is_empty()); + let message: Message = Message { + role: Role::User, + content: Content::MultiPart(vec![]), + }; + assert!(message.is_empty()); + } + + #[test] + fn test_message_tool_use() { + let tool_use: Message = tool::Use { + id: "tool_123".into(), + name: "tool".into(), + input: serde_json::json!({}), + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(); + + assert!(tool_use.tool_use().is_some()); + } + + #[test] + #[cfg(feature = "markdown")] + // mostly for coverage + fn test_into_static() { + let content: Content = "Hello, world!".into(); + let content: Content<'static> = content.into_static(); + assert_eq!(content.to_string(), "Hello, world!"); + + let content = Content::SinglePart("Hello, world!".into()); + let content: Content<'static> = content.into_static(); + assert_eq!(content.to_string(), "Hello, world!"); + + let block: Block = "Hello, world!".into(); + let block: Block<'static> = block.into_static(); + assert_eq!(block.to_string(), "Hello, world!"); + + let image: Image = Image::from_parts(MediaType::Png, String::new()); + let image: Image<'static> = image.into_static(); + assert_eq!(image.to_string(), "![Image](data:image/png;base64,)"); + + let tool_use: Block = tool::Use { + id: "tool_123".into(), + name: "tool".into(), + input: serde_json::json!({}), + #[cfg(feature = "prompt-caching")] + cache_control: None, + } + .into(); + let tool_use: Block<'static> = tool_use.into_static(); + assert_eq!( + tool_use.markdown_verbose().as_ref(), + "\n````json\n{\"type\":\"tool_use\",\"id\":\"tool_123\",\"name\":\"tool\",\"input\":{}}\n````" + ); + + let message: Message = (Role::User, "Hello, world!").into(); + let _: Message<'static> = message.into_static(); + } + + #[test] + fn test_push_delta() { + let mut content = Content::SinglePart("Hello, world!".into()); + content + .push_delta(Delta::Text { + text: " How are you?".into(), + }) + .unwrap(); + + assert_eq!(content.to_string(), "Hello, world! How are you?"); + assert!(content.is_multi_part()); + + // an incompatible delta + let err = content.push_delta(Delta::Json { + partial_json: "blabla".into(), + }); + assert!(err.is_err()); + } + #[test] #[cfg(feature = "markdown")] fn test_merge_deltas() { @@ -1137,13 +1256,23 @@ mod tests { // by default tool use is hidden let opts = crate::markdown::Options::default().with_tool_use(); - let markdown = block.markdown_custom(&opts); + let markdown = block.markdown_custom(opts); assert_eq!( markdown.as_ref(), "\n````json\n{\"type\":\"tool_use\",\"id\":\"tool_123\",\"name\":\"tool\",\"input\":{\"key\":\"value\"}}\n````" ); + // test junk json + let deltas = [Delta::Json { + partial_json: "blabla".into(), + }]; + let err = block.merge_deltas(deltas).unwrap_err(); + assert_eq!( + err.to_string(), + "Cannot apply delta because deserialization failed because: Could not merge partial json `blabla` into `{\"key\":\"value\"}` because expected value at line 1 column 1" + ); + // content mismatch let deltas = [Delta::Json { partial_json: "blabla".into(), @@ -1168,11 +1297,11 @@ mod tests { content: Content::SinglePart("Hello, world!".into()), }; - assert_eq!(message.len(), 1); + assert_eq!(message.len(), 13); message.content.push("How are you?"); - assert_eq!(message.len(), 2); + assert_eq!(message.len(), 25); } #[test] @@ -1237,6 +1366,12 @@ mod tests { assert_eq!(content.to_string(), "Hello, world!"); } + #[test] + fn test_content_from_slice_of_str() { + let content: Content = ["Hello, world!"].into(); + assert_eq!(content.to_string(), "Hello, world!"); + } + #[test] fn test_content_from_block() { let content: Content = Block::text("Hello, world!").into(); @@ -1244,14 +1379,35 @@ mod tests { } #[test] + #[cfg(feature = "markdown")] fn test_merge_deltas_error() { - let mut block: Block = "Hello, world!".into(); + let mut text_block: Block = "Hello, world!".into(); - let deltas = [Delta::Json { - partial_json: "blabla".into(), + let json_deltas = [Delta::Json { + partial_json: "{\"k\": \"v\"}".into(), }]; - let err = block.merge_deltas(deltas).unwrap_err(); + let err = text_block.merge_deltas(json_deltas).unwrap_err(); + + let mut json_block = Block::ToolUse { + call: tool::Use { + id: "tool_123".into(), + name: "tool".into(), + input: serde_json::json!({}), + #[cfg(feature = "prompt-caching")] + cache_control: None, + }, + }; + + let json_deltas = [Delta::Json { + partial_json: "{\"k\": \"v\"}".into(), + }]; + + json_block.merge_deltas(json_deltas).unwrap(); + assert_eq!( + json_block.markdown_verbose().as_ref(), + "\n````json\n{\"type\":\"tool_use\",\"id\":\"tool_123\",\"name\":\"tool\",\"input\":{\"k\":\"v\"}}\n````" + ); assert!(matches!(err, DeltaError::ContentMismatch { .. })); } @@ -1272,7 +1428,7 @@ mod tests { .with_tool_results(); assert_eq!( - message.markdown_custom(&opts).to_string(), + message.markdown_custom(opts).to_string(), "### User\n\nHello, world!" ); @@ -1286,7 +1442,7 @@ mod tests { }; assert_eq!( - message.markdown_custom(&opts).to_string(), + message.markdown_custom(opts).to_string(), "### Assistant\n\nHello, world!\n\nHow are you?" ); @@ -1301,7 +1457,7 @@ mod tests { .into(); assert_eq!( - message.markdown_custom(&opts).to_string(), + message.markdown_custom(opts).to_string(), "### Tool\n\n````json\n{\"type\":\"tool_result\",\"tool_use_id\":\"tool_123\",\"content\":\"Hello, world!\",\"is_error\":false}\n````" ); @@ -1316,7 +1472,7 @@ mod tests { .into(); assert_eq!( - message.markdown_custom(&opts).to_string(), + message.markdown_custom(opts).to_string(), "### Error\n\n````json\n{\"type\":\"tool_result\",\"tool_use_id\":\"tool_123\",\"content\":\"Hello, world!\",\"is_error\":true}\n````" ); } diff --git a/src/response/message.rs b/src/response/message.rs index 718ae21..8d7cc45 100644 --- a/src/response/message.rs +++ b/src/response/message.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; /// A [`prompt::message`] with additional response metadata. #[derive(Debug, Serialize, Deserialize, derive_more::Display)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[display("{}", message)] pub struct Message<'a> { /// Unique `id` for the message. @@ -70,7 +70,7 @@ impl Message<'_> { /// Reason the model stopped generating tokens. #[derive(Debug, Serialize, Deserialize)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[serde(rename_all = "snake_case")] pub enum StopReason { /// The model reached a natural stopping point. @@ -86,7 +86,7 @@ pub enum StopReason { /// Usage statistics from the API. This is used in multiple contexts, not just /// for messages. #[derive(Debug, Serialize, Deserialize, Default)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] pub struct Usage { /// Number of input tokens used. pub input_tokens: u64, @@ -104,7 +104,7 @@ pub struct Usage { impl crate::markdown::ToMarkdown for Message<'_> { fn markdown_events_custom<'a>( &'a self, - options: &'a crate::markdown::Options, + options: crate::markdown::Options, ) -> Box> + 'a> { self.message.markdown_events_custom(options) } @@ -139,7 +139,7 @@ mod tests { #[test] fn deserialize_response_message() { let message: Message = serde_json::from_str(RESPONSE_JSON).unwrap(); - assert_eq!(message.message.content.len(), 1); + assert_eq!(message.message.content.len(), 22); assert_eq!(message.id, "msg_013Zva2CMHLNnXjNJJKqJ2EF"); assert_eq!(message.model, crate::Model::Sonnet35_20240620); assert!(matches!(message.stop_reason, Some(StopReason::EndTurn))); @@ -186,4 +186,53 @@ mod tests { }); assert!(message.tool_use().is_some()); } + + #[test] + fn test_into_static() { + // Refers to json: + let message: Message = serde_json::from_str(RESPONSE_JSON).unwrap(); + // Owns the `Cow` fields: + let static_message = message.into_static(); + + assert_eq!(static_message.id, "msg_013Zva2CMHLNnXjNJJKqJ2EF"); + assert_eq!(static_message.model, crate::Model::Sonnet35_20240620); + assert!(matches!( + static_message.stop_reason, + Some(StopReason::EndTurn) + )); + assert_eq!(static_message.stop_sequence, None); + assert_eq!(static_message.usage.input_tokens, 2095); + assert_eq!(static_message.usage.output_tokens, 503); + } + + #[test] + #[cfg(feature = "markdown")] + fn test_markdown() { + use crate::markdown::ToMarkdown; + + let message = Message { + id: "id".into(), + message: prompt::Message { + role: prompt::message::Role::User, + content: prompt::message::Content::SinglePart( + "Hello, **world**!".into(), + ), + }, + model: crate::Model::Sonnet35, + stop_reason: None, + stop_sequence: None, + usage: Usage { + input_tokens: 1, + #[cfg(feature = "prompt-caching")] + cache_creation_input_tokens: Some(2), + #[cfg(feature = "prompt-caching")] + cache_read_input_tokens: Some(3), + output_tokens: 4, + }, + }; + + let expected = "### User\n\nHello, **world**!"; + let markdown = message.markdown(); + assert_eq!(markdown.as_ref(), expected); + } } diff --git a/src/tool.rs b/src/tool.rs index 5d99451..6878cd7 100644 --- a/src/tool.rs +++ b/src/tool.rs @@ -12,7 +12,8 @@ use serde::{Deserialize, Serialize}; /// [`prompt::message`]: crate::prompt::message #[derive(Serialize, Deserialize)] #[serde(rename_all = "snake_case", tag = "type")] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] +#[cfg_attr(test, derive(Debug))] pub enum Choice { /// Model chooses which tool to use, or no tool at all. Auto, @@ -28,8 +29,9 @@ pub enum Choice { /// A tool a model can use while completing a [`prompt::Message`]. /// /// [`prompt::Message`]: crate::prompt::Message -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[derive(Serialize, Deserialize)] +#[serde(try_from = "ToolBuilder<'a>")] pub struct Tool<'a> { /// Name of the tool. pub name: Cow<'a, str>, @@ -50,12 +52,62 @@ pub struct Tool<'a> { pub cache_control: Option, } +impl<'a> TryFrom> for Tool<'a> { + type Error = ToolBuildError; + + fn try_from( + builder: ToolBuilder<'a>, + ) -> std::result::Result { + builder.build() + } +} + /// A builder for creating a [`Tool`] with some basic validation. See /// [`Tool::builder`] to create a new builder. pub struct ToolBuilder<'a> { tool: Tool<'a>, } +// ToolBuilder must implement Deserialize but we can't derive it because it +// would recursively require Tool to implement Deserialize, so we have to +// implement it manually. This is a bit ugly, but it works and ensures that +// a Tool is always valid when deserialized. +impl<'de> Deserialize<'de> for ToolBuilder<'_> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct Foreign { + name: Cow<'static, str>, + description: Cow<'static, str>, + input_schema: serde_json::Value, + #[cfg(feature = "prompt-caching")] + cache_control: Option, + } + + let foreign = Foreign::deserialize(deserializer)?; + + let Foreign { + name, + description, + input_schema, + #[cfg(feature = "prompt-caching")] + cache_control, + } = foreign; + + Ok(ToolBuilder { + tool: Tool { + name, + description, + input_schema, + #[cfg(feature = "prompt-caching")] + cache_control, + }, + }) + } +} + impl<'a> ToolBuilder<'a> { /// Set the description for the tool. pub fn description(mut self, description: impl Into>) -> Self { @@ -122,6 +174,10 @@ impl<'a> ToolBuilder<'a> { schema: &serde_json::Value, ) -> std::result::Result<(), Cow<'static, str>> { let obj = if let Some(obj) = schema.as_object() { + if obj.is_empty() { + return Err("Input `schema` is an empty object.".into()); + } + obj } else { return Err(format!( @@ -262,7 +318,7 @@ impl<'a> Tool<'a> { // Serialize. This is a bit of a hack but it works. pub fn from_serializable( value: T, - ) -> std::result::Result + ) -> std::result::Result, serde_json::Error> where T: Serialize, { @@ -271,13 +327,16 @@ impl<'a> Tool<'a> { } } -impl TryFrom for Tool<'_> { +impl TryFrom for Tool<'static> { type Error = serde_json::Error; fn try_from( value: serde_json::Value, ) -> std::result::Result { - serde_json::from_value(value) + let builder: ToolBuilder<'static> = serde_json::from_value(value)?; + builder + .build() + .map_err(|e| serde::de::Error::custom(e.to_string())) } } @@ -290,7 +349,7 @@ impl TryFrom for Tool<'_> { derive(derive_more::Display), display("\n````json\n{}\n````\n", serde_json::to_string_pretty(self).unwrap()) )] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Use<'a> { /// Unique Id for this tool call. @@ -341,7 +400,7 @@ impl TryFrom for Use<'_> { impl crate::markdown::ToMarkdown for Use<'_> { fn markdown_events_custom<'a>( &'a self, - options: &'a crate::markdown::Options, + options: crate::markdown::Options, ) -> Box> + 'a> { use pulldown_cmark::{CodeBlockKind, Event, Tag, TagEnd}; @@ -378,7 +437,7 @@ impl std::fmt::Display for Use<'_> { /// [`User`]: crate::prompt::message::Role::User /// [`Message`]: crate::prompt::message #[derive(Clone, Debug, Serialize, Deserialize)] -#[cfg_attr(any(feature = "partial_eq", test), derive(PartialEq))] +#[cfg_attr(any(feature = "partial-eq", test), derive(PartialEq))] // On the one hand this can clash with the `Result` type from the standard // library, but on the other hand it's what the API uses, and I'm trying to // be as faithful to the API as possible. @@ -465,4 +524,386 @@ mod tests { // crate. assert_eq!(use_.to_string(), ""); } + + #[test] + fn test_tool_schema_validation() { + let schema = serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string"], + }); + + assert!(ToolBuilder::is_valid_input_schema(&schema).is_ok()); + + let schema = serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": "letter", + }); + + assert!(ToolBuilder::is_valid_input_schema(&schema).is_err()); + } + + #[test] + fn test_build() { + let tool = Tool::builder("test_name") + .description("test_description") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string"], + })) + .build() + .unwrap(); + + assert_eq!(tool.name, "test_name"); + assert_eq!(tool.description, "test_description"); + assert_eq!( + tool.input_schema, + serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string"], + }) + ); + + // Test error cases + let tool = Tool::builder("test_name") + .description("test_description") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": "letter", + })) + .build(); + + assert!(matches!( + tool, + Err(ToolBuildError::InvalidInputSchema { .. }) + )); + + // input schema not an object + let tool = Tool::builder("test_name") + .description("test_description") + .schema(serde_json::Value::String("blah".into())) + .build(); + + assert!(matches!( + tool, + Err(ToolBuildError::InvalidInputSchema { .. }) + )); + + // Properties not an object + let tool = Tool::builder("test_name") + .description("test_description") + .schema(serde_json::json!({ + "type": "object", + "properties": "blah", + "required": ["letter", "string"], + })) + .build(); + + assert!(matches!( + tool, + Err(ToolBuildError::InvalidInputSchema { .. }) + )); + + // Schema does not have properties + let tool = Tool::builder("test_name") + .description("test_description") + .schema(serde_json::json!({ + "type": "object", + "required": ["letter", "string"], + })) + .build(); + + assert!(matches!( + tool, + Err(ToolBuildError::InvalidInputSchema { .. }) + )); + + // Schema does not have `required` keys (empty array allowed, but it + // must be present) + let tool = Tool::builder("test_name") + .description("test_description") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + })) + .build(); + + assert!(matches!( + tool, + Err(ToolBuildError::InvalidInputSchema { .. }) + )); + + // required keys not found in properties + let tool = Tool::builder("test_name") + .description("test_description") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string", "foo"], + })) + .build(); + + assert!(matches!( + tool, + Err(ToolBuildError::InvalidInputSchema { .. }) + )); + + // required keys not strings + let tool = Tool::builder("test_name") + .description("test_description") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": [1, 2], + })) + .build(); + + assert!(matches!( + tool, + Err(ToolBuildError::InvalidInputSchema { .. }) + )); + + // missing schema + let tool = Tool::builder("test_name") + .description("test_description") + .build(); + + assert!(matches!(tool, Err(ToolBuildError::EmptyInputSchema))); + + // with missing names and descriptions + let tool = Tool::builder("") + .description("foo") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string"], + })) + .build(); + + assert!(matches!(tool, Err(ToolBuildError::EmptyName))); + + let tool = Tool::builder("foo") + .description("") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string"], + })) + .build(); + + assert!(matches!(tool, Err(ToolBuildError::EmptyDescription))); + } + + #[test] + fn test_choice_serde() { + let choice = Choice::Auto; + let json = serde_json::to_string(&choice).unwrap(); + let choice2: Choice = serde_json::from_str(&json).unwrap(); + assert_eq!(choice, choice2); + + let choice = Choice::Any; + let json = serde_json::to_string(&choice).unwrap(); + let choice2: Choice = serde_json::from_str(&json).unwrap(); + assert_eq!(choice, choice2); + + let choice = Choice::Tool { + name: "test_name".into(), + }; + let json = serde_json::to_string(&choice).unwrap(); + let choice2: Choice = serde_json::from_str(&json).unwrap(); + assert_eq!(choice, choice2); + } + + #[test] + fn test_result_serde() { + let result = Result { + tool_use_id: "test_id".into(), + content: "test_content".into(), + is_error: false, + #[cfg(feature = "prompt-caching")] + cache_control: None, + }; + + let json = serde_json::to_string(&result).unwrap(); + let result2: Result = serde_json::from_str(&json).unwrap(); + assert_eq!(result, result2); + } + + #[test] + fn test_result_into_static() { + let result = Result { + tool_use_id: "test_id".into(), + content: "test_content".into(), + is_error: false, + #[cfg(feature = "prompt-caching")] + cache_control: None, + }; + + let result = result.into_static(); + + assert_eq!(result.tool_use_id, "test_id"); + assert_eq!(result.content.to_string(), "test_content"); + assert_eq!(result.is_error, false); + } + + #[test] + fn test_tool_from_serializable() { + let tool = Tool::from_serializable(serde_json::json!({ + "name": "test_name", + "description": "test_description", + "input_schema": { + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string"], + }, + })) + .unwrap(); + + assert_eq!(tool.name, "test_name"); + assert_eq!(tool.description, "test_description"); + assert_eq!( + tool.input_schema, + serde_json::json!({ + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + "required": ["letter", "string"], + }) + ); + + // Test invalid schema. Comprehensive testing of this is in the builder + // tests. This just makes sure that the error is propagated. + let tool = Tool::from_serializable(serde_json::json!({ + "name": "test_name", + "description": "test_description", + "input_schema": { + "type": "object", + "properties": { + "letter": { + "type": "string", + "description": "The letter to count", + }, + "string": { + "type": "string", + "description": "The string to count letters in", + }, + }, + // should be an array + "required": "letter", + }, + })); + + assert!(tool.is_err()); + } }