Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add filter to rule #141

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions demos/dqx_demo_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@
arguments:
col_name: col3

- criticality: error
filter: col1<3
check:
function: is_not_null_and_not_empty
arguments:
col_name: col4

- criticality: warn
check:
function: value_is_in_list
Expand Down Expand Up @@ -189,6 +196,11 @@
name='col3_is_null_or_empty',
criticality='error',
check=is_not_null_and_not_empty('col3')),
DQRule( # define rule with a filter
name='col_4_is_null_or_empty',
criticality='error',
filter='col1<3',
check=is_not_null_and_not_empty('col4')),
DQRule( # name auto-generated if not provided
criticality='warn',
check=value_is_in_list('col4', ['1', '2']))
Expand Down
12 changes: 12 additions & 0 deletions docs/dqx/docs/guide.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ checks = DQRuleColSet( # define rule for multiple columns at once
name='col3_is_null_or_empty',
criticality='error',
check=is_not_null_and_not_empty('col3')),
DQRule( # define rule with a filter
name='col_4_is_null_or_empty',
criticality='error',
filter='col1<3',
check=is_not_null_and_not_empty('col4')),
DQRule( # name auto-generated if not provided
criticality='warn',
check=value_is_in_list('col4', ['1', '2']))
Expand Down Expand Up @@ -288,6 +293,13 @@ checks = yaml.safe_load("""
arguments:
col_name: col3

- criticality: error
filter: col1<3
check:
function: is_not_null_and_not_empty
arguments:
col_name: col4

- criticality: warn
check:
function: value_is_in_list
Expand Down
13 changes: 13 additions & 0 deletions docs/dqx/docs/reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ The following quality rules / functions are currently available:

You can check implementation details of the rules [here](https://github.com/databrickslabs/dqx/blob/main/src/databricks/labs/dqx/col_functions.py).

#### Apply Filter on quality rule

If you want to apply a filter to a part of the dataframe, you can add a `filter` to the rule.
For example, if you want to check that a col `a` is not null when `b` is positive, you can do it like this:
```yaml
- criticality: "error"
filter: b>0
check:
function: "is_not_null"
arguments:
col_name: "a"
```

### Creating your own checks

#### Use sql expression
Expand Down
4 changes: 3 additions & 1 deletion src/databricks/labs/dqx/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,22 @@ def build_checks_by_metadata(checks: list[dict], glbs: dict[str, Any] | None = N
assert func # should already be validated
func_args = check.get("arguments", {})
criticality = check_def.get("criticality", "error")
filter_expr = check_def.get("filter")

if "col_names" in func_args:
logger.debug(f"Adding DQRuleColSet with columns: {func_args['col_names']}")
dq_rule_checks += DQRuleColSet(
columns=func_args["col_names"],
check_func=func,
criticality=criticality,
filter=filter_expr,
# provide arguments without "col_names"
check_func_kwargs={k: func_args[k] for k in func_args.keys() - {"col_names"}},
).get_rules()
else:
name = check_def.get("name", None)
check_func = func(**func_args)
dq_rule_checks.append(DQRule(check=check_func, name=name, criticality=criticality))
dq_rule_checks.append(DQRule(check=check_func, name=name, criticality=criticality, filter=filter_expr))

logger.debug("Exiting build_checks_by_metadata function with dq_rule_checks")
return dq_rule_checks
Expand Down
8 changes: 7 additions & 1 deletion src/databricks/labs/dqx/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DQRule:
check: Column
name: str = ""
criticality: str = Criticality.ERROR.value
filter: str | None = None

def __post_init__(self):
# take the name from the alias of the column expression if not provided
Expand All @@ -58,7 +59,10 @@ def check_column(self) -> Column:

:return: Column object
"""
return F.when(self.check.isNull(), F.lit(None).cast("string")).otherwise(self.check)
# if filter is provided, apply the filter to the check
filter_col = F.expr(self.filter) if self.filter else F.lit(True)

return F.when(self.check.isNotNull(), F.when(filter_col, self.check)).otherwise(F.lit(None).cast("string"))


@dataclass(frozen=True)
Expand All @@ -75,6 +79,7 @@ class DQRuleColSet:
columns: list[str]
check_func: Callable
criticality: str = Criticality.ERROR.value
filter: str | None = None
check_func_args: list[Any] = field(default_factory=list)
check_func_kwargs: dict[str, Any] = field(default_factory=dict)

Expand All @@ -88,6 +93,7 @@ def get_rules(self) -> list[DQRule]:
rule = DQRule(
criticality=self.criticality,
check=self.check_func(col_name, *self.check_func_args, **self.check_func_kwargs),
filter=self.filter,
)
rules.append(rule)
return rules
Expand Down
70 changes: 69 additions & 1 deletion tests/integration/test_apply_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from chispa.dataframe_comparer import assert_df_equality # type: ignore
from databricks.labs.dqx.col_functions import is_not_null_and_not_empty, make_condition
from databricks.labs.dqx.engine import DQEngine
from databricks.labs.dqx.rule import DQRule
from databricks.labs.dqx.rule import DQRule, DQRuleColSet


SCHEMA = "a: int, b: int, c: int"
Expand Down Expand Up @@ -379,6 +379,74 @@ def test_apply_checks_by_metadata(ws, spark):
assert_df_equality(checked, expected, ignore_nullable=True)


def test_apply_checks_with_filter(ws, spark):
mwojtyczka marked this conversation as resolved.
Show resolved Hide resolved
dq_engine = DQEngine(ws)
test_df = spark.createDataFrame(
[[1, 3, 3], [2, None, 4], [3, 4, None], [4, None, None], [None, None, None]], SCHEMA
)

checks = DQRuleColSet(
check_func=is_not_null_and_not_empty, criticality="warn", filter="b>3", columns=["a", "c"]
).get_rules() + [
DQRule(
name="col_b_is_null_or_empty",
criticality="error",
check=is_not_null_and_not_empty("b"),
filter="a<3",
)
]

checked = dq_engine.apply_checks(test_df, checks)

expected = spark.createDataFrame(
[
[1, 3, 3, None, None],
[2, None, 4, {"col_b_is_null_or_empty": "Column b is null or empty"}, None],
[3, 4, None, None, {"col_c_is_null_or_empty": "Column c is null or empty"}],
[4, None, None, None, None],
[None, None, None, None, None],
],
EXPECTED_SCHEMA,
)

assert_df_equality(checked, expected, ignore_nullable=True)


def test_apply_checks_by_metadata_with_filter(ws, spark):
dq_engine = DQEngine(ws)
test_df = spark.createDataFrame(
[[1, 3, 3], [2, None, 4], [3, 4, None], [4, None, None], [None, None, None]], SCHEMA
)

checks = [
{
"criticality": "warn",
"filter": "b>3",
"check": {"function": "is_not_null_and_not_empty", "arguments": {"col_names": ["b", "c"]}},
},
{
"criticality": "error",
"filter": "a<3",
"check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "b"}},
},
]

checked = dq_engine.apply_checks_by_metadata(test_df, checks, globals())

expected = spark.createDataFrame(
[
[1, 3, 3, None, None],
[2, None, 4, {"col_b_is_null_or_empty": "Column b is null or empty"}, None],
[3, 4, None, None, {"col_c_is_null_or_empty": "Column c is null or empty"}],
[4, None, None, None, None],
[None, None, None, None, None],
],
EXPECTED_SCHEMA,
)

assert_df_equality(checked, expected, ignore_nullable=True)


def test_apply_checks_from_json_file_by_metadata(ws, spark):
dq_engine = DQEngine(ws)
schema = "col1: int, col2: int, col3: int, col4 int"
Expand Down
28 changes: 20 additions & 8 deletions tests/unit/test_build_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_get_rules():
def test_build_rules():
actual_rules = DQEngineCore.build_checks(
# set of columns for the same check
DQRuleColSet(columns=["a", "b"], criticality="error", check_func=is_not_null_and_not_empty),
DQRuleColSet(columns=["a", "b"], criticality="error", filter="c>0", check_func=is_not_null_and_not_empty),
DQRuleColSet(columns=["c"], criticality="warn", check_func=is_not_null_and_not_empty),
# with check function params provided as positional arguments
DQRuleColSet(columns=["d", "e"], criticality="error", check_func=value_is_in_list, check_func_args=[[1, 2]]),
Expand All @@ -73,21 +73,21 @@ def test_build_rules():
DQRuleColSet(columns=["a", "b"], criticality="error", check_func=is_not_null_and_not_empty_array),
DQRuleColSet(columns=["c"], criticality="warn", check_func=is_not_null_and_not_empty_array),
) + [
DQRule(name="col_g_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("g")),
DQRule(name="col_g_is_null_or_empty", criticality="warn", filter="a=0", check=is_not_null_and_not_empty("g")),
DQRule(criticality="warn", check=value_is_in_list("h", allowed=[1, 2])),
]

expected_rules = [
DQRule(name="col_a_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("a")),
DQRule(name="col_b_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("b")),
DQRule(name="col_a_is_null_or_empty", criticality="error", filter="c>0", check=is_not_null_and_not_empty("a")),
DQRule(name="col_b_is_null_or_empty", criticality="error", filter="c>0", check=is_not_null_and_not_empty("b")),
DQRule(name="col_c_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("c")),
DQRule(name="col_d_value_is_not_in_the_list", criticality="error", check=value_is_in_list("d", allowed=[1, 2])),
DQRule(name="col_e_value_is_not_in_the_list", criticality="error", check=value_is_in_list("e", allowed=[1, 2])),
DQRule(name="col_f_value_is_not_in_the_list", criticality="warn", check=value_is_in_list("f", allowed=[3])),
DQRule(name="col_a_is_null_or_empty_array", criticality="error", check=is_not_null_and_not_empty_array("a")),
DQRule(name="col_b_is_null_or_empty_array", criticality="error", check=is_not_null_and_not_empty_array("b")),
DQRule(name="col_c_is_null_or_empty_array", criticality="warn", check=is_not_null_and_not_empty_array("c")),
DQRule(name="col_g_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("g")),
DQRule(name="col_g_is_null_or_empty", criticality="warn", filter="a=0", check=is_not_null_and_not_empty("g")),
DQRule(name="col_h_value_is_not_in_the_list", criticality="warn", check=value_is_in_list("h", allowed=[1, 2])),
]

Expand All @@ -101,10 +101,12 @@ def test_build_rules_by_metadata():
},
{
"criticality": "warn",
"filter": "a>0",
"check": {"function": "is_not_null_and_not_empty", "arguments": {"col_names": ["c"]}},
},
{
"criticality": "error",
"filter": "c=0",
"check": {"function": "value_is_in_list", "arguments": {"col_names": ["d", "e"], "allowed": [1, 2]}},
},
{
Expand Down Expand Up @@ -142,9 +144,19 @@ def test_build_rules_by_metadata():
expected_rules = [
DQRule(name="col_a_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("a")),
DQRule(name="col_b_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("b")),
DQRule(name="col_c_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("c")),
DQRule(name="col_d_value_is_not_in_the_list", criticality="error", check=value_is_in_list("d", allowed=[1, 2])),
DQRule(name="col_e_value_is_not_in_the_list", criticality="error", check=value_is_in_list("e", allowed=[1, 2])),
DQRule(name="col_c_is_null_or_empty", criticality="warn", filter="a>0", check=is_not_null_and_not_empty("c")),
DQRule(
name="col_d_value_is_not_in_the_list",
criticality="error",
filter="c=0",
check=value_is_in_list("d", allowed=[1, 2]),
),
DQRule(
name="col_e_value_is_not_in_the_list",
criticality="error",
filter="c=0",
check=value_is_in_list("e", allowed=[1, 2]),
),
DQRule(name="col_f_value_is_not_in_the_list", criticality="warn", check=value_is_in_list("f", allowed=[3])),
DQRule(name="col_g_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("g")),
DQRule(name="col_h_value_is_not_in_the_list", criticality="warn", check=value_is_in_list("h", allowed=[1, 2])),
Expand Down
Loading