Skip to content

Commit

Permalink
refacto: yes
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Feb 21, 2024
1 parent bad8eb6 commit c143d49
Show file tree
Hide file tree
Showing 66 changed files with 916 additions and 546 deletions.
15 changes: 15 additions & 0 deletions backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os

from dotenv import load_dotenv

load_dotenv()
DATABASE_URL = os.getenv("DATABASE_URL")

# Private key used to generate the JWT tokens for secure authentication
SECRET_KEY = os.getenv("SECRET_KEY", "default_unsecure_key")

# Algorithm used to generate JWT tokens
ALGORITHM = os.getenv("ALGORITHM", "HS256")

# If the API runs in admin mode, it will allow the creation of new users
ADMIN_MODE = bool(int(os.getenv("ADMIN_MODE", False)))
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import Depends, HTTPException, status
from pathlib import Path
from fastapi import Depends, HTTPException, Response, status

from backend.api_plugins.lib.user_management import (
User,
Expand All @@ -7,11 +8,15 @@
get_user,
user_exists,
)
from backend.database import Database


def insecure_authentication_routes(app):
async def get_current_user(token: str) -> User:
email = token.replace("Bearer ", "")
with Database() as connection:
connection.run_script(Path(__file__).parent / "users_tables.sql")

async def get_current_user(email: str) -> User:
email = email.replace("Bearer ", "")
user = get_user(email)
return user

Expand Down Expand Up @@ -51,5 +56,10 @@ async def login(email: str) -> dict:
@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)) -> User:
return current_user


@app.get("/user")
async def user_root() -> dict:
return Response("Insecure user management routes are enabled. Do not use in prod.", status_code=200)

return Depends(get_current_user)
7 changes: 7 additions & 0 deletions backend/api_plugins/insecure_authentication/users_tables.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Dialect MUST be sqlite, even if the database you use is different.
-- It is transpiled to the right dialect when executed.

CREATE TABLE IF NOT EXISTS "users" (
"email" VARCHAR(255) PRIMARY KEY,
"password" TEXT
);
5 changes: 1 addition & 4 deletions backend/api_plugins/lib/user_management.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os
from datetime import datetime, timedelta
from typing import Optional

import argon2
from jose import jwt
from pydantic import BaseModel

from backend import ALGORITHM, SECRET_KEY
from backend.database import Database

SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key")
ALGORITHM = "HS256"


class UnsecureUser(BaseModel):
email: str = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

from dotenv import load_dotenv
from fastapi import Depends, HTTPException, status
from pathlib import Path
from typing import List
from fastapi import Depends, HTTPException, Response, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from backend import ADMIN_MODE

from backend.api_plugins.lib.user_management import (
ALGORITHM,
Expand All @@ -18,13 +18,12 @@
user_exists,
)

load_dotenv()

# If the API runs in admin mode, it will allow the creation of new users
# For public deployments, that prevents unwanted people to singup and consume tokens
ADMIN_MODE = bool(int(os.getenv("ADMIN_MODE", False)))

def authentication_routes(app):
def authentication_routes(app, dependencies=List[Depends]):
from backend.database import Database
with Database() as connection:
connection.run_script(Path(__file__).parent / "users_tables.sql")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/user/login")

async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
Expand Down Expand Up @@ -97,5 +96,11 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)) -> User:
return current_user


@app.get("/user")
async def user_root() -> dict:
return Response("User management routes are enabled.", status_code=200)


return Depends(get_current_user)
7 changes: 7 additions & 0 deletions backend/api_plugins/secure_authentcation/users_tables.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Dialect MUST be sqlite, even if the database you use is different.
-- It is transpiled to the right dialect when executed.

CREATE TABLE IF NOT EXISTS "users" (
"email" VARCHAR(255) PRIMARY KEY,
"password" TEXT
);
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from datetime import datetime
import json
from pathlib import Path
from typing import List, Optional, Sequence
from uuid import uuid4

from fastapi import APIRouter, Depends, FastAPI
from fastapi import APIRouter, Depends, FastAPI, Response

from backend.api_plugins.lib.user_management import User

