Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MetricValues #503

Merged
merged 36 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3321668
Use SerializableDataclass to serialize and deserialize MetricConfig o…
tetsuok Aug 30, 2022
a134e0d
Fix isort
tetsuok Aug 30, 2022
49baf8d
Bring back cls_name
tetsuok Aug 30, 2022
0e7ebd4
Merge branch 'main' into issue-427-metric-config
tetsuok Aug 31, 2022
72c0a8e
Merge branch 'main' into issue-427-metric-config
tetsuok Sep 5, 2022
5b33546
Replace metric_config_from_dict with get_metric_config_serializer
tetsuok Sep 5, 2022
793096f
Fix isort issue
tetsuok Sep 5, 2022
05f6e9d
Merge branch 'main' into issue-427-metric-config
tetsuok Sep 5, 2022
3471647
Fix TypeError
tetsuok Sep 5, 2022
0b464b5
Format with Black
tetsuok Sep 5, 2022
ace4515
Merge branch 'main' into issue-427-metric-config
tetsuok Sep 10, 2022
301f123
Ignore typecheck for dataclass
tetsuok Sep 10, 2022
3654a1e
Update explainaboard/metrics/nlg_meta_evaluation.py
tetsuok Sep 10, 2022
8f74cbd
Add type annotations to to_metric
tetsuok Sep 10, 2022
02d48f6
Merge branch 'main' into issue-427-metric-config
tetsuok Sep 11, 2022
bdf5698
Remove external_stats from tests
tetsuok Sep 11, 2022
64823af
Merge branch 'main' into issue-427-metric-config
Sep 14, 2022
c35232e
add MetricValue
odashi Sep 14, 2022
9f1592c
Merge branch 'issue-427-metric-config' of github.com:tetsuok/Explaina…
odashi Sep 14, 2022
1d8c8be
fix linter errors
odashi Sep 14, 2022
62232ee
fix serialization bug
odashi Sep 14, 2022
7143ec2
Merge branch 'issue-427-metric-config' into refactor-metricresult
odashi Sep 16, 2022
cfcf8dc
add some tests.
odashi Sep 19, 2022
bb18379
Merge branch 'main' into refactor-metricresult
odashi Sep 19, 2022
67b6d13
fix bugs and introduce shape checking.
odashi Sep 19, 2022
de3dac4
Fix batching in most metrics
neubig Sep 19, 2022
d07e733
Simplification
neubig Sep 20, 2022
d2823e7
Merge branch 'main' into refactor-metricresult
odashi Sep 20, 2022
e11cc81
Merge branch 'fix-calc-metric-from-aggregate' into refactor-metricresult
odashi Sep 20, 2022
c5fe886
fix bugs
odashi Sep 20, 2022
b850461
Merge branch 'main' into refactor-metricresult
odashi Sep 20, 2022
caed18f
Change RuntimeError to assertion.
odashi Sep 20, 2022
0cf56d9
Merge branch 'fix-calc-metric-from-aggregate' into refactor-metricresult
odashi Sep 20, 2022
9f19f8f
Merge branch 'main' into refactor-metricresult
odashi Sep 23, 2022
ad02ab2
remove metric registry
odashi Sep 24, 2022
70ee795
Merge branch 'main' into refactor-metricresult
odashi Sep 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions explainaboard/analysis/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from explainaboard.analysis.case import AnalysisCase, AnalysisCaseCollection
from explainaboard.analysis.feature import FeatureType
from explainaboard.analysis.performance import BucketPerformance, Performance
from explainaboard.metrics.metric import Metric, MetricConfig, MetricStats
from explainaboard.metrics.metric import (
ConfidenceInterval,
Metric,
MetricConfig,
MetricStats,
Score,
)
from explainaboard.serialization.serializers import PrimitiveSerializer
from explainaboard.utils.typing_utils import narrow, unwrap, unwrap_generator

