From 2a7daadd87c90872a981c5b11d2ec72244b56647 Mon Sep 17 00:00:00 2001 From: Alexis VIALARET Date: Thu, 14 Dec 2023 09:05:30 +0100 Subject: [PATCH] upd: better db class, better testing --- authentication.py | 6 ++-- database/database.py | 19 ++++++++++-- main.py | 31 +++++++++++++++---- sandbox_alexis/storage_backend.py | 15 ---------- tests/test_api.py | 40 +++++++++++++++++++++---- user_management.py | 50 +++++++++++++++++++++++++++++++ 6 files changed, 129 insertions(+), 32 deletions(-) delete mode 100644 sandbox_alexis/storage_backend.py create mode 100644 user_management.py diff --git a/authentication.py b/authentication.py index e796a22..92e477f 100644 --- a/authentication.py +++ b/authentication.py @@ -4,7 +4,7 @@ from jose import jwt -from database.database import DatabaseConnection +from database.database import Database SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key") ALGORITHM = "HS256" @@ -14,11 +14,11 @@ class User(BaseModel): password: str = None def create_user(user: User): - with DatabaseConnection() as connection: + with Database() as connection: connection.query("INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password)) def get_user(email: str): - with DatabaseConnection() as connection: + with Database() as connection: user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0] for row in user_row: return User(**row) diff --git a/database/database.py b/database/database.py index 0390c9a..bab151a 100644 --- a/database/database.py +++ b/database/database.py @@ -1,14 +1,21 @@ +import os from pathlib import Path import sqlite3 from typing import List -class DatabaseConnection: +class Database: + def __init__(self): + db_name = "test.sqlite" if os.getenv("TESTING", "false").lower() == "true" else "database.sqlite" + self.db_path = Path(__file__).parent / db_name + def __enter__(self): - self.conn = sqlite3.connect(Path(__file__).parent / "database.sqlite") + self.conn = sqlite3.connect(self.db_path) self.conn.row_factory = sqlite3.Row return self def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + self.conn.rollback() self.conn.commit() self.conn.close() @@ -26,5 +33,11 @@ def query_from_file(self, file_path): query = file.read() self.query(query) -with DatabaseConnection() as connection: + def delete_db(self): + if self.conn: + self.conn.close() + if self.db_path.exists(): + self.db_path.unlink(missing_ok=True) + +with Database() as connection: connection.query_from_file(Path(__file__).parent / "database_init.sql") \ No newline at end of file diff --git a/main.py b/main.py index 2eabb6b..c089cdd 100644 --- a/main.py +++ b/main.py @@ -6,8 +6,8 @@ from jose import jwt, JWTError import document_store -from authentication import (authenticate_user, create_access_token, create_user, - get_user, User, SECRET_KEY, ALGORITHM) +from user_management import (authenticate_user, create_access_token, create_user, + get_user, User, SECRET_KEY, ALGORITHM, user_exists) from document_store import StorageBackend from model import Doc @@ -41,16 +41,35 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: @app.post("/user/signup") async def signup(user: User): - try: - user = get_user(user.email) + if user_exists(user.email): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered" ) - except Exception as e: - create_user(user) + + create_user(user) return {"email": user.email} + +@app.delete("/user/") +async def delete_user(current_user: User = Depends(get_current_user)): + email = current_user.email + try: + user = get_user(email) + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User {email} not found" + ) + delete_user(email) + return {"detail": f"User {email} deleted"} + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal Server Error" + ) + + @app.post("/user/login") async def login(form_data: OAuth2PasswordRequestForm = Depends()): user = authenticate_user(form_data.username, form_data.password) diff --git a/sandbox_alexis/storage_backend.py b/sandbox_alexis/storage_backend.py deleted file mode 100644 index 693f63e..0000000 --- a/sandbox_alexis/storage_backend.py +++ /dev/null @@ -1,15 +0,0 @@ -from upath import UPath as Path -from enum import Enum - - -class StorageBackend(Enum): - LOCAL = "local" - MEMORY = "memory" - GCS = "gcs" - S3 = "s3" - AZURE = "az" - - -def get_storage_root_path(bucket_name, storage_backend: StorageBackend): - root_path = Path(f"{storage_backend.value}://{bucket_name}") - return root_path diff --git a/tests/test_api.py b/tests/test_api.py index ca663d0..893afae 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,21 +1,51 @@ +import os +os.environ["TESTING"] = "True" + +from pathlib import Path from fastapi.testclient import TestClient +import pytest + +from database.database import Database from main import app client = TestClient(app) -def test_signup(): +@pytest.fixture() +def initialize_database(): + db = Database() + with db: + db.query_from_file(Path(__file__).parents[1] / "database" / "database_init.sql") + yield db + db.delete_db() + +def test_signup(initialize_database): + response = client.post("/user/signup", json={"email": "test@example.com", "password": "testpassword"}) + assert response.status_code == 200 + assert response.json()["email"] == "test@example.com" + + response = client.post("/user/signup", json={"email": "test@example.com", "password": "testpassword"}) + assert response.status_code == 400 + assert "detail" in response.json() + assert response.json()["detail"] == "User test@example.com already registered" + +def test_login(initialize_database): response = client.post("/user/signup", json={"email": "test@example.com", "password": "testpassword"}) assert response.status_code == 200 assert response.json()["email"] == "test@example.com" + response = client.post("/user/login", data={"username": "test@example.com", "password": "testpassword"}) + assert response.status_code == 200 + assert "access_token" in response.json() -def test_login(): +def test_user_me(initialize_database): + response = client.post("/user/signup", json={"email": "test@example.com", "password": "testpassword"}) + assert response.status_code == 200 + assert response.json()["email"] == "test@example.com" + response = client.post("/user/login", data={"username": "test@example.com", "password": "testpassword"}) assert response.status_code == 200 assert "access_token" in response.json() -def test_user_me(): - login_response = client.post("/user/login", data={"username": "test@example.com", "password": "testpassword"}) - token = login_response.json()["access_token"] + token = response.json()["access_token"] response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"}) assert response.status_code == 200 assert response.json()["email"] == "test@example.com" diff --git a/user_management.py b/user_management.py new file mode 100644 index 0000000..0574614 --- /dev/null +++ b/user_management.py @@ -0,0 +1,50 @@ +from datetime import timedelta, datetime +import os +from pydantic import BaseModel +from jose import jwt + + +from database.database import Database + +SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key") +ALGORITHM = "HS256" + +class User(BaseModel): + email: str = None + password: str = None + +def create_user(user: User): + with Database() as connection: + connection.query("INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password)) + +def user_exists(email: str) -> bool: + with Database() as connection: + result = connection.query("SELECT 1 FROM user WHERE email = ?", (email,))[0] + return bool(result) + +def get_user(email: str): + with Database() as connection: + user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0] + for row in user_row: + return User(**row) + raise Exception("User not found") + +def delete_user(email: str): + with Database() as connection: + connection.query("DELETE FROM user WHERE email = ?", (email,)) + +def authenticate_user(username: str, password: str): + user = get_user(username) + if not user or not password == user.password: + return False + return user + +def create_access_token(*, data: dict, expires_delta: timedelta = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt