From 082c71add7f213fe094f7bb9285cf3002f6a44ea Mon Sep 17 00:00:00 2001 From: Baptiste Pasquier Date: Mon, 15 Apr 2024 16:59:42 +0200 Subject: [PATCH 1/4] Update Ruff config --- pyproject.toml | 66 ++++++++++---------------------------------------- 1 file changed, 13 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9644a8..ca0497f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,64 +20,24 @@ requires-python = ">=3.8" packages = ["lib", "config", "tests"] [tool.ruff] +target-version = "py310" + +[tool.ruff.lint] select = [ - "E", - "W", - "F", - "I", - "N", - "Q", - "PTH", - "PD", -] # See: https://beta.ruff.rs/docs/rules/ + "E", # pycodestyle + "W", # pycodestyle + "F", # Pyflakes + "I", # isort + "N", # pep8-naming + "Q", # flake8-quotes + "PTH", # flake8-use-pathlib + "PD", # pandas-vet +] # See: https://docs.astral.sh/ruff/rules/ ignore = ["D100", "D103", "D203", "D213", "ANN101", "ANN102"] -line-length = 140 -target-version = "py310" -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "venv", -] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" -# https://black.readthedocs.io/en/stable/usage_and_configuration/the_basics.html -[tool.black] -line-length = 100 -target-version = ["py310"] -include = '\.pyi?$' -exclude = ''' -( - /( - \.direnv - | \.eggs - | \.git - | \.tox - | \.venv - | _build - | build - | dist - | venv - )/ -) -''' - -[tool.isort] -profile = "black" - [tool.bandit] exclude_dirs = [".venv", "tests"] skips = ["B101", "B104"] From 967266126e876f4120101bafe2a88e9ef50a001d Mon Sep 17 00:00:00 2001 From: Baptiste Pasquier Date: Wed, 17 Apr 2024 11:27:46 +0200 Subject: [PATCH 2/4] Add ruff and nbstripout as pre-commit hooks --- .pre-commit-config.yaml | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b7b65ac..eb2011d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,22 +7,17 @@ repos: - id: check-toml - id: check-json - id: check-added-large-files - - repo: local + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.1 hooks: - - id: isort - name: Sorting imports (isort) - entry: isort - types: [python] - language: system - id: ruff - name: Linting (ruff) - entry: ruff --fix - types: [python] - language: system + args: [ --fix ] + - id: ruff-format + + - repo: https://github.com/kynan/nbstripout + rev: 0.7.1 + hooks: - id: nbstripout - name: Strip Jupyter notebook output (nbstripout) - entry: nbstripout - types: [file] - files: (.ipynb)$ - language: system + exclude: ^(.svn|CVS|.bzr|.hg|.git|__pycache__|.tox|.ipynb_checkpoints|assets|tests/assets/|venv/|.venv/) From 58abbcf5556116a91794a6b5203c43f11facf1e5 Mon Sep 17 00:00:00 2001 From: Baptiste Pasquier Date: Mon, 15 Apr 2024 17:22:52 +0200 Subject: [PATCH 3/4] :white_check_mark: pass pre-commit --- README.md | 2 +- backend/Dockerfile | 2 +- backend/api_plugins/__init__.py | 14 ++- .../insecure_authentication.py | 29 +++-- backend/api_plugins/lib/user_management.py | 13 ++- .../secure_authentication.py | 20 ++-- backend/api_plugins/sessions/sessions.py | 29 +++-- backend/config.py | 33 +++--- backend/database.py | 62 +++++++--- backend/logger.py | 1 + backend/main.py | 2 +- backend/model.py | 7 +- .../answer_question_from_docs_and_history.py | 20 +++- .../chain_links/condense_question.py | 35 ++++-- .../chain_links/documented_runnable.py | 108 ++++++++++++------ .../rag_components/chain_links/rag_basic.py | 16 ++- .../chain_links/rag_with_history.py | 25 ++-- .../chain_links/retrieve_and_format_docs.py | 12 +- .../rag_components/chat_message_history.py | 2 +- backend/rag_components/document_loader.py | 7 +- backend/rag_components/embedding.py | 4 +- backend/rag_components/llm.py | 4 +- backend/rag_components/rag.py | 39 ++++--- backend/rag_components/retriever.py | 2 - backend/rag_components/vector_store.py | 16 ++- backend/requirements.txt | 2 +- docker-compose.yaml | 2 +- docs/backend/chains/basic_chain.md | 1 - docs/backend/chains/chain_with_memory.md | 1 - docs/backend/chains/chains.md | 10 +- docs/doc_generation.py | 22 ++-- docs/index.md | 2 +- frontend/Dockerfile | 2 +- frontend/front.py | 36 ++++-- frontend/lib/auth.py | 35 ++++-- frontend/lib/backend_interface.py | 5 +- frontend/lib/basic_chat.py | 1 - frontend/lib/session_chat.py | 11 +- frontend/lib/sidebar.py | 24 +++- frontend/requirements.txt | 2 +- requirements-dev.txt | 2 +- 41 files changed, 457 insertions(+), 205 deletions(-) diff --git a/README.md b/README.md index 5f1c60b..daeeb78 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ This is a starter kit to deploy a modularizable RAG locally or on the cloud (or ## Quickstart -This quickstart will guide you through the steps to serve the RAG and load a few documents. +This quickstart will guide you through the steps to serve the RAG and load a few documents. You will run both the back and front on your machine. diff --git a/backend/Dockerfile b/backend/Dockerfile index f0c912e..769df96 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -21,4 +21,4 @@ EXPOSE $PORT COPY . ./backend -CMD python -m uvicorn backend.main:app --host 0.0.0.0 --port $PORT \ No newline at end of file +CMD python -m uvicorn backend.main:app --host 0.0.0.0 --port $PORT diff --git a/backend/api_plugins/__init__.py b/backend/api_plugins/__init__.py index ce320d6..8481b5c 100644 --- a/backend/api_plugins/__init__.py +++ b/backend/api_plugins/__init__.py @@ -1,3 +1,13 @@ -from backend.api_plugins.insecure_authentication.insecure_authentication import insecure_authentication_routes -from backend.api_plugins.secure_authentication.secure_authentication import authentication_routes +from backend.api_plugins.insecure_authentication.insecure_authentication import ( + insecure_authentication_routes, +) +from backend.api_plugins.secure_authentication.secure_authentication import ( + authentication_routes, +) from backend.api_plugins.sessions.sessions import session_routes + +__all__ = [ + "insecure_authentication_routes", + "authentication_routes", + "session_routes", +] diff --git a/backend/api_plugins/insecure_authentication/insecure_authentication.py b/backend/api_plugins/insecure_authentication/insecure_authentication.py index 0c3d60e..983faa8 100644 --- a/backend/api_plugins/insecure_authentication/insecure_authentication.py +++ b/backend/api_plugins/insecure_authentication/insecure_authentication.py @@ -25,23 +25,30 @@ async def get_current_user(email: str) -> User: async def signup(email: str) -> dict: user = User(email=email) if user_exists(user.email): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"User {user.email} already registered", + ) create_user(user) return {"email": user.email} - @app.delete("/user/") async def del_user(current_user: User = Depends(get_current_user)) -> dict: email = current_user.email try: user = get_user(email) if user is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User {email} not found", + ) delete_user(email) return {"detail": f"User {email} deleted"} except Exception: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") - + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal Server Error", + ) @app.post("/user/login") async def login(email: str) -> dict: @@ -51,16 +58,20 @@ async def login(email: str) -> dict: status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username", ) - return {"access_token": email, "token_type": "bearer"} # Fake bearer token to still provide user authentication - + return { + "access_token": email, + "token_type": "bearer", + } # Fake bearer token to still provide user authentication @app.get("/user/me") async def user_me(current_user: User = Depends(get_current_user)) -> User: return current_user - @app.get("/user") async def user_root() -> dict: - return Response("Insecure user management routes are enabled. Do not use in prod.", status_code=200) + return Response( + "Insecure user management routes are enabled. Do not use in prod.", + status_code=200, + ) return Depends(get_current_user) diff --git a/backend/api_plugins/lib/user_management.py b/backend/api_plugins/lib/user_management.py index 8b8aa63..3f403af 100644 --- a/backend/api_plugins/lib/user_management.py +++ b/backend/api_plugins/lib/user_management.py @@ -13,6 +13,7 @@ class UnsecureUser(BaseModel): email: str = None password: bytes = None + class User(BaseModel): email: str = None hashed_password: str = None @@ -26,7 +27,8 @@ def from_unsecure_user(cls, unsecure_user: UnsecureUser): def create_user(user: User) -> None: with Database() as connection: connection.execute( - "INSERT INTO users (email, password) VALUES (?, ?)", (user.email, user.hashed_password) + "INSERT INTO users (email, password) VALUES (?, ?)", + (user.email, user.hashed_password), ) @@ -54,12 +56,17 @@ def authenticate_user(username: str, password: bytes) -> bool | User: if not user: return False - if argon2.verify_password(user.hashed_password.encode("utf-8"), password.encode("utf-8")): + if argon2.verify_password( + user.hashed_password.encode("utf-8"), password.encode("utf-8") + ): return user return False -def create_access_token(*, data: dict, expires_delta: Optional[timedelta] = None) -> str: + +def create_access_token( + *, data: dict, expires_delta: Optional[timedelta] = None +) -> str: to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta diff --git a/backend/api_plugins/secure_authentication/secure_authentication.py b/backend/api_plugins/secure_authentication/secure_authentication.py index ea03f7d..ab8e191 100644 --- a/backend/api_plugins/secure_authentication/secure_authentication.py +++ b/backend/api_plugins/secure_authentication/secure_authentication.py @@ -22,6 +22,7 @@ def authentication_routes(app, dependencies=List[Depends]): from backend.database import Database + with Database() as connection: connection.run_script(Path(__file__).parent / "users_tables.sql") @@ -46,22 +47,23 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: except JWTError: raise credentials_exception - @app.post("/user/signup", include_in_schema=ADMIN_MODE) async def signup(user: UnsecureUser) -> dict: if not ADMIN_MODE: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled" + ) user = User.from_unsecure_user(user) if user_exists(user.email): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered" + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"User {user.email} already registered", ) create_user(user) return {"email": user.email} - @app.delete("/user/") async def del_user(current_user: User = Depends(get_current_user)) -> dict: email = current_user.email @@ -69,16 +71,17 @@ async def del_user(current_user: User = Depends(get_current_user)) -> dict: user = get_user(email) if user is None: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found" + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User {email} not found", ) delete_user(email) return {"detail": f"User {email} deleted"} except Exception: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal Server Error", ) - @app.post("/user/login") async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict: user = authenticate_user(form_data.username, form_data.password) @@ -93,15 +96,12 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict: access_token = create_access_token(data=user_data) return {"access_token": access_token, "token_type": "bearer"} - @app.get("/user/me") async def user_me(current_user: User = Depends(get_current_user)) -> User: return current_user - @app.get("/user") async def user_root() -> dict: return Response("User management routes are enabled.", status_code=200) - return Depends(get_current_user) diff --git a/backend/api_plugins/sessions/sessions.py b/backend/api_plugins/sessions/sessions.py index 2649bd4..20e64a6 100644 --- a/backend/api_plugins/sessions/sessions.py +++ b/backend/api_plugins/sessions/sessions.py @@ -22,7 +22,9 @@ def session_routes( connection.run_script(Path(__file__).parent / "sessions_tables.sql") @app.post("/session/new") - async def chat_new(current_user: User=authentication, dependencies=dependencies) -> dict: + async def chat_new( + current_user: User = authentication, dependencies=dependencies + ) -> dict: chat_id = str(uuid4()) timestamp = datetime.utcnow().isoformat() user_id = current_user.email if current_user else "unauthenticated" @@ -33,26 +35,30 @@ async def chat_new(current_user: User=authentication, dependencies=dependencies) ) return {"session_id": chat_id} - @app.get("/session/list") - async def chat_list(current_user: User=authentication, dependencies=dependencies) -> List[dict]: + async def chat_list( + current_user: User = authentication, dependencies=dependencies + ) -> List[dict]: user_email = current_user.email if current_user else "unauthenticated" chats = [] with Database() as connection: result = connection.execute( - "SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp DESC", + "SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp" + " DESC", (user_email,), ) chats = [{"id": row[0], "timestamp": row[1]} for row in result] return chats - @app.get("/session/{session_id}") - async def chat(session_id: str, current_user: User=authentication, dependencies=dependencies) -> dict: + async def chat( + session_id: str, current_user: User = authentication, dependencies=dependencies + ) -> dict: messages: List[Message] = [] with Database() as connection: result = connection.execute( - "SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC", + "SELECT id, timestamp, session_id, message FROM message_history WHERE" + " session_id = ? ORDER BY timestamp ASC", (session_id,), ) for row in result: @@ -66,8 +72,13 @@ async def chat(session_id: str, current_user: User=authentication, dependencies= content=content, ) messages.append(message) - return {"chat_id": session_id, "messages": [message.dict() for message in messages]} + return { + "chat_id": session_id, + "messages": [message.dict() for message in messages], + } @app.get("/session") - async def session_root(current_user: User=authentication, dependencies=dependencies) -> dict: + async def session_root( + current_user: User = authentication, dependencies=dependencies + ) -> dict: return Response("Sessions management routes are enabled.", status_code=200) diff --git a/backend/config.py b/backend/config.py index 92b2854..d7c5bc3 100644 --- a/backend/config.py +++ b/backend/config.py @@ -12,11 +12,13 @@ load_dotenv() + @dataclass class LLMConfig: source: BaseChatModel | LLM | str source_config: dict + @dataclass class VectorStoreConfig: source: VectorStore | str @@ -24,43 +26,48 @@ class VectorStoreConfig: insertion_mode: str # "None", "full", "incremental" + @dataclass class EmbeddingModelConfig: source: Embeddings | str source_config: dict + @dataclass class DatabaseConfig: database_url: str + @dataclass class RagConfig: """ Configuration class for the Retrieval-Augmented Generation (RAG) system. It is meant to be injected in the RAG class to configure the various components. - This class holds the configuration for the various components that make up the RAG system, including - the language model, vector store, embedding model, and database configurations. It provides a method - to construct a RagConfig instance from a YAML file, allowing for easy external configuration. + This class holds the configuration for the various components that make up the RAG + system, including the language model, vector store, embedding model, and database + configurations. It provides a method to construct a RagConfig instance from a YAML + file, allowing for easy external configuration. Attributes: llm (LLMConfig): Configuration for the language model component. vector_store (VectorStoreConfig): Configuration for the vector store component. - embedding_model (EmbeddingModelConfig): Configuration for the embedding model component. + embedding_model (EmbeddingModelConfig): Configuration for the embedding model + component. database (DatabaseConfig): Configuration for the database connection. Methods: - from_yaml: Class method to create an instance of RagConfig from a YAML file, with optional environment - variables for template rendering. + from_yaml: Class method to create an instance of RagConfig from a YAML file, + with optional environment variables for template rendering. """ - llm: LLMConfig = field(default_factory=LLMConfig) - vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig) - embedding_model: EmbeddingModelConfig = field(default_factory=EmbeddingModelConfig) - database: DatabaseConfig = field(default_factory=DatabaseConfig) - chat_history_window_size: int = 5 - max_tokens_limit: int = 3000 - response_mode: str = None + llm: LLMConfig = field(default_factory=LLMConfig) + vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig) + embedding_model: EmbeddingModelConfig = field(default_factory=EmbeddingModelConfig) + database: DatabaseConfig = field(default_factory=DatabaseConfig) + chat_history_window_size: int = 5 + max_tokens_limit: int = 3000 + response_mode: str = None @classmethod def from_yaml(cls, yaml_path: Path, env: dict = None): diff --git a/backend/database.py b/backend/database.py index d1c42cc..4e3c58d 100644 --- a/backend/database.py +++ b/backend/database.py @@ -26,7 +26,8 @@ class Database: url (URL): The parsed URL object of the connection string. pool (PooledDB): The connection pool for database connections. conn (Connection): The current database connection. - DIALECT_PLACEHOLDERS (dict): Mapping of database dialects to their placeholder symbols. + DIALECT_PLACEHOLDERS (dict): Mapping of database dialects to their placeholder + symbols. """ DIALECT_PLACEHOLDERS = { @@ -52,10 +53,17 @@ def __enter__(self) -> "Database": self.conn = self.pool.connection() return self - def __exit__(self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any]) -> None: + def __exit__( + self, + exc_type: Optional[type], + exc_value: Optional[BaseException], + traceback: Optional[Any], + ) -> None: if self.conn: if exc_type: - self.logger.error("Transaction failed", exc_info=(exc_type, exc_value, traceback)) + self.logger.error( + "Transaction failed", exc_info=(exc_type, exc_value, traceback) + ) self.conn.rollback() else: self.conn.commit() @@ -93,10 +101,16 @@ def initialize_schema(self): try: self.logger.debug("Initializing database schema") sql_script = Path(__file__).parent.joinpath("db_init.sql").read_text() - transpiled_sql = sqlglot.transpile(sql_script, read="sqlite", write=self.url.drivername.replace("postgresql", "postgres")) + transpiled_sql = sqlglot.transpile( + sql_script, + read="sqlite", + write=self.url.drivername.replace("postgresql", "postgres"), + ) for statement in transpiled_sql: self.execute(statement) - self.logger.info(f"Database schema initialized successfully for {self.url.drivername}") + self.logger.info( + f"Database schema initialized successfully for {self.url.drivername}" + ) except Exception as e: self.logger.exception("Schema initialization failed", exc_info=e) raise @@ -105,33 +119,46 @@ def run_script(self, path: Path): try: self.logger.debug(f"Running Database script at {str(path)}") sql_script = path.read_text() - transpiled_sql = sqlglot.transpile(sql_script, read="sqlite", write=self.url.drivername.replace("postgresql", "postgres")) + transpiled_sql = sqlglot.transpile( + sql_script, + read="sqlite", + write=self.url.drivername.replace("postgresql", "postgres"), + ) for statement in transpiled_sql: self.execute(statement) - self.logger.info(f"Successfuly ran script at {path} for {self.url.drivername}") + self.logger.info( + f"Successfuly ran script at {path} for {self.url.drivername}" + ) except Exception as e: - self.logger.exception(f"Failed to execute the script {path} for {self.url.drivername}", exc_info=e) + self.logger.exception( + f"Failed to execute the script {path} for {self.url.drivername}", + exc_info=e, + ) raise def _create_pool(self) -> PooledDB: if self.connection_string.startswith("sqlite:///"): import sqlite3 - Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir(parents=True, exist_ok=True) + + Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir( + parents=True, exist_ok=True + ) return PooledDB( creator=sqlite3, database=self.connection_string.replace("sqlite:///", ""), - maxconnections=5 + maxconnections=5, ) elif self.connection_string.startswith("postgresql://"): import psycopg2 + return PooledDB( - creator=psycopg2, - dsn=self.connection_string, - maxconnections=5 + creator=psycopg2, dsn=self.connection_string, maxconnections=5 ) - elif self.connection_string.startswith("mysql://") or \ - self.connection_string.startswith("mysql+pymysql://"): + elif self.connection_string.startswith( + "mysql://" + ) or self.connection_string.startswith("mysql+pymysql://"): import mysql.connector + return PooledDB( creator=mysql.connector, user=self.url.username, @@ -139,14 +166,15 @@ def _create_pool(self) -> PooledDB: host=self.url.host, port=self.url.port, database=self.url.database, - maxconnections=5 + maxconnections=5, ) elif self.connection_string.startswith("sqlserver://"): import pyodbc + return PooledDB( creator=pyodbc, dsn=self.connection_string.replace("sqlserver://", ""), - maxconnections=5 + maxconnections=5, ) else: raise ValueError(f"Unsupported database type: {self.url.drivername}") diff --git a/backend/logger.py b/backend/logger.py index a3099c5..f518fa7 100644 --- a/backend/logger.py +++ b/backend/logger.py @@ -4,6 +4,7 @@ # Implement your custom logging logic here. Eg. send logs to a cloud's logging tool. _logger_instance = None + def get_logger() -> Logger: global _logger_instance if _logger_instance is None: diff --git a/backend/main.py b/backend/main.py index 1595601..e2116f4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -3,7 +3,7 @@ from fastapi import FastAPI from langserve import add_routes -from backend.api_plugins import authentication_routes, session_routes +# from backend.api_plugins import authentication_routes, session_routes from backend.rag_components.rag import RAG # Initialize a RAG as discribed in the config.yaml file diff --git a/backend/model.py b/backend/model.py index 7e63ef4..aadd47e 100644 --- a/backend/model.py +++ b/backend/model.py @@ -12,16 +12,21 @@ class Message(BaseModel): sender: str content: str + class Input(BaseModel): question: str + class InvokeRequest(BaseModel): input: Input | List[Input] # Supports batched input id: str = Field(default_factory=lambda: str(uuid4())) session_id: str = None config: dict = Field(default_factory=dict) kwargs: dict = Field(default_factory=dict) - timestamp: str | datetime = Field(default_factory=lambda: datetime.utcnow().isoformat()) + timestamp: str | datetime = Field( + default_factory=lambda: datetime.utcnow().isoformat() + ) + class UserMessage(InvokeRequest): sender: str = "user" diff --git a/backend/rag_components/chain_links/answer_question_from_docs_and_history.py b/backend/rag_components/chain_links/answer_question_from_docs_and_history.py index 9934f3a..8eee958 100644 --- a/backend/rag_components/chain_links/answer_question_from_docs_and_history.py +++ b/backend/rag_components/chain_links/answer_question_from_docs_and_history.py @@ -1,10 +1,12 @@ -"""This chain answers the provided question based on documents it retreives and the conversation history""" +"""This chain answers the provided question based on documents it retreives and the +conversation history""" + from langchain_core.retrievers import BaseRetriever from pydantic import BaseModel -from backend.rag_components.chain_links.rag_basic import rag_basic -from backend.rag_components.chain_links.condense_question import condense_question +from backend.rag_components.chain_links.condense_question import condense_question from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable +from backend.rag_components.chain_links.rag_basic import rag_basic class QuestionWithHistory(BaseModel): @@ -16,11 +18,17 @@ class Response(BaseModel): response: str -def answer_question_from_docs_and_history_chain(llm, retriever: BaseRetriever) -> DocumentedRunnable: +def answer_question_from_docs_and_history_chain( + llm, retriever: BaseRetriever +) -> DocumentedRunnable: reformulate_question = condense_question(llm) answer_question = rag_basic(llm, retriever) - chain = reformulate_question | answer_question + chain = reformulate_question | answer_question typed_chain = chain.with_types(input_type=QuestionWithHistory, output_type=Response) - return DocumentedRunnable(typed_chain, chain_name="Answer question from docs and history", user_doc=__doc__) + return DocumentedRunnable( + typed_chain, + chain_name="Answer question from docs and history", + user_doc=__doc__, + ) diff --git a/backend/rag_components/chain_links/condense_question.py b/backend/rag_components/chain_links/condense_question.py index 70a38a2..9fd115a 100644 --- a/backend/rag_components/chain_links/condense_question.py +++ b/backend/rag_components/chain_links/condense_question.py @@ -1,6 +1,8 @@ -"""This chain condenses the chat history and the question into one standalone question.""" -from langchain_core.prompts import PromptTemplate +"""This chain condenses the chat history and the question into one standalone +question.""" + from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate from pydantic import BaseModel from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable @@ -16,20 +18,35 @@ class StandaloneQuestion(BaseModel): prompt = """\ -Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. You are presented with a conversation that may contain some spelling mistakes and grammatical errors, but your goal is to understand the underlying question. Make sure to avoid the use of unclear pronouns. +Given the conversation history and the following question, can you rephrase the user's \ +question in its original language so that it is self-sufficient. You are presented \ +with a conversation that may contain some spelling mistakes and grammatical errors, \ +but your goal is to understand the underlying question. Make sure to avoid the use of \ +unclear pronouns. -If the question is already self-sufficient, return the original question. If it seem the user is authorizing the chatbot to answer without specific context, make sure to reflect that in the rephrased question. +If the question is already self-sufficient, return the original question. If it seem \ +the user is authorizing the chatbot to answer without specific context, make sure to \ +reflect that in the rephrased question. Chat history: {chat_history} Question: {question} -""" # noqa: E501 +""" # noqa: E501 def condense_question(llm) -> DocumentedRunnable: - condense_question_prompt = PromptTemplate.from_template(prompt) # chat_history, question + condense_question_prompt = PromptTemplate.from_template( + prompt + ) # chat_history, question standalone_question = condense_question_prompt | llm | StrOutputParser() - - typed_chain = standalone_question.with_types(input_type=QuestionWithChatHistory, output_type=StandaloneQuestion) - return DocumentedRunnable(typed_chain, chain_name="Condense question and history", prompt=prompt, user_doc=__doc__) \ No newline at end of file + + typed_chain = standalone_question.with_types( + input_type=QuestionWithChatHistory, output_type=StandaloneQuestion + ) + return DocumentedRunnable( + typed_chain, + chain_name="Condense question and history", + prompt=prompt, + user_doc=__doc__, + ) diff --git a/backend/rag_components/chain_links/documented_runnable.py b/backend/rag_components/chain_links/documented_runnable.py index ba42d6d..c0608d5 100644 --- a/backend/rag_components/chain_links/documented_runnable.py +++ b/backend/rag_components/chain_links/documented_runnable.py @@ -1,30 +1,38 @@ from __future__ import annotations + import json +from dataclasses import asdict, dataclass from typing import Any, List, Optional + +import tabulate from docdantic import get_field_info -from langchain_core.runnables.base import Runnable, RunnableBinding, RunnableBindingBase, RunnableSequence, RunnableParallel +from jinja2 import Template +from langchain_core.runnables.base import ( + Runnable, + RunnableBinding, + RunnableBindingBase, + RunnableParallel, + RunnableSequence, +) from langchain_core.runnables.utils import Input, Output from pydantic.main import ModelMetaclass -import tabulate -from jinja2 import Template - - -from dataclasses import asdict, dataclass -from typing import Optional, Any @dataclass class RunnableSequenceDocumentation: docs: List[RunnableDocumentation] + @dataclass class RunnableParallelDocumentation: docs: List[RunnableDocumentation] + @dataclass class RunnableBindingDocumentation: docs: List[RunnableDocumentation] + @dataclass class RunnableDocumentation: chain_name: str @@ -38,9 +46,9 @@ def to_json(self): class EnhancedJSONEncoder(json.JSONEncoder): def default(self, o): return o.__name__ - + return json.dumps(asdict(self), cls=EnhancedJSONEncoder) - + def to_markdown(self) -> str: rendered_sub_docs = "" if self.sub_docs: @@ -60,19 +68,21 @@ def to_markdown(self) -> str: prompt=self.prompt, io_doc=input_doc + "\n" + output_doc, sub_docs=rendered_sub_docs, - user_doc=self.user_doc + user_doc=self.user_doc, ) class DocumentedRunnable(RunnableBindingBase[Input, Output]): """A DocumentedRunnable is a wrapper around a Runnable that generates documentation. - FIXME: Bound runnables that have configurable fields are not handled correctly, causing the playground to not be properly usable. + FIXME: Bound runnables that have configurable fields are not handled correctly, + causing the playground to not be properly usable. TODO: Add Mermaid diagrams. - This class is used to create a documented version of a Runnable, which is an executable - unit of work in the langchain framework. The documentation includes information about - the input and output types, as well as any additional user-provided prompts or documentation. + This class is used to create a documented version of a Runnable, which is an + executable unit of work in the langchain framework. The documentation includes + information about the input and output types, as well as any additional + user-provided prompts or documentation. Attributes: documentation (str): The generated markdown documentation for the Runnable. @@ -81,26 +91,33 @@ class DocumentedRunnable(RunnableBindingBase[Input, Output]): runnable (Runnable): The Runnable to be documented. chain_name (Optional[str]): The name of the chain that the Runnable belongs to. prompt (Optional[str]): The prompt that the Runnable uses, if applicable. - user_doc (Optional[str]): Any additional documentation you want displayed in the documentation. + user_doc (Optional[str]): Any additional documentation you want displayed in the + documentation. """ documentation: RunnableDocumentation | None def __init__( - self, - runnable: Runnable[Input, Output], - chain_name: Optional[str]=None, - prompt: Optional[str]=None, - user_doc: Optional[str]=None, + self, + runnable: Runnable[Input, Output], + chain_name: Optional[str] = None, + prompt: Optional[str] = None, + user_doc: Optional[str] = None, **kwargs: Any, ) -> None: - custom_input_type = runnable.InputType custom_output_type = runnable.OutputType final_chain_name = chain_name or type(runnable).__name__ if isinstance(runnable, RunnableSequence): - sub_docs = [runnable.documentation if isinstance(runnable, DocumentedRunnable) else DocumentedRunnable(runnable).documentation for runnable in runnable.steps] + sub_docs = [ + ( + runnable.documentation + if isinstance(runnable, DocumentedRunnable) + else DocumentedRunnable(runnable).documentation + ) + for runnable in runnable.steps + ] sub_docs = [doc for doc in sub_docs if doc] if len(sub_docs) >= 2: documentation = RunnableDocumentation( @@ -109,13 +126,20 @@ def __init__( input_type=custom_input_type, output_type=custom_output_type, user_doc=user_doc, - sub_docs=RunnableSequenceDocumentation(docs=sub_docs) + sub_docs=RunnableSequenceDocumentation(docs=sub_docs), ) else: documentation = sub_docs[0] if len(sub_docs) else None - + elif isinstance(runnable, RunnableParallel): - sub_docs = [runnable.documentation if isinstance(runnable, DocumentedRunnable) else DocumentedRunnable(runnable).documentation for runnable in runnable.steps.values()] + sub_docs = [ + ( + runnable.documentation + if isinstance(runnable, DocumentedRunnable) + else DocumentedRunnable(runnable).documentation + ) + for runnable in runnable.steps.values() + ] sub_docs = [doc for doc in sub_docs if doc] if len(sub_docs) >= 2: documentation = RunnableDocumentation( @@ -124,7 +148,7 @@ def __init__( input_type=custom_input_type, output_type=custom_output_type, user_doc=user_doc, - sub_docs=RunnableParallelDocumentation(docs=sub_docs) + sub_docs=RunnableParallelDocumentation(docs=sub_docs), ) else: documentation = sub_docs[0] if len(sub_docs) else None @@ -134,9 +158,13 @@ def __init__( while isinstance(bound_runnable, RunnableBinding): bound_runnable = bound_runnable.bound - sub_docs = [bound_runnable.documentation if isinstance(bound_runnable, DocumentedRunnable) else DocumentedRunnable(bound_runnable).documentation] + sub_docs = [ + bound_runnable.documentation + if isinstance(bound_runnable, DocumentedRunnable) + else DocumentedRunnable(bound_runnable).documentation + ] sub_docs = [doc for doc in sub_docs if doc] - + if final_chain_name == "RunnableBinding": documentation = sub_docs[0] if len(sub_docs) else None else: @@ -146,25 +174,30 @@ def __init__( input_type=custom_input_type, output_type=custom_output_type, user_doc=user_doc, - sub_docs=RunnableBindingDocumentation(docs=sub_docs) if len(sub_docs) else None + sub_docs=( + RunnableBindingDocumentation(docs=sub_docs) + if len(sub_docs) + else None + ), ) else: documentation = None super().__init__( - bound=runnable, + bound=runnable, documentation=documentation, custom_input_type=custom_input_type, custom_output_type=custom_output_type, **kwargs, ) + def render_io_doc(input, output) -> tuple[str]: if isinstance(input, ModelMetaclass): input_doc = render_model_doc(input, "Input") else: input_doc = f"### Input: {input.__name__}" - + if isinstance(output, ModelMetaclass): output_doc = render_model_doc(output, "Output") else: @@ -179,16 +212,21 @@ def render_model_doc(model: ModelMetaclass, input_or_output: str) -> str: tables = { cls: tabulate.tabulate( [ - (field.name, field.type, field.required, field.default,) + ( + field.name, + field.type, + field.required, + field.default, + ) for field in fields ], headers=["Name", "Type", "Required", "Default"], - tablefmt="github" + tablefmt="github", ) for cls, fields in field_info.items() } - return '\n'.join( + return "\n".join( f"\n### {input_or_output}: {cls}\n\n{table}\n\n" for cls, table in tables.items() ) @@ -232,4 +270,4 @@ def render_model_doc(model: ModelMetaclass, input_or_output: str) -> str: {{ doc }} {% endfor %} -""" \ No newline at end of file +""" diff --git a/backend/rag_components/chain_links/rag_basic.py b/backend/rag_components/chain_links/rag_basic.py index bf6b2a7..569a9fc 100644 --- a/backend/rag_components/chain_links/rag_basic.py +++ b/backend/rag_components/chain_links/rag_basic.py @@ -1,4 +1,5 @@ """This chain answers the provided question based on documents it retreives.""" + from langchain_core.prompts import ChatPromptTemplate from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import RunnablePassthrough @@ -7,13 +8,12 @@ from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable from backend.rag_components.chain_links.retrieve_and_format_docs import fetch_docs_chain - prompt = """ As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail. Do not invent information. You must sift through various sources of information, disregarding any data that is not relevant to the query's context. Your response should integrate knowledge from the valid sources you have identified. Additionally, the question might include hypothetical or counterfactual statements. You need to recognize these and adjust your response to provide accurate, relevant information without being misled by the counterfactuals. Respond to the question only taking into account the following context. If no context is provided, do not answer. You may provide an answer if the user explicitely asked for a general answer. You may ask the user to rephrase their question, or their permission to answer without specific context from your own knowledge. Context: {relevant_documents} Question: {question} -""" # noqa: E501 +""" # noqa: E501 class Question(BaseModel): @@ -26,9 +26,17 @@ class Response(BaseModel): def rag_basic(llm, retriever: BaseRetriever) -> DocumentedRunnable: chain = ( - {"relevant_documents": fetch_docs_chain(retriever), "question": RunnablePassthrough(Question)} + { + "relevant_documents": fetch_docs_chain(retriever), + "question": RunnablePassthrough(Question), + } | ChatPromptTemplate.from_template(prompt) | llm ) typed_chain = chain.with_types(input_type=str, output_type=Response) - return DocumentedRunnable(typed_chain, chain_name="Answer questions from documents stored in a vector store", prompt=prompt, user_doc=__doc__) + return DocumentedRunnable( + typed_chain, + chain_name="Answer questions from documents stored in a vector store", + prompt=prompt, + user_doc=__doc__, + ) diff --git a/backend/rag_components/chain_links/rag_with_history.py b/backend/rag_components/chain_links/rag_with_history.py index 40b5e71..0e71a40 100644 --- a/backend/rag_components/chain_links/rag_with_history.py +++ b/backend/rag_components/chain_links/rag_with_history.py @@ -1,27 +1,36 @@ from langchain_core.retrievers import BaseRetriever from langchain_core.runnables.history import RunnableWithMessageHistory -from backend.config import RagConfig -from backend.rag_components.chain_links.answer_question_from_docs_and_history import answer_question_from_docs_and_history_chain +from backend.config import RagConfig +from backend.rag_components.chain_links.answer_question_from_docs_and_history import ( + answer_question_from_docs_and_history_chain, +) from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable from backend.rag_components.chat_message_history import get_chat_message_history -def rag_with_history_chain(config: RagConfig, llm, retriever: BaseRetriever) -> DocumentedRunnable: +def rag_with_history_chain( + config: RagConfig, llm, retriever: BaseRetriever +) -> DocumentedRunnable: chain = answer_question_from_docs_and_history_chain(llm, retriever) chain_with_mem = RunnableWithMessageHistory( chain, lambda session_id: get_chat_message_history(config, session_id), input_messages_key="question", - history_messages_key="chat_history" + history_messages_key="chat_history", ) return chain_with_mem - # A bug in DocumentedRunnable makes configurable Runnables such as RunnableWithMessageHistory not work properly. + # A bug in DocumentedRunnable makes configurable Runnables such as + # RunnableWithMessageHistory not work properly. return DocumentedRunnable( - chain_with_mem, - chain_name="RAG with persistant memory", - user_doc="This chain answers the provided question based on documents it retreives and the conversation history. It uses a persistant memory to store the conversation history.", + chain_with_mem, + chain_name="RAG with persistant memory", + user_doc=( + "This chain answers the provided question based on documents it retreives" + " and the conversation history. It uses a persistant memory to store the" + " conversation history." + ), ) diff --git a/backend/rag_components/chain_links/retrieve_and_format_docs.py b/backend/rag_components/chain_links/retrieve_and_format_docs.py index a34249c..e81d624 100644 --- a/backend/rag_components/chain_links/retrieve_and_format_docs.py +++ b/backend/rag_components/chain_links/retrieve_and_format_docs.py @@ -1,7 +1,7 @@ """This chain fetches the relevant documents and combines them into a single string.""" -from langchain_core.prompts import PromptTemplate from langchain.schema import format_document +from langchain_core.prompts import PromptTemplate from pydantic import BaseModel from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable @@ -19,11 +19,15 @@ class Documents(BaseModel): def fetch_docs_chain(retriever) -> DocumentedRunnable: relevant_documents = retriever | _combine_documents - typed_chain = relevant_documents.with_types(input_type=Question, output_type=Documents) - return DocumentedRunnable(typed_chain, chain_name="Fetch documents", prompt=prompt, user_doc=__doc__) + typed_chain = relevant_documents.with_types( + input_type=Question, output_type=Documents + ) + return DocumentedRunnable( + typed_chain, chain_name="Fetch documents", prompt=prompt, user_doc=__doc__ + ) def _combine_documents(docs, document_separator="\n\n"): document_prompt = PromptTemplate.from_template(template=prompt) doc_strings = [format_document(doc, document_prompt) for doc in docs] - return document_separator.join(doc_strings) \ No newline at end of file + return document_separator.join(doc_strings) diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py index 22704a3..20dd63a 100644 --- a/backend/rag_components/chat_message_history.py +++ b/backend/rag_components/chat_message_history.py @@ -1,4 +1,3 @@ - from datetime import datetime from langchain_community.chat_message_histories import SQLChatMessageHistory @@ -23,6 +22,7 @@ def get_chat_message_history(config: RagConfig, chat_id): custom_message_converter=TimestampedMessageConverter(TABLE_NAME), ) + class TimestampedMessageConverter(DefaultMessageConverter): def __init__(self, table_name: str): self.model_class = create_message_model(table_name, declarative_base()) diff --git a/backend/rag_components/document_loader.py b/backend/rag_components/document_loader.py index 04aa56f..5878309 100644 --- a/backend/rag_components/document_loader.py +++ b/backend/rag_components/document_loader.py @@ -33,14 +33,17 @@ def get_best_loader(file_extension: str, llm: BaseChatModel): input_variables=["file_extension", "loaders"], template=""" Among the following loaders, which is the best to load a "{file_extension}" file? \ - Only give me one the class name without any other special characters. If no relevant loader is found, respond "None". + Only give me one the class name without any other special characters. If no \ + relevant loader is found, respond "None". Loaders: {loaders} """, ) chain = LLMChain(llm=llm, prompt=prompt, output_key="loader_class_name") - return chain({"file_extension": file_extension, "loaders": loaders})["loader_class_name"] + return chain({"file_extension": file_extension, "loaders": loaders})[ + "loader_class_name" + ] def get_loaders() -> List[str]: diff --git a/backend/rag_components/embedding.py b/backend/rag_components/embedding.py index 7a38223..75f2f24 100644 --- a/backend/rag_components/embedding.py +++ b/backend/rag_components/embedding.py @@ -6,6 +6,8 @@ def get_embedding_model(config: RagConfig): spec = getattr(embeddings, config.embedding_model.source) kwargs = { - key: value for key, value in config.embedding_model.source_config.items() if key in spec.__fields__.keys() + key: value + for key, value in config.embedding_model.source_config.items() + if key in spec.__fields__.keys() } return spec(**kwargs) diff --git a/backend/rag_components/llm.py b/backend/rag_components/llm.py index cbd8306..bdfb23d 100644 --- a/backend/rag_components/llm.py +++ b/backend/rag_components/llm.py @@ -9,7 +9,9 @@ def get_llm_model(config: RagConfig, callbacks: List[BaseCallbackHandler] = []): llm_spec = getattr(chat_models, config.llm.source) kwargs = { - key: value for key, value in config.llm.source_config.items() if key in llm_spec.__fields__.keys() + key: value + for key, value in config.llm.source_config.items() + if key in llm_spec.__fields__.keys() } kwargs["streaming"] = True kwargs["callbacks"] = callbacks diff --git a/backend/rag_components/rag.py b/backend/rag_components/rag.py index a7a74f6..a7c6d72 100644 --- a/backend/rag_components/rag.py +++ b/backend/rag_components/rag.py @@ -6,38 +6,45 @@ from langchain.indexes import SQLRecordManager, index from langchain.schema.embeddings import Embeddings from langchain.vectorstores import VectorStore -from langchain_core.retrievers import BaseRetriever from langchain.vectorstores.utils import filter_complex_metadata +from langchain_core.retrievers import BaseRetriever from backend.config import RagConfig from backend.database import Database from backend.logger import get_logger +from backend.rag_components.chain_links.rag_basic import rag_basic +from backend.rag_components.chain_links.rag_with_history import rag_with_history_chain from backend.rag_components.document_loader import get_documents from backend.rag_components.embedding import get_embedding_model from backend.rag_components.llm import get_llm_model from backend.rag_components.retriever import get_retriever from backend.rag_components.vector_store import get_vector_store -from backend.rag_components.chain_links.rag_basic import rag_basic -from backend.rag_components.chain_links.rag_with_history import rag_with_history_chain + class RAG: """ - The RAG class orchestrates the components necessary for a retrieval-augmented generation pipeline. + The RAG class orchestrates the components necessary for a retrieval-augmented + generation pipeline. It initializes with a configuration, either directly or from a file. The RAG has two main purposes: - - loading the RAG with documents, which involves ingesting and processing documents to be retrievable by the system - - generating the chain from the components as specified in the configuration, which entails \ - assembling the various components (language model, embeddings, vector store) into a \ - coherent pipeline for generating responses based on retrieved information. + - loading the RAG with documents, which involves ingesting and processing + documents to be retrievable by the system + - generating the chain from the components as specified in the configuration, + which entails assembling the various components (language model, embeddings, + vector store) into a coherent pipeline for generating responses based on + retrieved information. Attributes: config (RagConfig): Configuration object containing settings for RAG components. llm (BaseChatModel): The language model used for generating responses. - embeddings (Embeddings): The embedding model used for creating vector representations of text. - vector_store (VectorStore): The vector store that holds and allows for searching of embeddings. + embeddings (Embeddings): The embedding model used for creating vector + representations of text. + vector_store (VectorStore): The vector store that holds and allows for searching + of embeddings. logger (Logger): Logger for logging information, warnings, and errors. """ + def __init__(self, config: Union[Path, RagConfig]): if isinstance(config, RagConfig): self.config = config @@ -60,13 +67,17 @@ def get_chain(self, memory: bool = False): chain = rag_basic(self.llm, self.retriever) return chain - def load_file(self, file_path: Path) -> List[Document]: documents = get_documents(file_path, self.llm) filtered_documents = filter_complex_metadata(documents) return self.load_documents(filtered_documents) - def load_documents(self, documents: List[Document], insertion_mode: str = None, namespace: str = "default"): + def load_documents( + self, + documents: List[Document], + insertion_mode: str = None, + namespace: str = "default", + ): insertion_mode = insertion_mode or self.config.vector_store.insertion_mode record_manager = SQLRecordManager( @@ -79,7 +90,9 @@ def load_documents(self, documents: List[Document], insertion_mode: str = None, batch_size = 100 for batch in range(0, len(documents), batch_size): - self.logger.info(f"Indexing batch {batch} to {min(len(documents), batch + batch_size)}.") + self.logger.info( + f"Indexing batch {batch} to {min(len(documents), batch + batch_size)}." + ) indexing_output = index( documents[batch : min(len(documents), batch + batch_size)], diff --git a/backend/rag_components/retriever.py b/backend/rag_components/retriever.py index 260e457..25cd552 100644 --- a/backend/rag_components/retriever.py +++ b/backend/rag_components/retriever.py @@ -1,6 +1,4 @@ from langchain_core.vectorstores import VectorStore -from langchain import retrievers as base_retrievers -from langchain_community import retrievers as community_retrievers def get_retriever(vector_store: VectorStore): diff --git a/backend/rag_components/vector_store.py b/backend/rag_components/vector_store.py index 9959ef8..e0474ce 100644 --- a/backend/rag_components/vector_store.py +++ b/backend/rag_components/vector_store.py @@ -8,17 +8,27 @@ def get_vector_store(embedding_model, config: RagConfig): vector_store_spec = getattr(vectorstores, config.vector_store.source) - # the vector store class in langchain doesn't have a uniform interface to pass the embedding model + # the vector store class in langchain doesn't have a uniform interface to pass the + # embedding model # we extract the propertiy of the class that matches the 'Embeddings' type # and instanciate the vector store with our embedding model signature = inspect.signature(vector_store_spec.__init__) parameters = signature.parameters params_dict = dict(parameters) embedding_param = next( - (param for param in params_dict.values() if "Embeddings" in str(param.annotation)), None + ( + param + for param in params_dict.values() + if "Embeddings" in str(param.annotation) + ), + None, ) - kwargs = {key: value for key, value in config.vector_store.source_config.items() if key in parameters.keys()} + kwargs = { + key: value + for key, value in config.vector_store.source_config.items() + if key in parameters.keys() + } kwargs[embedding_param.name] = embedding_model vector_store = vector_store_spec(**kwargs) return vector_store diff --git a/backend/requirements.txt b/backend/requirements.txt index 0107f98..8561d22 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -20,4 +20,4 @@ sse_starlette sentence-transformers chromadb mysql_connector_repackaged -psycopg2-binary \ No newline at end of file +psycopg2-binary diff --git a/docker-compose.yaml b/docker-compose.yaml index 4889de1..d93ef0d 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -35,4 +35,4 @@ services: networks: app-network: - driver: bridge \ No newline at end of file + driver: bridge diff --git a/docs/backend/chains/basic_chain.md b/docs/backend/chains/basic_chain.md index 25a6e32..6f907e0 100644 --- a/docs/backend/chains/basic_chain.md +++ b/docs/backend/chains/basic_chain.md @@ -54,4 +54,3 @@ This chain fetches the relevant documents and combines them into a single string - diff --git a/docs/backend/chains/chain_with_memory.md b/docs/backend/chains/chain_with_memory.md index f2bc516..dcb0b2f 100644 --- a/docs/backend/chains/chain_with_memory.md +++ b/docs/backend/chains/chain_with_memory.md @@ -173,4 +173,3 @@ This chain fetches the relevant documents and combines them into a single string - diff --git a/docs/backend/chains/chains.md b/docs/backend/chains/chains.md index 9af9d9b..6be2722 100644 --- a/docs/backend/chains/chains.md +++ b/docs/backend/chains/chains.md @@ -7,7 +7,7 @@ We provide two basic RAG chains to get you started, one does simple one-shot Q&A ## Chains and chain links -This repo does not define large monolithic chains. To make it easier to pick and chose the required functionalities, we provide "chain links" at `backend/rag_components/chain_links`. All links are valid, self-sufficient chains. You can think of it as a toolbox of langchain components meant to be composed and stacked together to build actually useful chains. +This repo does not define large monolithic chains. To make it easier to pick and chose the required functionalities, we provide "chain links" at `backend/rag_components/chain_links`. All links are valid, self-sufficient chains. You can think of it as a toolbox of langchain components meant to be composed and stacked together to build actually useful chains. As all links are `Runnable` objects, they can be built from other chain links which are themselves made of chain links, etc... @@ -21,7 +21,7 @@ Typically, a link has: - A chain definition -For example the `condense_question` chain link has `QuestionWithChatHistory` and `StandaloneQuestion` as input an output models. +For example the `condense_question` chain link has `QuestionWithChatHistory` and `StandaloneQuestion` as input an output models. ```python class QuestionWithChatHistory(BaseModel): question: str @@ -69,9 +69,9 @@ Documentation is recursively generated from all the `DocumentedRunnable` chains In order to document your chain, just wrap it in a `DocumentedRunnable`: ```python documented_chain = DocumentedRunnable( - runnable=my_chain, - chain_name="My documented Chain", - user_doc="Additional chain explainations that will be displayed in the markdown", + runnable=my_chain, + chain_name="My documented Chain", + user_doc="Additional chain explainations that will be displayed in the markdown", prompt=prompt, ) ``` diff --git a/docs/doc_generation.py b/docs/doc_generation.py index 1739ccf..ab1ce78 100644 --- a/docs/doc_generation.py +++ b/docs/doc_generation.py @@ -1,20 +1,28 @@ -"""Quick ad-hoc script that generates markdown documentation for the basic RAG chain and the RAG chain with memory.""" +"""Quick ad-hoc script that generates markdown documentation for the basic RAG chain and +the RAG chain with memory.""" + from pathlib import Path + from backend.rag_components.chain_links.documented_runnable import DocumentedRunnable from backend.rag_components.rag import RAG - rag = RAG(config=Path(__file__).parents[1] / "backend" / "config.yaml") chain = rag.get_chain(memory=False) -with open(Path(__file__).parent / "backend" / "chains" / "basic_chain.md", "w") as f: +with (Path(__file__).parent / "backend" / "chains" / "basic_chain.md").open("w") as f: f.write(chain.documentation.to_markdown()) chain = rag.get_chain(memory=True) doc_chain = DocumentedRunnable( - chain, - chain_name="RAG with persistant memory", - user_doc="This chain answers the provided question based on documents it retreives and the conversation history. It uses a persistant memory to store the conversation history.", + chain, + chain_name="RAG with persistant memory", + user_doc=( + "This chain answers the provided question based on documents it retreives and " + "the conversation history. It uses a persistant memory to store the " + "conversation history." + ), ) -with open(Path(__file__).parent / "backend" / "chains" / "chain_with_memory.md", "w") as f: +with (Path(__file__).parent / "backend" / "chains" / "chain_with_memory.md").open( + "w" +) as f: f.write(doc_chain.documentation.to_markdown()) diff --git a/docs/index.md b/docs/index.md index 9d6967b..5bf9852 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,7 +17,7 @@ This is a starter kit to deploy a modularizable RAG locally or on the cloud (or ## Quickstart -This quickstart will guide you through the steps to serve the RAG and load a few documents. +This quickstart will guide you through the steps to serve the RAG and load a few documents. You will run both the back and front on your machine. diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 7fefc4e..2a1bbef 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -21,4 +21,4 @@ EXPOSE $PORT COPY . ./frontend -CMD python -m streamlit run frontend/front.py --server.port $PORT \ No newline at end of file +CMD python -m streamlit run frontend/front.py --server.port $PORT diff --git a/frontend/front.py b/frontend/front.py index cea1b6f..704279a 100644 --- a/frontend/front.py +++ b/frontend/front.py @@ -23,37 +23,53 @@ def browser_tab_title(): def application_header(): st.image(Image.open(ASSETS_PATH / "logo_title.jpeg")) - st.caption("Learn more about the RAG indus kit here: https://artefactory.github.io/skaff-rag-accelerator") + st.caption( + "Learn more about the RAG indus kit here:" + " https://artefactory.github.io/skaff-rag-accelerator" + ) if __name__ == "__main__": browser_tab_title() application_header() - # The session is used to make requests to the backend. It helps with the handling of cookies, auth, and other session data + # The session is used to make requests to the backend. It helps with the handling of + # cookies, auth, and other session data initialize_state_variable("session", value=create_session()) # The chain is our RAG that will be used to answer questions. - # Langserve's RemoteRunnable allows us to work as if the RAG was local, but it's actually running on the backend + # Langserve's RemoteRunnable allows us to work as if the RAG was local, but it's + # actually running on the backend initialize_state_variable("chain", value=RemoteRunnable(BACKEND_URL)) - # If the backend supports authentication but the user is not authenticated, show the authentication page - if backend_supports_auth() and st.session_state.get("authenticated_session", None) is None: + # If the backend supports authentication but the user is not authenticated, show the + # authentication page + if ( + backend_supports_auth() + and st.session_state.get("authenticated_session", None) is None + ): initialize_state_variable("login_status_message", value="") initialize_state_variable("login_status_level", value="info") authentication_page() - st.stop() # Stop the script to avoid running the rest of the code if the user is not authenticated + # Stop the script to avoid running the rest of the code if the user is not + # authenticated + st.stop() - # If the backend does not support authentication, just use the session as the authenticated session - if not backend_supports_auth() and st.session_state.get("authenticated_session", None) is None: + # If the backend does not support authentication, just use the session as the + # authenticated session + if ( + not backend_supports_auth() + and st.session_state.get("authenticated_session", None) is None + ): st.session_state["authenticated_session"] = st.session_state["session"] # Once we have an authenticated session, show the chat interface if st.session_state.get("authenticated_session") is not None: - # If the backend supports sessions, enable session navigation if backend_supports_sessions(): - initialize_state_variable("email", value="demo.user@email.com") # With authentication, this will take the user's email + initialize_state_variable( + "email", value="demo.user@email.com" + ) # With authentication, this will take the user's email sidebar() session_chat() else: diff --git a/frontend/lib/auth.py b/frontend/lib/auth.py index c7b5b39..b711d9a 100644 --- a/frontend/lib/auth.py +++ b/frontend/lib/auth.py @@ -13,7 +13,9 @@ def authentication_page(): auth_form_tabs = [stx.TabBarItemData(id="Login", title="Login", description="")] if ADMIN_MODE: - auth_form_tabs += [stx.TabBarItemData(id="Signup", title="Signup", description="")] + auth_form_tabs += [ + stx.TabBarItemData(id="Signup", title="Signup", description="") + ] tab = stx.tab_bar(data=auth_form_tabs, default="Login") if tab == "Login": @@ -27,8 +29,11 @@ def login_form(): username = st.text_input("Username", key="username") password = st.text_input("Password", type="password") if st.session_state["login_status_message"]: - # Dynamically decides whether to call st.error, st.success, or any other Streamlit method - getattr(st, st.session_state["login_status_level"])(st.session_state["login_status_message"]) + # Dynamically decides whether to call st.error, st.success, or any other + # Streamlit method + getattr(st, st.session_state["login_status_level"])( + st.session_state["login_status_message"] + ) submit = st.form_submit_button("Log in") if submit: @@ -39,7 +44,9 @@ def login_form(): session = authenticate_session(session, token) else: st.session_state["login_status_level"] = "error" - st.session_state["login_status_message"] = "Username/password combination not found" + st.session_state["login_status_message"] = ( + "Username/password combination not found" + ) st.session_state["authenticated_session"] = session st.session_state["email"] = username st.rerun() @@ -50,7 +57,9 @@ def signup_form(): username = st.text_input("Username", key="username") password = st.text_input("Password", type="password") if st.session_state["login_status_message"]: - getattr(st, st.session_state["login_status_level"])(st.session_state["login_status_message"]) + getattr(st, st.session_state["login_status_level"])( + st.session_state["login_status_message"] + ) submit = st.form_submit_button("Sign up") if submit: @@ -63,7 +72,9 @@ def signup_form(): st.session_state["login_status_level"] = "success" st.session_state["login_status_message"] = "Success! Account created." if st.session_state["login_status_message"]: - getattr(st, st.session_state["login_status_level"])(st.session_state["login_status_message"]) + getattr(st, st.session_state["login_status_level"])( + st.session_state["login_status_message"] + ) sleep(1.5) else: st.session_state["login_status_level"] = "error" @@ -75,7 +86,9 @@ def signup_form(): def get_token(username: str, password: str) -> Optional[str]: session = create_session() - response = session.post("/user/login", data={"username": username, "password": password}) + response = session.post( + "/user/login", data={"username": username, "password": password} + ) if response.status_code == 200 and "access_token" in response.json(): return response.json()["access_token"] else: @@ -84,7 +97,9 @@ def get_token(username: str, password: str) -> Optional[str]: def sign_up(username: str, password: str) -> bool: session = create_session() - response = session.post("/user/signup", json={"email": username, "password": password}) + response = session.post( + "/user/signup", json={"email": username, "password": password} + ) if response.status_code == 200 and "email" in response.json(): return True else: @@ -93,5 +108,7 @@ def sign_up(username: str, password: str) -> bool: def authenticate_session(session, bearer_token: str) -> requests.Session: session.headers.update({"Authorization": f"Bearer {bearer_token}"}) - st.session_state["chain"] = RemoteRunnable(BACKEND_URL, headers={"Authorization": f"Bearer {bearer_token}"}) + st.session_state["chain"] = RemoteRunnable( + BACKEND_URL, headers={"Authorization": f"Bearer {bearer_token}"} + ) return session diff --git a/frontend/lib/backend_interface.py b/frontend/lib/backend_interface.py index 19af143..612824b 100644 --- a/frontend/lib/backend_interface.py +++ b/frontend/lib/backend_interface.py @@ -16,7 +16,9 @@ def query(verb: str, url: str, **kwargs): st.session_state["authenticated_session"] = None st.session_state["email"] = None st.session_state["login_status_level"] = "error" - st.session_state["login_status_message"] = "Session expired. Please log in again." + st.session_state["login_status_message"] = ( + "Session expired. Please log in again." + ) st.rerun() return response @@ -34,6 +36,7 @@ def create_session() -> Session: session = BaseUrlSession(BACKEND_URL) return session + class BaseUrlSession(Session): def __init__(self, base_url): super().__init__() diff --git a/frontend/lib/basic_chat.py b/frontend/lib/basic_chat.py index 51c3738..baffe47 100644 --- a/frontend/lib/basic_chat.py +++ b/frontend/lib/basic_chat.py @@ -6,7 +6,6 @@ def basic_chat(): with st.container(border=True): if user_question: - with st.chat_message("user"): st.write(user_question) diff --git a/frontend/lib/session_chat.py b/frontend/lib/session_chat.py index 3f20969..d8da3d0 100644 --- a/frontend/lib/session_chat.py +++ b/frontend/lib/session_chat.py @@ -17,7 +17,10 @@ class Message: def __post_init__(self): self.id = str(uuid4()) if self.id is None else self.id - self.timestamp = datetime.utcnow().isoformat() if self.timestamp is None else self.timestamp + self.timestamp = ( + datetime.utcnow().isoformat() if self.timestamp is None else self.timestamp + ) + def session_chat(): user_question = st.chat_input("Say something") @@ -40,7 +43,10 @@ def session_chat(): st.session_state["messages"].append(user_message) chain = st.session_state.get("chain") - response = chain.stream({"question": user_question}, {"configurable": {"session_id": session_id}}) + response = chain.stream( + {"question": user_question}, + {"configurable": {"session_id": session_id}}, + ) with st.chat_message("assistant"): full_response = "" @@ -60,6 +66,7 @@ def new_session(): st.session_state["messages"] = [] return session_id + def get_session(session_id: str): session = query("get", f"/session/{session_id}").json() return session diff --git a/frontend/lib/sidebar.py b/frontend/lib/sidebar.py index d9a3c9e..6e19a8a 100644 --- a/frontend/lib/sidebar.py +++ b/frontend/lib/sidebar.py @@ -10,9 +10,14 @@ def sidebar(): with st.sidebar: st.sidebar.title("RAG Industrialization Kit", anchor="top") - st.sidebar.markdown(f"

