diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index d816205..26a736c 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -356,4 +356,38 @@ def ends_with_foo(col_name: str) -> Column: dq_engine = DQEngine(WorkspaceClient()) valid_and_quarantined_df = dq_engine.apply_checks_by_metadata(input_df, checks, globals()) -display(valid_and_quarantined_df) \ No newline at end of file +display(valid_and_quarantined_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Applying custom column names + +# COMMAND ---------- + +from databricks.sdk import WorkspaceClient +from databricks.labs.dqx.engine import ( + DQEngine, + ExtraParams, + DQRule +) + +from databricks.labs.dqx.col_functions import is_not_null_and_not_empty + +# using ExtraParams class to configure the engine with custom column names +extra_parameters = ExtraParams(column_names={"errors": "dq_errors", "warnings": "dq_warnings"}) + +ws = WorkspaceClient() +dq_engine = DQEngine(ws, extra_params=extra_parameters) + +schema = "col1: string" +input_df = spark.createDataFrame([["str1"], ["foo"], ["str3"]], schema) + +checks = [ DQRule( + name='col_1_is_null_or_empty', + criticality='error', + check=is_not_null_and_not_empty('col1')), + ] + +valid_and_quarantined_df = dq_engine.apply_checks(input_df, checks) +display(valid_and_quarantined_df) diff --git a/docs/dqx/docs/guide.mdx b/docs/dqx/docs/guide.mdx index 2ff8d7a..76e3269 100644 --- a/docs/dqx/docs/guide.mdx +++ b/docs/dqx/docs/guide.mdx @@ -55,7 +55,7 @@ dlt_expectations = dlt_generator.generate_dlt_rules(profiles, language="Python_D print(dlt_expectations) ``` -### Using CLI +### Using CLI You can optionally install DQX in the workspace, see the [Installation Guide](/docs/installation#dqx-installation-in-a-databricks-workspace). As part of the installation, a config, dashboards and profiler workflow is installed. The workflow can be run manually in the workspace UI or using the CLI as below. @@ -116,7 +116,7 @@ print(status) ``` Note that checks are validated automatically when applied as part of the -`apply_checks_by_metadata_and_split` and `apply_checks_by_metadata` methods +`apply_checks_by_metadata_and_split` and `apply_checks_by_metadata` methods (see [Quality rules defined as config](#quality-rules-defined-as-config)). ### Using CLI @@ -178,7 +178,7 @@ checks = dq_engine.load_checks_from_installation(assume_user=True, run_config_na input_df = spark.read.table("catalog1.schema1.table1") -# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes +# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df, checks) # Option 2: apply quality rules on the dataframe and report issues as additional columns (`_warning` and `_error`) @@ -198,7 +198,7 @@ checks = dq_engine.load_checks_from_workspace_file(workspace_path="/Shared/App1/ input_df = spark.read.table("catalog1.schema1.table1") -# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes +# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df, checks) # Option 2: apply quality rules on the dataframe and report issues as additional columns (`_warning` and `_error`) @@ -220,7 +220,7 @@ dq_engine = DQEngine(WorkspaceClient()) input_df = spark.read.table("catalog1.schema1.table1") -# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes +# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df, checks) # Option 2: apply quality rules on the dataframe and report issues as additional columns (`_warning` and `_error`) @@ -241,18 +241,18 @@ from databricks.sdk import WorkspaceClient dq_engine = DQEngine(WorkspaceClient()) checks = DQRuleColSet( # define rule for multiple columns at once - columns=["col1", "col2"], - criticality="error", + columns=["col1", "col2"], + criticality="error", check_func=is_not_null).get_rules() + [ DQRule( # define rule for a single column - name='col3_is_null_or_empty', - criticality='error', - check=is_not_null_and_not_empty('col3')), + name="col3_is_null_or_empty", + criticality="error", + check=is_not_null_and_not_empty("col3")), DQRule( # define rule with a filter - name='col_4_is_null_or_empty', - criticality='error', - filter='col1<3', - check=is_not_null_and_not_empty('col4')), + name="col_4_is_null_or_empty", + criticality="error", + filter="col1<3", + check=is_not_null_and_not_empty("col4")), DQRule( # name auto-generated if not provided criticality='warn', check=value_is_in_list('col4', ['1', '2'])) @@ -260,7 +260,7 @@ checks = DQRuleColSet( # define rule for multiple columns at once input_df = spark.read.table("catalog1.schema1.table1") -# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes +# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_and_split(input_df, checks) # Option 2: apply quality rules on the dataframe and report issues as additional columns (`_warning` and `_error`) @@ -312,7 +312,7 @@ checks = yaml.safe_load(""" input_df = spark.read.table("catalog1.schema1.table1") -# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes +# Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df, checks) # Option 2: apply quality rules on the dataframe and report issues as additional columns (`_warning` and `_error`) @@ -423,3 +423,27 @@ dq_engine = DQEngine(ws) For details on the specific methods available in the engine, visit to the [reference](/docs/reference#dq-engine-methods) section. Information on testing applications that use `DQEngine` can be found [here](/docs/reference#testing-applications-using-dqx). + +## Additional Configuration + +### Customizing Reporting Error and Warning Columns + +By default, DQX appends `_error` and `_warning` reporting columns to the output DataFrame to flag quality issues. + +You can customize the names of these reporting columns by specifying additional configurations in the engine. + +```python +from databricks.sdk import WorkspaceClient +from databricks.labs.dqx.engine import ( + DQEngine, + ExtraParams, +) + +# customize reporting column names +extra_parameters = ExtraParams(column_names={"errors": "dq_errors", "warnings": "dq_warnings"}) + +ws = WorkspaceClient() +dq_engine = DQEngine(ws, extra_params=extra_parameters) +``` + + diff --git a/src/databricks/labs/dqx/base.py b/src/databricks/labs/dqx/base.py index b1bbffa..651e696 100644 --- a/src/databricks/labs/dqx/base.py +++ b/src/databricks/labs/dqx/base.py @@ -113,18 +113,16 @@ def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> C :return ValidationStatus: The validation status. """ - @staticmethod @abc.abstractmethod - def get_invalid(df: DataFrame) -> DataFrame: + def get_invalid(self, df: DataFrame) -> DataFrame: """ Get records that violate data quality checks (records with warnings and errors). @param df: input DataFrame. @return: dataframe with error and warning rows and corresponding reporting columns. """ - @staticmethod @abc.abstractmethod - def get_valid(df: DataFrame) -> DataFrame: + def get_valid(self, df: DataFrame) -> DataFrame: """ Get records that don't violate data quality checks (records with warnings but no errors). @param df: input DataFrame. diff --git a/src/databricks/labs/dqx/engine.py b/src/databricks/labs/dqx/engine.py index 7cefb2a..ce2a665 100644 --- a/src/databricks/labs/dqx/engine.py +++ b/src/databricks/labs/dqx/engine.py @@ -9,10 +9,19 @@ import yaml import pyspark.sql.functions as F from pyspark.sql import DataFrame -from databricks.labs.dqx.rule import DQRule, Criticality, Columns, DQRuleColSet, ChecksValidationStatus +from databricks.labs.dqx.rule import ( + DQRule, + Criticality, + DQRuleColSet, + ChecksValidationStatus, + ColumnArguments, + ExtraParams, + DefaultColumnNames, +) from databricks.labs.dqx.utils import deserialize_dicts from databricks.labs.dqx import col_functions from databricks.labs.blueprint.installation import Installation + from databricks.labs.dqx.base import DQEngineBase, DQEngineCoreBase from databricks.labs.dqx.config import WorkspaceConfig, RunConfig from databricks.sdk.errors import NotFound @@ -24,7 +33,25 @@ class DQEngineCore(DQEngineCoreBase): - """Data Quality Engine Core class to apply data quality checks to a given dataframe.""" + """Data Quality Engine Core class to apply data quality checks to a given dataframe. + Args: + workspace_client (WorkspaceClient): WorkspaceClient instance to use for accessing the workspace. + extra_params (ExtraParams): Extra parameters for the DQEngine. + """ + + def __init__(self, workspace_client: WorkspaceClient, extra_params: ExtraParams | None = None): + super().__init__(workspace_client) + + extra_params = extra_params or ExtraParams() + + self._column_names = { + ColumnArguments.ERRORS: extra_params.column_names.get( + ColumnArguments.ERRORS.value, DefaultColumnNames.ERRORS.value + ), + ColumnArguments.WARNINGS: extra_params.column_names.get( + ColumnArguments.WARNINGS.value, DefaultColumnNames.WARNINGS.value + ), + } def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame: if not checks: @@ -32,8 +59,8 @@ def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame: warning_checks = self._get_check_columns(checks, Criticality.WARN.value) error_checks = self._get_check_columns(checks, Criticality.ERROR.value) - ndf = self._create_results_map(df, error_checks, Columns.ERRORS.value) - ndf = self._create_results_map(ndf, warning_checks, Columns.WARNINGS.value) + ndf = self._create_results_map(df, error_checks, self._column_names[ColumnArguments.ERRORS]) + ndf = self._create_results_map(ndf, warning_checks, self._column_names[ColumnArguments.WARNINGS]) return ndf @@ -77,13 +104,16 @@ def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> C return status - @staticmethod - def get_invalid(df: DataFrame) -> DataFrame: - return df.where(F.col(Columns.ERRORS.value).isNotNull() | F.col(Columns.WARNINGS.value).isNotNull()) + def get_invalid(self, df: DataFrame) -> DataFrame: + return df.where( + F.col(self._column_names[ColumnArguments.ERRORS]).isNotNull() + | F.col(self._column_names[ColumnArguments.WARNINGS]).isNotNull() + ) - @staticmethod - def get_valid(df: DataFrame) -> DataFrame: - return df.where(F.col(Columns.ERRORS.value).isNull()).drop(Columns.ERRORS.value, Columns.WARNINGS.value) + def get_valid(self, df: DataFrame) -> DataFrame: + return df.where(F.col(self._column_names[ColumnArguments.ERRORS]).isNull()).drop( + self._column_names[ColumnArguments.ERRORS], self._column_names[ColumnArguments.WARNINGS] + ) @staticmethod def load_checks_from_local_file(path: str) -> list[dict]: @@ -179,8 +209,7 @@ def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]: """ return [check for check in checks if check.rule_criticality == criticality] - @staticmethod - def _append_empty_checks(df: DataFrame) -> DataFrame: + def _append_empty_checks(self, df: DataFrame) -> DataFrame: """Append empty checks at the end of dataframe. :param df: dataframe without checks @@ -188,8 +217,8 @@ def _append_empty_checks(df: DataFrame) -> DataFrame: """ return df.select( "*", - F.lit(None).cast("map").alias(Columns.ERRORS.value), - F.lit(None).cast("map").alias(Columns.WARNINGS.value), + F.lit(None).cast("map").alias(self._column_names[ColumnArguments.ERRORS]), + F.lit(None).cast("map").alias(self._column_names[ColumnArguments.WARNINGS]), ) @staticmethod @@ -352,9 +381,14 @@ def _resolve_function(func_name: str, glbs: dict[str, Any] | None = None, fail_o class DQEngine(DQEngineBase): """Data Quality Engine class to apply data quality checks to a given dataframe.""" - def __init__(self, workspace_client: WorkspaceClient, engine: DQEngineCoreBase | None = None): + def __init__( + self, + workspace_client: WorkspaceClient, + engine: DQEngineCoreBase | None = None, + extra_params: ExtraParams | None = None, + ): super().__init__(workspace_client) - self._engine = engine or DQEngineCore(workspace_client) + self._engine = engine or DQEngineCore(workspace_client, extra_params) def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame: return self._engine.apply_checks(df, checks) @@ -376,13 +410,11 @@ def apply_checks_by_metadata( def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus: return DQEngineCore.validate_checks(checks, glbs) - @staticmethod - def get_invalid(df: DataFrame) -> DataFrame: - return DQEngineCore.get_invalid(df) + def get_invalid(self, df: DataFrame) -> DataFrame: + return self._engine.get_invalid(df) - @staticmethod - def get_valid(df: DataFrame) -> DataFrame: - return DQEngineCore.get_valid(df) + def get_valid(self, df: DataFrame) -> DataFrame: + return self._engine.get_valid(df) @staticmethod def load_checks_from_local_file(path: str) -> list[dict]: diff --git a/src/databricks/labs/dqx/rule.py b/src/databricks/labs/dqx/rule.py index dcd494c..79b29d3 100644 --- a/src/databricks/labs/dqx/rule.py +++ b/src/databricks/labs/dqx/rule.py @@ -8,19 +8,32 @@ from databricks.labs.dqx.utils import get_column_name -# TODO: make this configurable -class Columns(Enum): +class Criticality(Enum): + """Enum class to represent criticality of the check.""" + + WARN = "warn" + ERROR = "error" + + +class DefaultColumnNames(Enum): """Enum class to represent columns in the dataframe that will be used for error and warning reporting.""" ERRORS = "_errors" WARNINGS = "_warnings" -class Criticality(Enum): - """Enum class to represent criticality of the check.""" +class ColumnArguments(Enum): + """Enum class that is used as input parsing for custom column naming.""" - WARN = "warn" - ERROR = "error" + ERRORS = "errors" + WARNINGS = "warnings" + + +@dataclass(frozen=True) +class ExtraParams: + """Class to represent extra parameters for DQEngine.""" + + column_names: dict[str, str] = field(default_factory=dict) @dataclass(frozen=True) diff --git a/tests/integration/test_apply_checks.py b/tests/integration/test_apply_checks.py index 909489a..93d4a1e 100644 --- a/tests/integration/test_apply_checks.py +++ b/tests/integration/test_apply_checks.py @@ -4,12 +4,15 @@ from pyspark.sql import Column from chispa.dataframe_comparer import assert_df_equality # type: ignore from databricks.labs.dqx.col_functions import is_not_null_and_not_empty, make_condition -from databricks.labs.dqx.engine import DQEngine -from databricks.labs.dqx.rule import DQRule, DQRuleColSet - +from databricks.labs.dqx.engine import ( + DQEngine, + ExtraParams, +) +from databricks.labs.dqx.rule import DQRule, DQRuleColSet, ColumnArguments SCHEMA = "a: int, b: int, c: int" EXPECTED_SCHEMA = SCHEMA + ", _errors: map, _warnings: map" +EXPECTED_SCHEMA_WITH_CUSTOM_NAMES = SCHEMA + ", dq_errors: map, dq_warnings: map" def test_apply_checks_on_empty_checks(ws, spark): @@ -559,3 +562,100 @@ def test_get_invalid_records(ws, spark): ) assert_df_equality(invalid_df, expected_invalid_df) + + +def test_apply_checks_with_custom_column_naming(ws, spark): + dq_engine = DQEngine( + ws, + extra_params=ExtraParams( + column_names={ColumnArguments.ERRORS.value: "dq_errors", ColumnArguments.WARNINGS.value: "dq_warnings"} + ), + ) + test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA) + + checks = [ + {"criticality": "warn", "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "a"}}} + ] + checked = dq_engine.apply_checks_by_metadata(test_df, checks) + + expected = spark.createDataFrame( + [ + [1, 3, 3, None, None], + [2, None, 4, None, None], + [None, 4, None, None, {"col_a_is_null_or_empty": "Column a is null or empty"}], + [None, None, None, None, {"col_a_is_null_or_empty": "Column a is null or empty"}], + ], + EXPECTED_SCHEMA_WITH_CUSTOM_NAMES, + ) + + assert_df_equality(checked, expected, ignore_nullable=True) + + +def test_apply_checks_by_metadata_with_custom_column_naming(ws, spark): + dq_engine = DQEngine( + ws, + extra_params=ExtraParams( + column_names={ColumnArguments.ERRORS.value: "dq_errors", ColumnArguments.WARNINGS.value: "dq_warnings"} + ), + ) + test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA) + + checks = [ + {"criticality": "warn", "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "a"}}}, + {"criticality": "error", "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "b"}}}, + ] + good, bad = dq_engine.apply_checks_by_metadata_and_split(test_df, checks) + + assert_df_equality(good, spark.createDataFrame([[1, 3, 3], [None, 4, None]], SCHEMA), ignore_nullable=True) + + assert_df_equality( + bad, + spark.createDataFrame( + [ + [2, None, 4, {"col_b_is_null_or_empty": "Column b is null or empty"}, None], + [None, 4, None, None, {"col_a_is_null_or_empty": "Column a is null or empty"}], + [ + None, + None, + None, + {"col_b_is_null_or_empty": "Column b is null or empty"}, + {"col_a_is_null_or_empty": "Column a is null or empty"}, + ], + ], + EXPECTED_SCHEMA_WITH_CUSTOM_NAMES, + ), + ) + + +def test_apply_checks_by_metadata_with_custom_column_naming_fallback_to_default(ws, spark): + dq_engine = DQEngine( + ws, + extra_params=ExtraParams(column_names={"errors_invalid": "dq_errors", "warnings_invalid": "dq_warnings"}), + ) + test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA) + + checks = [ + {"criticality": "warn", "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "a"}}}, + {"criticality": "error", "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "b"}}}, + ] + good, bad = dq_engine.apply_checks_by_metadata_and_split(test_df, checks) + + assert_df_equality(good, spark.createDataFrame([[1, 3, 3], [None, 4, None]], SCHEMA), ignore_nullable=True) + + assert_df_equality( + bad, + spark.createDataFrame( + [ + [2, None, 4, {"col_b_is_null_or_empty": "Column b is null or empty"}, None], + [None, 4, None, None, {"col_a_is_null_or_empty": "Column a is null or empty"}], + [ + None, + None, + None, + {"col_b_is_null_or_empty": "Column b is null or empty"}, + {"col_a_is_null_or_empty": "Column a is null or empty"}, + ], + ], + EXPECTED_SCHEMA, + ), + )