Skip to content

Commit

Permalink
chore: add repositories
Browse files Browse the repository at this point in the history
  • Loading branch information
dantetemplar committed Oct 30, 2023
1 parent d11ffee commit 615c5d5
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 40 deletions.
4 changes: 4 additions & 0 deletions src/repositories/some_module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__all__ = ["AbstractSomeModuleRepository", "SomeModuleRepository"]

from src.repositories.some_module.abc import AbstractSomeModuleRepository
from src.repositories.some_module.repository import SomeModuleRepository
37 changes: 37 additions & 0 deletions src/repositories/some_module/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
__all__ = ["AbstractSomeModuleRepository"]

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
# safe import during type checking stage;
# we can do it as abstraction layer doesn't use any scheme
from src.schemas import CreateSomeScheme, ViewSomeScheme, UpdateSomeScheme


class AbstractSomeModuleRepository(metaclass=ABCMeta):
# ----------------- CRUD ----------------- #
@abstractmethod
async def create(self, data: "CreateSomeScheme") -> "ViewSomeScheme":
...

@abstractmethod
async def batch_create(self, data: list["CreateSomeScheme"]) -> list["ViewSomeScheme"]:

Check failure on line 19 in src/repositories/some_module/abc.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

src/repositories/some_module/abc.py:19:89: E501 Line too long (91 > 88 characters)
...

@abstractmethod
async def read(self, id: int) -> "ViewSomeScheme":
...

@abstractmethod
async def read_all(self) -> list["ViewSomeScheme"]:
...

@abstractmethod
async def batch_read(self, ids: list[int]) -> list["ViewSomeScheme"]:
...

@abstractmethod
async def update(self, id_: int, data: "UpdateSomeScheme"):
...
# ^^^^^^^^^^^^^^^^^ CRUD ^^^^^^^^^^^^^^^^^ #
103 changes: 103 additions & 0 deletions src/repositories/some_module/repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
__all__ = ["SomeModuleRepository"]

from typing import Optional

from sqlalchemy import select, insert, or_, update
from sqlalchemy.dialects.postgresql import insert as postgres_insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.base import ExecutableOption

from src.repositories.some_module.abc import AbstractSomeModuleRepository
from src.schemas import ViewSomeScheme, CreateSomeScheme, UpdateSomeScheme
from src.storages.sqlalchemy import AbstractSQLAlchemyStorage
from src.storages.sqlalchemy.models import SomeModel

# get_options = (selectinload(Model.some_field_raising_greenlet), )
get_options: tuple[ExecutableOption, ...] = ()


class SomeModuleRepository(AbstractSomeModuleRepository):
storage: AbstractSQLAlchemyStorage

def __init__(self, storage: AbstractSQLAlchemyStorage):
self.storage = storage

def _create_session(self) -> AsyncSession:
return self.storage.create_session()

# ----------------- CRUD ----------------- #

async def create(self, data: CreateSomeScheme) -> ViewSomeScheme:
async with self._create_session() as session:
_insert_query = insert(SomeModel).returning(SomeModel)
if get_options:
_insert_query = _insert_query.options(*get_options)
obj = await session.scalar(_insert_query, params=data.model_dump())
await session.commit()
return ViewSomeScheme.model_validate(obj)

async def batch_create(self, data: list[CreateSomeScheme]) -> list[ViewSomeScheme]:
async with self._create_session() as session:
if not data:
return []
_insert_query = insert(SomeModel).returning(SomeModel)
if get_options:

Check failure on line 44 in src/repositories/some_module/repository.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

src/repositories/some_module/repository.py:44:89: E501 Line too long (98 > 88 characters)
_insert_query = _insert_query.options(*get_options)
objs = await session.scalars(_insert_query, params=[obj.model_dump() for obj in data])
await session.commit()
return [ViewSomeScheme.model_validate(obj) for obj in objs]

async def read(self, id: int) -> ViewSomeScheme:
async with self._create_session() as session:
q = select(SomeModel).where(SomeModel.id == id)
if get_options:
q = q.options(*get_options)
obj = await session.scalar(q)

if obj:
return ViewSomeScheme.model_validate(obj)

async def read_all(self) -> list[ViewSomeScheme]:
async with self._create_session() as session:
q = select(SomeModel)
if get_options:
q = q.options(*get_options)
objs = await session.scalars(q)
return [ViewSomeScheme.model_validate(obj) for obj in objs]