Logged in as {st.session_state['email']}

", unsafe_allow_html=True) - - if st.sidebar.button("New Chat", use_container_width=True, key="new_chat_button"): + st.sidebar.markdown( + f"

Logged in as {st.session_state['email']}

", + unsafe_allow_html=True, + ) + + if st.sidebar.button( + "New Chat", use_container_width=True, key="new_chat_button" + ): st.session_state["messages"] = [] with st.empty(): @@ -20,7 +25,9 @@ def sidebar(): chats_by_time_ago = {} for chat in chat_list: chat_id, timestamp = chat["id"], chat["timestamp"] - time_ago = humanize.naturaltime(datetime.utcnow() - datetime.fromisoformat(timestamp)) + time_ago = humanize.naturaltime( + datetime.utcnow() - datetime.fromisoformat(timestamp) + ) if time_ago not in chats_by_time_ago: chats_by_time_ago[time_ago] = [] chats_by_time_ago[time_ago].append(chat) @@ -29,9 +36,14 @@ def sidebar(): st.sidebar.markdown(time_ago) for chat in chats: chat_id = chat["id"] - if st.sidebar.button(chat_id, key=chat_id, use_container_width=True): + if st.sidebar.button( + chat_id, key=chat_id, use_container_width=True + ): st.session_state["chat_id"] = chat_id - messages = [Message(**message) for message in get_session(chat_id)["messages"]] + messages = [ + Message(**message) + for message in get_session(chat_id)["messages"] + ] st.session_state["messages"] = messages diff --git a/frontend/requirements.txt b/frontend/requirements.txt index dd129ab..41f6ab3 100644 --- a/frontend/requirements.txt +++ b/frontend/requirements.txt @@ -7,4 +7,4 @@ python-dotenv Requests streamlit httpx_sse -pydantic==1.* \ No newline at end of file +pydantic==1.* diff --git a/requirements-dev.txt b/requirements-dev.txt index e3a19fe..31a0f07 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,4 +4,4 @@ mkdocs mkdocs_monorepo_plugin pymdown-extensions mkdocs-pymdownx-material-extras -mkdocs-techdocs-core \ No newline at end of file +mkdocs-techdocs-core From a80bf4f9086799df2977a0820e3b5ba4306b9335 Mon Sep 17 00:00:00 2001 From: Baptiste Pasquier Date: Mon, 15 Apr 2024 17:30:04 +0200 Subject: [PATCH 4/4] Update CI workflow --- .github/workflows/ci.yaml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d5b1987..dba2d00 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -11,15 +11,18 @@ jobs: python-version: ['3.11'] steps: - - uses: actions/checkout@v2 + - name: Check out repository + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install requirements run: | python -m pip install --upgrade pip - pip install -r requirements-dev.txt + pip install pre-commit + - name: Run Pre commit hook (formatting, linting & tests) run: pre-commit run --all-files --hook-stage pre-push --show-diff-on-failure