Skip to content

Commit

Permalink
Merge pull request #49 from piercefreeman/feature/group-by-function-s…
Browse files Browse the repository at this point in the history
…upport

Support function use in group by
  • Loading branch information
piercefreeman authored Dec 24, 2024
2 parents eb99385 + 4ee29b7 commit f3c4e7a
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 28 deletions.
15 changes: 14 additions & 1 deletion iceaxe/__tests__/conf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pyinstrument import Profiler

from iceaxe.base import Field, TableBase
from iceaxe.base import Field, TableBase, UniqueConstraint


class UserDemo(TableBase):
Expand Down Expand Up @@ -103,6 +103,19 @@ class DemoModelB(TableBase):
code: str = Field(unique=True)


class JsonDemo(TableBase):
"""
Model for testing JSON field updates.
"""

id: int | None = Field(primary_key=True, default=None)
settings: dict[Any, Any] = Field(is_json=True)
metadata: dict[Any, Any] | None = Field(is_json=True)
unique_val: str

table_args = [UniqueConstraint(columns=["unique_val"])]


@contextmanager
def run_profile(request):
TESTS_ROOT = Path.cwd()
Expand Down
17 changes: 17 additions & 0 deletions iceaxe/__tests__/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,20 @@ def test_multiple_group_by():
'GROUP BY "employee"."department", "employee"."last_name"',
[],
)


def test_group_by_with_function():
new_query = (
QueryBuilder()
.select(
(
func.date_trunc("month", FunctionDemoModel.created_at),
func.count(FunctionDemoModel.id),
)
)
.group_by(func.date_trunc("month", FunctionDemoModel.created_at))
)
assert new_query.build() == (
'SELECT date_trunc(\'month\', "functiondemomodel"."created_at") AS aggregate_0, count("functiondemomodel"."id") AS aggregate_1 FROM "functiondemomodel" GROUP BY date_trunc(\'month\', "functiondemomodel"."created_at")',
[],
)
84 changes: 84 additions & 0 deletions iceaxe/__tests__/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ComplexDemo,
DemoModelA,
DemoModelB,
JsonDemo,
UserDemo,
)
from iceaxe.alias_values import alias
Expand Down Expand Up @@ -1028,3 +1029,86 @@ async def test_select_with_order_by_func_count(db_connection: DBConnection):
assert result[1] == ("Jane", 1)
# Bob has 0 posts
assert result[2] == ("Bob", 0)


@pytest.mark.asyncio
async def test_json_update(db_connection: DBConnection):
"""
Test that JSON fields are correctly serialized during updates.
"""
# Create the table first
await db_connection.conn.execute("DROP TABLE IF EXISTS jsondemo")
await create_all(db_connection, [JsonDemo])

# Create initial object with JSON data
demo = JsonDemo(
settings={"theme": "dark", "notifications": True},
metadata={"version": 1},
unique_val="1",
)
await db_connection.insert([demo])

# Update JSON fields
demo.settings = {"theme": "light", "notifications": False}
demo.metadata = {"version": 2, "last_updated": "2024-01-01"}
await db_connection.update([demo])

# Verify the update through a fresh select
result = await db_connection.exec(
QueryBuilder().select(JsonDemo).where(JsonDemo.id == demo.id)
)
assert len(result) == 1
assert result[0].settings == {"theme": "light", "notifications": False}
assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}


@pytest.mark.asyncio
async def test_json_upsert(db_connection: DBConnection):
"""
Test that JSON fields are correctly serialized during upsert operations.
"""
# Create the table first
await db_connection.conn.execute("DROP TABLE IF EXISTS jsondemo")
await create_all(db_connection, [JsonDemo])

# Initial insert via upsert
demo = JsonDemo(
settings={"theme": "dark", "notifications": True},
metadata={"version": 1},
unique_val="1",
)
result = await db_connection.upsert(
[demo],
conflict_fields=(JsonDemo.unique_val,),
update_fields=(JsonDemo.metadata,),
returning_fields=(JsonDemo.unique_val, JsonDemo.metadata),
)

assert result is not None
assert len(result) == 1
assert result[0][0] == "1"
assert result[0][1] == {"version": 1}

# Update via upsert
demo2 = JsonDemo(
settings={"theme": "dark", "notifications": True},
metadata={"version": 2, "last_updated": "2024-01-01"}, # New metadata
unique_val="1", # Same value to trigger update
)
result = await db_connection.upsert(
[demo2],
conflict_fields=(JsonDemo.unique_val,),
update_fields=(JsonDemo.metadata,),
returning_fields=(JsonDemo.unique_val, JsonDemo.metadata),
)

assert result is not None
assert len(result) == 1
assert result[0][0] == "1"
assert result[0][1] == {"version": 2, "last_updated": "2024-01-01"}

# Verify through a fresh select
result = await db_connection.exec(QueryBuilder().select(JsonDemo))
assert len(result) == 1
assert result[0].settings == {"theme": "dark", "notifications": True}
assert result[0].metadata == {"version": 2, "last_updated": "2024-01-01"}
8 changes: 8 additions & 0 deletions iceaxe/mountaineer/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic_settings import BaseSettings

from iceaxe.modifications import MODIFICATION_TRACKER_VERBOSITY


class DatabaseConfig(BaseSettings):
"""
Expand Down Expand Up @@ -36,3 +38,9 @@ class DatabaseConfig(BaseSettings):
The port number where PostgreSQL server is listening.
Defaults to the standard PostgreSQL port 5432 if not specified.
"""

