Skip to content
This repository has been archived by the owner on Sep 3, 2024. It is now read-only.

Commit

Permalink
Indexing logic improvements (#6)
Browse files Browse the repository at this point in the history
* Create packages context

* Add DocItem schema

* Add DocFragment schema

* Add default host to config

* Add Embeddings context

* Add paraphrase-L3 as an embedding model

* Add more options for the embedding macro

* Add tests, clean up old unused modules

* Set up tests to handle multiple embedding models

* Add another embedding model to the app

* Add the embed function

* Rework HexClient to only use hex repo

* Add the function to add packages

* Add tests and bugfixes for the items and fragments helper functions

* Add mix task for adding documentation

* Add mix task for embeddings

* Update the frontend to work with new changes

* Formatting

* Remove unnecessary metaprogramming from the embedding logic

* Consolidate all migrations into one

* Update deps

* Change logic of finding the latest release to account for prerelease versions

* Add tests for the latest release function, move code to more appropriate module

* Remove unnecessary docstrings

* Use a transaction_with helper function

* Rename the mix tasks to comply with convention

* Change the mix tasks to comply with new functions

* Move the embedding process outside of transaction to preven timeout errors in postgres

* Change the add_package function to make better use of the transaction_with helper

* Add error messages for the embed mix task

* Decouple the DB logic from embedding provider behaviour

* Simplify the latest function for releases

* Change error handling in HexClient

* Remove redundant Nx.Serving batch size options

* Fix typo

* Bump Req and ReqHex to get rid of dodgy error management in HexClient

* Remove unnecessary `with` clauses

* Change release tests to be more realistic
  • Loading branch information
karol-t-wilk authored Jun 14, 2024
1 parent 1219e1b commit 6350441
Show file tree
Hide file tree
Showing 33 changed files with 1,020 additions and 345 deletions.
31 changes: 30 additions & 1 deletion config/config.exs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,35 @@ config :search,
# Add types added by the pgvector-elixir extension to Postgrex
config :search, Search.Repo, types: Search.PostgrexTypes

# Register embedding providers
config :search, :embedding_providers,
paraphrase_l3: {
Search.Embeddings.BumblebeeProvider,
serving_name: Search.Embeddings.ParaphraseL3,
model: {:hf, "sentence-transformers/paraphrase-MiniLM-L3-v2"},
embedding_size: 384,
load_model_opts: [
backend: EXLA.Backend
],
serving_opts: [
compile: [batch_size: 16, sequence_length: 512],
defn_options: [compiler: EXLA]
]
},
paraphrase_albert_small: {
Search.Embeddings.BumblebeeProvider,
serving_name: Search.Embeddings.ParaphraseAlbertSmall,
model: {:hf, "sentence-transformers/paraphrase-albert-small-v2"},
embedding_size: 768,
load_model_opts: [
backend: EXLA.Backend
],
serving_opts: [
compile: [batch_size: 16, sequence_length: 100],
defn_options: [compiler: EXLA]
]
}

# Configures the endpoint
config :search, SearchWeb.Endpoint,
url: [host: "localhost"],
Expand Down Expand Up @@ -56,7 +85,7 @@ config :logger, :console,
config :phoenix, :json_library, Jason

# Configure the EXLA backend for Nx
config :nx, :default_backend, EXLA.Backend
config :nx, :default_backend, {EXLA.Backend, client: :host}

# Import environment specific config. This must remain at the bottom
# of this file so it overrides the configuration defined above.
Expand Down
37 changes: 37 additions & 0 deletions lib/mix/tasks/search.add.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
defmodule Mix.Tasks.Search.Add do
alias Search.Packages
alias Search.HexClient

@moduledoc """
Usage: mix #{Mix.Task.task_name(__MODULE__)} <PACKAGE> [<VERSION>]
Fetches the documentation for the given package from Hex. Does not embed it yet.
If the version is ommitted, it will choose the newest release.
"""
@shortdoc "Adds a package's documentation to the index"

use Mix.Task

@requirements ["app.start"]

@impl Mix.Task
def run(args) do
[package | args_tail] = args

package_or_release =
case args_tail do
[version] ->
version = Version.parse!(version)
%HexClient.Release{package_name: package, version: version}

[] ->
package
end

case Packages.add_package(package_or_release) do
{:ok, package} -> Mix.shell().info("Package #{package.name}@#{package.version} added.")
{:error, err} -> Mix.shell().error("Error: #{err}")
end
end
end
31 changes: 31 additions & 0 deletions lib/mix/tasks/search.embed.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
defmodule Mix.Tasks.Search.Embed do
@moduledoc """
Usage: mix #{Mix.Task.task_name(__MODULE__)} <MODEL_NAME>
Embeds the unembedded docs using the model registered in the config
"""
@shortdoc "Embeds the unembedded doc fragments"

use Mix.Task

@requirements ["app.start"]

defp callback({total, done}) do
ProgressBar.render(done, total)
end

@impl Mix.Task
def run([model_name]) do
embedding_models =
Search.Application.embedding_models()
|> Keyword.keys()
|> Enum.map(&Atom.to_string/1)

if Enum.member?(embedding_models, model_name) do
Search.Embeddings.embed(String.to_existing_atom(model_name), &callback/1)
Mix.shell().info("Done.")
else
Mix.shell().error("Expected model name to be one of: #{Enum.join(embedding_models, ", ")}.")
end
end
end
53 changes: 0 additions & 53 deletions lib/mix/tasks/search/index.ex

This file was deleted.

26 changes: 16 additions & 10 deletions lib/search/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,24 @@ defmodule Search.Application do

use Application

def embedding_models do
Application.fetch_env!(:search, :embedding_providers)
end

@impl true
def start(_type, _args) do
children = [
SearchWeb.Telemetry,
Search.Repo,
{DNSCluster, query: Application.get_env(:search, :dns_cluster_query) || :ignore},
{Phoenix.PubSub, name: Search.PubSub},
# Start a worker by calling: Search.Worker.start_link(arg)
{Search.Embedding, name: Search.Embedding},
# Start to serve requests, typically the last entry
SearchWeb.Endpoint
]
children =
[
SearchWeb.Telemetry,
Search.Repo,
{DNSCluster, query: Application.get_env(:search, :dns_cluster_query) || :ignore},
{Phoenix.PubSub, name: Search.PubSub}
] ++
Enum.map(embedding_models(), fn {_, {provider, opts}} -> provider.child_spec(opts) end) ++
[
# Start to serve requests, typically the last entry
SearchWeb.Endpoint
]

# See https://hexdocs.pm/elixir/Supervisor.html
# for other strategies and supported options
Expand Down
23 changes: 0 additions & 23 deletions lib/search/embedding.ex

This file was deleted.

107 changes: 107 additions & 0 deletions lib/search/embeddings.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
defmodule Search.Embeddings do
@moduledoc """
The Embeddings context.
"""

import Ecto.Query
import Pgvector.Ecto.Query
require Logger
alias Search.{Embeddings, Repo, Packages}

@doc """
Embeds any doc fragments which do not have an embedding yet.
Recieves an optional callback,
which is called to notify about the embedding progress with the tuple {total, done} as its argument.
"""
def embed(model_name, progress_callback \\ &Function.identity/1) do
{provider, config} =
Application.fetch_env!(:search, :embedding_providers)
|> Keyword.fetch!(model_name)

table_name = table_name(model_name)

fragments =
from f in Packages.DocFragment,
left_join: e in ^{table_name, Embeddings.Embedding},
on: e.doc_fragment_id == f.id,
where: is_nil(e)

fragments = Repo.all(fragments)
texts = Enum.map(fragments, & &1.text)

embeddings = provider.embed(texts, progress_callback, config)

now = DateTime.utc_now(:second)

embeddings_params =
Stream.zip(fragments, embeddings)
|> Enum.map(fn {fragment, embedding} ->
%{
doc_fragment_id: fragment.id,
embedding: embedding,
updated_at: now,
inserted_at: now
}
end)

Repo.transaction_with(fn ->
{inserted_count, inserted_embeddings} =
Repo.insert_all({table_name, Embeddings.Embedding}, embeddings_params, returning: true)

if inserted_count == length(embeddings) do
{:ok, inserted_embeddings}
else
{:error, "Could not insert all embeddings."}
end
end)
end

def embedding_size(model_name), do: get_config(model_name, :embedding_size)
def table_name(model_name), do: "embeddings__#{model_name}"

def embed_one(model_name, text) do
{provider, config} =
Application.fetch_env!(:search, :embedding_providers)
|> Keyword.fetch!(model_name)

provider.embed_one(text, config)
end

def knn_query(model_name, query_vector, opts \\ []) do
table_name = table_name(model_name)

%{metric: metric, k: k} =
opts
|> Keyword.validate!(metric: :cosine, k: nil)
|> Map.new()

query =
from e in {table_name, Embeddings.Embedding},
preload: [doc_fragment: [doc_item: :package]],
select: e,
limit: ^k

query =
case metric do
:cosine ->
from e in query,
order_by: cosine_distance(e.embedding, ^query_vector)

:l2 ->
from e in query,
order_by: l2_distance(e.embedding, ^query_vector)
end

Repo.all(query)
end

defp get_config(model_name, key) do
{_provider, config} =
Application.fetch_env!(:search, :embedding_providers)
|> Keyword.fetch!(model_name)

Keyword.fetch!(config, key)
end
end
Loading

0 comments on commit 6350441

Please sign in to comment.