From c2d74188ee768ef9c3040239dee2447f3f209353 Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande Date: Wed, 23 Oct 2024 13:16:36 +0530 Subject: [PATCH 1/3] Added Pinecone Memory Adapter --- .../adapters/memory/pinecone/__init__.py | 14 ++ .../adapters/memory/pinecone/config.py | 17 ++ .../adapters/memory/pinecone/pinecone.py | 195 ++++++++++++++++++ 3 files changed, 226 insertions(+) create mode 100644 llama_stack/providers/adapters/memory/pinecone/__init__.py create mode 100644 llama_stack/providers/adapters/memory/pinecone/config.py create mode 100644 llama_stack/providers/adapters/memory/pinecone/pinecone.py diff --git a/llama_stack/providers/adapters/memory/pinecone/__init__.py b/llama_stack/providers/adapters/memory/pinecone/__init__.py new file mode 100644 index 00000000..d91442e1 --- /dev/null +++ b/llama_stack/providers/adapters/memory/pinecone/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import PineconeConfig, PineconeRequestProviderData # noqa: F401 +from .pinecone import PineconeMemoryAdapter + + +async def get_adapter_impl(config: PineconeConfig, _deps): + impl = PineconeMemoryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/memory/pinecone/config.py b/llama_stack/providers/adapters/memory/pinecone/config.py new file mode 100644 index 00000000..8043e0a5 --- /dev/null +++ b/llama_stack/providers/adapters/memory/pinecone/config.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class PineconeRequestProviderData(BaseModel): + pinecone_api_key: str + + +class PineconeConfig(BaseModel): + dimensions: int + cloud: str + region: str diff --git a/llama_stack/providers/adapters/memory/pinecone/pinecone.py b/llama_stack/providers/adapters/memory/pinecone/pinecone.py new file mode 100644 index 00000000..0cade2b1 --- /dev/null +++ b/llama_stack/providers/adapters/memory/pinecone/pinecone.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json + +from numpy.typing import NDArray +from pinecone import ServerlessSpec +from pinecone.grpc import PineconeGRPC as Pinecone + +from llama_stack.apis.memory import * # noqa: F403 +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import ( + BankWithIndex, + EmbeddingIndex, +) +from .config import PineconeConfig, PineconeRequestProviderData + + +class PineconeIndex(EmbeddingIndex): + def __init__(self, client: Pinecone, index_name: str): + self.client = client + self.index_name = index_name + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len( + embeddings + ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + + data_objects = [] + for i, chunk in enumerate(chunks): + data_objects.append( + { + "id": f"vec{i+1}", + "values": embeddings[i].tolist(), + "metadata": {"chunk": chunk}, + } + ) + + # Inserting chunks into a prespecified Weaviate collection + index = self.client.Index(self.index_name) + index.upsert(vectors=data_objects) + + async def query( + self, embedding: NDArray, k: int, score_threshold: float + ) -> QueryDocumentsResponse: + index = self.client.Index(self.index_name) + + results = index.query( + vector=embedding, top_k=k, include_values=True, include_metadata=True + ) + + chunks = [] + scores = [] + for doc in results["matches"]: + chunk_json = doc["metadata"]["chunk"] + try: + chunk_dict = json.loads(chunk_json) + chunk = Chunk(**chunk_dict) + except Exception: + import traceback + + traceback.print_exc() + print(f"Failed to parse document: {chunk_json}") + continue + + chunks.append(chunk) + scores.append(doc.score) + + return QueryDocumentsResponse(chunks=chunks, scores=scores) + + +class PineconeMemoryAdapter( + Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate +): + def __init__(self, config: PineconeConfig) -> None: + self.config = config + self.client_cache = {} + self.cache = {} + + def _get_client(self) -> Pinecone: + provider_data = self.get_request_provider_data() + assert provider_data is not None, "Request provider data must be set" + assert isinstance(provider_data, PineconeRequestProviderData) + + key = f"{provider_data.pinecone_api_key}" + if key in self.client_cache: + return self.client_cache[key] + + client = Pinecone(api_key=provider_data.pinecone_api_key) + self.client_cache[key] = client + return client + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + def check_if_index_exists( + self, + client: Pinecone, + index_name: str, + ) -> bool: + try: + # Get list of all indexes + active_indexes = client.list_indexes() + for index in active_indexes: + if index["name"] == index_name: + return True + return False + except Exception as e: + print(f"Error checking index: {e}") + return False + + async def register_memory_bank( + self, + memory_bank: MemoryBankDef, + ) -> None: + assert ( + memory_bank.type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.type}" + + client = self._get_client() + + # Create collection if it doesn't exist + if not self.check_if_index_exists(client, memory_bank.identifier): + client.create_index( + name=memory_bank.identifier, + dimension=self.config.dimensions if self.config.dimensions else 1024, + metric="cosine", + spec=ServerlessSpec( + cloud=self.config.cloud if self.config.cloud else "aws", + region=self.config.region if self.config.region else "us-east-1", + ), + ) + + index = BankWithIndex( + bank=memory_bank, + index=PineconeIndex(client=client, index_name=memory_bank.identifier), + ) + self.cache[memory_bank.identifier] = index + + async def list_memory_banks(self) -> List[MemoryBankDef]: + # TODO: right now the Llama Stack is the source of truth for these banks. That is + # not ideal. It should be Weaviate which is the source of truth. Unfortunately, + # list() happens at Stack startup when the Pinecone client (credentials) is not + # yet available. We need to figure out a way to make this work. + return [i.bank for i in self.cache.values()] + + async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: + if bank_id in self.cache: + return self.cache[bank_id] + + bank = await self.memory_bank_store.get_memory_bank(bank_id) + if not bank: + raise ValueError(f"Bank {bank_id} not found") + + client = self._get_client() + if not self.check_if_index_exists(client, bank_id): + raise ValueError(f"Collection with name `{bank_id}` not found") + + index = BankWithIndex( + bank=bank, + index=PineconeIndex(client=client, index_name=bank_id), + ) + self.cache[bank_id] = index + return index + + async def insert_documents( + self, + bank_id: str, + documents: List[MemoryBankDocument], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + await index.insert_documents(documents) + + async def query_documents( + self, + bank_id: str, + query: InterleavedTextMedia, + params: Optional[Dict[str, Any]] = None, + ) -> QueryDocumentsResponse: + index = await self._get_and_cache_bank_index(bank_id) + if not index: + raise ValueError(f"Bank {bank_id} not found") + + return await index.query_documents(query, params) From 07e9da19b311f1e6a8695b681182db9d54b3792e Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande Date: Wed, 23 Oct 2024 23:45:01 +0530 Subject: [PATCH 2/3] Added in registry and tests passed --- .../adapters/memory/pinecone/config.py | 6 ++--- .../adapters/memory/pinecone/pinecone.py | 27 +++++++++++-------- llama_stack/providers/registry/memory.py | 10 +++++++ .../tests/memory/provider_config_example.yaml | 8 +++--- .../providers/tests/memory/test_memory.py | 16 +++++------ 5 files changed, 42 insertions(+), 25 deletions(-) diff --git a/llama_stack/providers/adapters/memory/pinecone/config.py b/llama_stack/providers/adapters/memory/pinecone/config.py index 8043e0a5..8e66eefe 100644 --- a/llama_stack/providers/adapters/memory/pinecone/config.py +++ b/llama_stack/providers/adapters/memory/pinecone/config.py @@ -12,6 +12,6 @@ class PineconeRequestProviderData(BaseModel): class PineconeConfig(BaseModel): - dimensions: int - cloud: str - region: str + dimension: int = 384 + cloud: str = "aws" + region: str = "us-east-1" diff --git a/llama_stack/providers/adapters/memory/pinecone/pinecone.py b/llama_stack/providers/adapters/memory/pinecone/pinecone.py index 0cade2b1..acc2a8a9 100644 --- a/llama_stack/providers/adapters/memory/pinecone/pinecone.py +++ b/llama_stack/providers/adapters/memory/pinecone/pinecone.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json +import time from numpy.typing import NDArray from pinecone import ServerlessSpec @@ -34,15 +34,20 @@ async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): for i, chunk in enumerate(chunks): data_objects.append( { - "id": f"vec{i+1}", + "id": chunk.document_id, "values": embeddings[i].tolist(), - "metadata": {"chunk": chunk}, + "metadata": { + "content": chunk.content, + "token_count": chunk.token_count, + "document_id": chunk.document_id, + }, } ) # Inserting chunks into a prespecified Weaviate collection index = self.client.Index(self.index_name) index.upsert(vectors=data_objects) + time.sleep(1) async def query( self, embedding: NDArray, k: int, score_threshold: float @@ -50,16 +55,16 @@ async def query( index = self.client.Index(self.index_name) results = index.query( - vector=embedding, top_k=k, include_values=True, include_metadata=True + vector=embedding, top_k=k, include_values=False, include_metadata=True ) chunks = [] scores = [] for doc in results["matches"]: - chunk_json = doc["metadata"]["chunk"] + chunk_json = doc["metadata"] + print(f"chunk_json: {chunk_json}") try: - chunk_dict = json.loads(chunk_json) - chunk = Chunk(**chunk_dict) + chunk = Chunk(**chunk_json) except Exception: import traceback @@ -130,11 +135,11 @@ async def register_memory_bank( if not self.check_if_index_exists(client, memory_bank.identifier): client.create_index( name=memory_bank.identifier, - dimension=self.config.dimensions if self.config.dimensions else 1024, + dimension=self.config.dimension, metric="cosine", spec=ServerlessSpec( - cloud=self.config.cloud if self.config.cloud else "aws", - region=self.config.region if self.config.region else "us-east-1", + cloud=self.config.cloud, + region=self.config.region, ), ) @@ -146,7 +151,7 @@ async def register_memory_bank( async def list_memory_banks(self) -> List[MemoryBankDef]: # TODO: right now the Llama Stack is the source of truth for these banks. That is - # not ideal. It should be Weaviate which is the source of truth. Unfortunately, + # not ideal. It should be pinecone which is the source of truth. Unfortunately, # list() happens at Stack startup when the Pinecone client (credentials) is not # yet available. We need to figure out a way to make this work. return [i.bank for i in self.cache.values()] diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index a0fbf163..62b07e9c 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -84,4 +84,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.memory.qdrant.QdrantConfig", ), ), + remote_provider_spec( + Api.memory, + AdapterSpec( + adapter_type="pinecone", + pip_packages=EMBEDDING_DEPS + ["pinecone"], + module="llama_stack.providers.adapters.memory.pinecone", + config_class="llama_stack.providers.adapters.memory.pinecone.PineconeConfig", + provider_data_validator="llama_stack.providers.adapters.memory.pinecone.PineconeRequestProviderData", + ), + ), ] diff --git a/llama_stack/providers/tests/memory/provider_config_example.yaml b/llama_stack/providers/tests/memory/provider_config_example.yaml index 13575a59..da226d69 100644 --- a/llama_stack/providers/tests/memory/provider_config_example.yaml +++ b/llama_stack/providers/tests/memory/provider_config_example.yaml @@ -20,10 +20,12 @@ providers: config: host: localhost port: 6333 + - provider_id: test-pinecone + provider_type: remote::pinecone + config: {} # if a provider needs private keys from the client, they use the # "get_request_provider_data" function (see distribution/request_headers.py) # this is a place to provide such data. provider_data: - "test-weaviate": - weaviate_api_key: 0xdeadbeefputrealapikeyhere - weaviate_cluster_url: http://foobarbaz + "test-pinecone": + pinecone_api_key: diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b26bf75a..7043772d 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -69,7 +69,7 @@ def sample_documents(): async def register_memory_bank(banks_impl: MemoryBanks): bank = VectorMemoryBankDef( - identifier="test_bank", + identifier="test-bank", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, @@ -95,7 +95,7 @@ async def test_banks_register(memory_settings): # but so far we don't have an unregister API unfortunately, so be careful banks_impl = memory_settings["memory_banks_impl"] bank = VectorMemoryBankDef( - identifier="test_bank_no_provider", + identifier="test-bank-no-provider", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, @@ -119,33 +119,33 @@ async def test_query_documents(memory_settings, sample_documents): banks_impl = memory_settings["memory_banks_impl"] with pytest.raises(ValueError): - await memory_impl.insert_documents("test_bank", sample_documents) + await memory_impl.insert_documents("test-bank", sample_documents) await register_memory_bank(banks_impl) - await memory_impl.insert_documents("test_bank", sample_documents) + await memory_impl.insert_documents("test-bank", sample_documents) query1 = "programming language" - response1 = await memory_impl.query_documents("test_bank", query1) + response1 = await memory_impl.query_documents("test-bank", query1) assert_valid_response(response1) assert any("Python" in chunk.content for chunk in response1.chunks) # Test case 3: Query with semantic similarity query3 = "AI and brain-inspired computing" - response3 = await memory_impl.query_documents("test_bank", query3) + response3 = await memory_impl.query_documents("test-bank", query3) assert_valid_response(response3) assert any("neural networks" in chunk.content.lower() for chunk in response3.chunks) # Test case 4: Query with limit on number of results query4 = "computer" params4 = {"max_chunks": 2} - response4 = await memory_impl.query_documents("test_bank", query4, params4) + response4 = await memory_impl.query_documents("test-bank", query4, params4) assert_valid_response(response4) assert len(response4.chunks) <= 2 # Test case 5: Query with threshold on similarity score query5 = "quantum computing" # Not directly related to any document params5 = {"score_threshold": 0.2} - response5 = await memory_impl.query_documents("test_bank", query5, params5) + response5 = await memory_impl.query_documents("test-bank", query5, params5) assert_valid_response(response5) print("The scores are:", response5.scores) assert all(score >= 0.2 for score in response5.scores) From 9d630601b9796e5547ecb7554a3fd6db83563bd6 Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande Date: Thu, 24 Oct 2024 14:02:48 +0530 Subject: [PATCH 3/3] print statements removed --- llama_stack/providers/adapters/memory/pinecone/pinecone.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llama_stack/providers/adapters/memory/pinecone/pinecone.py b/llama_stack/providers/adapters/memory/pinecone/pinecone.py index acc2a8a9..40547242 100644 --- a/llama_stack/providers/adapters/memory/pinecone/pinecone.py +++ b/llama_stack/providers/adapters/memory/pinecone/pinecone.py @@ -62,7 +62,6 @@ async def query( scores = [] for doc in results["matches"]: chunk_json = doc["metadata"] - print(f"chunk_json: {chunk_json}") try: chunk = Chunk(**chunk_json) except Exception: