diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index 7c39b28..1770b4c 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -104,6 +104,7 @@ """) dq_engine = DQEngine(WorkspaceClient()) + status = dq_engine.validate_checks(checks) print(status.has_errors) print(status.errors) @@ -334,5 +335,6 @@ def ends_with_foo(col_name: str) -> Column: input_df = spark.createDataFrame([["str1"], ["foo"], ["str3"]], schema) dq_engine = DQEngine(WorkspaceClient()) + valid_and_quarantined_df = dq_engine.apply_checks_by_metadata(input_df, checks, globals()) display(valid_and_quarantined_df) \ No newline at end of file diff --git a/demos/dqx_demo_tool.py b/demos/dqx_demo_tool.py index d95cec8..9c9c970 100644 --- a/demos/dqx_demo_tool.py +++ b/demos/dqx_demo_tool.py @@ -84,7 +84,7 @@ ws = WorkspaceClient() dq_engine = DQEngine(ws) -run_config = dq_engine.load_run_config(run_config="default", assume_user=True) +run_config = dq_engine.load_run_config(run_config_name="default", assume_user=True) # read the input data, limit to 1000 rows for demo purpose input_df = spark.read.format(run_config.input_format).load(run_config.input_location).limit(1000) @@ -101,14 +101,14 @@ print(yaml.safe_dump(checks)) # save generated checks to location specified in the default run configuration inside workspace installation folder -dq_engine.save_checks(checks, run_config_name="default") +dq_engine.save_checks_in_installation(checks, run_config_name="default") # or save it to an arbitrary workspace location #dq_engine.save_checks_in_workspace_file(checks, workspace_path="/Shared/App1/checks.yml") # COMMAND ---------- # MAGIC %md -# MAGIC ### Prepare checks manually (optional) +# MAGIC ### Prepare checks manually and save in the workspace (optional) # MAGIC # MAGIC You can modify the check candidates generated by the profiler to suit your needs. Alternatively, you can create checks manually, as demonstrated below, without using the profiler. @@ -161,7 +161,7 @@ dq_engine = DQEngine(WorkspaceClient()) # save checks to location specified in the default run configuration inside workspace installation folder -dq_engine.save_checks(checks, run_config_name="default") +dq_engine.save_checks_in_installation(checks, run_config_name="default") # or save it to an arbitrary workspace location #dq_engine.save_checks_in_workspace_file(checks, workspace_path="/Shared/App1/checks.yml") @@ -175,7 +175,7 @@ from databricks.labs.dqx.engine import DQEngine from databricks.sdk import WorkspaceClient -run_config = dq_engine.load_run_config(run_config="default", assume_user=True) +run_config = dq_engine.load_run_config(run_config_name="default", assume_user=True) # read the data, limit to 1000 rows for demo purpose bronze_df = spark.read.format(run_config.input_format).load(run_config.input_location).limit(1000) @@ -186,7 +186,7 @@ dq_engine = DQEngine(WorkspaceClient()) # load checks from location defined in the run configuration -checks = dq_engine.load_checks(assume_user=True, run_config_name="default") +checks = dq_engine.load_checks_from_installation(assume_user=True, run_config_name="default") # or load checks from arbitrary workspace file # checks = dq_engine.load_checks_from_workspace_file(workspace_path="/Shared/App1/checks.yml") print(checks) diff --git a/docs/dqx/docs/demos.mdx b/docs/dqx/docs/demos.mdx index 8fc3575..3e8926b 100644 --- a/docs/dqx/docs/demos.mdx +++ b/docs/dqx/docs/demos.mdx @@ -4,8 +4,7 @@ sidebar_position: 4 # Demos -After the [installation](/docs/installation) of the framework, -you can import the following notebooks in the Databricks workspace to try it out: +Install the [installation](/docs/installation) framework, and import the following notebooks in the Databricks workspace to try it out: * [DQX Demo Notebook (library)](https://github.com/databrickslabs/dqx/blob/main/demos/dqx_demo_library.py) - demonstrates how to use DQX as a library. -* [DQX Demo Notebook (tool)](https://github.com/databrickslabs/dqx/blob/main/demos/dqx_demo_tool.py) - demonstrates how to use DQX when installed in the workspace, including usage of DQX dashboards. +* [DQX Demo Notebook (tool)](https://github.com/databrickslabs/dqx/blob/main/demos/dqx_demo_tool.py) - demonstrates how to use DQX as a tool when installed in the workspace. * [DQX DLT Demo Notebook](https://github.com/databrickslabs/dqx/blob/main/demos/dqx_dlt_demo.py) - demonstrates how to use DQX with Delta Live Tables (DLT). diff --git a/docs/dqx/docs/dev/contributing.mdx b/docs/dqx/docs/dev/contributing.mdx index b6d068e..ba2f9aa 100644 --- a/docs/dqx/docs/dev/contributing.mdx +++ b/docs/dqx/docs/dev/contributing.mdx @@ -93,12 +93,9 @@ make lint make test ``` -Configure auth to Databricks workspace for integration testing by configuring credentials. - -If you want to run the tests from an IDE you must setup `.env` or `~/.databricks/debug-env.json` file -(see [instructions](https://github.com/databrickslabs/pytester?tab=readme-ov-file#debug_env_name-fixture)). - -Setup required environment variables for executing integration tests and code coverage: +Setup required environment variables for executing integration tests and code coverage using the command line. +Note that integration tests are run automatically when you create a Pull Request in Github. +You can also run them from a local machine by configuring authentication to a Databricks workspace as below: ```shell export DATABRICKS_HOST=https:// export DATABRICKS_CLUSTER_ID= @@ -119,9 +116,13 @@ Calculate test coverage and display report in html: make coverage ``` +If you want to be able to run integration tests from your IDE, you must setup `.env` or `~/.databricks/debug-env.json` file +(see [instructions](https://github.com/databrickslabs/pytester?tab=readme-ov-file#debug_env_name-fixture)). +The name of the debug environment that you define must be `ws`. + ## Running CLI from the local repo -Once you clone the repo locally and install Databricks CLI you can run labs CLI commands. +Once you clone the repo locally and install Databricks CLI you can run labs CLI commands from the root of the repository. Similar to other databricks cli commands we can specify profile to use with `--profile`. Authenticate your current machine to your Databricks Workspace: diff --git a/docs/dqx/docs/guide.mdx b/docs/dqx/docs/guide.mdx index 4772bf8..601f9d6 100644 --- a/docs/dqx/docs/guide.mdx +++ b/docs/dqx/docs/guide.mdx @@ -24,22 +24,23 @@ from databricks.labs.dqx.profiler.dlt_generator import DQDltGenerator from databricks.labs.dqx.engine import DQEngine from databricks.sdk import WorkspaceClient -df = spark.read.table("catalog1.schema1.table1") +input_df = spark.read.table("catalog1.schema1.table1") +# profile input data ws = WorkspaceClient() profiler = DQProfiler(ws) -summary_stats, profiles = profiler.profile(df) +summary_stats, profiles = profiler.profile(input_df) # generate DQX quality rules/checks generator = DQGenerator(ws) checks = generator.generate_dq_rules(profiles) # with default level "error" -# save checks in the workspace dq_engine = DQEngine(ws) -# in arbitrary workspace location + +# save checks in arbitrary workspace location dq_engine.save_checks_in_workspace_file(checks, workspace_path="/Shared/App1/checks.yml") -# in workspace location specified in the run config (only works if DQX is installed in the workspace) -dq_engine.save_checks(checks, run_config_name="default") +# save checks in the installation folder specified in the default run config (only works if DQX is installed in the workspace) +dq_engine.save_checks_in_installation(checks, run_config_name="default") # generate DLT expectations dlt_generator = DQDltGenerator(ws) @@ -153,9 +154,9 @@ Fields: - `check`: column expression containing "function" (check function to apply), "arguments" (check function arguments), and "col_name" (column name as `str` the check will be applied for) or "col_names" (column names as `array` the check will be applied for). - (optional) `name` for the check: autogenerated if not provided. -#### Loading and execution methods +### Loading and execution methods -**Method 1: load checks from a workspace file in the installation folder** +#### Method 1: Loading checks from a workspace file in the installation folder If DQX is installed in the workspace, you can load checks based on the run configuration: @@ -164,9 +165,10 @@ from databricks.labs.dqx.engine import DQEngine from databricks.sdk import WorkspaceClient dq_engine = DQEngine(WorkspaceClient()) - # load check file specified in the run configuration -checks = dq_engine.load_checks(assume_user=True, run_config_name="default") +checks = dq_engine.load_checks_from_installation(assume_user=True, run_config_name="default") + +input_df = spark.read.table("catalog1.schema1.table1") # Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df, checks) @@ -175,9 +177,7 @@ valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df valid_and_quarantined_df = dq_engine.apply_checks_by_metadata(input_df, checks) ``` -Checks are validated automatically as part of the `apply_checks_by_metadata_and_split` and `apply_checks_by_metadata` methods. - -**Method 2: load checks from a workspace file** +#### Method 2: Loading checks from a workspace file The checks can also be loaded from any file in the Databricks workspace: @@ -188,6 +188,8 @@ from databricks.sdk import WorkspaceClient dq_engine = DQEngine(WorkspaceClient()) checks = dq_engine.load_checks_from_workspace_file(workspace_path="/Shared/App1/checks.yml") +input_df = spark.read.table("catalog1.schema1.table1") + # Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df, checks) @@ -197,7 +199,7 @@ valid_and_quarantined_df = dq_engine.apply_checks_by_metadata(input_df, checks) Checks are validated automatically as part of the `apply_checks_by_metadata_and_split` and `apply_checks_by_metadata` methods. -**Method 3: load checks from a local file** +#### Method 3: Loading checks from a local file Checks can also be loaded from a file in the local file system: @@ -208,6 +210,8 @@ from databricks.sdk import WorkspaceClient checks = DQEngine.load_checks_from_local_file("checks.yml") dq_engine = DQEngine(WorkspaceClient()) +input_df = spark.read.table("catalog1.schema1.table1") + # Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df, checks) @@ -217,13 +221,15 @@ valid_and_quarantined_df = dq_engine.apply_checks_by_metadata(input_df, checks) ### Quality rules defined as code -**Method 1: using DQX classes** +#### Method 1: Using DQX classes ```python from databricks.labs.dqx.col_functions import is_not_null, is_not_null_and_not_empty, value_is_in_list -from databricks.labs.dqx.engine import DQEngine, DQRuleColSet, DQRule +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.rule import DQRuleColSet, DQRule from databricks.sdk import WorkspaceClient + dq_engine = DQEngine(WorkspaceClient()) checks = DQRuleColSet( # define rule for multiple columns at once @@ -239,6 +245,8 @@ checks = DQRuleColSet( # define rule for multiple columns at once check=value_is_in_list('col4', ['1', '2'])) ] +input_df = spark.read.table("catalog1.schema1.table1") + # Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_and_split(input_df, checks) @@ -246,9 +254,9 @@ valid_df, quarantined_df = dq_engine.apply_checks_and_split(input_df, checks) valid_and_quarantined_df = dq_engine.apply_checks(input_df, checks) ``` -See details of the check functions [here](/docs/reference#quality-rules--functions). +See details of the check functions [here](/docs/reference#quality-rules). -**Method 2: using yaml config** +#### Method 2: Using yaml config ```python import yaml @@ -282,6 +290,8 @@ checks = yaml.safe_load(""" - 2 """) +input_df = spark.read.table("catalog1.schema1.table1") + # Option 1: apply quality rules on the dataframe and provide valid and invalid (quarantined) dataframes valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df, checks) @@ -289,15 +299,15 @@ valid_df, quarantined_df = dq_engine.apply_checks_by_metadata_and_split(input_df valid_and_quarantined_df = dq_engine.apply_checks_by_metadata(input_df, checks) ``` -See details of the check functions [here](/docs/reference/#quality-rules--functions). +See details of the check functions [here](/docs/reference#quality-rules). ### Integration with DLT (Delta Live Tables) DLT provides [expectations](https://docs.databricks.com/en/delta-live-tables/expectations.html) to enforce data quality constraints. However, expectations don't offer detailed insights into why certain checks fail. The example below demonstrates how to integrate DQX with DLT to provide comprehensive quality information. -The DQX integration does not use expectations with DLT but DQX own methods. +The DQX integration with DLT does not use DLT Expectations but DQX own methods. -**Option 1: apply quality rules and quarantine bad records** +#### Option 1: Apply quality rules and quarantine bad records ```python import dlt @@ -326,7 +336,7 @@ def quarantine(): return dq_engine.get_invalid(df) ``` -**Option 2: apply quality rules as additional columns (`_warning` and `_error`)** +#### Option 2: Apply quality rules and report issues as additional columns ```python import dlt @@ -367,6 +377,29 @@ After executing the command: Note: the dashboards are only using the quarantined data as input as defined during the installation process. If you change the quarantine table in the run config after the deployment (`quarantine_table` field), you need to update the dashboard queries accordingly. -## Explore Quality Rules and Create Custom Checks +## Quality Rules and Creation of Custom Checks + +Discover the full list of available data quality rules and learn how to define your own custom checks in our [Reference](/docs/reference#quality-rules) section. + +## Details on DQX Engine and Workspace Client + +To perform data quality checking with DQX, you need to create `DQEngine` object. +The engine requires a Databricks workspace client for authentication and interaction with the Databricks workspace. + +When running the code on a Databricks workspace (e.g. in a notebook or as a job), the workspace client is automatically authenticated. +For external environments (e.g. CI servers or local machines), you can authenticate using any method supported by the Databricks SDK. Detailed instructions are available in the [default authentication flow](https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#default-authentication-flow). + +If you use Databricks [configuration profiles](https://docs.databricks.com/dev-tools/auth.html#configuration-profiles) or Databricks-specific [environment variables](https://docs.databricks.com/dev-tools/auth.html#environment-variables) for authentication, you only need the following code to create a workspace client: +```python +from databricks.sdk import WorkspaceClient +from databricks.labs.dqx.engine import DQEngine + +ws = WorkspaceClient() + +# use the workspace client to create the DQX engine +dq_engine = DQEngine(ws) +``` + +For details on the specific methods available in the engine, visit to the [reference](/docs/reference#dq-engine-methods) section. -Discover the full list of available data quality rules and learn how to define your own custom checks in our [Reference](/docs/reference) section. +Information on testing applications that use `DQEngine` can be found [here](/docs/reference#testing-applications-using-dqx). diff --git a/docs/dqx/docs/reference.mdx b/docs/dqx/docs/reference.mdx index 47b4c6d..7431b80 100644 --- a/docs/dqx/docs/reference.mdx +++ b/docs/dqx/docs/reference.mdx @@ -3,11 +3,13 @@ sidebar_position: 5 title: Reference --- -# Quality rules and functions reference +# Reference -This page provides a reference for the quality rules and functions available in DQX. +## Quality rules -## Quality rules / functions +This page provides a reference for the quality rule functions (checks) available in DQX. + +### Quality rule functions (checks) The following quality rules / functions are currently available: @@ -32,9 +34,9 @@ The following quality rules / functions are currently available: You can check implementation details of the rules [here](https://github.com/databrickslabs/dqx/blob/main/src/databricks/labs/dqx/col_functions.py). -## Creating your own checks +### Creating your own checks -### Use sql expression +#### Use sql expression If a check that you need does not exist in DQX, you can define them using sql expression rule (`sql_expression`), for example: @@ -57,7 +59,7 @@ Sql expression is also useful if you want to make cross-column validation, for e msg: a is greater than b ``` -### Define custom check functions +#### Define custom check functions If you need a reusable check or need to implement a more complicated logic you can define your own check functions. A check is a function available from 'globals' that returns `pyspark.sql.Column`, for example: @@ -72,7 +74,7 @@ def ends_with_foo(col_name: str) -> Column: return make_condition(column.endswith("foo"), f"Column {col_name} ends with foo", f"{col_name}_ends_with_foo") ``` -Then you can use the function as a check: +and use the function as a check: ```python import yaml from databricks.labs.dqx.engine import DQEngine @@ -100,3 +102,105 @@ You can see all existing DQX checks [here](https://github.com/databrickslabs/dqx Feel free to submit a PR to DQX with your own check so that other can benefit from it (see [contribution guide](/docs/dev/contributing)). +## DQ engine methods + +Performing data quality checks using DQX requires creating DQEngine object. + +The following table outlines the available methods and their functionalities: + +| Check | Description | Arguments | +| ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| apply_checks | Applies quality checks to the DataFrame and returns a DataFrame with reporting columns. | df: DataFrame to check; checks: List of checks to the DataFrame. Each check is an instance of DQRule class. | +| apply_checks_and_split | Applies quality checks to the DataFrame and returns valid and invalid (quarantine) DataFrames with reporting columns. | df: DataFrame to check; checks: List of checks to apply to the DataFrame. Each check is an instance of DQRule class. | +| apply_checks_by_metadata | Applies quality checks defined as a dictionary to the DataFrame and returns a DataFrame with reporting columns. | df: DataFrame to check. checks: List of dictionaries describing checks. glbs: Optional dictionary with functions mapping (e.g., globals() of the calling module). | +| apply_checks_by_metadata_and_split | Applies quality checks defined as a dictionary and returns valid and invalid (quarantine) DataFrames. | df: DataFrame to check; checks: List of dictionaries describing checks. glbs: Optional dictionary with functions mapping (e.g., globals() of the calling module). | +| validate_checks | Validates the provided quality checks to ensure they conform to the expected structure and types. | checks: List of checks to validate; glbs: Optional dictionary of global functions that can be used. | +| get_invalid | Retrieves records from the DataFrame that violate data quality checks (records with warnings and errors). | df: Input DataFrame. | +| get_valid | Retrieves records from the DataFrame that pass all data quality checks. | df: Input DataFrame. | +| load_checks_from_local_file | Loads quality rules from a local file (supports YAML and JSON). | path: Path to a file containing the checks. | +| save_checks_in_local_file | Saves quality rules to a local file in YAML format. | checks: List of checks to save; path: Path to a file containing the checks. | +| load_checks_from_workspace_file | Loads checks from a file (JSON or YAML) stored in the Databricks workspace. | workspace_path: Path to the file in the workspace. | +| load_checks_from_installation | Loads checks from the workspace installation configuration file (`checks_file` field). | run_config_name: Name of the run config to use; product_name: Name of the product/installation directory; assume_user: If True, assume user installation. | +| save_checks_in_workspace_file | Saves checks to a file (YAML) in the Databricks workspace. | checks: List of checks to save; workspace_path: Destination path for the checks file in the workspace. | +| save_checks_in_installation | Saves checks to the installation folder as a YAML file. | checks: List of checks to save; run_config_name: Name of the run config to use; assume_user: If True, assume user installation. | +| load_run_config | Loads run configuration from the installation folder. | run_config_name: Name of the run config to use; assume_user: If True, assume user installation. | + +## Testing Applications Using DQX + +### Standard testing with DQEngine + +Testing applications that use DQEngine requires proper initialization of the Databricks workspace client. Detailed guidance on authentication for the workspace client is available [here](https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html#default-authentication-flow). + +For testing, we recommend: +* [pytester fixtures](https://github.com/databrickslabs/pytester) to setup Databricks remote Spark session and workspace client. For pytester to be able to authenticate to a workspace you need to use [debug_env_name fixture](https://github.com/databrickslabs/pytester?tab=readme-ov-file#debug_env_name-fixture). We recommend using the `~/.databricks/debug-env.json` file to store different sets of environment variables. +* [chispa](https://github.com/MrPowers/chispa) for asserting Spark DataFrames. + +These libraries are also used internally for testing DQX. + +Example test: +```python +from chispa.dataframe_comparer import assert_df_equality +from databricks.labs.dqx.col_functions import is_not_null_and_not_empty +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.rule import DQRule + + +@pytest.fixture +def debug_env_name(): + return "ws" # Specify the name of the target environment from ~/.databricks/debug-env.json + + +def test_dq(ws, spark): # use ws and spark pytester fixtures to initialize workspace client and spark session + schema = "a: int, b: int, c: int" + expected_schema = schema + ", _errors: map, _warnings: map" + test_df = spark.createDataFrame([[1, 3, 3]], schema) + + checks = [ + DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")), + DQRule(name="col_b_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("b")), + ] + + dq_engine = DQEngine(ws) + df = dq_engine.apply_checks(test_df, checks) + + expected_df = spark.createDataFrame([[1, 3, 3, None, None]], expected_schema) + assert_df_equality(df, expected_df) +``` + +### Local testing with DQEngine + +If workspace-level access is unavailable in your unit testing environment, you can perform local testing by installing the latest `pyspark` package and mocking the workspace client. + +**Note: This approach should be treated as experimental!** It does not offer the same level of testing as the standard approach and it is only applicable to selected methods. +We strongly recommend following the standard testing procedure outlined above, which includes proper initialization of the workspace client. + +Example test: +```python +from unittest.mock import MagicMock +from databricks.sdk import WorkspaceClient +from pyspark.sql import SparkSession +from chispa.dataframe_comparer import assert_df_equality +from databricks.labs.dqx.col_functions import is_not_null_and_not_empty +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.rule import DQRule + + +def test_dq(): + spark = SparkSession.builder.master("local[*]").getOrCreate() # create spark local session + ws = MagicMock(spec=WorkspaceClient, **{"catalogs.list.return_value": []}) # mock the workspace client + + schema = "a: int, b: int, c: int" + expected_schema = schema + ", _errors: map, _warnings: map" + test_df = spark.createDataFrame([[1, 3, 3]], schema) + + checks = [ + DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")), + DQRule(name="col_b_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("b")), + ] + + dq_engine = DQEngine(ws) + df = dq_engine.apply_checks(test_df, checks) + + expected_df = spark.createDataFrame([[1, 3, 3, None, None]], expected_schema) + assert_df_equality(df, expected_df) +``` diff --git a/src/databricks/labs/dqx/base.py b/src/databricks/labs/dqx/base.py index d8fb8f9..b1bbffa 100644 --- a/src/databricks/labs/dqx/base.py +++ b/src/databricks/labs/dqx/base.py @@ -1,6 +1,8 @@ import abc -from typing import final from functools import cached_property +from typing import Any, final +from pyspark.sql import DataFrame +from databricks.labs.dqx.rule import DQRule, ChecksValidationStatus from databricks.sdk import WorkspaceClient from databricks.labs.dqx.__about__ import __version__ @@ -26,6 +28,127 @@ def _verify_workspace_client(ws: WorkspaceClient) -> WorkspaceClient: product_info = getattr(ws.config, '_product_info') if product_info[0] != "dqx": setattr(ws.config, '_product_info', ('dqx', __version__)) + # make sure Unity Catalog is accessible in the current Databricks workspace ws.catalogs.list() return ws + + +class DQEngineCoreBase(DQEngineBase): + + @abc.abstractmethod + def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame: + """Applies data quality checks to a given dataframe. + + :param df: dataframe to check + :param checks: list of checks to apply to the dataframe. Each check is an instance of DQRule class. + :return: dataframe with errors and warning reporting columns + """ + + @abc.abstractmethod + def apply_checks_and_split(self, df: DataFrame, checks: list[DQRule]) -> tuple[DataFrame, DataFrame]: + """Applies data quality checks to a given dataframe and split it into two ("good" and "bad"), + according to the data quality checks. + + :param df: dataframe to check + :param checks: list of checks to apply to the dataframe. Each check is an instance of DQRule class. + :return: two dataframes - "good" which includes warning rows but no reporting columns, and "data" having + error and warning rows and corresponding reporting columns + """ + + @abc.abstractmethod + def apply_checks_by_metadata_and_split( + self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None + ) -> tuple[DataFrame, DataFrame]: + """Wrapper around `apply_checks_and_split` for use in the metadata-driven pipelines. The main difference + is how the checks are specified - instead of using functions directly, they are described as function name plus + arguments. + + :param df: dataframe to check + :param checks: list of dictionaries describing checks. Each check is a dictionary consisting of following fields: + * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - + it will be used as an error/warning message, or `null` if it's evaluated to `false` + * `name` - name that will be given to a resulting column. Autogenerated if not provided + * `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), + and `warn` (data is going into both dataframes) + :param glbs: dictionary with functions mapping (eg. ``globals()`` of the calling module). + If not specified, then only built-in functions are used for the checks. + :return: two dataframes - "good" which includes warning rows but no reporting columns, and "bad" having + error and warning rows and corresponding reporting columns + """ + + @abc.abstractmethod + def apply_checks_by_metadata( + self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None + ) -> DataFrame: + """Wrapper around `apply_checks` for use in the metadata-driven pipelines. The main difference + is how the checks are specified - instead of using functions directly, they are described as function name plus + arguments. + + :param df: dataframe to check + :param checks: list of dictionaries describing checks. Each check is a dictionary consisting of following fields: + * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - + it will be used as an error/warning message, or `null` if it's evaluated to `false` + * `name` - name that will be given to a resulting column. Autogenerated if not provided + * `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), + and `warn` (data is going into both dataframes) + :param glbs: dictionary with functions mapping (eg. ``globals()`` of calling module). + If not specified, then only built-in functions are used for the checks. + :return: dataframe with errors and warning reporting columns + """ + + @staticmethod + @abc.abstractmethod + def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus: + """ + Validate the input dict to ensure they conform to expected structure and types. + + Each check can be a dictionary. The function validates + the presence of required keys, the existence and callability of functions, and the types + of arguments passed to these functions. + + :param checks: List of checks to apply to the dataframe. Each check should be a dictionary. + :param glbs: Optional dictionary of global functions that can be used in checks. + + :return ValidationStatus: The validation status. + """ + + @staticmethod + @abc.abstractmethod + def get_invalid(df: DataFrame) -> DataFrame: + """ + Get records that violate data quality checks (records with warnings and errors). + @param df: input DataFrame. + @return: dataframe with error and warning rows and corresponding reporting columns. + """ + + @staticmethod + @abc.abstractmethod + def get_valid(df: DataFrame) -> DataFrame: + """ + Get records that don't violate data quality checks (records with warnings but no errors). + @param df: input DataFrame. + @return: dataframe with warning rows but no reporting columns. + """ + + @staticmethod + @abc.abstractmethod + def load_checks_from_local_file(path: str) -> list[dict]: + """ + Load checks (dq rules) from a file (json or yml) in the local file system. + This does not require installation of DQX in the workspace. + The returning checks can be used as input for `apply_checks_by_metadata` function. + + :param path: path to a file containing the checks. + :return: list of dq rules + """ + + @staticmethod + @abc.abstractmethod + def save_checks_in_local_file(checks: list[dict], path: str): + """ + Save checks (dq rules) to yml file in the local file system. + + :param checks: list of dq rules to save + :param path: path to a file containing the checks. + """ diff --git a/src/databricks/labs/dqx/engine.py b/src/databricks/labs/dqx/engine.py index 0a6785e..c18552f 100644 --- a/src/databricks/labs/dqx/engine.py +++ b/src/databricks/labs/dqx/engine.py @@ -5,148 +5,167 @@ import itertools from pathlib import Path from collections.abc import Callable -from dataclasses import dataclass, field -from enum import Enum from typing import Any import yaml - import pyspark.sql.functions as F -from pyspark.sql import Column, DataFrame +from pyspark.sql import DataFrame +from databricks.labs.dqx.rule import DQRule, Criticality, Columns, DQRuleColSet, ChecksValidationStatus +from databricks.labs.dqx.utils import deserialize_dicts from databricks.labs.dqx import col_functions from databricks.labs.blueprint.installation import Installation - -from databricks.labs.dqx.base import DQEngineBase +from databricks.labs.dqx.base import DQEngineBase, DQEngineCoreBase from databricks.labs.dqx.config import WorkspaceConfig, RunConfig -from databricks.labs.dqx.utils import get_column_name from databricks.sdk.errors import NotFound from databricks.sdk.service.workspace import ImportFormat +from databricks.sdk import WorkspaceClient logger = logging.getLogger(__name__) -# TODO: make this configurable -class Columns(Enum): - """Enum class to represent columns in the dataframe that will be used for error and warning reporting.""" +class DQEngineCore(DQEngineCoreBase): + """Data Quality Engine Core class to apply data quality checks to a given dataframe.""" - ERRORS = "_errors" - WARNINGS = "_warnings" + def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame: + if not checks: + return self._append_empty_checks(df) + warning_checks = self._get_check_columns(checks, Criticality.WARN.value) + error_checks = self._get_check_columns(checks, Criticality.ERROR.value) + ndf = self._create_results_map(df, error_checks, Columns.ERRORS.value) + ndf = self._create_results_map(ndf, warning_checks, Columns.WARNINGS.value) -class Criticality(Enum): - """Enum class to represent criticality of the check.""" + return ndf - WARN = "warn" - ERROR = "error" + def apply_checks_and_split(self, df: DataFrame, checks: list[DQRule]) -> tuple[DataFrame, DataFrame]: + if not checks: + return df, self._append_empty_checks(df).limit(0) + checked_df = self.apply_checks(df, checks) -@dataclass(frozen=True) -class ChecksValidationStatus: - """Class to represent the validation status.""" + good_df = self.get_valid(checked_df) + bad_df = self.get_invalid(checked_df) + + return good_df, bad_df - _errors: list[str] = field(default_factory=list) + def apply_checks_by_metadata_and_split( + self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None + ) -> tuple[DataFrame, DataFrame]: + dq_rule_checks = self.build_checks_by_metadata(checks, glbs) - def add_error(self, error: str): - """Add an error to the validation status.""" - self._errors.append(error) + good_df, bad_df = self.apply_checks_and_split(df, dq_rule_checks) - def add_errors(self, errors: list[str]): - """Add an error to the validation status.""" - self._errors.extend(errors) + return good_df, bad_df - @property - def has_errors(self) -> bool: - """Check if there are any errors in the validation status.""" - return bool(self._errors) + def apply_checks_by_metadata( + self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None + ) -> DataFrame: + dq_rule_checks = self.build_checks_by_metadata(checks, glbs) - @property - def errors(self) -> list[str]: - """Get the list of errors in the validation status.""" - return self._errors + return self.apply_checks(df, dq_rule_checks) - def to_string(self) -> str: - """Convert the validation status to a string.""" - if self.has_errors: - return "\n".join(self._errors) - return "No errors found" + @staticmethod + def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus: + status = ChecksValidationStatus() - def __str__(self) -> str: - """String representation of the ValidationStatus class.""" - return self.to_string() + for check in checks: + logger.debug(f"Processing check definition: {check}") + if isinstance(check, dict): + status.add_errors(DQEngineCore._validate_checks_dict(check, glbs)) + else: + status.add_error(f"Unsupported check type: {type(check)}") + return status -@dataclass(frozen=True) -class DQRule: - """Class to represent a data quality rule consisting of following fields: - * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - - it will be used as an error/warning message, or `null` if it's evaluated to `false` - * `name` - optional name that will be given to a resulting column. Autogenerated if not provided - * `criticality` (optional) - possible values are `error` (critical problems), and `warn` (potential problems) - """ + @staticmethod + def get_invalid(df: DataFrame) -> DataFrame: + return df.where(F.col(Columns.ERRORS.value).isNotNull() | F.col(Columns.WARNINGS.value).isNotNull()) - check: Column - name: str = "" - criticality: str = Criticality.ERROR.value + @staticmethod + def get_valid(df: DataFrame) -> DataFrame: + return df.where(F.col(Columns.ERRORS.value).isNull()).drop(Columns.ERRORS.value, Columns.WARNINGS.value) - def __post_init__(self): - # take the name from the alias of the column expression if not provided - object.__setattr__(self, "name", self.name if self.name else "col_" + get_column_name(self.check)) + @staticmethod + def load_checks_from_local_file(path: str) -> list[dict]: + if not path: + raise ValueError("filename must be provided") - @ft.cached_property - def rule_criticality(self) -> str: - """Returns criticality of the check. + try: + checks = Installation.load_local(list[dict[str, str]], Path(path)) + return deserialize_dicts(checks) + except FileNotFoundError: + msg = f"Checks file {path} missing" + raise FileNotFoundError(msg) from None - :return: string describing criticality - `warn` or `error`. - :raises ValueError: if criticality is invalid. - """ - criticality = self.criticality - if criticality not in {Criticality.WARN.value, Criticality.ERROR.value}: - raise ValueError(f"Invalid criticality value: {criticality}") + @staticmethod + def save_checks_in_local_file(checks: list[dict], path: str): + if not path: + raise ValueError("filename must be provided") - return criticality + try: + with open(path, 'w', encoding="utf-8") as file: + yaml.safe_dump(checks, file) + except FileNotFoundError: + msg = f"Checks file {path} missing" + raise FileNotFoundError(msg) from None - def check_column(self) -> Column: - """Creates a Column object from the given check. + @staticmethod + def build_checks_by_metadata(checks: list[dict], glbs: dict[str, Any] | None = None) -> list[DQRule]: + """Build checks based on check specification, i.e. function name plus arguments. - :return: Column object + :param checks: list of dictionaries describing checks. Each check is a dictionary consisting of following fields: + * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - + it will be used as an error/warning message, or `null` if it's evaluated to `false` + * `name` - name that will be given to a resulting column. Autogenerated if not provided + * `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), + and `warn` (data is going into both dataframes) + :param glbs: dictionary with functions mapping (eg. ``globals()`` of the calling module). + If not specified, then only built-in functions are used for the checks. + :return: list of data quality check rules """ - return F.when(self.check.isNull(), F.lit(None).cast("string")).otherwise(self.check) + status = DQEngineCore.validate_checks(checks, glbs) + if status.has_errors: + raise ValueError(str(status)) + dq_rule_checks = [] + for check_def in checks: + logger.debug(f"Processing check definition: {check_def}") + check = check_def.get("check", {}) + func_name = check.get("function", None) + func = DQEngineCore._resolve_function(func_name, glbs, fail_on_missing=True) + assert func # should already be validated + func_args = check.get("arguments", {}) + criticality = check_def.get("criticality", "error") -@dataclass(frozen=True) -class DQRuleColSet: - """Class to represent a data quality col rule set which defines quality check function for a set of columns. - The class consists of the following fields: - * `columns` - list of column names to which the given check function should be applied - * `criticality` - criticality level ('warn' or 'error') - * `check_func` - check function to be applied - * `check_func_args` - non-keyword / positional arguments for the check function after the col_name - * `check_func_kwargs` - keyword /named arguments for the check function after the col_name - """ + if "col_names" in func_args: + logger.debug(f"Adding DQRuleColSet with columns: {func_args['col_names']}") + dq_rule_checks += DQRuleColSet( + columns=func_args["col_names"], + check_func=func, + criticality=criticality, + # provide arguments without "col_names" + check_func_kwargs={k: func_args[k] for k in func_args.keys() - {"col_names"}}, + ).get_rules() + else: + name = check_def.get("name", None) + check_func = func(**func_args) + dq_rule_checks.append(DQRule(check=check_func, name=name, criticality=criticality)) - columns: list[str] - check_func: Callable - criticality: str = Criticality.ERROR.value - check_func_args: list[Any] = field(default_factory=list) - check_func_kwargs: dict[str, Any] = field(default_factory=dict) + logger.debug("Exiting build_checks_by_metadata function with dq_rule_checks") + return dq_rule_checks - def get_rules(self) -> list[DQRule]: - """Build a list of rules for a set of columns. + @staticmethod + def build_checks(*rules_col_set: DQRuleColSet) -> list[DQRule]: + """ + Build rules from dq rules and rule sets. + :param rules_col_set: list of dq rules which define multiple columns for the same check function :return: list of dq rules """ - rules = [] - for col_name in self.columns: - rule = DQRule( - criticality=self.criticality, - check=self.check_func(col_name, *self.check_func_args, **self.check_func_kwargs), - ) - rules.append(rule) - return rules - + rules_nested = [rule_set.get_rules() for rule_set in rules_col_set] + flat_rules = list(itertools.chain(*rules_nested)) -class DQEngine(DQEngineBase): - """Data Quality Engine class to apply data quality checks to a given dataframe.""" + return list(filter(None, flat_rules)) @staticmethod def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]: @@ -194,85 +213,6 @@ def _create_results_map(df: DataFrame, checks: list[DQRule], dest_col: str) -> D m_col = F.map_filter(m_col, lambda _, v: v.isNotNull()) return df.withColumn(dest_col, F.when(F.size(m_col) > 0, m_col).otherwise(empty_type)) - def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame: - """Applies data quality checks to a given dataframe. - - :param df: dataframe to check - :param checks: list of checks to apply to the dataframe. Each check is an instance of DQRule class. - :return: dataframe with errors and warning reporting columns - """ - if not checks: - return self._append_empty_checks(df) - - warning_checks = self._get_check_columns(checks, Criticality.WARN.value) - error_checks = self._get_check_columns(checks, Criticality.ERROR.value) - ndf = self._create_results_map(df, error_checks, Columns.ERRORS.value) - ndf = self._create_results_map(ndf, warning_checks, Columns.WARNINGS.value) - - return ndf - - def apply_checks_and_split(self, df: DataFrame, checks: list[DQRule]) -> tuple[DataFrame, DataFrame]: - """Applies data quality checks to a given dataframe and split it into two ("good" and "bad"), - according to the data quality checks. - - :param df: dataframe to check - :param checks: list of checks to apply to the dataframe. Each check is an instance of DQRule class. - :return: two dataframes - "good" which includes warning rows but no reporting columns, and "data" having - error and warning rows and corresponding reporting columns - """ - if not checks: - return df, self._append_empty_checks(df).limit(0) - - checked_df = self.apply_checks(df, checks) - - good_df = self.get_valid(checked_df) - bad_df = self.get_invalid(checked_df) - - return good_df, bad_df - - @staticmethod - def get_invalid(df: DataFrame) -> DataFrame: - """ - Get records that violate data quality checks (records with warnings and errors). - @param df: input DataFrame. - @return: dataframe with error and warning rows and corresponding reporting columns. - """ - return df.where(F.col(Columns.ERRORS.value).isNotNull() | F.col(Columns.WARNINGS.value).isNotNull()) - - @staticmethod - def get_valid(df: DataFrame) -> DataFrame: - """ - Get records that don't violate data quality checks (records with warnings but no errors). - @param df: input DataFrame. - @return: dataframe with warning rows but no reporting columns. - """ - return df.where(F.col(Columns.ERRORS.value).isNull()).drop(Columns.ERRORS.value, Columns.WARNINGS.value) - - @staticmethod - def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus: - """ - Validate the input dict to ensure they conform to expected structure and types. - - Each check can be a dictionary. The function validates - the presence of required keys, the existence and callability of functions, and the types - of arguments passed to these functions. - - :param checks: List of checks to apply to the dataframe. Each check should be a dictionary. - :param glbs: Optional dictionary of global functions that can be used in checks. - - :return ValidationStatus: The validation status. - """ - status = ChecksValidationStatus() - - for check in checks: - logger.debug(f"Processing check definition: {check}") - if isinstance(check, dict): - status.add_errors(DQEngine._validate_checks_dict(check, glbs)) - else: - status.add_error(f"Unsupported check type: {type(check)}") - - return status - @staticmethod def _validate_checks_dict(check: dict, glbs: dict[str, Any] | None) -> list[str]: """ @@ -295,7 +235,7 @@ def _validate_checks_dict(check: dict, glbs: dict[str, Any] | None) -> list[str] elif not isinstance(check["check"], dict): errors.append(f"'check' field should be a dictionary: {check}") else: - errors.extend(DQEngine._validate_check_block(check, glbs)) + errors.extend(DQEngineCore._validate_check_block(check, glbs)) return errors @@ -317,12 +257,12 @@ def _validate_check_block(check: dict, glbs: dict[str, Any] | None) -> list[str] return [f"'function' field is missing in the 'check' block: {check}"] func_name = check_block["function"] - func = DQEngine.resolve_function(func_name, glbs, fail_on_missing=False) + func = DQEngineCore._resolve_function(func_name, glbs, fail_on_missing=False) if not callable(func): return [f"function '{func_name}' is not defined: {check}"] arguments = check_block.get("arguments", {}) - return DQEngine._validate_check_function_arguments(arguments, func, check) + return DQEngineCore._validate_check_function_arguments(arguments, func, check) @staticmethod def _validate_check_function_arguments(arguments: dict, func: Callable, check: dict) -> list[str]: @@ -351,9 +291,9 @@ def _validate_check_function_arguments(arguments: dict, func: Callable, check: d 'col_name' if k == 'col_names' else k: arguments['col_names'][0] if k == 'col_names' else v for k, v in arguments.items() } - return DQEngine._validate_func_args(arguments, func, check) + return DQEngineCore._validate_func_args(arguments, func, check) - return DQEngine._validate_func_args(arguments, func, check) + return DQEngineCore._validate_func_args(arguments, func, check) @staticmethod def _validate_func_args(arguments: dict, func: Callable, check: dict) -> list[str]: @@ -395,52 +335,7 @@ def cached_signature(check_func): return errors @staticmethod - def build_checks_by_metadata(checks: list[dict], glbs: dict[str, Any] | None = None) -> list[DQRule]: - """Build checks based on check specification, i.e. function name plus arguments. - - :param checks: list of dictionaries describing checks. Each check is a dictionary consisting of following fields: - * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - - it will be used as an error/warning message, or `null` if it's evaluated to `false` - * `name` - name that will be given to a resulting column. Autogenerated if not provided - * `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), - and `warn` (data is going into both dataframes) - :param glbs: dictionary with functions mapping (eg. ``globals()`` of the calling module). - If not specified, then only built-in functions are used for the checks. - :return: list of data quality check rules - """ - status = DQEngine.validate_checks(checks, glbs) - if status.has_errors: - raise ValueError(str(status)) - - dq_rule_checks = [] - for check_def in checks: - logger.debug(f"Processing check definition: {check_def}") - check = check_def.get("check", {}) - func_name = check.get("function", None) - func = DQEngine.resolve_function(func_name, glbs, fail_on_missing=True) - assert func # should already be validated - func_args = check.get("arguments", {}) - criticality = check_def.get("criticality", "error") - - if "col_names" in func_args: - logger.debug(f"Adding DQRuleColSet with columns: {func_args['col_names']}") - dq_rule_checks += DQRuleColSet( - columns=func_args["col_names"], - check_func=func, - criticality=criticality, - # provide arguments without "col_names" - check_func_kwargs={k: func_args[k] for k in func_args.keys() - {"col_names"}}, - ).get_rules() - else: - name = check_def.get("name", None) - check_func = func(**func_args) - dq_rule_checks.append(DQRule(check=check_func, name=name, criticality=criticality)) - - logger.debug("Exiting build_checks_by_metadata function with dq_rule_checks") - return dq_rule_checks - - @staticmethod - def resolve_function(func_name: str, glbs: dict[str, Any] | None = None, fail_on_missing=True) -> Callable | None: + def _resolve_function(func_name: str, glbs: dict[str, Any] | None = None, fail_on_missing=True) -> Callable | None: logger.debug(f"Resolving function: {func_name}") if glbs: func = glbs.get(func_name) @@ -451,104 +346,45 @@ def resolve_function(func_name: str, glbs: dict[str, Any] | None = None, fail_on logger.debug(f"Function {func_name} resolved successfully") return func - def apply_checks_by_metadata_and_split( - self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None - ) -> tuple[DataFrame, DataFrame]: - """Wrapper around `apply_checks_and_split` for use in the metadata-driven pipelines. The main difference - is how the checks are specified - instead of using functions directly, they are described as function name plus - arguments. - :param df: dataframe to check - :param checks: list of dictionaries describing checks. Each check is a dictionary consisting of following fields: - * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - - it will be used as an error/warning message, or `null` if it's evaluated to `false` - * `name` - name that will be given to a resulting column. Autogenerated if not provided - * `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), - and `warn` (data is going into both dataframes) - :param glbs: dictionary with functions mapping (eg. ``globals()`` of the calling module). - If not specified, then only built-in functions are used for the checks. - :return: two dataframes - "good" which includes warning rows but no reporting columns, and "bad" having - error and warning rows and corresponding reporting columns - """ - dq_rule_checks = self.build_checks_by_metadata(checks, glbs) +class DQEngine(DQEngineBase): + """Data Quality Engine class to apply data quality checks to a given dataframe.""" - good_df, bad_df = self.apply_checks_and_split(df, dq_rule_checks) + def __init__(self, workspace_client: WorkspaceClient, engine: DQEngineCoreBase | None = None): + super().__init__(workspace_client) + self._engine = engine or DQEngineCore(workspace_client) - return good_df, bad_df + def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame: + return self._engine.apply_checks(df, checks) + + def apply_checks_and_split(self, df: DataFrame, checks: list[DQRule]) -> tuple[DataFrame, DataFrame]: + return self._engine.apply_checks_and_split(df, checks) + + def apply_checks_by_metadata_and_split( + self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None + ) -> tuple[DataFrame, DataFrame]: + return self._engine.apply_checks_by_metadata_and_split(df, checks, glbs) def apply_checks_by_metadata( self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None ) -> DataFrame: - """Wrapper around `apply_checks` for use in the metadata-driven pipelines. The main difference - is how the checks are specified - instead of using functions directly, they are described as function name plus - arguments. - - :param df: dataframe to check - :param checks: list of dictionaries describing checks. Each check is a dictionary consisting of following fields: - * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - - it will be used as an error/warning message, or `null` if it's evaluated to `false` - * `name` - name that will be given to a resulting column. Autogenerated if not provided - * `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), - and `warn` (data is going into both dataframes) - :param glbs: dictionary with functions mapping (eg. ``globals()`` of calling module). - If not specified, then only built-in functions are used for the checks. - :return: dataframe with errors and warning reporting columns - """ - dq_rule_checks = self.build_checks_by_metadata(checks, glbs) - - return self.apply_checks(df, dq_rule_checks) + return self._engine.apply_checks_by_metadata(df, checks, glbs) @staticmethod - def build_checks(*rules_col_set: DQRuleColSet) -> list[DQRule]: - """ - Build rules from dq rules and rule sets. - - :param rules_col_set: list of dq rules which define multiple columns for the same check function - :return: list of dq rules - """ - rules_nested = [rule_set.get_rules() for rule_set in rules_col_set] - flat_rules = list(itertools.chain(*rules_nested)) - - return list(filter(None, flat_rules)) - - def load_run_config( - self, run_config: str | None = "default", assume_user: bool = True, product_name: str = "dqx" - ) -> RunConfig: - """ - Load run configuration from the installation. - - :param run_config: name of the run configuration to use - :param assume_user: if True, assume user installation - :param product_name: name of the product - """ - installation = self._get_installation(assume_user, product_name) - return self._load_run_config(installation, run_config) + def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus: + return DQEngineCore.validate_checks(checks, glbs) @staticmethod - def _load_run_config(installation, run_config): - """Load run configuration from the installation.""" - config = installation.load(WorkspaceConfig) - return config.get_run_config(run_config) + def get_invalid(df: DataFrame) -> DataFrame: + return DQEngineCore.get_invalid(df) @staticmethod - def load_checks_from_local_file(filename: str) -> list[dict]: - """ - Load checks (dq rules) from a file (json or yml) in the local file system. - This does not require installation of DQX in the workspace. - The returning checks can be used as input for `apply_checks_by_metadata` function. - - :param filename: file name / path containing the checks. - :return: list of dq rules - """ - if not filename: - raise ValueError("filename must be provided") + def get_valid(df: DataFrame) -> DataFrame: + return DQEngineCore.get_valid(df) - try: - checks = Installation.load_local(list[dict[str, str]], Path(filename)) - return DQEngine._deserialize_dicts(checks) - except FileNotFoundError: - msg = f"Checks file {filename} missing" - raise FileNotFoundError(msg) from None + @staticmethod + def load_checks_from_local_file(path: str) -> list[dict]: + return DQEngineCore.load_checks_from_local_file(path) def load_checks_from_workspace_file(self, workspace_path: str) -> list[dict]: """Load checks (dq rules) from a file (json or yml) in the workspace. @@ -565,7 +401,7 @@ def load_checks_from_workspace_file(self, workspace_path: str) -> list[dict]: logger.info(f"Loading quality rules (checks) from {workspace_path} in the workspace.") return self._load_checks_from_file(installation, filename) - def load_checks( + def load_checks_from_installation( self, run_config_name: str | None = "default", product_name: str = "dqx", assume_user: bool = True ) -> list[dict]: """ @@ -584,7 +420,11 @@ def load_checks( logger.info(f"Loading quality rules (checks) from {installation.install_folder()}/{filename} in the workspace.") return self._load_checks_from_file(installation, filename) - def save_checks( + @staticmethod + def save_checks_in_local_file(checks: list[dict], path: str): + return DQEngineCore.save_checks_in_local_file(checks, path) + + def save_checks_in_installation( self, checks: list[dict], run_config_name: str | None = "default", @@ -623,23 +463,18 @@ def save_checks_in_workspace_file(self, checks: list[dict], workspace_path: str) workspace_path, yaml.safe_dump(checks).encode('utf-8'), format=ImportFormat.AUTO, overwrite=True ) - @staticmethod - def save_checks_in_local_file(checks: list[dict], filename: str): + def load_run_config( + self, run_config_name: str | None = "default", assume_user: bool = True, product_name: str = "dqx" + ) -> RunConfig: """ - Save checks (dq rules) to yml file in the local file system. + Load run configuration from the installation. - :param checks: list of dq rules to save - :param filename: file name / path containing the checks. + :param run_config_name: name of the run configuration to use + :param assume_user: if True, assume user installation + :param product_name: name of the product """ - if not filename: - raise ValueError("filename must be provided") - - try: - with open(filename, 'w', encoding="utf-8") as file: - yaml.safe_dump(checks, file) - except FileNotFoundError: - msg = f"Checks file {filename} missing" - raise FileNotFoundError(msg) from None + installation = self._get_installation(assume_user, product_name) + return self._load_run_config(installation, run_config_name) def _get_installation(self, assume_user, product_name): if assume_user: @@ -651,23 +486,17 @@ def _get_installation(self, assume_user, product_name): installation.current(self.ws, product_name, assume_user=assume_user) return installation - def _load_checks_from_file(self, installation: Installation, filename: str) -> list[dict]: + @staticmethod + def _load_run_config(installation, run_config_name): + """Load run configuration from the installation.""" + config = installation.load(WorkspaceConfig) + return config.get_run_config(run_config_name) + + @staticmethod + def _load_checks_from_file(installation: Installation, filename: str) -> list[dict]: try: checks = installation.load(list[dict[str, str]], filename=filename) - return self._deserialize_dicts(checks) + return deserialize_dicts(checks) except NotFound: msg = f"Checks file {filename} missing" raise NotFound(msg) from None - - @classmethod - def _deserialize_dicts(cls, checks: list[dict[str, str]]) -> list[dict]: - """ - deserialize string fields instances containing dictionaries - @param checks: list of checks - @return: - """ - for item in checks: - for key, value in item.items(): - if value.startswith("{") and value.endswith("}"): - item[key] = yaml.safe_load(value.replace("'", '"')) - return checks diff --git a/src/databricks/labs/dqx/installer/install.py b/src/databricks/labs/dqx/installer/install.py index b192398..f692b9f 100644 --- a/src/databricks/labs/dqx/installer/install.py +++ b/src/databricks/labs/dqx/installer/install.py @@ -1,3 +1,4 @@ +import re import logging import dataclasses import os @@ -45,7 +46,7 @@ from databricks.labs.dqx.__about__ import __version__ from databricks.labs.dqx.config import WorkspaceConfig, RunConfig from databricks.labs.dqx.contexts.workspace import WorkspaceContext -from databricks.labs.dqx.utils import extract_major_minor + logger = logging.getLogger(__name__) with_user_agent_extra("cmd", "install") @@ -147,6 +148,19 @@ def run( raise err return config + @staticmethod + def extract_major_minor(version_string: str): + """ + Extracts the major and minor version from a version string. + + :param version_string: The version string to extract from. + :return: The major.minor version as a string, or None if not found. + """ + match = re.search(r"(\d+\.\d+)", version_string) + if match: + return match.group(1) + return None + def _is_testing(self): return self.product_info.product_name() != "dqx" @@ -211,7 +225,7 @@ def _compare_remote_local_versions(self): try: local_version = self.product_info.released_version() remote_version = self.installation.load(Version).version - if extract_major_minor(remote_version) == extract_major_minor(local_version): + if self.extract_major_minor(remote_version) == self.extract_major_minor(local_version): logger.info(f"DQX v{self.product_info.version()} is already installed on this workspace") msg = "Do you want to update the existing installation?" if not self.prompts.confirm(msg): @@ -515,7 +529,7 @@ def _create_dashboard(self, folder: Path, *, parent_path: str) -> None: dashboard_id = self._install_state.dashboards.get(reference) logger.debug(f"dashboard id retrieved is {dashboard_id}") - logger.info(f"Installing '{metadata.display_name}' dashboard...") + logger.info(f"Installing '{metadata.display_name}' dashboard in '{parent_path}'") if dashboard_id is not None: dashboard_id = self._handle_existing_dashboard(dashboard_id, metadata.display_name, parent_path) dashboard = Dashboards(self._ws).create_dashboard( diff --git a/src/databricks/labs/dqx/installer/workflow_task.py b/src/databricks/labs/dqx/installer/workflow_task.py index 3fc415a..636d0c4 100644 --- a/src/databricks/labs/dqx/installer/workflow_task.py +++ b/src/databricks/labs/dqx/installer/workflow_task.py @@ -7,7 +7,6 @@ from databricks.sdk import WorkspaceClient from databricks.labs.dqx.config import WorkspaceConfig -from databricks.labs.dqx.utils import remove_extra_indentation logger = logging.getLogger(__name__) @@ -80,3 +79,19 @@ def register(func): return register register(fn) return fn + + +def remove_extra_indentation(doc: str) -> str: + """ + Remove extra indentation from docstring. + + :param doc: Docstring + """ + lines = doc.splitlines() + stripped = [] + for line in lines: + if line.startswith(" " * 4): + stripped.append(line[4:]) + else: + stripped.append(line) + return "\n".join(stripped) diff --git a/src/databricks/labs/dqx/rule.py b/src/databricks/labs/dqx/rule.py new file mode 100644 index 0000000..b04ed4b --- /dev/null +++ b/src/databricks/labs/dqx/rule.py @@ -0,0 +1,128 @@ +from enum import Enum +from dataclasses import dataclass, field +import functools as ft +from typing import Any +from collections.abc import Callable +from pyspark.sql import Column +import pyspark.sql.functions as F +from databricks.labs.dqx.utils import get_column_name + + +# TODO: make this configurable +class Columns(Enum): + """Enum class to represent columns in the dataframe that will be used for error and warning reporting.""" + + ERRORS = "_errors" + WARNINGS = "_warnings" + + +class Criticality(Enum): + """Enum class to represent criticality of the check.""" + + WARN = "warn" + ERROR = "error" + + +@dataclass(frozen=True) +class DQRule: + """Class to represent a data quality rule consisting of following fields: + * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - + it will be used as an error/warning message, or `null` if it's evaluated to `false` + * `name` - optional name that will be given to a resulting column. Autogenerated if not provided + * `criticality` (optional) - possible values are `error` (critical problems), and `warn` (potential problems) + """ + + check: Column + name: str = "" + criticality: str = Criticality.ERROR.value + + def __post_init__(self): + # take the name from the alias of the column expression if not provided + object.__setattr__(self, "name", self.name if self.name else "col_" + get_column_name(self.check)) + + @ft.cached_property + def rule_criticality(self) -> str: + """Returns criticality of the check. + + :return: string describing criticality - `warn` or `error`. + :raises ValueError: if criticality is invalid. + """ + criticality = self.criticality + if criticality not in {Criticality.WARN.value, Criticality.ERROR.value}: + raise ValueError(f"Invalid criticality value: {criticality}") + + return criticality + + def check_column(self) -> Column: + """Creates a Column object from the given check. + + :return: Column object + """ + return F.when(self.check.isNull(), F.lit(None).cast("string")).otherwise(self.check) + + +@dataclass(frozen=True) +class DQRuleColSet: + """Class to represent a data quality col rule set which defines quality check function for a set of columns. + The class consists of the following fields: + * `columns` - list of column names to which the given check function should be applied + * `criticality` - criticality level ('warn' or 'error') + * `check_func` - check function to be applied + * `check_func_args` - non-keyword / positional arguments for the check function after the col_name + * `check_func_kwargs` - keyword /named arguments for the check function after the col_name + """ + + columns: list[str] + check_func: Callable + criticality: str = Criticality.ERROR.value + check_func_args: list[Any] = field(default_factory=list) + check_func_kwargs: dict[str, Any] = field(default_factory=dict) + + def get_rules(self) -> list[DQRule]: + """Build a list of rules for a set of columns. + + :return: list of dq rules + """ + rules = [] + for col_name in self.columns: + rule = DQRule( + criticality=self.criticality, + check=self.check_func(col_name, *self.check_func_args, **self.check_func_kwargs), + ) + rules.append(rule) + return rules + + +@dataclass(frozen=True) +class ChecksValidationStatus: + """Class to represent the validation status.""" + + _errors: list[str] = field(default_factory=list) + + def add_error(self, error: str): + """Add an error to the validation status.""" + self._errors.append(error) + + def add_errors(self, errors: list[str]): + """Add an error to the validation status.""" + self._errors.extend(errors) + + @property + def has_errors(self) -> bool: + """Check if there are any errors in the validation status.""" + return bool(self._errors) + + @property + def errors(self) -> list[str]: + """Get the list of errors in the validation status.""" + return self._errors + + def to_string(self) -> str: + """Convert the validation status to a string.""" + if self.has_errors: + return "\n".join(self._errors) + return "No errors found" + + def __str__(self) -> str: + """String representation of the ValidationStatus class.""" + return self.to_string() diff --git a/src/databricks/labs/dqx/utils.py b/src/databricks/labs/dqx/utils.py index 7d2c26d..c087df1 100644 --- a/src/databricks/labs/dqx/utils.py +++ b/src/databricks/labs/dqx/utils.py @@ -1,4 +1,5 @@ import re +import yaml from pyspark.sql import Column from pyspark.sql import SparkSession @@ -44,30 +45,14 @@ def read_input_data(spark: SparkSession, input_location: str | None, input_forma ) -def remove_extra_indentation(doc: str) -> str: +def deserialize_dicts(checks: list[dict[str, str]]) -> list[dict]: """ - Remove extra indentation from docstring. - - :param doc: Docstring - """ - lines = doc.splitlines() - stripped = [] - for line in lines: - if line.startswith(" " * 4): - stripped.append(line[4:]) - else: - stripped.append(line) - return "\n".join(stripped) - - -def extract_major_minor(version_string: str): - """ - Extracts the major and minor version from a version string. - - :param version_string: The version string to extract from. - :return: The major.minor version as a string, or None if not found. + deserialize string fields instances containing dictionaries + @param checks: list of checks + @return: """ - match = re.search(r"(\d+\.\d+)", version_string) - if match: - return match.group(1) - return None + for item in checks: + for key, value in item.items(): + if value.startswith("{") and value.endswith("}"): + item[key] = yaml.safe_load(value.replace("'", '"')) + return checks diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ff5caaf..33d9907 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -32,7 +32,7 @@ @pytest.fixture def debug_env_name(): - return "ws" + return "ws" # Specify the name of the debug environment from ~/.databricks/debug-env.json @pytest.fixture diff --git a/tests/integration/test_apply_checks.py b/tests/integration/test_apply_checks.py index 388dd72..63b9684 100644 --- a/tests/integration/test_apply_checks.py +++ b/tests/integration/test_apply_checks.py @@ -4,10 +4,9 @@ from pyspark.sql import Column from chispa.dataframe_comparer import assert_df_equality # type: ignore from databricks.labs.dqx.col_functions import is_not_null_and_not_empty, make_condition -from databricks.labs.dqx.engine import ( - DQRule, - DQEngine, -) +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.rule import DQRule + SCHEMA = "a: int, b: int, c: int" EXPECTED_SCHEMA = SCHEMA + ", _errors: map, _warnings: map" @@ -442,3 +441,53 @@ def col_test_check_func(col_name: str) -> Column: check_col = F.col(col_name) condition = check_col.isNull() | (check_col == "") | (check_col == "null") return make_condition(condition, "new check failed", f"{col_name}_is_null_or_empty") + + +def test_get_valid_records(ws, spark): + dq_engine = DQEngine(ws) + + test_df = spark.createDataFrame( + [ + [1, 1, 1, None, None], + [None, 2, 2, None, {"col_a_is_null_or_empty": "check failed"}], + [None, 2, 2, {"col_b_is_null_or_empty": "check failed"}, None], + ], + EXPECTED_SCHEMA, + ) + + valid_df = dq_engine.get_valid(test_df) + + expected_valid_df = spark.createDataFrame( + [ + [1, 1, 1], + [None, 2, 2], + ], + SCHEMA, + ) + + assert_df_equality(valid_df, expected_valid_df) + + +def test_get_invalid_records(ws, spark): + dq_engine = DQEngine(ws) + + test_df = spark.createDataFrame( + [ + [1, 1, 1, None, None], + [None, 2, 2, None, {"col_a_is_null_or_empty": "check failed"}], + [None, 2, 2, {"col_b_is_null_or_empty": "check failed"}, None], + ], + EXPECTED_SCHEMA, + ) + + invalid_df = dq_engine.get_invalid(test_df) + + expected_invalid_df = spark.createDataFrame( + [ + [None, 2, 2, None, {"col_a_is_null_or_empty": "check failed"}], + [None, 2, 2, {"col_b_is_null_or_empty": "check failed"}, None], + ], + EXPECTED_SCHEMA, + ) + + assert_df_equality(invalid_df, expected_invalid_df) diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index a49f211..5758cd0 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -122,7 +122,7 @@ def test_profiler(ws, setup_workflows, caplog): profile(installation_ctx.workspace_client, run_config=run_config.name, ctx=installation_ctx.workspace_installer) - checks = DQEngine(ws).load_checks( + checks = DQEngine(ws).load_checks_from_installation( run_config_name=run_config.name, assume_user=True, product_name=installation_ctx.installation.product() ) assert checks, "Checks were not loaded correctly" diff --git a/tests/integration/test_config.py b/tests/integration/test_config.py new file mode 100644 index 0000000..47f860e --- /dev/null +++ b/tests/integration/test_config.py @@ -0,0 +1,28 @@ +from unittest.mock import patch +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.blueprint.installation import Installation + + +def test_load_run_config_from_user_installation(ws, installation_ctx): + installation_ctx.installation.save(installation_ctx.config) + product_name = installation_ctx.product_info.product_name() + + run_config = DQEngine(ws).load_run_config(run_config_name="default", assume_user=True, product_name=product_name) + expected_run_config = installation_ctx.config.get_run_config("default") + + assert run_config == expected_run_config + + +def test_load_run_config_from_global_installation(ws, installation_ctx): + product_name = installation_ctx.product_info.product_name() + expected_run_config = installation_ctx.config.get_run_config("default") + + with patch.object(Installation, '_global_installation', return_value=f"/Shared/{product_name}"): + installation_ctx.installation = Installation.assume_global(ws, product_name) + installation_ctx.installation.save(installation_ctx.config) + + run_config = DQEngine(ws).load_run_config( + run_config_name="default", assume_user=False, product_name=product_name + ) + + assert run_config == expected_run_config diff --git a/tests/integration/test_load_checks_from_file.py b/tests/integration/test_load_checks_from_file.py index 9d815c8..53c92c3 100644 --- a/tests/integration/test_load_checks_from_file.py +++ b/tests/integration/test_load_checks_from_file.py @@ -18,7 +18,9 @@ def test_load_checks_when_checks_file_does_not_exist_in_workspace(ws, installati def test_load_checks_from_installation_when_checks_file_does_not_exist_in_workspace(ws, installation_ctx): installation_ctx.installation.save(installation_ctx.config) with pytest.raises(NotFound, match="Checks file checks.yml missing"): - DQEngine(ws).load_checks(assume_user=True, product_name=installation_ctx.installation.product()) + DQEngine(ws).load_checks_from_installation( + run_config_name="default", assume_user=True, product_name=installation_ctx.installation.product() + ) def test_load_checks_from_file(ws, installation_ctx, make_check_file_as_yaml): @@ -37,7 +39,9 @@ def test_load_checks_from_user_installation(ws, installation_ctx, make_check_fil installation_ctx.installation.save(installation_ctx.config) make_check_file_as_yaml(install_dir=installation_ctx.installation.install_folder()) - checks = DQEngine(ws).load_checks(assume_user=True, product_name=installation_ctx.installation.product()) + checks = DQEngine(ws).load_checks_from_installation( + run_config_name="default", assume_user=True, product_name=installation_ctx.installation.product() + ) assert checks, "Checks were not loaded correctly" @@ -49,16 +53,18 @@ def test_load_checks_from_global_installation(ws, installation_ctx, make_check_f installation_ctx.installation = Installation.assume_global(ws, product_name) installation_ctx.installation.save(installation_ctx.config) make_check_file_as_yaml(install_dir=install_dir) - checks = DQEngine(ws).load_checks(assume_user=False, product_name=product_name) + checks = DQEngine(ws).load_checks_from_installation( + run_config_name="default", assume_user=False, product_name=product_name + ) assert checks, "Checks were not loaded correctly" assert installation_ctx.workspace_installation.folder == f"/Shared/{product_name}" def test_load_checks_when_global_installation_missing(ws): with pytest.raises(NotInstalled, match="Application not installed: dqx"): - DQEngine(ws).load_checks(assume_user=False) + DQEngine(ws).load_checks_from_installation(run_config_name="default", assume_user=False) def test_load_checks_when_user_installation_missing(ws): with pytest.raises(NotFound): - DQEngine(ws).load_checks(assume_user=True) + DQEngine(ws).load_checks_from_installation(run_config_name="default", assume_user=True) diff --git a/tests/integration/test_profiler_runner.py b/tests/integration/test_profiler_runner.py index 3df442c..65c32b2 100644 --- a/tests/integration/test_profiler_runner.py +++ b/tests/integration/test_profiler_runner.py @@ -106,7 +106,7 @@ def test_profiler_workflow(ws, spark, setup_workflows): ProfilerWorkflow().profile(ctx) # type: ignore - checks = DQEngine(ws).load_checks( + checks = DQEngine(ws).load_checks_from_installation( run_config_name=run_config.name, assume_user=True, product_name=installation_ctx.installation.product() ) assert checks, "Checks were not loaded correctly" diff --git a/tests/integration/test_profiler_workflow.py b/tests/integration/test_profiler_workflow.py index 506a00f..a5d76ff 100644 --- a/tests/integration/test_profiler_workflow.py +++ b/tests/integration/test_profiler_workflow.py @@ -37,7 +37,7 @@ def test_profiler_workflow_e2e(ws, setup_workflows): installation_ctx.deployed_workflows.run_workflow("profiler", run_config.name) - checks = DQEngine(ws).load_checks( + checks = DQEngine(ws).load_checks_from_installation( run_config_name=run_config.name, assume_user=True, product_name=installation_ctx.installation.product() ) assert checks, "Checks were not loaded correctly" diff --git a/tests/integration/test_save_checks_to_file.py b/tests/integration/test_save_checks_to_file.py index 17946b8..8dd314a 100644 --- a/tests/integration/test_save_checks_to_file.py +++ b/tests/integration/test_save_checks_to_file.py @@ -30,9 +30,13 @@ def test_save_checks_in_user_installation(ws, installation_ctx): product_name = installation_ctx.product_info.product_name() dq_engine = DQEngine(ws) - dq_engine.save_checks(TEST_CHECKS, assume_user=True, product_name=product_name) + dq_engine.save_checks_in_installation( + TEST_CHECKS, run_config_name="default", assume_user=True, product_name=product_name + ) - checks = dq_engine.load_checks(assume_user=True, product_name=product_name) + checks = dq_engine.load_checks_from_installation( + run_config_name="default", assume_user=True, product_name=product_name + ) assert TEST_CHECKS == checks, "Checks were not saved correctly" @@ -45,18 +49,22 @@ def test_save_checks_in_global_installation(ws, installation_ctx): installation_ctx.installation.save(installation_ctx.config) dq_engine = DQEngine(ws) - dq_engine.save_checks(TEST_CHECKS, assume_user=False, product_name=product_name) + dq_engine.save_checks_in_installation( + TEST_CHECKS, run_config_name="default", assume_user=False, product_name=product_name + ) - checks = dq_engine.load_checks(assume_user=False, product_name=product_name) + checks = dq_engine.load_checks_from_installation( + run_config_name="default", assume_user=False, product_name=product_name + ) assert TEST_CHECKS == checks, "Checks were not saved correctly" assert installation_ctx.workspace_installation.folder == f"/Shared/{product_name}" def test_save_checks_when_global_installation_missing(ws): with pytest.raises(NotInstalled, match="Application not installed: dqx"): - DQEngine(ws).save_checks(TEST_CHECKS, assume_user=False) + DQEngine(ws).save_checks_in_installation(TEST_CHECKS, run_config_name="default", assume_user=False) def test_load_checks_when_user_installation_missing(ws): with pytest.raises(NotFound): - DQEngine(ws).save_checks(TEST_CHECKS, assume_user=True) + DQEngine(ws).save_checks_in_installation(TEST_CHECKS, run_config_name="default", assume_user=True) diff --git a/tests/unit/test_build_rules.py b/tests/unit/test_build_rules.py index c338090..b09118a 100644 --- a/tests/unit/test_build_rules.py +++ b/tests/unit/test_build_rules.py @@ -11,14 +11,14 @@ from databricks.labs.dqx.engine import ( DQRule, DQRuleColSet, - DQEngine, + DQEngineCore, ) SCHEMA = "a: int, b: int, c: int" def test_build_rules_empty() -> None: - actual_rules = DQEngine.build_checks() + actual_rules = DQEngineCore.build_checks() expected_rules: list[DQRule] = [] @@ -57,7 +57,7 @@ def test_get_rules(): def test_build_rules(): - actual_rules = DQEngine.build_checks( + actual_rules = DQEngineCore.build_checks( # set of columns for the same check DQRuleColSet(columns=["a", "b"], criticality="error", check_func=is_not_null_and_not_empty), DQRuleColSet(columns=["c"], criticality="warn", check_func=is_not_null_and_not_empty), @@ -137,7 +137,7 @@ def test_build_rules_by_metadata(): }, ] - actual_rules = DQEngine.build_checks_by_metadata(checks) + actual_rules = DQEngineCore.build_checks_by_metadata(checks) expected_rules = [ DQRule(name="col_a_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("a")), @@ -165,14 +165,14 @@ def test_build_checks_by_metadata_when_check_spec_is_missing() -> None: checks: list[dict] = [{}] # missing check spec with pytest.raises(ValueError, match="'check' field is missing"): - DQEngine.build_checks_by_metadata(checks) + DQEngineCore.build_checks_by_metadata(checks) def test_build_checks_by_metadata_when_function_spec_is_missing() -> None: checks: list[dict] = [{"check": {}}] # missing func spec with pytest.raises(ValueError, match="'function' field is missing in the 'check' block"): - DQEngine.build_checks_by_metadata(checks) + DQEngineCore.build_checks_by_metadata(checks) def test_build_checks_by_metadata_when_arguments_are_missing(): @@ -188,14 +188,14 @@ def test_build_checks_by_metadata_when_arguments_are_missing(): with pytest.raises( ValueError, match="No arguments provided for function 'is_not_null_and_not_empty' in the 'arguments' block" ): - DQEngine.build_checks_by_metadata(checks) + DQEngineCore.build_checks_by_metadata(checks) def test_build_checks_by_metadata_when_function_does_not_exist(): checks = [{"check": {"function": "function_does_not_exists", "arguments": {"col_name": "a"}}}] with pytest.raises(ValueError, match="function 'function_does_not_exists' is not defined"): - DQEngine.build_checks_by_metadata(checks) + DQEngineCore.build_checks_by_metadata(checks) def test_build_checks_by_metadata_logging_debug_calls(caplog): @@ -208,5 +208,5 @@ def test_build_checks_by_metadata_logging_debug_calls(caplog): logger = logging.getLogger("databricks.labs.dqx.engine") logger.setLevel(logging.DEBUG) with caplog.at_level("DEBUG"): - DQEngine.build_checks_by_metadata(checks) + DQEngineCore.build_checks_by_metadata(checks) assert "Resolving function: is_not_null_and_not_empty" in caplog.text diff --git a/tests/unit/test_installer.py b/tests/unit/test_installer.py index 2977ba0..f55b810 100644 --- a/tests/unit/test_installer.py +++ b/tests/unit/test_installer.py @@ -49,3 +49,15 @@ def test_configure_raises_many_errors(): installer.configure() assert exc_info.value.errs == errors + + +def test_extract_major_minor(): + assert WorkspaceInstaller.extract_major_minor("1.2.3") == "1.2" + assert WorkspaceInstaller.extract_major_minor("10.20.30") == "10.20" + assert WorkspaceInstaller.extract_major_minor("v1.2.3") == "1.2" + assert WorkspaceInstaller.extract_major_minor("version 1.2.3") == "1.2" + assert WorkspaceInstaller.extract_major_minor("1.2") == "1.2" + assert WorkspaceInstaller.extract_major_minor("1.2.3.4") == "1.2" + assert WorkspaceInstaller.extract_major_minor("no version") is None + assert WorkspaceInstaller.extract_major_minor("") is None + assert WorkspaceInstaller.extract_major_minor("1") is None diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index de671ae..2943749 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,6 +1,6 @@ import pyspark.sql.functions as F import pytest -from databricks.labs.dqx.utils import read_input_data, get_column_name, remove_extra_indentation, extract_major_minor +from databricks.labs.dqx.utils import read_input_data, get_column_name def test_get_column_name(): @@ -81,45 +81,3 @@ def test_read_invalid_input_location(spark_session_mock): with pytest.raises(ValueError, match="Invalid input location."): read_input_data(spark_session_mock, input_location, input_format) - - -def test_remove_extra_indentation_no_indentation(): - doc = "This is a test docstring." - expected = "This is a test docstring." - assert remove_extra_indentation(doc) == expected - - -def test_remove_extra_indentation_with_indentation(): - doc = " This is a test docstring with indentation." - expected = "This is a test docstring with indentation." - assert remove_extra_indentation(doc) == expected - - -def test_remove_extra_indentation_mixed_indentation(): - doc = " This is a test docstring with indentation.\nThis line has no indentation." - expected = "This is a test docstring with indentation.\nThis line has no indentation." - assert remove_extra_indentation(doc) == expected - - -def test_remove_extra_indentation_multiple_lines(): - doc = " Line one.\n Line two.\n Line three." - expected = "Line one.\nLine two.\nLine three." - assert remove_extra_indentation(doc) == expected - - -def test_remove_extra_indentation_empty_string(): - doc = "" - expected = "" - assert remove_extra_indentation(doc) == expected - - -def test_extract_major_minor(): - assert extract_major_minor("1.2.3") == "1.2" - assert extract_major_minor("10.20.30") == "10.20" - assert extract_major_minor("v1.2.3") == "1.2" - assert extract_major_minor("version 1.2.3") == "1.2" - assert extract_major_minor("1.2") == "1.2" - assert extract_major_minor("1.2.3.4") == "1.2" - assert extract_major_minor("no version") is None - assert extract_major_minor("") is None - assert extract_major_minor("1") is None diff --git a/tests/unit/test_workflow_task.py b/tests/unit/test_workflow_task.py index ea7932b..8e12f2a 100644 --- a/tests/unit/test_workflow_task.py +++ b/tests/unit/test_workflow_task.py @@ -1,5 +1,5 @@ import pytest -from databricks.labs.dqx.installer.workflow_task import workflow_task, Task, Workflow +from databricks.labs.dqx.installer.workflow_task import workflow_task, Task, Workflow, remove_extra_indentation def test_dependencies(): @@ -74,3 +74,33 @@ def test_workflow_task_returns_register(): decorator = workflow_task() assert callable(decorator) assert decorator.__name__ == "register" + + +def test_remove_extra_indentation_no_indentation(): + doc = "This is a test docstring." + expected = "This is a test docstring." + assert remove_extra_indentation(doc) == expected + + +def test_remove_extra_indentation_with_indentation(): + doc = " This is a test docstring with indentation." + expected = "This is a test docstring with indentation." + assert remove_extra_indentation(doc) == expected + + +def test_remove_extra_indentation_mixed_indentation(): + doc = " This is a test docstring with indentation.\nThis line has no indentation." + expected = "This is a test docstring with indentation.\nThis line has no indentation." + assert remove_extra_indentation(doc) == expected + + +def test_remove_extra_indentation_multiple_lines(): + doc = " Line one.\n Line two.\n Line three." + expected = "Line one.\nLine two.\nLine three." + assert remove_extra_indentation(doc) == expected + + +def test_remove_extra_indentation_empty_string(): + doc = "" + expected = "" + assert remove_extra_indentation(doc) == expected