Skip to content

Commit

Permalink
Merge pull request #131 from ipums/bug_fix_column_mappings_validation
Browse files Browse the repository at this point in the history
Fix a bug with the override_column_X attributes in conf_validations.py
  • Loading branch information
riley-harper authored Feb 20, 2024
2 parents e8db991 + 96d8c0a commit 12ff643
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 23 deletions.
26 changes: 16 additions & 10 deletions hlink/linking/matching/link_step_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,32 @@ def _explode(self, df, comparisons, comparison_features, blocking, id_column, is
expand_length = exploding_column["expand_length"]
derived_from_column = exploding_column["derived_from"]
explode_selects = [
explode(self._expand(derived_from_column, expand_length)).alias(
exploding_column_name
(
explode(self._expand(derived_from_column, expand_length)).alias(
exploding_column_name
)
if exploding_column_name == column
else column
)
if exploding_column_name == column
else column
for column in all_column_names
]
else:
explode_selects = [
explode(col(exploding_column_name)).alias(exploding_column_name)
if exploding_column_name == c
else c
(
explode(col(exploding_column_name)).alias(exploding_column_name)
if exploding_column_name == c
else c
)
for c in all_column_names
]
if "dataset" in exploding_column:
derived_from_column = exploding_column["derived_from"]
explode_selects_with_derived_column = [
col(derived_from_column).alias(exploding_column_name)
if exploding_column_name == column
else column
(
col(derived_from_column).alias(exploding_column_name)
if exploding_column_name == column
else column
)
for column in all_column_names
]
if exploding_column["dataset"] == "a":
Expand Down
60 changes: 48 additions & 12 deletions hlink/scripts/lib/conf_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# in this project's top-level directory, and also on-line at:
# https://github.com/ipums/hlink

from pyspark.sql.utils import AnalysisException
from os import path
from typing import Any, Literal

import colorama
from pyspark.sql.utils import AnalysisException
from pyspark.sql import DataFrame


def print_checking(section: str):
Expand Down Expand Up @@ -265,7 +268,47 @@ def check_substitution_columns(config, columns_available):
)


def check_column_mappings(config, df_a, df_b):
def check_column_mappings_column_available(
column_mapping: dict[str, Any],
df: DataFrame,
previous_mappings: list[str],
a_or_b: Literal["a", "b"],
) -> None:
"""
Check whether a column in a column mapping is available or not. Raise a
ValueError if it is not available.
previous_mappings is a list of columns mapped by previous column mappings.
"""
column_name = column_mapping["column_name"]
override_column = column_mapping.get(f"override_column_{a_or_b}")
df_columns_lower = [column.lower() for column in df.columns]

if override_column is not None:
if override_column.lower() not in df_columns_lower:
raise ValueError(
f"Within a [[column_mappings]] the override_column_{a_or_b} column "
f"'{override_column}' does not exist in datasource_{a_or_b}.\n"
f"Column mapping: {column_mapping}\n"
f"Available columns: {df.columns}"
)
else:
if (
column_name.lower() not in df_columns_lower
and column_name not in previous_mappings
):
raise ValueError(
f"Within a [[column_mappings]] the column_name '{column_name}' "
f"does not exist in datasource_{a_or_b} and no previous "
"[[column_mapping]] alias exists for it.\n"
f"Column mapping: {column_mapping}.\n"
f"Available columns:\n {df.columns}"
)


def check_column_mappings(
config: dict[str, Any], df_a: DataFrame, df_b: DataFrame
) -> list[str]:
column_mappings = config.get("column_mappings")
if not column_mappings:
raise ValueError("No [[column_mappings]] exist in the conf file.")
Expand All @@ -276,22 +319,15 @@ def check_column_mappings(config, df_a, df_b):
column_name = c.get("column_name")
set_value_column_a = c.get("set_value_column_a")
set_value_column_b = c.get("set_value_column_b")

if not column_name:
raise ValueError(
f"The following [[column_mappings]] has no 'column_name' attribute: {c}"
)
if set_value_column_a is None:
if column_name.lower() not in [c.lower() for c in df_a.columns]:
if column_name not in columns_available:
raise ValueError(
f"Within a [[column_mappings]] the column_name: '{column_name}' does not exist in datasource_a and no previous [[column_mapping]] alias exists for it. \nColumn mapping: {c}. \nAvailable columns: \n {df_a.columns}"
)
check_column_mappings_column_available(c, df_a, columns_available, "a")
if set_value_column_b is None:
if column_name.lower() not in [c.lower() for c in df_b.columns]:
if column_name not in columns_available:
raise ValueError(
f"Within a [[column_mappings]] the column_name: '{column_name}' does not exist in datasource_b and no previous [[column_mapping]] alias exists for it. Column mapping: {c}. Available columns: \n {df_b.columns}"
)
check_column_mappings_column_available(c, df_b, columns_available, "b")
if alias in columns_available:
duplicates.append(alias)
elif not alias and column_name in columns_available:
Expand Down
200 changes: 199 additions & 1 deletion hlink/tests/conf_validations_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import pytest

from pyspark.sql import SparkSession

from hlink.configs.load_config import load_conf_file
from hlink.scripts.lib.conf_validations import analyze_conf
from hlink.scripts.lib.conf_validations import analyze_conf, check_column_mappings
from hlink.linking.link_run import LinkRun


Expand All @@ -25,3 +27,199 @@ def test_invalid_conf(conf_dir_path, spark, conf_name, error_msg):

with pytest.raises(ValueError, match=error_msg):
analyze_conf(link_run)


def test_check_column_mappings_mappings_missing(spark: SparkSession) -> None:
"""
The config must have a column_mappings section.
"""
config = {}
df_a = spark.createDataFrame([[1], [2], [3]], ["a"])
df_b = spark.createDataFrame([[4], [5], [6]], ["b"])

with pytest.raises(
ValueError, match=r"No \[\[column_mappings\]\] exist in the conf file"
):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_no_column_name(spark: SparkSession) -> None:
"""
Each column mapping in the config must have a column_name attribute.
"""
config = {
"column_mappings": [{"column_name": "AGE", "alias": "age"}, {"alias": "height"}]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

expected_err = (
r"The following \[\[column_mappings\]\] has no 'column_name' attribute:"
)
with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_column_name_not_available_datasource_a(
spark: SparkSession,
) -> None:
"""
Column mappings may only use column_names that appear in datasource A or a
previous column mapping.
"""
config = {"column_mappings": [{"column_name": "HEIGHT"}]}

df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
df_b = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])

