Skip to content

Commit

Permalink
Use the per-test table creation session in SQLA_coltypes tests
Browse files Browse the repository at this point in the history
and rename it to temp_session (and temp_engine).
  • Loading branch information
martinburchell committed Feb 26, 2025
1 parent b5bd4db commit ad93c07
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 56 deletions.
50 changes: 25 additions & 25 deletions server/camcops_server/cc_modules/tests/cc_dump_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
)


@pytest.mark.usefixtures("setup_dest_session")
@pytest.mark.usefixtures("setup_temp_session")
class DumpTestCase(DemoRequestTestCase):
pass

Expand All @@ -67,7 +67,7 @@ def test_copies_column_comments(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(patient)
Expand All @@ -80,7 +80,7 @@ def test_copies_column_info_for_foreign_key(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(patient)
Expand All @@ -99,7 +99,7 @@ def test_foreign_keys_are_empty_set(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(bmi)
Expand All @@ -112,7 +112,7 @@ def test_tablet_record_includes_summaries(self) -> None:

options = TaskExportOptions(db_include_summaries=True)
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(bmi)
Expand All @@ -130,7 +130,7 @@ def test_has_extra_id_num_columns(self) -> None:

options = TaskExportOptions(db_patient_id_per_row=True)
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(patient)
Expand All @@ -144,7 +144,7 @@ def test_task_descendant_has_extra_task_xref_columns(self) -> None:

options = TaskExportOptions(db_patient_id_per_row=True)
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

single_photo = photo_sequence.photos[0]
Expand All @@ -162,7 +162,7 @@ def test_copies_table_with_subset_of_columns(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

columns = [
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_appends_extra_id_columns(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

est = ExtraSummaryTable(
Expand All @@ -216,7 +216,7 @@ def test_appends_extra_task_xref_columns(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

est = ExtraSummaryTable(
Expand Down Expand Up @@ -248,15 +248,15 @@ def test_task_fields_copied(self) -> None:

copy_tasks_and_summaries(
tasks=[bmi],
dst_engine=self.dest_engine,
dst_session=self.dest_session,
dst_engine=self.temp_engine,
dst_session=self.temp_session,
export_options=export_options,
req=self.req,
)
self.dest_session.commit()
self.temp_session.commit()

query = select(text("*")).select_from(table("bmi"))
result = self.dest_session.execute(query)
result = self.temp_session.execute(query)

row = next(result)

Expand Down Expand Up @@ -288,15 +288,15 @@ def test_summary_fields_copied(self) -> None:

copy_tasks_and_summaries(
tasks=[bmi],
dst_engine=self.dest_engine,
dst_session=self.dest_session,
dst_engine=self.temp_engine,
dst_session=self.temp_session,
export_options=export_options,
req=self.req,
)
self.dest_session.commit()
self.temp_session.commit()

query = select(text("*")).select_from(table("bmi"))
result = self.dest_session.execute(query)
result = self.temp_session.execute(query)

row = next(result)

Expand All @@ -316,13 +316,13 @@ def test_has_extra_id_num_columns(self) -> None:

copy_tasks_and_summaries(
tasks=[bmi],
dst_engine=self.dest_engine,
dst_session=self.dest_session,
dst_engine=self.temp_engine,
dst_session=self.temp_session,
export_options=export_options,
req=self.req,
)
query = select(text("*")).select_from(table("bmi"))
result = self.dest_session.execute(query)
result = self.temp_session.execute(query)

row = next(result)

Expand All @@ -343,13 +343,13 @@ def test_has_extra_task_xref_columns(self) -> None:

copy_tasks_and_summaries(
tasks=[photo_sequence],
dst_engine=self.dest_engine,
dst_session=self.dest_session,
dst_engine=self.temp_engine,
dst_session=self.temp_session,
export_options=export_options,
req=self.req,
)
query = select(text("*")).select_from(table("photosequence_photos"))
result = self.dest_session.execute(query)
result = self.temp_session.execute(query)

row = next(result)

Expand All @@ -368,7 +368,7 @@ def test_omits_irrelevant_columns(self) -> None:
)

controller = DumpController(
self.dest_engine, self.dest_session, options, self.req
self.temp_engine, self.temp_session, options, self.req
)

table_column_names = {}
Expand Down
45 changes: 29 additions & 16 deletions server/camcops_server/cc_modules/tests/cc_sqla_coltypes_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
import pendulum
from pendulum import DateTime as Pendulum, Duration
import phonenumbers
import pytest
from semantic_version import Version
from sqlalchemy import insert
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import func
from sqlalchemy.sql.schema import Column
Expand All @@ -54,14 +56,17 @@
SemanticVersionColType,
unknown_field_to_utcdatetime,
)
from camcops_server.cc_modules.cc_sqlalchemy import Base
from camcops_server.cc_modules.cc_unittest import DemoRequestTestCase


class TestColTypeBase(DeclarativeBase):
pass


# =============================================================================
# Unit testing
# =============================================================================
class TestColType(Base):
class TestColType(TestColTypeBase):
__tablename__ = "test_coltype"

id = Column("id", Integer, primary_key=True)
Expand All @@ -83,7 +88,15 @@ class TestColType(Base):
)


class SqlaColtypesTest(DemoRequestTestCase):
@pytest.mark.usefixtures("setup_temp_session")
class SqlaColtypesTestCase(DemoRequestTestCase):
def setUp(self) -> None:
super().setUp()

TestColType.metadata.create_all(self.temp_engine)


class SqlaColtypesTest(SqlaColtypesTestCase):
def _assert_dt_equal(
self,
a: Union[datetime.datetime, Pendulum],
Expand All @@ -104,7 +117,7 @@ def test_iso_datetime_field(self) -> None:

table = TestColType.__table__

self.dbsession.execute(
self.temp_session.execute(
insert(table).values(
[
{
Expand Down Expand Up @@ -146,7 +159,7 @@ def test_iso_datetime_field(self) -> None:
.order_by(table.c.id)
)

rows = list(self.dbsession.execute(statement).mappings())
rows = list(self.temp_session.execute(statement).mappings())

self._assert_dt_equal(rows[0].dt_local, now)
self._assert_dt_equal(rows[0].dt_utc, now_utc)
Expand All @@ -172,7 +185,7 @@ def test_iso_duration_field(self) -> None:

table = TestColType.__table__

self.dbsession.execute(
self.temp_session.execute(
insert(table).values(
[
{"id": 1, "duration_iso": d1},
Expand All @@ -188,7 +201,7 @@ def test_iso_duration_field(self) -> None:
.order_by(table.c.id)
)

rows = list(self.dbsession.execute(statement).mappings())
rows = list(self.temp_session.execute(statement).mappings())

self._assert_duration_equal(rows[0].duration_iso, d1)
self._assert_duration_equal(rows[1].duration_iso, d2)
Expand All @@ -201,7 +214,7 @@ def test_semantic_version_field(self) -> None:

table = TestColType.__table__

self.dbsession.execute(
self.temp_session.execute(
insert(table).values(
[
{"id": 1, "version": v1},
Expand All @@ -217,7 +230,7 @@ def test_semantic_version_field(self) -> None:
.order_by(table.c.id)
)

rows = list(self.dbsession.execute(statement).mappings())
rows = list(self.temp_session.execute(statement).mappings())

self.assertEqual(rows[0]["version"], v1)
self.assertEqual(rows[1]["version"], v2)
Expand All @@ -232,7 +245,7 @@ def test_phone_number_field(self) -> None:

table = TestColType.__table__

self.dbsession.execute(
self.temp_session.execute(
insert(table).values(
[
{"id": 1, "phone_number": p1},
Expand All @@ -249,15 +262,15 @@ def test_phone_number_field(self) -> None:
.order_by(table.c.id)
)

rows = list(self.dbsession.execute(statement).mappings())
rows = list(self.temp_session.execute(statement).mappings())

self.assertEqual(rows[0]["phone_number"], p1)
self.assertEqual(rows[1]["phone_number"], p2)
self.assertEqual(rows[2]["phone_number"], p3)
self.assertIsNone(rows[3]["phone_number"])


class GenCamcopsColumnsTests(DemoRequestTestCase):
class GenCamcopsColumnsTests(SqlaColtypesTestCase):
def test_returns_camcops_columns(self) -> None:
obj = TestColType(id=1, number_1_to_3=1, flag=True)

Expand All @@ -269,7 +282,7 @@ def test_returns_camcops_columns(self) -> None:
self.assertTrue(column.info.get("is_camcops_column"))


class GenCamcopsBlobColumnsTests(DemoRequestTestCase):
class GenCamcopsBlobColumnsTests(SqlaColtypesTestCase):
def test_returns_camcops_columns(self) -> None:
obj = TestColType(id=1, blob_id=2)

Expand All @@ -281,7 +294,7 @@ def test_returns_camcops_columns(self) -> None:
self.assertTrue(column.info.get("is_blob_id_field"))


class GenColumnsMatchingAttrnamesTests(DemoRequestTestCase):
class GenColumnsMatchingAttrnamesTests(SqlaColtypesTestCase):
def test_returns_matching_columns(self) -> None:
obj = TestColType(id=1, number_1_to_3=1, flag=True)

Expand All @@ -296,7 +309,7 @@ def test_returns_matching_columns(self) -> None:
self.assertEqual(attrnames, [])


class PermittedValueFailureMsgsTests(DemoRequestTestCase):
class PermittedValueFailureMsgsTests(SqlaColtypesTestCase):
def test_returns_failure_messages(self) -> None:
obj = TestColType(id=1, number_1_to_3=123)

Expand All @@ -306,7 +319,7 @@ def test_returns_failure_messages(self) -> None:
self.assertIn("Invalid value", messages[0])


class PermittedValuesOkTests(DemoRequestTestCase):
class PermittedValuesOkTests(SqlaColtypesTestCase):
def test_returns_false_if_not_ok(self) -> None:
obj = TestColType(id=1, number_1_to_3=123)

Expand Down
Loading

0 comments on commit ad93c07

Please sign in to comment.