Skip to content

Commit

Permalink
feat: customizable column names for DQEngine along with placeholer fo…
Browse files Browse the repository at this point in the history
…r other future configurations
  • Loading branch information
hrfmartins authored and hrfmartins committed Feb 1, 2025
1 parent 27bc038 commit a09b0fe
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 36 deletions.
72 changes: 45 additions & 27 deletions src/databricks/labs/dqx/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import itertools
from pathlib import Path
from collections.abc import Callable
from typing import Any
from typing import Any, Optional
from dataclasses import dataclass, field
import yaml
import pyspark.sql.functions as F
from pyspark.sql import DataFrame
from databricks.labs.dqx.rule import DQRule, Criticality, Columns, DQRuleColSet, ChecksValidationStatus
from databricks.labs.dqx.rule import DQRule, Criticality, DQRuleColSet, ChecksValidationStatus, ColumnArguments, \
ExtraParams, DefaultColumnNames
from databricks.labs.dqx.utils import deserialize_dicts
from databricks.labs.dqx import col_functions
from databricks.labs.blueprint.installation import Installation

from databricks.labs.dqx.base import DQEngineBase, DQEngineCoreBase
from databricks.labs.dqx.config import WorkspaceConfig, RunConfig
from databricks.sdk.errors import NotFound
Expand All @@ -24,16 +27,35 @@


class DQEngineCore(DQEngineCoreBase):
"""Data Quality Engine Core class to apply data quality checks to a given dataframe."""
"""Data Quality Engine Core class to apply data quality checks to a given dataframe.
Args:
workspace_client (WorkspaceClient): WorkspaceClient instance to use for accessing the workspace.
extra_params (ExtraParams): Extra parameters for the DQEngine.
"""

def __init__(self, workspace_client: WorkspaceClient, extra_params: ExtraParams | None = None):
super().__init__(workspace_client)

extra_params = extra_params or ExtraParams()

self._column_names = {
ColumnArguments.ERRORS: extra_params.column_names.get(
ColumnArguments.ERRORS.value, DefaultColumnNames.ERRORS.value
),
ColumnArguments.WARNINGS: extra_params.column_names.get(
ColumnArguments.WARNINGS.value, DefaultColumnNames.WARNINGS.value
),
}


def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame:
if not checks:
return self._append_empty_checks(df)

warning_checks = self._get_check_columns(checks, Criticality.WARN.value)
error_checks = self._get_check_columns(checks, Criticality.ERROR.value)
ndf = self._create_results_map(df, error_checks, Columns.ERRORS.value)
ndf = self._create_results_map(ndf, warning_checks, Columns.WARNINGS.value)
ndf = self._create_results_map(df, error_checks, self._column_names[ColumnArguments.ERRORS])
ndf = self._create_results_map(ndf, warning_checks, self._column_names[ColumnArguments.WARNINGS])

return ndf

Expand All @@ -57,12 +79,13 @@ def apply_checks_by_metadata_and_split(

return good_df, bad_df


def apply_checks_by_metadata(
self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None
) -> DataFrame:
dq_rule_checks = self.build_checks_by_metadata(checks, glbs)
self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None
) -> DataFrame:
dq_rule_checks = self.build_checks_by_metadata(checks, glbs)

return self.apply_checks(df, dq_rule_checks)
return self.apply_checks(df, dq_rule_checks)

@staticmethod
def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus:
Expand All @@ -77,13 +100,11 @@ def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> C

return status

@staticmethod
def get_invalid(df: DataFrame) -> DataFrame:
return df.where(F.col(Columns.ERRORS.value).isNotNull() | F.col(Columns.WARNINGS.value).isNotNull())
def get_invalid(self, df: DataFrame) -> DataFrame:
return df.where(F.col(self._column_names[ColumnArguments.ERRORS]).isNotNull() | F.col(self._column_names[ColumnArguments.WARNINGS]).isNotNull())

@staticmethod
def get_valid(df: DataFrame) -> DataFrame:
return df.where(F.col(Columns.ERRORS.value).isNull()).drop(Columns.ERRORS.value, Columns.WARNINGS.value)
def get_valid(self, df: DataFrame) -> DataFrame:
return df.where(F.col(self._column_names[ColumnArguments.ERRORS]).isNull()).drop(self._column_names[ColumnArguments.ERRORS], self._column_names[ColumnArguments.WARNINGS])

