Skip to content

Commit

Permalink
create configurable topN per workspace (#616)
Browse files Browse the repository at this point in the history
* create configurable topN per workspace

* Update TopN UI text
Fix fallbacks for all providers
Add SQLite CHECK to TOPN value

* merge with master
Update zilliz provider for variable TopN

---------

Co-authored-by: timothycarambat <[email protected]>
  • Loading branch information
shatfield4 and timothycarambat authored Jan 18, 2024
1 parent 683cd69 commit 56fa17c
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 19 deletions.
35 changes: 35 additions & 0 deletions frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ function castToType(key, value) {
similarityThreshold: {
cast: (value) => parseFloat(value),
},
topN: {
cast: (value) => Number(value),
},
};

if (!definitions.hasOwnProperty(key)) return value;
Expand Down Expand Up @@ -236,6 +239,38 @@ export default function WorkspaceSettings({ active, workspace, settings }) {
autoComplete="off"
onChange={() => setHasChanges(true)}
/>

<div className="mt-4">
<div className="flex flex-col">
<label
htmlFor="name"
className="block text-sm font-medium text-white"
>
Max Context Snippets
</label>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
This setting controls the maximum amount of context
snippets the will be sent to the LLM for per chat or
query.
<br />
<i>Recommended: 4</i>
</p>
</div>
<input
name="topN"
type="number"
min={1}
max={12}
step={1}
onWheel={(e) => e.target.blur()}
defaultValue={workspace?.topN ?? 4}
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
placeholder="4"
required={true}
autoComplete="off"
onChange={() => setHasChanges(true)}
/>
</div>
<div className="mt-4">
<div className="flex flex-col">
<label
Expand Down
1 change: 1 addition & 0 deletions server/models/workspace.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const Workspace = {
"openAiPrompt",
"similarityThreshold",
"chatModel",
"topN",
],

new: async function (name = null, creatorId = null) {
Expand Down
2 changes: 2 additions & 0 deletions server/prisma/migrations/20240118201333_init/migration.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "workspaces" ADD COLUMN "topN" INTEGER DEFAULT 4 CHECK ("topN" > 0);
1 change: 1 addition & 0 deletions server/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ model workspaces {
openAiPrompt String?
similarityThreshold Float? @default(0.25)
chatModel String?
topN Int? @default(4)
workspace_users workspace_users[]
documents workspace_documents[]
}
Expand Down
1 change: 1 addition & 0 deletions server/utils/chats/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ async function chatWithWorkspace(
input: message,
LLMConnector,
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
});

// Failed similarity search.
Expand Down
1 change: 1 addition & 0 deletions server/utils/chats/stream.js
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ async function streamChatWithWorkspace(
input: message,
LLMConnector,
similarityThreshold: workspace?.similarityThreshold,
topN: workspace?.topN,
});

// Failed similarity search.
Expand Down
9 changes: 6 additions & 3 deletions server/utils/vectorDbProviders/chroma/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ const Chroma = {
client,
namespace,
queryVector,
similarityThreshold = 0.25
similarityThreshold = 0.25,
topN = 4
) {
const collection = await client.getCollection({ name: namespace });
const result = {
Expand All @@ -78,7 +79,7 @@ const Chroma = {

const response = await collection.query({
queryEmbeddings: queryVector,
nResults: 4,
nResults: topN,
});
response.ids[0].forEach((_, i) => {
if (
Expand Down Expand Up @@ -271,6 +272,7 @@ const Chroma = {
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
Expand All @@ -289,7 +291,8 @@ const Chroma = {
client,
namespace,
queryVector,
similarityThreshold
similarityThreshold,
topN
);

const sources = sourceDocuments.map((metadata, i) => {
Expand Down
9 changes: 6 additions & 3 deletions server/utils/vectorDbProviders/lance/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ const LanceDb = {
client,
namespace,
queryVector,
similarityThreshold = 0.25
similarityThreshold = 0.25,
topN = 4
) {
const collection = await client.openTable(namespace);
const result = {
Expand All @@ -74,7 +75,7 @@ const LanceDb = {
const response = await collection
.search(queryVector)
.metricType("cosine")
.limit(5)
.limit(topN)
.execute();

response.forEach((item) => {
Expand Down Expand Up @@ -240,6 +241,7 @@ const LanceDb = {
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
Expand All @@ -258,7 +260,8 @@ const LanceDb = {
client,
namespace,
queryVector,
similarityThreshold
similarityThreshold,
topN
);

const sources = sourceDocuments.map((metadata, i) => {
Expand Down
8 changes: 6 additions & 2 deletions server/utils/vectorDbProviders/milvus/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ const Milvus = {
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
Expand All @@ -283,7 +284,8 @@ const Milvus = {
client,
namespace,
queryVector,
similarityThreshold
similarityThreshold,
topN
);

const sources = sourceDocuments.map((metadata, i) => {
Expand All @@ -299,7 +301,8 @@ const Milvus = {
client,
namespace,
queryVector,
similarityThreshold = 0.25
similarityThreshold = 0.25,
topN = 4
) {
const result = {
contextTexts: [],
Expand All @@ -309,6 +312,7 @@ const Milvus = {
const response = await client.search({
collection_name: namespace,
vectors: queryVector,
limit: topN,
});
response.results.forEach((match) => {
if (match.score < similarityThreshold) return;
Expand Down
9 changes: 6 additions & 3 deletions server/utils/vectorDbProviders/pinecone/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ const Pinecone = {
index,
namespace,
queryVector,
similarityThreshold = 0.25
similarityThreshold = 0.25,
topN = 4
) {
const result = {
contextTexts: [],
Expand All @@ -55,7 +56,7 @@ const Pinecone = {
queryRequest: {
namespace,
vector: queryVector,
topK: 4,
topK: topN,
includeMetadata: true,
},
});
Expand Down Expand Up @@ -237,6 +238,7 @@ const Pinecone = {
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
Expand All @@ -252,7 +254,8 @@ const Pinecone = {
pineconeIndex,
namespace,
queryVector,
similarityThreshold
similarityThreshold,
topN
);

const sources = sourceDocuments.map((metadata, i) => {
Expand Down
9 changes: 6 additions & 3 deletions server/utils/vectorDbProviders/qdrant/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ const QDrant = {
_client,
namespace,
queryVector,
similarityThreshold = 0.25
similarityThreshold = 0.25,
topN = 4
) {
const { client } = await this.connect();
const result = {
Expand All @@ -64,7 +65,7 @@ const QDrant = {

const responses = await client.search(namespace, {
vector: queryVector,
limit: 4,
limit: topN,
with_payload: true,
});

Expand Down Expand Up @@ -301,6 +302,7 @@ const QDrant = {
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
Expand All @@ -319,7 +321,8 @@ const QDrant = {
client,
namespace,
queryVector,
similarityThreshold
similarityThreshold,
topN
);

const sources = sourceDocuments.map((metadata, i) => {
Expand Down
9 changes: 6 additions & 3 deletions server/utils/vectorDbProviders/weaviate/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ const Weaviate = {
client,
namespace,
queryVector,
similarityThreshold = 0.25
similarityThreshold = 0.25,
topN = 4
) {
const result = {
contextTexts: [],
Expand All @@ -95,7 +96,7 @@ const Weaviate = {
.withClassName(camelCase(namespace))
.withFields(`${fields} _additional { id certainty }`)
.withNearVector({ vector: queryVector })
.withLimit(4)
.withLimit(topN)
.do();

const responses = queryResponse?.data?.Get?.[camelCase(namespace)];
Expand Down Expand Up @@ -347,6 +348,7 @@ const Weaviate = {
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
Expand All @@ -365,7 +367,8 @@ const Weaviate = {
client,
namespace,
queryVector,
similarityThreshold
similarityThreshold,
topN
);

const sources = sourceDocuments.map((metadata, i) => {
Expand Down
8 changes: 6 additions & 2 deletions server/utils/vectorDbProviders/zilliz/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ const Zilliz = {
input = "",
LLMConnector = null,
similarityThreshold = 0.25,
topN = 4,
}) {
if (!namespace || !input || !LLMConnector)
throw new Error("Invalid request to performSimilaritySearch.");
Expand All @@ -284,7 +285,8 @@ const Zilliz = {
client,
namespace,
queryVector,
similarityThreshold
similarityThreshold,
topN
);

const sources = sourceDocuments.map((metadata, i) => {
Expand All @@ -300,7 +302,8 @@ const Zilliz = {
client,
namespace,
queryVector,
similarityThreshold = 0.25
similarityThreshold = 0.25,
topN = 4
) {
const result = {
contextTexts: [],
Expand All @@ -310,6 +313,7 @@ const Zilliz = {
const response = await client.search({
collection_name: namespace,
vectors: queryVector,
limit: topN,
});
response.results.forEach((match) => {
if (match.score < similarityThreshold) return;
Expand Down

0 comments on commit 56fa17c

Please sign in to comment.