diff --git a/lea/conductor.py b/lea/conductor.py index e2e1f98..6c8ea43 100644 --- a/lea/conductor.py +++ b/lea/conductor.py @@ -425,7 +425,7 @@ def name_user_dataset(self) -> str: def list_materialized_audit_table_refs( self, database_client: DatabaseClient, dataset: str ) -> set[TableRef]: - existing_tables = database_client.list_tables(dataset) + existing_tables = database_client.list_table_stats(dataset) existing_audit_tables = { table_ref: stats for table_ref, stats in existing_tables.items() diff --git a/lea/databases.py b/lea/databases.py index 30ed106..1d2fff8 100644 --- a/lea/databases.py +++ b/lea/databases.py @@ -39,6 +39,9 @@ class DatabaseClient(typing.Protocol): def create_dataset(self, dataset_name: str): pass + def delete_dataset(self, dataset_name: str): + pass + def materialize_script(self, script: scripts.Script) -> DatabaseJob: pass @@ -58,7 +61,10 @@ def delete_and_insert( def delete_table(self, table_ref: scripts.TableRef) -> DatabaseJob: pass - def list_tables(self, dataset_name: str) -> dict[scripts.TableRef, TableStats]: + def list_table_stats(self, dataset_name: str) -> dict[scripts.TableRef, TableStats]: + pass + + def list_table_fields(self, dataset_name: str) -> dict[scripts.TableRef, list[scripts.Field]]: pass @@ -102,7 +108,7 @@ def exception(self) -> Exception: return self.query_job.exception() -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class TableStats: n_rows: int n_bytes: int @@ -129,8 +135,6 @@ def __init__( self.dry_run = dry_run def create_dataset(self, dataset_name: str): - from google.cloud import bigquery - dataset_ref = bigquery.DatasetReference( project=self.write_project_id, dataset_id=dataset_name ) @@ -236,7 +240,7 @@ def delete_table(self, table_ref: scripts.TableRef) -> BigQueryJob: client=self, query_job=self.client.query(delete_code, job_config=job_config) ) - def list_tables(self, dataset_name: str) -> dict[scripts.TableRef, TableStats]: + def list_table_stats(self, dataset_name: str) -> dict[scripts.TableRef, TableStats]: query = f""" SELECT table_id, row_count, size_bytes FROM `{self.write_project_id}.{dataset_name}.__TABLES__` @@ -249,6 +253,22 @@ def list_tables(self, dataset_name: str) -> dict[scripts.TableRef, TableStats]: for row in job.result() } + def list_table_fields(self, dataset_name: str) -> dict[scripts.TableRef, set[scripts.Field]]: + query = f""" + SELECT table_name, column_name + FROM `{self.write_project_id}.{dataset_name}.INFORMATION_SCHEMA.COLUMNS` + """ + job = self.client.query(query) + return { + BigQueryDialect.parse_table_ref(f"{dataset_name}.{table_name}"): [ + scripts.Field(name=row["column_name"]) for _, row in rows.iterrows() + ] + for table_name, rows in job.result() + .to_dataframe() + .sort_values(["table_name", "column_name"]) + .groupby("table_name") + } + def make_job_config( self, script: scripts.SQLScript | None = None, **kwargs ) -> bigquery.QueryJobConfig: diff --git a/lea/field.py b/lea/field.py index 1e425cc..0bae9b3 100644 --- a/lea/field.py +++ b/lea/field.py @@ -4,11 +4,11 @@ import enum -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Field: name: str - tags: set[FieldTag] - description: str + tags: set[FieldTag] = dataclasses.field(default_factory=set) + description: str | None = None @property def is_unique(self):