Skip to content

Commit

Permalink
Use semantic cache
Browse files Browse the repository at this point in the history
  • Loading branch information
yankeexe committed Jan 29, 2025
1 parent 2dbe3b0 commit a9e9f72
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 27 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ SHELL :=/bin/bash
.PHONY: clean check setup
.DEFAULT_GOAL=help
VENV_DIR = .venv
PYTHON_VERSION=python3.11

check: # Ruff check
@ruff check .
Expand All @@ -23,7 +24,7 @@ run: # Run the application

setup: # Initial project setup
@echo "Creating virtual env at: $(VENV_DIR)"s
@python3 -m venv $(VENV_DIR)
@$(PYTHON_VERSION) -m venv $(VENV_DIR)
@echo "Installing dependencies..."
@source $(VENV_DIR)/bin/activate && pip install -r requirements/requirements-dev.txt && pip install -r requirements/requirements.txt
@echo -e "\n✅ Done.\n🎉 Run the following commands to get started:\n\n ➡️ source $(VENV_DIR)/bin/activate\n ➡️ make run\n"
Expand Down
129 changes: 107 additions & 22 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import csv
import os
import sys
import tempfile
from io import StringIO

import chromadb
import ollama
Expand All @@ -9,6 +12,8 @@
)
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_core.documents import Document
from langchain_ollama import OllamaEmbeddings
from langchain_redis import RedisConfig, RedisVectorStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder
from streamlit.runtime.uploaded_file_manager import UploadedFile
Expand Down Expand Up @@ -68,7 +73,26 @@ def process_document(uploaded_file: UploadedFile) -> list[Document]:
return text_splitter.split_documents(docs)


def get_vector_collection() -> chromadb.Collection:
def get_redis_store() -> RedisVectorStore:
embeddings = OllamaEmbeddings(
model="nomic-embed-text:latest",
)
return RedisVectorStore(
embeddings,
config=RedisConfig(
index_name="cached_contents",
redis_url="redis://localhost:6379",
distance_metric="COSINE",
metadata_schema=[
{"name": "answer", "type": "text"},
],
),
)


def get_vector_collection(
collection_name: str,
) -> chromadb.Collection:
"""Gets or creates a ChromaDB collection for vector storage.
Creates an Ollama embedding function using the nomic-embed-text model and initializes
Expand All @@ -84,15 +108,17 @@ def get_vector_collection() -> chromadb.Collection:
model_name="nomic-embed-text:latest",
)

chroma_client = chromadb.PersistentClient(path="./demo-rag-chroma")
chroma_client = chromadb.PersistentClient(path="./demo-rag-chroma-db")
return chroma_client.get_or_create_collection(
name="rag_app",
name=collection_name,
embedding_function=ollama_ef,
metadata={"hnsw:space": "cosine"},
)


def add_to_vector_collection(all_splits: list[Document], file_name: str):
def add_to_vector_collection(
collection_name: str, all_splits: list[Document], file_name: str
):
"""Adds document splits to a vector collection for semantic search.
Takes a list of document splits and adds them to a ChromaDB vector collection
Expand All @@ -108,7 +134,7 @@ def add_to_vector_collection(all_splits: list[Document], file_name: str):
Raises:
ChromaDBError: If there are issues upserting documents to the collection
"""
collection = get_vector_collection()
collection = get_vector_collection(collection_name)
documents, metadatas, ids = [], [], []

for idx, split in enumerate(all_splits):
Expand All @@ -124,7 +150,7 @@ def add_to_vector_collection(all_splits: list[Document], file_name: str):
st.success("Data added to the vector store!")


def query_collection(prompt: str, n_results: int = 10):
def query_collection(collection_name: str, prompt: str, n_results: int = 10):
"""Queries the vector collection with a given prompt to retrieve relevant documents.
Args:
Expand All @@ -137,11 +163,24 @@ def query_collection(prompt: str, n_results: int = 10):
Raises:
ChromaDBError: If there are issues querying the collection.
"""
collection = get_vector_collection()
collection = get_vector_collection(collection_name)
results = collection.query(query_texts=[prompt], n_results=n_results)
return results


def query_semantic_cache(query: str, n_results: int = 1, threshold: float = 80.0):
vector_store = get_redis_store()
results = vector_store.similarity_search_with_score(query, k=n_results)

if not results:
return None

percentage = (1 - abs(results[0][1])) * 100
if percentage >= threshold:
return results
return None


def call_llm(context: str, prompt: str):
"""Calls the language model with context and prompt to generate a response.
Expand Down Expand Up @@ -210,23 +249,59 @@ def re_rank_cross_encoders(documents: list[str]) -> tuple[str, list[int]]:
return relevant_text, relevant_text_ids


