Skip to content

Commit

Permalink
. refactor - decouple AdapterDispatcher from Adapter trait
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremychone committed Dec 9, 2024
1 parent c47ae85 commit a64126b
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 46 deletions.
8 changes: 3 additions & 5 deletions examples/c06-target-resolver.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
//! This example demonstrates how to use a custom ServiceTargetResolver which gives full control of the final
//! mapping for Endpoint, Model/AdapterKind, and Auth
//! This example demonstrates how to use a custom ServiceTargetResolver which gives full control over the final
//! mapping for Endpoint, Model/AdapterKind, and Auth.
//!
//! IMPORTANT - Here we are using xAI as an example of a custom ServiceTarget.
//! It works with regular chat using the basic OpenAIAdapter,
//! but for streaming, xAI does not follow OpenAI's specifications.
//! Therefore, below we use regular chat, and this crate provides an XaiAdapter.
//! However, there is now an XaiAdapter, which gets activated on `starts_with("grok")`.
use genai::adapter::AdapterKind;
use genai::chat::{ChatMessage, ChatRequest};
Expand Down
4 changes: 2 additions & 2 deletions src/adapter/adapter_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ pub trait Adapter {
// #[deprecated(note = "use default_auth")]
// fn default_key_env_name(kind: AdapterKind) -> Option<&'static str>;

fn default_auth(kind: AdapterKind) -> AuthData;
fn default_auth() -> AuthData;

fn default_endpoint(kind: AdapterKind) -> Endpoint;
fn default_endpoint() -> Endpoint;

// NOTE: Adapter is a crate trait, so it is acceptable to use async fn here.
async fn all_model_names(kind: AdapterKind) -> Result<Vec<String>>;
Expand Down
4 changes: 2 additions & 2 deletions src/adapter/adapters/anthropic/adapter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ const MODELS: &[&str] = &[
];

impl Adapter for AnthropicAdapter {
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "https://api.anthropic.com/v1/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(_kind: AdapterKind) -> AuthData {
fn default_auth() -> AuthData {
AuthData::from_env("ANTHROPIC_API_KEY")
}

Expand Down
4 changes: 2 additions & 2 deletions src/adapter/adapters/cohere/adapter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ const MODELS: &[&str] = &[
];

impl Adapter for CohereAdapter {
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "https://api.cohere.com/v1/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(_kind: AdapterKind) -> AuthData {
fn default_auth() -> AuthData {
AuthData::from_env("COHERE_API_KEY")
}

Expand Down
4 changes: 2 additions & 2 deletions src/adapter/adapters/gemini/adapter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ const MODELS: &[&str] = &[
// -X POST 'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'

impl Adapter for GeminiAdapter {
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(_kind: AdapterKind) -> AuthData {
fn default_auth() -> AuthData {
AuthData::from_env("GEMINI_API_KEY")
}

Expand Down
4 changes: 2 additions & 2 deletions src/adapter/adapters/groq/adapter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ pub(in crate::adapter) const MODELS: &[&str] = &[

// The Groq API adapter is modeled after the OpenAI adapter, as the Groq API is compatible with the OpenAI API.
impl Adapter for GroqAdapter {
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "https://api.groq.com/openai/v1/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(_kind: AdapterKind) -> AuthData {
fn default_auth() -> AuthData {
AuthData::from_env("GROQ_API_KEY")
}

Expand Down
6 changes: 3 additions & 3 deletions src/adapter/adapters/ollama/adapter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ pub struct OllamaAdapter;
/// (https://github.com/ollama/ollama/blob/main/docs/openai.md)
/// Since the base Ollama API supports `application/x-ndjson` for streaming, whereas others support `text/event-stream`
impl Adapter for OllamaAdapter {
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "http://localhost:11434/v1/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(_kind: AdapterKind) -> AuthData {
fn default_auth() -> AuthData {
AuthData::from_single("ollama")
}

Expand All @@ -33,7 +33,7 @@ impl Adapter for OllamaAdapter {
/// Later, we might add another function with a endpoint, so the the user can give an custom endpoint.
async fn all_model_names(adapter_kind: AdapterKind) -> Result<Vec<String>> {
// FIXME: This is harcoded to the default endpoint, should take endpoint as Argument
let endpoint = Self::default_endpoint(adapter_kind);
let endpoint = Self::default_endpoint();
let base_url = endpoint.base_url();
let url = format!("{base_url}models");

Expand Down
4 changes: 2 additions & 2 deletions src/adapter/adapters/openai/adapter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ const MODELS: &[&str] = &[
];

impl Adapter for OpenAIAdapter {
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "https://api.openai.com/v1/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(_kind: AdapterKind) -> AuthData {
fn default_auth() -> AuthData {
AuthData::from_env("OPENAI_API_KEY")
}

Expand Down
4 changes: 2 additions & 2 deletions src/adapter/adapters/xai/adapter_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ pub(in crate::adapter) const MODELS: &[&str] = &["grok-beta"];

// The Groq API adapter is modeled after the OpenAI adapter, as the Groq API is compatible with the OpenAI API.
impl Adapter for XaiAdapter {
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "https://api.x.ai/v1/";
Endpoint::from_static(BASE_URL)
}

fn default_auth(_kind: AdapterKind) -> AuthData {
fn default_auth() -> AuthData {
AuthData::from_env("XAI_API_KEY")
}

Expand Down
49 changes: 27 additions & 22 deletions src/adapter/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,39 @@ use super::groq::GroqAdapter;
use crate::adapter::xai::XaiAdapter;
use crate::resolver::{AuthData, Endpoint};

/// A construct that allows dispatching calls to the Adapters.
///
/// Note 1: This struct does not need to implement the Adapter trait, as some of its methods take the adapter_kind as a parameter.
///
/// Note 2: This struct might be renamed to avoid confusion with the traditional Rust dispatcher pattern.
pub struct AdapterDispatcher;

impl Adapter for AdapterDispatcher {
fn default_endpoint(kind: AdapterKind) -> Endpoint {
impl AdapterDispatcher {
pub fn default_endpoint(kind: AdapterKind) -> Endpoint {
match kind {
AdapterKind::OpenAI => OpenAIAdapter::default_endpoint(kind),
AdapterKind::Anthropic => AnthropicAdapter::default_endpoint(kind),
AdapterKind::Cohere => CohereAdapter::default_endpoint(kind),
AdapterKind::Ollama => OllamaAdapter::default_endpoint(kind),
AdapterKind::Gemini => GeminiAdapter::default_endpoint(kind),
AdapterKind::Groq => GroqAdapter::default_endpoint(kind),
AdapterKind::Xai => XaiAdapter::default_endpoint(kind),
AdapterKind::OpenAI => OpenAIAdapter::default_endpoint(),
AdapterKind::Anthropic => AnthropicAdapter::default_endpoint(),
AdapterKind::Cohere => CohereAdapter::default_endpoint(),
AdapterKind::Ollama => OllamaAdapter::default_endpoint(),
AdapterKind::Gemini => GeminiAdapter::default_endpoint(),
AdapterKind::Groq => GroqAdapter::default_endpoint(),
AdapterKind::Xai => XaiAdapter::default_endpoint(),
}
}

fn default_auth(kind: AdapterKind) -> AuthData {
pub fn default_auth(kind: AdapterKind) -> AuthData {
match kind {
AdapterKind::OpenAI => OpenAIAdapter::default_auth(kind),
AdapterKind::Anthropic => AnthropicAdapter::default_auth(kind),
AdapterKind::Cohere => CohereAdapter::default_auth(kind),
AdapterKind::Ollama => OllamaAdapter::default_auth(kind),
AdapterKind::Gemini => GeminiAdapter::default_auth(kind),
AdapterKind::Groq => GroqAdapter::default_auth(kind),
AdapterKind::Xai => XaiAdapter::default_auth(kind),
AdapterKind::OpenAI => OpenAIAdapter::default_auth(),
AdapterKind::Anthropic => AnthropicAdapter::default_auth(),
AdapterKind::Cohere => CohereAdapter::default_auth(),
AdapterKind::Ollama => OllamaAdapter::default_auth(),
AdapterKind::Gemini => GeminiAdapter::default_auth(),
AdapterKind::Groq => GroqAdapter::default_auth(),
AdapterKind::Xai => XaiAdapter::default_auth(),
}
}

async fn all_model_names(kind: AdapterKind) -> Result<Vec<String>> {
pub async fn all_model_names(kind: AdapterKind) -> Result<Vec<String>> {
match kind {
AdapterKind::OpenAI => OpenAIAdapter::all_model_names(kind).await,
AdapterKind::Anthropic => AnthropicAdapter::all_model_names(kind).await,
Expand All @@ -53,7 +58,7 @@ impl Adapter for AdapterDispatcher {
}
}

fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
pub fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
match model.adapter_kind {
AdapterKind::OpenAI => OpenAIAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::Anthropic => AnthropicAdapter::get_service_url(model, service_type, endpoint),
Expand All @@ -65,7 +70,7 @@ impl Adapter for AdapterDispatcher {
}
}

fn to_web_request_data(
pub fn to_web_request_data(
target: ServiceTarget,
service_type: ServiceType,
chat_req: ChatRequest,
Expand All @@ -85,7 +90,7 @@ impl Adapter for AdapterDispatcher {
}
}

fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
pub fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
match model_iden.adapter_kind {
AdapterKind::OpenAI => OpenAIAdapter::to_chat_response(model_iden, web_response),
AdapterKind::Anthropic => AnthropicAdapter::to_chat_response(model_iden, web_response),
Expand All @@ -97,7 +102,7 @@ impl Adapter for AdapterDispatcher {
}
}

fn to_chat_stream(
pub fn to_chat_stream(
model_iden: ModelIden,
reqwest_builder: RequestBuilder,
options_set: ChatOptionsSet<'_, '_>,
Expand Down
2 changes: 1 addition & 1 deletion src/client/client_impl.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind, ServiceType, WebRequestData};
use crate::adapter::{AdapterDispatcher, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ChatOptions, ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::{Client, Error, ModelIden, Result, ServiceTarget};

Expand Down
2 changes: 1 addition & 1 deletion src/client/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::adapter::{Adapter, AdapterDispatcher};
use crate::adapter::AdapterDispatcher;
use crate::chat::ChatOptions;
use crate::client::ServiceTarget;
use crate::resolver::{AuthResolver, ModelMapper, ServiceTargetResolver};
Expand Down

0 comments on commit a64126b

Please sign in to comment.