Skip to content

Commit

Permalink
Replace object creation methods with factories
Browse files Browse the repository at this point in the history
  • Loading branch information
martinburchell committed Oct 11, 2024
1 parent 665c822 commit da50f3e
Show file tree
Hide file tree
Showing 14 changed files with 2,491 additions and 1,698 deletions.
3 changes: 0 additions & 3 deletions server/camcops_server/cc_modules/cc_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2540,8 +2540,5 @@ def get_unittest_request(
req.set_get_params(params)

req._debugging_db_session = dbsession
user = User()
user.superuser = True
req._debugging_user = user

return req
161 changes: 137 additions & 24 deletions server/camcops_server/cc_modules/cc_testfactories.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""

from typing import TYPE_CHECKING
from typing import Optional, TYPE_CHECKING

from cardinal_pythonlib.datetimefunc import (
convert_datetime_to_utc,
Expand All @@ -37,6 +37,7 @@
from faker import Faker
import pendulum

from camcops_server.cc_modules.cc_blob import Blob
from camcops_server.cc_modules.cc_constants import DateFormat, ERA_NOW
from camcops_server.cc_modules.cc_device import Device
from camcops_server.cc_modules.cc_email import Email
Expand All @@ -57,6 +58,8 @@

if TYPE_CHECKING:
from factory.builder import Resolver
from camcops_server.cc_modules.cc_request import CamcopsRequest


# Avoid any ID clashes with objects not created with factories
ID_OFFSET = 1000
Expand Down Expand Up @@ -121,10 +124,27 @@ class UserFactory(BaseFactory):
class Meta:
model = User

id = factory.Sequence(lambda n: n + ID_OFFSET)
username = factory.Sequence(lambda n: f"user{n + ID_OFFSET}")
username = factory.Sequence(lambda n: f"user{n}")
hashedpw = ""

@factory.post_generation
def password(
obj: "Resolver",
create: bool,
password: Optional[str],
request: "CamcopsRequest" = None,
**kwargs,
) -> None:
if not create:
return

if password is None:
return

assert request is not None

obj.set_password(request, password)


class GenericTabletRecordFactory(BaseFactory):
class Meta:
Expand All @@ -139,21 +159,21 @@ class Meta:
_adding_user = factory.SubFactory(UserFactory)

@factory.lazy_attribute
def _when_added_exact(self) -> pendulum.DateTime:
return pendulum.parse(self.default_iso_datetime)
def _when_added_exact(obj: "Resolver") -> pendulum.DateTime:
return pendulum.parse(obj.default_iso_datetime)

@factory.lazy_attribute
def _when_added_batch_utc(self) -> pendulum.DateTime:
era_time = pendulum.parse(self.default_iso_datetime)
def _when_added_batch_utc(obj: "Resolver") -> pendulum.DateTime:
era_time = pendulum.parse(obj.default_iso_datetime)
return convert_datetime_to_utc(era_time)

@factory.lazy_attribute
def _era(self) -> str:
era_time = pendulum.parse(self.default_iso_datetime)
def _era(obj: "Resolver") -> str:
era_time = pendulum.parse(obj.default_iso_datetime)
return format_datetime(era_time, DateFormat.ISO8601)

@factory.lazy_attribute
def _current(self) -> bool:
def _current(obj: "Resolver") -> bool:
# _current = True gets ignored for some reason
return True

Expand All @@ -168,6 +188,7 @@ class Meta:
address = factory.LazyFunction(Fake.en_gb.address)
gp = factory.LazyFunction(Fake.en_gb.name)
other = factory.LazyFunction(Fake.en_us.paragraph)
email = factory.LazyFunction(Fake.en_gb.email)

@factory.lazy_attribute
def forename(obj: "Resolver") -> str:
Expand All @@ -178,14 +199,14 @@ def forename(obj: "Resolver") -> str:

class ServerCreatedPatientFactory(PatientFactory):
@factory.lazy_attribute
def _device(self) -> Device:
def _device(obj: "Resolver") -> Device:
# Should have been created in BasicDatabaseTestCase.setUp
return Device.get_server_device(
ServerCreatedPatientFactory._meta.sqlalchemy_session
)

@factory.lazy_attribute
def _era(self) -> str:
def _era(obj: "Resolver") -> str:
return ERA_NOW


Expand Down Expand Up @@ -260,14 +281,14 @@ class ServerCreatedPatientIdNumFactory(PatientIdNumFactory):
patient = factory.SubFactory(ServerCreatedPatientFactory)

@factory.lazy_attribute
def _device(self) -> Device:
def _device(obj: "Resolver") -> Device:
# Should have been created in BasicDatabaseTestCase.setUp
return Device.get_server_device(
ServerCreatedPatientIdNumFactory._meta.sqlalchemy_session
)

@factory.lazy_attribute
def _era(self) -> str:
def _era(obj: "Resolver") -> str:
return ERA_NOW


Expand Down Expand Up @@ -341,19 +362,19 @@ class Meta:
# be a SQLite thing.
@factory.post_generation
def sent_at_utc(
self, create: bool, sent_at_utc: pendulum.DateTime, **kwargs
obj: "Resolver", create: bool, sent_at_utc: pendulum.DateTime, **kwargs
) -> None:
if not create:
return

self.sent_at_utc = sent_at_utc
obj.sent_at_utc = sent_at_utc

@factory.post_generation
def sent(self, create: bool, sent: bool, **kwargs) -> None:
def sent(obj: "Resolver", create: bool, sent: bool, **kwargs) -> None:
if not create:
return

self.sent = sent
obj.sent = sent


class PatientTaskScheduleEmailFactory(BaseFactory):
Expand All @@ -374,25 +395,117 @@ class Meta:
# __init__() does not accept arbitrary keyword args.
@factory.post_generation
def may_run_reports(
self, create: bool, may_run_reports: bool, **kwargs
obj: "Resolver", create: bool, may_run_reports: bool, **kwargs
) -> None:
if not create:
return

self.may_run_reports = may_run_reports
obj.may_run_reports = may_run_reports

@factory.post_generation
def groupadmin(self, create: bool, groupadmin: bool, **kwargs) -> None:
def groupadmin(
obj: "Resolver", create: bool, groupadmin: bool, **kwargs
) -> None:
if not create:
return

self.groupadmin = groupadmin
obj.groupadmin = groupadmin

@factory.post_generation
def may_manage_patients(
self, create: bool, may_manage_patients: bool, **kwargs
obj: "Resolver", create: bool, may_manage_patients: bool, **kwargs
) -> None:
if not create:
return

obj.may_manage_patients = may_manage_patients

@factory.post_generation
def may_use_webviewer(
obj: "Resolver", create: bool, may_use_webviewer: bool, **kwargs
) -> None:
if not create:
return

obj.may_use_webviewer = may_use_webviewer

@factory.post_generation
def view_all_patients_when_unfiltered(
obj: "Resolver",
create: bool,
view_all_patients_when_unfiltered: bool,
**kwargs,
) -> None:
if not create:
return

obj.view_all_patients_when_unfiltered = (
view_all_patients_when_unfiltered
)

@factory.post_generation
def may_add_notes(
obj: "Resolver",
create: bool,
may_add_notes: bool,
**kwargs,
) -> None:
if not create:
return

obj.may_add_notes = may_add_notes

@factory.post_generation
def may_dump_data(
obj: "Resolver",
create: bool,
may_dump_data: bool,
**kwargs,
) -> None:
if not create:
return

obj.may_dump_data = may_dump_data

@factory.post_generation
def may_email_patients(
obj: "Resolver",
create: bool,
may_email_patients: bool,
**kwargs,
) -> None:
if not create:
return

obj.may_email_patients = may_email_patients

@factory.post_generation
def may_upload(
obj: "Resolver",
create: bool,
may_upload: bool,
**kwargs,
) -> None:
if not create:
return

self.may_manage_patients = may_manage_patients
obj.may_upload = may_upload

@factory.post_generation
def may_register_devices(
obj: "Resolver",
create: bool,
may_register_devices: bool,
**kwargs,
) -> None:
if not create:
return

obj.may_register_devices = may_register_devices


class BlobFactory(GenericTabletRecordFactory):
class Meta:
model = Blob

id = factory.Sequence(lambda n: n + ID_OFFSET)
12 changes: 12 additions & 0 deletions server/camcops_server/cc_modules/cc_testproviders.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ def sex(self) -> str:
return self.random_choice(["M", "F", "X"], weights=[49.8, 49.8, 0.4])


class ValidPhoneNumberProvider(BaseProvider):
# The default Faker phone_number provider for en_GB uses
# https://www.ofcom.org.uk/phones-telecoms-and-internet/information-for-industry/numbering/numbers-for-drama # noqa: E501
# 07700 900000 to 900999 reserved for TV and Radio drama purposes
# but unfortunately the phonenumbers library considers these invalid.
def valid_phone_number(self) -> str:
number = self.generator.random_int(min=7000000000, max=7999999999)

return f"+44{number}"


class WaistProvider(BaseProvider):
def waist_cm(self) -> float:
return float(self.generator.random_int(min=40, max=130))
Expand All @@ -122,5 +133,6 @@ def register_all_providers(fake: Faker) -> None:
fake.add_provider(HeightProvider)
fake.add_provider(MassProvider)
fake.add_provider(NhsNumberProvider)
fake.add_provider(ValidPhoneNumberProvider)
fake.add_provider(WaistProvider)
fake.add_provider(SexProvider)
Loading

0 comments on commit da50f3e

Please sign in to comment.