ICEAXE_UNCOMMITTED_VERBOSITY: MODIFICATION_TRACKER_VERBOSITY | None = None
"""
The verbosity level for uncommitted modifications.
If set to None, uncommitted modifications will not be tracked.
"""
4 changes: 3 additions & 1 deletion iceaxe/mountaineer/dependencies/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ async def get_users(db: DBConnection = Depends(get_db_connection)):
password=config.POSTGRES_PASSWORD,
database=config.POSTGRES_DB,
)
connection = DBConnection(conn)
connection = DBConnection(
conn, uncommitted_verbosity=config.ICEAXE_UNCOMMITTED_VERBOSITY
)
try:
yield connection
finally:
Expand Down
17 changes: 11 additions & 6 deletions iceaxe/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self):
self._join_clauses: list[str] = []
self._limit_value: int | None = None
self._offset_value: int | None = None
self._group_by_fields: list[DBFieldClassDefinition] = []
self._group_by_clauses: list[str] = []
self._having_conditions: list[FieldComparison] = []
self._distinct_on_fields: list[QueryElementBase] = []
self._for_update_config: ForUpdateConfig = ForUpdateConfig()
Expand Down Expand Up @@ -771,9 +771,14 @@ def group_by(self, *fields: Any):
"""

for field in fields:
if not is_column(field):
raise ValueError(f"Invalid field for group by: {field}")
self._group_by_fields.append(field)
if is_column(field):
field_token, _ = field.to_query()
elif is_function_metadata(field):
field_token = field.literal
else:
raise ValueError(f"Invalid group by field: {field}")

self._group_by_clauses.append(str(field_token))

return self

Expand Down Expand Up @@ -987,9 +992,9 @@ def build(self) -> tuple[str, list[Any]]:
query += f" WHERE {comparison_literal}"
variables += comparison_variables

if self._group_by_fields:
if self._group_by_clauses:
query += " GROUP BY "
query += ", ".join(str(sql(field)) for field in self._group_by_fields)
query += ", ".join(str(field) for field in self._group_by_clauses)

if self._having_conditions:
query += " HAVING "
Expand Down
64 changes: 44 additions & 20 deletions iceaxe/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from contextlib import asynccontextmanager
from json import loads as json_loads
from typing import (
Any,
Literal,
Expand All @@ -14,7 +15,7 @@
import asyncpg
from typing_extensions import TypeVarTuple

from iceaxe.base import TableBase
from iceaxe.base import DBFieldClassDefinition, TableBase
from iceaxe.logging import LOGGER
from iceaxe.modifications import ModificationTracker
from iceaxe.queries import (
Expand Down Expand Up @@ -81,6 +82,7 @@ class User(TableBase):
def __init__(
self,
conn: asyncpg.Connection,
*,
uncommitted_verbosity: Literal["ERROR", "WARNING", "INFO"] | None = None,
):
"""
Expand Down Expand Up @@ -273,8 +275,8 @@ async def upsert(
*,
conflict_fields: tuple[Any, ...],
update_fields: tuple[Any, ...] | None = None,
returning_fields: tuple[T, *Ts],
) -> list[tuple[T, *Ts]]: ...
returning_fields: tuple[T, *Ts] | None = None,
) -> list[tuple[T, *Ts]] | None: ...

@overload
async def upsert(
Expand Down Expand Up @@ -332,13 +334,26 @@ async def upsert(
return None

# Evaluate column types
conflict_fields_cols = [field for field in conflict_fields if is_column(field)]
update_fields_cols = [
field for field in update_fields or [] if is_column(field)
]
returning_fields_cols = [
field for field in returning_fields or [] if is_column(field)
]
conflict_fields_cols: list[DBFieldClassDefinition] = []
update_fields_cols: list[DBFieldClassDefinition] = []
returning_fields_cols: list[DBFieldClassDefinition] = []

# Explicitly validate types of all columns
for field in conflict_fields:
if is_column(field):
conflict_fields_cols.append(field)
else:
raise ValueError(f"Field {field} is not a column")
for field in update_fields or []:
if is_column(field):
update_fields_cols.append(field)
else:
raise ValueError(f"Field {field} is not a column")
for field in returning_fields or []:
if is_column(field):
returning_fields_cols.append(field)
else:
raise ValueError(f"Field {field} is not a column")

results: list[tuple[T, *Ts]] = []
async with self._ensure_transaction():
Expand Down Expand Up @@ -387,14 +402,17 @@ async def upsert(
if returning_fields_cols:
result = await self.conn.fetchrow(query, *values)
if result:
results.append(
tuple(
[
result[field.key]
for field in returning_fields_cols
]
)
)
# Process returned values, deserializing JSON if needed
processed_values = []
for field in returning_fields_cols:
value = result[field.key]
if (
value is not None
and field.root_model.model_fields[field.key].is_json
):
value = json_loads(value)
processed_values.append(value)
results.append(tuple(processed_values))
else:
await self.conn.execute(query, *values)

Expand Down Expand Up @@ -441,7 +459,7 @@ async def update(self, objects: Sequence[TableBase]):

for obj in model_objects:
modified_attrs = {
k: v
k: obj.model_fields[k].to_db_value(v)
for k, v in obj.get_modified_attributes().items()
if not obj.model_fields[k].exclude
}
Expand All @@ -455,7 +473,13 @@ async def update(self, objects: Sequence[TableBase]):

query = f"UPDATE {table_name} SET {set_clause} WHERE {primary_key_name} = $1"
values = [getattr(obj, primary_key)] + list(modified_attrs.values())
await self.conn.execute(query, *values)
try:
await self.conn.execute(query, *values)
except Exception as e:
LOGGER.error(
f"Error executing query: {query} with variables: {values}"
)
raise e
obj.clear_modified_attributes()

self.modification_tracker.clear_status(objects)
Expand Down

0 comments on commit f3c4e7a

Please sign in to comment.