Skip to content

Commit

Permalink
pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
slava-vishnyakov committed Aug 6, 2024
1 parent 141b684 commit d5437ef
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 24 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ async def main():

# Add some sentences
sentences = ["This is a test sentence.", "Another example sentence."]
await rag.add(sentences)
rag.add(sentences)

# Search for similar sentences
results = await rag.search("test sentence", n=2)
results = rag.search("test sentence", n=2)
print(results)

# Run the async function
Expand Down
38 changes: 36 additions & 2 deletions rag_engine/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from openai import AsyncOpenAI
from openai import AsyncOpenAI, OpenAI
from typing import List, Tuple
import asyncio

Expand Down Expand Up @@ -44,7 +44,7 @@ async def get_embedding_async(text: str, api_key: str, model: str, size: int = N
response = await client.embeddings.create(**kwargs)
return response.data[0].embedding, model, size

async def get_embeddings(texts: List[str], api_key: str, model: str, size: int = None) -> List[Tuple[List[float], str, int]]:
async def get_embeddings_async(texts: List[str], api_key: str, model: str, size: int = None) -> List[Tuple[List[float], str, int]]:
"""
Asynchronously get embeddings for multiple texts using the specified model.
Expand All @@ -56,3 +56,37 @@ async def get_embeddings(texts: List[str], api_key: str, model: str, size: int =
"""
tasks = [get_embedding_async(text, api_key, model, size) for text in texts]
return await asyncio.gather(*tasks)

def get_embeddings(texts: List[str], api_key: str, model: str, size: int = None) -> List[Tuple[List[float], str, int]]:
"""
Synchronously get embeddings for multiple texts using the specified model.
:param texts: List of input texts to embed.
:param api_key: OpenAI API key.
:param model: The embedding model to use.
:param size: The size of the embedding vector (only for non-ADA_002 models).
:return: A list of tuples, each containing an embedding vector, model name, and vector size.
"""
client = OpenAI(api_key=api_key)

if model not in MODEL_DIMENSIONS:
raise ValueError(f"Invalid model: {model}")

if size is not None:
if model == ADA_002:
raise ValueError(f"Size parameter not supported for {ADA_002}")
if size < 1 or size > MODEL_DIMENSIONS[model]:
raise ValueError(f"Invalid size for {model}. Must be between 1 and {MODEL_DIMENSIONS[model]}")
else:
size = MODEL_DIMENSIONS[model]

kwargs = {"model": model}
if model != ADA_002:
kwargs["dimensions"] = size

results = []
for text in texts:
response = client.embeddings.create(input=text, **kwargs)
results.append((response.data[0].embedding, model, size))

return results
8 changes: 4 additions & 4 deletions rag_engine/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, filename: str, api_key: str = None, model: str = ADA_002, siz
self.model = model
self.size = size if model != ADA_002 else None

async def add(self, sentences: List[Union[str, Dict[str, Union[int, str]]]]) -> List[int]:
def add(self, sentences: List[Union[str, Dict[str, Union[int, str]]]]) -> List[int]:
"""
Add sentences to the database and generate their embeddings.
Expand All @@ -40,19 +40,19 @@ async def add(self, sentences: List[Union[str, Dict[str, Union[int, str]]]]) ->
texts.append(item['text'])
ids.append(item['id'])

embeddings = await get_embeddings(texts, self.api_key, self.model, self.size)
embeddings = get_embeddings(texts, self.api_key, self.model, self.size)

return self.db.insert_embeddings(texts, embeddings, ids)

async def search(self, query: str, n: int = 5) -> List[Dict[str, Union[str, float]]]:
def search(self, query: str, n: int = 5) -> List[Dict[str, Union[str, float]]]:
"""
Search for similar sentences based on the query.
:param query: The search query.
:param n: Number of results to return. Default is 5.
:return: List of dictionaries containing similar sentences and their similarity scores.
"""
query_embedding, model, size = (await get_embeddings([query], self.api_key, self.model, self.size))[0]
query_embedding, model, size = get_embeddings([query], self.api_key, self.model, self.size)[0]
return self.db.search_similar(query_embedding, n)

def delete_ids(self, ids: List[int]) -> None:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
# Add the parent directory to the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from rag_engine.embeddings import get_embeddings, ADA_002, SMALL_3, LARGE_3
from rag_engine.embeddings import get_embeddings_async, ADA_002, SMALL_3, LARGE_3

load_dotenv()

@pytest.mark.asyncio
async def test_get_embeddings_ada_002():
api_key = os.getenv("OPENAI_API_KEY")
texts = ["This is a test", "Another test"]
embeddings = await get_embeddings(texts, api_key, ADA_002)
embeddings = await get_embeddings_async(texts, api_key, ADA_002)

assert len(embeddings) == 2
assert len(embeddings[0][0]) == 1536 # ADA_002 returns 1536-dimensional vectors
Expand All @@ -28,7 +28,7 @@ async def test_get_embeddings_small_3():
pytest.skip("OPENAI_API_KEY not set")

texts = ["This is a test", "Another test"]
embeddings = await get_embeddings(texts, api_key, SMALL_3, size=512)
embeddings = await get_embeddings_async(texts, api_key, SMALL_3, size=512)

assert len(embeddings) == 2
assert len(embeddings[0][0]) == 512
Expand All @@ -42,7 +42,7 @@ async def test_get_embeddings_large_3():
pytest.skip("OPENAI_API_KEY not set")

texts = ["This is a test", "Another test"]
embeddings = await get_embeddings(texts, api_key, LARGE_3)
embeddings = await get_embeddings_async(texts, api_key, LARGE_3)

assert len(embeddings) == 2
assert len(embeddings[0][0]) == 3072 # LARGE_3 returns 3072-dimensional vectors by default
Expand All @@ -56,7 +56,7 @@ async def test_invalid_model():
pytest.skip("OPENAI_API_KEY not set")

with pytest.raises(ValueError):
await get_embeddings(["Test"], api_key, "invalid_model")
await get_embeddings_async(["Test"], api_key, "invalid_model")

@pytest.mark.asyncio
async def test_invalid_size():
Expand All @@ -65,4 +65,4 @@ async def test_invalid_size():
pytest.skip("OPENAI_API_KEY not set")

with pytest.raises(ValueError):
await get_embeddings(["Test"], api_key, SMALL_3, size=5000)
await get_embeddings_async(["Test"], api_key, SMALL_3, size=5000)
20 changes: 10 additions & 10 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def rag():
@pytest.mark.asyncio
async def test_add_and_search(rag):
sentences = ["This is a test sentence.", "Another example sentence."]
await rag.add(sentences)
rag.add(sentences)

results = await rag.search("test sentence", n=2)
results = rag.search("test sentence", n=2)
assert len(results) == 2
assert results[0]['text'] == "This is a test sentence."

Expand All @@ -34,11 +34,11 @@ async def test_delete_ids(rag):
{"id": 1, "text": "Delete me"},
{"id": 2, "text": "Keep me"}
]
await rag.add(sentences)
rag.add(sentences)

rag.delete_ids([1])

results = await rag.search("Delete", n=2)
results = rag.search("Delete", n=2)
assert len(results) == 1
assert results[0]['text'] == "Keep me"

Expand All @@ -55,11 +55,11 @@ async def test_different_models():

sentences = ["This is a test sentence."]

await rag_ada.add(sentences)
await rag_small.add(sentences)
rag_ada.add(sentences)
rag_small.add(sentences)

results_ada = await rag_ada.search("test", n=1)
results_small = await rag_small.search("test", n=1)
results_ada = rag_ada.search("test", n=1)
results_small = rag_small.search("test", n=1)

assert len(results_ada) == 1
assert len(results_small) == 1
Expand All @@ -79,11 +79,11 @@ async def test_model_consistency():
rag = RAGEngine(db_file, api_key, model=ADA_002)

sentences = ["This is a test sentence."]
await rag.add(sentences)
rag.add(sentences)

# Try to create a new RAGEngine with a different model on the same database
with pytest.raises(ValueError):
rag_small = RAGEngine(db_file, api_key, model=SMALL_3)
await rag_small.add(sentences)
rag_small.add(sentences)

os.remove(db_file)

0 comments on commit d5437ef

Please sign in to comment.