Skip to content

Commit

Permalink
Merge pull request #2436 from moj-analytical-services/take_converged_…
Browse files Browse the repository at this point in the history
…clusters_out_of_play

Take converged clusters out of play
  • Loading branch information
RobinL authored Sep 30, 2024
2 parents 88376ac + 1e7a5a0 commit 2936d77
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 39 deletions.
187 changes: 149 additions & 38 deletions splink/internals/connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _cc_update_representatives_first_iter() -> str:
This is only used for the first iteration as we
In this SQL, we also generate "rep_match", which is a boolean
In this SQL, we also generate "needs_updating", which is a boolean
that indicates whether the current representative differs
from the previous representative.
Expand All @@ -141,17 +141,17 @@ def _cc_update_representatives_first_iter() -> str:
select
n.node_id,
n.representative,
n.representative <> repr.representative as rep_match
n.representative <> repr.representative as needs_updating
from neighbours_first_iter as n
left join representatives as repr
inner join representatives as repr
on n.node_id = repr.node_id
"""

return sql


def _cc_generate_representatives_loop_cond(
prev_representatives: str,
prev_representatives: str, filtered_neighbours: str
) -> str:
"""SQL for Connected components main loop.
Expand All @@ -168,11 +168,11 @@ def _cc_generate_representatives_loop_cond(
all of our neighbours' representatives to a solution.
The key difference between this function and 'cc_update_neighbours_first_iter',
is the usage of 'rep_match'.
is the usage of 'needs_updating'.
The logic behind 'rep_match' is summarised in 'cc_update_representatives_first_iter'
and it can be used here to reduce our neighbours table to only those nodes that need
updating.
The logic behind 'needs_updating' is summarised in
'cc_update_representatives_first_iter' and it can be used here to reduce our
neighbours table to only those nodes that need updating.
"""

sql = f"""
Expand All @@ -189,13 +189,13 @@ def _cc_generate_representatives_loop_cond(
neighbours.node_id,
repr_neighbour.representative as representative
from __splink__df_neighbours as neighbours
from {filtered_neighbours} as neighbours
left join {prev_representatives} as repr_neighbour
inner join {prev_representatives} as repr_neighbour
on neighbours.neighbour = repr_neighbour.node_id
where
repr_neighbour.rep_match
repr_neighbour.needs_updating
UNION ALL
Expand All @@ -219,7 +219,7 @@ def _cc_update_representatives_loop_cond(
"""SQL to update our representatives table - while loop condition.
Reorganises our representatives output generated in
cc_generate_representatives_loop_cond() and isolates 'rep_match',
cc_generate_representatives_loop_cond() and isolates 'needs_updating',
which indicates whether all representatives have 'settled' (i.e.
no change from previous iteration).
"""
Expand All @@ -229,7 +229,7 @@ def _cc_update_representatives_loop_cond(
r.node_id,
r.representative,
r.representative <> repr.representative as rep_match
r.representative <> repr.representative as needs_updating
from r
Expand All @@ -243,20 +243,73 @@ def _cc_update_representatives_loop_cond(
def _cc_assess_exit_condition(representatives_name: str) -> str:
"""SQL exit condition for our Connected Components algorithm.
Where 'rep_match' (summarised in 'cc_update_representatives_first_iter')
Where 'needs_updating' (summarised in 'cc_update_representatives_first_iter')
it indicates that some nodes still require updating and have not yet
settled.
"""

sql = f"""
select count(*) as count
select count(*) as count_of_nodes_needing_updating
from {representatives_name}
where rep_match
where needs_updating
"""

return sql


def _cc_find_converged_nodes(
representatives_name: str, neighbours_name: str
) -> list[dict[str, str]]:
"""SQL to find nodes that have converged so are part of a stable cluster.
These can be removed 'from play' to slim down tables and make the algorithm
run faster.
Args:
representatives_name: The name of the representatives table.
neighbours_name: The name of the neighbours table.
Returns:
str: SQL query to find unconverged nodes.
"""

# Take nodes, and find neighbours to the nodes (follow edges in both directions)
# For each neighbour to the node, find its representative
# Does that lead us back to the same cluster? Tf not there is a neighbour
# outside of the cluster
sqls = []

sql_non_stable = f"""
SELECT DISTINCT r.representative
FROM {representatives_name} r
JOIN {neighbours_name} n ON r.node_id = n.node_id
JOIN {representatives_name} r2 ON n.neighbour = r2.node_id
WHERE r.representative != r2.representative
"""

sqls.append(
{
"sql": sql_non_stable,
"output_table_name": "non_stable_representatives",
}
)

sql_stable = f"""
SELECT *
FROM {representatives_name}
WHERE representative NOT IN (
SELECT representative FROM non_stable_representatives
)
"""
sqls.append(
{
"sql": sql_stable,
"output_table_name": "__splink__representatives_stable",
}
)

return sqls


def solve_connected_components(
nodes_table: SplinkDataFrame,
edges_table: SplinkDataFrame,
Expand All @@ -279,7 +332,6 @@ def solve_connected_components(
Splink dataframe containing our edges dataframe to be connected.
Returns:
SplinkDataFrame: A dataframe containing the connected components list
for your link or dedupe job.
Expand Down Expand Up @@ -341,28 +393,74 @@ def solve_connected_components(
prev_representatives_table = representatives

# Loop while our representative table still has unsettled nodes
iteration, root_rows_count = 0, 1
while root_rows_count > 0:
# (nodes where the representative has changed since the last iteration)
converged_clusters_tables = []
filtered_neighbours = neighbours

iteration, needs_updating_count = 0, 1
while needs_updating_count > 0:
start_time = time.time()
iteration += 1

# Loop summary:
# 1. Find stable clusters and remove from representatives table
# Stable clusters are those where a set of nodes are within the same cluster
# and those nodes have no neighbours outside of their cluster.
# Add to list of converged clusters.
# 2. Update representatives table by following links from current reps
# to their neighbours, and recalculating min representative
# 3. Join on the representatives table from the previous iteration
# to create the "needs_updating" column based on whether rep has changed
# 4. Assess if any representatives changed between iterations, exit if not.

# 1a. Find stable clusters and remove from representatives table
pipeline = CTEPipeline([filtered_neighbours, prev_representatives_table])
converged_nodes_sqls = _cc_find_converged_nodes(
prev_representatives_table.templated_name,
filtered_neighbours.templated_name,
)
pipeline.enqueue_list_of_sqls(converged_nodes_sqls)

representatives_stable = db_api.sql_pipeline_to_splink_dataframe(pipeline)

converged_clusters_tables.append(representatives_stable)

# 1. Update our neighbours table.
# 2. Join on the representatives table from the previous iteration
# to create the "rep_match" column.
# 3. Assess if our exit condition has been met.
# Remove stable clusters from representatives table
pipeline = CTEPipeline([representatives_stable, prev_representatives_table])
sql = f"""
SELECT *
FROM {prev_representatives_table.templated_name}
WHERE representative NOT IN (
SELECT representative FROM __splink__representatives_stable
)
"""
pipeline.enqueue_sql(sql, "__splink__representatives_unstable")
prev_representatives_thinned = db_api.sql_pipeline_to_splink_dataframe(pipeline)

# 1a. Thin neighbours table - we can drop all rows that refer to
# node_ids that have converged
pipeline = CTEPipeline([prev_representatives_thinned, filtered_neighbours])
sql = f"""
select * from {filtered_neighbours.templated_name}
where node_id in
(select node_id from {prev_representatives_thinned.templated_name})
"""
pipeline.enqueue_sql(sql, "__splink__df_neighbours_filtered")
filtered_neighbours_thinned = db_api.sql_pipeline_to_splink_dataframe(pipeline)
filtered_neighbours.drop_table_from_database_and_remove_from_cache()
filtered_neighbours = filtered_neighbours_thinned

# Generates our representatives table for the next iteration
# by joining our previous tables onto our neighbours table.
pipeline = CTEPipeline([neighbours])
pipeline = CTEPipeline([filtered_neighbours])
sql = _cc_generate_representatives_loop_cond(
prev_representatives_table.physical_name,
prev_representatives_thinned.physical_name,
filtered_neighbours.templated_name,
)
pipeline.enqueue_sql(sql, "r")
# Update our rep_match column in the representatives table.
# Update our needs_updating column in the representatives table.
sql = _cc_update_representatives_loop_cond(
prev_representatives_table.physical_name
prev_representatives_thinned.physical_name
)

repr_name = f"__splink__df_representatives_{iteration}"
Expand All @@ -374,6 +472,10 @@ def solve_connected_components(

representatives = db_api.sql_pipeline_to_splink_dataframe(pipeline)

# Now the new representatives have been computed from the thinned
# representatives we no longer need the older table
prev_representatives_thinned.drop_table_from_database_and_remove_from_cache()

pipeline = CTEPipeline()
# Update table reference
prev_representatives_table.drop_table_from_database_and_remove_from_cache()
Expand All @@ -390,26 +492,35 @@ def solve_connected_components(

root_rows = root_rows_df.as_record_dict()
root_rows_df.drop_table_from_database_and_remove_from_cache()
root_rows_count = root_rows[0]["count"]
needs_updating_count = root_rows[0]["count_of_nodes_needing_updating"]
logger.info(
f"Completed iteration {iteration}, root rows count {root_rows_count}"
f"Completed iteration {iteration}, "
f"num representatives needing updating: {needs_updating_count}"
)
end_time = time.time()
logger.log(15, f" Iteration time: {end_time - start_time} seconds")

sql = f"""
select
node_id as {node_id_column_name},
representative as cluster_id
from {representatives.templated_name}
order by cluster_id, node_id
"""
converged_clusters_tables.append(representatives)
filtered_neighbours.drop_table_from_database_and_remove_from_cache()

pipeline = CTEPipeline()

pipeline = CTEPipeline([representatives])
pipeline.enqueue_sql(sql, "__splink__clustering_output")
sql = " UNION ALL ".join(
[
f"""select node_id as {node_id_column_name}, representative as cluster_id
from {t.physical_name}"""
for t in converged_clusters_tables
]
)

pipeline.enqueue_sql(sql, "__splink__clustering_output_final")

final_result = db_api.sql_pipeline_to_splink_dataframe(pipeline)

representatives.drop_table_from_database_and_remove_from_cache()
neighbours.drop_table_from_database_and_remove_from_cache()

for t in converged_clusters_tables:
t.drop_table_from_database_and_remove_from_cache()

return final_result
2 changes: 1 addition & 1 deletion splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def cluster_pairwise_predictions_at_threshold(
select
cc.cluster_id,
{select_columns_sql}
from __splink__clustering_output as cc
from __splink__clustering_output_final as cc
left join __splink__df_concat
on cc.node_id = {uid_concat_nodes}
"""
Expand Down

0 comments on commit 2936d77

Please sign in to comment.