Skip to content

Commit

Permalink
Merge pull request #47 from cblessing24/fix_issue46
Browse files Browse the repository at this point in the history
Fix parts of computed source getting incorrect name
  • Loading branch information
christoph-blessing authored Oct 6, 2023
2 parents 84ac630 + 74df031 commit 758a26c
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 144 deletions.
2 changes: 1 addition & 1 deletion link/infrastructure/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def create_dj_table() -> dj.Table:
if not child.table_name.startswith(parts().table_name + "__"):
continue
part_definition = child.describe(printout=False).replace(parts().full_table_name, "master")
part_definitions[dj.utils.to_camel_case(child.table_name.split("__")[1])] = part_definition
part_definitions[dj.utils.to_camel_case(child.table_name.split("__")[-1])] = part_definition
for part_name, part_definition in part_definitions.items():
part_definitions[part_name] = replace_stores(part_definition, replacement_stores)
part_tables: dict[str, type[dj.Part]] = {}
Expand Down
142 changes: 66 additions & 76 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import pathlib
from concurrent import futures
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from random import choices
from string import ascii_lowercase
from typing import Dict

import datajoint as dj
import docker
Expand Down Expand Up @@ -59,7 +58,6 @@ class HealthCheckConfig:
@dataclass(frozen=True)
class DatabaseConfig:
password: str # MYSQL root user password
users: Dict[str, UserConfig]
schema_name: str


Expand All @@ -71,6 +69,7 @@ class DatabaseSpec:

@dataclass(frozen=True)
class UserConfig:
host: str
name: str
password: str
grants: list[str]
Expand Down Expand Up @@ -104,35 +103,6 @@ def docker_client():
return docker.client.from_env()


@pytest.fixture(scope=SCOPE)
def create_user_configs(outbound_schema_name):
def _create_user_configs(schema_name):
return dict(
admin_user=UserConfig(
"admin_user",
"admin_user_password",
grants=[
f"GRANT ALL PRIVILEGES ON `{outbound_schema_name}`.* TO '$name'@'%';",
],
),
end_user=UserConfig(
"end_user",
"end_user_password",
grants=[r"GRANT ALL PRIVILEGES ON `end_user\_%`.* TO '$name'@'%';"],
),
dj_user=UserConfig(
"dj_user",
"dj_user_password",
grants=[
f"GRANT SELECT, REFERENCES ON `{schema_name}`.* TO '$name'@'%';",
f"GRANT ALL PRIVILEGES ON `{outbound_schema_name}`.* TO '$name'@'%';",
],
),
)

return _create_user_configs


@pytest.fixture(scope=SCOPE)
def create_random_string():
def _create_random_string(length=6):
Expand All @@ -147,7 +117,7 @@ def network():


@pytest.fixture(scope=SCOPE)
def get_db_spec(create_random_string, create_user_configs, network):
def get_db_spec(create_random_string, network):
def _get_db_spec(name):
schema_name = "end_user_schema"
return DatabaseSpec(
Expand All @@ -160,7 +130,6 @@ def _get_db_spec(name):
),
DatabaseConfig(
password=DATABASE_ROOT_PASSWORD,
users=create_user_configs(schema_name),
schema_name=schema_name,
),
)
Expand Down Expand Up @@ -188,13 +157,6 @@ def _get_minio_spec(name):
return _get_minio_spec


@pytest.fixture(scope=SCOPE)
def outbound_schema_name():
name = "outbound_schema"
os.environ["LINK_OUTBOUND"] = name
return name


