-
-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Reranker WIP * add cacheing and singleton loading * Add field to workspaces for vectorSearchMode Add UI for lancedb to change mode update all search endpoints to pass in reranker prop if provider can use it * update hint text * When reranking, swap score to rerank score * update optchain
- Loading branch information
1 parent
bb5c3b7
commit ad01df8
Showing
16 changed files
with
339 additions
and
9 deletions.
There are no files selected for viewing
51 changes: 51 additions & 0 deletions
51
frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import { useState } from "react"; | ||
|
||
// We dont support all vectorDBs yet for reranking due to complexities of how each provider | ||
// returns information. We need to normalize the response data so Reranker can be used for each provider. | ||
const supportedVectorDBs = ["lancedb"]; | ||
const hint = { | ||
default: { | ||
title: "Default", | ||
description: | ||
"This is the fastest performance, but may not return the most relevant results leading to model hallucinations.", | ||
}, | ||
rerank: { | ||
title: "Accuracy Optimized", | ||
description: | ||
"LLM responses may take longer to generate, but your responses will be more accurate and relevant.", | ||
}, | ||
}; | ||
|
||
export default function VectorSearchMode({ workspace, setHasChanges }) { | ||
const [selection, setSelection] = useState( | ||
workspace?.vectorSearchMode ?? "default" | ||
); | ||
if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB)) | ||
return null; | ||
|
||
return ( | ||
<div> | ||
<div className="flex flex-col"> | ||
<label htmlFor="name" className="block input-label"> | ||
Search Preference | ||
</label> | ||
</div> | ||
<select | ||
name="vectorSearchMode" | ||
value={selection} | ||
className="border-none bg-theme-settings-input-bg text-white text-sm mt-2 rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5" | ||
onChange={(e) => { | ||
setSelection(e.target.value); | ||
setHasChanges(true); | ||
}} | ||
required={true} | ||
> | ||
<option value="default">Default</option> | ||
<option value="rerank">Accuracy Optimized</option> | ||
</select> | ||
<p className="text-white text-opacity-60 text-xs font-medium py-1.5"> | ||
{hint[selection]?.description} | ||
</p> | ||
</div> | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
-- AlterTable | ||
ALTER TABLE "workspaces" ADD COLUMN "vectorSearchMode" TEXT DEFAULT 'default'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ downloaded/* | |
!downloaded/.placeholder | ||
openrouter | ||
apipie | ||
novita | ||
novita | ||
mixedbread-ai* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
const path = require("path"); | ||
const fs = require("fs"); | ||
|
||
class NativeEmbeddingReranker { | ||
static #model = null; | ||
static #tokenizer = null; | ||
static #transformers = null; | ||
|
||
constructor() { | ||
// An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s) | ||
// Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s) | ||
this.model = "Xenova/ms-marco-MiniLM-L-6-v2"; | ||
this.cacheDir = path.resolve( | ||
process.env.STORAGE_DIR | ||
? path.resolve(process.env.STORAGE_DIR, `models`) | ||
: path.resolve(__dirname, `../../../storage/models`) | ||
); | ||
this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/")); | ||
// Make directory when it does not exist in existing installations | ||
if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir); | ||
this.log("Initialized"); | ||
} | ||
|
||
log(text, ...args) { | ||
console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args); | ||
} | ||
|
||
/** | ||
* This function will preload the reranker suite and tokenizer. | ||
* This is useful for reducing the latency of the first rerank call and pre-downloading the models and such | ||
* to avoid having to wait for the models to download on the first rerank call. | ||
*/ | ||
async preload() { | ||
try { | ||
this.log(`Preloading reranker suite...`); | ||
await this.initClient(); | ||
this.log( | ||
`Preloaded reranker suite. Reranking is available as a service now.` | ||
); | ||
return; | ||
} catch (e) { | ||
console.error(e); | ||
this.log( | ||
`Failed to preload reranker suite. Reranking will be available on the first rerank call.` | ||
); | ||
return; | ||
} | ||
} | ||
|
||
async initClient() { | ||
if (NativeEmbeddingReranker.#transformers) { | ||
this.log(`Reranker suite already initialized - reusing.`); | ||
return; | ||
} | ||
|
||
await import("@xenova/transformers").then( | ||
async ({ AutoModelForSequenceClassification, AutoTokenizer }) => { | ||
this.log(`Loading reranker suite...`); | ||
NativeEmbeddingReranker.#transformers = { | ||
AutoModelForSequenceClassification, | ||
AutoTokenizer, | ||
}; | ||
await this.#getPreTrainedModel(); | ||
await this.#getPreTrainedTokenizer(); | ||
} | ||
); | ||
return; | ||
} | ||
|
||
async #getPreTrainedModel() { | ||
if (NativeEmbeddingReranker.#model) { | ||
this.log(`Loading model from singleton...`); | ||
return NativeEmbeddingReranker.#model; | ||
} | ||
|
||
const model = | ||
await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained( | ||
this.model, | ||
{ | ||
progress_callback: (p) => | ||
p.status === "progress" && | ||
this.log(`Loading model ${this.model}... ${p?.progress}%`), | ||
cache_dir: this.cacheDir, | ||
} | ||
); | ||
this.log(`Loaded model ${this.model}`); | ||
NativeEmbeddingReranker.#model = model; | ||
return model; | ||
} | ||
|
||
async #getPreTrainedTokenizer() { | ||
if (NativeEmbeddingReranker.#tokenizer) { | ||
this.log(`Loading tokenizer from singleton...`); | ||
return NativeEmbeddingReranker.#tokenizer; | ||
} | ||
|
||
const tokenizer = | ||
await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained( | ||
this.model, | ||
{ | ||
progress_callback: (p) => | ||
p.status === "progress" && | ||
this.log(`Loading tokenizer ${this.model}... ${p?.progress}%`), | ||
cache_dir: this.cacheDir, | ||
} | ||
); | ||
this.log(`Loaded tokenizer ${this.model}`); | ||
NativeEmbeddingReranker.#tokenizer = tokenizer; | ||
return tokenizer; | ||
} | ||
|
||
/** | ||
* Reranks a list of documents based on the query. | ||
* @param {string} query - The query to rerank the documents against. | ||
* @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search. | ||
* @param {Object} options - The options for the reranking. | ||
* @param {number} options.topK - The number of top documents to return. | ||
* @returns {Promise<any[]>} - The reranked list of documents. | ||
*/ | ||
async rerank(query, documents, options = { topK: 4 }) { | ||
await this.initClient(); | ||
const model = NativeEmbeddingReranker.#model; | ||
const tokenizer = NativeEmbeddingReranker.#tokenizer; | ||
|
||
const start = Date.now(); | ||
this.log(`Reranking ${documents.length} documents...`); | ||
const inputs = tokenizer(new Array(documents.length).fill(query), { | ||
text_pair: documents.map((doc) => doc.text), | ||
padding: true, | ||
truncation: true, | ||
}); | ||
const { logits } = await model(inputs); | ||
const reranked = logits | ||
.sigmoid() | ||
.tolist() | ||
.map(([score], i) => ({ | ||
rerank_corpus_id: i, | ||
rerank_score: score, | ||
...documents[i], | ||
})) | ||
.sort((a, b) => b.rerank_score - a.rerank_score) | ||
.slice(0, options.topK); | ||
|
||
this.log( | ||
`Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms` | ||
); | ||
return reranked; | ||
} | ||
} | ||
|
||
module.exports = { | ||
NativeEmbeddingReranker, | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.