diff --git a/src/haddock/libs/libinteractive.py b/src/haddock/libs/libinteractive.py index ac6760d7d..1f9c3f860 100644 --- a/src/haddock/libs/libinteractive.py +++ b/src/haddock/libs/libinteractive.py @@ -9,6 +9,7 @@ from haddock.libs.libplots import read_capri_table from haddock.modules import get_module_steps_folders + def handle_ss_file( df_ss: pd.DataFrame, ) -> tuple[pd.DataFrame, dict]: @@ -35,16 +36,17 @@ def handle_ss_file( df_ss_grouped = df_ss.groupby("cluster_id") # calculate the mean and standard deviation of the first 4 elements # of each group - new_values = np.zeros((len(df_ss_grouped), 3)) + new_values = [] # loop over df_ss_grouped with enumerate - for i, clt_id in enumerate(df_ss_grouped): + for clt_id in df_ss_grouped: ave_score = np.mean(clt_id[1]["score"].iloc[:4]) std_score = np.std(clt_id[1]["score"].iloc[:4]) - new_values[i] = [ave_score, std_score, clt_id[0]] + new_values.append([ave_score, std_score, clt_id[0]]) # get the index that sorts the array by the first column - clt_ranks = np.argsort(new_values[:, 0]) + new_values_arr = np.array(new_values) + clt_ranks = np.argsort(new_values_arr[:, 0]) # the ranked clusters are the third column of the new_values array - clt_sorted = new_values[clt_ranks, 2] + clt_sorted = new_values_arr[clt_ranks, 2] clt_ranks_dict = {clt_sorted[i]: i + 1 for i in range(len(clt_sorted))} # adjust clustering values if there are clusters if list(np.unique(df_ss["cluster_id"])) != ["-"]: diff --git a/tests/test_libinteractive.py b/tests/test_libinteractive.py index a0a6cf4bb..085036735 100644 --- a/tests/test_libinteractive.py +++ b/tests/test_libinteractive.py @@ -1,6 +1,6 @@ import tempfile from haddock.libs.libplots import read_capri_table -from haddock.libs.libinteractive import handle_ss_file, rewrite_capri_tables, look_for_capri, handle_clt_file +from haddock.libs.libinteractive import handle_ss_file, look_for_capri, handle_clt_file import pytest from pathlib import Path import numpy as np @@ -26,7 +26,18 @@ def test_handle_ss_file(example_capri_df_ss): assert clt_ranks_dict[16] == 1 assert clt_ranks_dict[1] == 2 assert clt_ranks_dict[34] == 39 - + + +def test_handle_ss_file_unclustered(example_capri_df_ss): + """Test handle_ss_file function with unclustered data.""" + df_ss = example_capri_df_ss.copy() + df_ss['cluster_id'] = '-' + df_ss['cluster_ranking'] = "-" + df_ss['model-cluster_ranking'] = "-" + df_ss, clt_ranks_dict = handle_ss_file(df_ss) + assert clt_ranks_dict == {"-": 1} + assert df_ss['cluster_ranking'].values[0] == "-" + def test_handle_clt_file(example_capri_df_ss): """Test handle_clt_file function."""