def create_cached_contents(uploaded_file: UploadedFile) -> list[Document]:
data = uploaded_file.getvalue().decode("utf-8")
csv_reader = csv.DictReader(StringIO(data))

docs = []
for row in csv_reader:
docs.append(
Document(page_content=row["question"], metadata={"answer": row["answer"]})
)
vector_store = get_redis_store()
vector_store.add_documents(docs)
st.success("Cache contents added!")


if __name__ == "__main__":
# Document Upload Area
with st.sidebar:
st.set_page_config(page_title="RAG Question Answer")
uploaded_file = st.file_uploader(
"**📑 Upload PDF files for QnA**", type=["pdf"], accept_multiple_files=False
"**📑 Upload PDF files for QnA**",
type=["pdf", "csv"],
accept_multiple_files=False,
help="Upload csv for cached results only",
)
upload_option = st.radio(
"Upload options:",
options=["Primary", "Cache"],
help="Choose Primary for uploading document for QnA.\n\nChoose Cache for uploading cached results",
)

if (
uploaded_file
and upload_option == "Primary"
and uploaded_file.name.split(".")[-1] == "csv"
):
st.error("CSV is only allowed for 'Cache' option.")
sys.exit(1)

process = st.button(
"⚡️ Process",
)
if uploaded_file and process:
normalize_uploaded_file_name = uploaded_file.name.translate(
str.maketrans({"-": "_", ".": "_", " ": "_"})
)
all_splits = process_document(uploaded_file)
add_to_vector_collection(all_splits, normalize_uploaded_file_name)

if upload_option == "Cache":
all_splits = create_cached_contents(uploaded_file)
else:
all_splits = process_document(uploaded_file)
add_to_vector_collection(
"rag_app", all_splits, normalize_uploaded_file_name
)

# Question and Answer Area
st.header("🗣️ RAG Question Answer")
Expand All @@ -236,15 +311,25 @@ def re_rank_cross_encoders(documents: list[str]) -> tuple[str, list[int]]:
)

if ask and prompt:
results = query_collection(prompt)
context = results.get("documents")[0]
relevant_text, relevant_text_ids = re_rank_cross_encoders(context)
response = call_llm(context=relevant_text, prompt=prompt)
st.write_stream(response)

with st.expander("See retrieved documents"):
st.write(results)

with st.expander("See most relevant document ids"):
st.write(relevant_text_ids)
st.write(relevant_text)
cached_results = query_semantic_cache(query=prompt)

if cached_results:
st.write(cached_results[0][0].metadata["answer"].replace("\\n", "\n"))
else:
results = query_collection(prompt=prompt, collection_name="rag_app")

context = results.get("documents")[0]
if not context:
st.write("No results found.")
sys.exit(1)

relevant_text, relevant_text_ids = re_rank_cross_encoders(context)
response = call_llm(context=relevant_text, prompt=prompt)
st.write_stream(response)

with st.expander("See retrieved documents"):
st.write(results)

with st.expander("See most relevant document ids"):
st.write(relevant_text_ids)
st.write(relevant_text)
9 changes: 5 additions & 4 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
chromadb==0.5.23 # Vector Database
langchain-community==0.3.7 # Utils for text splitting
langchain-ollama==0.2.2 # Ollama embedding provider
langchain-redis==0.1.2 # Redis semantic cache provider
ollama==0.3.3 # Local inference
chromadb==0.5.20 # Vector Database
PyMuPDF==1.24.14 # PDF Document loader
sentence-transformers==3.3.1 # CrossEncoder Re-ranking
streamlit==1.40.1 # Application UI
PyMuPDF==1.24.14 # PDF Document loader
langchain-community==0.3.7 # Utils for text splitting

0 comments on commit a9e9f72

Please sign in to comment.