Skip to content

Commit

Permalink
* prep refactor for custom endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremychone committed Dec 8, 2024
1 parent 175ed48 commit 011fb40
Show file tree
Hide file tree
Showing 21 changed files with 352 additions and 227 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repository = "https://github.com/jeremychone/rust-genai"

[lints.rust]
unsafe_code = "forbid"
# unused = { level = "allow", priority = -1 } # For exploratory dev.
unused = { level = "allow", priority = -1 } # For exploratory dev.
# missing_docs = "warn"

[dependencies]
Expand Down
2 changes: 1 addition & 1 deletion examples/c00-readme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
continue;
}

let adapter_kind = client.resolve_model_iden(model)?.adapter_kind;
let adapter_kind = client.resolve_service_target(model)?.model.adapter_kind;

println!("\n===== MODEL: {model} ({adapter_kind}) =====");

Expand Down
2 changes: 1 addition & 1 deletion examples/c04-chat-options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("\n--- Question:\n{question}");
let chat_res = client.exec_chat_stream(MODEL, chat_req.clone(), Some(&options)).await?;

let adapter_kind = client.resolve_model_iden(MODEL)?.adapter_kind;
let adapter_kind = client.resolve_service_target(MODEL)?.model.adapter_kind;
println!("\n--- Answer: ({MODEL} - {adapter_kind})");
print_chat_stream(chat_res, None).await?;

Expand Down
12 changes: 6 additions & 6 deletions src/adapter/adapter_kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ impl AdapterKind {
}

/// Utilities
impl AdapterKind {
/// Get the default key environment variable name for the adapter kind.
pub fn default_key_env_name(&self) -> Option<&'static str> {
AdapterDispatcher::default_key_env_name(*self)
}
}
// impl AdapterKind {
// /// Get the default key environment variable name for the adapter kind.
// pub fn default_key_env_name(&self) -> Option<&'static str> {
// AdapterDispatcher::default_key_env_name(*self)
// }
// }

/// From Model implementations
impl AdapterKind {
Expand Down
14 changes: 10 additions & 4 deletions src/adapter/adapter_types.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
use crate::adapter::AdapterKind;
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::WebResponse;
use crate::Result;
use crate::{ClientConfig, ModelIden};
use crate::{Result, ServiceTarget};
use reqwest::RequestBuilder;
use serde_json::Value;

pub trait Adapter {
fn default_key_env_name(kind: AdapterKind) -> Option<&'static str>;
// #[deprecated(note = "use default_auth")]
// fn default_key_env_name(kind: AdapterKind) -> Option<&'static str>;

fn default_auth(kind: AdapterKind) -> AuthData;

fn default_endpoint(kind: AdapterKind) -> Endpoint;

// NOTE: Adapter is a crate trait, so it is acceptable to use async fn here.
async fn all_model_names(kind: AdapterKind) -> Result<Vec<String>>;

/// The base service URL for this AdapterKind for the given service type.
/// NOTE: For some services, the URL will be further updated in the to_web_request_data method.
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String;
fn get_service_url(model_iden: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String;

/// To be implemented by Adapters.
fn to_web_request_data(
model_iden: ModelIden,
service_target: ServiceTarget,
config_set: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
Expand Down
45 changes: 28 additions & 17 deletions src/adapter/adapters/anthropic/adapter_impl.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
use crate::adapter::adapters::support::get_api_key;
use crate::adapter::anthropic::AnthropicStreamer;
use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
ToolCall,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::WebResponse;
use crate::Result;
use crate::{ClientConfig, ModelIden};
use crate::{Result, ServiceTarget};
use reqwest::RequestBuilder;
use reqwest_eventsource::EventSource;
use serde_json::{json, Value};
use value_ext::JsonValueExt;

pub struct AnthropicAdapter;

const BASE_URL: &str = "https://api.anthropic.com/v1/";

// NOTE: For Anthropic, the max_tokens must be specified.
// To avoid surprises, the default value for genai is the maximum for a given model.
// The 3-5 models have an 8k max token limit, while the 3 models have a 4k limit.
Expand All @@ -32,49 +31,61 @@ const MODELS: &[&str] = &[
];

impl Adapter for AnthropicAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
Some("ANTHROPIC_API_KEY")
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://api.anthropic.com/v1/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_env("ANTHROPIC_API_KEY")
}

/// Note: For now, it returns the common models (see above)
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}

fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
let base_url = endpoint.base_url();
match service_type {
ServiceType::Chat | ServiceType::ChatStream => format!("{BASE_URL}messages"),
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}messages"),
}
}

fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let model_name = model_iden.model_name.clone();
let ServiceTarget { endpoint, auth, model } = target;

let stream = matches!(service_type, ServiceType::ChatStream);
let url = Self::get_service_url(model_iden.clone(), service_type);
// -- api_key
let api_key = get_api_key(auth, &model)?;

// -- api_key (this Adapter requires it)
let api_key = get_api_key(model_iden.clone(), client_config)?;
// -- url
let url = Self::get_service_url(&model, service_type, endpoint);

// -- headers
let headers = vec![
// headers
("x-api-key".to_string(), api_key.to_string()),
("x-api-key".to_string(), api_key),
("anthropic-version".to_string(), ANTRHOPIC_VERSION.to_string()),
];

let model_name = model.model_name.clone();

// -- Parts
let AnthropicRequestParts {
system,
messages,
tools,
} = Self::into_anthropic_request_parts(model_iden.clone(), chat_req)?;
} = Self::into_anthropic_request_parts(model, chat_req)?;

// -- Build the basic payload

let stream = matches!(service_type, ServiceType::ChatStream);
let mut payload = json!({
"model": model_name.to_string(),
"messages": messages,
Expand All @@ -99,7 +110,7 @@ impl Adapter for AnthropicAdapter {
}

let max_tokens = options_set.max_tokens().unwrap_or_else(|| {
if model_iden.model_name.contains("3-5") {
if model_name.contains("3-5") {
MAX_TOKENS_8K
} else {
MAX_TOKENS_4K
Expand Down
40 changes: 25 additions & 15 deletions src/adapter/adapters/cohere/adapter_impl.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
use crate::adapter::adapters::support::get_api_key;
use crate::adapter::cohere::CohereStreamer;
use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::{WebResponse, WebStream};
use crate::{ClientConfig, ModelIden};
use crate::{ClientConfig, ModelIden, ServiceTarget};
use crate::{Error, Result};
use reqwest::RequestBuilder;
use serde_json::{json, Value};
use value_ext::JsonValueExt;

pub struct CohereAdapter;

const BASE_URL: &str = "https://api.cohere.com/v1/";
const MODELS: &[&str] = &[
"command-r-plus",
"command-r",
Expand All @@ -24,49 +24,59 @@ const MODELS: &[&str] = &[
];

impl Adapter for CohereAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
Some("COHERE_API_KEY")
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://api.cohere.com/v1/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_env("COHERE_API_KEY")
}

/// Note: For now, it returns the common ones (see above)
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}

fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
let base_url = endpoint.base_url();
match service_type {
ServiceType::Chat | ServiceType::ChatStream => format!("{BASE_URL}chat"),
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}chat"),
}
}

fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let model_name = model_iden.model_name.clone();

let stream = matches!(service_type, ServiceType::ChatStream);

let url = Self::get_service_url(model_iden.clone(), service_type);
let ServiceTarget { endpoint, auth, model } = target;

// -- api_key (this Adapter requires it)
let api_key = get_api_key(model_iden.clone(), client_config)?;
let api_key = get_api_key(auth, &model)?;

// -- url
let url = Self::get_service_url(&model, service_type, endpoint);

// -- headers
let headers = vec![
// headers
("Authorization".to_string(), format!("Bearer {api_key}")),
];

let model_name = model.model_name.clone();

// -- parts
let CohereChatRequestParts {
preamble,
message,
chat_history,
} = Self::into_cohere_request_parts(model_iden, chat_req)?;
} = Self::into_cohere_request_parts(model, chat_req)?;

// -- Build the basic payload
let stream = matches!(service_type, ServiceType::ChatStream);
let mut payload = json!({
"model": model_name.to_string(),
"message": message,
Expand Down
53 changes: 33 additions & 20 deletions src/adapter/adapters/gemini/adapter_impl.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use crate::adapter::adapters::support::get_api_key;
use crate::adapter::gemini::GeminiStreamer;
use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
MessageContent, MetaUsage,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::{WebResponse, WebStream};
use crate::{ClientConfig, ModelIden};
use crate::{ClientConfig, ModelIden, ServiceTarget};
use crate::{Error, Result};
use reqwest::RequestBuilder;
use serde_json::{json, Value};
use value_ext::JsonValueExt;

pub struct GeminiAdapter;

const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
const MODELS: &[&str] = &[
"gemini-1.5-pro",
"gemini-1.5-flash",
Expand All @@ -29,49 +29,62 @@ const MODELS: &[&str] = &[
// -X POST 'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'

impl Adapter for GeminiAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
Some("GEMINI_API_KEY")
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_env("GEMINI_API_KEY")
}

/// Note: For now, this returns the common models (see above)
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}

fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
/// NOTE: As Google Gemini has decided to put their API_KEY in the URL,
/// this will return the URL without the API_KEY in it. The API_KEY will need to be added by the caller.
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
let base_url = endpoint.base_url();
let model_name = model.model_name.clone();
match service_type {
ServiceType::Chat | ServiceType::ChatStream => BASE_URL.to_string(),
ServiceType::Chat => format!("{base_url}models/{model_name}:generateContent"),
ServiceType::ChatStream => format!("{base_url}models/{model_name}:streamGenerateContent"),
}
}

fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let api_key = get_api_key(model_iden.clone(), client_config)?;
let ServiceTarget { endpoint, auth, model } = target;

// For Gemini, the service URL returned is just the base URL
// since the model and API key are part of the URL (see below)
let url = Self::get_service_url(model_iden.clone(), service_type);
// -- api_key
let api_key = get_api_key(auth, &model)?;

// -- url
// NOTE: Somehow, Google decided to put the API key in the URL.
// This should be considered an antipattern from a security point of view
// even if it is done by the well respected Google. Everybody can make mistake once in a while.
// e.g., '...models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
let model_name = &*model_iden.model_name;
let url = match service_type {
ServiceType::Chat => format!("{url}models/{model_name}:generateContent?key={api_key}"),
ServiceType::ChatStream => format!("{url}models/{model_name}:streamGenerateContent?key={api_key}"),
};
let url = Self::get_service_url(&model, service_type, endpoint);
let url = format!("{url}?key={api_key}");

let headers = vec![];

let GeminiChatRequestParts { system, contents } = Self::into_gemini_request_parts(model_iden, chat_req)?;
// -- parts
let GeminiChatRequestParts { system, contents } = Self::into_gemini_request_parts(model, chat_req)?;

// -- Playload
let mut payload = json!({
"contents": contents,
});

// -- headers (empty for gemini, since API_KEY is in url)
let headers = vec![];

// Note: It's unclear from the spec if the content of systemInstruction should have a role.
// Right now, it is omitted (since the spec states it can only be "user" or "model")
// It seems to work. https://ai.google.dev/api/rest/v1beta/models/generateContent
Expand Down
Loading

0 comments on commit 011fb40

Please sign in to comment.