Expand Down Expand Up @@ -284,28 +290,29 @@ def perform(
# has no samples
if n_samples == 0.0:
value = 0.0
conf_low: Optional[float] = None
conf_high: Optional[float] = None
ci_low: Optional[float] = None
ci_high: Optional[float] = None
else:
bucket_stats = metric_stat.filter(bucket_collection.samples)
metric_result = metric_func.evaluate_from_stats(
bucket_stats,
confidence_alpha=confidence_alpha,
)

conf_low, conf_high = (
metric_result.confidence_interval
if metric_result.confidence_interval
else (None, None)
)

value = metric_result.value
value = unwrap(metric_result.get_value(Score, "score")).value
ci = metric_result.get_value(ConfidenceInterval, "score_ci")
if ci is not None:
ci_low = ci.low
ci_high = ci.high
else:
ci_low = None
ci_high = None

performance = Performance(
metric_name=metric_func.config.name,
value=value,
confidence_score_low=conf_low,
confidence_score_high=conf_high,
confidence_score_low=ci_low,
confidence_score_high=ci_high,
)

bucket_performance.performances.append(performance)
Expand Down
58 changes: 58 additions & 0 deletions explainaboard/metrics/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
SeqCorrectCount,
SeqCorrectCountConfig,
)
from explainaboard.metrics.metric import Score
from explainaboard.utils.typing_utils import unwrap


class AccuracyConfigTest(unittest.TestCase):
Expand All @@ -35,6 +37,17 @@ def test_to_metric(self) -> None:
self.assertIsInstance(AccuracyConfig("Accuracy").to_metric(), Accuracy)


class AccuracyTest(unittest.TestCase):
def test_evaluate(self) -> None:
metric = AccuracyConfig(name='Accuracy').to_metric()
true = ['a', 'b', 'a', 'b', 'a', 'b']
pred = ['a', 'b', 'a', 'b', 'b', 'a']
result = metric.evaluate(true, pred, confidence_alpha=0.05)
self.assertAlmostEqual(
unwrap(result.get_value(Score, "score")).value, 2.0 / 3.0
)


class CorrectCountConfigTest(unittest.TestCase):
def test_serialize(self) -> None:
self.assertEqual(
Expand All @@ -59,6 +72,15 @@ def test_to_metric(self) -> None:
)


class CorrectCountTest(unittest.TestCase):
def test_evaluate(self) -> None:
metric = CorrectCountConfig(name='CorrectCount').to_metric()
true = ['a', 'b', 'a', 'b', 'a', 'b']
pred = ['a', 'b', 'a', 'b', 'b', 'a']
result = metric.evaluate(true, pred, confidence_alpha=0.05)
self.assertAlmostEqual(unwrap(result.get_value(Score, "score")).value, 4)


class SeqCorrectCountConfigTest(unittest.TestCase):
def test_serialize(self) -> None:
self.assertEqual(
Expand All @@ -81,3 +103,39 @@ def test_to_metric(self) -> None:
SeqCorrectCountConfig("SeqCorrectCount").to_metric(),
SeqCorrectCount,
)


class SeqCorrectCountTest(unittest.TestCase):
def test_evaluate(self) -> None:
metric = SeqCorrectCountConfig(name='SeqCorrectCount').to_metric()
true = [
{
"start_idx": [8, 17, 39, 46, 58, 65, 65, 80],
"end_idx": [8, 18, 40, 47, 59, 65, 66, 81],
"corrections": [
["the"],
["found"],
["other"],
["there"],
["chickens."],
["in"],
["which"],
["selling"],
],
}
]
pred = [
{
"start_idx": [8, 17, 39, 46, 58],
"end_idx": [8, 18, 40, 47, 59],
"corrections": [
["the"],
["found"],
["other"],
["there"],
["chickens."],
],
}
]
result = metric.evaluate(true, pred)
self.assertAlmostEqual(unwrap(result.get_value(Score, "score")).value, 5)
30 changes: 16 additions & 14 deletions explainaboard/metrics/external_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