expected_err = (
r"Within a \[\[column_mappings\]\] the column_name 'HEIGHT' "
r"does not exist in datasource_a and no previous \[\[column_mapping\]\] "
"alias exists for it"
)

with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_set_value_column_a_does_not_need_column(
spark: SparkSession,
) -> None:
"""
When set_value_column_a is present for a column mapping, that column does not
need to be present in datasource A.
"""
config = {"column_mappings": [{"column_name": "HEIGHT", "set_value_column_a": 125}]}

df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
df_b = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_column_name_not_available_datasource_b(
spark: SparkSession,
) -> None:
"""
Column mappings may only use column_names that appear in datasource B or a
previous column mapping.
"""
config = {"column_mappings": [{"column_name": "HEIGHT"}]}

df_a = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])

expected_err = (
r"Within a \[\[column_mappings\]\] the column_name 'HEIGHT' "
r"does not exist in datasource_b and no previous \[\[column_mapping\]\] "
"alias exists for it"
)

with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_set_value_column_b_does_not_need_column(
spark: SparkSession,
) -> None:
"""
When set_value_column_b is present for a column mapping, that column does not
need to be present in datasource B.
"""
config = {"column_mappings": [{"column_name": "HEIGHT", "set_value_column_b": 125}]}

df_a = spark.createDataFrame([[70, 123], [50, 123], [30, 123]], ["AGE", "HEIGHT"])
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_previous_mappings_are_available(
spark: SparkSession,
) -> None:
"""
Columns created in a previous column mapping can be used in other column
mappings.
"""
config = {
"column_mappings": [
{"column_name": "AGE", "alias": "AGE_HLINK"},
{"column_name": "AGE_HLINK", "alias": "AGE_HLINK2"},
]
}
df_a = spark.createDataFrame([[70], [50], [30]], ["AGE"])
df_b = spark.createDataFrame([[20], [40], [60]], ["AGE"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_override_column_a(spark: SparkSession) -> None:
"""
The override_column_a attribute lets you control which column you read from
in datasource A.
"""
config = {
"column_mappings": [{"column_name": "AGE", "override_column_a": "ageColumn"}]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_override_column_b(spark: SparkSession) -> None:
"""
The override_column_b attribute lets you control which column you read from
in datasource B.
"""
config = {
"column_mappings": [{"column_name": "ageColumn", "override_column_b": "AGE"}]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_override_column_a_not_present(
spark: SparkSession,
) -> None:
"""
The override_column_a column must be present in datasource A.
"""
config = {
"column_mappings": [
{"column_name": "AGE", "override_column_a": "oops_not_there"}
]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["ageColumn"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

expected_err = (
r"Within a \[\[column_mappings\]\] the override_column_a column "
"'oops_not_there' does not exist in datasource_a"
)
with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)


def test_check_column_mappings_override_column_b_not_present(
spark: SparkSession,
) -> None:
"""
The override_column_b column must be present in datasource B.
"""
config = {
"column_mappings": [
{"column_name": "AGE", "override_column_b": "oops_not_there"}
]
}
df_a = spark.createDataFrame([[20], [40], [60]], ["AGE"])
df_b = spark.createDataFrame([[70], [50], [30]], ["AGE"])

expected_err = (
r"Within a \[\[column_mappings\]\] the override_column_b column "
"'oops_not_there' does not exist in datasource_b"
)
with pytest.raises(ValueError, match=expected_err):
check_column_mappings(config, df_a, df_b)

0 comments on commit 12ff643

Please sign in to comment.