async def batch_read(self, ids: list[int]) -> list[ViewSomeScheme]:
async with self._create_session() as session:
if not ids:
return []

q = select(SomeModel).where(
or_(
*[SomeModel.id == id for id in ids],
)
)

if get_options:
q = q.options(*get_options)
objs = await session.scalars(q)

return [ViewSomeScheme.model_validate(obj) for obj in objs]

async def update(self, id_: int, data: UpdateSomeScheme):
async with self._create_session() as session:
q = (
update(SomeModel)
.where(SomeModel.id == id_)
.values(**data.model_dump())
.returning(SomeModel)
)

if get_options:
q = q.options(*get_options)

obj = await session.scalar(q)
await session.commit()

if obj:
return ViewSomeScheme.model_validate(obj)

# ^^^^^^^^^^^^^^^^^ CRUD ^^^^^^^^^^^^^^^^^ #
10 changes: 5 additions & 5 deletions src/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# fmt: off
__all__ = [
"CreateSomeSchemeInSingle", "ViewSomeSchemeInSingle", "UpdateSomeSchemeInSingle",
"CreateSomeScheme", "ViewSomeScheme", "UpdateSomeScheme",
]

# fmt: on

from src.schemas.some_scheme_in_plural import (
CreateSomeSchemeInSingle,
ViewSomeSchemeInSingle,
UpdateSomeSchemeInSingle,
from src.schemas.some_scheme import (
CreateSomeScheme,
ViewSomeScheme,
UpdateSomeScheme,
)
25 changes: 25 additions & 0 deletions src/schemas/some_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
__all__ = [
"CreateSomeScheme",
"ViewSomeScheme",
"UpdateSomeScheme",
]

from pydantic import BaseModel, ConfigDict


class CreateSomeScheme(BaseModel):
...


class ViewSomeScheme(BaseModel):
model_config = ConfigDict(
from_attributes=True,
)


class UpdateSomeScheme(BaseModel):
...

# Note: if some relation is needed, add it here
# from src.schemas.some_scheme2_in_plural import ViewSomeScheme2InPlural
# ViewSomeSchemeInSingle.model_rebuild()
24 changes: 0 additions & 24 deletions src/schemas/some_scheme_in_plural.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/storages/sqlalchemy/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
import src.storages.sql.models.__mixin__ # noqa: F401, E402

# Add all models here
from src.storages.sqlalchemy.models.some_model_in_plural import SomeModelInSingle
from src.storages.sqlalchemy.models.some_model import SomeModel

__all__ = ["Base", "SomeModelInSingle"]
__all__ = ["Base", "SomeModel"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["SomeModelInSingle", "SomeModelInSingleXTag"]
__all__ = ["SomeModel", "SomeModelInSingleXTag"]

from typing import TYPE_CHECKING

Expand All @@ -20,29 +20,29 @@ class Tag:
pass


class SomeModelInSingle(
class SomeModel(
Base,
IdMixin,
):
# - Meta
__tablename__ = "some_model_in_plural"
__tablename__ = "some_model"
# - Fields
alias: Mapped[str] = mapped_column(String(255), unique=True)
# - - Relationships
# one-to-many
user_id: Mapped[int] = mapped_column(
ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
user: Mapped["User"] = relationship("User", back_populates="some_model_in_single")
user: Mapped["User"] = relationship("User", back_populates="some_model")
# many-to-many
tags: Mapped[list["Tag"]] = relationship(
"Tag", secondary="some_model_in_single_x_tag"
"Tag", secondary="some_model_x_tag"
)


class SomeModelInSingleXTag(Base):
# - Meta
__tablename__ = "some_model_in_single_x_tag"
__tablename__ = "some_model_x_tag"
# - Fields
some_model_in_single_id: Mapped[int] = mapped_column(
ForeignKey("some_model_in_plural.id", ondelete="CASCADE"), primary_key=True
Expand All @@ -51,7 +51,7 @@ class SomeModelInSingleXTag(Base):
ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True
)
# - - Relationships
some_model_in_single: Mapped[SomeModelInSingle] = relationship(
"SomeModelInSingle", back_populates="some_model_in_single_x_tags"
some_model_in_single: Mapped["SomeModel"] = relationship(
"SomeModel", back_populates="some_model_x_tag"
)
tag: Mapped[Tag] = relationship("Tag", back_populates="some_model_in_single_x_tags")
tag: Mapped["Tag"] = relationship("Tag", back_populates="some_model_x_tag")

0 comments on commit 615c5d5

Please sign in to comment.