from explainaboard.metrics.metric import (
AuxiliaryMetricResult,
ConfidenceInterval,
Metric,
MetricConfig,
MetricResult,
MetricStats,
MetricValue,
Score,
SimpleMetricStats,
)
from explainaboard.serialization import common_registry
Expand Down Expand Up @@ -185,17 +188,16 @@ def evaluate_from_stats(
"""
config = self._get_config(config)
agg_stats = self.aggregate_stats(stats)
agreement = self.calc_agreement(stats)
value = self.calc_metric_from_aggregate(agg_stats, config)
confidence_interval = (
self.calc_confidence_interval(stats, confidence_alpha)
if confidence_alpha
else None
)
return MetricResult(
config,
float(value),
confidence_interval,
confidence_alpha,
ExternalEvalResult(agreement),
)

metric_values: dict[str, MetricValue] = {
"score": Score(float(self.calc_metric_from_aggregate(agg_stats, config))),
"agreement": Score(self.calc_agreement(stats)),
}
if confidence_alpha is not None:
ci = self.calc_confidence_interval(stats, confidence_alpha)
if ci is not None:
metric_values["score_ci"] = ConfidenceInterval(
ci[0], ci[1], confidence_alpha
)

return MetricResult(config, metric_values)
62 changes: 62 additions & 0 deletions explainaboard/metrics/f1_score_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@

import unittest

from sklearn.metrics import f1_score

from explainaboard.metrics.f1_score import (
F1Score,
F1ScoreConfig,
SeqF1Score,
SeqF1ScoreConfig,
)
from explainaboard.metrics.metric import Score
from explainaboard.utils.typing_utils import unwrap


class F1ScoreConfigTest(unittest.TestCase):
Expand Down Expand Up @@ -56,6 +60,28 @@ def test_to_metric(self) -> None:
)


class F1ScoreTest(unittest.TestCase):
def test_evaluate_micro(self) -> None:
metric = F1ScoreConfig(name='F1', average='micro').to_metric()
true = ['a', 'b', 'a', 'b', 'a', 'a', 'c', 'c']
pred = ['a', 'b', 'a', 'b', 'b', 'a', 'c', 'a']
sklearn_f1 = f1_score(true, pred, average='micro')
result = metric.evaluate(true, pred, confidence_alpha=0.05)
self.assertAlmostEqual(
unwrap(result.get_value(Score, "score")).value, sklearn_f1
)

def test_evaluate_macro(self) -> None:
metric = F1ScoreConfig(name='F1', average='macro').to_metric()
true = ['a', 'b', 'a', 'b', 'a', 'a', 'c', 'c']
pred = ['a', 'b', 'a', 'b', 'b', 'a', 'c', 'a']
sklearn_f1 = f1_score(true, pred, average='macro')
result = metric.evaluate(true, pred, confidence_alpha=None)
self.assertAlmostEqual(
unwrap(result.get_value(Score, "score")).value, sklearn_f1
)


class SeqF1ScoreConfigTest(unittest.TestCase):
def test_serialize(self) -> None:
self.assertEqual(
Expand Down Expand Up @@ -102,3 +128,39 @@ def test_to_metric(self) -> None:
SeqF1ScoreConfig("SeqF1Score").to_metric(),
SeqF1Score,
)


class SeqF1ScoreTest(unittest.TestCase):
def test_evaluate_micro(self) -> None:
true = [
['O', 'O', 'B-MISC', 'I-MISC', 'B-MISC', 'O', 'O'],
['B-PER', 'I-PER', 'O'],
]
pred = [
['O', 'O', 'B-MISC', 'I-MISC', 'B-MISC', 'I-MISC', 'O'],
['B-PER', 'I-PER', 'O'],
]
metric = SeqF1ScoreConfig(
name='MicroF1', average='micro', tag_schema='bio'
).to_metric()
result = metric.evaluate(true, pred, confidence_alpha=None)
self.assertAlmostEqual(
unwrap(result.get_value(Score, "score")).value, 2.0 / 3.0
)

def test_evaluate_macro(self) -> None:
true = [
['O', 'O', 'B-MISC', 'I-MISC', 'B-MISC', 'O', 'O'],
['B-PER', 'I-PER', 'O'],
]
pred = [
['O', 'O', 'B-MISC', 'I-MISC', 'B-MISC', 'I-MISC', 'O'],
['B-PER', 'I-PER', 'O'],
]
metric = SeqF1ScoreConfig(
name='MacroF1', average='macro', tag_schema='bio'
).to_metric()
result = metric.evaluate(true, pred, confidence_alpha=None)
self.assertAlmostEqual(
unwrap(result.get_value(Score, "score")).value, 3.0 / 4.0
)
Loading