diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index d5b1987..dba2d00 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -11,15 +11,18 @@ jobs:
python-version: ['3.11']
steps:
- - uses: actions/checkout@v2
+ - name: Check out repository
+ uses: actions/checkout@v4
+
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install requirements
run: |
python -m pip install --upgrade pip
- pip install -r requirements-dev.txt
+ pip install pre-commit
+
- name: Run Pre commit hook (formatting, linting & tests)
run: pre-commit run --all-files --hook-stage pre-push --show-diff-on-failure
diff --git a/.github/workflows/deploy_docs.yaml b/.github/workflows/deploy_docs.yaml
index 140d3b6..d8692f1 100644
--- a/.github/workflows/deploy_docs.yaml
+++ b/.github/workflows/deploy_docs.yaml
@@ -19,8 +19,8 @@ jobs:
- name: Install requirements
run: |
python -m pip install --upgrade pip
- pip install mkdocs mkdocs-techdocs-core pymdown-extensions mkdocs_monorepo_plugin
+ pip install mkdocs pymdown-extensions termynal mkdocs-material
- name: Deploying MkDocs documentation
run: |
mkdocs build
- mkdocs gh-deploy --force
+ mkdocs gh-deploy --force
\ No newline at end of file
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b7b65ac..eb2011d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -7,22 +7,17 @@ repos:
- id: check-toml
- id: check-json
- id: check-added-large-files
- - repo: local
+
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.3.1
hooks:
- - id: isort
- name: Sorting imports (isort)
- entry: isort
- types: [python]
- language: system
- id: ruff
- name: Linting (ruff)
- entry: ruff --fix
- types: [python]
- language: system
+ args: [ --fix ]
+ - id: ruff-format
+
+ - repo: https://github.com/kynan/nbstripout
+ rev: 0.7.1
+ hooks:
- id: nbstripout
- name: Strip Jupyter notebook output (nbstripout)
- entry: nbstripout
- types: [file]
- files: (.ipynb)$
- language: system
+
exclude: ^(.svn|CVS|.bzr|.hg|.git|__pycache__|.tox|.ipynb_checkpoints|assets|tests/assets/|venv/|.venv/)
diff --git a/.skaff/skaff.yaml b/.skaff/skaff.yaml
index d072ce7..c82ccac 100644
--- a/.skaff/skaff.yaml
+++ b/.skaff/skaff.yaml
@@ -7,7 +7,7 @@ name: GenAI RAG indus kit
owner: alexis.vialaret@artefact.com
description: >
Deploy production grade RAGs quickly, on any cloud.
-documentation_url: https://artefactory.github.io/skaff-rag-accelerator/
+documentation_url: https://artefactory-skaff.github.io/skaff-rag-accelerator/
type: deployable # deployable, knowldege pack
lifecycle: prototype # prototype, production
diff --git a/README.md b/README.md
index 5f1c60b..df6c34a 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@ This is a starter kit to deploy a modularizable RAG locally or on the cloud (or
## Features
-- A configurable RAG setup based around Langchain ([Check out the configuration cookbook here](https://artefactory.github.io/skaff-rag-accelerator/cookbook/))
+- A configurable RAG setup based around Langchain ([Check out the configuration cookbook here](https://artefactory-skaff.github.io/skaff-rag-accelerator/cookbook/))
- `RAG` and `RagConfig` python classes that manage components (vector store, llm, retreiver, ...)
- A REST API based on Langserve + FastAPI to provide easy access to the RAG as a web backend
- Optional API plugins for secure user authentication, session management, ...
@@ -17,7 +17,7 @@ This is a starter kit to deploy a modularizable RAG locally or on the cloud (or
## Quickstart
-This quickstart will guide you through the steps to serve the RAG and load a few documents.
+This quickstart will guide you through the steps to serve the RAG and load a few documents.
You will run both the back and front on your machine.
@@ -80,7 +80,7 @@ Right now the RAG does not have any documents loaded, you can use the notebook i
To deep dive into under the hood, take a look at the documentation
-[On github pages](https://artefactory.github.io/skaff-rag-accelerator/)
+[On github pages](https://artefactory-skaff.github.io/skaff-rag-accelerator/)
Or serve them locally:
```shell
diff --git a/backend/Dockerfile b/backend/Dockerfile
index f0c912e..769df96 100644
--- a/backend/Dockerfile
+++ b/backend/Dockerfile
@@ -21,4 +21,4 @@ EXPOSE $PORT
COPY . ./backend
-CMD python -m uvicorn backend.main:app --host 0.0.0.0 --port $PORT
\ No newline at end of file
+CMD python -m uvicorn backend.main:app --host 0.0.0.0 --port $PORT
diff --git a/backend/api_plugins/__init__.py b/backend/api_plugins/__init__.py
index ce320d6..8481b5c 100644
--- a/backend/api_plugins/__init__.py
+++ b/backend/api_plugins/__init__.py
@@ -1,3 +1,13 @@
-from backend.api_plugins.insecure_authentication.insecure_authentication import insecure_authentication_routes
-from backend.api_plugins.secure_authentication.secure_authentication import authentication_routes
+from backend.api_plugins.insecure_authentication.insecure_authentication import (
+ insecure_authentication_routes,
+)
+from backend.api_plugins.secure_authentication.secure_authentication import (
+ authentication_routes,
+)
from backend.api_plugins.sessions.sessions import session_routes
+
+__all__ = [
+ "insecure_authentication_routes",
+ "authentication_routes",
+ "session_routes",
+]
diff --git a/backend/api_plugins/insecure_authentication/insecure_authentication.py b/backend/api_plugins/insecure_authentication/insecure_authentication.py
index 0c3d60e..983faa8 100644
--- a/backend/api_plugins/insecure_authentication/insecure_authentication.py
+++ b/backend/api_plugins/insecure_authentication/insecure_authentication.py
@@ -25,23 +25,30 @@ async def get_current_user(email: str) -> User:
async def signup(email: str) -> dict:
user = User(email=email)
if user_exists(user.email):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"User {user.email} already registered",
+ )
create_user(user)
return {"email": user.email}
-
@app.delete("/user/")
async def del_user(current_user: User = Depends(get_current_user)) -> dict:
email = current_user.email
try:
user = get_user(email)
if user is None:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"User {email} not found",
+ )
delete_user(email)
return {"detail": f"User {email} deleted"}
except Exception:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error")
-
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Internal Server Error",
+ )
@app.post("/user/login")
async def login(email: str) -> dict:
@@ -51,16 +58,20 @@ async def login(email: str) -> dict:
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username",
)
- return {"access_token": email, "token_type": "bearer"} # Fake bearer token to still provide user authentication
-
+ return {
+ "access_token": email,
+ "token_type": "bearer",
+ } # Fake bearer token to still provide user authentication
@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)) -> User:
return current_user
-
@app.get("/user")
async def user_root() -> dict:
- return Response("Insecure user management routes are enabled. Do not use in prod.", status_code=200)
+ return Response(
+ "Insecure user management routes are enabled. Do not use in prod.",
+ status_code=200,
+ )
return Depends(get_current_user)
diff --git a/backend/api_plugins/lib/user_management.py b/backend/api_plugins/lib/user_management.py
index 8b8aa63..3f403af 100644
--- a/backend/api_plugins/lib/user_management.py
+++ b/backend/api_plugins/lib/user_management.py
@@ -13,6 +13,7 @@ class UnsecureUser(BaseModel):
email: str = None
password: bytes = None
+
class User(BaseModel):
email: str = None
hashed_password: str = None
@@ -26,7 +27,8 @@ def from_unsecure_user(cls, unsecure_user: UnsecureUser):
def create_user(user: User) -> None:
with Database() as connection:
connection.execute(
- "INSERT INTO users (email, password) VALUES (?, ?)", (user.email, user.hashed_password)
+ "INSERT INTO users (email, password) VALUES (?, ?)",
+ (user.email, user.hashed_password),
)
@@ -54,12 +56,17 @@ def authenticate_user(username: str, password: bytes) -> bool | User:
if not user:
return False
- if argon2.verify_password(user.hashed_password.encode("utf-8"), password.encode("utf-8")):
+ if argon2.verify_password(
+ user.hashed_password.encode("utf-8"), password.encode("utf-8")
+ ):
return user
return False
-def create_access_token(*, data: dict, expires_delta: Optional[timedelta] = None) -> str:
+
+def create_access_token(
+ *, data: dict, expires_delta: Optional[timedelta] = None
+) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
diff --git a/backend/api_plugins/secure_authentication/secure_authentication.py b/backend/api_plugins/secure_authentication/secure_authentication.py
index ea03f7d..ab8e191 100644
--- a/backend/api_plugins/secure_authentication/secure_authentication.py
+++ b/backend/api_plugins/secure_authentication/secure_authentication.py
@@ -22,6 +22,7 @@
def authentication_routes(app, dependencies=List[Depends]):
from backend.database import Database
+
with Database() as connection:
connection.run_script(Path(__file__).parent / "users_tables.sql")
@@ -46,22 +47,23 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
except JWTError:
raise credentials_exception
-
@app.post("/user/signup", include_in_schema=ADMIN_MODE)
async def signup(user: UnsecureUser) -> dict:
if not ADMIN_MODE:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled"
+ )
user = User.from_unsecure_user(user)
if user_exists(user.email):
raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered"
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"User {user.email} already registered",
)
create_user(user)
return {"email": user.email}
-
@app.delete("/user/")
async def del_user(current_user: User = Depends(get_current_user)) -> dict:
email = current_user.email
@@ -69,16 +71,17 @@ async def del_user(current_user: User = Depends(get_current_user)) -> dict:
user = get_user(email)
if user is None:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found"
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"User {email} not found",
)
delete_user(email)
return {"detail": f"User {email} deleted"}
except Exception:
raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error"
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Internal Server Error",
)
-
@app.post("/user/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
user = authenticate_user(form_data.username, form_data.password)
@@ -93,15 +96,12 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
access_token = create_access_token(data=user_data)
return {"access_token": access_token, "token_type": "bearer"}
-
@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)) -> User:
return current_user
-
@app.get("/user")
async def user_root() -> dict:
return Response("User management routes are enabled.", status_code=200)
-
return Depends(get_current_user)
diff --git a/backend/config.py b/backend/config.py
index 92b2854..d7c5bc3 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -12,11 +12,13 @@
load_dotenv()
+
@dataclass
class LLMConfig:
source: BaseChatModel | LLM | str
source_config: dict
+
@dataclass
class VectorStoreConfig:
source: VectorStore | str
@@ -24,43 +26,48 @@ class VectorStoreConfig:
insertion_mode: str # "None", "full", "incremental"
+
@dataclass
class EmbeddingModelConfig:
source: Embeddings | str
source_config: dict
+
@dataclass
class DatabaseConfig:
database_url: str
+
@dataclass
class RagConfig:
"""
Configuration class for the Retrieval-Augmented Generation (RAG) system.
It is meant to be injected in the RAG class to configure the various components.
- This class holds the configuration for the various components that make up the RAG system, including
- the language model, vector store, embedding model, and database configurations. It provides a method
- to construct a RagConfig instance from a YAML file, allowing for easy external configuration.
+ This class holds the configuration for the various components that make up the RAG
+ system, including the language model, vector store, embedding model, and database
+ configurations. It provides a method to construct a RagConfig instance from a YAML
+ file, allowing for easy external configuration.
Attributes:
llm (LLMConfig): Configuration for the language model component.
vector_store (VectorStoreConfig): Configuration for the vector store component.
- embedding_model (EmbeddingModelConfig): Configuration for the embedding model component.
+ embedding_model (EmbeddingModelConfig): Configuration for the embedding model
+ component.
database (DatabaseConfig): Configuration for the database connection.
Methods:
- from_yaml: Class method to create an instance of RagConfig from a YAML file, with optional environment
- variables for template rendering.
+ from_yaml: Class method to create an instance of RagConfig from a YAML file,
+ with optional environment variables for template rendering.
"""
- llm: LLMConfig = field(default_factory=LLMConfig)
- vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig)
- embedding_model: EmbeddingModelConfig = field(default_factory=EmbeddingModelConfig)
- database: DatabaseConfig = field(default_factory=DatabaseConfig)
- chat_history_window_size: int = 5
- max_tokens_limit: int = 3000
- response_mode: str = None
+ llm: LLMConfig = field(default_factory=LLMConfig)
+ vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig)
+ embedding_model: EmbeddingModelConfig = field(default_factory=EmbeddingModelConfig)
+ database: DatabaseConfig = field(default_factory=DatabaseConfig)
+ chat_history_window_size: int = 5
+ max_tokens_limit: int = 3000
+ response_mode: str = None
@classmethod
def from_yaml(cls, yaml_path: Path, env: dict = None):
diff --git a/backend/database.py b/backend/database.py
index d1c42cc..4e3c58d 100644
--- a/backend/database.py
+++ b/backend/database.py
@@ -26,7 +26,8 @@ class Database:
url (URL): The parsed URL object of the connection string.
pool (PooledDB): The connection pool for database connections.
conn (Connection): The current database connection.
- DIALECT_PLACEHOLDERS (dict): Mapping of database dialects to their placeholder symbols.
+ DIALECT_PLACEHOLDERS (dict): Mapping of database dialects to their placeholder
+ symbols.
"""
DIALECT_PLACEHOLDERS = {
@@ -52,10 +53,17 @@ def __enter__(self) -> "Database":
self.conn = self.pool.connection()
return self
- def __exit__(self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any]) -> None:
+ def __exit__(
+ self,
+ exc_type: Optional[type],
+ exc_value: Optional[BaseException],
+ traceback: Optional[Any],
+ ) -> None:
if self.conn:
if exc_type:
- self.logger.error("Transaction failed", exc_info=(exc_type, exc_value, traceback))
+ self.logger.error(
+ "Transaction failed", exc_info=(exc_type, exc_value, traceback)
+ )
self.conn.rollback()
else:
self.conn.commit()
@@ -93,10 +101,16 @@ def initialize_schema(self):
try:
self.logger.debug("Initializing database schema")
sql_script = Path(__file__).parent.joinpath("db_init.sql").read_text()
- transpiled_sql = sqlglot.transpile(sql_script, read="sqlite", write=self.url.drivername.replace("postgresql", "postgres"))
+ transpiled_sql = sqlglot.transpile(
+ sql_script,
+ read="sqlite",
+ write=self.url.drivername.replace("postgresql", "postgres"),
+ )
for statement in transpiled_sql:
self.execute(statement)
- self.logger.info(f"Database schema initialized successfully for {self.url.drivername}")
+ self.logger.info(
+ f"Database schema initialized successfully for {self.url.drivername}"
+ )
except Exception as e:
self.logger.exception("Schema initialization failed", exc_info=e)
raise
@@ -105,33 +119,46 @@ def run_script(self, path: Path):
try:
self.logger.debug(f"Running Database script at {str(path)}")
sql_script = path.read_text()
- transpiled_sql = sqlglot.transpile(sql_script, read="sqlite", write=self.url.drivername.replace("postgresql", "postgres"))
+ transpiled_sql = sqlglot.transpile(
+ sql_script,
+ read="sqlite",
+ write=self.url.drivername.replace("postgresql", "postgres"),
+ )
for statement in transpiled_sql:
self.execute(statement)
- self.logger.info(f"Successfuly ran script at {path} for {self.url.drivername}")
+ self.logger.info(
+ f"Successfuly ran script at {path} for {self.url.drivername}"
+ )
except Exception as e:
- self.logger.exception(f"Failed to execute the script {path} for {self.url.drivername}", exc_info=e)
+ self.logger.exception(
+ f"Failed to execute the script {path} for {self.url.drivername}",
+ exc_info=e,
+ )
raise
def _create_pool(self) -> PooledDB:
if self.connection_string.startswith("sqlite:///"):
import sqlite3
- Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir(parents=True, exist_ok=True)
+
+ Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir(
+ parents=True, exist_ok=True
+ )
return PooledDB(
creator=sqlite3,
database=self.connection_string.replace("sqlite:///", ""),
- maxconnections=5
+ maxconnections=5,
)
elif self.connection_string.startswith("postgresql://"):
import psycopg2
+
return PooledDB(
- creator=psycopg2,
- dsn=self.connection_string,
- maxconnections=5
+ creator=psycopg2, dsn=self.connection_string, maxconnections=5
)
- elif self.connection_string.startswith("mysql://") or \
- self.connection_string.startswith("mysql+pymysql://"):
+ elif self.connection_string.startswith(
+ "mysql://"
+ ) or self.connection_string.startswith("mysql+pymysql://"):
import mysql.connector
+
return PooledDB(
creator=mysql.connector,
user=self.url.username,
@@ -139,14 +166,15 @@ def _create_pool(self) -> PooledDB:
host=self.url.host,
port=self.url.port,
database=self.url.database,
- maxconnections=5
+ maxconnections=5,
)
elif self.connection_string.startswith("sqlserver://"):
import pyodbc
+
return PooledDB(
creator=pyodbc,
dsn=self.connection_string.replace("sqlserver://", ""),
- maxconnections=5
+ maxconnections=5,
)
else:
raise ValueError(f"Unsupported database type: {self.url.drivername}")
diff --git a/backend/logger.py b/backend/logger.py
index a3099c5..f518fa7 100644
--- a/backend/logger.py
+++ b/backend/logger.py
@@ -4,6 +4,7 @@
# Implement your custom logging logic here. Eg. send logs to a cloud's logging tool.
_logger_instance = None
+
def get_logger() -> Logger:
global _logger_instance
if _logger_instance is None:
diff --git a/backend/main.py b/backend/main.py
index 1595601..1bb37d8 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -3,17 +3,17 @@
from fastapi import FastAPI
from langserve import add_routes
-from backend.api_plugins import authentication_routes, session_routes
+# from backend.api_plugins import authentication_routes, session_routes
from backend.rag_components.rag import RAG
# Initialize a RAG as discribed in the config.yaml file
-# https://artefactory.github.io/skaff-rag-accelerator/backend/rag_ragconfig/
+# https://artefactory-skaff.github.io/skaff-rag-accelerator/backend/rag_ragconfig/
rag = RAG(config=Path(__file__).parent / "config.yaml")
chain = rag.get_chain()
# Create a minimal RAG server based on langserve
# Learn how to extend this configuration to add authentication and session management
-# https://artefactory.github.io/skaff-rag-accelerator/backend/plugins/plugins/
+# https://artefactory-skaff.github.io/skaff-rag-accelerator/backend/plugins/plugins/
app = FastAPI(
title="RAG Accelerator",
description="A RAG-based question answering API",
diff --git a/backend/model.py b/backend/model.py
index 7e63ef4..aadd47e 100644
--- a/backend/model.py
+++ b/backend/model.py
@@ -12,16 +12,21 @@ class Message(BaseModel):
sender: str
content: str
+
class Input(BaseModel):
question: str
+
class InvokeRequest(BaseModel):
input: Input | List[Input] # Supports batched input
id: str = Field(default_factory=lambda: str(uuid4()))
session_id: str = None
config: dict = Field(default_factory=dict)
kwargs: dict = Field(default_factory=dict)
- timestamp: str | datetime = Field(default_factory=lambda: datetime.utcnow().isoformat())
+ timestamp: str | datetime = Field(
+ default_factory=lambda: datetime.utcnow().isoformat()
+ )
+
class UserMessage(InvokeRequest):
sender: str = "user"
diff --git a/backend/rag_components/chain_links/answer_question_from_docs_and_history.py b/backend/rag_components/chain_links/answer_question_from_docs_and_history.py
index 9934f3a..8eee958 100644
--- a/backend/rag_components/chain_links/answer_question_from_docs_and_history.py
+++ b/backend/rag_components/chain_links/answer_question_from_docs_and_history.py
@@ -1,10 +1,12 @@
-"""This chain answers the provided question based on documents it retreives and the conversation history"""
+"""This chain answers the provided question based on documents it retreives and the
+conversation history"""
+
from langchain_core.retrievers import BaseRetriever
from pydantic import BaseModel
-from backend.rag_components.chain_links.rag_basic import rag_basic
-from backend.rag_components.chain_links.condense_question import condense_question
+from backend.rag_components.chain_links.condense_question import condense_question
from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable
+from backend.rag_components.chain_links.rag_basic import rag_basic
class QuestionWithHistory(BaseModel):
@@ -16,11 +18,17 @@ class Response(BaseModel):
response: str
-def answer_question_from_docs_and_history_chain(llm, retriever: BaseRetriever) -> DocumentedRunnable:
+def answer_question_from_docs_and_history_chain(
+ llm, retriever: BaseRetriever
+) -> DocumentedRunnable:
reformulate_question = condense_question(llm)
answer_question = rag_basic(llm, retriever)
- chain = reformulate_question | answer_question
+ chain = reformulate_question | answer_question
typed_chain = chain.with_types(input_type=QuestionWithHistory, output_type=Response)
- return DocumentedRunnable(typed_chain, chain_name="Answer question from docs and history", user_doc=__doc__)
+ return DocumentedRunnable(
+ typed_chain,
+ chain_name="Answer question from docs and history",
+ user_doc=__doc__,
+ )
diff --git a/backend/rag_components/chain_links/condense_question.py b/backend/rag_components/chain_links/condense_question.py
index 70a38a2..9fd115a 100644
--- a/backend/rag_components/chain_links/condense_question.py
+++ b/backend/rag_components/chain_links/condense_question.py
@@ -1,6 +1,8 @@
-"""This chain condenses the chat history and the question into one standalone question."""
-from langchain_core.prompts import PromptTemplate
+"""This chain condenses the chat history and the question into one standalone
+question."""
+
from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel
from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable
@@ -16,20 +18,35 @@ class StandaloneQuestion(BaseModel):
prompt = """\
-Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. You are presented with a conversation that may contain some spelling mistakes and grammatical errors, but your goal is to understand the underlying question. Make sure to avoid the use of unclear pronouns.
+Given the conversation history and the following question, can you rephrase the user's \
+question in its original language so that it is self-sufficient. You are presented \
+with a conversation that may contain some spelling mistakes and grammatical errors, \
+but your goal is to understand the underlying question. Make sure to avoid the use of \
+unclear pronouns.
-If the question is already self-sufficient, return the original question. If it seem the user is authorizing the chatbot to answer without specific context, make sure to reflect that in the rephrased question.
+If the question is already self-sufficient, return the original question. If it seem \
+the user is authorizing the chatbot to answer without specific context, make sure to \
+reflect that in the rephrased question.
Chat history: {chat_history}
Question: {question}
-""" # noqa: E501
+""" # noqa: E501
def condense_question(llm) -> DocumentedRunnable:
- condense_question_prompt = PromptTemplate.from_template(prompt) # chat_history, question
+ condense_question_prompt = PromptTemplate.from_template(
+ prompt
+ ) # chat_history, question
standalone_question = condense_question_prompt | llm | StrOutputParser()
-
- typed_chain = standalone_question.with_types(input_type=QuestionWithChatHistory, output_type=StandaloneQuestion)
- return DocumentedRunnable(typed_chain, chain_name="Condense question and history", prompt=prompt, user_doc=__doc__)
\ No newline at end of file
+
+ typed_chain = standalone_question.with_types(
+ input_type=QuestionWithChatHistory, output_type=StandaloneQuestion
+ )
+ return DocumentedRunnable(
+ typed_chain,
+ chain_name="Condense question and history",
+ prompt=prompt,
+ user_doc=__doc__,
+ )
diff --git a/backend/rag_components/chain_links/documented_runnable.py b/backend/rag_components/chain_links/documented_runnable.py
index ba42d6d..c0608d5 100644
--- a/backend/rag_components/chain_links/documented_runnable.py
+++ b/backend/rag_components/chain_links/documented_runnable.py
@@ -1,30 +1,38 @@
from __future__ import annotations
+
import json
+from dataclasses import asdict, dataclass
from typing import Any, List, Optional
+
+import tabulate
from docdantic import get_field_info
-from langchain_core.runnables.base import Runnable, RunnableBinding, RunnableBindingBase, RunnableSequence, RunnableParallel
+from jinja2 import Template
+from langchain_core.runnables.base import (
+ Runnable,
+ RunnableBinding,
+ RunnableBindingBase,
+ RunnableParallel,
+ RunnableSequence,
+)
from langchain_core.runnables.utils import Input, Output
from pydantic.main import ModelMetaclass
-import tabulate
-from jinja2 import Template
-
-
-from dataclasses import asdict, dataclass
-from typing import Optional, Any
@dataclass
class RunnableSequenceDocumentation:
docs: List[RunnableDocumentation]
+
@dataclass
class RunnableParallelDocumentation:
docs: List[RunnableDocumentation]
+
@dataclass
class RunnableBindingDocumentation:
docs: List[RunnableDocumentation]
+
@dataclass
class RunnableDocumentation:
chain_name: str
@@ -38,9 +46,9 @@ def to_json(self):
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o):
return o.__name__
-
+
return json.dumps(asdict(self), cls=EnhancedJSONEncoder)
-
+
def to_markdown(self) -> str:
rendered_sub_docs = ""
if self.sub_docs:
@@ -60,19 +68,21 @@ def to_markdown(self) -> str:
prompt=self.prompt,
io_doc=input_doc + "\n" + output_doc,
sub_docs=rendered_sub_docs,
- user_doc=self.user_doc
+ user_doc=self.user_doc,
)
class DocumentedRunnable(RunnableBindingBase[Input, Output]):
"""A DocumentedRunnable is a wrapper around a Runnable that generates documentation.
- FIXME: Bound runnables that have configurable fields are not handled correctly, causing the playground to not be properly usable.
+ FIXME: Bound runnables that have configurable fields are not handled correctly,
+ causing the playground to not be properly usable.
TODO: Add Mermaid diagrams.
- This class is used to create a documented version of a Runnable, which is an executable
- unit of work in the langchain framework. The documentation includes information about
- the input and output types, as well as any additional user-provided prompts or documentation.
+ This class is used to create a documented version of a Runnable, which is an
+ executable unit of work in the langchain framework. The documentation includes
+ information about the input and output types, as well as any additional
+ user-provided prompts or documentation.
Attributes:
documentation (str): The generated markdown documentation for the Runnable.
@@ -81,26 +91,33 @@ class DocumentedRunnable(RunnableBindingBase[Input, Output]):
runnable (Runnable): The Runnable to be documented.
chain_name (Optional[str]): The name of the chain that the Runnable belongs to.
prompt (Optional[str]): The prompt that the Runnable uses, if applicable.
- user_doc (Optional[str]): Any additional documentation you want displayed in the documentation.
+ user_doc (Optional[str]): Any additional documentation you want displayed in the
+ documentation.
"""
documentation: RunnableDocumentation | None
def __init__(
- self,
- runnable: Runnable[Input, Output],
- chain_name: Optional[str]=None,
- prompt: Optional[str]=None,
- user_doc: Optional[str]=None,
+ self,
+ runnable: Runnable[Input, Output],
+ chain_name: Optional[str] = None,
+ prompt: Optional[str] = None,
+ user_doc: Optional[str] = None,
**kwargs: Any,
) -> None:
-
custom_input_type = runnable.InputType
custom_output_type = runnable.OutputType
final_chain_name = chain_name or type(runnable).__name__
if isinstance(runnable, RunnableSequence):
- sub_docs = [runnable.documentation if isinstance(runnable, DocumentedRunnable) else DocumentedRunnable(runnable).documentation for runnable in runnable.steps]
+ sub_docs = [
+ (
+ runnable.documentation
+ if isinstance(runnable, DocumentedRunnable)
+ else DocumentedRunnable(runnable).documentation
+ )
+ for runnable in runnable.steps
+ ]
sub_docs = [doc for doc in sub_docs if doc]
if len(sub_docs) >= 2:
documentation = RunnableDocumentation(
@@ -109,13 +126,20 @@ def __init__(
input_type=custom_input_type,
output_type=custom_output_type,
user_doc=user_doc,
- sub_docs=RunnableSequenceDocumentation(docs=sub_docs)
+ sub_docs=RunnableSequenceDocumentation(docs=sub_docs),
)
else:
documentation = sub_docs[0] if len(sub_docs) else None
-
+
elif isinstance(runnable, RunnableParallel):
- sub_docs = [runnable.documentation if isinstance(runnable, DocumentedRunnable) else DocumentedRunnable(runnable).documentation for runnable in runnable.steps.values()]
+ sub_docs = [
+ (
+ runnable.documentation
+ if isinstance(runnable, DocumentedRunnable)
+ else DocumentedRunnable(runnable).documentation
+ )
+ for runnable in runnable.steps.values()
+ ]
sub_docs = [doc for doc in sub_docs if doc]
if len(sub_docs) >= 2:
documentation = RunnableDocumentation(
@@ -124,7 +148,7 @@ def __init__(
input_type=custom_input_type,
output_type=custom_output_type,
user_doc=user_doc,
- sub_docs=RunnableParallelDocumentation(docs=sub_docs)
+ sub_docs=RunnableParallelDocumentation(docs=sub_docs),
)
else:
documentation = sub_docs[0] if len(sub_docs) else None
@@ -134,9 +158,13 @@ def __init__(
while isinstance(bound_runnable, RunnableBinding):
bound_runnable = bound_runnable.bound
- sub_docs = [bound_runnable.documentation if isinstance(bound_runnable, DocumentedRunnable) else DocumentedRunnable(bound_runnable).documentation]
+ sub_docs = [
+ bound_runnable.documentation
+ if isinstance(bound_runnable, DocumentedRunnable)
+ else DocumentedRunnable(bound_runnable).documentation
+ ]
sub_docs = [doc for doc in sub_docs if doc]
-
+
if final_chain_name == "RunnableBinding":
documentation = sub_docs[0] if len(sub_docs) else None
else:
@@ -146,25 +174,30 @@ def __init__(
input_type=custom_input_type,
output_type=custom_output_type,
user_doc=user_doc,
- sub_docs=RunnableBindingDocumentation(docs=sub_docs) if len(sub_docs) else None
+ sub_docs=(
+ RunnableBindingDocumentation(docs=sub_docs)
+ if len(sub_docs)
+ else None
+ ),
)
else:
documentation = None
super().__init__(
- bound=runnable,
+ bound=runnable,
documentation=documentation,
custom_input_type=custom_input_type,
custom_output_type=custom_output_type,
**kwargs,
)
+
def render_io_doc(input, output) -> tuple[str]:
if isinstance(input, ModelMetaclass):
input_doc = render_model_doc(input, "Input")
else:
input_doc = f"### Input: {input.__name__}"
-
+
if isinstance(output, ModelMetaclass):
output_doc = render_model_doc(output, "Output")
else:
@@ -179,16 +212,21 @@ def render_model_doc(model: ModelMetaclass, input_or_output: str) -> str:
tables = {
cls: tabulate.tabulate(
[
- (field.name, field.type, field.required, field.default,)
+ (
+ field.name,
+ field.type,
+ field.required,
+ field.default,
+ )
for field in fields
],
headers=["Name", "Type", "Required", "Default"],
- tablefmt="github"
+ tablefmt="github",
)
for cls, fields in field_info.items()
}
- return '\n'.join(
+ return "\n".join(
f"\n### {input_or_output}: {cls}\n\n{table}\n\n"
for cls, table in tables.items()
)
@@ -232,4 +270,4 @@ def render_model_doc(model: ModelMetaclass, input_or_output: str) -> str:
{{ doc }}
{% endfor %}
-"""
\ No newline at end of file
+"""
diff --git a/backend/rag_components/chain_links/rag_basic.py b/backend/rag_components/chain_links/rag_basic.py
index bf6b2a7..569a9fc 100644
--- a/backend/rag_components/chain_links/rag_basic.py
+++ b/backend/rag_components/chain_links/rag_basic.py
@@ -1,4 +1,5 @@
"""This chain answers the provided question based on documents it retreives."""
+
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnablePassthrough
@@ -7,13 +8,12 @@
from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable
from backend.rag_components.chain_links.retrieve_and_format_docs import fetch_docs_chain
-
prompt = """
As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail. Do not invent information. You must sift through various sources of information, disregarding any data that is not relevant to the query's context. Your response should integrate knowledge from the valid sources you have identified. Additionally, the question might include hypothetical or counterfactual statements. You need to recognize these and adjust your response to provide accurate, relevant information without being misled by the counterfactuals. Respond to the question only taking into account the following context. If no context is provided, do not answer. You may provide an answer if the user explicitely asked for a general answer. You may ask the user to rephrase their question, or their permission to answer without specific context from your own knowledge.
Context: {relevant_documents}
Question: {question}
-""" # noqa: E501
+""" # noqa: E501
class Question(BaseModel):
@@ -26,9 +26,17 @@ class Response(BaseModel):
def rag_basic(llm, retriever: BaseRetriever) -> DocumentedRunnable:
chain = (
- {"relevant_documents": fetch_docs_chain(retriever), "question": RunnablePassthrough(Question)}
+ {
+ "relevant_documents": fetch_docs_chain(retriever),
+ "question": RunnablePassthrough(Question),
+ }
| ChatPromptTemplate.from_template(prompt)
| llm
)
typed_chain = chain.with_types(input_type=str, output_type=Response)
- return DocumentedRunnable(typed_chain, chain_name="Answer questions from documents stored in a vector store", prompt=prompt, user_doc=__doc__)
+ return DocumentedRunnable(
+ typed_chain,
+ chain_name="Answer questions from documents stored in a vector store",
+ prompt=prompt,
+ user_doc=__doc__,
+ )
diff --git a/backend/rag_components/chain_links/rag_with_history.py b/backend/rag_components/chain_links/rag_with_history.py
index 40b5e71..0e71a40 100644
--- a/backend/rag_components/chain_links/rag_with_history.py
+++ b/backend/rag_components/chain_links/rag_with_history.py
@@ -1,27 +1,36 @@
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.history import RunnableWithMessageHistory
-from backend.config import RagConfig
-from backend.rag_components.chain_links.answer_question_from_docs_and_history import answer_question_from_docs_and_history_chain
+from backend.config import RagConfig
+from backend.rag_components.chain_links.answer_question_from_docs_and_history import (
+ answer_question_from_docs_and_history_chain,
+)
from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable
from backend.rag_components.chat_message_history import get_chat_message_history
-def rag_with_history_chain(config: RagConfig, llm, retriever: BaseRetriever) -> DocumentedRunnable:
+def rag_with_history_chain(
+ config: RagConfig, llm, retriever: BaseRetriever
+) -> DocumentedRunnable:
chain = answer_question_from_docs_and_history_chain(llm, retriever)
chain_with_mem = RunnableWithMessageHistory(
chain,
lambda session_id: get_chat_message_history(config, session_id),
input_messages_key="question",
- history_messages_key="chat_history"
+ history_messages_key="chat_history",
)
return chain_with_mem
- # A bug in DocumentedRunnable makes configurable Runnables such as RunnableWithMessageHistory not work properly.
+ # A bug in DocumentedRunnable makes configurable Runnables such as
+ # RunnableWithMessageHistory not work properly.
return DocumentedRunnable(
- chain_with_mem,
- chain_name="RAG with persistant memory",
- user_doc="This chain answers the provided question based on documents it retreives and the conversation history. It uses a persistant memory to store the conversation history.",
+ chain_with_mem,
+ chain_name="RAG with persistant memory",
+ user_doc=(
+ "This chain answers the provided question based on documents it retreives"
+ " and the conversation history. It uses a persistant memory to store the"
+ " conversation history."
+ ),
)
diff --git a/backend/rag_components/chain_links/retrieve_and_format_docs.py b/backend/rag_components/chain_links/retrieve_and_format_docs.py
index a34249c..e81d624 100644
--- a/backend/rag_components/chain_links/retrieve_and_format_docs.py
+++ b/backend/rag_components/chain_links/retrieve_and_format_docs.py
@@ -1,7 +1,7 @@
"""This chain fetches the relevant documents and combines them into a single string."""
-from langchain_core.prompts import PromptTemplate
from langchain.schema import format_document
+from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel
from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable
@@ -19,11 +19,15 @@ class Documents(BaseModel):
def fetch_docs_chain(retriever) -> DocumentedRunnable:
relevant_documents = retriever | _combine_documents
- typed_chain = relevant_documents.with_types(input_type=Question, output_type=Documents)
- return DocumentedRunnable(typed_chain, chain_name="Fetch documents", prompt=prompt, user_doc=__doc__)
+ typed_chain = relevant_documents.with_types(
+ input_type=Question, output_type=Documents
+ )
+ return DocumentedRunnable(
+ typed_chain, chain_name="Fetch documents", prompt=prompt, user_doc=__doc__
+ )
def _combine_documents(docs, document_separator="\n\n"):
document_prompt = PromptTemplate.from_template(template=prompt)
doc_strings = [format_document(doc, document_prompt) for doc in docs]
- return document_separator.join(doc_strings)
\ No newline at end of file
+ return document_separator.join(doc_strings)
diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py
index 22704a3..20dd63a 100644
--- a/backend/rag_components/chat_message_history.py
+++ b/backend/rag_components/chat_message_history.py
@@ -1,4 +1,3 @@
-
from datetime import datetime
from langchain_community.chat_message_histories import SQLChatMessageHistory
@@ -23,6 +22,7 @@ def get_chat_message_history(config: RagConfig, chat_id):
custom_message_converter=TimestampedMessageConverter(TABLE_NAME),
)
+
class TimestampedMessageConverter(DefaultMessageConverter):
def __init__(self, table_name: str):
self.model_class = create_message_model(table_name, declarative_base())
diff --git a/backend/rag_components/document_loader.py b/backend/rag_components/document_loader.py
index 04aa56f..5878309 100644
--- a/backend/rag_components/document_loader.py
+++ b/backend/rag_components/document_loader.py
@@ -33,14 +33,17 @@ def get_best_loader(file_extension: str, llm: BaseChatModel):
input_variables=["file_extension", "loaders"],
template="""
Among the following loaders, which is the best to load a "{file_extension}" file? \
- Only give me one the class name without any other special characters. If no relevant loader is found, respond "None".
+ Only give me one the class name without any other special characters. If no \
+ relevant loader is found, respond "None".
Loaders: {loaders}
""",
)
chain = LLMChain(llm=llm, prompt=prompt, output_key="loader_class_name")
- return chain({"file_extension": file_extension, "loaders": loaders})["loader_class_name"]
+ return chain({"file_extension": file_extension, "loaders": loaders})[
+ "loader_class_name"
+ ]
def get_loaders() -> List[str]:
diff --git a/backend/rag_components/embedding.py b/backend/rag_components/embedding.py
index 7a38223..75f2f24 100644
--- a/backend/rag_components/embedding.py
+++ b/backend/rag_components/embedding.py
@@ -6,6 +6,8 @@
def get_embedding_model(config: RagConfig):
spec = getattr(embeddings, config.embedding_model.source)
kwargs = {
- key: value for key, value in config.embedding_model.source_config.items() if key in spec.__fields__.keys()
+ key: value
+ for key, value in config.embedding_model.source_config.items()
+ if key in spec.__fields__.keys()
}
return spec(**kwargs)
diff --git a/backend/rag_components/llm.py b/backend/rag_components/llm.py
index cbd8306..bdfb23d 100644
--- a/backend/rag_components/llm.py
+++ b/backend/rag_components/llm.py
@@ -9,7 +9,9 @@
def get_llm_model(config: RagConfig, callbacks: List[BaseCallbackHandler] = []):
llm_spec = getattr(chat_models, config.llm.source)
kwargs = {
- key: value for key, value in config.llm.source_config.items() if key in llm_spec.__fields__.keys()
+ key: value
+ for key, value in config.llm.source_config.items()
+ if key in llm_spec.__fields__.keys()
}
kwargs["streaming"] = True
kwargs["callbacks"] = callbacks
diff --git a/backend/rag_components/rag.py b/backend/rag_components/rag.py
index a7a74f6..a7c6d72 100644
--- a/backend/rag_components/rag.py
+++ b/backend/rag_components/rag.py
@@ -6,38 +6,45 @@
from langchain.indexes import SQLRecordManager, index
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores import VectorStore
-from langchain_core.retrievers import BaseRetriever
from langchain.vectorstores.utils import filter_complex_metadata
+from langchain_core.retrievers import BaseRetriever
from backend.config import RagConfig
from backend.database import Database
from backend.logger import get_logger
+from backend.rag_components.chain_links.rag_basic import rag_basic
+from backend.rag_components.chain_links.rag_with_history import rag_with_history_chain
from backend.rag_components.document_loader import get_documents
from backend.rag_components.embedding import get_embedding_model
from backend.rag_components.llm import get_llm_model
from backend.rag_components.retriever import get_retriever
from backend.rag_components.vector_store import get_vector_store
-from backend.rag_components.chain_links.rag_basic import rag_basic
-from backend.rag_components.chain_links.rag_with_history import rag_with_history_chain
+
class RAG:
"""
- The RAG class orchestrates the components necessary for a retrieval-augmented generation pipeline.
+ The RAG class orchestrates the components necessary for a retrieval-augmented
+ generation pipeline.
It initializes with a configuration, either directly or from a file.
The RAG has two main purposes:
- - loading the RAG with documents, which involves ingesting and processing documents to be retrievable by the system
- - generating the chain from the components as specified in the configuration, which entails \
- assembling the various components (language model, embeddings, vector store) into a \
- coherent pipeline for generating responses based on retrieved information.
+ - loading the RAG with documents, which involves ingesting and processing
+ documents to be retrievable by the system
+ - generating the chain from the components as specified in the configuration,
+ which entails assembling the various components (language model, embeddings,
+ vector store) into a coherent pipeline for generating responses based on
+ retrieved information.
Attributes:
config (RagConfig): Configuration object containing settings for RAG components.
llm (BaseChatModel): The language model used for generating responses.
- embeddings (Embeddings): The embedding model used for creating vector representations of text.
- vector_store (VectorStore): The vector store that holds and allows for searching of embeddings.
+ embeddings (Embeddings): The embedding model used for creating vector
+ representations of text.
+ vector_store (VectorStore): The vector store that holds and allows for searching
+ of embeddings.
logger (Logger): Logger for logging information, warnings, and errors.
"""
+
def __init__(self, config: Union[Path, RagConfig]):
if isinstance(config, RagConfig):
self.config = config
@@ -60,13 +67,17 @@ def get_chain(self, memory: bool = False):
chain = rag_basic(self.llm, self.retriever)
return chain
-
def load_file(self, file_path: Path) -> List[Document]:
documents = get_documents(file_path, self.llm)
filtered_documents = filter_complex_metadata(documents)
return self.load_documents(filtered_documents)
- def load_documents(self, documents: List[Document], insertion_mode: str = None, namespace: str = "default"):
+ def load_documents(
+ self,
+ documents: List[Document],
+ insertion_mode: str = None,
+ namespace: str = "default",
+ ):
insertion_mode = insertion_mode or self.config.vector_store.insertion_mode
record_manager = SQLRecordManager(
@@ -79,7 +90,9 @@ def load_documents(self, documents: List[Document], insertion_mode: str = None,
batch_size = 100
for batch in range(0, len(documents), batch_size):
- self.logger.info(f"Indexing batch {batch} to {min(len(documents), batch + batch_size)}.")
+ self.logger.info(
+ f"Indexing batch {batch} to {min(len(documents), batch + batch_size)}."
+ )
indexing_output = index(
documents[batch : min(len(documents), batch + batch_size)],
diff --git a/backend/rag_components/retriever.py b/backend/rag_components/retriever.py
index 260e457..25cd552 100644
--- a/backend/rag_components/retriever.py
+++ b/backend/rag_components/retriever.py
@@ -1,6 +1,4 @@
from langchain_core.vectorstores import VectorStore
-from langchain import retrievers as base_retrievers
-from langchain_community import retrievers as community_retrievers
def get_retriever(vector_store: VectorStore):
diff --git a/backend/rag_components/vector_store.py b/backend/rag_components/vector_store.py
index 9959ef8..e0474ce 100644
--- a/backend/rag_components/vector_store.py
+++ b/backend/rag_components/vector_store.py
@@ -8,17 +8,27 @@
def get_vector_store(embedding_model, config: RagConfig):
vector_store_spec = getattr(vectorstores, config.vector_store.source)
- # the vector store class in langchain doesn't have a uniform interface to pass the embedding model
+ # the vector store class in langchain doesn't have a uniform interface to pass the
+ # embedding model
# we extract the propertiy of the class that matches the 'Embeddings' type
# and instanciate the vector store with our embedding model
signature = inspect.signature(vector_store_spec.__init__)
parameters = signature.parameters
params_dict = dict(parameters)
embedding_param = next(
- (param for param in params_dict.values() if "Embeddings" in str(param.annotation)), None
+ (
+ param
+ for param in params_dict.values()
+ if "Embeddings" in str(param.annotation)
+ ),
+ None,
)
- kwargs = {key: value for key, value in config.vector_store.source_config.items() if key in parameters.keys()}
+ kwargs = {
+ key: value
+ for key, value in config.vector_store.source_config.items()
+ if key in parameters.keys()
+ }
kwargs[embedding_param.name] = embedding_model
vector_store = vector_store_spec(**kwargs)
return vector_store
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 0107f98..e1194f4 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -16,8 +16,10 @@ sqlglot
uvicorn
python-multipart
sse_starlette
+docdantic
+openai
sentence-transformers
chromadb
mysql_connector_repackaged
-psycopg2-binary
\ No newline at end of file
+psycopg2-binary
diff --git a/docker-compose.yaml b/docker-compose.yaml
index 4889de1..d93ef0d 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -35,4 +35,4 @@ services:
networks:
app-network:
- driver: bridge
\ No newline at end of file
+ driver: bridge
diff --git a/docs/backend/chains/basic_chain.md b/docs/backend/chains/basic_chain.md
index 25a6e32..6f907e0 100644
--- a/docs/backend/chains/basic_chain.md
+++ b/docs/backend/chains/basic_chain.md
@@ -54,4 +54,3 @@ This chain fetches the relevant documents and combines them into a single string
-
diff --git a/docs/backend/chains/chain_with_memory.md b/docs/backend/chains/chain_with_memory.md
index f2bc516..dcb0b2f 100644
--- a/docs/backend/chains/chain_with_memory.md
+++ b/docs/backend/chains/chain_with_memory.md
@@ -173,4 +173,3 @@ This chain fetches the relevant documents and combines them into a single string
-
diff --git a/docs/backend/chains/chains.md b/docs/backend/chains/chains.md
index 9af9d9b..6be2722 100644
--- a/docs/backend/chains/chains.md
+++ b/docs/backend/chains/chains.md
@@ -7,7 +7,7 @@ We provide two basic RAG chains to get you started, one does simple one-shot Q&A
## Chains and chain links
-This repo does not define large monolithic chains. To make it easier to pick and chose the required functionalities, we provide "chain links" at `backend/rag_components/chain_links`. All links are valid, self-sufficient chains. You can think of it as a toolbox of langchain components meant to be composed and stacked together to build actually useful chains.
+This repo does not define large monolithic chains. To make it easier to pick and chose the required functionalities, we provide "chain links" at `backend/rag_components/chain_links`. All links are valid, self-sufficient chains. You can think of it as a toolbox of langchain components meant to be composed and stacked together to build actually useful chains.
As all links are `Runnable` objects, they can be built from other chain links which are themselves made of chain links, etc...
@@ -21,7 +21,7 @@ Typically, a link has:
- A chain definition
-For example the `condense_question` chain link has `QuestionWithChatHistory` and `StandaloneQuestion` as input an output models.
+For example the `condense_question` chain link has `QuestionWithChatHistory` and `StandaloneQuestion` as input an output models.
```python
class QuestionWithChatHistory(BaseModel):
question: str
@@ -69,9 +69,9 @@ Documentation is recursively generated from all the `DocumentedRunnable` chains
In order to document your chain, just wrap it in a `DocumentedRunnable`:
```python
documented_chain = DocumentedRunnable(
- runnable=my_chain,
- chain_name="My documented Chain",
- user_doc="Additional chain explainations that will be displayed in the markdown",
+ runnable=my_chain,
+ chain_name="My documented Chain",
+ user_doc="Additional chain explainations that will be displayed in the markdown",
prompt=prompt,
)
```
diff --git a/docs/doc_generation.py b/docs/doc_generation.py
index 1739ccf..ab1ce78 100644
--- a/docs/doc_generation.py
+++ b/docs/doc_generation.py
@@ -1,20 +1,28 @@
-"""Quick ad-hoc script that generates markdown documentation for the basic RAG chain and the RAG chain with memory."""
+"""Quick ad-hoc script that generates markdown documentation for the basic RAG chain and
+the RAG chain with memory."""
+
from pathlib import Path
+
from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable
from backend.rag_components.rag import RAG
-
rag = RAG(config=Path(__file__).parents[1] / "backend" / "config.yaml")
chain = rag.get_chain(memory=False)
-with open(Path(__file__).parent / "backend" / "chains" / "basic_chain.md", "w") as f:
+with (Path(__file__).parent / "backend" / "chains" / "basic_chain.md").open("w") as f:
f.write(chain.documentation.to_markdown())
chain = rag.get_chain(memory=True)
doc_chain = DocumentedRunnable(
- chain,
- chain_name="RAG with persistant memory",
- user_doc="This chain answers the provided question based on documents it retreives and the conversation history. It uses a persistant memory to store the conversation history.",
+ chain,
+ chain_name="RAG with persistant memory",
+ user_doc=(
+ "This chain answers the provided question based on documents it retreives and "
+ "the conversation history. It uses a persistant memory to store the "
+ "conversation history."
+ ),
)
-with open(Path(__file__).parent / "backend" / "chains" / "chain_with_memory.md", "w") as f:
+with (Path(__file__).parent / "backend" / "chains" / "chain_with_memory.md").open(
+ "w"
+) as f:
f.write(doc_chain.documentation.to_markdown())
diff --git a/docs/images/favicon.svg b/docs/images/favicon.svg
new file mode 100644
index 0000000..247c970
--- /dev/null
+++ b/docs/images/favicon.svg
@@ -0,0 +1,11 @@
+
+
\ No newline at end of file
diff --git a/docs/images/logo.svg b/docs/images/logo.svg
new file mode 100644
index 0000000..146732c
--- /dev/null
+++ b/docs/images/logo.svg
@@ -0,0 +1,36 @@
+
+
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
index 9d6967b..fa341f1 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -17,7 +17,7 @@ This is a starter kit to deploy a modularizable RAG locally or on the cloud (or
## Quickstart
-This quickstart will guide you through the steps to serve the RAG and load a few documents.
+This quickstart will guide you through the steps to serve the RAG and load a few documents.
You will run both the back and front on your machine.
@@ -79,7 +79,7 @@ Right now the RAG does not have any documents loaded, you can use the notebook i
To deep dive into under the hood, take a look at the documentation
-[On github pages](https://artefactory.github.io/skaff-rag-accelerator/)
+[On github pages](https://artefactory-skaff.github.io/skaff-rag-accelerator/)
Or serve them locally:
```shell
diff --git a/docs/stylesheets/skaff.css b/docs/stylesheets/skaff.css
new file mode 100644
index 0000000..b744c67
--- /dev/null
+++ b/docs/stylesheets/skaff.css
@@ -0,0 +1,73 @@
+/* https://squidfunk.github.io/mkdocs-material/setup/changing-the-colors/#custom-colors */
+:root {
+ --md-primary-fg-color: #142146;
+ --md-accent-fg-color: #fc2c7f;
+}
+
+ /* Revert hue value to that of pre mkdocs-material v9.4.0 */
+ [data-md-color-scheme="slate"] {
+ /* Hue taken from hsl of #142146, used for bg on website*/
+ --md-hue: 227;
+ /* Increase the lightness by 5%, opacity by 0.2 */
+ --md-default-fg-color: hsla(var(--md-hue),15%,95%,1.0);
+ --md-default-fg-color--light: hsla(var(--md-hue),15%,95%,0.76);
+ --md-default-fg-color--lighter: hsla(var(--md-hue),15%,95%,0.52);
+ --md-default-fg-color--lightest: hsla(var(--md-hue),15%,95%,0.32);
+ /* Change the saturation and lightness to match #142146 */
+ --md-default-bg-color: hsla(var(--md-hue),87%,6%,1);
+ --md-default-bg-color--light: hsla(var(--md-hue),87%,6%,0.54);
+ --md-default-bg-color--lighter: hsla(var(--md-hue),87%,6%,0.26);
+ --md-default-bg-color--lightest: hsla(var(--md-hue),87%,6%,0.07);
+ /* Increase the opacity of code to 1.0 */
+ --md-code-fg-color: hsla(var(--md-hue),18%,86%,1.0);
+ --md-code-hl-comment-color: #666666;
+ --md-typeset-a-color: #65cccc;
+ }
+
+ [data-md-color-scheme="default"] {
+ --md-hue: 227;
+ --md-default-bg-color: hsla(var(--md-hue),100%,96%,1);
+ --md-typeset-a-color: #65cccc;
+
+ --md-code-fg-color: hsla(var(--md-hue),18%,86%,1.0);
+ --md-code-bg-color: #262a32;
+ --md-code-hl-name-color: var(--md-code-fg-color);
+ --md-code-hl-operator-color: var(--md-code-fg-color);
+ --md-code-hl-punctuation-color: var(--md-code-fg-color);
+ --md-code-hl-comment-color: #666666;
+ --md-code-hl-variable-color: var(--md-code-fg-color);
+ }
+
+ .custom-source-wrapper {
+ display: flex;
+ justify-content: space-between;
+ align-items: center;
+ gap: 12px;
+ }
+
+ @media screen and (min-width: 60em) {
+ .md-header__source {
+ box-sizing: content-box;
+ max-width: 11.7rem;
+ width: 11.7rem;
+ }
+ }
+
+ .custom-login-button {
+ font-size: 0.8rem;
+ font-weight: 700;
+ float: right;
+ border: 2px solid #fff;
+ border-radius: 8px;
+ padding: 6px 12px;
+ transition: opacity .25s;
+ }
+
+ .custom-login-button:hover {
+ opacity: 0.7
+ }
+
+ /* Hide all ToC entries for parameters. */
+ li.md-nav__item>a[href*="("] {
+ display: none;
+ }
diff --git a/examples/load_documents.ipynb b/examples/load_documents.ipynb
deleted file mode 100644
index 524aa21..0000000
--- a/examples/load_documents.ipynb
+++ /dev/null
@@ -1,108 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This is an interactive example that will walk you through the initialization of a RAG and the basic embedding of a few documents."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from pathlib import Path\n",
- "import sys\n",
- "import os\n",
- "repo_root = Path(os.getcwd()).parent\n",
- "sys.path.append(str(repo_root))\n",
- "\n",
- "from backend.config import RagConfig\n",
- "from backend.rag_components.rag import RAG\n",
- "\n",
- "rag_config = RagConfig.from_yaml(repo_root / \"backend\" / \"config.yaml\")\n",
- "rag_config.database.database_url = f\"sqlite:////{repo_root}/database/rag.sqlite3\"\n",
- "\n",
- "rag = RAG(config=rag_config)\n",
- "\n",
- "print(\"LLM:\", rag.llm.__class__.__name__)\n",
- "print(\"Embedding model:\", rag.embeddings.__class__.__name__)\n",
- "print(\"Vector store:\", rag.vector_store.__class__.__name__)\n",
- "print(\"Retriever:\", rag.retriever.__class__.__name__)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Here we transform our CSV into standalone embeddable documents that we will be able to feed the vector store.\n",
- "\n",
- "We generate one document for each line, and each document will contain header:value pairs for all the columns.\n",
- "\n",
- "This is a very simplistic example, but vector store data models can get more advanced to support more [powerful retreival methods.](https://python.langchain.com/docs/modules/data_connection/retrievers/)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from langchain_community.document_loaders.csv_loader import CSVLoader\n",
- "from langchain.vectorstores.utils import filter_complex_metadata\n",
- "\n",
- "\n",
- "data_sample_path = repo_root / \"examples\" / \"billionaires.csv\"\n",
- "\n",
- "loader = CSVLoader(\n",
- " file_path=str(data_sample_path),\n",
- " csv_args={\"delimiter\": \",\", \"quotechar\": '\"', \"escapechar\": \"\\\\\"},\n",
- " encoding=\"utf-8-sig\",\n",
- ")\n",
- "\n",
- "raw_documents = loader.load()\n",
- "documents = filter_complex_metadata(raw_documents)\n",
- "documents[:5]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "To load the docs in the vector store, we recommend using the `load_document` as it [indexes previously embedded docs](https://python.langchain.com/docs/modules/data_connection/indexing), making the process idempotent."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "rag.load_documents(documents)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.5"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/frontend/Dockerfile b/frontend/Dockerfile
index 7fefc4e..2a1bbef 100644
--- a/frontend/Dockerfile
+++ b/frontend/Dockerfile
@@ -21,4 +21,4 @@ EXPOSE $PORT
COPY . ./frontend
-CMD python -m streamlit run frontend/front.py --server.port $PORT
\ No newline at end of file
+CMD python -m streamlit run frontend/front.py --server.port $PORT
diff --git a/frontend/front.py b/frontend/front.py
index cea1b6f..7736957 100644
--- a/frontend/front.py
+++ b/frontend/front.py
@@ -23,37 +23,53 @@ def browser_tab_title():
def application_header():
st.image(Image.open(ASSETS_PATH / "logo_title.jpeg"))
- st.caption("Learn more about the RAG indus kit here: https://artefactory.github.io/skaff-rag-accelerator")
+ st.caption(
+ "Learn more about the RAG indus kit here:"
+ " https://artefactory-skaff.github.io/skaff-rag-accelerator/"
+ )
if __name__ == "__main__":
browser_tab_title()
application_header()
- # The session is used to make requests to the backend. It helps with the handling of cookies, auth, and other session data
+ # The session is used to make requests to the backend. It helps with the handling of
+ # cookies, auth, and other session data
initialize_state_variable("session", value=create_session())
# The chain is our RAG that will be used to answer questions.
- # Langserve's RemoteRunnable allows us to work as if the RAG was local, but it's actually running on the backend
+ # Langserve's RemoteRunnable allows us to work as if the RAG was local, but it's
+ # actually running on the backend
initialize_state_variable("chain", value=RemoteRunnable(BACKEND_URL))
- # If the backend supports authentication but the user is not authenticated, show the authentication page
- if backend_supports_auth() and st.session_state.get("authenticated_session", None) is None:
+ # If the backend supports authentication but the user is not authenticated, show the
+ # authentication page
+ if (
+ backend_supports_auth()
+ and st.session_state.get("authenticated_session", None) is None
+ ):
initialize_state_variable("login_status_message", value="")
initialize_state_variable("login_status_level", value="info")
authentication_page()
- st.stop() # Stop the script to avoid running the rest of the code if the user is not authenticated
+ # Stop the script to avoid running the rest of the code if the user is not
+ # authenticated
+ st.stop()
- # If the backend does not support authentication, just use the session as the authenticated session
- if not backend_supports_auth() and st.session_state.get("authenticated_session", None) is None:
+ # If the backend does not support authentication, just use the session as the
+ # authenticated session
+ if (
+ not backend_supports_auth()
+ and st.session_state.get("authenticated_session", None) is None
+ ):
st.session_state["authenticated_session"] = st.session_state["session"]
# Once we have an authenticated session, show the chat interface
if st.session_state.get("authenticated_session") is not None:
-
# If the backend supports sessions, enable session navigation
if backend_supports_sessions():
- initialize_state_variable("email", value="demo.user@email.com") # With authentication, this will take the user's email
+ initialize_state_variable(
+ "email", value="demo.user@email.com"
+ ) # With authentication, this will take the user's email
sidebar()
session_chat()
else:
diff --git a/frontend/lib/auth.py b/frontend/lib/auth.py
index c7b5b39..b711d9a 100644
--- a/frontend/lib/auth.py
+++ b/frontend/lib/auth.py
@@ -13,7 +13,9 @@
def authentication_page():
auth_form_tabs = [stx.TabBarItemData(id="Login", title="Login", description="")]
if ADMIN_MODE:
- auth_form_tabs += [stx.TabBarItemData(id="Signup", title="Signup", description="")]
+ auth_form_tabs += [
+ stx.TabBarItemData(id="Signup", title="Signup", description="")
+ ]
tab = stx.tab_bar(data=auth_form_tabs, default="Login")
if tab == "Login":
@@ -27,8 +29,11 @@ def login_form():
username = st.text_input("Username", key="username")
password = st.text_input("Password", type="password")
if st.session_state["login_status_message"]:
- # Dynamically decides whether to call st.error, st.success, or any other Streamlit method
- getattr(st, st.session_state["login_status_level"])(st.session_state["login_status_message"])
+ # Dynamically decides whether to call st.error, st.success, or any other
+ # Streamlit method
+ getattr(st, st.session_state["login_status_level"])(
+ st.session_state["login_status_message"]
+ )
submit = st.form_submit_button("Log in")
if submit:
@@ -39,7 +44,9 @@ def login_form():
session = authenticate_session(session, token)
else:
st.session_state["login_status_level"] = "error"
- st.session_state["login_status_message"] = "Username/password combination not found"
+ st.session_state["login_status_message"] = (
+ "Username/password combination not found"
+ )
st.session_state["authenticated_session"] = session
st.session_state["email"] = username
st.rerun()
@@ -50,7 +57,9 @@ def signup_form():
username = st.text_input("Username", key="username")
password = st.text_input("Password", type="password")
if st.session_state["login_status_message"]:
- getattr(st, st.session_state["login_status_level"])(st.session_state["login_status_message"])
+ getattr(st, st.session_state["login_status_level"])(
+ st.session_state["login_status_message"]
+ )
submit = st.form_submit_button("Sign up")
if submit:
@@ -63,7 +72,9 @@ def signup_form():
st.session_state["login_status_level"] = "success"
st.session_state["login_status_message"] = "Success! Account created."
if st.session_state["login_status_message"]:
- getattr(st, st.session_state["login_status_level"])(st.session_state["login_status_message"])
+ getattr(st, st.session_state["login_status_level"])(
+ st.session_state["login_status_message"]
+ )
sleep(1.5)
else:
st.session_state["login_status_level"] = "error"
@@ -75,7 +86,9 @@ def signup_form():
def get_token(username: str, password: str) -> Optional[str]:
session = create_session()
- response = session.post("/user/login", data={"username": username, "password": password})
+ response = session.post(
+ "/user/login", data={"username": username, "password": password}
+ )
if response.status_code == 200 and "access_token" in response.json():
return response.json()["access_token"]
else:
@@ -84,7 +97,9 @@ def get_token(username: str, password: str) -> Optional[str]:
def sign_up(username: str, password: str) -> bool:
session = create_session()
- response = session.post("/user/signup", json={"email": username, "password": password})
+ response = session.post(
+ "/user/signup", json={"email": username, "password": password}
+ )
if response.status_code == 200 and "email" in response.json():
return True
else:
@@ -93,5 +108,7 @@ def sign_up(username: str, password: str) -> bool:
def authenticate_session(session, bearer_token: str) -> requests.Session:
session.headers.update({"Authorization": f"Bearer {bearer_token}"})
- st.session_state["chain"] = RemoteRunnable(BACKEND_URL, headers={"Authorization": f"Bearer {bearer_token}"})
+ st.session_state["chain"] = RemoteRunnable(
+ BACKEND_URL, headers={"Authorization": f"Bearer {bearer_token}"}
+ )
return session
diff --git a/frontend/lib/backend_interface.py b/frontend/lib/backend_interface.py
index 19af143..612824b 100644
--- a/frontend/lib/backend_interface.py
+++ b/frontend/lib/backend_interface.py
@@ -16,7 +16,9 @@ def query(verb: str, url: str, **kwargs):
st.session_state["authenticated_session"] = None
st.session_state["email"] = None
st.session_state["login_status_level"] = "error"
- st.session_state["login_status_message"] = "Session expired. Please log in again."
+ st.session_state["login_status_message"] = (
+ "Session expired. Please log in again."
+ )
st.rerun()
return response
@@ -34,6 +36,7 @@ def create_session() -> Session:
session = BaseUrlSession(BACKEND_URL)
return session
+
class BaseUrlSession(Session):
def __init__(self, base_url):
super().__init__()
diff --git a/frontend/lib/basic_chat.py b/frontend/lib/basic_chat.py
index 51c3738..baffe47 100644
--- a/frontend/lib/basic_chat.py
+++ b/frontend/lib/basic_chat.py
@@ -6,7 +6,6 @@ def basic_chat():
with st.container(border=True):
if user_question:
-
with st.chat_message("user"):
st.write(user_question)
diff --git a/frontend/lib/session_chat.py b/frontend/lib/session_chat.py
index 3f20969..d8da3d0 100644
--- a/frontend/lib/session_chat.py
+++ b/frontend/lib/session_chat.py
@@ -17,7 +17,10 @@ class Message:
def __post_init__(self):
self.id = str(uuid4()) if self.id is None else self.id
- self.timestamp = datetime.utcnow().isoformat() if self.timestamp is None else self.timestamp
+ self.timestamp = (
+ datetime.utcnow().isoformat() if self.timestamp is None else self.timestamp
+ )
+
def session_chat():
user_question = st.chat_input("Say something")
@@ -40,7 +43,10 @@ def session_chat():
st.session_state["messages"].append(user_message)
chain = st.session_state.get("chain")
- response = chain.stream({"question": user_question}, {"configurable": {"session_id": session_id}})
+ response = chain.stream(
+ {"question": user_question},
+ {"configurable": {"session_id": session_id}},
+ )
with st.chat_message("assistant"):
full_response = ""
@@ -60,6 +66,7 @@ def new_session():
st.session_state["messages"] = []
return session_id
+
def get_session(session_id: str):
session = query("get", f"/session/{session_id}").json()
return session
diff --git a/frontend/requirements.txt b/frontend/requirements.txt
index dd129ab..41f6ab3 100644
--- a/frontend/requirements.txt
+++ b/frontend/requirements.txt
@@ -7,4 +7,4 @@ python-dotenv
Requests
streamlit
httpx_sse
-pydantic==1.*
\ No newline at end of file
+pydantic==1.*
diff --git a/load_documents.ipynb b/load_documents.ipynb
new file mode 100644
index 0000000..0d521a1
--- /dev/null
+++ b/load_documents.ipynb
@@ -0,0 +1,204 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This is an interactive example that will walk you through the initialization of a RAG and the basic embedding of a few documents."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Successfuly ran script at /Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/backend/rag_components/rag_tables.sql for sqlite\n",
+ "/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/venv/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The class `langchain_community.chat_models.azure_openai.AzureChatOpenAI` was deprecated in langchain-community 0.0.10 and will be removed in 0.2.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import AzureChatOpenAI`.\n",
+ " warn_deprecated(\n",
+ "/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LLM: AzureChatOpenAI\n",
+ "Embedding model: HuggingFaceEmbeddings\n",
+ "Vector store: Chroma\n",
+ "Retriever: VectorStoreRetriever\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pathlib import Path\n",
+ "import os\n",
+ "\n",
+ "from backend.config import RagConfig\n",
+ "from backend.rag_components.rag import RAG\n",
+ "\n",
+ "repo_root = Path(os.getcwd())\n",
+ "\n",
+ "rag_config = RagConfig.from_yaml(repo_root / \"backend\" / \"config.yaml\")\n",
+ "rag = RAG(config=rag_config)\n",
+ "\n",
+ "print(\"LLM:\", rag.llm.__class__.__name__)\n",
+ "print(\"Embedding model:\", rag.embeddings.__class__.__name__)\n",
+ "print(\"Vector store:\", rag.vector_store.__class__.__name__)\n",
+ "print(\"Retriever:\", rag.retriever.__class__.__name__)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here we transform our CSV into standalone embeddable documents that we will be able to feed the vector store.\n",
+ "\n",
+ "We generate one document for each line, and each document will contain header:value pairs for all the columns.\n",
+ "\n",
+ "This is a very simplistic example, but vector store data models can get more advanced to support more [powerful retreival methods.](https://python.langchain.com/docs/modules/data_connection/retrievers/)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Document(page_content='rank: 1\\nfinalWorth: 211000\\ncategory: Fashion & Retail\\npersonName: Bernard Arnault & family\\nage: 74\\ncountry: France\\ncity: Paris\\nsource: LVMH\\nindustries: Fashion & Retail\\ncountryOfCitizenship: France\\norganization: LVMH Moët Hennessy Louis Vuitton\\nselfMade: FALSE\\nstatus: U\\ngender: M\\nbirthDate: 3/5/1949 0:00\\nlastName: Arnault\\nfirstName: Bernard\\ntitle: Chairman and CEO\\ndate: 4/4/2023 5:01\\nstate: \\nresidenceStateRegion: \\nbirthYear: 1949\\nbirthMonth: 3\\nbirthDay: 5\\ncpi_country: 110.05\\ncpi_change_country: 1.1\\ngdp_country: $2,715,518,274,227\\ngross_tertiary_education_enrollment: 65.6\\ngross_primary_education_enrollment_country: 102.5\\nlife_expectancy_country: 82.5\\ntax_revenue_country_country: 24.2\\ntotal_tax_rate_country: 60.7\\npopulation_country: 67059887\\nlatitude_country: 46.227638\\nlongitude_country: 2.213749', metadata={'source': '/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/examples/billionaires.csv', 'row': 0}),\n",
+ " Document(page_content='rank: 2\\nfinalWorth: 180000\\ncategory: Automotive\\npersonName: Elon Musk\\nage: 51\\ncountry: United States\\ncity: Austin\\nsource: Tesla, SpaceX\\nindustries: Automotive\\ncountryOfCitizenship: United States\\norganization: Tesla\\nselfMade: TRUE\\nstatus: D\\ngender: M\\nbirthDate: 6/28/1971 0:00\\nlastName: Musk\\nfirstName: Elon\\ntitle: CEO\\ndate: 4/4/2023 5:01\\nstate: Texas\\nresidenceStateRegion: South\\nbirthYear: 1971\\nbirthMonth: 6\\nbirthDay: 28\\ncpi_country: 117.24\\ncpi_change_country: 7.5\\ngdp_country: $21,427,700,000,000\\ngross_tertiary_education_enrollment: 88.2\\ngross_primary_education_enrollment_country: 101.8\\nlife_expectancy_country: 78.5\\ntax_revenue_country_country: 9.6\\ntotal_tax_rate_country: 36.6\\npopulation_country: 328239523\\nlatitude_country: 37.09024\\nlongitude_country: -95.712891', metadata={'source': '/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/examples/billionaires.csv', 'row': 1}),\n",
+ " Document(page_content='rank: 3\\nfinalWorth: 114000\\ncategory: Technology\\npersonName: Jeff Bezos\\nage: 59\\ncountry: United States\\ncity: Medina\\nsource: Amazon\\nindustries: Technology\\ncountryOfCitizenship: United States\\norganization: Amazon\\nselfMade: TRUE\\nstatus: D\\ngender: M\\nbirthDate: 1/12/1964 0:00\\nlastName: Bezos\\nfirstName: Jeff\\ntitle: Chairman and Founder\\ndate: 4/4/2023 5:01\\nstate: Washington\\nresidenceStateRegion: West\\nbirthYear: 1964\\nbirthMonth: 1\\nbirthDay: 12\\ncpi_country: 117.24\\ncpi_change_country: 7.5\\ngdp_country: $21,427,700,000,000\\ngross_tertiary_education_enrollment: 88.2\\ngross_primary_education_enrollment_country: 101.8\\nlife_expectancy_country: 78.5\\ntax_revenue_country_country: 9.6\\ntotal_tax_rate_country: 36.6\\npopulation_country: 328239523\\nlatitude_country: 37.09024\\nlongitude_country: -95.712891', metadata={'source': '/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/examples/billionaires.csv', 'row': 2}),\n",
+ " Document(page_content='rank: 4\\nfinalWorth: 107000\\ncategory: Technology\\npersonName: Larry Ellison\\nage: 78\\ncountry: United States\\ncity: Lanai\\nsource: Oracle\\nindustries: Technology\\ncountryOfCitizenship: United States\\norganization: Oracle\\nselfMade: TRUE\\nstatus: U\\ngender: M\\nbirthDate: 8/17/1944 0:00\\nlastName: Ellison\\nfirstName: Larry\\ntitle: CTO and Founder\\ndate: 4/4/2023 5:01\\nstate: Hawaii\\nresidenceStateRegion: West\\nbirthYear: 1944\\nbirthMonth: 8\\nbirthDay: 17\\ncpi_country: 117.24\\ncpi_change_country: 7.5\\ngdp_country: $21,427,700,000,000\\ngross_tertiary_education_enrollment: 88.2\\ngross_primary_education_enrollment_country: 101.8\\nlife_expectancy_country: 78.5\\ntax_revenue_country_country: 9.6\\ntotal_tax_rate_country: 36.6\\npopulation_country: 328239523\\nlatitude_country: 37.09024\\nlongitude_country: -95.712891', metadata={'source': '/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/examples/billionaires.csv', 'row': 3}),\n",
+ " Document(page_content='rank: 5\\nfinalWorth: 106000\\ncategory: Finance & Investments\\npersonName: Warren Buffett\\nage: 92\\ncountry: United States\\ncity: Omaha\\nsource: Berkshire Hathaway\\nindustries: Finance & Investments\\ncountryOfCitizenship: United States\\norganization: Berkshire Hathaway Inc. (Cl A)\\nselfMade: TRUE\\nstatus: D\\ngender: M\\nbirthDate: 8/30/1930 0:00\\nlastName: Buffett\\nfirstName: Warren\\ntitle: CEO\\ndate: 4/4/2023 5:01\\nstate: Nebraska\\nresidenceStateRegion: Midwest\\nbirthYear: 1930\\nbirthMonth: 8\\nbirthDay: 30\\ncpi_country: 117.24\\ncpi_change_country: 7.5\\ngdp_country: $21,427,700,000,000\\ngross_tertiary_education_enrollment: 88.2\\ngross_primary_education_enrollment_country: 101.8\\nlife_expectancy_country: 78.5\\ntax_revenue_country_country: 9.6\\ntotal_tax_rate_country: 36.6\\npopulation_country: 328239523\\nlatitude_country: 37.09024\\nlongitude_country: -95.712891', metadata={'source': '/Users/alexis.vialaret/vscode_projects/skaff-rag-accelerator/examples/billionaires.csv', 'row': 4})]"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from langchain_community.document_loaders.csv_loader import CSVLoader\n",
+ "from langchain.vectorstores.utils import filter_complex_metadata\n",
+ "\n",
+ "\n",
+ "data_sample_path = repo_root / \"examples\" / \"billionaires.csv\"\n",
+ "\n",
+ "loader = CSVLoader(\n",
+ " file_path=str(data_sample_path),\n",
+ " csv_args={\"delimiter\": \",\", \"quotechar\": '\"', \"escapechar\": \"\\\\\"},\n",
+ " encoding=\"utf-8-sig\",\n",
+ ")\n",
+ "\n",
+ "raw_documents = loader.load()\n",
+ "documents = filter_complex_metadata(raw_documents)\n",
+ "documents[:5]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To load the docs in the vector store, we recommend using the `load_document` as it [indexes previously embedded docs](https://python.langchain.com/docs/modules/data_connection/indexing), making the process idempotent."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Indexing 2640 documents.\n",
+ "Indexing batch 0 to 100.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 100 to 200.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 200 to 300.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 300 to 400.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 400 to 500.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 500 to 600.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 600 to 700.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 700 to 800.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 800 to 900.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 900 to 1000.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1000 to 1100.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1100 to 1200.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1200 to 1300.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1300 to 1400.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1400 to 1500.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1500 to 1600.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1600 to 1700.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1700 to 1800.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1800 to 1900.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 1900 to 2000.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 2000 to 2100.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 2100 to 2200.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 2200 to 2300.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 2300 to 2400.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 2400 to 2500.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 2500 to 2600.\n",
+ "{'event': 'load_documents', 'num_added': 100, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n",
+ "Indexing batch 2600 to 2640.\n",
+ "{'event': 'load_documents', 'num_added': 40, 'num_updated': 0, 'num_skipped': 0, 'num_deleted': 0}\n"
+ ]
+ }
+ ],
+ "source": [
+ "rag.load_documents(documents)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/mkdocs.yml b/mkdocs.yml
index a7f4f7d..ea5e3f5 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -2,24 +2,87 @@ site_name: GenAI RAG Accelerator
repo_name: artefactory/skaff-rag-accelerator
repo_url: https://github.com/artefactory/skaff-rag-accelerator
+theme:
+ name: material
+ logo: images/logo.svg
+ favicon: images/favicon.svg
+ font:
+ text: Oxygen
+ features:
+ - search.suggest
+ - search.highlight
+ - content.code.annotate
+ - content.code.copy
+ - content.code.select
+ - navigation.indexes
+ - navigation.path
+ - navigation.instant
+ - navigation.instant.preview
+ - navigation.instant.prefetch
+ - navigation.instant.progress
+ - navigation.tracking
+ - toc.follow
+ palette: # Light and dark mode
+ - media: "(prefers-color-scheme: light)"
+ scheme: default
+ primary: custom
+ accent: custom
+ toggle:
+ icon: material/lightbulb-outline
+ name: "Switch to dark mode"
+ - media: "(prefers-color-scheme: dark)"
+ scheme: slate
+ primary: custom
+ accent: custom
+ toggle:
+ icon: material/lightbulb
+ name: "Switch to light mode"
+
+extra:
+ # hide the "Made with Material for MkDocs" message
+ generator: false
+ analytics:
+ provider: google
+ property: G-7REH78BCSD
+ feedback:
+ title: Was this page helpful?
+ ratings:
+ - icon: material/thumb-up-outline
+ name: This page was helpful
+ data: 1@
+ note: >-
+ Thanks for your feedback!
+ - icon: material/thumb-down-outline
+ name: This page could be improved
+ data: 0
+ note: >-
+ Thanks for your feedback! Help us improve this page by
+ opening an issue.
+
+extra_css:
+ - stylesheets/skaff.css
+
plugins:
- - techdocs-core
+ - termynal
+ - search
markdown_extensions:
- admonition
- attr_list
- md_in_html
- pymdownx.superfences
- - pymdownx.tabbed
+ - pymdownx.tabbed:
+ alternate_style: true
- pymdownx.tasklist
- pymdownx.snippets
+
nav:
- Home: index.md
- The frontend: frontend.md
- The database: database.md
- The backend:
- - The backend: backend/backend.md
+ - The Backend: backend/backend.md
- RAG and RAGConfig classes: backend/rag_ragconfig.md
- Chains and chain links: backend/chains/chains.md
- API Plugins:
diff --git a/pyproject.toml b/pyproject.toml
index f9644a8..acdbf3f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,71 +13,31 @@ readme = "README.md"
requires-python = ">=3.8"
[project.urls]
-"Homepage" = "https://github.com/artefactory-fr/rag-as-a-service"
-"Documentation" = "https://artefactory-fr.github.io/rag-as-a-service"
+"Homepage" = "https://github.com/artefactory-skaff/skaff-rag-accelerator"
+"Documentation" = "https://artefactory-skaff.github.io/skaff-rag-accelerator/"
[tool.setuptools]
packages = ["lib", "config", "tests"]
[tool.ruff]
+target-version = "py310"
+
+[tool.ruff.lint]
select = [
- "E",
- "W",
- "F",
- "I",
- "N",
- "Q",
- "PTH",
- "PD",
-] # See: https://beta.ruff.rs/docs/rules/
+ "E", # pycodestyle
+ "W", # pycodestyle
+ "F", # Pyflakes
+ "I", # isort
+ "N", # pep8-naming
+ "Q", # flake8-quotes
+ "PTH", # flake8-use-pathlib
+ "PD", # pandas-vet
+] # See: https://docs.astral.sh/ruff/rules/
ignore = ["D100", "D103", "D203", "D213", "ANN101", "ANN102"]
-line-length = 140
-target-version = "py310"
-exclude = [
- ".bzr",
- ".direnv",
- ".eggs",
- ".git",
- ".ruff_cache",
- ".svn",
- ".tox",
- ".venv",
- "__pypackages__",
- "_build",
- "buck-out",
- "build",
- "dist",
- "node_modules",
- "venv",
-]
-[tool.ruff.pydocstyle]
+[tool.ruff.lint.pydocstyle]
convention = "google"
-# https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html
-[tool.black]
-line-length = 100
-target-version = ["py310"]
-include = '\.pyi?$'
-exclude = '''
-(
- /(
- \.direnv
- | \.eggs
- | \.git
- | \.tox
- | \.venv
- | _build
- | build
- | dist
- | venv
- )/
-)
-'''
-
-[tool.isort]
-profile = "black"
-
[tool.bandit]
exclude_dirs = [".venv", "tests"]
skips = ["B101", "B104"]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index e3a19fe..21c3689 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -4,4 +4,6 @@ mkdocs
mkdocs_monorepo_plugin
pymdown-extensions
mkdocs-pymdownx-material-extras
-mkdocs-techdocs-core
\ No newline at end of file
+mkdocs-techdocs-core
+termynal
+mkdocs-material