Expand All @@ -16,6 +18,9 @@ def session_routes(
from backend.database import Database
from backend.model import Message

with Database() as connection:
connection.run_script(Path(__file__).parent / "sessions_tables.sql")

@app.post("/session/new")
async def chat_new(current_user: User=authentication, dependencies=dependencies) -> dict:
chat_id = str(uuid4())
Expand All @@ -26,7 +31,7 @@ async def chat_new(current_user: User=authentication, dependencies=dependencies)
"INSERT INTO session (id, timestamp, user_id) VALUES (?, ?, ?)",
(chat_id, timestamp, user_id),
)
return {"chat_id": chat_id}
return {"session_id": chat_id}


@app.get("/session/list")
Expand All @@ -43,20 +48,26 @@ async def chat_list(current_user: User=authentication, dependencies=dependencies


@app.get("/session/{session_id}")
async def chat(session_id: str, dependencies=dependencies) -> dict:
async def chat(session_id: str, current_user: User=authentication, dependencies=dependencies) -> dict:
messages: List[Message] = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp, session_id, sender, content FROM message_history WHERE session_id = ? ORDER BY timestamp ASC",
"SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC",
(session_id,),
)
for row in result:
content = json.loads(row[3])["data"]["content"]
message_type = json.loads(row[3])["type"]
message = Message(
id=row[0],
timestamp=row[1],
chat_id=row[2],
sender=row[3],
content=row[4]
session_id=row[2],
sender=message_type if message_type == "human" else "ai",
content=content,
)
messages.append(message)
return {"chat_id": session_id, "messages": [message.model_dump() for message in messages]}
return {"chat_id": session_id, "messages": [message.dict() for message in messages]}

@app.get("/session")
async def session_root(current_user: User=authentication, dependencies=dependencies) -> dict:
return Response("Sessions management routes are enabled.", status_code=200)
9 changes: 9 additions & 0 deletions backend/api_plugins/sessions/sessions_tables.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- Dialect MUST be sqlite, even if the database you use is different.
-- It is transpiled to the right dialect when executed.

CREATE TABLE IF NOT EXISTS "session" (
"id" VARCHAR(255) PRIMARY KEY,
"timestamp" DATETIME,
"user_id" VARCHAR(255),
FOREIGN KEY ("user_id") REFERENCES "users" ("email")
);
65 changes: 31 additions & 34 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,60 +5,57 @@
import yaml
from dotenv import load_dotenv
from jinja2 import Template
from langchain.chat_models.base import BaseChatModel
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores import VectorStore
from langchain_core.language_models import LLM
from langchain_core.language_models.chat_models import BaseChatModel

load_dotenv()

@dataclass
class LLMConfig:
source: BaseChatModel | str = "AzureChatOpenAI"
source_config: dict = field(default_factory=lambda: {
"openai_api_type": "azure",
"openai_api_base": "https://poc-genai-gpt4.openai.azure.com/",
"openai_api_version": "2023-07-01-preview",
"openai_api_key": os.environ.get("OPENAI_API_KEY"),
"deployment_name": "gpt4v",
})

temperature: float = 0.1
source: BaseChatModel | LLM | str
source_config: dict

@dataclass
class VectorStoreConfig:
source: VectorStore | str = "Chroma"
source_config: dict = field(default_factory=lambda: {
"persist_directory": "vector_database/",
"collection_metadata": {
"hnsw:space": "cosine"
}
})

retriever_search_type: str = "similarity"
retriever_config: dict = field(default_factory=lambda: {
"top_k": 20,
"score_threshold": 0.5
})
source: VectorStore | str
source_config: dict

insertion_mode: str = "full" # "None", "full", "incremental"
retriever_search_type: str
retriever_config: dict
insertion_mode: str # "None", "full", "incremental"

@dataclass
class EmbeddingModelConfig:
source: Embeddings | str = "OpenAIEmbeddings"
source_config: dict = field(default_factory=lambda: {
"openai_api_type": "azure",
"openai_api_base": "https://poc-openai-artefact.openai.azure.com/",
"openai_api_key": os.environ.get("EMBEDDING_API_KEY"),
"deployment": "embeddings",
"chunk_size": 500,
})
source: Embeddings | str
source_config: dict

@dataclass
class DatabaseConfig:
database_url: str = os.environ.get("DATABASE_URL")
database_url: str

