Skip to content

Commit

Permalink
add: database and auth mecanism
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Dec 13, 2023
1 parent 54cabbe commit 5936f34
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 84 deletions.
73 changes: 1 addition & 72 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,4 @@

# skaff-rag-accelerator

[![CI status](https://github.com/artefactory/skaff-rag-accelerator/actions/workflows/ci.yaml/badge.svg)](https://github.com/artefactory/skaff-rag-accelerator/actions/workflows/ci.yaml?query=branch%3Amain)
[![Python Version](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10-blue.svg)]()

[![Linting , formatting, imports sorting: ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
[![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit)
[![Pre-commit](https://img.shields.io/badge/pre--commit-enabled-informational?logo=pre-commit&logoColor=white)](https://github.com/artefactory/skaff-rag-accelerator/blob/main/.pre-commit-config.yaml)
</div>

TODO: if not done already, check out the [Skaff documentation](https://artefact.roadie.so/catalog/default/component/repo-builder-ds/docs/) for more information about the generated repository.

Deploy RAGs quickly

## Table of Contents

- [skaff-rag-accelerator](#skaff-rag-accelerator)
- [Table of Contents](#table-of-contents)
- [Installation](#installation)
- [Usage](#usage)
- [Documentation](#documentation)
- [Repository Structure](#repository-structure)

## Installation

To install the required packages in a virtual environment, run the following command:

```bash
make install
```

TODO: Choose between conda and venv if necessary or let the Makefile as is and copy/paste the [MORE INFO installation section](MORE_INFO.md#eased-installation) to explain how to choose between conda and venv.

A complete list of available commands can be found using the following command:

```bash
make help
```

## Usage

TODO: Add usage instructions here

## Documentation

TODO: Github pages is not enabled by default, you need to enable it in the repository settings: Settings > Pages > Source: "Deploy from a branch" / Branch: "gh-pages" / Folder: "/(root)"

A detailed documentation of this project is available [here](https://artefactory.github.io/skaff-rag-accelerator/)

To serve the documentation locally, run the following command:

```bash
mkdocs serve
```

To build it and deploy it to GitHub pages, run the following command:

```bash
make deploy_docs
```

## Repository Structure

```
.
├── .github <- GitHub Actions workflows and PR template
├── bin <- Bash files
├── config <- Configuration files
├── docs <- Documentation files (mkdocs)
├── lib <- Python modules
├── notebooks <- Jupyter notebooks
├── secrets <- Secret files (ignored by git)
└── tests <- Unit tests
```
</div>
41 changes: 41 additions & 0 deletions authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from datetime import timedelta, datetime
import os
from pydantic import BaseModel
from jose import jwt


from database.database import DatabaseConnection

SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key")
ALGORITHM = "HS256"

class User(BaseModel):
email: str = None
password: str = None

def create_user(user: User):
with DatabaseConnection() as connection:
connection.query("INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password))

def get_user(email: str):
with DatabaseConnection() as connection:
user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0]
for row in user_row:
return User(**row)
raise Exception("User not found")

def authenticate_user(username: str, password: str):
user = get_user(username)
if not user or not password == user.password:
return False
return user

def create_access_token(*, data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
Empty file added client/main.py
Empty file.
30 changes: 30 additions & 0 deletions database/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path
import sqlite3
from typing import List

class DatabaseConnection:
def __enter__(self):
self.conn = sqlite3.connect(Path(__file__).parent / "database.sqlite")
self.conn.row_factory = sqlite3.Row
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.conn.commit()
self.conn.close()

def query(self, query, params=None) -> List[List[sqlite3.Row]]:
cursor = self.conn.cursor()
results = []
commands = filter(None, query.split(";"))
for command in commands:
cursor.execute(command, params or ())
results.append(cursor.fetchall())
return results

def query_from_file(self, file_path):
with open(file_path, 'r') as file:
query = file.read()
self.query(query)

with DatabaseConnection() as connection:
connection.query_from_file(Path(__file__).parent / "database_init.sql")
33 changes: 33 additions & 0 deletions database/database_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
-- Go to https://dbdiagram.io/d/RAGAAS-63dbdcc6296d97641d7e07c8
-- Make your changes
-- Export > Export to PostgresSQL (or other)
-- Translate to SQLite (works with a cmd+k in Cursor, or https://www.rebasedata.com/convert-postgresql-to-sqlite-online)
-- Paste here
-- Replace "CREATE TABLE" with "CREATE TABLE IF NOT EXISTS"

CREATE TABLE IF NOT EXISTS "user" (
"email" TEXT PRIMARY KEY,
"password" TEXT
);

CREATE TABLE IF NOT EXISTS "chat" (
"id" TEXT PRIMARY KEY,
"user_id" TEXT,
FOREIGN KEY ("user_id") REFERENCES "user" ("email")
);

CREATE TABLE IF NOT EXISTS "message" (
"id" TEXT PRIMARY KEY,
"timestamp" TEXT,
"chat_id" TEXT,
"sender" TEXT,
"content" TEXT,
FOREIGN KEY ("chat_id") REFERENCES "chat" ("id")
);

CREATE TABLE IF NOT EXISTS "feedback" (
"id" TEXT PRIMARY KEY,
"message_id" TEXT,
"feedback" TEXT,
FOREIGN KEY ("message_id") REFERENCES "message" ("id")
);
4 changes: 2 additions & 2 deletions document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def persist_to_bucket(bucket_path: str, store: Chroma):


def store_documents(docs: List[Document], bucket_path: str, storage_backend: StorageBackend):
lagnchain_documents = [doc.to_langchain_document() for doc in docs]
langchain_documents = [doc.to_langchain_document() for doc in docs]
embeddings_model = OpenAIEmbeddings()
persistent_client = chromadb.PersistentClient()
collection = persistent_client.get_or_create_collection(get_storage_root_path(bucket_path, storage_backend))
collection.add(documents=lagnchain_documents)
collection.add(documents=langchain_documents)
langchain_chroma = Chroma(
client=persistent_client,
collection_name=bucket_path,
Expand Down
118 changes: 110 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,123 @@
from fastapi import FastAPI, HTTPException, status, Body
from datetime import timedelta
from typing import List
from langchain.docstore.document import Document
from document_store import StorageBackend

from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt, JWTError

import document_store
from model import ChatMessage
from authentication import (authenticate_user, create_access_token, create_user,
get_user, User, SECRET_KEY, ALGORITHM)
from document_store import StorageBackend
from model import Doc


app = FastAPI()


############################################
### Authentication ###
############################################

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get("email") # 'sub' is commonly used to store user identity
if email is None:
raise credentials_exception
# Here you should fetch the user from the database by user_id
user = get_user(email)
if user is None:
raise credentials_exception
return user
except JWTError:
raise credentials_exception

@app.post("/user/signup")
async def signup(user: User):
try:
user = get_user(user.email)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"User {user.email} already registered"
)
except Exception as e:
create_user(user)
return {"email": user.email}

@app.post("/user/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=60)
access_token = create_access_token(
data=user.model_dump(), expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}

@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)):
return current_user


############################################
### Chat ###
############################################
# P1
@app.post("/chat/new")
async def chat_new(current_user: User = Depends(get_current_user)):
pass

# P1
@app.post("/chat/user_message")
async def chat_prompt(current_user: User = Depends(get_current_user)):
pass

@app.get("/chat/list")
async def chat_list(current_user: User = Depends(get_current_user)):
pass

@app.get("/chat/{chat_id}")
async def chat(chat_id: str, current_user: User = Depends(get_current_user)):
pass


############################################
### Feedback ###
############################################

@app.post("/feedback/thumbs_up")
async def feedback_thumbs_up(current_user: User = Depends(get_current_user)):
pass

@app.post("/feedback/thumbs_down")
async def feedback_thumbs_down(current_user: User = Depends(get_current_user)):
pass

@app.post("/feedback/regenerate")
async def feedback_regenerate(current_user: User = Depends(get_current_user)):
pass


############################################
### Other ###
############################################

@app.post("/index/documents")
async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend):
document_store.store_documents(chunks, bucket, storage_backend)

@app.post("/chat")
async def chat(chat_message: ChatMessage):
pass


if __name__ == "__main__":
import uvicorn
Expand Down
1 change: 1 addition & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

class ChatMessage(BaseModel):
message: str
message_id: str
session_id: str

class Doc(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ universal_pathlib
chromadb
langchain
langchainhub
gpt4all
gpt4all
python-multipart
httpx
2 changes: 1 addition & 1 deletion sandbox_alexis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@

split_documents = load_and_split_document(text=data)
root_path = get_storage_root_path("dbt-server-alexis3-36fe-rag", StorageBackend.GCS)
vector_store = Chroma(persist_directory=str(root_path / "chromadb"), embedding_function=GPT4AllEmbeddings())
vector_store = Chroma(persist_directory=root_path / "chromadb", embedding_function=GPT4AllEmbeddings())
db = vector_store.add_documents(split_documents)
21 changes: 21 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from fastapi.testclient import TestClient
from main import app

client = TestClient(app)

def test_signup():
response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"})
assert response.status_code == 200
assert response.json()["email"] == "[email protected]"

def test_login():
response = client.post("/user/login", data={"username": "[email protected]", "password": "testpassword"})
assert response.status_code == 200
assert "access_token" in response.json()

def test_user_me():
login_response = client.post("/user/login", data={"username": "[email protected]", "password": "testpassword"})
token = login_response.json()["access_token"]
response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
assert response.json()["email"] == "[email protected]"

0 comments on commit 5936f34

Please sign in to comment.