Skip to content

Commit

Permalink
fix: do not default to compute project when building out scripts
Browse files Browse the repository at this point in the history
runs could fail when using a different compute and target project on
BigQuery runs, because the project wasn't specified when running the
audit tables and so it defaulted to the compute project instead of the
proper target.

Fix by allowing to specify a target project through the CLI, defaulting
it to LEA_BQ_PROJECT_ID if it exists, and passing it down through the
DAG creation to make the table_refs for the tables to run.
  • Loading branch information
GitSquared committed Dec 26, 2024
1 parent 5e85638 commit dc2ad6e
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 10 deletions.
9 changes: 7 additions & 2 deletions lea/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def app():
@app.command()
@click.option("--select", "-m", multiple=True, default=["*"], help="Scripts to materialize.")
@click.option("--dataset", default=None, help="Name of the base dataset.")
@click.option(
"--project", default=None, help="Name of the project where the base dataset is located."
)
@click.option("--scripts", default="views", help="Directory where the scripts are located.")
@click.option(
"--incremental", nargs=2, type=str, multiple=True, help="Incremental field name and value."
Expand All @@ -29,7 +32,9 @@ def app():
@click.option(
"--production", is_flag=True, default=False, help="Whether to run the scripts in production."
)
def run(select, dataset, scripts, incremental, stateful, dry, keep_going, print, production):
def run(
select, dataset, project, scripts, incremental, stateful, dry, keep_going, print, production
):
if not pathlib.Path(scripts).is_dir():
raise click.ClickException(f"Directory {scripts} does not exist")

Expand All @@ -42,7 +47,7 @@ def run(select, dataset, scripts, incremental, stateful, dry, keep_going, print,
incremental_field_name = next(iter(incremental_field_values), None)
incremental_field_values = incremental_field_values[incremental_field_name]

conductor = lea.Conductor(scripts_dir=scripts, dataset_name=dataset)
conductor = lea.Conductor(scripts_dir=scripts, dataset_name=dataset, project_name=project)
conductor.run(
*select,
production=production,
Expand Down
10 changes: 9 additions & 1 deletion lea/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ def delete_audit_tables(


class Conductor:
def __init__(self, scripts_dir: str, dataset_name: str | None = None):
def __init__(
self, scripts_dir: str, dataset_name: str | None = None, project_name: str | None = None
):
# Load environment variables from .env file
# TODO: is is Pythonic to do this here?
dotenv.load_dotenv(".env", verbose=True)
Expand All @@ -379,10 +381,16 @@ def __init__(self, scripts_dir: str, dataset_name: str | None = None):
if dataset_name is None:
raise ValueError("Dataset name could not be inferred")
self.dataset_name = dataset_name

if project_name is None:
project_name = os.environ.get("LEA_BQ_PROJECT_ID")
self.project_name = project_name

self.dag = DAGOfScripts.from_directory(
scripts_dir=self.scripts_dir,
sql_dialect=BigQueryDialect(),
dataset_name=self.dataset_name,
project_name=self.project_name,
)

def make_client(self, dry_run: bool = False):
Expand Down
21 changes: 18 additions & 3 deletions lea/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,28 @@ def __init__(
scripts: list[Script],
scripts_dir: pathlib.Path,
dataset_name: str,
project_name: str | None = None,
):
graphlib.TopologicalSorter.__init__(self, dependency_graph)
self.dependency_graph = dependency_graph
self.scripts = {script.table_ref: script for script in scripts}
self.scripts_dir = scripts_dir
self.dataset_name = dataset_name
self.project_name = project_name

@classmethod
def from_directory(
cls, scripts_dir: pathlib.Path, sql_dialect: SQLDialect, dataset_name: str
cls,
scripts_dir: pathlib.Path,
sql_dialect: SQLDialect,
dataset_name: str,
project_name: str | None = None,
) -> DAGOfScripts:
scripts = read_scripts(
scripts_dir=scripts_dir, sql_dialect=sql_dialect, dataset_name=dataset_name
scripts_dir=scripts_dir,
sql_dialect=sql_dialect,
dataset_name=dataset_name,
project_name=project_name,
)

# Fields in the script's code may contain tags. These tags induce assertion tests, which
Expand All @@ -48,6 +57,7 @@ def from_directory(
scripts=scripts,
scripts_dir=scripts_dir,
dataset_name=dataset_name,
project_name=project_name,
)

def select(self, *queries: str) -> set[TableRef]:
Expand Down Expand Up @@ -104,7 +114,12 @@ def _select(
return

*schema, name = query.split(".")
table_ref = TableRef(dataset=self.dataset_name, schema=tuple(schema), name=name)
table_ref = TableRef(
dataset=self.dataset_name,
schema=tuple(schema),
name=name,
project=self.project_name,
)
yield table_ref
if include_ancestors:
yield from iter_ancestors(self.dependency_graph, node=table_ref)
Expand Down
2 changes: 2 additions & 0 deletions lea/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def materialize_sql_script(self, sql_script: scripts.SQLScript) -> BigQueryJob:
table_ref_str = BigQueryDialect.format_table_ref(sql_script.table_ref)
destination = bigquery.TableReference.from_string(
f"{self.write_project_id}.{table_ref_str}"
if not sql_script.table_ref.project
else table_ref_str
)
job_config = self.make_job_config(
script=sql_script, destination=destination, write_disposition="WRITE_TRUNCATE"
Expand Down
16 changes: 13 additions & 3 deletions lea/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def __post_init__(self):

@classmethod
def from_path(
cls, scripts_dir: pathlib.Path, relative_path: pathlib.Path, sql_dialect: SQLDialect
cls,
scripts_dir: pathlib.Path,
relative_path: pathlib.Path,
sql_dialect: SQLDialect,
project_name: str | None,
) -> SQLScript:
# Either the file is a Jinja template
if relative_path.suffixes == [".sql", ".jinja"]:
Expand All @@ -82,7 +86,9 @@ def from_path(
code = (scripts_dir / relative_path).read_text().rstrip().rstrip(";")

return cls(
table_ref=TableRef.from_path(scripts_dir=scripts_dir, relative_path=relative_path),
table_ref=TableRef.from_path(
scripts_dir=scripts_dir, relative_path=relative_path, project_name=project_name
),
code=code,
sql_dialect=sql_dialect,
)
Expand Down Expand Up @@ -181,7 +187,10 @@ def __rich__(self):


def read_scripts(
scripts_dir: pathlib.Path, sql_dialect: SQLDialect, dataset_name: str
scripts_dir: pathlib.Path,
sql_dialect: SQLDialect,
dataset_name: str,
project_name: str | None = None,
) -> list[Script]:
def read_script(path: pathlib.Path) -> Script:
match tuple(path.suffixes):
Expand All @@ -190,6 +199,7 @@ def read_script(path: pathlib.Path) -> Script:
scripts_dir=scripts_dir,
relative_path=path.relative_to(scripts_dir),
sql_dialect=sql_dialect,
project_name=project_name,
)
case _:
raise ValueError(f"Unsupported script type: {path}")
Expand Down
5 changes: 4 additions & 1 deletion lea/table_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ def __str__(self):
return ".".join(filter(None, [self.project, self.dataset, *self.schema, self.name]))

@classmethod
def from_path(cls, scripts_dir: pathlib.Path, relative_path: pathlib.Path) -> TableRef:
def from_path(
cls, scripts_dir: pathlib.Path, relative_path: pathlib.Path, project_name: str | None
) -> TableRef:
parts = list(filter(None, relative_path.parts))
*schema, filename = parts
return cls(
dataset=scripts_dir.name,
schema=tuple(schema),
# Remove the ex
name=filename.split(".")[0],
project=project_name,
)

def replace_dataset(self, dataset: str) -> TableRef:
Expand Down

0 comments on commit dc2ad6e

Please sign in to comment.