Skip to content

Commit

Permalink
Merge pull request #972 from haddocking/fix_rescore
Browse files Browse the repository at this point in the history
Fix rescore
  • Loading branch information
mgiulini authored Aug 14, 2024
2 parents 01eb9ae + f901e15 commit d71e541
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
12 changes: 7 additions & 5 deletions src/haddock/libs/libinteractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from haddock.libs.libplots import read_capri_table
from haddock.modules import get_module_steps_folders


def handle_ss_file(
df_ss: pd.DataFrame,
) -> tuple[pd.DataFrame, dict]:
Expand All @@ -35,16 +36,17 @@ def handle_ss_file(
df_ss_grouped = df_ss.groupby("cluster_id")
# calculate the mean and standard deviation of the first 4 elements
# of each group
new_values = np.zeros((len(df_ss_grouped), 3))
new_values = []
# loop over df_ss_grouped with enumerate
for i, clt_id in enumerate(df_ss_grouped):
for clt_id in df_ss_grouped:
ave_score = np.mean(clt_id[1]["score"].iloc[:4])
std_score = np.std(clt_id[1]["score"].iloc[:4])
new_values[i] = [ave_score, std_score, clt_id[0]]
new_values.append([ave_score, std_score, clt_id[0]])
# get the index that sorts the array by the first column
clt_ranks = np.argsort(new_values[:, 0])
new_values_arr = np.array(new_values)
clt_ranks = np.argsort(new_values_arr[:, 0])
# the ranked clusters are the third column of the new_values array
clt_sorted = new_values[clt_ranks, 2]
clt_sorted = new_values_arr[clt_ranks, 2]
clt_ranks_dict = {clt_sorted[i]: i + 1 for i in range(len(clt_sorted))}
# adjust clustering values if there are clusters
if list(np.unique(df_ss["cluster_id"])) != ["-"]:
Expand Down
15 changes: 13 additions & 2 deletions tests/test_libinteractive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tempfile
from haddock.libs.libplots import read_capri_table
from haddock.libs.libinteractive import handle_ss_file, rewrite_capri_tables, look_for_capri, handle_clt_file
from haddock.libs.libinteractive import handle_ss_file, look_for_capri, handle_clt_file
import pytest
from pathlib import Path
import numpy as np
Expand All @@ -26,7 +26,18 @@ def test_handle_ss_file(example_capri_df_ss):
assert clt_ranks_dict[16] == 1
assert clt_ranks_dict[1] == 2
assert clt_ranks_dict[34] == 39



def test_handle_ss_file_unclustered(example_capri_df_ss):
"""Test handle_ss_file function with unclustered data."""
df_ss = example_capri_df_ss.copy()
df_ss['cluster_id'] = '-'
df_ss['cluster_ranking'] = "-"
df_ss['model-cluster_ranking'] = "-"
df_ss, clt_ranks_dict = handle_ss_file(df_ss)
assert clt_ranks_dict == {"-": 1}
assert df_ss['cluster_ranking'].values[0] == "-"


def test_handle_clt_file(example_capri_df_ss):
"""Test handle_clt_file function."""
Expand Down

0 comments on commit d71e541

Please sign in to comment.