@dataclass
class RagConfig:
"""
Configuration class for the Retrieval-Augmented Generation (RAG) system.
It is meant to be injected in the RAG class to configure the various components.
This class holds the configuration for the various components that make up the RAG system, including
the language model, vector store, embedding model, and database configurations. It provides a method
to construct a RagConfig instance from a YAML file, allowing for easy external configuration.
Attributes:
llm (LLMConfig): Configuration for the language model component.
vector_store (VectorStoreConfig): Configuration for the vector store component.
embedding_model (EmbeddingModelConfig): Configuration for the embedding model component.
database (DatabaseConfig): Configuration for the database connection.
Methods:
from_yaml: Class method to create an instance of RagConfig from a YAML file, with optional environment
variables for template rendering.
"""

llm: LLMConfig = field(default_factory=LLMConfig)
vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig)
embedding_model: EmbeddingModelConfig = field(default_factory=EmbeddingModelConfig)
Expand Down
20 changes: 6 additions & 14 deletions backend/config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
LLMConfig: &LLMConfig
source: AzureChatOpenAI
source: ChatOllama
source_config:
openai_api_type: azure
openai_api_key: {{ OPENAI_API_KEY }}
openai_api_base: https://genai-ds.openai.azure.com/
openai_api_version: 2023-07-01-preview
deployment_name: gpt4
temperature: 0.1
model: tinyllama
temperature: 0

VectorStoreConfig: &VectorStoreConfig
source: Chroma
Expand All @@ -19,20 +15,16 @@ VectorStoreConfig: &VectorStoreConfig
retriever_config:
k: 20
score_threshold: 0.5

insertion_mode: null

EmbeddingModelConfig: &EmbeddingModelConfig
source: OpenAIEmbeddings
source: HuggingFaceEmbeddings
source_config:
openai_api_type: azure
openai_api_key: {{ EMBEDDING_API_KEY }}
openai_api_base: https://poc-openai-artefact.openai.azure.com/
deployment: embeddings
model_name: BAAI/bge-base-en-v1.5
chunk_size: 500

DatabaseConfig: &DatabaseConfig
database_url: {{ DATABASE_URL }}
database_url: sqlite:///database/rag.sqlite3

RagConfig:
llm: *LLMConfig
Expand Down
22 changes: 17 additions & 5 deletions backend/database.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import os
from logging import Logger
from pathlib import Path
from typing import Any, Optional

import sqlglot
from dbutils.pooled_db import PooledDB
from dotenv import load_dotenv
from sqlalchemy.engine.url import make_url

from backend import DATABASE_URL
from backend.logger import get_logger

load_dotenv()
POOL = None


class Database:
"""
Handles database operations.
Expand All @@ -28,15 +28,15 @@ class Database:
conn (Connection): The current database connection.
DIALECT_PLACEHOLDERS (dict): Mapping of database dialects to their placeholder symbols.
"""

DIALECT_PLACEHOLDERS = {
"sqlite": "?",
"postgresql": "%s",
"mysql": "%s",
}

def __init__(self, connection_string: str = None, logger: Logger = None):
self.connection_string = connection_string or os.getenv("DATABASE_URL")
self.connection_string = connection_string or DATABASE_URL
self.logger = logger or get_logger()

self.url = make_url(self.connection_string)
Expand Down Expand Up @@ -101,6 +101,18 @@ def initialize_schema(self):
self.logger.exception("Schema initialization failed", exc_info=e)
raise

def run_script(self, path: Path):
try:
self.logger.debug(f"Running Database script at {str(path)}")
sql_script = path.read_text()
transpiled_sql = sqlglot.transpile(sql_script, read="sqlite", write=self.url.drivername.replace("postgresql", "postgres"))
for statement in transpiled_sql:
self.execute(statement)
self.logger.info(f"Successfuly ran script at {path} for {self.url.drivername}")
except Exception as e:
self.logger.exception(f"Failed to execute the script {path} for {self.url.drivername}", exc_info=e)
raise

def _create_pool(self) -> PooledDB:
if self.connection_string.startswith("sqlite:///"):
import sqlite3
Expand Down
Loading

0 comments on commit c143d49

Please sign in to comment.