From bc1f241b1c36ec8f2c0483664c0b03a245ace847 Mon Sep 17 00:00:00 2001 From: CedricCortenraede Date: Wed, 19 Jun 2024 12:10:13 +0000 Subject: [PATCH] Ran Black formatted on python code in ./src/ There appeared to be some python formatting in bf6e156adb0e5359db2f84f32731b9db06cd724f that did not conform with Black's formatting standards. So Black(https://github.com/psf/black) formatter was used to fix these issues. --- src/api/api/main.py | 8 +++--- src/api/api/models/base.py | 2 +- src/api/api/routers/v1/internal.py | 31 ++++++++++++---------- src/api/api/routers/v1/internal_streams.py | 18 ++++++++----- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/src/api/api/main.py b/src/api/api/main.py index c433471..1d33196 100644 --- a/src/api/api/main.py +++ b/src/api/api/main.py @@ -32,7 +32,7 @@ async def init_db(app: Litestar) -> None: from models.base import Base - + # Import models. import models.country import models.stream @@ -43,7 +43,7 @@ async def init_db(app: Litestar) -> None: import db.seeders.country_seeder import db.seeders.stream_tag_seeder import db.seeders.stream_seeder - + async with app.state.db_engine.begin() as connection: await connection.run_sync(Base.metadata.create_all) @@ -119,12 +119,12 @@ def create_app_private() -> Litestar: app = create_app() app_private = create_app_private() - + if __name__ == "__main__": # Run the API (for debugging) # uvicorn.run("main:app", reload=True, reload_dirs="./", port=8002) - + subprocess.Popen( [ "litestar", diff --git a/src/api/api/models/base.py b/src/api/api/models/base.py index 1c2dcc4..fa2b68a 100644 --- a/src/api/api/models/base.py +++ b/src/api/api/models/base.py @@ -2,4 +2,4 @@ class Base(DeclarativeBase): - pass \ No newline at end of file + pass diff --git a/src/api/api/routers/v1/internal.py b/src/api/api/routers/v1/internal.py index 85f61be..b271ad1 100644 --- a/src/api/api/routers/v1/internal.py +++ b/src/api/api/routers/v1/internal.py @@ -25,16 +25,17 @@ class AnimalItem: class internalController(Controller): path = "/internal" tags = ["internal"] - + @get("/streams") - async def get_streams( - self, session: AsyncSession - ) -> list[Stream]: + async def get_streams(self, session: AsyncSession) -> list[Stream]: pass - + @post("/stream_animals") async def store_stream_animals( - self, session: AsyncSession, stream_id: int, data: Annotated[list[AnimalItem], Body()] + self, + session: AsyncSession, + stream_id: int, + data: Annotated[list[AnimalItem], Body()], ) -> Response: # Check if provided stream_id is valid. if not await session.scalars(select(Stream.id).filter_by(id=stream_id)).first(): @@ -43,14 +44,16 @@ async def store_stream_animals( content="Provided stream id is not valid.", status_code=422, ) - + # Save animals to provided stream_id. for animal in data: animal_name, animal_count = animal.animal, animal.count - + # Check if animal already exists in database, if not create animal. - animal_id = await session.scalars(select(Animal.id).filter_by(name=animal_name)).first() - + animal_id = await session.scalars( + select(Animal.id).filter_by(name=animal_name) + ).first() + # If no animal exists, create one. if not animal_id: # Create animal object, get extra information from external API. @@ -58,11 +61,11 @@ async def store_stream_animals( name=animal_name, ) session.add(animal_db) - + # Get id of newly created animal. session.flush() animal_id = animal_db.id - + # Link animal to stream_id. stmt = insert(streams_animals).values( stream_id=stream_id, @@ -74,10 +77,10 @@ async def store_stream_animals( set_={"count": streams_animals.c.count + animal_count}, ) await session.execute(stmt) - + # Save all changes to database. await session.commit() - + return Response( media_type=MediaType.TEXT, content="Successfully saved provided animals to stream.", diff --git a/src/api/api/routers/v1/internal_streams.py b/src/api/api/routers/v1/internal_streams.py index fdf61d6..2b49c78 100644 --- a/src/api/api/routers/v1/internal_streams.py +++ b/src/api/api/routers/v1/internal_streams.py @@ -18,8 +18,8 @@ class StreamRepository(SQLAlchemyAsyncRepository[Stream]): model_type = Stream - - + + async def provide_streams_repository(session: AsyncSession) -> StreamRepository: return StreamRepository(session=session) @@ -29,13 +29,17 @@ async def provide_streams_repository(session: AsyncSession) -> StreamRepository: class internalController(Controller): path = "/internal" tags = ["internal-streams"] - + dependencies = {"streams_repository": Provide(provide_streams_repository)} - + @get("/streams") async def get_streams(self, stream_repository: StreamRepository) -> list[Stream]: return await stream_repository.list() - + @get("/streams/{stream_id}") - async def get_stream(self, stream_repository: StreamRepository, stream_id: int) -> Stream: - return await stream_repository.get(item_id=stream_id, load=[Stream.tag, Stream.country, Stream.animals]) + async def get_stream( + self, stream_repository: StreamRepository, stream_id: int + ) -> Stream: + return await stream_repository.get( + item_id=stream_id, load=[Stream.tag, Stream.country, Stream.animals] + )