diff --git a/lib/instructor_lite/adapters/gemini.ex b/lib/instructor_lite/adapters/gemini.ex new file mode 100644 index 0000000..c2bfd95 --- /dev/null +++ b/lib/instructor_lite/adapters/gemini.ex @@ -0,0 +1,130 @@ +defmodule InstructorLite.Adapters.Gemini do + @moduledoc """ + + """ + @behaviour InstructorLite.Adapter + + @send_request_schema NimbleOptions.new!( + api_key: [ + type: :string, + required: true, + doc: "Gemini API key" + ], + http_client: [ + type: :atom, + default: Req, + doc: "Any module that follows `Req.post/2` interface" + ], + http_options: [ + type: :keyword_list, + default: [receive_timeout: 60_000], + doc: "Options passed to `http_client.post/2`" + ], + url: [ + type: :string, + default: + "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent", + doc: "API endpoint to use for sending requests" + ], + model: [ + type: :string, + default: "gemini-1.5-flash-8b", + doc: + "Gemini [model](https://ai.google.dev/gemini-api/docs/models/gemini)" + ] + ) + + @impl InstructorLite.Adapter + def send_request(params, opts) do + context = + opts + |> Keyword.get(:adapter_context, []) + |> NimbleOptions.validate!(@send_request_schema) + + options = + [ + path_params: [model: context[:model]], + path_params_style: :curly + ] + |> Keyword.merge(context[:http_options]) + |> Keyword.merge( + json: params, + params: [key: context[:api_key]] + ) + + case context[:http_client].post(context[:url], options) do + {:ok, %{status: 200, body: body}} -> {:ok, body} + {:ok, response} -> {:error, response} + {:error, reason} -> {:error, reason} + end + end + + @impl InstructorLite.Adapter + def initial_prompt(params, opts) do + mandatory_part = """ + As a genius expert, your task is to understand the content and provide the parsed objects in json that match json schema + """ + + optional_notes = + if notes = opts[:notes] do + """ + Additional notes on the schema:\n + #{notes} + """ + else + "" + end + + sys_instruction = %{ + parts: [ + %{text: mandatory_part <> optional_notes} + ] + } + + generation_config = %{ + responseMimeType: "application/json", + responseSchema: Keyword.fetch!(opts, :json_schema) + } + + params + |> Map.put_new(:systemInstruction, sys_instruction) + |> Map.update(:generationConfig, generation_config, fn user_config -> + Map.merge(generation_config, user_config) + end) + end + + @impl InstructorLite.Adapter + def retry_prompt(params, resp_params, errors, _response, _opts) do + do_better = [ + %{role: "model", parts: [%{text: Jason.encode!(resp_params)}]}, + %{ + role: "user", + parts: [ + %{ + text: """ + The response did not pass validation. Please try again and fix the following validation errors:\n + + #{errors} + """ + } + ] + } + ] + + Map.update(params, :contents, do_better, fn contents -> contents ++ do_better end) + end + + @impl InstructorLite.Adapter + def parse_response(response, _opts) do + case response do + %{"candidates" => [%{"content" => %{"parts" => [%{"text" => text}]}}]} -> + Jason.decode(text) + + %{"promptFeedback" => %{"blockReason" => reason}} -> + {:error, :refusal, reason} + + other -> + {:error, :unexpected_response, other} + end + end +end diff --git a/test/instructor_lite/adapters/gemini_test.exs b/test/instructor_lite/adapters/gemini_test.exs new file mode 100644 index 0000000..bd34c23 --- /dev/null +++ b/test/instructor_lite/adapters/gemini_test.exs @@ -0,0 +1,234 @@ +defmodule InstructorLite.Adapters.GeminiTest do + use ExUnit.Case, async: true + + import Mox + + alias InstructorLite.Adapters.Gemini + alias InstructorLite.HTTPClient + + setup :verify_on_exit! + + describe "initial_prompt/2" do + test "adds structured response parameters" do + params = %{} + + assert Gemini.initial_prompt(params, json_schema: :json_schema, notes: "Explanation") == %{ + generationConfig: %{ + responseMimeType: "application/json", + responseSchema: :json_schema + }, + systemInstruction: %{ + parts: [ + %{ + text: """ + As a genius expert, your task is to understand the content and provide the parsed objects in json that match json schema + Additional notes on the schema: + + Explanation + """ + } + ] + } + } + end + end + + describe "retry_prompt/5" do + test "adds new content entries" do + params = %{contents: []} + + assert Gemini.retry_prompt(params, %{foo: "bar"}, "list of errors", nil, []) == %{ + contents: [ + %{parts: [%{text: "{\"foo\":\"bar\"}"}], role: "model"}, + %{ + parts: [ + %{ + text: """ + The response did not pass validation. Please try again and fix the following validation errors: + + + list of errors + """ + } + ], + role: "user" + } + ] + } + end + end + + describe "parse_response/2" do + test "decodes json from expected output" do + response = %{ + "candidates" => [ + %{ + "avgLogprobs" => -0.0510383415222168, + "content" => %{ + "parts" => [ + %{ + "text" => "{\"birth_date\": \"1732-02-22\", \"name\": \"George Washington\"}\n" + } + ], + "role" => "model" + }, + "finishReason" => "STOP", + "safetyRatings" => [ + %{ + "category" => "HARM_CATEGORY_HATE_SPEECH", + "probability" => "NEGLIGIBLE" + }, + %{ + "category" => "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability" => "NEGLIGIBLE" + }, + %{ + "category" => "HARM_CATEGORY_HARASSMENT", + "probability" => "NEGLIGIBLE" + }, + %{ + "category" => "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability" => "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata" => %{ + "candidatesTokenCount" => 25, + "promptTokenCount" => 34, + "totalTokenCount" => 59 + } + } + + assert {:ok, %{"birth_date" => "1732-02-22", "name" => "George Washington"}} = + Gemini.parse_response(response, []) + end + + test "invalid json" do + response = %{ + "candidates" => [ + %{ + "avgLogprobs" => -0.0510383415222168, + "content" => %{ + "parts" => [ + %{ + "text" => "{{" + } + ], + "role" => "model" + }, + "finishReason" => "STOP", + "safetyRatings" => [ + %{ + "category" => "HARM_CATEGORY_HATE_SPEECH", + "probability" => "NEGLIGIBLE" + }, + %{ + "category" => "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability" => "NEGLIGIBLE" + }, + %{ + "category" => "HARM_CATEGORY_HARASSMENT", + "probability" => "NEGLIGIBLE" + }, + %{ + "category" => "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability" => "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata" => %{ + "candidatesTokenCount" => 25, + "promptTokenCount" => 34, + "totalTokenCount" => 59 + } + } + + assert {:error, %Jason.DecodeError{}} = Gemini.parse_response(response, []) + end + + test "returns refusal" do + response = %{ + "promptFeedback" => %{ + "blockReason" => "OTHER" + }, + "usageMetadata" => %{ + "candidatesTokenCount" => 25, + "promptTokenCount" => 34, + "totalTokenCount" => 59 + } + } + + assert {:error, :refusal, "OTHER"} = + Gemini.parse_response(response, []) + end + + test "unexpected content" do + response = "Internal Server Error" + + assert {:error, :unexpected_response, "Internal Server Error"} = + Gemini.parse_response(response, []) + end + end + + describe "send_request/2" do + test "overridable options" do + params = %{hello: "world"} + + opts = [ + adapter_context: [ + http_client: HTTPClient.Mock, + api_key: "api-key", + http_options: [foo: "bar", path_params: [model: "new-model"], path_params_style: :colon], + url: "https://generativelanguage.googleapis.com/v2alpha/models/:model/foo", + model: "gemini-1.5-flash" + ] + ] + + expect(HTTPClient.Mock, :post, fn url, options -> + assert url == "https://generativelanguage.googleapis.com/v2alpha/models/:model/foo" + + assert options == [ + foo: "bar", + path_params: [model: "new-model"], + path_params_style: :colon, + json: %{hello: "world"}, + params: [key: "api-key"] + ] + + {:ok, %{status: 200, body: "response"}} + end) + + assert {:ok, "response"} = Gemini.send_request(params, opts) + end + + test "non-200 response" do + opts = [ + adapter_context: [ + http_client: HTTPClient.Mock, + api_key: "api-key" + ] + ] + + expect(HTTPClient.Mock, :post, fn _url, _options -> + {:ok, %{status: 400, body: "response"}} + end) + + assert {:error, %{status: 400, body: "response"}} = Gemini.send_request(%{}, opts) + end + + test "request error" do + opts = [ + adapter_context: [ + http_client: HTTPClient.Mock, + api_key: "api-key" + ] + ] + + expect(HTTPClient.Mock, :post, fn _url, _options -> {:error, :timeout} end) + + assert {:error, :timeout} = Gemini.send_request(%{}, opts) + end + end +end diff --git a/test/integration_test.exs b/test/integration_test.exs index cf587f6..88d7cdd 100644 --- a/test/integration_test.exs +++ b/test/integration_test.exs @@ -2,7 +2,7 @@ defmodule InstructorLite.IntegrationTest do use ExUnit.Case, async: true alias InstructorLite.TestSchemas - alias InstructorLite.Adapters.{Anthropic, OpenAI, Llamacpp} + alias InstructorLite.Adapters.{Anthropic, OpenAI, Llamacpp, Gemini} @moduletag :integration @@ -411,4 +411,149 @@ defmodule InstructorLite.IntegrationTest do assert is_float(score) end end + + describe "Gemini" do + def to_gemini_schema(json_schema), do: Map.drop(json_schema, [:title, :additionalProperties]) + + test "schemaless" do + response_model = %{name: :string, birth_date: :date} + + json_schema = + response_model |> InstructorLite.JSONSchema.from_ecto_schema() |> to_gemini_schema() + + result = + InstructorLite.instruct( + %{ + contents: [ + %{role: "user", parts: [%{text: "Who was the first president of the USA?"}]} + ] + }, + response_model: response_model, + json_schema: json_schema, + adapter: Gemini, + adapter_context: [ + http_client: Req, + api_key: Application.fetch_env!(:instructor_lite, :gemini_key) + ] + ) + + assert {:ok, %{name: name, birth_date: birth_date}} = result + assert is_binary(name) + assert %Date{} = birth_date + end + + test "basic ecto schema" do + result = + InstructorLite.instruct( + %{ + contents: [ + %{ + role: "user", + parts: [ + %{ + text: + "Classify the following text: Hello, I am a Nigerian prince and I would like to give you $1,000,000." + } + ] + } + ] + }, + response_model: TestSchemas.SpamPrediction, + json_schema: TestSchemas.SpamPrediction.json_schema() |> to_gemini_schema(), + adapter: Gemini, + adapter_context: [ + http_client: Req, + api_key: Application.fetch_env!(:instructor_lite, :gemini_key) + ] + ) + + assert {:ok, %{class: :spam, score: score}} = result + assert is_float(score) + end + + test "all ecto types" do + result = + InstructorLite.instruct( + %{ + contents: [ + %{ + role: "user", + parts: [%{text: "Please fill test data"}] + } + ] + }, + response_model: TestSchemas.AllEctoTypes, + json_schema: TestSchemas.AllEctoTypes.json_schema() |> to_gemini_schema(), + adapter: Gemini, + adapter_context: [ + http_client: Req, + api_key: Application.fetch_env!(:instructor_lite, :gemini_key) + ] + ) + + assert {:ok, + %{ + binary_id: binary_id, + integer: integer, + float: float, + boolean: boolean, + string: string, + array: array, + # map: _map, + # map_two: _map_two, + decimal: decimal, + date: date, + time: time, + time_usec: time_usec, + naive_datetime: naive_datetime, + naive_datetime_usec: naive_datetime_usec, + utc_datetime: utc_datetime, + utc_datetime_usec: utc_datetime_usec + }} = result + + assert is_binary(binary_id) + assert is_integer(integer) + assert is_float(float) + assert is_boolean(boolean) + assert is_binary(string) + assert is_list(array) + # Doesn't work? + # assert is_map(map) + # assert is_map(map_two) + assert %Decimal{} = decimal + assert %Date{} = date + assert %Time{} = time + assert %Time{} = time_usec + assert %NaiveDateTime{} = naive_datetime + assert %NaiveDateTime{} = naive_datetime_usec + assert %DateTime{} = utc_datetime + assert %DateTime{} = utc_datetime_usec + end + + test "with validate_changeset" do + result = + InstructorLite.instruct( + %{ + contents: [ + %{ + role: "user", + parts: [%{text: "Guess the result!"}] + } + ] + }, + response_model: TestSchemas.CoinGuess, + json_schema: TestSchemas.CoinGuess.json_schema() |> to_gemini_schema(), + max_retries: 1, + adapter: Gemini, + adapter_context: [ + http_client: Req, + api_key: Application.fetch_env!(:instructor_lite, :gemini_key), + model: "gemini-1.5-pro" + ], + extra: :tails + ) + + assert {:ok, %{guess: :tails}} = result + end + end end