Skip to content

Commit

Permalink
feat: customizable column names for DQEngine along with placeholer fo…
Browse files Browse the repository at this point in the history
…r other future configurations
  • Loading branch information
hrfmartins committed Jan 23, 2025
1 parent 982bd60 commit dbc3dd2
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 17 deletions.
64 changes: 48 additions & 16 deletions src/databricks/labs/dqx/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from typing import Any, Optional
import yaml

import pyspark.sql.functions as F
Expand All @@ -18,19 +18,26 @@
from databricks.labs.dqx.base import DQEngineBase
from databricks.labs.dqx.config import WorkspaceConfig
from databricks.labs.dqx.utils import get_column_name
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound

logger = logging.getLogger(__name__)


# TODO: make this configurable
class Columns(Enum):
class DefaultColumnNames(Enum):
"""Enum class to represent columns in the dataframe that will be used for error and warning reporting."""

ERRORS = "_errors"
WARNINGS = "_warnings"


class ColumnArguments(Enum):
"""Enum class that is used as input parsing for custom column naming."""

ERRORS = "errors"
WARNINGS = "warnings"


class Criticality(Enum):
"""Enum class to represent criticality of the check."""

Expand Down Expand Up @@ -142,9 +149,32 @@ def get_rules(self) -> list[DQRule]:
rules.append(rule)
return rules

@dataclass(frozen=True)
class ExtraParams:
"""Class to represent extra parameters for DQEngine."""

column_names: Optional[dict[str, str]] = field(default_factory=dict)


class DQEngine(DQEngineBase):
"""Data Quality Engine class to apply data quality checks to a given dataframe."""
"""Data Quality Engine class to apply data quality checks to a given dataframe.
Args:
workspace_client (WorkspaceClient): WorkspaceClient instance to use for accessing the workspace.
extra_params (ExtraParams): Extra parameters for the DQEngine.
"""

def __init__(self, workspace_client: WorkspaceClient, extra_params: Optional[ExtraParams] = None):
super().__init__(workspace_client)

extra_params = extra_params or ExtraParams()

self._column_names = {
ColumnArguments.ERRORS: extra_params.column_names.get(ColumnArguments.ERRORS.value, DefaultColumnNames.ERRORS.value),
ColumnArguments.WARNINGS: extra_params.column_names.get(
ColumnArguments.WARNINGS.value, DefaultColumnNames.WARNINGS.value
),
}

@staticmethod
def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]:
Expand All @@ -156,17 +186,16 @@ def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]:
"""
return [check for check in checks if check.rule_criticality == criticality]

@staticmethod
def _append_empty_checks(df: DataFrame) -> DataFrame:
def _append_empty_checks(self, df: DataFrame) -> DataFrame:
"""Append empty checks at the end of dataframe.
:param df: dataframe without checks
:return: dataframe with checks
"""
return df.select(
"*",
F.lit(None).cast("map<string, string>").alias(Columns.ERRORS.value),
F.lit(None).cast("map<string, string>").alias(Columns.WARNINGS.value),
F.lit(None).cast("map<string, string>").alias(self._column_names[ColumnArguments.ERRORS]),
F.lit(None).cast("map<string, string>").alias(self._column_names[ColumnArguments.WARNINGS]),
)

@staticmethod
Expand Down Expand Up @@ -204,8 +233,8 @@ def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame:

warning_checks = self._get_check_columns(checks, Criticality.WARN.value)
error_checks = self._get_check_columns(checks, Criticality.ERROR.value)
ndf = self._create_results_map(df, error_checks, Columns.ERRORS.value)
ndf = self._create_results_map(ndf, warning_checks, Columns.WARNINGS.value)
ndf = self._create_results_map(df, error_checks, self._column_names[ColumnArguments.ERRORS])
ndf = self._create_results_map(ndf, warning_checks, self._column_names[ColumnArguments.WARNINGS])

return ndf

Expand All @@ -228,23 +257,26 @@ def apply_checks_and_split(self, df: DataFrame, checks: list[DQRule]) -> tuple[D

return good_df, bad_df

@staticmethod
def get_invalid(df: DataFrame) -> DataFrame:
def get_invalid(self, df: DataFrame) -> DataFrame:
"""
Get records that violate data quality checks (records with warnings and errors).
@param df: input DataFrame.
@param column_names: dictionary with column names for errors and warnings.
@return: dataframe with error and warning rows and corresponding reporting columns.
"""
return df.where(F.col(Columns.ERRORS.value).isNotNull() | F.col(Columns.WARNINGS.value).isNotNull())
return df.where(
F.col(self._column_names[ColumnArguments.ERRORS]).isNotNull()
| F.col(self._column_names[ColumnArguments.WARNINGS]).isNotNull()
)

@staticmethod
def get_valid(df: DataFrame) -> DataFrame:
def get_valid(self, df: DataFrame) -> DataFrame:
"""
Get records that don't violate data quality checks (records with warnings but no errors).
@param df: input DataFrame.
@param column_names: dictionary with column names for errors and warnings.
@return: dataframe with warning rows but no reporting columns.
"""
return df.where(F.col(Columns.ERRORS.value).isNull()).drop(Columns.ERRORS.value, Columns.WARNINGS.value)
return df.where(F.col(self._column_names[ColumnArguments.ERRORS]).isNull()).drop(self._column_names[ColumnArguments.ERRORS], self._column_names[ColumnArguments.WARNINGS])

@staticmethod
def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus:
Expand Down
27 changes: 26 additions & 1 deletion tests/integration/test_apply_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from databricks.labs.dqx.col_functions import is_not_null_and_not_empty, make_condition
from databricks.labs.dqx.engine import (
DQRule,
DQEngine,
DQEngine, ExtraParams,
)

SCHEMA = "a: int, b: int, c: int"
EXPECTED_SCHEMA = SCHEMA + ", _errors: map<string,string>, _warnings: map<string,string>"
EXPECTED_SCHEMA_WITH_CUSTOM_NAMES = SCHEMA + ", ERROR: map<string,string>, WARN: map<string,string>"


def test_apply_checks_on_empty_checks(ws, spark):
Expand Down Expand Up @@ -442,3 +443,27 @@ def col_test_check_func(col_name: str) -> Column:
check_col = F.col(col_name)
condition = check_col.isNull() | (check_col == "") | (check_col == "null")
return make_condition(condition, "new check failed", f"{col_name}_is_null_or_empty")


def test_apply_checks_with_custom_column_naming(ws, spark):
dq_engine = DQEngine(ws, ExtraParams(column_names = {'errors': 'ERROR', 'warnings': 'WARN'}))
test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA)

checks = [{"criticality": "warn", "check": {"function": "col_test_check_func", "arguments": {"col_name": "a"}}}]
checked = dq_engine.apply_checks_by_metadata(test_df, checks, globals())

assert 'ERROR' in checked.columns
assert 'WARN' in checked.columns

expected = spark.createDataFrame(
[
[1, 3, 3, None, None],
[2, None, 4, None, None],
[None, 4, None, None, {"col_a_is_null_or_empty": "new check failed"}],
[None, None, None, None, {"col_a_is_null_or_empty": "new check failed"}],
],
EXPECTED_SCHEMA_WITH_CUSTOM_NAMES,
)

assert_df_equality(checked, expected, ignore_nullable=True)

0 comments on commit dbc3dd2

Please sign in to comment.