@staticmethod
def load_checks_from_local_file(path: str) -> list[dict]:
Expand Down Expand Up @@ -177,17 +198,16 @@ def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]:
"""
return [check for check in checks if check.rule_criticality == criticality]

@staticmethod
def _append_empty_checks(df: DataFrame) -> DataFrame:
def _append_empty_checks(self, df: DataFrame) -> DataFrame:
"""Append empty checks at the end of dataframe.
:param df: dataframe without checks
:return: dataframe with checks
"""
return df.select(
"*",
F.lit(None).cast("map<string, string>").alias(Columns.ERRORS.value),
F.lit(None).cast("map<string, string>").alias(Columns.WARNINGS.value),
F.lit(None).cast("map<string, string>").alias(self._column_names[ColumnArguments.ERRORS]),
F.lit(None).cast("map<string, string>").alias(self._column_names[ColumnArguments.WARNINGS]),
)

@staticmethod
Expand Down Expand Up @@ -350,9 +370,9 @@ def _resolve_function(func_name: str, glbs: dict[str, Any] | None = None, fail_o
class DQEngine(DQEngineBase):
"""Data Quality Engine class to apply data quality checks to a given dataframe."""

def __init__(self, workspace_client: WorkspaceClient, engine: DQEngineCoreBase | None = None):
def __init__(self, workspace_client: WorkspaceClient, engine: DQEngineCoreBase | None = None, extra_params: ExtraParams | None = None):
super().__init__(workspace_client)
self._engine = engine or DQEngineCore(workspace_client)
self._engine = engine or DQEngineCore(workspace_client, extra_params)

def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame:
return self._engine.apply_checks(df, checks)
Expand All @@ -374,13 +394,11 @@ def apply_checks_by_metadata(
def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus:
return DQEngineCore.validate_checks(checks, glbs)

@staticmethod
def get_invalid(df: DataFrame) -> DataFrame:
return DQEngineCore.get_invalid(df)
def get_invalid(self, df: DataFrame) -> DataFrame:
return self._engine.get_invalid(df)

@staticmethod
def get_valid(df: DataFrame) -> DataFrame:
return DQEngineCore.get_valid(df)
def get_valid(self, df: DataFrame) -> DataFrame:
return self._engine.get_valid(df)

@staticmethod
def load_checks_from_local_file(path: str) -> list[dict]:
Expand Down
26 changes: 19 additions & 7 deletions src/databricks/labs/dqx/rule.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
from enum import Enum
from dataclasses import dataclass, field
import functools as ft
from typing import Any
from typing import Any, Optional
from collections.abc import Callable
from pyspark.sql import Column
import pyspark.sql.functions as F
from databricks.labs.dqx.utils import get_column_name


# TODO: make this configurable
class Columns(Enum):
class Criticality(Enum):
"""Enum class to represent criticality of the check."""

WARN = "warn"
ERROR = "error"


class DefaultColumnNames(Enum):
"""Enum class to represent columns in the dataframe that will be used for error and warning reporting."""

ERRORS = "_errors"
WARNINGS = "_warnings"


class Criticality(Enum):
"""Enum class to represent criticality of the check."""
class ColumnArguments(Enum):
"""Enum class that is used as input parsing for custom column naming."""

WARN = "warn"
ERROR = "error"
ERRORS = "errors"
WARNINGS = "warnings"


@dataclass(frozen=True)
class ExtraParams:
"""Class to represent extra parameters for DQEngine."""

column_names: Optional[dict[str, str]] = field(default_factory=dict)

@dataclass(frozen=True)
class DQRule:
Expand Down
32 changes: 30 additions & 2 deletions tests/integration/test_apply_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
from pyspark.sql import Column
from chispa.dataframe_comparer import assert_df_equality # type: ignore
from databricks.labs.dqx.col_functions import is_not_null_and_not_empty, make_condition
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.engine import (
DQRule,
DQEngine,
ExtraParams,
)
from databricks.labs.dqx.rule import DQRule


SCHEMA = "a: int, b: int, c: int"
EXPECTED_SCHEMA = SCHEMA + ", _errors: map<string,string>, _warnings: map<string,string>"
EXPECTED_SCHEMA_WITH_CUSTOM_NAMES = SCHEMA + ", ERROR: map<string,string>, WARN: map<string,string>"


def test_apply_checks_on_empty_checks(ws, spark):
Expand Down Expand Up @@ -491,3 +495,27 @@ def test_get_invalid_records(ws, spark):
)

assert_df_equality(invalid_df, expected_invalid_df)

def test_apply_checks_with_custom_column_naming(ws, spark):
dq_engine = DQEngine(ws, extra_params=ExtraParams(column_names = {'errors': 'ERROR', 'warnings': 'WARN'}))
test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA)

checks = [{"criticality": "warn", "check": {"function": "col_test_check_func", "arguments": {"col_name": "a"}}}]
checked = dq_engine.apply_checks_by_metadata(test_df, checks, globals())

assert 'ERROR' in checked.columns
assert 'WARN' in checked.columns

expected = spark.createDataFrame(
[
[1, 3, 3, None, None],
[2, None, 4, None, None],
[None, 4, None, None, {"col_a_is_null_or_empty": "new check failed"}],
[None, None, None, None, {"col_a_is_null_or_empty": "new check failed"}],
],
EXPECTED_SCHEMA_WITH_CUSTOM_NAMES,
)

assert_df_equality(checked, expected, ignore_nullable=True)


0 comments on commit a09b0fe

Please sign in to comment.