This repository has been archived by the owner on Sep 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create packages context * Add DocItem schema * Add DocFragment schema * Add default host to config * Add Embeddings context * Add paraphrase-L3 as an embedding model * Add more options for the embedding macro * Add tests, clean up old unused modules * Set up tests to handle multiple embedding models * Add another embedding model to the app * Add the embed function * Rework HexClient to only use hex repo * Add the function to add packages * Add tests and bugfixes for the items and fragments helper functions * Add mix task for adding documentation * Add mix task for embeddings * Update the frontend to work with new changes * Formatting * Remove unnecessary metaprogramming from the embedding logic * Consolidate all migrations into one * Update deps * Change logic of finding the latest release to account for prerelease versions * Add tests for the latest release function, move code to more appropriate module * Remove unnecessary docstrings * Use a transaction_with helper function * Rename the mix tasks to comply with convention * Change the mix tasks to comply with new functions * Move the embedding process outside of transaction to preven timeout errors in postgres * Change the add_package function to make better use of the transaction_with helper * Add error messages for the embed mix task * Decouple the DB logic from embedding provider behaviour * Simplify the latest function for releases * Change error handling in HexClient * Remove redundant Nx.Serving batch size options * Fix typo * Bump Req and ReqHex to get rid of dodgy error management in HexClient * Remove unnecessary `with` clauses * Change release tests to be more realistic
- Loading branch information
1 parent
1219e1b
commit 6350441
Showing
33 changed files
with
1,020 additions
and
345 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
defmodule Mix.Tasks.Search.Add do | ||
alias Search.Packages | ||
alias Search.HexClient | ||
|
||
@moduledoc """ | ||
Usage: mix #{Mix.Task.task_name(__MODULE__)} <PACKAGE> [<VERSION>] | ||
Fetches the documentation for the given package from Hex. Does not embed it yet. | ||
If the version is ommitted, it will choose the newest release. | ||
""" | ||
@shortdoc "Adds a package's documentation to the index" | ||
|
||
use Mix.Task | ||
|
||
@requirements ["app.start"] | ||
|
||
@impl Mix.Task | ||
def run(args) do | ||
[package | args_tail] = args | ||
|
||
package_or_release = | ||
case args_tail do | ||
[version] -> | ||
version = Version.parse!(version) | ||
%HexClient.Release{package_name: package, version: version} | ||
|
||
[] -> | ||
package | ||
end | ||
|
||
case Packages.add_package(package_or_release) do | ||
{:ok, package} -> Mix.shell().info("Package #{package.name}@#{package.version} added.") | ||
{:error, err} -> Mix.shell().error("Error: #{err}") | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
defmodule Mix.Tasks.Search.Embed do | ||
@moduledoc """ | ||
Usage: mix #{Mix.Task.task_name(__MODULE__)} <MODEL_NAME> | ||
Embeds the unembedded docs using the model registered in the config | ||
""" | ||
@shortdoc "Embeds the unembedded doc fragments" | ||
|
||
use Mix.Task | ||
|
||
@requirements ["app.start"] | ||
|
||
defp callback({total, done}) do | ||
ProgressBar.render(done, total) | ||
end | ||
|
||
@impl Mix.Task | ||
def run([model_name]) do | ||
embedding_models = | ||
Search.Application.embedding_models() | ||
|> Keyword.keys() | ||
|> Enum.map(&Atom.to_string/1) | ||
|
||
if Enum.member?(embedding_models, model_name) do | ||
Search.Embeddings.embed(String.to_existing_atom(model_name), &callback/1) | ||
Mix.shell().info("Done.") | ||
else | ||
Mix.shell().error("Expected model name to be one of: #{Enum.join(embedding_models, ", ")}.") | ||
end | ||
end | ||
end |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
defmodule Search.Embeddings do | ||
@moduledoc """ | ||
The Embeddings context. | ||
""" | ||
|
||
import Ecto.Query | ||
import Pgvector.Ecto.Query | ||
require Logger | ||
alias Search.{Embeddings, Repo, Packages} | ||
|
||
@doc """ | ||
Embeds any doc fragments which do not have an embedding yet. | ||
Recieves an optional callback, | ||
which is called to notify about the embedding progress with the tuple {total, done} as its argument. | ||
""" | ||
def embed(model_name, progress_callback \\ &Function.identity/1) do | ||
{provider, config} = | ||
Application.fetch_env!(:search, :embedding_providers) | ||
|> Keyword.fetch!(model_name) | ||
|
||
table_name = table_name(model_name) | ||
|
||
fragments = | ||
from f in Packages.DocFragment, | ||
left_join: e in ^{table_name, Embeddings.Embedding}, | ||
on: e.doc_fragment_id == f.id, | ||
where: is_nil(e) | ||
|
||
fragments = Repo.all(fragments) | ||
texts = Enum.map(fragments, & &1.text) | ||
|
||
embeddings = provider.embed(texts, progress_callback, config) | ||
|
||
now = DateTime.utc_now(:second) | ||
|
||
embeddings_params = | ||
Stream.zip(fragments, embeddings) | ||
|> Enum.map(fn {fragment, embedding} -> | ||
%{ | ||
doc_fragment_id: fragment.id, | ||
embedding: embedding, | ||
updated_at: now, | ||
inserted_at: now | ||
} | ||
end) | ||
|
||
Repo.transaction_with(fn -> | ||
{inserted_count, inserted_embeddings} = | ||
Repo.insert_all({table_name, Embeddings.Embedding}, embeddings_params, returning: true) | ||
|
||
if inserted_count == length(embeddings) do | ||
{:ok, inserted_embeddings} | ||
else | ||
{:error, "Could not insert all embeddings."} | ||
end | ||
end) | ||
end | ||
|
||
def embedding_size(model_name), do: get_config(model_name, :embedding_size) | ||
def table_name(model_name), do: "embeddings__#{model_name}" | ||
|
||
def embed_one(model_name, text) do | ||
{provider, config} = | ||
Application.fetch_env!(:search, :embedding_providers) | ||
|> Keyword.fetch!(model_name) | ||
|
||
provider.embed_one(text, config) | ||
end | ||
|
||
def knn_query(model_name, query_vector, opts \\ []) do | ||
table_name = table_name(model_name) | ||
|
||
%{metric: metric, k: k} = | ||
opts | ||
|> Keyword.validate!(metric: :cosine, k: nil) | ||
|> Map.new() | ||
|
||
query = | ||
from e in {table_name, Embeddings.Embedding}, | ||
preload: [doc_fragment: [doc_item: :package]], | ||
select: e, | ||
limit: ^k | ||
|
||
query = | ||
case metric do | ||
:cosine -> | ||
from e in query, | ||
order_by: cosine_distance(e.embedding, ^query_vector) | ||
|
||
:l2 -> | ||
from e in query, | ||
order_by: l2_distance(e.embedding, ^query_vector) | ||
end | ||
|
||
Repo.all(query) | ||
end | ||
|
||
defp get_config(model_name, key) do | ||
{_provider, config} = | ||
Application.fetch_env!(:search, :embedding_providers) | ||
|> Keyword.fetch!(model_name) | ||
|
||
Keyword.fetch!(config, key) | ||
end | ||
end |
Oops, something went wrong.