diff --git a/.vscode/cspell.json b/.vscode/cspell.json index a22575878d..991ae2753a 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -32,6 +32,7 @@ "downcasted", "downcasting", "entra", + "endregion", "etag", "eventhub", "eventhubs", @@ -39,6 +40,7 @@ "iothub", "keyvault", "msrc", + "openai", "pageable", "pkce", "pkcs", diff --git a/Cargo.toml b/Cargo.toml index 694c89ae73..69fb2f6d51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "eng/test/mock_transport", "sdk/storage", "sdk/storage/azure_storage_blob", + "sdk/openai/inference", ] [workspace.package] diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml new file mode 100644 index 0000000000..0f1700da99 --- /dev/null +++ b/sdk/openai/inference/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "azure_openai_inference" +version = "1.0.0-beta.1" +description = "Rust client library for Azure OpenAI Inference" +readme = "README.md" +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true +keywords = ["sdk", "azure", "rest"] +categories = ["api-bindings"] + +[lints] +workspace = true + +[dependencies] +azure_core = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +bytes = { workspace = true } +typespec_client_core = { workspace = true, features = ["derive"] } + +[dev-dependencies] +azure_core = { workspace = true, features = ["reqwest"] } +azure_identity = { workspace = true } +reqwest = { workspace = true } +tokio = { workspace = true } diff --git a/sdk/openai/inference/README.md b/sdk/openai/inference/README.md new file mode 100644 index 0000000000..92306126f5 --- /dev/null +++ b/sdk/openai/inference/README.md @@ -0,0 +1,39 @@ +# Azure OpenAI Inference SDK for Rust + +## Introduction + +This SDK provides Rust types to interact with both OpenAI and Azure OpenAI services. + +Note: Currently request and response models have as few fields as possible, leveraging the server side defaults wherever they can. + +### Features + +All features are showcased in the `example` folder of this crate. The following is a list of what is currently supported: + +- Supporting both usage with OpenAI and Azure OpenAI services by using `OpenAIClient` or `AzureOpenAIClient`, respectively. +- Key credential authentication is supported. +- [Azure Only] Azure Active Directory (AAD) authentication is supported. +- `ChatCompletions` operation supported (limited fields). +- Streaming for `ChatCompletions` is supported + +## Authentication methods + +### Azure Active Directory + +This authentication method is only supported for Azure OpenAI services. + +```rust +AzureOpenAIClient::new( + endpoint, + Arc::new(DefaultAzureCredentialBuilder::new().build()?), + None, +)? +``` + +### Key Credentials + +This method of authentication is supported both for Azure and non-Azure OpenAI services. + +```rust +OpenAIClient::with_key_credential(secret, None)? +``` diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs new file mode 100644 index 0000000000..1e0da9b73a --- /dev/null +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use azure_openai_inference::{ + clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, + AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, +}; + +// This example illustrates how to use Azure OpenAI with key credential authentication to generate a chat completion. +#[tokio::main] +pub async fn main() { + let endpoint = + std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); + let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); + + let chat_completions_client = AzureOpenAIClient::with_key_credential( + endpoint, + secret, + Some( + AzureOpenAIClientOptions::builder() + .with_api_version(AzureServiceVersion::V2023_12_01Preview) + .build(), + ), + ) + .unwrap() + .chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::with_user_message( + "gpt-4-1106-preview", + "Tell me a joke about pineapples", + ); + + let response = chat_completions_client + .create_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await; + + match response { + Ok(chat_completions_response) => { + let chat_completions = chat_completions_response + .deserialize_body() + .await + .expect("Failed to deserialize response"); + println!("{:#?}", &chat_completions); + } + Err(e) => { + println!("Error: {}", e); + } + }; +} diff --git a/sdk/openai/inference/examples/azure_chat_completions_aad.rs b/sdk/openai/inference/examples/azure_chat_completions_aad.rs new file mode 100644 index 0000000000..4d187d4bfb --- /dev/null +++ b/sdk/openai/inference/examples/azure_chat_completions_aad.rs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use azure_identity::DefaultAzureCredentialBuilder; +use azure_openai_inference::{ + clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, + AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, +}; + +/// This example illustrates how to use Azure OpenAI Chat Completions with Azure Active Directory authentication. +#[tokio::main] +async fn main() { + let endpoint = + std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); + + let chat_completions_client = AzureOpenAIClient::new( + endpoint, + DefaultAzureCredentialBuilder::new() + .build() + .expect("Failed to create Azure credential"), + Some( + AzureOpenAIClientOptions::builder() + .with_api_version(AzureServiceVersion::V2023_12_01Preview) + .build(), + ), + ) + .unwrap() + .chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::with_user_message( + "gpt-4-1106-preview", + "Tell me a joke about pineapples", + ); + + let response = chat_completions_client + .create_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await; + + match response { + Ok(chat_completions_response) => { + let chat_completions = chat_completions_response + .deserialize_body() + .await + .expect("Failed to deserialize response"); + println!("{:#?}", &chat_completions); + } + Err(e) => { + println!("Error: {}", e); + } + }; +} diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs new file mode 100644 index 0000000000..233326b1f3 --- /dev/null +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use azure_openai_inference::{ + clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, + AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, +}; +use futures::stream::StreamExt; +use std::io::{self, Write}; + +/// This example illustrates how to use Azure OpenAI with key credential authentication to stream chat completions. +#[tokio::main] +async fn main() { + let endpoint = + std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); + let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); + + let chat_completions_client = AzureOpenAIClient::with_key_credential( + endpoint, + secret, + Some( + AzureOpenAIClientOptions::builder() + .with_api_version(AzureServiceVersion::V2023_12_01Preview) + .build(), + ), + ) + .unwrap() + .chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::with_user_message_and_stream( + "gpt-4-1106-preview", + "Write me an essay that is at least 200 words long on the nutritional values (or lack thereof) of fast food. + Start the essay by stating 'this essay will be x many words long' where x is the number of words in the essay.",); + + let response = chat_completions_client + .stream_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await + .unwrap(); + + // this pins the stream to the stack so it is safe to poll it (namely, it won't be de-allocated or moved) + futures::pin_mut!(response); + + while let Some(result) = response.next().await { + match result { + Ok(delta) => { + if let Some(choice) = delta.choices.get(0) { + choice.delta.as_ref().map(|d| { + d.content.as_ref().map(|c| { + print!("{}", c); + let _ = io::stdout().flush(); + }); + }); + } + } + Err(e) => println!("Error: {:?}", e), + } + } +} diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs new file mode 100644 index 0000000000..a5e6fe8261 --- /dev/null +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use azure_openai_inference::{ + clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, + CreateChatCompletionsRequest, +}; + +/// This example illustrates how to use OpenAI to generate a chat completion. +#[tokio::main] +pub async fn main() { + let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); + + let chat_completions_client = OpenAIClient::with_key_credential(secret, None) + .unwrap() + .chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::with_user_message( + "gpt-3.5-turbo-1106", + "Tell me a joke about pineapples", + ); + + let response = chat_completions_client + .create_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await; + + match response { + Ok(chat_completions_response) => { + let chat_completions = chat_completions_response + .deserialize_body() + .await + .expect("Failed to deserialize response"); + println!("{:#?}", &chat_completions); + } + Err(e) => { + println!("Error: {}", e); + } + }; +} diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs new file mode 100644 index 0000000000..3058d3e0bb --- /dev/null +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use azure_openai_inference::{ + clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, + CreateChatCompletionsRequest, +}; +use futures::stream::StreamExt; +use std::io::{self, Write}; + +/// This example illustrates how to use OpenAI to stream chat completions. +#[tokio::main] +async fn main() { + let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); + + let chat_completions_client = OpenAIClient::with_key_credential(secret, None) + .unwrap() + .chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::with_user_message_and_stream( + "gpt-3.5-turbo-1106", + "Write me an essay that is at least 200 words long on the nutritional values (or lack thereof) of fast food. + Start the essay by stating 'this essay will be x many words long' where x is the number of words in the essay.",); + + let response = chat_completions_client + .stream_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await + .unwrap(); + + // this pins the stream to the stack so it is safe to poll it (namely, it won't be de-allocated or moved) + futures::pin_mut!(response); + + while let Some(result) = response.next().await { + match result { + Ok(delta) => { + if let Some(choice) = delta.choices.get(0) { + choice.delta.as_ref().map(|d| { + d.content.as_ref().map(|c| { + print!("{}", c); + let _ = io::stdout().flush(); + }); + }); + } + } + Err(e) => println!("Error: {:?}", e), + } + } +} diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs new file mode 100644 index 0000000000..e5e1004b84 --- /dev/null +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use std::sync::Arc; + +use crate::credentials::{AzureKeyCredential, DEFAULT_SCOPE}; + +use crate::options::AzureOpenAIClientOptions; +use azure_core::credentials::TokenCredential; +use azure_core::{self, Policy, Result}; +use azure_core::{BearerTokenCredentialPolicy, Url}; + +use super::chat_completions_client::ChatCompletionsClient; +use super::BaseOpenAIClientMethods; + +/// Defines the methods provided by a [`AzureOpenAIClient`] and can be used for mocking. +pub trait AzureOpenAIClientMethods { + /// Returns the endpoint [`Url`] of the client. + fn endpoint(&self) -> &Url; + + /// Returns a new instance of the [`ChatCompletionsClient`]. + fn chat_completions_client(&self) -> ChatCompletionsClient; +} + +/// An Azure OpenAI client. +#[derive(Debug, Clone)] +pub struct AzureOpenAIClient { + /// The Azure resource endpoint + endpoint: Url, + + /// The pipeline for sending requests to the service. + pipeline: azure_core::Pipeline, +} + +impl AzureOpenAIClient { + /// Creates a new [`AzureOpenAIClient`] using a [`TokenCredential`]. + /// See the following example for Azure Active Directory authentication: + /// + /// # Parameters + /// * `endpoint` - The full URL of the Azure OpenAI resource endpoint. + /// * `credential` - An implementation of [`TokenCredential`] used for authentication. + /// * `client_options` - Optional configuration for the client. The [`AzureServiceVersion`](crate::options::AzureServiceVersion) can be provided here. + /// + /// # Example + /// + /// ```no_run + /// use azure_openai_inference::clients::{AzureOpenAIClient, AzureOpenAIClientMethods}; + /// use azure_identity::DefaultAzureCredentialBuilder; + /// + /// let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT environment variable"); + /// let client = AzureOpenAIClient::new( + /// endpoint, + /// DefaultAzureCredentialBuilder::new().build().unwrap(), + /// None, + /// ).unwrap(); + /// ``` + pub fn new( + endpoint: impl AsRef, + credential: Arc, + client_options: Option, + ) -> Result { + let endpoint = Url::parse(endpoint.as_ref())?; + + let options = client_options.unwrap_or_default(); + + let auth_policy = Arc::new(BearerTokenCredentialPolicy::new(credential, DEFAULT_SCOPE)); + let version_policy: Arc = options.api_service_version.clone().into(); + let per_call_policies: Vec> = vec![auth_policy, version_policy]; + + let pipeline = super::new_pipeline(per_call_policies, options.client_options.clone()); + + Ok(AzureOpenAIClient { endpoint, pipeline }) + } + + /// Creates a new [`AzureOpenAIClient`] using a key credential + /// + /// # Parameters + /// * `endpoint` - The full URL of the Azure OpenAI resource endpoint. + /// * `secret` - The key credential used for authentication. Passed as header parameter in the request. + /// * `client_options` - Optional configuration for the client. The [`AzureServiceVersion`](crate::options::AzureServiceVersion) can be provided here. + /// + /// # Example + /// ```no_run + /// use azure_openai_inference::clients::{AzureOpenAIClient, AzureOpenAIClientMethods}; + /// + /// let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT environment variable"); + /// let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY environment variable"); + /// let client = AzureOpenAIClient::with_key_credential( + /// endpoint, + /// secret, + /// None, + /// ).unwrap(); + /// ``` + pub fn with_key_credential( + endpoint: impl AsRef, + secret: impl Into, + client_options: Option, + ) -> Result { + let endpoint = Url::parse(endpoint.as_ref())?; + + let options = client_options.unwrap_or_default(); + + let auth_policy: Arc = AzureKeyCredential::new(secret).into(); + let version_policy: Arc = options.api_service_version.clone().into(); + let per_call_policies: Vec> = vec![auth_policy, version_policy]; + + let pipeline = super::new_pipeline(per_call_policies, options.client_options.clone()); + + Ok(AzureOpenAIClient { endpoint, pipeline }) + } +} + +impl AzureOpenAIClientMethods for AzureOpenAIClient { + /// Returns the endpoint [`Url`] of the client. + fn endpoint(&self) -> &Url { + &self.endpoint + } + + /// Returns a new instance of the [`ChatCompletionsClient`] using an [`AzureOpenAIClient`] underneath. + fn chat_completions_client(&self) -> ChatCompletionsClient { + ChatCompletionsClient::new(Box::new(self.clone())) + } +} + +impl BaseOpenAIClientMethods for AzureOpenAIClient { + fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result { + // TODO gracefully handle this, if it makes sense. A panic seems appropriate IMO. + Ok(self + .endpoint() + .join("openai/")? + .join("deployments/")? + .join(&format!( + "{}/", + deployment_name.expect("Deployment name is required.") + ))?) + } + + fn pipeline(&self) -> &azure_core::Pipeline { + &self.pipeline + } +} diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs new file mode 100644 index 0000000000..9d2016c267 --- /dev/null +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use super::{new_json_request, BaseOpenAIClientMethods}; +use crate::{ + helpers::streaming::{string_chunks, EventStreamer}, + response::{CreateChatCompletionsResponse, CreateChatCompletionsStreamResponse}, + CreateChatCompletionsRequest, +}; +use azure_core::{Context, Method, Response, Result}; +use futures::{Stream, StreamExt}; + +/// A [`ChatCompletionsClient`]'s methods. This trait can be used for mocking. +pub trait ChatCompletionsClientMethods { + /// Creates a new chat completion. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In OpenAI it is the model name to be used. + /// * `chat_completions_request` - The request specifying the chat completion to be created. + #[allow(async_fn_in_trait)] + async fn create_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result>; + + /// Creates a new chat completion and returns a streamed response. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In OpenAI it is the model name to be used. + /// * `chat_completions_request` - The request specifying the chat completion to be created. + #[allow(async_fn_in_trait)] + async fn stream_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result>>; +} + +/// A client for Chat Completions related operations. +pub struct ChatCompletionsClient { + /// The underlying HTTP client with an associated pipeline. + base_client: Box, +} + +impl ChatCompletionsClient { + pub(super) fn new(base_client: Box) -> Self { + Self { base_client } + } +} + +impl ChatCompletionsClientMethods for ChatCompletionsClient { + /// Creates a new chat completion. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In OpenAI it is the model name to be used. + /// * `chat_completions_request` - The request specifying the chat completion to be created. + async fn create_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result> { + let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?; + let request_url = base_url.join("chat/completions")?; + + let mut request = new_json_request(request_url, Method::Post, &chat_completions_request); + + self.base_client + .pipeline() + .send::(&Context::new(), &mut request) + .await + } + + /// Creates a new chat completion and returns a streamed response. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In OpenAI it is the model name to be used. + /// * `chat_completions_request` - The request specifying the chat completion to be created. + async fn stream_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result>> { + let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?; + let request_url = base_url.join("chat/completions")?; + + let mut request = new_json_request(request_url, Method::Post, &chat_completions_request); + + let response_body = self + .base_client + .pipeline() + .send::<()>(&Context::new(), &mut request) + .await? + .into_body(); + + Ok(ChatCompletionsStreamHandler::event_stream(response_body)) + } +} + +/// A placeholder type to provide an implementation for the [`EventStreamer`] trait specifically for chat completions. +struct ChatCompletionsStreamHandler; + +impl EventStreamer for ChatCompletionsStreamHandler { + fn event_stream( + response_body: azure_core::ResponseBody, + ) -> impl Stream> { + let stream_event_delimiter = "\n\n"; + + string_chunks(response_body, stream_event_delimiter).map(|event| match event { + Ok(event) => serde_json::from_str::(&event) + .map_err(|e| e.into()), + Err(e) => Err(e), + }) + } +} diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs new file mode 100644 index 0000000000..f892407cda --- /dev/null +++ b/sdk/openai/inference/src/clients/mod.rs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +mod azure_openai_client; +mod chat_completions_client; +mod openai_client; + +use std::sync::Arc; + +pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods}; +pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods}; +pub use openai_client::{OpenAIClient, OpenAIClientMethods}; + +/// A trait that defines the common behavior expected from an [`OpenAIClient`] and an [`AzureOpenAIClient`]. +/// This trait will be used as a boxed types for any clients such as [`ChatCompletionsClient`] so they issue HTTP requests. +trait BaseOpenAIClientMethods { + /// Returns the base [`Url`] of the underlying client. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In an [`OpenAIClient`] this parameter is ignored. + fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result; + + /// Returns the [`azure_core::Pipeline`] of the underlying client. + fn pipeline(&self) -> &azure_core::Pipeline; +} + +fn new_pipeline( + per_call_policies: Vec>, + options: azure_core::ClientOptions, +) -> azure_core::Pipeline { + azure_core::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options, + per_call_policies, + Vec::new(), + ) +} + +fn new_json_request( + url: azure_core::Url, + method: azure_core::Method, + json_body: &T, +) -> azure_core::Request +where + T: serde::Serialize, +{ + let mut request = azure_core::Request::new(url, method); + + // For some reason non-Azure OpenAI's API is strict about these headers being present + request.insert_header(azure_core::headers::CONTENT_TYPE, "application/json"); + request.insert_header(azure_core::headers::ACCEPT, "application/json"); + + request.set_json(json_body).unwrap(); + request +} diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs new file mode 100644 index 0000000000..7d0eeed3ce --- /dev/null +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use std::sync::Arc; + +use azure_core::{Policy, Result, Url}; + +use crate::{credentials::OpenAIKeyCredential, OpenAIClientOptions}; + +use super::{BaseOpenAIClientMethods, ChatCompletionsClient}; + +/// Defines the methods provided by a [`OpenAIClient`] and can be used for mocking. +pub trait OpenAIClientMethods { + fn chat_completions_client(&self) -> ChatCompletionsClient; +} + +/// An OpenAI client. +#[derive(Debug, Clone)] +pub struct OpenAIClient { + base_url: Url, + pipeline: azure_core::Pipeline, +} + +impl OpenAIClient { + /// Creates a new [`OpenAIClient`] using a secret key. + /// + /// # Parameters + /// * `secret` - The key credential used for authentication. + /// * `client_options` - Optional configuration for the client. Reserved for future used, currently can always be `None`. + /// + /// # Example + /// ```no_run + /// use azure_openai_inference::clients::{OpenAIClient, OpenAIClientMethods}; + /// + /// let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); + /// let client = OpenAIClient::with_key_credential(secret, None).unwrap(); + /// ``` + pub fn with_key_credential( + secret: impl Into, + client_options: Option, + ) -> Result { + let base_url = Url::parse("https://api.openai.com/v1/")?; + let options = client_options.unwrap_or_default(); + let auth_policy: Arc = OpenAIKeyCredential::new(secret).into(); + + let pipeline = super::new_pipeline(vec![auth_policy], options.client_options.clone()); + + Ok(OpenAIClient { base_url, pipeline }) + } +} + +impl OpenAIClientMethods for OpenAIClient { + fn chat_completions_client(&self) -> ChatCompletionsClient { + ChatCompletionsClient::new(Box::new(self.clone())) + } +} + +impl BaseOpenAIClientMethods for OpenAIClient { + fn pipeline(&self) -> &azure_core::Pipeline { + &self.pipeline + } + + fn base_url(&self, _deployment_name: Option<&str>) -> Result { + Ok(self.base_url.clone()) + } +} diff --git a/sdk/openai/inference/src/credentials/azure_key_credential.rs b/sdk/openai/inference/src/credentials/azure_key_credential.rs new file mode 100644 index 0000000000..f39f560d2c --- /dev/null +++ b/sdk/openai/inference/src/credentials/azure_key_credential.rs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use async_trait::async_trait; +use std::sync::Arc; + +use azure_core::{ + credentials::Secret, + headers::{HeaderName, HeaderValue}, + Context, Header, Policy, PolicyResult, Request, +}; + +/// A key credential for the [AzureOpenAIClient](crate::clients::AzureOpenAIClient). +/// +/// # Example +/// ```no_run +/// use azure_openai_inference::clients::{AzureOpenAIClient, AzureOpenAIClientMethods}; +/// +/// let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); +/// let azure_open_ai_client = AzureOpenAIClient::with_key_credential( +/// "https://my.endpoint/", +/// secret, +/// None, +/// ).unwrap(); +/// ``` +#[derive(Debug, Clone)] +pub struct AzureKeyCredential(Secret); + +impl AzureKeyCredential { + /// Create a new [`AzureKeyCredential`]. + pub fn new(api_key: impl Into) -> Self { + Self(Secret::new(api_key.into())) + } +} + +impl Header for AzureKeyCredential { + fn name(&self) -> HeaderName { + HeaderName::from_static("api-key") + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(self.0.secret().to_string()) + } +} + +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for AzureKeyCredential { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request.insert_header(Header::name(self), Header::value(self)); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl From for Arc { + fn from(credential: AzureKeyCredential) -> Arc { + Arc::new(credential) + } +} diff --git a/sdk/openai/inference/src/credentials/mod.rs b/sdk/openai/inference/src/credentials/mod.rs new file mode 100644 index 0000000000..5ace38ea0e --- /dev/null +++ b/sdk/openai/inference/src/credentials/mod.rs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +mod azure_key_credential; +mod openai_key_credential; + +pub(crate) use azure_key_credential::*; +pub(crate) use openai_key_credential::*; + +pub(crate) const DEFAULT_SCOPE: [&str; 1] = ["https://cognitiveservices.azure.com/.default"]; diff --git a/sdk/openai/inference/src/credentials/openai_key_credential.rs b/sdk/openai/inference/src/credentials/openai_key_credential.rs new file mode 100644 index 0000000000..7bf09d0c1f --- /dev/null +++ b/sdk/openai/inference/src/credentials/openai_key_credential.rs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use async_trait::async_trait; +use std::sync::Arc; + +use azure_core::{ + credentials::Secret, + headers::{HeaderName, HeaderValue, AUTHORIZATION}, + Context, Header, Policy, PolicyResult, Request, +}; + +/// A key credential for the [OpenAIClient](crate::clients::OpenAIClient). +/// +/// # Example +/// ```no_run +/// use azure_openai_inference::clients::{OpenAIClient, OpenAIClientMethods}; +/// +/// let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); +/// let open_ai_client = OpenAIClient::with_key_credential( +/// secret, +/// None, +/// ).unwrap(); +/// ``` +#[derive(Debug, Clone)] +pub struct OpenAIKeyCredential(Secret); + +impl OpenAIKeyCredential { + pub fn new(access_token: impl Into) -> Self { + Self(Secret::new(access_token.into())) + } +} + +impl Header for OpenAIKeyCredential { + fn name(&self) -> HeaderName { + AUTHORIZATION + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) + } +} + +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for OpenAIKeyCredential { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request.insert_header(Header::name(self), Header::value(self)); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl From for Arc { + fn from(credential: OpenAIKeyCredential) -> Arc { + Arc::new(credential) + } +} diff --git a/sdk/openai/inference/src/helpers/mod.rs b/sdk/openai/inference/src/helpers/mod.rs new file mode 100644 index 0000000000..2d0184ea05 --- /dev/null +++ b/sdk/openai/inference/src/helpers/mod.rs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +pub(crate) mod streaming; diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs new file mode 100644 index 0000000000..9f234acb22 --- /dev/null +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use azure_core::{Error, Result}; +use futures::{Stream, StreamExt}; + +/// A trait used to designate a type into which the streams will be deserialized. +pub(crate) trait EventStreamer +where + T: serde::de::DeserializeOwned, +{ + fn event_stream(response_body: azure_core::ResponseBody) -> impl Stream>; +} + +/// A helper function to be used in streaming scenarios. The `response_body`, the input stream +/// is buffered until a `stream_event_delimiter` is found. This constitutes a single event. +/// These series of events are then returned as a stream. +/// +/// # Arguments +/// * `response_body` - The response body stream of an HTTP request. +/// * `stream_event_delimiter` - The delimiter that separates events in the stream. In some cases `\n\n`, in other cases can be `\n\r\n\n`. +/// # Returns +/// The `response_body` stream segmented and streamed into String events demarcated by `stream_event_delimiter`. +pub(crate) fn string_chunks<'a>( + response_body: (impl Stream> + Unpin + 'a), + stream_event_delimiter: &'a str, +) -> impl Stream> + 'a { + let chunk_buffer = Vec::new(); + let stream = futures::stream::unfold( + (response_body, chunk_buffer), + move |(mut response_body, mut chunk_buffer)| async move { + let delimiter = stream_event_delimiter.as_bytes(); + let delimiter_len = delimiter.len(); + + if let Some(Ok(bytes)) = response_body.next().await { + chunk_buffer.extend_from_slice(&bytes); + if let Some(pos) = chunk_buffer + .windows(delimiter_len) + .position(|window| window == delimiter) + { + // the range must include the delimiter bytes + let mut bytes = chunk_buffer + .drain(..pos + delimiter_len) + .collect::>(); + bytes.truncate(bytes.len() - delimiter_len); + + return if let Ok(yielded_value) = std::str::from_utf8(&bytes) { + // We strip the "data: " portion of the event. The rest is always JSON and will be deserialized + // by a subsequent mapping function for this stream + let yielded_value = yielded_value.trim_start_matches("data:").trim(); + if yielded_value == "[DONE]" { + return None; + } else { + Some((Ok(yielded_value.to_string()), (response_body, chunk_buffer))) + } + } else { + None + }; + } + if !chunk_buffer.is_empty() { + return Some(( + Err(Error::with_message( + azure_core::error::ErrorKind::DataConversion, + || "Incomplete chunk", + )), + (response_body, chunk_buffer), + )); + } + // We drain the buffer of any messages that may be left over. + // The block above will be skipped, since response_body.next() will be None every time + } else if !chunk_buffer.is_empty() { + if let Some(pos) = chunk_buffer + .windows(delimiter_len) + .position(|window| window == delimiter) + { + // the range must include the delimiter bytes + let mut bytes = chunk_buffer + .drain(..pos + delimiter_len) + .collect::>(); + bytes.truncate(bytes.len() - delimiter_len); + + return if let Ok(yielded_value) = std::str::from_utf8(&bytes) { + let yielded_value = yielded_value.trim_start_matches("data:").trim(); + if yielded_value == "[DONE]" { + return None; + } else { + Some((Ok(yielded_value.to_string()), (response_body, chunk_buffer))) + } + } else { + None + }; + } + // if we get to this point, it means we have drained the buffer of all events, meaning that we haven't been able to find the next delimiter + } + None + }, + ); + + // We specifically allow the Error::with_message(ErrorKind::DataConversion, || "Incomplete chunk") + // So that we are able to continue pushing bytes to the buffer until we find the next delimiter + return stream.filter(|it| { + std::future::ready( + it.is_ok() + || it.as_ref().unwrap_err().to_string() + != Error::with_message(azure_core::error::ErrorKind::DataConversion, || { + "Incomplete chunk" + }) + .to_string(), + ) + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::pin_mut; + + #[tokio::test] + async fn clean_chunks() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static(b"data: piece 1\n\n")), + Ok(bytes::Bytes::from_static(b"data: piece 2\n\n")), + Ok(bytes::Bytes::from_static(b"data: [DONE]\n\n")), + ]); + + let actual = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual); + let actual: Vec> = actual.collect().await; + + let expected: Vec> = + vec![Ok("piece 1".to_string()), Ok("piece 2".to_string())]; + assert_result_vectors(expected, actual); + } + + #[tokio::test] + async fn multiple_message_in_one_chunk() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static( + b"data: piece 1\n\ndata: piece 2\n\n", + )), + Ok(bytes::Bytes::from_static( + b"data: piece 3\n\ndata: [DONE]\n\n", + )), + ]); + + let mut actual = Vec::new(); + + let actual_stream = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual_stream); + + while let Some(event) = actual_stream.next().await { + actual.push(event); + } + + let expected: Vec> = vec![ + Ok("piece 1".to_string()), + Ok("piece 2".to_string()), + Ok("piece 3".to_string()), + ]; + assert_result_vectors(expected, actual); + } + + #[tokio::test] + async fn data_marker_in_previous_chunk() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static( + b"data: piece 1\n\ndata: piece 2\n\ndata:", + )), + Ok(bytes::Bytes::from_static(b" piece 3\n\ndata: [DONE]\n\n")), + ]); + + let mut actual = Vec::new(); + + let actual_stream = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual_stream); + + while let Some(event) = actual_stream.next().await { + actual.push(event); + } + + let expected: Vec> = vec![ + Ok("piece 1".to_string()), + Ok("piece 2".to_string()), + Ok("piece 3".to_string()), + ]; + assert_result_vectors(expected, actual); + } + + #[tokio::test] + async fn event_delimiter_split_across_chunks() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static(b"data: piece 1\n")), + Ok(bytes::Bytes::from_static(b"\ndata: [DONE]")), + ]); + + let actual = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual); + let actual: Vec> = actual.collect().await; + + let expected: Vec> = vec![Ok("piece 1".to_string())]; + assert_result_vectors(expected, actual); + } + + #[tokio::test] + async fn event_delimiter_at_start_of_next_chunk() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static(b"data: piece 1")), + Ok(bytes::Bytes::from_static(b"\n\ndata: [DONE]")), + ]); + + let actual = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual); + let actual: Vec> = actual.collect().await; + + let expected: Vec> = vec![Ok("piece 1".to_string())]; + assert_result_vectors(expected, actual); + } + + // This is an over simplification, reasonable for an MVP. We should: + // 1. propagate error upwards + // 2. handle an unexpected "data:" marker (this will simply send the string as is, which will fail deserialization in an upper mapping layer) + #[tokio::test] + async fn error_in_response_ends_stream() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static(b"data: piece 1\n\n")), + Err(Error::with_message( + azure_core::error::ErrorKind::Other, + || "Incomplete chunk", + )), + ]); + + let actual = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual); + let actual: Vec> = actual.collect().await; + + let expected: Vec> = vec![Ok("piece 1".to_string())]; + assert_result_vectors(expected, actual); + } + + fn assert_result_vectors(expected: Vec>, actual: Vec>) + where + T: std::fmt::Debug + PartialEq, + { + assert_eq!(expected.len(), actual.len()); + for (expected, actual) in expected.iter().zip(actual.iter()) { + if let Ok(actual) = actual { + assert_eq!(actual, expected.as_ref().unwrap()); + } else { + let actual_err = actual.as_ref().unwrap_err(); + let expected_err = expected.as_ref().unwrap_err(); + assert_eq!(actual_err.kind(), expected_err.kind()); + assert_eq!(actual_err.to_string(), expected_err.to_string()); + } + } + } +} diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs new file mode 100644 index 0000000000..dceeb85b4d --- /dev/null +++ b/sdk/openai/inference/src/lib.rs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +pub mod clients; +mod credentials; +mod helpers; +mod models; +mod options; + +pub use models::{request::*, response}; +pub use options::*; diff --git a/sdk/openai/inference/src/models/chat_completions.rs b/sdk/openai/inference/src/models/chat_completions.rs new file mode 100644 index 0000000000..9d11d55223 --- /dev/null +++ b/sdk/openai/inference/src/models/chat_completions.rs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +pub mod request { + + use serde::Serialize; + + /// The configuration information for a chat completions request. + /// Completions support a wide variety of tasks and generate text that continues from or "completes" + /// provided prompt data. + #[derive(Serialize, Debug, Clone, Default)] + #[non_exhaustive] + pub struct CreateChatCompletionsRequest { + pub messages: Vec, + pub model: String, + pub stream: Option, + } + + /// An abstract representation of a chat message as provided in a request. + #[derive(Serialize, Debug, Clone, Default)] + #[non_exhaustive] + pub struct ChatCompletionRequestMessageBase { + /// An optional name for the participant. + #[serde(skip)] + pub name: Option, + /// The contents of the message. + pub content: String, // TODO this should be either a string or ChatCompletionRequestMessageContentPart (a polymorphic type) + } + + /// A description of the intended purpose of a message within a chat completions interaction. + #[derive(Serialize, Debug, Clone)] + #[non_exhaustive] + #[serde(tag = "role")] + pub enum ChatCompletionRequestMessage { + /// The role that instructs or sets the behavior of the assistant." + #[serde(rename = "system")] + System(ChatCompletionRequestMessageBase), + + /// The role that provides input for chat completions. + #[serde(rename = "user")] + User(ChatCompletionRequestMessageBase), + } + + impl ChatCompletionRequestMessage { + /// Creates a new [`ChatCompletionRequestMessage`] with a single `user` message. + pub fn with_user_role(content: impl Into) -> Self { + Self::User(ChatCompletionRequestMessageBase { + content: content.into(), + name: None, + }) + } + + /// Creates a new [`ChatCompletionRequestMessage`] with a single `system` message. + pub fn with_system_role(content: impl Into) -> Self { + Self::System(ChatCompletionRequestMessageBase { + content: content.into(), + name: None, + }) + } + } + + impl CreateChatCompletionsRequest { + /// Creates a new [`CreateChatCompletionsRequest`] with a single `user` message. + /// + /// # Example + /// + /// ```rust + /// let request = azure_openai_inference::CreateChatCompletionsRequest::with_user_message( + /// "gpt-3.5-turbo-1106", + /// "Why couldn't the eagles take Frodo directly to mount doom?"); + /// ``` + pub fn with_user_message(model: &str, prompt: &str) -> Self { + Self { + model: model.to_string(), + messages: vec![ChatCompletionRequestMessage::with_user_role(prompt)], + ..Default::default() + } + } + + /// Creates a new [`CreateChatCompletionsRequest`] with a single `system` message whose response will be streamed. + /// + /// # Example + /// + /// ```rust + /// let request = azure_openai_inference::CreateChatCompletionsRequest::with_user_message_and_stream( + /// "gpt-3.5-turbo-1106", + /// "Why couldn't the eagles take Frodo directly to Mount Doom?"); + /// ``` + pub fn with_user_message_and_stream( + model: impl Into, + prompt: impl Into, + ) -> Self { + Self { + model: model.into(), + messages: vec![ChatCompletionRequestMessage::with_user_role(prompt)], + stream: Some(true), + ..Default::default() + } + } + + /// Creates a new [`CreateChatCompletionsRequest`] with a list of messages. + /// + /// # Example + /// ```rust + /// let request = azure_openai_inference::CreateChatCompletionsRequest::with_messages( + /// "gpt-3.5-turbo-1106", + /// vec![ + /// azure_openai_inference::ChatCompletionRequestMessage::with_system_role("You are a good math tutor who explains things briefly."), + /// azure_openai_inference::ChatCompletionRequestMessage::with_user_role("What is the value of 'x' in the equation: '2x + 3 = 11'?"), + /// ]); + pub fn with_messages( + model: impl Into, + messages: Vec, + ) -> Self { + Self { + model: model.into(), + messages, + ..Default::default() + } + } + } +} + +pub mod response { + + use azure_core::Model; + use serde::Deserialize; + + /// Representation of the response data from a chat completions request. + /// Completions support a wide variety of tasks and generate text that continues from or "completes" + /// provided prompt data. + #[derive(Debug, Clone, Deserialize, Model)] + #[non_exhaustive] + pub struct CreateChatCompletionsResponse { + /// The collection of completions choices associated with this completions response. + /// Generally, `n` choices are generated per provided prompt with a default value of 1. + /// Token limits and other settings may limit the number of choices generated. + pub choices: Vec, + } + + /// The representation of a single prompt completion as part of an overall chat completions request. + /// Generally, `n` choices are generated per provided prompt with a default value of 1. + /// Token limits and other settings may limit the number of choices generated. + #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] + pub struct ChatCompletionChoice { + /// The chat message for a given chat completions prompt. + pub message: ChatCompletionResponseMessage, + } + + #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] + pub struct ChatCompletionResponseMessage { + /// The content of the message. + pub content: Option, + + /// The chat role associated with the message. + pub role: String, + } + + // region: --- Streaming + /// Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] + pub struct CreateChatCompletionsStreamResponse { + /// A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. + pub choices: Vec, + } + + /// A chat completion delta generated by streamed model responses. + #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] + pub struct ChatCompletionStreamChoice { + /// The delta message content for a streaming response. + pub delta: Option, + } + + /// A chat completion delta generated by streamed model responses. + /// + /// Note: all fields are optional because in a streaming scenario there is no guarantee of what is present in the model. + #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] + pub struct ChatCompletionStreamResponseMessage { + /// The content of the chunk message. + pub content: Option, + + /// The chat role associated with the message. + pub role: Option, + } + // endregion: Streaming +} diff --git a/sdk/openai/inference/src/models/mod.rs b/sdk/openai/inference/src/models/mod.rs new file mode 100644 index 0000000000..67a1b7a918 --- /dev/null +++ b/sdk/openai/inference/src/models/mod.rs @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +mod chat_completions; + +pub use chat_completions::*; diff --git a/sdk/openai/inference/src/options/azure_openai_client_options.rs b/sdk/openai/inference/src/options/azure_openai_client_options.rs new file mode 100644 index 0000000000..bcdbef9047 --- /dev/null +++ b/sdk/openai/inference/src/options/azure_openai_client_options.rs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use azure_core::ClientOptions; + +use crate::AzureServiceVersion; + +/// Options to be passed to [`AzureOpenAIClient`](crate::clients::AzureOpenAIClient). +// TODO: I was not able to find ClientOptions as a derive macros +#[derive(Clone, Debug, Default)] +pub struct AzureOpenAIClientOptions { + #[allow(dead_code)] + pub(crate) client_options: ClientOptions, + pub(crate) api_service_version: AzureServiceVersion, +} + +impl AzureOpenAIClientOptions { + pub fn builder() -> builders::AzureOpenAIClientOptionsBuilder { + builders::AzureOpenAIClientOptionsBuilder::new() + } +} + +pub mod builders { + use super::*; + + #[derive(Clone, Debug, Default)] + pub struct AzureOpenAIClientOptionsBuilder { + options: AzureOpenAIClientOptions, + } + + impl AzureOpenAIClientOptionsBuilder { + pub(super) fn new() -> Self { + Self::default() + } + + /// Configures the [`AzureOpenAIClient`](crate::clients::AzureOpenAIClient) to use the specified API version. + /// If no value is supplied, the latest version will be used as default. See [`AzureServiceVersion::get_latest()`](AzureServiceVersion::get_latest). + pub fn with_api_version(mut self, api_service_version: AzureServiceVersion) -> Self { + self.options.api_service_version = api_service_version; + self + } + + /// Builds the [`AzureOpenAIClientOptions`]. + /// + /// # Examples + /// + /// ```rust + /// let options = azure_openai_inference::OpenAIClientOptions::builder().build(); + /// ``` + pub fn build(&self) -> AzureOpenAIClientOptions { + self.options.clone() + } + } +} diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs new file mode 100644 index 0000000000..ed2a303a01 --- /dev/null +++ b/sdk/openai/inference/src/options/mod.rs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +mod azure_openai_client_options; +mod openai_client_options; +mod service_version; + +pub use azure_openai_client_options::{builders::*, AzureOpenAIClientOptions}; +pub use openai_client_options::{builders::*, OpenAIClientOptions}; +pub use service_version::AzureServiceVersion; diff --git a/sdk/openai/inference/src/options/openai_client_options.rs b/sdk/openai/inference/src/options/openai_client_options.rs new file mode 100644 index 0000000000..b428727dae --- /dev/null +++ b/sdk/openai/inference/src/options/openai_client_options.rs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use azure_core::ClientOptions; + +/// Options to be passed to [`OpenAIClient`](crate::clients::OpenAIClient). +/// +/// Note: There are currently no options to be set. +/// This struct is a placeholder for future options. +// TODO: I was not able to find ClientOptions as a derive macros +#[derive(Clone, Debug, Default)] +pub struct OpenAIClientOptions { + pub(crate) client_options: ClientOptions, +} + +impl OpenAIClientOptions { + /// Creates a new [`builders::OpenAIClientOptionsBuilder`]. + pub fn builder() -> builders::OpenAIClientOptionsBuilder { + builders::OpenAIClientOptionsBuilder::new() + } +} + +/// Builder to construct a [`OpenAIClientOptions`]. +pub mod builders { + use super::*; + + #[derive(Clone, Debug, Default)] + pub struct OpenAIClientOptionsBuilder { + options: OpenAIClientOptions, + } + + impl OpenAIClientOptionsBuilder { + pub(super) fn new() -> Self { + Self::default() + } + + /// Builds the [`OpenAIClientOptions`]. + /// + /// # Examples + /// + /// ```rust + /// let options = azure_openai_inference::OpenAIClientOptions::builder().build(); + /// ``` + pub fn build(&self) -> OpenAIClientOptions { + self.options.clone() + } + } +} diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs new file mode 100644 index 0000000000..02df7b5b25 --- /dev/null +++ b/sdk/openai/inference/src/options/service_version.rs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +use async_trait::async_trait; +use std::{fmt::Display, sync::Arc}; + +use azure_core::{Context, Policy, PolicyResult, Request}; + +/// The version of the Azure service to use. +/// This enum is passed to the [`AzureOpenAIClientOptionsBuilder`](crate::AzureOpenAIClientOptionsBuilder) to configure an [`AzureOpenAIClient`](crate::clients::AzureOpenAIClient) to specify the version of the service to use. +/// +/// If no version is specified, the latest version will be used. See [`AzureServiceVersion::get_latest()`](AzureServiceVersion::get_latest). +#[derive(Debug, Clone)] +pub enum AzureServiceVersion { + V2023_09_01Preview, + V2023_12_01Preview, + V2024_07_01Preview, +} + +impl Default for AzureServiceVersion { + fn default() -> AzureServiceVersion { + AzureServiceVersion::get_latest() + } +} + +impl AzureServiceVersion { + /// Returns the latest supported version of the Azure OpenAI service. + pub fn get_latest() -> AzureServiceVersion { + AzureServiceVersion::V2024_07_01Preview + } +} + +impl From for String { + fn from(version: AzureServiceVersion) -> String { + let as_str = match version { + AzureServiceVersion::V2023_09_01Preview => "2023-09-01-preview", + AzureServiceVersion::V2023_12_01Preview => "2023-12-01-preview", + AzureServiceVersion::V2024_07_01Preview => "2024-07-01-preview", + }; + String::from(as_str) + } +} + +impl Display for AzureServiceVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&String::from(self.clone())) + } +} + +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for AzureServiceVersion { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request + .url_mut() + .query_pairs_mut() + .append_pair("api-version", &self.to_string()); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl From for Arc { + fn from(version: AzureServiceVersion) -> Arc { + Arc::new(version) + } +}