From bf6e156adb0e5359db2f84f32731b9db06cd724f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Cortenraede?= Date: Wed, 19 Jun 2024 14:09:22 +0200 Subject: [PATCH] Added database table creation --- .../api/db/connector/postgres_connection.py | 1 - src/api/api/db/data/streams.csv | 2 +- src/api/api/db/postgres.py | 22 ------ src/api/api/main.py | 25 ++++++ src/api/api/models/animal.py | 2 +- src/api/api/models/base.py | 5 ++ src/api/api/models/country.py | 2 +- src/api/api/models/stream.py | 2 +- src/api/api/models/streams_animals.py | 4 +- src/api/api/routers/v1/internal.py | 77 ++++++++++++++++++- src/api/api/routers/v1/internal_streams.py | 41 ++++++++++ 11 files changed, 152 insertions(+), 31 deletions(-) delete mode 100644 src/api/api/db/postgres.py create mode 100644 src/api/api/models/base.py create mode 100644 src/api/api/routers/v1/internal_streams.py diff --git a/src/api/api/db/connector/postgres_connection.py b/src/api/api/db/connector/postgres_connection.py index 060ec5d..4c00629 100644 --- a/src/api/api/db/connector/postgres_connection.py +++ b/src/api/api/db/connector/postgres_connection.py @@ -12,7 +12,6 @@ from core.config import settings -from db.postgres import create_tables # @asynccontextmanager diff --git a/src/api/api/db/data/streams.csv b/src/api/api/db/data/streams.csv index e66f348..d0b403c 100644 --- a/src/api/api/db/data/streams.csv +++ b/src/api/api/db/data/streams.csv @@ -90,7 +90,7 @@ Görbeháza,https://youtu.be/YfqeaUnj_EY,Other,Hungary,Hajdú-Bihar,47.820573748 Somogy,https://youtu.be/jhA9eugQJoo,Other,Hungary,Somogy,46.441561975993835, 17.576754124448122 Hai-Bar Nature Reserve,https://youtu.be/3Cq9kfMqXu4,Other,Israel,Haifa,32.75344231467224,35.01646219959212 Gamla Nature Reserve,https://youtu.be/8mi2qdmUVmI,Other,Israel,Golan Heights,32.90383296931956,35.75227027761189 -Hula Valley,https://youtu.be/h4OHj17aPck,Other,Isreal,Hula Valley,33.111275533691966,35.60408346850616 +Hula Valley,https://youtu.be/h4OHj17aPck,Other,Israel,Hula Valley,33.111275533691966,35.60408346850616 Makov,https://youtu.be/S46DdA8Mc4I,Other,Czechia,Pardubice,49.85484207606672,16.193971454841716 Mississippi River,https://youtu.be/Hkj9L-HKXJU,Other,United States of America,Wisconson,43.933960379423326,-91.36723896289995 Wellington,https://youtu.be/L9Qs9kuTA10,Other,New Zealand,Wellington,-41.26555225076079,174.74304855496442 diff --git a/src/api/api/db/postgres.py b/src/api/api/db/postgres.py deleted file mode 100644 index aa10c82..0000000 --- a/src/api/api/db/postgres.py +++ /dev/null @@ -1,22 +0,0 @@ -from sqlalchemy.orm import DeclarativeBase - - -class Base(DeclarativeBase): - pass - - -async def create_tables(engine): - # Import models. - import models.country - import models.stream - import models.animal - import models.streams_animals - - # Import seeders. - import db.seeders.country_seeder - import db.seeders.stream_tag_seeder - import db.seeders.stream_seeder - - # Create tables. - async with engine.begin() as connection: - await connection.run_sync(Base.metadata.create_all) diff --git a/src/api/api/main.py b/src/api/api/main.py index ccb7189..c433471 100644 --- a/src/api/api/main.py +++ b/src/api/api/main.py @@ -1,3 +1,4 @@ +from pathlib import Path import subprocess import litestar.cli.commands.core @@ -29,6 +30,24 @@ db_config = postgres_connection() +async def init_db(app: Litestar) -> None: + from models.base import Base + + # Import models. + import models.country + import models.stream + import models.animal + import models.streams_animals + + # Import seeders. + import db.seeders.country_seeder + import db.seeders.stream_tag_seeder + import db.seeders.stream_seeder + + async with app.state.db_engine.begin() as connection: + await connection.run_sync(Base.metadata.create_all) + + def create_app() -> Litestar: # Setup Litestar application and return this return Litestar( @@ -59,6 +78,9 @@ def create_app() -> Litestar: # postgres_connection, redis_connection, ], + on_startup=[ + init_db, + ], ) @@ -86,6 +108,7 @@ def create_app_private() -> Litestar: ), plugins=[ StructlogPlugin(StructlogConfig(config)), + # SQLAlchemyPlugin(config=db_config), ], lifespan=[ postgres_connection, @@ -96,10 +119,12 @@ def create_app_private() -> Litestar: app = create_app() app_private = create_app_private() + if __name__ == "__main__": # Run the API (for debugging) # uvicorn.run("main:app", reload=True, reload_dirs="./", port=8002) + subprocess.Popen( [ "litestar", diff --git a/src/api/api/models/animal.py b/src/api/api/models/animal.py index b58345c..ffab2c1 100644 --- a/src/api/api/models/animal.py +++ b/src/api/api/models/animal.py @@ -5,7 +5,7 @@ mapped_column, relationship, ) -from db.postgres import Base +from models.base import Base from models.stream import Stream diff --git a/src/api/api/models/base.py b/src/api/api/models/base.py new file mode 100644 index 0000000..1c2dcc4 --- /dev/null +++ b/src/api/api/models/base.py @@ -0,0 +1,5 @@ +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + pass \ No newline at end of file diff --git a/src/api/api/models/country.py b/src/api/api/models/country.py index f1ca7ba..c51ae8b 100644 --- a/src/api/api/models/country.py +++ b/src/api/api/models/country.py @@ -5,7 +5,7 @@ mapped_column, relationship, ) -from db.postgres import Base +from models.base import Base from models.stream import Stream diff --git a/src/api/api/models/stream.py b/src/api/api/models/stream.py index 89cc6d8..6f2e5ba 100644 --- a/src/api/api/models/stream.py +++ b/src/api/api/models/stream.py @@ -5,7 +5,7 @@ mapped_column, relationship, ) -from db.postgres import Base +from models.base import Base class StreamTag(Base): diff --git a/src/api/api/models/streams_animals.py b/src/api/api/models/streams_animals.py index 51648d9..d6264f6 100644 --- a/src/api/api/models/streams_animals.py +++ b/src/api/api/models/streams_animals.py @@ -1,8 +1,8 @@ -from db.postgres import Base +from models.base import Base from sqlalchemy import Table, Column, ForeignKey from sqlalchemy.types import Integer -Table( +streams_animals = Table( "streams_animals", Base.metadata, Column("stream_id", ForeignKey("streams.id"), primary_key=True), diff --git a/src/api/api/routers/v1/internal.py b/src/api/api/routers/v1/internal.py index 7be9c9f..85f61be 100644 --- a/src/api/api/routers/v1/internal.py +++ b/src/api/api/routers/v1/internal.py @@ -1,8 +1,23 @@ -from litestar import Controller, get, Request +from dataclasses import dataclass +from typing import Annotated +from litestar import Controller, get, Request, post, Response, MediaType from litestar.exceptions import * +from litestar.enums import RequestEncodingType +from litestar.params import Body +from sqlalchemy.ext.asyncio import AsyncSession -from modules.weather_information import get_weather_information +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy import select +from models.animal import Animal +from models.stream import Stream from litestar.datastructures import State +from models.streams_animals import streams_animals + + +@dataclass +class AnimalItem: + animal: str + count: int # TODO: exclude from schemas @@ -10,3 +25,61 @@ class internalController(Controller): path = "/internal" tags = ["internal"] + + @get("/streams") + async def get_streams( + self, session: AsyncSession + ) -> list[Stream]: + pass + + @post("/stream_animals") + async def store_stream_animals( + self, session: AsyncSession, stream_id: int, data: Annotated[list[AnimalItem], Body()] + ) -> Response: + # Check if provided stream_id is valid. + if not await session.scalars(select(Stream.id).filter_by(id=stream_id)).first(): + return Response( + media_type=MediaType.TEXT, + content="Provided stream id is not valid.", + status_code=422, + ) + + # Save animals to provided stream_id. + for animal in data: + animal_name, animal_count = animal.animal, animal.count + + # Check if animal already exists in database, if not create animal. + animal_id = await session.scalars(select(Animal.id).filter_by(name=animal_name)).first() + + # If no animal exists, create one. + if not animal_id: + # Create animal object, get extra information from external API. + animal_db = Animal( + name=animal_name, + ) + session.add(animal_db) + + # Get id of newly created animal. + session.flush() + animal_id = animal_db.id + + # Link animal to stream_id. + stmt = insert(streams_animals).values( + stream_id=stream_id, + animal_id=animal_id, + count=animal_count, + ) + stmt = stmt.on_conflict_do_update( + index_elements=["stream_id", "animal_id"], + set_={"count": streams_animals.c.count + animal_count}, + ) + await session.execute(stmt) + + # Save all changes to database. + await session.commit() + + return Response( + media_type=MediaType.TEXT, + content="Successfully saved provided animals to stream.", + status_code=201, + ) diff --git a/src/api/api/routers/v1/internal_streams.py b/src/api/api/routers/v1/internal_streams.py new file mode 100644 index 0000000..fdf61d6 --- /dev/null +++ b/src/api/api/routers/v1/internal_streams.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from typing import Annotated +from litestar import Controller, get, Request, post, Response, MediaType +from litestar.exceptions import * +from litestar.enums import RequestEncodingType +from litestar.params import Body +from sqlalchemy.ext.asyncio import AsyncSession + +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy import select +from models.animal import Animal +from models.stream import Stream +from litestar.datastructures import State +from models.streams_animals import streams_animals +from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository +from litestar.di import Provide + + +class StreamRepository(SQLAlchemyAsyncRepository[Stream]): + model_type = Stream + + +async def provide_streams_repository(session: AsyncSession) -> StreamRepository: + return StreamRepository(session=session) + + +# TODO: exclude from schemas +# Controller for internal endpoints +class internalController(Controller): + path = "/internal" + tags = ["internal-streams"] + + dependencies = {"streams_repository": Provide(provide_streams_repository)} + + @get("/streams") + async def get_streams(self, stream_repository: StreamRepository) -> list[Stream]: + return await stream_repository.list() + + @get("/streams/{stream_id}") + async def get_stream(self, stream_repository: StreamRepository, stream_id: int) -> Stream: + return await stream_repository.get(item_id=stream_id, load=[Stream.tag, Stream.country, Stream.animals])