Skip to content

Commit

Permalink
✅ pass pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Apr 15, 2024
1 parent 8f9c34f commit e40c3e7
Show file tree
Hide file tree
Showing 41 changed files with 457 additions and 205 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This is a starter kit to deploy a modularizable RAG locally or on the cloud (or

## Quickstart

This quickstart will guide you through the steps to serve the RAG and load a few documents.
This quickstart will guide you through the steps to serve the RAG and load a few documents.

You will run both the back and front on your machine.

Expand Down
2 changes: 1 addition & 1 deletion backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ EXPOSE $PORT

COPY . ./backend

CMD python -m uvicorn backend.main:app --host 0.0.0.0 --port $PORT
CMD python -m uvicorn backend.main:app --host 0.0.0.0 --port $PORT
14 changes: 12 additions & 2 deletions backend/api_plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from backend.api_plugins.insecure_authentication.insecure_authentication import insecure_authentication_routes
from backend.api_plugins.secure_authentication.secure_authentication import authentication_routes
from backend.api_plugins.insecure_authentication.insecure_authentication import (
insecure_authentication_routes,
)
from backend.api_plugins.secure_authentication.secure_authentication import (
authentication_routes,
)
from backend.api_plugins.sessions.sessions import session_routes

__all__ = [
"insecure_authentication_routes",
"authentication_routes",
"session_routes",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,30 @@ async def get_current_user(email: str) -> User:
async def signup(email: str) -> dict:
user = User(email=email)
if user_exists(user.email):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"User {user.email} already registered",
)
create_user(user)
return {"email": user.email}


@app.delete("/user/")
async def del_user(current_user: User = Depends(get_current_user)) -> dict:
email = current_user.email
try:
user = get_user(email)
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User {email} not found",
)
delete_user(email)
return {"detail": f"User {email} deleted"}
except Exception:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error")

raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error",
)

@app.post("/user/login")
async def login(email: str) -> dict:
Expand All @@ -51,16 +58,20 @@ async def login(email: str) -> dict:
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username",
)
return {"access_token": email, "token_type": "bearer"} # Fake bearer token to still provide user authentication

return {
"access_token": email,
"token_type": "bearer",
} # Fake bearer token to still provide user authentication

@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)) -> User:
return current_user


@app.get("/user")
async def user_root() -> dict:
return Response("Insecure user management routes are enabled. Do not use in prod.", status_code=200)
return Response(
"Insecure user management routes are enabled. Do not use in prod.",
status_code=200,
)

return Depends(get_current_user)
13 changes: 10 additions & 3 deletions backend/api_plugins/lib/user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class UnsecureUser(BaseModel):
email: str = None
password: bytes = None


class User(BaseModel):
email: str = None
hashed_password: str = None
Expand All @@ -26,7 +27,8 @@ def from_unsecure_user(cls, unsecure_user: UnsecureUser):
def create_user(user: User) -> None:
with Database() as connection:
connection.execute(
"INSERT INTO users (email, password) VALUES (?, ?)", (user.email, user.hashed_password)
"INSERT INTO users (email, password) VALUES (?, ?)",
(user.email, user.hashed_password),
)


Expand Down Expand Up @@ -54,12 +56,17 @@ def authenticate_user(username: str, password: bytes) -> bool | User:
if not user:
return False

if argon2.verify_password(user.hashed_password.encode("utf-8"), password.encode("utf-8")):
if argon2.verify_password(
user.hashed_password.encode("utf-8"), password.encode("utf-8")
):
return user

return False

def create_access_token(*, data: dict, expires_delta: Optional[timedelta] = None) -> str:

def create_access_token(
*, data: dict, expires_delta: Optional[timedelta] = None
) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
Expand Down
20 changes: 10 additions & 10 deletions backend/api_plugins/secure_authentication/secure_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

def authentication_routes(app, dependencies=List[Depends]):
from backend.database import Database

with Database() as connection:
connection.run_script(Path(__file__).parent / "users_tables.sql")

Expand All @@ -46,39 +47,41 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
except JWTError:
raise credentials_exception


@app.post("/user/signup", include_in_schema=ADMIN_MODE)
async def signup(user: UnsecureUser) -> dict:
if not ADMIN_MODE:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Signup is disabled"
)

user = User.from_unsecure_user(user)
if user_exists(user.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered"
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"User {user.email} already registered",
)

create_user(user)
return {"email": user.email}


@app.delete("/user/")
async def del_user(current_user: User = Depends(get_current_user)) -> dict:
email = current_user.email
try:
user = get_user(email)
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found"
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User {email} not found",
)
delete_user(email)
return {"detail": f"User {email} deleted"}
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error"
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error",
)


