diff --git a/dbtmetabase/__init__.py b/dbtmetabase/__init__.py index d9a1988b..574db846 100644 --- a/dbtmetabase/__init__.py +++ b/dbtmetabase/__init__.py @@ -5,7 +5,10 @@ logger = logging.getLogger(__name__) -__all__ = ["DbtReader", "MetabaseClient"] +__all__ = [ + "DbtReader", + "MetabaseClient", +] try: from ._version import __version__ as version # type: ignore diff --git a/dbtmetabase/__main__.py b/dbtmetabase/__main__.py index d6b526d5..8a9acddc 100644 --- a/dbtmetabase/__main__.py +++ b/dbtmetabase/__main__.py @@ -1,86 +1,15 @@ import functools import logging -from logging.handlers import RotatingFileHandler from pathlib import Path -from typing import Callable, Iterable, List, Optional, Union +from typing import Callable, Iterable, Optional import click import yaml -from rich.logging import RichHandler from typing_extensions import cast +from ._format import click_list_option_kwargs, setup_logging from .dbt import DbtReader -from .metabase import MetabaseClient - -LOG_PATH = Path.home().absolute() / ".dbt-metabase" / "logs" / "dbtmetabase.log" - -logger = logging.getLogger(__name__) - - -def _setup_logger(level: int = logging.INFO): - """Basic logger configuration for the CLI. - - Args: - level (int, optional): Logging level. Defaults to logging.INFO. - """ - - LOG_PATH.parent.mkdir(parents=True, exist_ok=True) - file_handler = RotatingFileHandler( - filename=LOG_PATH, - maxBytes=int(1e6), - backupCount=3, - ) - file_handler.setFormatter( - logging.Formatter("%(asctime)s — %(name)s — %(levelname)s — %(message)s") - ) - file_handler.setLevel(logging.WARNING) - - rich_handler = RichHandler( - level=level, - rich_tracebacks=True, - markup=True, - show_time=False, - ) - - logging.basicConfig( - level=level, - format="%(asctime)s — %(message)s", - datefmt="%Y-%m-%d %H:%M:%S %z", - handlers=[file_handler, rich_handler], - force=True, - ) - - -def _comma_separated_list_callback( - ctx: click.Context, - param: click.Option, - value: Union[str, List[str]], -) -> Optional[List[str]]: - """Click callback for handling comma-separated lists.""" - - if value is None: - return None - - assert ( - param.type == click.UNPROCESSED or param.type.name == "list" - ), "comma-separated list options must be of type UNPROCESSED or list" - - if ctx.get_parameter_source(str(param.name)) in ( - click.core.ParameterSource.DEFAULT, - click.core.ParameterSource.DEFAULT_MAP, - ) and isinstance(value, list): - # Lists in defaults (config or option) should be lists - return value - - elif isinstance(value, str): - str_value = value - if isinstance(value, list): - # When type=list, string value will be a list of chars - str_value = "".join(value) - else: - raise click.BadParameter("must be comma-separated list") - - return str_value.split(",") +from .metabase import MetabaseClient, MetabaseExposuresClient, MetabaseModelsClient @click.group() @@ -137,8 +66,7 @@ def _add_setup(func: Callable) -> Callable: metavar="SCHEMAS", envvar="DBT_SCHEMA_EXCLUDES", show_envvar=True, - type=click.UNPROCESSED, - callback=_comma_separated_list_callback, + **click_list_option_kwargs(), help="Target dbt schemas to exclude. Ignored in project parser.", ) @click.option( @@ -146,8 +74,7 @@ def _add_setup(func: Callable) -> Callable: metavar="MODELS", envvar="DBT_INCLUDES", show_envvar=True, - type=click.UNPROCESSED, - callback=_comma_separated_list_callback, + **click_list_option_kwargs(), help="Include specific dbt models names.", ) @click.option( @@ -155,8 +82,7 @@ def _add_setup(func: Callable) -> Callable: metavar="MODELS", envvar="DBT_EXCLUDES", show_envvar=True, - type=click.UNPROCESSED, - callback=_comma_separated_list_callback, + **click_list_option_kwargs(), help="Exclude specific dbt model names.", ) @click.option( @@ -197,7 +123,7 @@ def _add_setup(func: Callable) -> Callable: "metabase_verify", envvar="METABASE_VERIFY", show_envvar=True, - default=True, + default=MetabaseClient.DEFAULT_VERIFY, help="Verify the TLS certificate at the Metabase end.", ) @click.option( @@ -214,7 +140,7 @@ def _add_setup(func: Callable) -> Callable: envvar="METABASE_TIMEOUT", show_envvar=True, type=click.INT, - default=15, + default=MetabaseClient.DEFAULT_HTTP_TIMEOUT, show_default=True, help="Metabase API HTTP timeout in seconds.", ) @@ -242,7 +168,10 @@ def wrapper( verbose: bool, **kwargs, ): - _setup_logger(level=logging.DEBUG if verbose else logging.INFO) + setup_logging( + level=logging.DEBUG if verbose else logging.INFO, + path=Path.home().absolute() / ".dbt-metabase" / "logs" / "dbtmetabase.log", + ) return func( dbt_reader=DbtReader( @@ -299,7 +228,7 @@ def wrapper( metavar="SECS", envvar="METABASE_SYNC_TIMEOUT", show_envvar=True, - default=30, + default=MetabaseModelsClient.DEFAULT_SYNC_TIMEOUT, type=click.INT, help="Synchronization timeout in secs. When set, command fails on failed synchronization. Otherwise, command proceeds regardless. Only valid if sync is enabled.", ) @@ -307,6 +236,7 @@ def wrapper( "--metabase-exclude-sources", envvar="METABASE_EXCLUDE_SOURCES", show_envvar=True, + default=MetabaseModelsClient.DEFAULT_EXCLUDE_SOURCES, is_flag=True, help="Skip exporting sources to Metabase.", ) @@ -338,7 +268,7 @@ def models( envvar="OUTPUT_PATH", show_envvar=True, type=click.Path(exists=True, file_okay=False), - default=".", + default=MetabaseExposuresClient.DEFAULT_OUTPUT_PATH, show_default=True, help="Output path for generated exposure YAML files.", ) @@ -354,6 +284,7 @@ def models( envvar="METABASE_INCLUDE_PERSONAL_COLLECTIONS", show_envvar=True, is_flag=True, + default=MetabaseExposuresClient.DEFAULT_INCLUDE_PERSONAL_COLLECTIONS, help="Include personal collections when parsing exposures.", ) @click.option( @@ -361,8 +292,7 @@ def models( metavar="COLLECTIONS", envvar="METABASE_COLLECTION_INCLUDES", show_envvar=True, - type=click.UNPROCESSED, - callback=_comma_separated_list_callback, + **click_list_option_kwargs(), help="Metabase collection names to includes.", ) @click.option( @@ -370,8 +300,7 @@ def models( metavar="COLLECTIONS", envvar="METABASE_COLLECTION_EXCLUDES", show_envvar=True, - type=click.UNPROCESSED, - callback=_comma_separated_list_callback, + **click_list_option_kwargs(), help="Metabase collection names to exclude.", ) def exposures( diff --git a/dbtmetabase/_format.py b/dbtmetabase/_format.py index 22dea718..cb2fc236 100644 --- a/dbtmetabase/_format.py +++ b/dbtmetabase/_format.py @@ -1,5 +1,98 @@ +import logging import re -from typing import Optional +from logging.handlers import RotatingFileHandler +from pathlib import Path +from typing import Any, Iterable, List, Mapping, Optional, Union + +import click +from rich.logging import RichHandler + +_logger = logging.getLogger(__name__) + + +class _NullValue(str): + """Explicitly null field value.""" + + def __eq__(self, other: object) -> bool: + return other is None + + +NullValue = _NullValue() + + +def setup_logging(level: int, path: Path): + """Basic logger configuration for the CLI. + + Args: + level (int): Logging level. Defaults to logging.INFO. + path (Path): Path to file logs. + """ + + path.parent.mkdir(parents=True, exist_ok=True) + file_handler = RotatingFileHandler( + filename=path, + maxBytes=int(1e6), + backupCount=3, + ) + file_handler.setFormatter( + logging.Formatter("%(asctime)s — %(name)s — %(levelname)s — %(message)s") + ) + file_handler.setLevel(logging.WARNING) + + rich_handler = RichHandler( + level=level, + rich_tracebacks=True, + markup=True, + show_time=False, + ) + + logging.basicConfig( + level=level, + format="%(asctime)s — %(message)s", + datefmt="%Y-%m-%d %H:%M:%S %z", + handlers=[file_handler, rich_handler], + force=True, + ) + + +def click_list_option_kwargs() -> Mapping[str, Any]: + """Click option that accepts comma-separated values. + + Built-in list only allows repeated flags, which is ugly for larger lists. + + Returns: + Mapping[str, Any]: Mapping of kwargs (to be unpacked with **). + """ + + def callback( + ctx: click.Context, + param: click.Option, + value: Union[str, List[str]], + ) -> Optional[List[str]]: + if value is None: + return None + + if ctx.get_parameter_source(str(param.name)) in ( + click.core.ParameterSource.DEFAULT, + click.core.ParameterSource.DEFAULT_MAP, + ) and isinstance(value, list): + # Lists in defaults (config or option) should be lists + return value + + elif isinstance(value, str): + str_value = value + if isinstance(value, list): + # When type=list, string value will be a list of chars + str_value = "".join(value) + else: + raise click.BadParameter("must be comma-separated list") + + return str_value.split(",") + + return { + "type": click.UNPROCESSED, + "callback": callback, + } def safe_name(text: Optional[str]) -> str: @@ -26,3 +119,23 @@ def safe_description(text: Optional[str]) -> str: str: Sanitized string with escaped Jinja syntax. """ return re.sub(r"{{(.*)}}", r"\1", text or "") + + +def scan_fields(t: Mapping, fields: Iterable[str], ns: str) -> Mapping: + """Reads meta fields from a schem object. + + Args: + t (Mapping): Target to scan for fields. + fields (List): List of fields to accept. + ns (str): Field namespace (separated by .). + + Returns: + Mapping: Field values. + """ + + vals = {} + for field in fields: + if f"{ns}.{field}" in t: + value = t[f"{ns}.{field}"] + vals[field] = value if value is not None else NullValue + return vals diff --git a/dbtmetabase/dbt/__init__.py b/dbtmetabase/dbt/__init__.py new file mode 100644 index 00000000..1922f28b --- /dev/null +++ b/dbtmetabase/dbt/__init__.py @@ -0,0 +1,10 @@ +from .models import MetabaseColumn, MetabaseModel, ResourceType +from .reader import DEFAULT_SCHEMA, DbtReader + +__all__ = [ + "DEFAULT_SCHEMA", + "DbtReader", + "MetabaseModel", + "MetabaseColumn", + "ResourceType", +] diff --git a/dbtmetabase/dbt/models.py b/dbtmetabase/dbt/models.py new file mode 100644 index 00000000..1b2cdcdb --- /dev/null +++ b/dbtmetabase/dbt/models.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import MutableMapping, Optional, Sequence + + +class ResourceType(str, Enum): + node = "nodes" + source = "sources" + + +@dataclass +class MetabaseColumn: + name: str + description: Optional[str] = None + + display_name: Optional[str] = None + visibility_type: Optional[str] = None + semantic_type: Optional[str] = None + has_field_values: Optional[str] = None + coercion_strategy: Optional[str] = None + number_style: Optional[str] = None + + fk_target_table: Optional[str] = None + fk_target_field: Optional[str] = None + + meta_fields: MutableMapping = field(default_factory=dict) + + +@dataclass +class MetabaseModel: + name: str + schema: str + description: str = "" + + display_name: Optional[str] = None + visibility_type: Optional[str] = None + points_of_interest: Optional[str] = None + caveats: Optional[str] = None + + res_type: ResourceType = ResourceType.node + source: Optional[str] = None + unique_id: Optional[str] = None + + columns: Sequence[MetabaseColumn] = field(default_factory=list) + + @property + def ref(self) -> Optional[str]: + if self.res_type == ResourceType.node: + return f"ref('{self.name}')" + elif self.res_type == ResourceType.source: + return f"source('{self.source}', '{self.name}')" + return None diff --git a/dbtmetabase/dbt.py b/dbtmetabase/dbt/reader.py similarity index 71% rename from dbtmetabase/dbt.py rename to dbtmetabase/dbt/reader.py index 00403514..ab1f2efe 100644 --- a/dbtmetabase/dbt.py +++ b/dbtmetabase/dbt/reader.py @@ -1,92 +1,33 @@ -import dataclasses import json import logging -import re -from enum import Enum from pathlib import Path -from typing import Iterable, List, Mapping, MutableMapping, Optional, Sequence +from typing import Iterable, List, Mapping, Optional -logger = logging.getLogger(__name__) +from .._format import scan_fields +from .models import MetabaseColumn, MetabaseModel, ResourceType + +_logger = logging.getLogger(__name__) # Allowed metabase.* fields -_METABASE_COMMON_META_FIELDS = [ +_COMMON_META_FIELDS = [ "display_name", "visibility_type", ] # Must be covered by MetabaseColumn attributes -METABASE_COLUMN_META_FIELDS = _METABASE_COMMON_META_FIELDS + [ +_COLUMN_META_FIELDS = _COMMON_META_FIELDS + [ "semantic_type", "has_field_values", "coercion_strategy", "number_style", ] # Must be covered by MetabaseModel attributes -METABASE_MODEL_META_FIELDS = _METABASE_COMMON_META_FIELDS + [ +_MODEL_META_FIELDS = _COMMON_META_FIELDS + [ "points_of_interest", "caveats", ] # Default model schema (only schema in BigQuery) -METABASE_MODEL_DEFAULT_SCHEMA = "PUBLIC" - - -class ModelType(str, Enum): - nodes = "nodes" - sources = "sources" - - -@dataclasses.dataclass -class MetabaseColumn: - name: str - description: Optional[str] = None - - display_name: Optional[str] = None - visibility_type: Optional[str] = None - semantic_type: Optional[str] = None - has_field_values: Optional[str] = None - coercion_strategy: Optional[str] = None - number_style: Optional[str] = None - - fk_target_table: Optional[str] = None - fk_target_field: Optional[str] = None - - meta_fields: MutableMapping = dataclasses.field(default_factory=dict) - - -@dataclasses.dataclass -class MetabaseModel: - name: str - schema: str - description: str = "" - - display_name: Optional[str] = None - visibility_type: Optional[str] = None - points_of_interest: Optional[str] = None - caveats: Optional[str] = None - - model_type: ModelType = ModelType.nodes - source: Optional[str] = None - unique_id: Optional[str] = None - - columns: Sequence[MetabaseColumn] = dataclasses.field(default_factory=list) - - @property - def ref(self) -> Optional[str]: - if self.model_type == ModelType.nodes: - return f"ref('{self.name}')" - elif self.model_type == ModelType.sources: - return f"source('{self.source}', '{self.name}')" - return None - - -class _NullValue(str): - """Explicitly null field value.""" - - def __eq__(self, other: object) -> bool: - return other is None - - -NullValue = _NullValue() +DEFAULT_SCHEMA = "PUBLIC" class DbtReader: @@ -144,17 +85,17 @@ def read_models( model_database = node["database"].upper() if node["resource_type"] != "model": - logger.debug("Skipping %s not of resource type model", model_name) + _logger.debug("Skipping %s not of resource type model", model_name) continue if node["config"]["materialized"] == "ephemeral": - logger.debug( + _logger.debug( "Skipping ephemeral model %s not manifested in database", model_name ) continue if model_database != self.database: - logger.debug( + _logger.debug( "Skipping %s in database %s, not in target database %s", model_name, model_database, @@ -163,7 +104,7 @@ def read_models( continue if self.schema and model_schema != self.schema: - logger.debug( + _logger.debug( "Skipping %s in schema %s not in target schema %s", model_name, model_schema, @@ -172,15 +113,15 @@ def read_models( continue if model_schema in self.schema_excludes: - logger.debug( + _logger.debug( "Skipping %s in schema %s marked for exclusion", model_name, model_schema, ) continue - if not self.model_selected(model_name): - logger.debug( + if not self._model_selected(model_name): + _logger.debug( "Skipping %s not included in includes or excluded by excludes", model_name, ) @@ -192,7 +133,7 @@ def read_models( node, include_tags=include_tags, docs_url=docs_url, - model_type=ModelType.nodes, + res_type=ResourceType.node, source=None, ) ) @@ -203,17 +144,17 @@ def read_models( source_database = node["database"].upper() if node["resource_type"] != "source": - logger.debug("Skipping %s not of resource type source", source_name) + _logger.debug("Skipping %s not of resource type source", source_name) continue if source_database != self.database: - logger.debug( + _logger.debug( "Skipping %s not in target database %s", source_name, self.database ) continue if self.schema and source_schema != self.schema: - logger.debug( + _logger.debug( "Skipping %s in schema %s not in target schema %s", source_name, source_schema, @@ -222,15 +163,15 @@ def read_models( continue if source_schema in self.schema_excludes: - logger.debug( + _logger.debug( "Skipping %s in schema %s marked for exclusion", source_name, source_schema, ) continue - if not self.model_selected(source_name): - logger.debug( + if not self._model_selected(source_name): + _logger.debug( "Skipping %s not included in includes or excluded by excludes", source_name, ) @@ -242,7 +183,7 @@ def read_models( node, include_tags=include_tags, docs_url=docs_url, - model_type=ModelType.sources, + res_type=ResourceType.source, source=node["source_name"], ) ) @@ -254,7 +195,7 @@ def _read_model( manifest: Mapping, model: dict, source: Optional[str] = None, - model_type: ModelType = ModelType.nodes, + res_type: ResourceType = ResourceType.node, include_tags: bool = True, docs_url: Optional[str] = None, ) -> MetabaseModel: @@ -265,7 +206,7 @@ def _read_model( relationships = self._read_model_relationships( manifest=manifest, - model_type=model_type, + res_type=res_type, unique_id=unique_id, ) @@ -301,22 +242,29 @@ def _read_model( schema=schema, description=description, columns=metabase_columns, - model_type=model_type, + res_type=res_type, unique_id=unique_id, source=source, - **self.read_meta_fields(model, METABASE_MODEL_META_FIELDS), + **scan_fields( + model.get("meta", {}), + fields=_MODEL_META_FIELDS, + ns="metabase", + ), ) def _read_model_relationships( - self, manifest: Mapping, model_type: ModelType, unique_id: str + self, + manifest: Mapping, + res_type: ResourceType, + unique_id: str, ) -> Mapping[str, Mapping[str, str]]: children = manifest["child_map"][unique_id] relationship_tests = {} for child_id in children: child = {} - if manifest[model_type]: - child = manifest[model_type].get(child_id, {}) + if manifest[res_type]: + child = manifest[res_type].get(child_id, {}) # Only proceed if we are seeing an explicitly declared relationship test if ( @@ -329,9 +277,9 @@ def _read_model_relationships( # From experience, nodes contains at most two tables: the referenced model and the current model. # Note, sometimes only the referenced model is returned. - depends_on_nodes = list(child["depends_on"][model_type]) + depends_on_nodes = list(child["depends_on"][res_type]) if len(depends_on_nodes) > 2: - logger.warning( + _logger.warning( "Expected at most two nodes, got %d {} nodes, skipping %s {}", len(depends_on_nodes), unique_id, @@ -342,8 +290,8 @@ def _read_model_relationships( # Otherwise, the primary key of the current model would be (incorrectly) determined to be a foreign key. is_incoming_relationship_test = depends_on_nodes[1] != unique_id if len(depends_on_nodes) == 2 and is_incoming_relationship_test: - logger.debug( - "Skip this incoming relationship test, concerning nodes %s.", + _logger.debug( + "Skip this incoming relationship test, concerning nodes %s", depends_on_nodes, ) continue @@ -354,7 +302,7 @@ def _read_model_relationships( depends_on_nodes.remove(unique_id) if len(depends_on_nodes) != 1: - logger.warning( + _logger.warning( "Expected single node after filtering, got %d nodes, skipping %s", len(depends_on_nodes), unique_id, @@ -363,21 +311,21 @@ def _read_model_relationships( depends_on_id = depends_on_nodes[0] - foreign_key_model = manifest[model_type].get(depends_on_id, {}) + foreign_key_model = manifest[res_type].get(depends_on_id, {}) fk_target_table_alias = foreign_key_model.get( "alias", foreign_key_model.get("identifier", foreign_key_model.get("name")), ) if not fk_target_table_alias: - logger.debug( + _logger.debug( "Could not resolve depends on model id %s to a model in manifest", depends_on_id, ) continue - fk_target_schema = manifest[model_type][depends_on_id].get( - "schema", METABASE_MODEL_DEFAULT_SCHEMA + fk_target_schema = manifest[res_type][depends_on_id].get( + "schema", DEFAULT_SCHEMA ) fk_target_field = child["test_metadata"]["kwargs"]["field"].strip('"') @@ -394,15 +342,17 @@ def _read_column( schema: str, relationship: Optional[Mapping], ) -> MetabaseColumn: - column_name = column.get("name", "").upper().strip('"') - column_description = column.get("description") metabase_column = MetabaseColumn( - name=column_name, - description=column_description, - **self.read_meta_fields(column, METABASE_COLUMN_META_FIELDS), + name=column.get("name", "").upper().strip('"'), + description=column.get("description"), + **scan_fields( + column.get("meta", {}), + fields=_COLUMN_META_FIELDS, + ns="metabase", + ), ) - self.set_column_foreign_key( + self._set_column_fk( column=column, metabase_column=metabase_column, table=relationship["fk_target_table"] if relationship else None, @@ -412,19 +362,7 @@ def _read_column( return metabase_column - def model_selected(self, name: str) -> bool: - """Checks whether model passes inclusion/exclusion criteria. - - Args: - name (str): Model name. - - Returns: - bool: True if included, false otherwise. - """ - n = name.upper() - return n not in self.excludes and (not self.includes or n in self.includes) - - def set_column_foreign_key( + def _set_column_fk( self, column: Mapping, metabase_column: MetabaseColumn, @@ -448,8 +386,8 @@ def set_column_foreign_key( if not table or not field: if table or field: - logger.warning( - "Foreign key requires table and field for column %s", + _logger.warning( + "FK requires table and field for column %s", metabase_column.name, ) return @@ -463,47 +401,14 @@ def set_column_foreign_key( [x.strip('"').upper() for x in table_path] ) metabase_column.fk_target_field = field.strip('"').upper() - logger.debug( + _logger.debug( "Relation from %s to %s.%s", metabase_column.name, metabase_column.fk_target_table, metabase_column.fk_target_field, ) - @staticmethod - def read_meta_fields(obj: Mapping, fields: List) -> Mapping: - """Reads meta fields from a schem object. - - Args: - obj (Mapping): Schema object. - fields (List): List of fields to read. - - Returns: - Mapping: Field values. - """ - - vals = {} - meta = obj.get("meta", {}) - for field in fields: - if f"metabase.{field}" in meta: - value = meta[f"metabase.{field}"] - vals[field] = value if value is not None else NullValue - return vals - - @staticmethod - def parse_ref(text: str) -> Optional[str]: - """Parses dbt ref() or source() statement. - - Arguments: - text {str} -- Full statement in dbt YAML. - - Returns: - str -- Name of the reference. - """ - - # We are catching the rightmost argument of either source or ref which is ultimately the table name - matches = re.findall(r"['\"]([\w\_\-\ ]+)['\"][ ]*\)$", text.strip()) - if matches: - logger.debug("%s -> %s", text, matches[0]) - return matches[0] - return None + def _model_selected(self, name: str) -> bool: + """Checks whether model passes inclusion/exclusion criteria.""" + n = name.upper() + return n not in self.excludes and (not self.includes or n in self.includes) diff --git a/dbtmetabase/metabase.py b/dbtmetabase/metabase.py deleted file mode 100644 index 393ca16e..00000000 --- a/dbtmetabase/metabase.py +++ /dev/null @@ -1,970 +0,0 @@ -from __future__ import annotations - -import logging -import re -import time -from pathlib import Path -from typing import ( - Any, - Dict, - Iterable, - List, - Mapping, - MutableMapping, - Optional, - Tuple, - Union, -) - -import requests -import yaml -from requests.adapters import HTTPAdapter, Retry - -from ._format import safe_description, safe_name -from .dbt import ( - METABASE_MODEL_DEFAULT_SCHEMA, - MetabaseColumn, - MetabaseModel, - ModelType, - NullValue, -) - -logger = logging.getLogger(__name__) - - -class MetabaseArgumentError(ValueError): - """Invalid Metabase arguments supplied.""" - - -class MetabaseRuntimeError(RuntimeError): - """Metabase execution failed.""" - - -class _MetabaseClientJob: - """Scoped abstraction for jobs depending on the Metabase client.""" - - def __init__(self, client: MetabaseClient): - self.client = client - - -class _ExportModelsJob(_MetabaseClientJob): - """Job abstraction for exporting models.""" - - _SYNC_PERIOD = 5 - - def __init__( - self, - client: MetabaseClient, - database: str, - models: List[MetabaseModel], - exclude_sources: bool, - sync_timeout: int, - ): - super().__init__(client) - - self.database = database - self.models = [ - model - for model in models - if model.model_type != ModelType.sources or not exclude_sources - ] - self.sync_timeout = sync_timeout - - self.tables: Mapping[str, MutableMapping] = {} - self.updates: MutableMapping[str, MutableMapping[str, Any]] = {} - - def execute(self): - success = True - - database_id = None - for database in self.client.api("get", "/api/database"): - if database["name"].upper() == self.database.upper(): - database_id = database["id"] - break - if not database_id: - raise MetabaseRuntimeError(f"Cannot find database by name {self.database}") - - if self.sync_timeout: - self.client.api("post", f"/api/database/{database_id}/sync_schema") - time.sleep(self._SYNC_PERIOD) - - deadline = int(time.time()) + self.sync_timeout - synced = False - while not synced: - tables = self._load_tables(database_id) - - synced = True - for model in self.models: - schema_name = model.schema.upper() - model_name = model.name.upper() - table_key = f"{schema_name}.{model_name}" - - table = tables.get(table_key) - if not table: - logger.warning( - "Model %s not found in %s schema", table_key, schema_name - ) - synced = False - continue - - for column in model.columns: - column_name = column.name.upper() - - field = table.get("fields", {}).get(column_name) - if not field: - logger.warning( - "Column %s not found in %s model", column_name, table_key - ) - synced = False - continue - - self.tables = tables - - if int(time.time()) < deadline: - time.sleep(self._SYNC_PERIOD) - - if not synced and self.sync_timeout: - raise MetabaseRuntimeError("Unable to sync models between dbt and Metabase") - - for model in self.models: - success &= self._export_model(model) - - for update in self.updates.values(): - self.client.api( - "put", - f"/api/{update['kind']}/{update['id']}", - json=update["body"], - ) - logger.info( - "API %s/%s updated successfully: %s", - update["kind"], - update["id"], - ", ".join(update.get("body", {}).keys()), - ) - - if not success: - raise MetabaseRuntimeError( - "Model export encountered non-critical errors, check output" - ) - - def queue_update(self, entity: MutableMapping, delta: Mapping): - entity.update(delta) - - key = f"{entity['kind']}.{entity['id']}" - update = self.updates.get(key, {}) - update["kind"] = entity["kind"] - update["id"] = entity["id"] - - body = update.get("body", {}) - body.update(delta) - update["body"] = body - - self.updates[key] = update - - def _export_model(self, model: MetabaseModel) -> bool: - """Exports one dbt model to Metabase database schema. - - Arguments: - model {dict} -- One dbt model read from project. - - Returns: - bool -- True if exported successfully, false if there were errors. - """ - - success = True - - schema_name = model.schema.upper() - model_name = model.name.upper() - table_key = f"{schema_name}.{model_name}" - - api_table = self.tables.get(table_key) - if not api_table: - logger.error("Table %s does not exist in Metabase", table_key) - return False - - # Empty strings not accepted by Metabase - model_display_name = model.display_name or None - model_description = model.description or None - model_points_of_interest = model.points_of_interest or None - model_caveats = model.caveats or None - model_visibility = model.visibility_type or None - - body_table = {} - - # Update if specified, otherwise reset one that had been set - api_display_name = api_table.get("display_name") - if api_display_name != model_display_name and ( - model_display_name - or safe_name(api_display_name) != safe_name(api_table.get("name")) - ): - body_table["display_name"] = model_display_name - - if api_table.get("description") != model_description: - body_table["description"] = model_description - if api_table.get("points_of_interest") != model_points_of_interest: - body_table["points_of_interest"] = model_points_of_interest - if api_table.get("caveats") != model_caveats: - body_table["caveats"] = model_caveats - if api_table.get("visibility_type") != model_visibility: - body_table["visibility_type"] = model_visibility - - if body_table: - self.queue_update(entity=api_table, delta=body_table) - logger.info("Table %s will be updated", table_key) - else: - logger.info("Table %s is up-to-date", table_key) - - for column in model.columns: - success &= self._export_column(schema_name, model_name, column) - - return success - - def _export_column( - self, - schema_name: str, - model_name: str, - column: MetabaseColumn, - ) -> bool: - """Exports one dbt column to Metabase database schema. - - Arguments: - schema_name {str} -- Target schema name.s - model_name {str} -- One dbt model name read from project. - column {dict} -- One dbt column read from project. - - Returns: - bool -- True if exported successfully, false if there were errors. - """ - - success = True - - table_key = f"{schema_name}.{model_name}" - column_name = column.name.upper() - - api_field = self.tables.get(table_key, {}).get("fields", {}).get(column_name) - if not api_field: - logger.error( - "Field %s.%s does not exist in Metabase", - table_key, - column_name, - ) - return False - - if "special_type" in api_field: - semantic_type_key = "special_type" - else: - semantic_type_key = "semantic_type" - - fk_target_field_id = None - if column.semantic_type == "type/FK": - # Target table could be aliased if we parse_ref() on a source, so we caught aliases during model parsing - # This way we can unpack any alias mapped to fk_target_table when using yml project reader - target_table = ( - column.fk_target_table.upper() - if column.fk_target_table is not None - else None - ) - target_field = ( - column.fk_target_field.upper() - if column.fk_target_field is not None - else None - ) - - if not target_table or not target_field: - logger.info( - "Skipping FK resolution for %s table, %s field not resolved during dbt parsing", - table_key, - target_field, - ) - - else: - logger.debug( - "Looking for field %s in table %s", - target_field, - target_table, - ) - - fk_target_field = ( - self.tables.get(target_table, {}) - .get("fields", {}) - .get(target_field) - ) - if fk_target_field: - fk_target_field_id = fk_target_field.get("id") - if fk_target_field.get(semantic_type_key) != "type/PK": - logger.info( - "API field/%s will become PK (for %s column FK)", - fk_target_field_id, - column_name, - ) - body_fk_target_field = { - semantic_type_key: "type/PK", - } - self.queue_update( - entity=fk_target_field, delta=body_fk_target_field - ) - else: - logger.info( - "API field/%s is already PK (for %s column FK)", - fk_target_field_id, - column_name, - ) - else: - logger.error( - "Unable to find PK for %s.%s column FK", - target_table, - target_field, - ) - success = False - - # Empty strings not accepted by Metabase - column_description = column.description or None - column_display_name = column.display_name or None - column_visibility = column.visibility_type or "normal" - - # Preserve this relationship by default - if api_field["fk_target_field_id"] and not fk_target_field_id: - fk_target_field_id = api_field["fk_target_field_id"] - - body_field: MutableMapping[str, Optional[Any]] = {} - - # Update if specified, otherwise reset one that had been set - api_display_name = api_field.get("display_name") - if api_display_name != column_display_name and ( - column_display_name - or safe_name(api_display_name) != safe_name(api_field.get("name")) - ): - body_field["display_name"] = column_display_name - - if api_field.get("description") != column_description: - body_field["description"] = column_description - if api_field.get("visibility_type") != column_visibility: - body_field["visibility_type"] = column_visibility - if api_field.get("fk_target_field_id") != fk_target_field_id: - body_field["fk_target_field_id"] = fk_target_field_id - if ( - api_field.get("has_field_values") != column.has_field_values - and column.has_field_values - ): - body_field["has_field_values"] = column.has_field_values - if ( - api_field.get("coercion_strategy") != column.coercion_strategy - and column.coercion_strategy - ): - body_field["coercion_strategy"] = column.coercion_strategy - - settings = api_field.get("settings") or {} - if settings.get("number_style") != column.number_style and column.number_style: - settings["number_style"] = column.number_style - - if settings: - body_field["settings"] = settings - - # Allow explicit null type to override detected one - api_semantic_type = api_field.get(semantic_type_key) - if (column.semantic_type and api_semantic_type != column.semantic_type) or ( - column.semantic_type is NullValue and api_semantic_type - ): - body_field[semantic_type_key] = column.semantic_type or None - - if body_field: - self.queue_update(entity=api_field, delta=body_field) - logger.info("Field %s.%s will be updated", model_name, column_name) - else: - logger.info("Field %s.%s is up-to-date", model_name, column_name) - - return success - - def _load_tables(self, database_id: str) -> Mapping[str, MutableMapping]: - tables = {} - - metadata = self.client.api( - "get", - f"/api/database/{database_id}/metadata", - params={"include_hidden": "true"}, - ) - - bigquery_schema = metadata.get("details", {}).get("dataset-id") - - for table in metadata.get("tables", []): - # table[schema] is null for bigquery datasets - table["schema"] = ( - table.get("schema") or bigquery_schema or METABASE_MODEL_DEFAULT_SCHEMA - ).upper() - - fields = {} - for field in table.get("fields", []): - new_field = field.copy() - new_field["kind"] = "field" - - field_name = field["name"].upper() - fields[field_name] = new_field - - new_table = table.copy() - new_table["kind"] = "table" - new_table["fields"] = fields - - schema_name = table["schema"].upper() - table_name = table["name"].upper() - tables[f"{schema_name}.{table_name}"] = new_table - - return tables - - -class _ExtractExposuresJob(_MetabaseClientJob): - _RESOURCE_VERSION = 2 - - # This regex is looking for from and join clauses, and extracting the table part. - # It won't recognize some valid sql table references, such as `from "table with spaces"`. - _EXPOSURE_PARSER = re.compile(r"[FfJj][RrOo][OoIi][MmNn]\s+([\w.\"]+)") - _CTE_PARSER = re.compile( - r"[Ww][Ii][Tt][Hh]\s+\b(\w+)\b\s+as|[)]\s*[,]\s*\b(\w+)\b\s+as" - ) - - class DbtDumper(yaml.Dumper): - def increase_indent(self, flow=False, indentless=False): - return super().increase_indent(flow, indentless=False) - - def __init__( - self, - client: MetabaseClient, - models: List[MetabaseModel], - output_path: str, - output_grouping: Optional[str], - include_personal_collections: bool, - collection_includes: Optional[Iterable], - collection_excludes: Optional[Iterable], - ): - super().__init__(client) - - self.model_refs = {model.name.upper(): model.ref for model in models} - self.output_path = Path(output_path).expanduser() - - if output_grouping in (None, "collection", "type"): - self.output_grouping = output_grouping - else: - raise ValueError(f"Unsupported output_grouping: {output_grouping}") - - self.include_personal_collections = include_personal_collections - self.collection_includes = collection_includes or [] - self.collection_excludes = collection_excludes or [] - - self.table_names: Mapping = {} - self.models_exposed: List = [] - self.native_query: str = "" - - def execute(self) -> Mapping: - """Extracts exposures in Metabase downstream of dbt models and sources as parsed by dbt reader. - - Returns: - Mapping: JSON object representation of all exposures parsed. - """ - - self.table_names = { - table["id"]: table["name"] for table in self.client.api("get", "/api/table") - } - - documented_exposure_names = [] - parsed_exposures = [] - - for collection in self.client.api("get", "/api/collection"): - # Inclusion/exclusion criteria check - name_included = ( - collection["name"] in self.collection_includes - or not self.collection_includes - ) - name_excluded = collection["name"] in self.collection_excludes - personal_included = self.include_personal_collections or not collection.get( - "personal_owner_id" - ) - if not name_included or name_excluded or not personal_included: - logging.debug("Skipping collection %s", collection["name"]) - continue - - # Iter through collection - logger.info("Exploring collection %s", collection["name"]) - for item in self.client.api( - "get", f"/api/collection/{collection['id']}/items" - ): - # Ensure collection item is of parsable type - exposure_type = item["model"] - exposure_id = item["id"] - if exposure_type not in ("card", "dashboard"): - continue - - # Prepare attributes for population through _extract_card_exposures calls - self.models_exposed = [] - self.native_query = "" - native_query = "" - - exposure = self.client.api("get", f"/api/{exposure_type}/{exposure_id}") - exposure_name = exposure.get("name", "Exposure [Unresolved Name]") - logger.info( - "Introspecting exposure: %s", - exposure_name, - ) - - header = None - creator_name = None - creator_email = None - - # Process exposure - if exposure_type == "card": - # Build header for card and extract models to self.models_exposed - header = "### Visualization: {}\n\n".format( - exposure.get("display", "Unknown").title() - ) - - # Parse Metabase question - self._extract_card_exposures(exposure_id, exposure) - native_query = self.native_query - - elif exposure_type == "dashboard": - # We expect this dict key in order to iter through questions - if "ordered_cards" not in exposure: - continue - - # Build header for dashboard and extract models for each question to self.models_exposed - header = "### Dashboard Cards: {}\n\n".format( - str(len(exposure["ordered_cards"])) - ) - - # Iterate through dashboard questions - for dashboard_item in exposure["ordered_cards"]: - dashboard_item_reference = dashboard_item.get("card", {}) - if "id" not in dashboard_item_reference: - continue - - # Parse Metabase question - self._extract_card_exposures(dashboard_item_reference["id"]) - - if not self.models_exposed: - logger.info("No models mapped to exposure") - - # Extract creator info - if "creator" in exposure: - creator_email = exposure["creator"]["email"] - creator_name = exposure["creator"]["common_name"] - elif "creator_id" in exposure: - # If a metabase user is deactivated, the API returns a 404 - try: - creator = self.client.api( - "get", f"/api/user/{exposure['creator_id']}" - ) - except requests.exceptions.HTTPError as error: - creator = {} - if error.response is None or error.response.status_code != 404: - raise - - creator_name = creator.get("common_name") - creator_email = creator.get("email") - - exposure_label = exposure_name - # Only letters, numbers and underscores allowed in model names in dbt docs DAG / no duplicate model names - exposure_name = safe_name(exposure_name) - enumer = 1 - while exposure_name in documented_exposure_names: - exposure_name = f"{exposure_name}_{enumer}" - enumer += 1 - - # Construct exposure - parsed_exposures.append( - { - "id": item["id"], - "type": item["model"], - "collection": collection, - "exposure": self._build_exposure( - exposure_type=exposure_type, - exposure_id=exposure_id, - name=exposure_name, - label=exposure_label, - header=header or "", - created_at=exposure["created_at"], - creator_name=creator_name or "", - creator_email=creator_email or "", - description=exposure.get("description", ""), - native_query=native_query, - ), - } - ) - - documented_exposure_names.append(exposure_name) - - for group, exposures in self._group_exposures(parsed_exposures).items(): - path = self.output_path.joinpath(*group[:-1]) / f"{group[-1]}.yml" - path.parent.mkdir(parents=True, exist_ok=True) - - exposures_unwrapped = map(lambda x: x["exposure"], exposures) - exposures_sorted = sorted(exposures_unwrapped, key=lambda x: x["name"]) - - with open(path, "w", encoding="utf-8") as f: - yaml.dump( - { - "version": self._RESOURCE_VERSION, - "exposures": exposures_sorted, - }, - f, - Dumper=self.DbtDumper, - default_flow_style=False, - allow_unicode=True, - sort_keys=False, - ) - - return {"exposures": parsed_exposures} # todo: decide on output? - - def _extract_card_exposures( - self, - card_id: int, - exposure: Optional[Mapping] = None, - ): - """Extracts exposures from Metabase questions populating `self.models_exposed` - - Arguments: - card_id {int} -- Id of Metabase question used to pull question from api - - Keyword Arguments: - exposure {str} -- JSON api response from a question in Metabase, allows us to use the object if already in memory - - Returns: - None -- self.models_exposed is populated through this method. - """ - - # If an exposure is not passed, pull from id - if not exposure: - exposure = self.client.api("get", f"/api/card/{card_id}") - - query = exposure.get("dataset_query", {}) - - if query.get("type") == "query": - # Metabase GUI derived query - source_table_id = query.get("query", {}).get( - "source-table", exposure.get("table_id") - ) - - if str(source_table_id).startswith("card__"): - # Handle questions based on other question in virtual db - self._extract_card_exposures(int(source_table_id.split("__")[-1])) - else: - # Normal question - source_table = self.table_names.get(source_table_id) - if source_table: - logger.info( - "Model extracted from Metabase question: %s", - source_table, - ) - self.models_exposed.append(source_table) - - # Find models exposed through joins - for query_join in query.get("query", {}).get("joins", []): - # Handle questions based on other question in virtual db - if str(query_join.get("source-table", "")).startswith("card__"): - self._extract_card_exposures( - int(query_join.get("source-table").split("__")[-1]) - ) - continue - - # Joined model parsed - joined_table = self.table_names.get(query_join.get("source-table")) - if joined_table: - logger.info( - "Model extracted from Metabase question join: %s", - joined_table, - ) - self.models_exposed.append(joined_table) - - elif query.get("type") == "native": - # Metabase native query - native_query = query["native"].get("query") - ctes: List[str] = [] - - # Parse common table expressions for exclusion - for matched_cte in re.findall(self._CTE_PARSER, native_query): - ctes.extend(group.upper() for group in matched_cte if group) - - # Parse SQL for exposures through FROM or JOIN clauses - for sql_ref in re.findall(self._EXPOSURE_PARSER, native_query): - # Grab just the table / model name - clean_exposure = sql_ref.split(".")[-1].strip('"').upper() - - # Scrub CTEs (qualified sql_refs can not reference CTEs) - if clean_exposure in ctes and "." not in sql_ref: - continue - # Verify this is one of our parsed refable models so exposures dont break the DAG - if not self.model_refs.get(clean_exposure): - continue - - if clean_exposure: - logger.info( - "Model extracted from native query: %s", - clean_exposure, - ) - self.models_exposed.append(clean_exposure) - self.native_query = native_query - - def _build_exposure( - self, - exposure_type: str, - exposure_id: int, - name: str, - label: str, - header: str, - created_at: str, - creator_name: str, - creator_email: str, - description: str = "", - native_query: str = "", - ) -> Mapping: - """Builds an exposure object representation as defined here: https://docs.getdbt.com/reference/exposure-properties - - Arguments: - exposure_type {str} -- Model type in Metabase being either `card` or `dashboard` - exposure_id {str} -- Card or Dashboard id in Metabase - name {str} -- Name of exposure - label {str} -- Title of the card or dashboard in Metabase - header {str} -- The header goes at the top of the description and is useful for prefixing metadata - created_at {str} -- Timestamp of exposure creation derived from Metabase - creator_name {str} -- Creator name derived from Metabase - creator_email {str} -- Creator email derived from Metabase - - Keyword Arguments: - description {str} -- The description of the exposure as documented in Metabase. (default: No description provided in Metabase) - native_query {str} -- If exposure contains SQL, this arg will include the SQL in the dbt exposure documentation. (default: {""}) - - Returns: - Mapping -- JSON object representation of single exposure. - """ - - # Ensure model type is compatible - assert exposure_type in ( - "card", - "dashboard", - ), "Cannot construct exposure for object type of {}".format(exposure_type) - - if native_query: - # Format query into markdown code block - native_query = "#### Query\n\n```\n{}\n```\n\n".format( - "\n".join( - sql_line - for sql_line in self.native_query.strip().split("\n") - if sql_line.strip() != "" - ) - ) - - if not description: - description = "No description provided in Metabase\n\n" - - # Format metadata as markdown - metadata = ( - "#### Metadata\n\n" - + "Metabase Id: __{}__\n\n".format(exposure_id) - + "Created On: __{}__".format(created_at) - ) - - # Build description - description = ( - header + ("{}\n\n".format(description.strip())) + native_query + metadata - ) - - # Output exposure - return { - "name": name, - "label": label, - "description": safe_description(description), - "type": "analysis" if exposure_type == "card" else "dashboard", - "url": f"{self.client.url}/{exposure_type}/{exposure_id}", - "maturity": "medium", - "owner": { - "name": creator_name, - "email": creator_email, - }, - "depends_on": list( - { - self.model_refs[exposure.upper()] - for exposure in list({m for m in self.models_exposed}) - if exposure.upper() in self.model_refs - } - ), - } - - def _group_exposures( - self, exposures: Iterable[Mapping] - ) -> Mapping[Tuple[str, ...], Iterable[Mapping]]: - """Group exposures by configured output grouping. - - Args: - exposures (Iterable[Mapping]): Collection of exposures. - - Returns: - Mapping[Tuple[str, ...], Iterable[Mapping]]: Exposures indexed by configured grouping. - """ - - results: Dict[Tuple[str, ...], List[Mapping]] = {} - - for exposure in exposures: - group: Tuple[str, ...] = ("exposures",) - if self.output_grouping == "collection": - collection = exposure["collection"] - group = (collection.get("slug") or safe_name(collection["name"]),) - elif self.output_grouping == "type": - group = (exposure["type"], exposure["id"]) - - result = results.get(group, []) - result.append(exposure) - if group not in results: - results[group] = result - - return results - - -class MetabaseClient: - """Metabase API client.""" - - def __init__( - self, - url: str, - username: Optional[str] = None, - password: Optional[str] = None, - session_id: Optional[str] = None, - verify: bool = True, - cert: Optional[Union[str, Tuple[str, str]]] = None, - http_timeout: int = 15, - http_headers: Optional[dict] = None, - http_adapter: Optional[HTTPAdapter] = None, - ): - """New Metabase client. - - Args: - url (str): Metabase URL, e.g. "https://metabase.example.com". - username (Optional[str], optional): Metabase username (required unless providing session_id). Defaults to None. - password (Optional[str], optional): Metabase password (required unless providing session_id). Defaults to None. - session_id (Optional[str], optional): Metabase session ID. Defaults to None. - verify (bool, optional): Verify the TLS certificate at the Metabase end. Defaults to True. - cert (Optional[Union[str, Tuple[str, str]]], optional): Path to a custom certificate. Defaults to None. - http_timeout (int, optional): HTTP request timeout in secs. Defaults to 15. - http_headers (Optional[dict], optional): Additional HTTP headers. Defaults to None. - http_adapter (Optional[HTTPAdapter], optional): Custom requests HTTP adapter. Defaults to None. - """ - - self.url = url.rstrip("/") - - self.http_timeout = http_timeout - - self.session = requests.Session() - self.session.verify = verify - self.session.cert = cert - - if http_headers: - self.session.headers.update(http_headers) - - self.session.mount( - self.url, - http_adapter or HTTPAdapter(max_retries=Retry(total=3, backoff_factor=1)), - ) - - if not session_id: - if username and password: - session = self.api( - "post", - "/api/session", - json={"username": username, "password": password}, - ) - session_id = str(session["id"]) - else: - raise MetabaseArgumentError("Credentials or session ID required") - self.session.headers["X-Metabase-Session"] = session_id - - logger.info("Session established successfully") - - def api( - self, - method: str, - path: str, - critical: bool = True, - **kwargs, - ) -> Mapping: - """Unified way of calling Metabase API. - - Args: - method (str): HTTP verb, e.g. get, post, put. - path (str): Relative path of endpoint, e.g. /api/database. - critical (bool, optional): Raise on any HTTP errors. Defaults to True. - - Returns: - Mapping: JSON payload of the endpoint. - """ - - response = self.session.request( - method, - f"{self.url}{path}", - timeout=self.http_timeout, - **kwargs, - ) - - try: - response.raise_for_status() - except requests.exceptions.HTTPError: - if critical: - logger.error("HTTP request failed: %s", response.text) - raise - return {} - - response_json = response.json() - if "data" in response_json: - # Since X.40.0 responses are encapsulated in "data" with pagination parameters - return response_json["data"] - - return response_json - - def export_models( - self, - database: str, - models: List[MetabaseModel], - exclude_sources: bool = False, - sync_timeout: int = 30, - ): - """Exports dbt models to Metabase database schema. - - Args: - database (str): Metabase database name. - models (List[MetabaseModel]): List of dbt models read from project. - exclude_sources (bool, optional): Exclude dbt sources from export. Defaults to False. - """ - _ExportModelsJob( - client=self, - database=database, - models=models, - exclude_sources=exclude_sources, - sync_timeout=sync_timeout, - ).execute() - - def extract_exposures( - self, - models: List[MetabaseModel], - output_path: str = ".", - output_grouping: Optional[str] = None, - include_personal_collections: bool = True, - collection_includes: Optional[Iterable] = None, - collection_excludes: Optional[Iterable] = None, - ) -> Mapping: - """Extracts exposures in Metabase downstream of dbt models and sources as parsed by dbt reader. - - Args: - models (List[MetabaseModel]): List of dbt models. - output_path (str, optional): Path for output files. Defaults to ".". - output_grouping (Optional[str], optional): Grouping for output YAML files, supported values: "collection" (by collection slug) or "type" (by entity type). Defaults to None. - include_personal_collections (bool, optional): Include personal Metabase collections. Defaults to True. - collection_includes (Optional[Iterable], optional): Include certain Metabase collections. Defaults to None. - collection_excludes (Optional[Iterable], optional): Exclude certain Metabase collections. Defaults to None. - - Returns: - Mapping: _description_ - """ - return _ExtractExposuresJob( - client=self, - models=models, - output_path=output_path, - output_grouping=output_grouping, - include_personal_collections=include_personal_collections, - collection_includes=collection_includes, - collection_excludes=collection_excludes, - ).execute() diff --git a/dbtmetabase/metabase/__init__.py b/dbtmetabase/metabase/__init__.py new file mode 100644 index 00000000..4333f04e --- /dev/null +++ b/dbtmetabase/metabase/__init__.py @@ -0,0 +1,9 @@ +from .client import MetabaseClient +from .exposures import MetabaseExposuresClient +from .models import MetabaseModelsClient + +__all__ = [ + "MetabaseClient", + "MetabaseExposuresClient", + "MetabaseModelsClient", +] diff --git a/dbtmetabase/metabase/client.py b/dbtmetabase/metabase/client.py new file mode 100644 index 00000000..e8466ac9 --- /dev/null +++ b/dbtmetabase/metabase/client.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import logging +from typing import Mapping, Optional, Tuple, Union + +import requests +from requests.adapters import HTTPAdapter, Retry + +from .errors import MetabaseArgumentError +from .exposures import MetabaseExposuresClient +from .models import MetabaseModelsClient + +_logger = logging.getLogger(__name__) + + +class MetabaseClient(MetabaseModelsClient, MetabaseExposuresClient): + """Metabase API client.""" + + DEFAULT_VERIFY = True + DEFAULT_HTTP_TIMEOUT = 15 + + def __init__( + self, + url: str, + username: Optional[str] = None, + password: Optional[str] = None, + session_id: Optional[str] = None, + verify: bool = DEFAULT_VERIFY, + cert: Optional[Union[str, Tuple[str, str]]] = None, + http_timeout: int = DEFAULT_HTTP_TIMEOUT, + http_headers: Optional[dict] = None, + http_adapter: Optional[HTTPAdapter] = None, + ): + """New Metabase client. + + Args: + url (str): Metabase URL, e.g. "https://metabase.example.com". + username (Optional[str], optional): Metabase username (required unless providing session_id). Defaults to None. + password (Optional[str], optional): Metabase password (required unless providing session_id). Defaults to None. + session_id (Optional[str], optional): Metabase session ID. Defaults to None. + verify (bool, optional): Verify the TLS certificate at the Metabase end. Defaults to True. + cert (Optional[Union[str, Tuple[str, str]]], optional): Path to a custom certificate. Defaults to None. + http_timeout (int, optional): HTTP request timeout in secs. Defaults to 15. + http_headers (Optional[dict], optional): Additional HTTP headers. Defaults to None. + http_adapter (Optional[HTTPAdapter], optional): Custom requests HTTP adapter. Defaults to None. + """ + + self.url = url.rstrip("/") + + self.http_timeout = http_timeout + + self.session = requests.Session() + self.session.verify = verify + self.session.cert = cert + + if http_headers: + self.session.headers.update(http_headers) + + self.session.mount( + self.url, + http_adapter or HTTPAdapter(max_retries=Retry(total=3, backoff_factor=1)), + ) + + if not session_id: + if username and password: + session = self.api( + "post", + "/api/session", + json={"username": username, "password": password}, + ) + session_id = str(session["id"]) + else: + raise MetabaseArgumentError("Credentials or session ID required") + self.session.headers["X-Metabase-Session"] = session_id + + _logger.info("Session established successfully") + + def api( + self, + method: str, + path: str, + critical: bool = True, + **kwargs, + ) -> Mapping: + """Unified way of calling Metabase API. + + Args: + method (str): HTTP verb, e.g. get, post, put. + path (str): Relative path of endpoint, e.g. /api/database. + critical (bool, optional): Raise on any HTTP errors. Defaults to True. + + Returns: + Mapping: JSON payload of the endpoint. + """ + + response = self.session.request( + method, + f"{self.url}{path}", + timeout=self.http_timeout, + **kwargs, + ) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError: + if critical: + _logger.error("HTTP request failed: %s", response.text) + raise + return {} + + response_json = response.json() + if "data" in response_json: + # Since X.40.0 responses are encapsulated in "data" with pagination parameters + return response_json["data"] + + return response_json + + def format_url(self, path: str) -> str: + return self.url + path diff --git a/dbtmetabase/metabase/errors.py b/dbtmetabase/metabase/errors.py new file mode 100644 index 00000000..9a77268b --- /dev/null +++ b/dbtmetabase/metabase/errors.py @@ -0,0 +1,6 @@ +class MetabaseArgumentError(ValueError): + """Invalid Metabase arguments supplied.""" + + +class MetabaseRuntimeError(RuntimeError): + """Metabase execution failed.""" diff --git a/dbtmetabase/metabase/exposures.py b/dbtmetabase/metabase/exposures.py new file mode 100644 index 00000000..de0ee18a --- /dev/null +++ b/dbtmetabase/metabase/exposures.py @@ -0,0 +1,454 @@ +from __future__ import annotations + +import logging +import re +from abc import ABC, abstractmethod +from operator import itemgetter +from pathlib import Path +from typing import Iterable, Mapping, MutableMapping, MutableSequence, Optional, Tuple + +import requests +import yaml + +from .._format import safe_description, safe_name +from ..dbt.reader import MetabaseModel +from .errors import MetabaseArgumentError + +_RESOURCE_VERSION = 2 + +# Extracting table in `from` and `join` clauses (won't recognize some valid SQL, e.g. `from "table with spaces"`) +_EXPOSURE_PARSER = re.compile(r"[FfJj][RrOo][OoIi][MmNn]\s+([\w.\"]+)") +_CTE_PARSER = re.compile( + r"[Ww][Ii][Tt][Hh]\s+\b(\w+)\b\s+as|[)]\s*[,]\s*\b(\w+)\b\s+as" +) + +_logger = logging.getLogger(__name__) + + +class MetabaseExposuresClient(ABC): + """Abstraction for extracting exposures.""" + + DEFAULT_OUTPUT_PATH = "." + DEFAULT_INCLUDE_PERSONAL_COLLECTIONS = True + + @abstractmethod + def api(self, method: str, path: str, **kwargs) -> Mapping: + pass + + @abstractmethod + def format_url(self, path: str) -> str: + pass + + def extract_exposures( + self, + models: Iterable[MetabaseModel], + output_path: str = DEFAULT_OUTPUT_PATH, + output_grouping: Optional[str] = None, + include_personal_collections: bool = DEFAULT_INCLUDE_PERSONAL_COLLECTIONS, + collection_includes: Optional[Iterable] = None, + collection_excludes: Optional[Iterable] = None, + ) -> Iterable[Mapping]: + """Extracts exposures in Metabase downstream of dbt models and sources as parsed by dbt reader. + + Args: + models (Iterable[MetabaseModel]): List of dbt models. + output_path (str, optional): Path for output files. Defaults to ".". + output_grouping (Optional[str], optional): Grouping for output YAML files, supported values: "collection" (by collection slug) or "type" (by entity type). Defaults to None. + include_personal_collections (bool, optional): Include personal Metabase collections. Defaults to True. + collection_includes (Optional[Iterable], optional): Include certain Metabase collections. Defaults to None. + collection_excludes (Optional[Iterable], optional): Exclude certain Metabase collections. Defaults to None. + + Returns: + Iterable[Mapping]: List of parsed exposures. + """ + + if output_grouping not in (None, "collection", "type"): + raise MetabaseArgumentError( + f"Unsupported output_grouping: {output_grouping}" + ) + + collection_includes = collection_includes or [] + collection_excludes = collection_excludes or [] + + ctx = _Context( + model_refs={m.name.upper(): m.ref for m in models if m.ref}, + table_names={t["id"]: t["name"] for t in self.api("get", "/api/table")}, + ) + + exposures = [] + exposure_counts: MutableMapping[str, int] = {} + + for collection in self.api("get", "/api/collection"): + # Inclusion/exclusion check + name_included = ( + collection["name"] in collection_includes or not collection_includes + ) + name_excluded = collection["name"] in collection_excludes + personal_included = ( + not collection.get("personal_owner_id") or include_personal_collections + ) + if not name_included or name_excluded or not personal_included: + _logger.debug("Skipping collection %s", collection["name"]) + continue + + _logger.info("Exploring collection %s", collection["name"]) + for item in self.api("get", f"/api/collection/{collection['id']}/items"): + # Ensure collection item is of parsable type + exposure_type = item["model"] + exposure_id = item["id"] + if exposure_type not in ("card", "dashboard"): + continue + + # Prepare attributes for population through _extract_card_exposures calls + ctx.models_exposed = [] + ctx.native_query = "" + native_query = "" + + exposure = self.api("get", f"/api/{exposure_type}/{exposure_id}") + exposure_name = exposure.get("name", "Exposure [Unresolved Name]") + _logger.info("Introspecting exposure: %s", exposure_name) + + header = None + creator_name = None + creator_email = None + + # Process exposure + if exposure_type == "card": + # Build header for card and extract models to self.models_exposed + header = "### Visualization: {}\n\n".format( + exposure.get("display", "Unknown").title() + ) + + # Parse Metabase question + self.__extract_card_exposures(ctx, exposure_id, exposure) + native_query = ctx.native_query + + elif exposure_type == "dashboard": + # We expect this dict key in order to iter through questions + if "ordered_cards" not in exposure: + continue + + # Build header for dashboard and extract models for each question to self.models_exposed + header = "### Dashboard Cards: {}\n\n".format( + str(len(exposure["ordered_cards"])) + ) + + # Iterate through dashboard questions + for dashboard_item in exposure["ordered_cards"]: + dashboard_item_reference = dashboard_item.get("card", {}) + if "id" not in dashboard_item_reference: + continue + + # Parse Metabase question + self.__extract_card_exposures( + ctx, dashboard_item_reference["id"] + ) + + if not ctx.models_exposed: + _logger.info("No models mapped to exposure") + + # Extract creator info + if "creator" in exposure: + creator_email = exposure["creator"]["email"] + creator_name = exposure["creator"]["common_name"] + elif "creator_id" in exposure: + try: + creator = self.api("get", f"/api/user/{exposure['creator_id']}") + except requests.exceptions.HTTPError as error: + # If a Metabase user is deactivated, the API returns a 404 + creator = {} + if error.response is None or error.response.status_code != 404: + raise + + creator_name = creator.get("common_name") + creator_email = creator.get("email") + + exposure_label = exposure_name + # Unique names with letters, numbers and underscores allowed in dbt docs DAG + exposure_name = safe_name(exposure_name) + exposure_count = exposure_counts.get(exposure_name, 0) + exposure_counts[exposure_name] = exposure_count + 1 + exposure_suffix = f"_{exposure_count}" if exposure_count > 0 else "" + + exposures.append( + { + "id": item["id"], + "type": item["model"], + "collection": collection, + "body": self.__build_exposure( + ctx, + exposure_type=exposure_type, + exposure_id=exposure_id, + name=exposure_name + exposure_suffix, + label=exposure_label, + header=header or "", + created_at=exposure["created_at"], + creator_name=creator_name or "", + creator_email=creator_email or "", + description=exposure.get("description", ""), + native_query=native_query, + ), + } + ) + + self.__write_exposures(exposures, output_path, output_grouping) + + return exposures + + def __extract_card_exposures( + self, + ctx: _Context, + card_id: int, + exposure: Optional[Mapping] = None, + ): + """Extracts exposures from Metabase questions populating `ctx.models_exposed` + + Arguments: + card_id {int} -- Metabase question ID used to pull question from API. + + Keyword Arguments: + exposure {str} -- API response from a question in Metabase, allows us to use the object if already in memory. + + Returns: + None -- ctx.models_exposed is populated through this method. + """ + + # If an exposure is not passed, pull from id + if not exposure: + exposure = self.api("get", f"/api/card/{card_id}") + + query = exposure.get("dataset_query", {}) + + if query.get("type") == "query": + # Metabase GUI derived query + source_table_id = query.get("query", {}).get( + "source-table", exposure.get("table_id") + ) + + if str(source_table_id).startswith("card__"): + # Handle questions based on other question in virtual db + self.__extract_card_exposures( + ctx, + card_id=int(source_table_id.split("__")[-1]), + ) + else: + # Normal question + source_table = ctx.table_names.get(source_table_id) + if source_table: + _logger.info( + "Model extracted from Metabase question: %s", + source_table, + ) + ctx.models_exposed.append(source_table) + + # Find models exposed through joins + for query_join in query.get("query", {}).get("joins", []): + # Handle questions based on other question in virtual db + if str(query_join.get("source-table", "")).startswith("card__"): + self.__extract_card_exposures( + ctx, + card_id=int(query_join.get("source-table").split("__")[-1]), + ) + continue + + # Joined model parsed + joined_table = ctx.table_names.get(query_join.get("source-table")) + if joined_table: + _logger.info( + "Model extracted from Metabase question join: %s", + joined_table, + ) + ctx.models_exposed.append(joined_table) + + elif query.get("type") == "native": + # Metabase native query + native_query = query["native"].get("query") + ctes: MutableSequence[str] = [] + + # Parse common table expressions for exclusion + for matched_cte in re.findall(_CTE_PARSER, native_query): + ctes.extend(group.upper() for group in matched_cte if group) + + # Parse SQL for exposures through FROM or JOIN clauses + for sql_ref in re.findall(_EXPOSURE_PARSER, native_query): + # Grab just the table / model name + clean_exposure = sql_ref.split(".")[-1].strip('"').upper() + + # Scrub CTEs (qualified sql_refs can not reference CTEs) + if clean_exposure in ctes and "." not in sql_ref: + continue + # Verify this is one of our parsed refable models so exposures dont break the DAG + if not ctx.model_refs.get(clean_exposure): + continue + + if clean_exposure: + _logger.info( + "Model extracted from native query: %s", + clean_exposure, + ) + ctx.models_exposed.append(clean_exposure) + ctx.native_query = native_query + + def __build_exposure( + self, + ctx: _Context, + exposure_type: str, + exposure_id: int, + name: str, + label: str, + header: str, + created_at: str, + creator_name: str, + creator_email: str, + description: str = "", + native_query: str = "", + ) -> Mapping: + """Builds an exposure object representation as defined here: https://docs.getdbt.com/reference/exposure-properties + + Arguments: + exposure_type {str} -- Model type in Metabase being either `card` or `dashboard` + exposure_id {str} -- Card or Dashboard id in Metabase + name {str} -- Name of exposure + label {str} -- Title of the card or dashboard in Metabase + header {str} -- The header goes at the top of the description and is useful for prefixing metadata + created_at {str} -- Timestamp of exposure creation derived from Metabase + creator_name {str} -- Creator name derived from Metabase + creator_email {str} -- Creator email derived from Metabase + + Keyword Arguments: + description {str} -- The description of the exposure as documented in Metabase. (default: No description provided in Metabase) + native_query {str} -- If exposure contains SQL, this arg will include the SQL in the dbt exposure documentation. (default: {""}) + + Returns: + Mapping -- JSON object representation of single exposure. + """ + + # Ensure model type is compatible + assert exposure_type in ( + "card", + "dashboard", + ), "Cannot construct exposure for object type of {}".format(exposure_type) + + if native_query: + # Format query into markdown code block + native_query = "#### Query\n\n```\n{}\n```\n\n".format( + "\n".join( + sql_line + for sql_line in ctx.native_query.strip().split("\n") + if sql_line.strip() != "" + ) + ) + + if not description: + description = "No description provided in Metabase\n\n" + + # Format metadata as markdown + metadata = ( + "#### Metadata\n\n" + + "Metabase Id: __{}__\n\n".format(exposure_id) + + "Created On: __{}__".format(created_at) + ) + + # Build description + description = ( + header + ("{}\n\n".format(description.strip())) + native_query + metadata + ) + + # Output exposure + return { + "name": name, + "label": label, + "description": safe_description(description), + "type": "analysis" if exposure_type == "card" else "dashboard", + "url": self.format_url(f"/{exposure_type}/{exposure_id}"), + "maturity": "medium", + "owner": { + "name": creator_name, + "email": creator_email, + }, + "depends_on": list( + { + ctx.model_refs[exposure.upper()] + for exposure in list({m for m in ctx.models_exposed}) + if exposure.upper() in ctx.model_refs + } + ), + } + + def __write_exposures( + self, + exposures: Iterable[Mapping], + output_path: str, + output_grouping: Optional[str], + ): + """Write exposures to output files. + + Args: + output_path (str): Path for output files. + exposures (Iterable[Mapping]): Collection of exposures. + """ + + for group, exp in self.__group_exposures(exposures, output_grouping).items(): + path = Path(output_path).expanduser() + path = path.joinpath(*group[:-1]) / f"{group[-1]}.yml" + path.parent.mkdir(parents=True, exist_ok=True) + + exposures_unwrapped = map(lambda x: x["body"], exp) + exposures_sorted = sorted(exposures_unwrapped, key=itemgetter("name")) + + with open(path, "w", encoding="utf-8") as f: + yaml.dump( + { + "version": _RESOURCE_VERSION, + "exposures": exposures_sorted, + }, + f, + Dumper=_YAMLDumper, + default_flow_style=False, + allow_unicode=True, + sort_keys=False, + ) + + def __group_exposures( + self, + exposures: Iterable[Mapping], + output_grouping: Optional[str], + ) -> Mapping[Tuple[str, ...], Iterable[Mapping]]: + """Group exposures by configured output grouping. + + Args: + exposures (Iterable[Mapping]): Collection of exposures. + + Returns: + Mapping[Tuple[str, ...], Iterable[Mapping]]: Exposures indexed by configured grouping. + """ + + results: MutableMapping[Tuple[str, ...], MutableSequence[Mapping]] = {} + + for exposure in exposures: + group: Tuple[str, ...] = ("exposures",) + if output_grouping == "collection": + collection = exposure["collection"] + group = (collection.get("slug") or safe_name(collection["name"]),) + elif output_grouping == "type": + group = (exposure["type"], exposure["id"]) + + result = results.get(group, []) + result.append(exposure) + if group not in results: + results[group] = result + + return results + + +class _Context: + def __init__(self, model_refs: Mapping[str, str], table_names: Mapping[str, str]): + self.model_refs = model_refs + self.table_names = table_names + self.models_exposed: MutableSequence[str] = [] + self.native_query = "" + + +class _YAMLDumper(yaml.Dumper): + def increase_indent(self, flow=False, indentless=False): + return super().increase_indent(flow, indentless=False) diff --git a/dbtmetabase/metabase/models.py b/dbtmetabase/metabase/models.py new file mode 100644 index 00000000..291442de --- /dev/null +++ b/dbtmetabase/metabase/models.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import logging +import time +from abc import ABC, abstractmethod +from typing import Any, Iterable, Mapping, MutableMapping, Optional + +from .._format import NullValue, safe_name +from ..dbt import DEFAULT_SCHEMA, MetabaseColumn, MetabaseModel, ResourceType +from .errors import MetabaseRuntimeError + +_SYNC_PERIOD = 5 + +logger = logging.getLogger(__name__) + + +class MetabaseModelsClient(ABC): + """Abstraction for exporting models.""" + + DEFAULT_SYNC_TIMEOUT = 30 + DEFAULT_EXCLUDE_SOURCES = False + + @abstractmethod + def api(self, method: str, path: str, **kwargs) -> Mapping: + pass + + def export_models( + self, + database: str, + models: Iterable[MetabaseModel], + exclude_sources: bool = DEFAULT_EXCLUDE_SOURCES, + sync_timeout: int = DEFAULT_SYNC_TIMEOUT, + ): + """Exports dbt models to Metabase database schema. + + Args: + database (str): Metabase database name. + models (List[MetabaseModel]): List of dbt models read from project. + exclude_sources (bool, optional): Exclude dbt sources from export. Defaults to False. + """ + + ctx = _Context() + + success = True + + database_id = None + for api_database in self.api("get", "/api/database"): + if api_database["name"].upper() == database.upper(): + database_id = api_database["id"] + break + if not database_id: + raise MetabaseRuntimeError(f"Cannot find database by name {database}") + + if sync_timeout: + self.api("post", f"/api/database/{database_id}/sync_schema") + time.sleep(_SYNC_PERIOD) + + deadline = int(time.time()) + sync_timeout + synced = False + while not synced: + tables = self.__get_tables(database_id) + + synced = True + for model in models: + schema_name = model.schema.upper() + model_name = model.name.upper() + table_key = f"{schema_name}.{model_name}" + + table = tables.get(table_key) + if not table: + logger.warning( + "Model %s not found in %s schema", table_key, schema_name + ) + synced = False + continue + + for column in model.columns: + column_name = column.name.upper() + + field = table.get("fields", {}).get(column_name) + if not field: + logger.warning( + "Column %s not found in %s model", column_name, table_key + ) + synced = False + continue + + ctx.tables = tables + + if int(time.time()) < deadline: + time.sleep(_SYNC_PERIOD) + + if not synced and sync_timeout: + raise MetabaseRuntimeError("Unable to sync models between dbt and Metabase") + + models = [ + model + for model in models + if model.res_type != ResourceType.source or not exclude_sources + ] + for model in models: + success &= self.__export_model(ctx, model) + + for update in ctx.updates.values(): + self.api( + "put", + f"/api/{update['res_type']}/{update['id']}", + json=update["body"], + ) + logger.info( + "API %s/%s updated successfully: %s", + update["res_type"], + update["id"], + ", ".join(update.get("body", {}).keys()), + ) + + if not success: + raise MetabaseRuntimeError( + "Model export encountered non-critical errors, check output" + ) + + def __export_model(self, ctx: _Context, model: MetabaseModel) -> bool: + """Exports one dbt model to Metabase database schema. + + Arguments: + model {dict} -- One dbt model read from project. + + Returns: + bool -- True if exported successfully, false if there were errors. + """ + + success = True + + schema_name = model.schema.upper() + model_name = model.name.upper() + table_key = f"{schema_name}.{model_name}" + + api_table = ctx.tables.get(table_key) + if not api_table: + logger.error("Table %s does not exist in Metabase", table_key) + return False + + # Empty strings not accepted by Metabase + model_display_name = model.display_name or None + model_description = model.description or None + model_points_of_interest = model.points_of_interest or None + model_caveats = model.caveats or None + model_visibility = model.visibility_type or None + + body_table = {} + + # Update if specified, otherwise reset one that had been set + api_display_name = api_table.get("display_name") + if api_display_name != model_display_name and ( + model_display_name + or safe_name(api_display_name) != safe_name(api_table.get("name")) + ): + body_table["display_name"] = model_display_name + + if api_table.get("description") != model_description: + body_table["description"] = model_description + if api_table.get("points_of_interest") != model_points_of_interest: + body_table["points_of_interest"] = model_points_of_interest + if api_table.get("caveats") != model_caveats: + body_table["caveats"] = model_caveats + if api_table.get("visibility_type") != model_visibility: + body_table["visibility_type"] = model_visibility + + if body_table: + ctx.queue_update(entity=api_table, delta=body_table) + logger.info("Table %s will be updated", table_key) + else: + logger.info("Table %s is up-to-date", table_key) + + for column in model.columns: + success &= self.__export_column(ctx, schema_name, model_name, column) + + return success + + def __export_column( + self, + ctx: _Context, + schema_name: str, + model_name: str, + column: MetabaseColumn, + ) -> bool: + """Exports one dbt column to Metabase database schema. + + Arguments: + schema_name {str} -- Target schema name.s + model_name {str} -- One dbt model name read from project. + column {dict} -- One dbt column read from project. + + Returns: + bool -- True if exported successfully, false if there were errors. + """ + + success = True + + table_key = f"{schema_name}.{model_name}" + column_name = column.name.upper() + + api_field = ctx.tables.get(table_key, {}).get("fields", {}).get(column_name) + if not api_field: + logger.error( + "Field %s.%s does not exist in Metabase", + table_key, + column_name, + ) + return False + + if "special_type" in api_field: + semantic_type_key = "special_type" + else: + semantic_type_key = "semantic_type" + + fk_target_field_id = None + if column.semantic_type == "type/FK": + # Target table could be aliased if we parse_ref() on a source, so we caught aliases during model parsing + # This way we can unpack any alias mapped to fk_target_table when using yml project reader + target_table = ( + column.fk_target_table.upper() + if column.fk_target_table is not None + else None + ) + target_field = ( + column.fk_target_field.upper() + if column.fk_target_field is not None + else None + ) + + if not target_table or not target_field: + logger.info( + "Skipping FK resolution for %s table, %s field not resolved during dbt parsing", + table_key, + target_field, + ) + + else: + logger.debug( + "Looking for field %s in table %s", + target_field, + target_table, + ) + + fk_target_field = ( + ctx.tables.get(target_table, {}).get("fields", {}).get(target_field) + ) + if fk_target_field: + fk_target_field_id = fk_target_field.get("id") + if fk_target_field.get(semantic_type_key) != "type/PK": + logger.info( + "API field/%s will become PK (for %s column FK)", + fk_target_field_id, + column_name, + ) + body_fk_target_field = { + semantic_type_key: "type/PK", + } + ctx.queue_update( + entity=fk_target_field, delta=body_fk_target_field + ) + else: + logger.info( + "API field/%s is already PK (for %s column FK)", + fk_target_field_id, + column_name, + ) + else: + logger.error( + "Unable to find PK for %s.%s column FK", + target_table, + target_field, + ) + success = False + + # Empty strings not accepted by Metabase + column_description = column.description or None + column_display_name = column.display_name or None + column_visibility = column.visibility_type or "normal" + + # Preserve this relationship by default + if api_field["fk_target_field_id"] and not fk_target_field_id: + fk_target_field_id = api_field["fk_target_field_id"] + + body_field: MutableMapping[str, Optional[Any]] = {} + + # Update if specified, otherwise reset one that had been set + api_display_name = api_field.get("display_name") + if api_display_name != column_display_name and ( + column_display_name + or safe_name(api_display_name) != safe_name(api_field.get("name")) + ): + body_field["display_name"] = column_display_name + + if api_field.get("description") != column_description: + body_field["description"] = column_description + if api_field.get("visibility_type") != column_visibility: + body_field["visibility_type"] = column_visibility + if api_field.get("fk_target_field_id") != fk_target_field_id: + body_field["fk_target_field_id"] = fk_target_field_id + if ( + api_field.get("has_field_values") != column.has_field_values + and column.has_field_values + ): + body_field["has_field_values"] = column.has_field_values + if ( + api_field.get("coercion_strategy") != column.coercion_strategy + and column.coercion_strategy + ): + body_field["coercion_strategy"] = column.coercion_strategy + + settings = api_field.get("settings") or {} + if settings.get("number_style") != column.number_style and column.number_style: + settings["number_style"] = column.number_style + + if settings: + body_field["settings"] = settings + + # Allow explicit null type to override detected one + api_semantic_type = api_field.get(semantic_type_key) + if (column.semantic_type and api_semantic_type != column.semantic_type) or ( + column.semantic_type is NullValue and api_semantic_type + ): + body_field[semantic_type_key] = column.semantic_type or None + + if body_field: + ctx.queue_update(entity=api_field, delta=body_field) + logger.info("Field %s.%s will be updated", model_name, column_name) + else: + logger.info("Field %s.%s is up-to-date", model_name, column_name) + + return success + + def __get_tables(self, database_id: str) -> Mapping[str, MutableMapping]: + tables = {} + + metadata = self.api( + "get", + f"/api/database/{database_id}/metadata", + params={"include_hidden": "true"}, + ) + + bigquery_schema = metadata.get("details", {}).get("dataset-id") + + for table in metadata.get("tables", []): + # table[schema] is null for bigquery datasets + table["schema"] = ( + table.get("schema") or bigquery_schema or DEFAULT_SCHEMA + ).upper() + + fields = {} + for field in table.get("fields", []): + new_field = field.copy() + new_field["res_type"] = "field" + + field_name = field["name"].upper() + fields[field_name] = new_field + + new_table = table.copy() + new_table["res_type"] = "table" + new_table["fields"] = fields + + schema_name = table["schema"].upper() + table_name = table["name"].upper() + tables[f"{schema_name}.{table_name}"] = new_table + + return tables + + +class _Context: + def __init__(self): + self.tables: Mapping[str, MutableMapping] = {} + self.updates: MutableMapping[str, MutableMapping[str, Any]] = {} + + def queue_update(self, entity: MutableMapping, delta: Mapping): + entity.update(delta) + + key = f"{entity['res_type']}.{entity['id']}" + update = self.updates.get(key, {}) + update["res_type"] = entity["res_type"] + update["id"] = entity["id"] + + body = update.get("body", {}) + body.update(delta) + update["body"] = body + + self.updates[key] = update diff --git a/tests/test_dbt.py b/tests/test_dbt.py index 698c9763..fbc3726b 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -2,13 +2,9 @@ import unittest from pathlib import Path -from dbtmetabase.dbt import ( - DbtReader, - MetabaseColumn, - MetabaseModel, - ModelType, - NullValue, -) +from dbtmetabase import DbtReader +from dbtmetabase._format import NullValue +from dbtmetabase.dbt.models import MetabaseColumn, MetabaseModel, ResourceType class TestDbtReader(unittest.TestCase): @@ -29,7 +25,7 @@ def test_read_models(self): name="orders", schema="PUBLIC", description="This table has basic information about orders, as well as some derived facts based on payments", - model_type=ModelType.nodes, + res_type=ResourceType.node, source=None, unique_id="model.jaffle_shop.orders", columns=[ @@ -120,7 +116,7 @@ def test_read_models(self): name="customers", schema="PUBLIC", description="This table has basic information about a customer, as well as some derived facts based on a customer's orders", - model_type=ModelType.nodes, + res_type=ResourceType.node, source=None, unique_id="model.jaffle_shop.customers", columns=[ @@ -193,7 +189,7 @@ def test_read_models(self): name="stg_orders", schema="PUBLIC", description="", - model_type=ModelType.nodes, + res_type=ResourceType.node, source=None, unique_id="model.jaffle_shop.stg_orders", columns=[ @@ -221,7 +217,7 @@ def test_read_models(self): name="stg_payments", schema="PUBLIC", description="", - model_type=ModelType.nodes, + res_type=ResourceType.node, source=None, unique_id="model.jaffle_shop.stg_payments", columns=[ @@ -249,7 +245,7 @@ def test_read_models(self): name="stg_customers", schema="PUBLIC", description="", - model_type=ModelType.nodes, + res_type=ResourceType.node, source=None, unique_id="model.jaffle_shop.stg_customers", columns=[ diff --git a/tests/test_metabase.py b/tests/test_metabase.py index 45157955..1d5ac212 100644 --- a/tests/test_metabase.py +++ b/tests/test_metabase.py @@ -1,4 +1,4 @@ -# pylint: disable=protected-access +# pylint: disable=protected-access,no-member import json import logging @@ -8,8 +8,8 @@ import yaml -from dbtmetabase.dbt import MetabaseColumn, MetabaseModel, ModelType -from dbtmetabase.metabase import MetabaseClient, _ExportModelsJob, _ExtractExposuresJob +from dbtmetabase import MetabaseClient +from dbtmetabase.dbt.models import MetabaseColumn, MetabaseModel, ResourceType FIXTURES_PATH = Path("tests") / "fixtures" TMP_PATH = Path("tests") / "tmp" @@ -19,7 +19,7 @@ name="orders", schema="PUBLIC", description="This table has basic information about orders, as well as some derived facts based on payments", - model_type=ModelType.nodes, + res_type=ResourceType.node, columns=[ MetabaseColumn( name="ORDER_ID", @@ -108,7 +108,7 @@ name="customers", schema="PUBLIC", description="This table has basic information about a customer, as well as some derived facts based on a customer's orders", - model_type=ModelType.nodes, + res_type=ResourceType.node, columns=[ MetabaseColumn( name="CUSTOMER_ID", @@ -179,7 +179,7 @@ name="stg_orders", schema="PUBLIC", description="", - model_type=ModelType.nodes, + res_type=ResourceType.node, columns=[ MetabaseColumn( name="ORDER_ID", @@ -205,7 +205,7 @@ name="stg_payments", schema="PUBLIC", description="", - model_type=ModelType.nodes, + res_type=ResourceType.node, columns=[ MetabaseColumn( name="PAYMENT_ID", @@ -231,7 +231,7 @@ name="stg_customers", schema="PUBLIC", description="", - model_type=ModelType.nodes, + res_type=ResourceType.node, columns=[ MetabaseColumn( name="CUSTOMER_ID", @@ -283,8 +283,7 @@ def _assert_exposures(self, expected_path: Path, actual_path: Path): def test_exposures(self): fixtures_path = FIXTURES_PATH / "exposure" / "default" output_path = TMP_PATH / "exposure" / "default" - job = _ExtractExposuresJob( - client=self.client, + self.client.extract_exposures( models=MODELS, output_path=str(output_path), output_grouping=None, @@ -292,7 +291,6 @@ def test_exposures(self): collection_includes=None, collection_excludes=None, ) - job.execute() self._assert_exposures( fixtures_path / "exposures.yml", @@ -302,8 +300,7 @@ def test_exposures(self): def test_exposures_collection_grouping(self): fixtures_path = FIXTURES_PATH / "exposure" / "collection" output_path = TMP_PATH / "exposure" / "collection" - job = _ExtractExposuresJob( - client=self.client, + self.client.extract_exposures( models=MODELS, output_path=str(output_path), output_grouping="collection", @@ -311,7 +308,6 @@ def test_exposures_collection_grouping(self): collection_includes=None, collection_excludes=None, ) - job.execute() self._assert_exposures( fixtures_path / "a_look_at_your_customers_table.yml", @@ -325,8 +321,7 @@ def test_exposures_collection_grouping(self): def test_exposures_type_grouping(self): fixtures_path = FIXTURES_PATH / "exposure" / "type" output_path = TMP_PATH / "exposure" / "type" - job = _ExtractExposuresJob( - client=self.client, + self.client.extract_exposures( models=MODELS, output_path=str(output_path), output_grouping="type", @@ -334,7 +329,6 @@ def test_exposures_type_grouping(self): collection_includes=None, collection_excludes=None, ) - job.execute() for i in range(1, 18): self._assert_exposures( @@ -349,8 +343,7 @@ def test_exposures_type_grouping(self): ) def test_build_lookups(self): - job = _ExportModelsJob( - client=self.client, + self.client.export_models( database="unit_testing", models=[], exclude_sources=True, @@ -367,7 +360,7 @@ def test_build_lookups(self): "PUBLIC.STG_ORDERS", "PUBLIC.STG_PAYMENTS", ] - actual_tables = job._load_tables(database_id="2") + actual_tables = self.client._MetabaseModelsClient__get_tables(database_id="2") # type: ignore self.assertEqual(expected_tables, list(actual_tables.keys())) expected_columns = [