diff --git a/CHANGELOG.md b/CHANGELOG.md index f45ad7e84..f37f01725 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Match weight and m and u probabilities charts now have improved tooltips ([#2392](https://github.com/moj-analytical-services/splink/pull/2392)) - Added new `AbsoluteDifferenceLevel` comparison level for numerical columns ([#2398](https://github.com/moj-analytical-services/splink/pull/2398)) +- Added new `CosineSimilarityLevel` and `CosineSimilarityAtThresholds` for comparing array columns using cosine similarity ([#2405](https://github.com/moj-analytical-services/splink/pull/2405)) ### Fixed diff --git a/splink/comparison_level_library.py b/splink/comparison_level_library.py index 364148c9d..493d9148e 100644 --- a/splink/comparison_level_library.py +++ b/splink/comparison_level_library.py @@ -5,6 +5,7 @@ And, ArrayIntersectLevel, ColumnsReversedLevel, + CosineSimilarityLevel, CustomLevel, DamerauLevenshteinLevel, DistanceFunctionLevel, @@ -44,4 +45,5 @@ "And", "Not", "Or", + "CosineSimilarityLevel", ] diff --git a/splink/comparison_library.py b/splink/comparison_library.py index 3cca66a6a..978370db6 100644 --- a/splink/comparison_library.py +++ b/splink/comparison_library.py @@ -2,6 +2,7 @@ AbsoluteDateDifferenceAtThresholds, AbsoluteTimeDifferenceAtThresholds, ArrayIntersectAtSizes, + CosineSimilarityAtThresholds, CustomComparison, DamerauLevenshteinAtThresholds, DateOfBirthComparison, @@ -36,4 +37,5 @@ "ForenameSurnameComparison", "NameComparison", "PostcodeComparison", + "CosineSimilarityAtThresholds", ] diff --git a/splink/internals/comparison_level_library.py b/splink/internals/comparison_level_library.py index b330b7a89..08c44d875 100644 --- a/splink/internals/comparison_level_library.py +++ b/splink/internals/comparison_level_library.py @@ -792,6 +792,41 @@ def create_label_for_charts(self) -> str: return f"Distance less than {self.km_threshold}km" +class CosineSimilarityLevel(ComparisonLevelCreator): + def __init__( + self, + col_name: Union[str, ColumnExpression], + similarity_threshold: float, + ): + """A comparison level using a cosine similarity function + + e.g. array_cosine_similarity(val_l, val_r) >= similarity_threshold + + Args: + col_name (str): Input column name + similarity_threshold (float): The threshold to use to assess + similarity. Should be between 0 and 1. + """ + self.col_expression = ColumnExpression.instantiate_if_str(col_name) + self.similarity_threshold = validate_numeric_parameter( + lower_bound=0.0, + upper_bound=1.0, + parameter_value=similarity_threshold, + level_name=self.__class__.__name__, + parameter_name="similarity_threshold", + ) + + def create_sql(self, sql_dialect: SplinkDialect) -> str: + self.col_expression.sql_dialect = sql_dialect + col = self.col_expression + cs_fn = sql_dialect.cosine_similarity_function_name + return f"{cs_fn}({col.name_l}, {col.name_r}) >= {self.similarity_threshold}" + + def create_label_for_charts(self) -> str: + col = self.col_expression + return f"Cosine similarity of {col.label} >= {self.similarity_threshold}" + + class ArrayIntersectLevel(ComparisonLevelCreator): def __init__(self, col_name: str | ColumnExpression, min_intersection: int): """Represents a comparison level based around the size of an intersection of diff --git a/splink/internals/comparison_library.py b/splink/internals/comparison_library.py index 59c413f2c..474489493 100644 --- a/splink/internals/comparison_library.py +++ b/splink/internals/comparison_library.py @@ -1118,3 +1118,47 @@ def create_output_column_name(self) -> str: forename_output_name = self.col_expressions["forename"].output_column_name surname_output_name = self.col_expressions["surname"].output_column_name return f"{forename_output_name}_{surname_output_name}" + + +class CosineSimilarityAtThresholds(ComparisonCreator): + def __init__( + self, + col_name: str, + score_threshold_or_thresholds: Union[Iterable[float], float] = [0.9, 0.8, 0.7], + ): + """ + Represents a comparison of the data in `col_name` with two or more levels: + + - Cosine similarity levels at specified thresholds + - ... + - Anything else + + For example, with score_threshold_or_thresholds = [0.9, 0.7] the levels are: + + - Cosine similarity in `col_name` >= 0.9 + - Cosine similarity in `col_name` >= 0.7 + - Anything else + + Args: + col_name (str): The name of the column to compare. + score_threshold_or_thresholds (Union[float, list], optional): The + threshold(s) to use for the cosine similarity level(s). + Defaults to [0.9, 0.7]. + """ + + thresholds_as_iterable = ensure_is_iterable(score_threshold_or_thresholds) + self.thresholds = [*thresholds_as_iterable] + super().__init__(col_name) + + def create_comparison_levels(self) -> List[ComparisonLevelCreator]: + return [ + cll.NullLevel(self.col_expression), + *[ + cll.CosineSimilarityLevel(self.col_expression, threshold) + for threshold in self.thresholds + ], + cll.ElseLevel(), + ] + + def create_output_column_name(self) -> str: + return self.col_expression.output_column_name diff --git a/splink/internals/dialects.py b/splink/internals/dialects.py index d1564369b..a0e889301 100644 --- a/splink/internals/dialects.py +++ b/splink/internals/dialects.py @@ -97,6 +97,12 @@ def jaccard_function_name(self): f"Backend '{self.name}' does not have a 'Jaccard' function" ) + @property + def cosine_similarity_function_name(self): + raise NotImplementedError( + f"Backend '{self.name}' does not have a 'Cosine Similarity' function" + ) + def random_sample_sql( self, proportion, sample_size, seed=None, table=None, unique_id=None ): @@ -252,6 +258,10 @@ def explode_arrays_sql( return f"""select {','.join(cols_to_select)} from ({self.explode_arrays_sql(tbl_name,columns_to_explode,other_columns_to_retain)})""" # noqa: E501 + @property + def cosine_similarity_function_name(self): + return "array_cosine_similarity" + class SparkDialect(SplinkDialect): _dialect_name_for_factory = "spark"