Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for weaviate vector store #474

Closed
2 changes: 1 addition & 1 deletion gptcache/manager/scalar_data/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,4 @@ def report_cache(

def close(self):
me.disconnect()
self.con.close()
self.con.close()
26 changes: 26 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,32 @@ def get(name, **kwargs):
flush_interval_sec=flush_interval_sec,
index_params=index_params,
)
elif name == "weaviate":
from gptcache.manager.vector_data.weaviate import Weaviate
url = kwargs.get("url", None)
auth_client_secret = kwargs.get('auth_client_secret', None),
timeout_config = kwargs.get("timeout_config", (10, 60))
proxies = kwargs.get("proxies", None)
trust_env = kwargs.get("trust_env", False)
additional_headers = kwargs.get("additional_headers", None)
startup_period = kwargs.get("startup_period", 5)
embedded_options = kwargs.get("embedded_options", None)
additional_config = kwargs.get("additional_config", None)
class_name = kwargs.get("class_name", "Gptcache")
top_k = kwargs.get("top_k", 1)
vector_base = Weaviate(
url= url,
auth_client_secret = auth_client_secret,
timeout_config = timeout_config,
proxies = proxies,
trust_env = trust_env,
additional_headers = additional_headers,
startup_period = startup_period,
embedded_options = embedded_options,
additional_config = additional_config,
class_name = class_name,
top_k = top_k,
)
else:
raise NotFoundError("vector store", name)
return vector_base
131 changes: 131 additions & 0 deletions gptcache/manager/vector_data/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import List, Optional, Union

import numpy as np

from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils import import_weaviate
from gptcache.utils.log import gptcache_log

import_weaviate()

from weaviate import Client, EmbeddedOptions, Config


class Weaviate(VectorBase):
"""Weaviate Vector store"""
def __init__(
self,
url: str = None,
auth_client_secret = None,
timeout_config = (10, 60),
proxies: Optional[Union[dict, str]] = None,
trust_env: bool = False,
additional_headers: Optional[dict] = None,
startup_period: Optional[int] = 5,
embedded_options = None,
additional_config = None,
top_k: int = 1,
distance: str = "cosine",
class_name: str = "Gptcache",
):
self.class_name = class_name
self.top_k = top_k
self.distance = distance
if not url:
self.client = Client(
embedded_options = EmbeddedOptions(),
startup_period = startup_period,
timeout_config = timeout_config,
additional_config = additional_config
)
else:
self.client = Client(
url,
auth_client_secret,
timeout_config,
proxies,
trust_env,
additional_headers,
startup_period,
embedded_options,
additional_config,
)

def _create_collection(self, class_name: str):
if not class_name:
class_name = self.class_name
if self.client.schema.exists(class_name):
gptcache_log.info(
"The %s already exists, and it will be used directly", class_name
)
else:
gptcache_class_schema = {
"class": class_name,
"description": "caching LLM responses",
"properties": [
{
"name": "id_",
"dataType": ["int"],
}
],
'vectorIndexConfig':
{
"distance": self.distance
}
}
self.client.schema.create_class(gptcache_class_schema)

def mul_add(self, datas: List[VectorData]):
with self.client.batch(
batch_size=len(datas)
) as batch:
# Batch import
for data in datas:
properties = {
"id_": data.id,
}
self.client.batch.add_data_object(
properties,
self.class_name,
vector = data.data.tolist()
)

def search(self, data: np.ndarray, top_k: int = -1):
if not self.client.schema.exists(self.class_name):
self._create_collection(self.class_name)
if top_k==-1:
top_k = self.top_k
result = self.client.query.get(class_name = self.class_name, properties = ['id_']).\
with_near_vector(content={"vector": data.tolist()}).\
with_additional(['distance']).\
with_limit(top_k).do()
return list(map(lambda x: (x['_additional']['distance'], x['id_']), result['data']['Get'][self.class_name]))

def get_uuids(self, ids: List[str]):
uuid_list = []
for id_ in ids:
res = self.client.query.get(class_name=self.class_name, properties=['id_']).\
with_where({"path": ["id_"], "operator":"Equal", "valueNumber":id_}).\
with_additional(["id"]).do()
uuid_list.append(res['data']['Get'][self.class_name][0]['_additional']['id'])
return uuid_list

def delete(self, ids: List[str]):
uuids = self.get_uuids(ids)
for uuid_ in uuids:
self.client.data_object.delete(class_name = self.class_name, uuid=uuid_)

def rebuild(self, ids=None) :
return

def flush(self):
return True

def close(self):
pass






8 changes: 7 additions & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"import_fastapi",
"import_redis",
"import_qdrant",
"import_weaviate"
]

import importlib.util
Expand Down Expand Up @@ -116,7 +117,7 @@ def import_hnswlib():


def import_chromadb():
_check_library("chromadb")
_check_library("chromadb", package="chromadb==0.3.26")


def import_sqlalchemy():
Expand Down Expand Up @@ -260,5 +261,10 @@ def import_redis():
_check_library("redis_om")


def import_weaviate():
_check_library("weaviate-client")


def import_starlette():
_check_library("starlette")

2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ grpcio==1.53.0
protobuf==3.20.0
milvus==2.2.8
pymilvus==2.2.8
pymongo
mongoengine
30 changes: 30 additions & 0 deletions tests/unit_tests/manager/test_weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest

import numpy as np

from gptcache.manager.vector_data import VectorBase
from gptcache.manager.vector_data.base import VectorData


class TestUSearchDB(unittest.TestCase):
def test_normal(self):
size = 1000
dim = 512
top_k = 10
weaviate = VectorBase(
"weaviate",
top_k = top_k
)
data = np.random.randn(size, dim).astype(np.float32)
weaviate.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
search_result = weaviate.search(data[0], top_k)
self.assertEqual(len(search_result), top_k)
weaviate.mul_add([VectorData(id=size, data=data[0])])
ret = weaviate.search(data[0])
self.assertIn(ret[0][1], [0, size])
self.assertIn(ret[1][1], [0, size])
weaviate.delete([0, 1, 2, 3, 4, 5, size])
ret = weaviate.search(data[0])
self.assertNotIn(ret[0][1], [0, size])
weaviate.rebuild()
weaviate.close()