@app.post("/user/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
user = authenticate_user(form_data.username, form_data.password)
Expand All @@ -93,15 +96,12 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
access_token = create_access_token(data=user_data)
return {"access_token": access_token, "token_type": "bearer"}


@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)) -> User:
return current_user


@app.get("/user")
async def user_root() -> dict:
return Response("User management routes are enabled.", status_code=200)


return Depends(get_current_user)
29 changes: 20 additions & 9 deletions backend/api_plugins/sessions/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def session_routes(
connection.run_script(Path(__file__).parent / "sessions_tables.sql")

@app.post("/session/new")
async def chat_new(current_user: User=authentication, dependencies=dependencies) -> dict:
async def chat_new(
current_user: User = authentication, dependencies=dependencies
) -> dict:
chat_id = str(uuid4())
timestamp = datetime.utcnow().isoformat()
user_id = current_user.email if current_user else "unauthenticated"
Expand All @@ -33,26 +35,30 @@ async def chat_new(current_user: User=authentication, dependencies=dependencies)
)
return {"session_id": chat_id}


@app.get("/session/list")
async def chat_list(current_user: User=authentication, dependencies=dependencies) -> List[dict]:
async def chat_list(
current_user: User = authentication, dependencies=dependencies
) -> List[dict]:
user_email = current_user.email if current_user else "unauthenticated"
chats = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp DESC",
"SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp"
" DESC",
(user_email,),
)
chats = [{"id": row[0], "timestamp": row[1]} for row in result]
return chats


@app.get("/session/{session_id}")
async def chat(session_id: str, current_user: User=authentication, dependencies=dependencies) -> dict:
async def chat(
session_id: str, current_user: User = authentication, dependencies=dependencies
) -> dict:
messages: List[Message] = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC",
"SELECT id, timestamp, session_id, message FROM message_history WHERE "
"session_id = ? ORDER BY timestamp ASC",
(session_id,),
)
for row in result:
Expand All @@ -66,8 +72,13 @@ async def chat(session_id: str, current_user: User=authentication, dependencies=
content=content,
)
messages.append(message)
return {"chat_id": session_id, "messages": [message.dict() for message in messages]}
return {
"chat_id": session_id,
"messages": [message.dict() for message in messages],
}

@app.get("/session")
async def session_root(current_user: User=authentication, dependencies=dependencies) -> dict:
async def session_root(
current_user: User = authentication, dependencies=dependencies
) -> dict:
return Response("Sessions management routes are enabled.", status_code=200)
33 changes: 20 additions & 13 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,62 @@

load_dotenv()


@dataclass
class LLMConfig:
source: BaseChatModel | LLM | str
source_config: dict


@dataclass
class VectorStoreConfig:
source: VectorStore | str
source_config: dict

insertion_mode: str # "None", "full", "incremental"


@dataclass
class EmbeddingModelConfig:
source: Embeddings | str
source_config: dict


@dataclass
class DatabaseConfig:
database_url: str


@dataclass
class RagConfig:
"""
Configuration class for the Retrieval-Augmented Generation (RAG) system.
It is meant to be injected in the RAG class to configure the various components.
This class holds the configuration for the various components that make up the RAG system, including
the language model, vector store, embedding model, and database configurations. It provides a method
to construct a RagConfig instance from a YAML file, allowing for easy external configuration.
This class holds the configuration for the various components that make up the RAG
system, including the language model, vector store, embedding model, and database
configurations. It provides a method to construct a RagConfig instance from a YAML
file, allowing for easy external configuration.
Attributes:
llm (LLMConfig): Configuration for the language model component.
vector_store (VectorStoreConfig): Configuration for the vector store component.
embedding_model (EmbeddingModelConfig): Configuration for the embedding model component.
embedding_model (EmbeddingModelConfig): Configuration for the embedding model
component.
database (DatabaseConfig): Configuration for the database connection.
Methods:
from_yaml: Class method to create an instance of RagConfig from a YAML file, with optional environment
variables for template rendering.
from_yaml: Class method to create an instance of RagConfig from a YAML file,
with optional environment variables for template rendering.
"""

llm: LLMConfig = field(default_factory=LLMConfig)
vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig)
embedding_model: EmbeddingModelConfig = field(default_factory=EmbeddingModelConfig)
database: DatabaseConfig = field(default_factory=DatabaseConfig)
chat_history_window_size: int = 5
max_tokens_limit: int = 3000
response_mode: str = None
llm: LLMConfig = field(default_factory=LLMConfig)
vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig)
embedding_model: EmbeddingModelConfig = field(default_factory=EmbeddingModelConfig)
database: DatabaseConfig = field(default_factory=DatabaseConfig)
chat_history_window_size: int = 5
max_tokens_limit: int = 3000
response_mode: str = None

@classmethod
def from_yaml(cls, yaml_path: Path, env: dict = None):
Expand Down
Loading

0 comments on commit e40c3e7

Please sign in to comment.