Skip to content

Commit

Permalink
Merge pull request #63 from michaelbornholdt/fix_prec_recall
Browse files Browse the repository at this point in the history
Quick fix to precision recall
  • Loading branch information
gwaybio authored Oct 20, 2021
2 parents 7daede2 + 212e2a5 commit 3c907ea
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Currently, five metric operations are supported:
3. mp-value
4. Grit
5. Enrichment
6. Hit@k

## Demos

Expand Down
1 change: 1 addition & 0 deletions cytominer_eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def evaluate(
metric_result = precision_recall(
similarity_melted_df=similarity_melted_df,
replicate_groups=replicate_groups,
groupby_columns=groupby_columns,
k=precision_recall_k,
)
elif operation == "grit":
Expand Down
22 changes: 14 additions & 8 deletions cytominer_eval/operations/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
def precision_recall(
similarity_melted_df: pd.DataFrame,
replicate_groups: List[str],
groupby_columns: List[str],
k: Union[int, List[int]],
) -> pd.DataFrame:
"""Determine the precision and recall at k for all unique replicate groups
"""Determine the precision and recall at k for all unique groupby_columns samples
based on a predefined similarity metric (see cytominer_eval.transform.metric_melt)
Parameters
Expand All @@ -26,15 +27,20 @@ def precision_recall(
samples. Importantly, it must follow the exact structure as output from
:py:func:`cytominer_eval.transform.transform.metric_melt`.
replicate_groups : List
a list of metadata column names in the original profile dataframe to use as
replicate columns.
a list of metadata column names in the original profile dataframe to use as replicate columns.
groupby_columns : List of str
Column by which the similarity matrix is grouped and by which the precision/recall is calculated.
For example, if groupby_column = Metadata_sample then the precision is calculated for each sample.
Calculating the precision by sample is the default
but it is mathematically not incorrect to calculate the precision at the MOA level.
This is just less intuitive to understand.
k : List of ints or int
an integer indicating how many pairwise comparisons to threshold.
Returns
-------
pandas.DataFrame
precision and recall metrics for all replicate groups given k
precision and recall metrics for all groupby_column groups given k
"""
# Determine pairwise replicates and make sure to sort based on the metric!
similarity_melted_df = assign_replicates(
Expand All @@ -46,9 +52,9 @@ def precision_recall(

# Extract out specific columns
pair_ids = set_pair_ids()
replicate_group_cols = [
groupby_cols_suffix = [
"{x}{suf}".format(x=x, suf=pair_ids[list(pair_ids)[0]]["suffix"])
for x in replicate_groups
for x in groupby_columns
]
# iterate over all k
precision_recall_df = pd.DataFrame()
Expand All @@ -57,11 +63,11 @@ def precision_recall(
for k_ in k:
# Calculate precision and recall for all groups
precision_recall_df_at_k = similarity_melted_df.groupby(
replicate_group_cols
groupby_cols_suffix
).apply(lambda x: calculate_precision_recall(x, k=k_))
precision_recall_df = precision_recall_df.append(precision_recall_df_at_k)

# Rename the columns back to the replicate groups provided
rename_cols = dict(zip(replicate_group_cols, replicate_groups))
rename_cols = dict(zip(groupby_cols_suffix, groupby_columns))

return precision_recall_df.reset_index().rename(rename_cols, axis="columns")
5 changes: 5 additions & 0 deletions cytominer_eval/tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def test_evaluate_precision_recall():
},
}

gene_groupby_columns = ["Metadata_pert_name"]
compound_groupby_columns = ["Metadata_broad_sample"]

for k in ks:

# first test the function with k = float, later we test with k = list of floats
Expand All @@ -140,6 +143,7 @@ def test_evaluate_precision_recall():
features=gene_features,
meta_features=gene_meta_features,
replicate_groups=gene_groups,
groupby_columns=gene_groupby_columns,
operation="precision_recall",
similarity_metric="pearson",
precision_recall_k=k,
Expand All @@ -159,6 +163,7 @@ def test_evaluate_precision_recall():
features=compound_features,
meta_features=compound_meta_features,
replicate_groups=["Metadata_broad_sample"],
groupby_columns=compound_groupby_columns,
operation="precision_recall",
similarity_metric="pearson",
precision_recall_k=[k],
Expand Down
19 changes: 10 additions & 9 deletions cytominer_eval/tests/test_operations/test_precision_recall.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os
import random
import pytest
import pathlib
import tempfile
import numpy as np
import pandas as pd


from cytominer_eval.transform import metric_melt
from cytominer_eval.operations import precision_recall

random.seed(123)
tmpdir = tempfile.gettempdir()
random.seed(42)

# Load CRISPR dataset
example_file = "SQ00014610_normalized_feature_select.csv.gz"
Expand All @@ -37,32 +34,36 @@

replicate_groups = ["Metadata_gene_name", "Metadata_cell_line"]

groupby_columns = ["Metadata_pert_name"]


def test_precision_recall():
result_list = precision_recall(
similarity_melted_df=similarity_melted_df,
replicate_groups=replicate_groups,
groupby_columns=groupby_columns,
k=[5, 10],
)

result_int = precision_recall(
similarity_melted_df=similarity_melted_df,
replicate_groups=replicate_groups,
groupby_columns=groupby_columns,
k=5,
)

assert len(result_list.k.unique()) == 2
assert result_list.k.unique()[0] == 5

# ITGAV has a really strong profile
# ITGAV-1 has a really strong profile
assert (
result_list.sort_values(by="recall", ascending=False)
.reset_index(drop=True)
.iloc[0, :]
.Metadata_gene_name
== "ITGAV"
.Metadata_pert_name
== "ITGAV-1"
)

assert all(x in result_list.columns for x in replicate_groups)
assert all(x in result_list.columns for x in groupby_columns)

assert result_int.equals(result_list.query("k == 5"))

0 comments on commit 3c907ea

Please sign in to comment.