Skip to content

Commit

Permalink
Merge pull request #2405 from moj-analytical-services/add_cosine_simi…
Browse files Browse the repository at this point in the history
…larity

Add cosine similiarity comparison level and comparison
  • Loading branch information
RobinL authored Sep 16, 2024
2 parents 52059b5 + ef60aa3 commit f472116
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Match weight and m and u probabilities charts now have improved tooltips ([#2392](https://github.com/moj-analytical-services/splink/pull/2392))
- Added new `AbsoluteDifferenceLevel` comparison level for numerical columns ([#2398](https://github.com/moj-analytical-services/splink/pull/2398))
- Added new `CosineSimilarityLevel` and `CosineSimilarityAtThresholds` for comparing array columns using cosine similarity ([#2405](https://github.com/moj-analytical-services/splink/pull/2405))

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions splink/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
And,
ArrayIntersectLevel,
ColumnsReversedLevel,
CosineSimilarityLevel,
CustomLevel,
DamerauLevenshteinLevel,
DistanceFunctionLevel,
Expand Down Expand Up @@ -44,4 +45,5 @@
"And",
"Not",
"Or",
"CosineSimilarityLevel",
]
2 changes: 2 additions & 0 deletions splink/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AbsoluteDateDifferenceAtThresholds,
AbsoluteTimeDifferenceAtThresholds,
ArrayIntersectAtSizes,
CosineSimilarityAtThresholds,
CustomComparison,
DamerauLevenshteinAtThresholds,
DateOfBirthComparison,
Expand Down Expand Up @@ -36,4 +37,5 @@
"ForenameSurnameComparison",
"NameComparison",
"PostcodeComparison",
"CosineSimilarityAtThresholds",
]
35 changes: 35 additions & 0 deletions splink/internals/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,41 @@ def create_label_for_charts(self) -> str:
return f"Distance less than {self.km_threshold}km"


class CosineSimilarityLevel(ComparisonLevelCreator):
def __init__(
self,
col_name: Union[str, ColumnExpression],
similarity_threshold: float,
):
"""A comparison level using a cosine similarity function
e.g. array_cosine_similarity(val_l, val_r) >= similarity_threshold
Args:
col_name (str): Input column name
similarity_threshold (float): The threshold to use to assess
similarity. Should be between 0 and 1.
"""
self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.similarity_threshold = validate_numeric_parameter(
lower_bound=0.0,
upper_bound=1.0,
parameter_value=similarity_threshold,
level_name=self.__class__.__name__,
parameter_name="similarity_threshold",
)

def create_sql(self, sql_dialect: SplinkDialect) -> str:
self.col_expression.sql_dialect = sql_dialect
col = self.col_expression
cs_fn = sql_dialect.cosine_similarity_function_name
return f"{cs_fn}({col.name_l}, {col.name_r}) >= {self.similarity_threshold}"

def create_label_for_charts(self) -> str:
col = self.col_expression
return f"Cosine similarity of {col.label} >= {self.similarity_threshold}"


class ArrayIntersectLevel(ComparisonLevelCreator):
def __init__(self, col_name: str | ColumnExpression, min_intersection: int):
"""Represents a comparison level based around the size of an intersection of
Expand Down
44 changes: 44 additions & 0 deletions splink/internals/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,3 +1118,47 @@ def create_output_column_name(self) -> str:
forename_output_name = self.col_expressions["forename"].output_column_name
surname_output_name = self.col_expressions["surname"].output_column_name
return f"{forename_output_name}_{surname_output_name}"


class CosineSimilarityAtThresholds(ComparisonCreator):
def __init__(
self,
col_name: str,
score_threshold_or_thresholds: Union[Iterable[float], float] = [0.9, 0.8, 0.7],
):
"""
Represents a comparison of the data in `col_name` with two or more levels:
- Cosine similarity levels at specified thresholds
- ...
- Anything else
For example, with score_threshold_or_thresholds = [0.9, 0.7] the levels are:
- Cosine similarity in `col_name` >= 0.9
- Cosine similarity in `col_name` >= 0.7
- Anything else
Args:
col_name (str): The name of the column to compare.
score_threshold_or_thresholds (Union[float, list], optional): The
threshold(s) to use for the cosine similarity level(s).
Defaults to [0.9, 0.7].
"""

thresholds_as_iterable = ensure_is_iterable(score_threshold_or_thresholds)
self.thresholds = [*thresholds_as_iterable]
super().__init__(col_name)

def create_comparison_levels(self) -> List[ComparisonLevelCreator]:
return [
cll.NullLevel(self.col_expression),
*[
cll.CosineSimilarityLevel(self.col_expression, threshold)
for threshold in self.thresholds
],
cll.ElseLevel(),
]

def create_output_column_name(self) -> str:
return self.col_expression.output_column_name
10 changes: 10 additions & 0 deletions splink/internals/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def jaccard_function_name(self):
f"Backend '{self.name}' does not have a 'Jaccard' function"
)

@property
def cosine_similarity_function_name(self):
raise NotImplementedError(
f"Backend '{self.name}' does not have a 'Cosine Similarity' function"
)

def random_sample_sql(
self, proportion, sample_size, seed=None, table=None, unique_id=None
):
Expand Down Expand Up @@ -252,6 +258,10 @@ def explode_arrays_sql(
return f"""select {','.join(cols_to_select)}
from ({self.explode_arrays_sql(tbl_name,columns_to_explode,other_columns_to_retain)})""" # noqa: E501

@property
def cosine_similarity_function_name(self):
return "array_cosine_similarity"


class SparkDialect(SplinkDialect):
_dialect_name_for_factory = "spark"
Expand Down

0 comments on commit f472116

Please sign in to comment.