Skip to content

Commit

Permalink
Add Gemini adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
martosaur committed Oct 21, 2024
1 parent ca197f9 commit eb4c683
Show file tree
Hide file tree
Showing 3 changed files with 510 additions and 1 deletion.
130 changes: 130 additions & 0 deletions lib/instructor_lite/adapters/gemini.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
defmodule InstructorLite.Adapters.Gemini do
@moduledoc """
"""
@behaviour InstructorLite.Adapter

@send_request_schema NimbleOptions.new!(
api_key: [
type: :string,
required: true,
doc: "Gemini API key"
],
http_client: [
type: :atom,
default: Req,
doc: "Any module that follows `Req.post/2` interface"
],
http_options: [
type: :keyword_list,
default: [receive_timeout: 60_000],
doc: "Options passed to `http_client.post/2`"
],
url: [
type: :string,
default:
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent",
doc: "API endpoint to use for sending requests"
],
model: [
type: :string,
default: "gemini-1.5-flash-8b",
doc:
"Gemini [model](https://ai.google.dev/gemini-api/docs/models/gemini)"
]
)

@impl InstructorLite.Adapter
def send_request(params, opts) do
context =
opts
|> Keyword.get(:adapter_context, [])
|> NimbleOptions.validate!(@send_request_schema)

options =
[
path_params: [model: context[:model]],
path_params_style: :curly
]
|> Keyword.merge(context[:http_options])
|> Keyword.merge(
json: params,
params: [key: context[:api_key]]
)

case context[:http_client].post(context[:url], options) do
{:ok, %{status: 200, body: body}} -> {:ok, body}
{:ok, response} -> {:error, response}
{:error, reason} -> {:error, reason}
end
end

@impl InstructorLite.Adapter
def initial_prompt(params, opts) do
mandatory_part = """
As a genius expert, your task is to understand the content and provide the parsed objects in json that match json schema
"""

optional_notes =
if notes = opts[:notes] do
"""
Additional notes on the schema:\n
#{notes}
"""
else
""
end

sys_instruction = %{
parts: [
%{text: mandatory_part <> optional_notes}
]
}

generation_config = %{
responseMimeType: "application/json",
responseSchema: Keyword.fetch!(opts, :json_schema)
}

params
|> Map.put_new(:systemInstruction, sys_instruction)
|> Map.update(:generationConfig, generation_config, fn user_config ->
Map.merge(generation_config, user_config)
end)
end

@impl InstructorLite.Adapter
def retry_prompt(params, resp_params, errors, _response, _opts) do
do_better = [
%{role: "model", parts: [%{text: Jason.encode!(resp_params)}]},
%{
role: "user",
parts: [
%{
text: """
The response did not pass validation. Please try again and fix the following validation errors:\n
#{errors}
"""
}
]
}
]

Map.update(params, :contents, do_better, fn contents -> contents ++ do_better end)
end

@impl InstructorLite.Adapter
def parse_response(response, _opts) do
case response do
%{"candidates" => [%{"content" => %{"parts" => [%{"text" => text}]}}]} ->
Jason.decode(text)

%{"promptFeedback" => %{"blockReason" => reason}} ->
{:error, :refusal, reason}

other ->
{:error, :unexpected_response, other}
end
end
end
234 changes: 234 additions & 0 deletions test/instructor_lite/adapters/gemini_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
defmodule InstructorLite.Adapters.GeminiTest do
use ExUnit.Case, async: true

import Mox

alias InstructorLite.Adapters.Gemini
alias InstructorLite.HTTPClient

setup :verify_on_exit!

describe "initial_prompt/2" do
test "adds structured response parameters" do
params = %{}

assert Gemini.initial_prompt(params, json_schema: :json_schema, notes: "Explanation") == %{
generationConfig: %{
responseMimeType: "application/json",
responseSchema: :json_schema
},
systemInstruction: %{
parts: [
%{
text: """
As a genius expert, your task is to understand the content and provide the parsed objects in json that match json schema
Additional notes on the schema:
Explanation
"""
}
]
}
}
end
end

describe "retry_prompt/5" do
test "adds new content entries" do
params = %{contents: []}

assert Gemini.retry_prompt(params, %{foo: "bar"}, "list of errors", nil, []) == %{
contents: [
%{parts: [%{text: "{\"foo\":\"bar\"}"}], role: "model"},
%{
parts: [
%{
text: """
The response did not pass validation. Please try again and fix the following validation errors:
list of errors
"""
}
],
role: "user"
}
]
}
end
end

describe "parse_response/2" do
test "decodes json from expected output" do
response = %{
"candidates" => [
%{
"avgLogprobs" => -0.0510383415222168,
"content" => %{
"parts" => [
%{
"text" => "{\"birth_date\": \"1732-02-22\", \"name\": \"George Washington\"}\n"
}
],
"role" => "model"
},
"finishReason" => "STOP",
"safetyRatings" => [
%{
"category" => "HARM_CATEGORY_HATE_SPEECH",
"probability" => "NEGLIGIBLE"
},
%{
"category" => "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability" => "NEGLIGIBLE"
},
%{
"category" => "HARM_CATEGORY_HARASSMENT",
"probability" => "NEGLIGIBLE"
},
%{
"category" => "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability" => "NEGLIGIBLE"
}
]
}
],
"usageMetadata" => %{
"candidatesTokenCount" => 25,
"promptTokenCount" => 34,
"totalTokenCount" => 59
}
}

assert {:ok, %{"birth_date" => "1732-02-22", "name" => "George Washington"}} =
Gemini.parse_response(response, [])
end

test "invalid json" do
response = %{
"candidates" => [
%{
"avgLogprobs" => -0.0510383415222168,
"content" => %{
"parts" => [
%{
"text" => "{{"
}
],
"role" => "model"
},
"finishReason" => "STOP",
"safetyRatings" => [
%{
"category" => "HARM_CATEGORY_HATE_SPEECH",
"probability" => "NEGLIGIBLE"
},
%{
"category" => "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability" => "NEGLIGIBLE"
},
%{
"category" => "HARM_CATEGORY_HARASSMENT",
"probability" => "NEGLIGIBLE"
},
%{
"category" => "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability" => "NEGLIGIBLE"
}
]
}
],
"usageMetadata" => %{
"candidatesTokenCount" => 25,
"promptTokenCount" => 34,
"totalTokenCount" => 59
}
}

assert {:error, %Jason.DecodeError{}} = Gemini.parse_response(response, [])
end

test "returns refusal" do
response = %{
"promptFeedback" => %{
"blockReason" => "OTHER"
},
"usageMetadata" => %{
"candidatesTokenCount" => 25,
"promptTokenCount" => 34,
"totalTokenCount" => 59
}
}

assert {:error, :refusal, "OTHER"} =
Gemini.parse_response(response, [])
end

test "unexpected content" do
response = "Internal Server Error"

assert {:error, :unexpected_response, "Internal Server Error"} =
Gemini.parse_response(response, [])
end
end

describe "send_request/2" do
test "overridable options" do
params = %{hello: "world"}

opts = [
adapter_context: [
http_client: HTTPClient.Mock,
api_key: "api-key",
http_options: [foo: "bar", path_params: [model: "new-model"], path_params_style: :colon],
url: "https://generativelanguage.googleapis.com/v2alpha/models/:model/foo",
model: "gemini-1.5-flash"
]
]

expect(HTTPClient.Mock, :post, fn url, options ->
assert url == "https://generativelanguage.googleapis.com/v2alpha/models/:model/foo"

assert options == [
foo: "bar",
path_params: [model: "new-model"],
path_params_style: :colon,
json: %{hello: "world"},
params: [key: "api-key"]
]

{:ok, %{status: 200, body: "response"}}
end)

assert {:ok, "response"} = Gemini.send_request(params, opts)
end

test "non-200 response" do
opts = [
adapter_context: [
http_client: HTTPClient.Mock,
api_key: "api-key"
]
]

expect(HTTPClient.Mock, :post, fn _url, _options ->
{:ok, %{status: 400, body: "response"}}
end)

assert {:error, %{status: 400, body: "response"}} = Gemini.send_request(%{}, opts)
end

test "request error" do
opts = [
adapter_context: [
http_client: HTTPClient.Mock,
api_key: "api-key"
]
]

expect(HTTPClient.Mock, :post, fn _url, _options -> {:error, :timeout} end)

assert {:error, :timeout} = Gemini.send_request(%{}, opts)
end
end
end
Loading

0 comments on commit eb4c683

Please sign in to comment.