Skip to content

Commit

Permalink
Complete SQLAlchemy 2.0 updates for database export
Browse files Browse the repository at this point in the history
  • Loading branch information
martinburchell committed Feb 21, 2025
1 parent 74e9f42 commit 76b68f9
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 102 deletions.
93 changes: 21 additions & 72 deletions server/camcops_server/cc_modules/cc_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
gen_orm_classes_from_base,
walk_orm_tree,
)
from sqlalchemy import insert, Integer
from sqlalchemy.exc import CompileError
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session as SqlASession
Expand Down Expand Up @@ -143,9 +144,6 @@
# =============================================================================


USE_LEGACY_DUMP_METHOD = 1


class DumpController(object):
"""
A controller class that manages the copying (dumping) of information from
Expand Down Expand Up @@ -328,79 +326,22 @@ def get_dest_table_for_src_object(self, src_obj: object) -> Table:
if tablename in self.dst_tables:
return self.dst_tables[tablename]

if USE_LEGACY_DUMP_METHOD:
dst_table = self.get_legacy_dest_table(src_obj)
else:
dst_table = self.get_new_dest_table(src_obj)

# ... that modifies the metadata, so:
self.dst_tables[tablename] = dst_table
return dst_table

def get_legacy_dest_table(self, src_obj: object) -> Table:
src_table = src_obj.__table__ # type: Table
tablename = src_table.name

# Copy columns, dropping any we don't want, and dropping FK constraints
dst_columns = [] # type: List[Column]
for src_column in src_table.columns:
# log.debug("trying {!r}", src_column.name)
if self._dump_skip_column(tablename, src_column.name):
# log.debug("... skipping {!r}", src_column.name)
continue
# You can't add the source column directly; you get
# "sqlalchemy.exc.ArgumentError: Column object 'ccc' already
# assigned to Table 'ttt'"
copied_column = src_column.copy()
if FOREIGN_KEY_CONSTRAINTS_IN_DUMP:
copied_column.foreign_keys = set(
fk.copy() for fk in src_column.foreign_keys
)
log.warning(
"NOT WORKING: foreign key commands not being " "emitted"
)
# but
# https://docs.sqlalchemy.org/en/latest/core/constraints.html
# works fine under SQLite, even if the other table hasn't been
# created yet. Does the table to which the FK refer have to be
# in the metadata already?
# That's quite possible, but I've not checked.
# Would need to iterate through tables in dependency order,
# like merge_db() does.
else:
# Probably blank already, as the copy() command only copies
# non-constraint-bound ForeignKey objects, but to be sure:
copied_column.foreign_keys = set()
# ... type is: Set[ForeignKey]
# if src_column.foreign_keys:
# log.debug("Column {}, FKs {!r} -> {!r}", src_column.name,
# src_column.foreign_keys,
# copied_column.foreign_keys)
dst_columns.append(copied_column)

dst_columns += self.get_extra_columns(src_obj)

return Table(tablename, self.dst_metadata, *dst_columns)

def get_new_dest_table(self, src_obj: object) -> Table:
src_table = src_obj.__table__ # type: Table
dst_table = src_table.to_metadata(self.dst_metadata)

dst_columns = self.get_extra_columns(src_obj)

for dst_column in dst_columns:
dst_table.append_column(dst_column)

return dst_table
# Copy columns, dropping any we don't want, and dropping FK constraints
changed_columns = [] # type: List[Column]

def get_extra_columns(self, src_obj: object) -> List[Column]:
dst_columns = []
for dst_column in dst_table.columns:
if dst_column.foreign_keys:
changed_columns.append(Column(dst_column.name, Integer))
elif self._dump_skip_column(tablename, dst_column.name):
changed_columns.append(Column(dst_column.name, Integer))

# Add extra columns?
if self.export_options.db_include_summaries:
if isinstance(src_obj, GenericTabletRecordMixin):
for summary_element in src_obj.get_summaries(self.req):
dst_columns.append(
changed_columns.append(
CamcopsColumn(
summary_element.name,
summary_element.coltype,
Expand All @@ -411,11 +352,19 @@ def get_extra_columns(self, src_obj: object) -> List[Column]:
if self.export_options.db_patient_id_in_each_row:
merits, _ = self._merits_extra_id_num_columns(src_obj)
if merits:
dst_columns.extend(all_extra_id_columns(self.req))
changed_columns.extend(all_extra_id_columns(self.req))
if isinstance(src_obj, TaskDescendant):
dst_columns += src_obj.extra_task_xref_columns()
changed_columns += src_obj.extra_task_xref_columns()

return dst_columns
dst_table = Table(
tablename,
self.dst_metadata,
*changed_columns,
extend_existing=True,
)
# ... that modifies the metadata, so:
self.dst_tables[tablename] = dst_table
return dst_table

def get_dest_table_for_est(
self, est: "ExtraSummaryTable", add_extra_id_cols: bool = False
Expand Down Expand Up @@ -529,7 +478,7 @@ def _copy_object_to_dump(self, src_obj: object) -> None:
if isinstance(src_obj, TaskDescendant):
src_obj.add_extra_task_xref_info_to_row(row)
try:
self.dst_session.execute(dst_table.insert(row))
self.dst_session.execute(insert(dst_table).values(row))
except CompileError:
log.critical("\ndst_table:\n{}\nrow:\n{}", dst_table, row)
raise
Expand Down
46 changes: 23 additions & 23 deletions server/camcops_server/cc_modules/tests/cc_dump_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,19 @@
)


class GetDestTableForSrcObjectTests(DemoRequestTestCase):
@pytest.mark.usefixtures("setup_dest_session")
class DumpTestCase(DemoRequestTestCase):
pass


class GetDestTableForSrcObjectTests(DumpTestCase):
def test_copies_column_comments(self) -> None:
patient = PatientFactory()
src_table = patient.__table__

options = TaskExportOptions()
controller = DumpController(
self.engine, self.dbsession, options, self.req
self.dest_engine, self.dest_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(patient)
Expand All @@ -75,7 +80,7 @@ def test_foreign_keys_are_empty_set(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.engine, self.dbsession, options, self.req
self.dest_engine, self.dest_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(bmi)
Expand All @@ -88,7 +93,7 @@ def test_tablet_record_includes_summaries(self) -> None:

options = TaskExportOptions(db_include_summaries=True)
controller = DumpController(
self.engine, self.dbsession, options, self.req
self.dest_engine, self.dest_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(bmi)
Expand All @@ -106,7 +111,7 @@ def test_has_extra_id_num_columns(self) -> None:

options = TaskExportOptions(db_patient_id_per_row=True)
controller = DumpController(
self.engine, self.dbsession, options, self.req
self.dest_engine, self.dest_session, options, self.req
)

dest_table = controller.get_dest_table_for_src_object(patient)
Expand All @@ -120,7 +125,7 @@ def test_task_descendant_has_extra_task_xref_columns(self) -> None:

options = TaskExportOptions(db_patient_id_per_row=True)
controller = DumpController(
self.engine, self.dbsession, options, self.req
self.dest_engine, self.dest_session, options, self.req
)

single_photo = photo_sequence.photos[0]
Expand All @@ -131,14 +136,14 @@ def test_task_descendant_has_extra_task_xref_columns(self) -> None:
self.assertIn(EXTRA_TASK_TABLENAME_FIELD, dest_names)


class GetDestTableForEstTests(DemoRequestTestCase):
class GetDestTableForEstTests(DumpTestCase):
def test_copies_table_with_subset_of_columns(self) -> None:
patient = PatientFactory()
bmi = BmiFactory(patient=patient)

options = TaskExportOptions()
controller = DumpController(
self.engine, self.dbsession, options, self.req
self.dest_engine, self.dest_session, options, self.req
)

columns = [
Expand Down Expand Up @@ -168,7 +173,7 @@ def test_appends_extra_id_columns(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.engine, self.dbsession, options, self.req
self.dest_engine, self.dest_session, options, self.req
)

est = ExtraSummaryTable(
Expand All @@ -192,7 +197,7 @@ def test_appends_extra_task_xref_columns(self) -> None:

options = TaskExportOptions()
controller = DumpController(
self.engine, self.dbsession, options, self.req
self.dest_engine, self.dest_session, options, self.req
)

est = ExtraSummaryTable(
Expand All @@ -211,13 +216,11 @@ def test_appends_extra_task_xref_columns(self) -> None:
self.assertIn(EXTRA_TASK_TABLENAME_FIELD, dest_names)


@pytest.mark.usefixtures("setup_dest_session")
class CopyTasksAndSummariesTests(DemoRequestTestCase):
class CopyTasksAndSummariesTests(DumpTestCase):
def test_task_fields_copied(self) -> None:
export_options = TaskExportOptions(
include_blobs=False,
db_patient_id_per_row=False,
db_make_all_tables_even_empty=False,
db_include_summaries=False,
)

Expand All @@ -242,13 +245,13 @@ def test_task_fields_copied(self) -> None:
self.assertAlmostEqual(row.height_m, bmi.height_m)
self.assertAlmostEqual(row.mass_kg, bmi.mass_kg)

# TODO: Should be present but None
# for colname in [
# "_addition_pending",
# "_forcibly_preserved",
# "_manually_erased",
# ]: # not exhaustive list
# self.assertIsNone(getattr(row, colname))
# Should have been nulled
for colname in [
"_addition_pending",
"_forcibly_preserved",
"_manually_erased",
]: # not exhaustive list
self.assertIsNone(getattr(row, colname))

# No summaries
self.assertFalse(hasattr(row, SFN_IS_COMPLETE))
Expand All @@ -258,7 +261,6 @@ def test_summary_fields_copied(self) -> None:
export_options = TaskExportOptions(
include_blobs=False,
db_patient_id_per_row=False,
db_make_all_tables_even_empty=False,
db_include_summaries=True,
)

Expand Down Expand Up @@ -286,7 +288,6 @@ def test_has_extra_id_num_columns(self) -> None:
export_options = TaskExportOptions(
include_blobs=False,
db_patient_id_per_row=True,
db_make_all_tables_even_empty=False,
db_include_summaries=False,
)

Expand Down Expand Up @@ -315,7 +316,6 @@ def test_has_extra_task_xref_columns(self) -> None:
export_options = TaskExportOptions(
include_blobs=False,
db_patient_id_per_row=True,
db_make_all_tables_even_empty=False,
db_include_summaries=False,
)

Expand Down
39 changes: 32 additions & 7 deletions server/camcops_server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,24 +307,49 @@ def setup(


@pytest.fixture(scope="session")
def dest_engine(request: "FixtureRequest") -> Generator["Engine", None, None]:
engine = make_memory_sqlite_engine()
def dest_engine(
request: "FixtureRequest", echo: bool
) -> Generator["Engine", None, None]:
"""
An in-memory database for testing export via the dest_session fixture.
"""
engine = make_memory_sqlite_engine(echo=echo)

yield engine

engine.dispose()


# noinspection PyUnusedLocal
@pytest.fixture
def dest_tables(
request: "FixtureRequest", dest_engine: "Engine"
) -> Generator[None, None, None]:

# Unlike the tables fixture, we don't create any tables as they are created
# in the tests themselves and the columns change between tests. So the
# scope here is the default 'function', which means they are dropped after
# each test, rather than 'session', which would only drop them at the end
# of the test run.

yield

metadata = MetaData()
metadata.reflect(dest_engine)
metadata.drop_all(dest_engine)


# noinspection PyUnusedLocal
@pytest.fixture
def dest_session(
request: "FixtureRequest",
dest_engine: "Engine",
dest_tables: None,
) -> Generator[Session, None, None]:
"""
Returns an sqlalchemy session, and after the test tears down everything
properly.
"""

connection = dest_engine.connect()
# begin the nested transaction
transaction = connection.begin()
Expand All @@ -339,16 +364,16 @@ def dest_session(
# put back the connection to the connection pool
connection.close()

metadata = MetaData()
metadata.reflect(dest_engine)
metadata.drop_all(dest_engine)


@pytest.fixture
def setup_dest_session(
request: "FixtureRequest",
dest_engine: "Engine",
dest_session: Session,
) -> None:
"""
Use this fixture where a second, in-memory database is required.
Slow, so avoid use sparingly.
"""
request.cls.dest_session = dest_session
request.cls.dest_engine = dest_engine

0 comments on commit 76b68f9

Please sign in to comment.