def get_runner_kwargs(docker_client, spec):
common = dict(
detach=True,
Expand Down Expand Up @@ -235,20 +197,15 @@ def get_runner_kwargs(docker_client, spec):


@pytest.fixture(scope=SCOPE)
def create_user_config(create_random_string):
def _create_user_config(grants):
name = create_random_string()
return UserConfig(
name=name, password=create_random_string(), grants=[grant.replace("$name", name) for grant in grants]
)

return _create_user_config


@pytest.fixture(scope=SCOPE)
def create_user(create_user_config):
def create_user(create_random_string):
def _create_user(db_spec, grants):
config = create_user_config(grants)
user_name = create_random_string()
config = UserConfig(
host=db_spec.container.name,
name=user_name,
password=create_random_string(),
grants=[grant.replace("$name", user_name) for grant in grants],
)
with mysql_conn(db_spec) as connection:
with connection.cursor() as cursor:
cursor.execute(f"CREATE USER '{config.name}'@'%' IDENTIFIED BY '{config.password}';")
Expand Down Expand Up @@ -347,8 +304,10 @@ def _get_store_spec(minio_spec, protocol="s3", port=9000):
@pytest.fixture()
def dj_connection():
@contextmanager
def _dj_connection(db_spec, user_spec):
connection = dj.Connection(db_spec.container.name, user_spec.name, user_spec.password)
def _dj_connection():
connection = dj.Connection(
dj.config["database.host"], dj.config["database.user"], dj.config["database.password"]
)
try:
yield connection
finally:
Expand All @@ -360,12 +319,12 @@ def _dj_connection(db_spec, user_spec):
@pytest.fixture()
def connection_config():
@contextmanager
def _connection_config(db_spec, user):
def _connection_config(actor):
try:
with dj.config(
database__host=db_spec.container.name,
database__user=user.name,
database__password=user.password,
database__host=actor.credentials.host,
database__user=actor.credentials.name,
database__password=actor.credentials.password,
safemode=False,
):
dj.conn(reset=True)
Expand Down Expand Up @@ -449,16 +408,34 @@ def _temp_env_vars(**vars):
return _temp_env_vars


@pytest.fixture()
def act_as(connection_config, temp_env_vars):
@contextmanager
def _act_as(actor):
with connection_config(actor), temp_env_vars(**actor.environment):
yield

return _act_as


@pytest.fixture()
def configured_environment(temp_env_vars):
@contextmanager
def _configured_environment(user_spec, schema_name):
with temp_env_vars(LINK_USER=user_spec.name, LINK_PASS=user_spec.password, LINK_OUTBOUND=schema_name):
def _configured_environment(actor, schema_name):
with temp_env_vars(
LINK_USER=actor.credentials.name, LINK_PASS=actor.credentials.password, LINK_OUTBOUND=schema_name
):
yield

return _configured_environment


@dataclass(frozen=True)
class Actor:
credentials: UserConfig
environment: dict[str, str] = field(default_factory=dict)


@pytest.fixture()
def prepare_multiple_links(create_random_string, create_user, databases):
def _prepare_multiple_links(n_local_schemas):
Expand All @@ -468,24 +445,37 @@ def create_schema_names():
return names

schema_names = create_schema_names()
user_specs = {
"admin": create_user(databases["source"], grants=["GRANT ALL PRIVILEGES ON *.* TO '$name'@'%';"]),
"source": create_user(
databases["source"], grants=[f"GRANT ALL PRIVILEGES ON `{schema_names['source']}`.* TO '$name'@'%';"]
),
"local": create_user(
databases["local"],
grants=[f"GRANT ALL PRIVILEGES ON `{name}`.* TO '$name'@'%';" for name in schema_names["local"]],
),
"link": create_user(
link_actor = Actor(
create_user(
databases["source"],
grants=[
f"GRANT SELECT, REFERENCES ON `{schema_names['source']}`.* TO '$name'@'%';",
f"GRANT ALL PRIVILEGES ON `{schema_names['outbound']}`.* TO '$name'@'%';",
],
)
)
actors = {
"admin": Actor(create_user(databases["source"], grants=["GRANT ALL PRIVILEGES ON *.* TO '$name'@'%';"])),
"source": Actor(
create_user(
databases["source"],
grants=[f"GRANT ALL PRIVILEGES ON `{schema_names['source']}`.* TO '$name'@'%';"],
)
),
"local": Actor(
create_user(
databases["local"],
grants=[f"GRANT ALL PRIVILEGES ON `{name}`.* TO '$name'@'%';" for name in schema_names["local"]],
),
{
"LINK_USER": link_actor.credentials.name,
"LINK_PASS": link_actor.credentials.password,
"LINK_OUTBOUND": schema_names["outbound"],
},
),
"link": link_actor,
}
return schema_names, user_specs
return schema_names, actors

return _prepare_multiple_links

Expand Down Expand Up @@ -514,14 +504,14 @@ def _create_table(name, tier, definition, *, parts=None):

@pytest.fixture()
def prepare_table(dj_connection):
def _prepare_table(database, user, schema, table_cls, *, data=None, parts=None, context=None):
def _prepare_table(schema, table_cls, *, data=None, parts=None, context=None):
if data is None:
data = []
if parts is None:
parts = {}
with dj_connection(database, user) as connection:
with dj_connection() as connection:
dj.schema(schema, connection=connection, context=context)(table_cls)
table_cls().insert(data)
table_cls().insert(data, allow_direct_insert=True)
for name, part_data in parts.items():
getattr(table_cls, name).insert(part_data)

Expand Down
69 changes: 41 additions & 28 deletions tests/functional/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,59 @@


def test_local_table_creation_from_source_table_that_has_parent_raises_no_error(
prepare_link, create_table, prepare_table, databases, configured_environment, connection_config
prepare_link, create_table, prepare_table, act_as
):
schema_names, user_specs = prepare_link()
source_table_parent = create_table("Foo", dj.Manual, "foo: int")
prepare_table(databases["source"], user_specs["source"], schema_names["source"], source_table_parent)
source_table_name = "Bar"
source_table = create_table(source_table_name, dj.Manual, "-> source_table_parent")
prepare_table(
databases["source"],
user_specs["source"],
schema_names["source"],
source_table,
context={"source_table_parent": source_table_parent},
)
with connection_config(databases["local"], user_specs["local"]), configured_environment(
user_specs["link"], schema_names["outbound"]
):
schema_names, actors = prepare_link()
with act_as(actors["source"]):
source_table_parent = create_table("Foo", dj.Manual, "foo: int")
prepare_table(schema_names["source"], source_table_parent)
source_table_name = "Bar"
source_table = create_table(source_table_name, dj.Manual, "-> source_table_parent")
prepare_table(schema_names["source"], source_table, context={"source_table_parent": source_table_parent})
with act_as(actors["local"]):
link(
databases["source"].container.name,
actors["source"].credentials.host,
schema_names["source"],
schema_names["outbound"],
"Outbound",
schema_names["local"],
)(type(source_table_name, (dj.Manual,), {}))
)(type(source_table_name, tuple(), {}))


def test_local_table_creation_from_source_table_that_uses_current_timestamp_default_raises_no_error(
prepare_link, create_table, prepare_table, databases, configured_environment, connection_config
prepare_link, create_table, prepare_table, act_as
):
schema_names, user_specs = prepare_link()
source_table_name = "Foo"
source_table = create_table(source_table_name, dj.Manual, "foo = CURRENT_TIMESTAMP : timestamp")
prepare_table(databases["source"], user_specs["source"], schema_names["source"], source_table)
with connection_config(databases["local"], user_specs["local"]), configured_environment(
user_specs["link"], schema_names["outbound"]
):
schema_names, actors = prepare_link()
with act_as(actors["source"]):
source_table_name = "Foo"
source_table = create_table(source_table_name, dj.Manual, "foo = CURRENT_TIMESTAMP : timestamp")
prepare_table(schema_names["source"], source_table)
with act_as(actors["local"]):
link(
databases["source"].container.name,
actors["source"].credentials.host,
schema_names["source"],
schema_names["outbound"],
"Outbound",
schema_names["local"],
)(type(source_table_name, (dj.Manual,), {}))
)(type(source_table_name, tuple(), {}))


def test_part_tables_of_computed_source_gets_created_with_correct_name(
prepare_link, create_table, prepare_table, act_as
):
schema_names, actors = prepare_link()
with act_as(actors["source"]):
source_table_name = "Foo"
source_table = create_table(
source_table_name, dj.Computed, "foo: int", parts=[create_table("Bar", dj.Part, "-> master")]
)
prepare_table(schema_names["source"], source_table)
with act_as(actors["local"]):
local_table = link(
actors["source"].credentials.host,
schema_names["source"],
schema_names["outbound"],
"Outbound",
schema_names["local"],
)(type(source_table_name, tuple(), {}))
assert hasattr(local_table, "Bar")
Loading

0 comments on commit 758a26c

Please sign in to comment.