From 37145588a75b34cc9b9f5feb82a76f23bf692c95 Mon Sep 17 00:00:00 2001 From: Yusuke Oda Date: Sat, 24 Sep 2022 01:13:24 +0900 Subject: [PATCH] Strict shape checking of `aggregate_stats` and `calc_metric_from_aggregate` (#499) Co-authored-by: Graham Neubig --- explainaboard/metrics/accuracy.py | 2 +- explainaboard/metrics/continuous.py | 6 +- explainaboard/metrics/eaas.py | 18 +- explainaboard/metrics/external_eval.py | 2 +- explainaboard/metrics/f1_score.py | 41 ++-- explainaboard/metrics/log_prob.py | 13 +- explainaboard/metrics/metric.py | 121 ++++++++-- explainaboard/metrics/metric_test.py | 223 +++++++++++++++++++ explainaboard/metrics/nlg_meta_evaluation.py | 27 ++- 9 files changed, 392 insertions(+), 61 deletions(-) create mode 100644 explainaboard/metrics/metric_test.py diff --git a/explainaboard/metrics/accuracy.py b/explainaboard/metrics/accuracy.py index 6c1d0b5a..1a9d4b98 100644 --- a/explainaboard/metrics/accuracy.py +++ b/explainaboard/metrics/accuracy.py @@ -76,7 +76,7 @@ def calc_stats_from_data( ) ) - def aggregate_stats(self, stats: MetricStats) -> np.ndarray: + def _aggregate_stats(self, stats: MetricStats) -> np.ndarray: """See Metric.aggregate_stats.""" data = stats.get_batch_data() if stats.is_batched() else stats.get_data() if data.size == 0: diff --git a/explainaboard/metrics/continuous.py b/explainaboard/metrics/continuous.py index db321cbe..6a550f85 100644 --- a/explainaboard/metrics/continuous.py +++ b/explainaboard/metrics/continuous.py @@ -41,11 +41,13 @@ def is_simple_average(self, stats: MetricStats): """See Metric.is_simple_average.""" return False - def calc_metric_from_aggregate( + def _calc_metric_from_aggregate( self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None ) -> np.ndarray: """See Metric.calc_metric_from_aggregate.""" - return np.sqrt(agg_stats) + if agg_stats.shape[-1] != 1: + raise ValueError("Invalid shape for aggregate stats {agg_stats.shape}") + return np.sqrt(np.squeeze(agg_stats, axis=-1)) @dataclass diff --git a/explainaboard/metrics/eaas.py b/explainaboard/metrics/eaas.py index bdb8c154..61d67753 100644 --- a/explainaboard/metrics/eaas.py +++ b/explainaboard/metrics/eaas.py @@ -96,11 +96,12 @@ class EaaSMetric(Metric): _NOT_SIMPLE_METRICS = {'bleu', 'chrf', 'length_ratio', 'length'} - def calc_metric_from_aggregate( + def _calc_metric_from_aggregate( self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None ) -> np.ndarray: """See Metric.calc_metric_from_aggregate.""" - if agg_stats.ndim == 1: + is_batched = agg_stats.ndim != 1 + if not is_batched: agg_stats = agg_stats.reshape((1, agg_stats.shape[0])) n_samples = agg_stats.shape[0] if self.config.name in {'bleu', 'chrf'}: @@ -113,19 +114,20 @@ def calc_metric_from_aggregate( metric_class._compute_score_from_stats(list(single_stat)).score / 100.0 ) - return ret_metric + calc_result = ret_metric elif self.config.name == 'length_ratio': - return agg_stats[:, 0] / agg_stats[:, 1] - elif self.config.name == 'length': - return agg_stats[:, 0] + calc_result = agg_stats[:, 0] / agg_stats[:, 1] else: - return agg_stats + calc_result = agg_stats[:, 0] + if not is_batched: + calc_result = calc_result[0] + return calc_result def is_simple_average(self, stats: MetricStats): """See Metric.is_simple_average.""" return self.config.name not in self._NOT_SIMPLE_METRICS - def aggregate_stats(self, stats: MetricStats) -> np.ndarray: + def _aggregate_stats(self, stats: MetricStats) -> np.ndarray: """See: Metric.aggregate_stats.""" data = stats.get_batch_data() if stats.is_batched() else stats.get_data() if self.config.name in {'bleu', 'chrf'}: diff --git a/explainaboard/metrics/external_eval.py b/explainaboard/metrics/external_eval.py index 8ce8675b..066c19f9 100644 --- a/explainaboard/metrics/external_eval.py +++ b/explainaboard/metrics/external_eval.py @@ -156,7 +156,7 @@ def calc_agreement(self, stats: MetricStats) -> float: # self.config.agreement = fleiss_kappa(mat_kappa) return fleiss_kappa(mat_kappa) - def aggregate_stats(self, stats: MetricStats) -> np.ndarray: + def _aggregate_stats(self, stats: MetricStats) -> np.ndarray: """See Metric.aggregate_stats.""" data = stats.get_batch_data() if stats.is_batched() else stats.get_data() diff --git a/explainaboard/metrics/f1_score.py b/explainaboard/metrics/f1_score.py index 751303e5..fedf7e8d 100644 --- a/explainaboard/metrics/f1_score.py +++ b/explainaboard/metrics/f1_score.py @@ -100,14 +100,12 @@ def calc_stats_from_data( stats[i, tid * stat_mult + 3] += 1 return SimpleMetricStats(stats) - def calc_metric_from_aggregate( + def _calc_metric_from_aggregate( self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None ) -> np.ndarray: """See Metric.calc_metric_from_aggregate.""" - if agg_stats.size == 1: - return agg_stats - - if agg_stats.ndim == 1: + is_batched = agg_stats.ndim != 1 + if not is_batched: agg_stats = agg_stats.reshape((1, agg_stats.shape[0])) config = cast(F1ScoreConfig, unwrap_or(config, self.config)) @@ -135,6 +133,9 @@ def calc_metric_from_aggregate( if config.average == 'macro': f1 = np.mean(f1, axis=1) + if not is_batched: + f1 = f1[0] + return f1 @@ -185,23 +186,19 @@ def calc_stats_from_data( ) return SimpleMetricStats(np.array(stats)) - def aggregate_stats(self, stats: MetricStats) -> np.ndarray: - """See Metric.aggregate_stats.""" - data = stats.get_batch_data() if stats.is_batched() else stats.get_data() - if data.size == 0: - return np.array(0.0) - else: - # when data.ndim == 3, e.g., - # * 1000 * 100 * 3 -> 1000 * 3 - data_sum = np.sum(data, axis=(-2)) - total_gold = data_sum[0] if data.ndim == 2 else data_sum[:, 0] - total_pred = data_sum[1] if data.ndim == 2 else data_sum[:, 1] - correct_num = data_sum[2] if data.ndim == 2 else data_sum[:, 2] - - precision = correct_num * 1.0 / total_pred - recall = correct_num * 1.0 / total_gold - fscore = 2.0 * precision * recall / (precision + recall) - return np.array(fscore) + def _calc_metric_from_aggregate( + self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None + ) -> np.ndarray: + """See Metric._calc_metric_from_aggregate.""" + is_batched = agg_stats.ndim == 2 + if not is_batched: + agg_stats = agg_stats.reshape((1, -1)) + precision = agg_stats[:, 2] * 1.0 / agg_stats[:, 1] + recall = agg_stats[:, 2] * 1.0 / agg_stats[:, 0] + fscore = 2.0 * precision * recall / (precision + recall) + if not is_batched: + fscore = fscore[0] + return fscore @dataclass diff --git a/explainaboard/metrics/log_prob.py b/explainaboard/metrics/log_prob.py index a60ada65..9b8e2f8a 100644 --- a/explainaboard/metrics/log_prob.py +++ b/explainaboard/metrics/log_prob.py @@ -56,14 +56,21 @@ def calc_stats_from_data( t = type(pred_data[0]) raise ValueError(f'Invalid type of pred_data for calc_stats_from_data {t}') - def calc_metric_from_aggregate( + def _calc_metric_from_aggregate( self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None ) -> np.ndarray: """See Metric.calc_metric_from_aggregate.""" - if agg_stats.ndim == 1: + is_batched = agg_stats.ndim != 1 + if not is_batched: agg_stats = agg_stats.reshape((1, agg_stats.shape[0])) config = cast(LogProbConfig, unwrap_or(config, self.config)) - val = agg_stats if agg_stats.size == 1 else agg_stats[:, 0] / agg_stats[:, 1] + val = ( + agg_stats[:, 0] + if agg_stats.size == 1 + else agg_stats[:, 0] / agg_stats[:, 1] + ) if config.ppl: val = np.exp(-val) + if not is_batched: + val = val[0] return val diff --git a/explainaboard/metrics/metric.py b/explainaboard/metrics/metric.py index 2c72493c..38f11e89 100644 --- a/explainaboard/metrics/metric.py +++ b/explainaboard/metrics/metric.py @@ -269,44 +269,110 @@ def calc_stats_from_data( evaluation metric can be calculated later. In the simplest form, this is just the evaluation metric value for each example. - :param true_data: gold-standard data - :param pred_data: predicted data - :param config: a configuration to over-ride the default for this object - :return: a numpy array of shape [len(true_data), X] where X=1 in the simplest - case of decomposable eval metrics + Args: + true_data: gold-standard data + pred_data: predicted data + config: a configuration to over-ride the default for this object + + Returns: + A numpy array of shape [len(true_data), X] where X=1 in the simplest case of + decomposable eval metrics """ ... - def aggregate_stats(self, stats: MetricStats) -> np.ndarray: + @final + def aggregate_stats( + self, stats: MetricStats + ) -> np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any]: """Aggregate sufficient statistics from multiple examples into a single example. Args: stats: stats for every example Returns: - Aggregated stats + Aggregated stats. Shape must be: + - Non-batched data: [num_aggregate_stats] + - Batched data: [num_batches, num_aggregate_stats] """ + result = self._aggregate_stats(stats) + + num_stats = ( + result.shape[-1] + if self.uses_customized_aggregate() + else stats.num_statistics() + ) + result_shape = ( + (stats.get_batch_data().shape[0], num_stats) + if stats.is_batched() + else (num_stats,) + ) + + assert result.shape == result_shape, ( + "BUG: invalid operation: " + f"{type(self).__name__}._aggregate_stats(): " + f"Expected shape {result_shape}, but got {result.shape}." + ) + + return result + + def _aggregate_stats( + self, stats: MetricStats + ) -> np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any]: + """Inner function of aggregate_stats.""" data = stats.get_batch_data() if stats.is_batched() else stats.get_data() - if data.size == 0: - return np.array(0.0) + if data.shape[-2] == 0: + return np.zeros( + shape=data.shape[:-2] + (data.shape[-1],), + dtype=np.float32, + ) else: return np.mean(data, axis=-2) + @final def calc_metric_from_aggregate( - self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None - ) -> np.ndarray: + self, + agg_stats: np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any], + config: Optional[MetricConfig] = None, + ) -> np.ndarray[tuple[()], Any] | np.ndarray[tuple[int], Any]: """From aggregated sufficient statistics, calculate the metric value. Args: - agg_stats: aggregated statistics, either: - one-dimensional [metric_size] - two-dimensional [batch_size, metric_size] + agg_stats: aggregated statistics. Shape must be: + - Non-batched data: [num_aggregate_stats] + - Batched data: [num_batches, num_aggregate_stats] config: a configuration to over-ride the default for this object Returns: - calculated metric of size 1, or metrics of size [batch_size] + Calculated metrics. Shape must be: + - Non-batched data: [] + - Batched data: [num_batches] """ - return agg_stats + if agg_stats.ndim not in (1, 2): + raise ValueError(f"Invalid shape size: {agg_stats.shape}") + + result = self._calc_metric_from_aggregate(agg_stats, config) + result_shape = () if agg_stats.ndim == 1 else (agg_stats.shape[0],) + + assert result.shape == result_shape, ( + "BUG: invalid operation: " + f"{type(self).__name__}._calc_metric_from_aggregate(): " + f"Expected shape {result_shape}, but got {result.shape}." + ) + + return result + + def _calc_metric_from_aggregate( + self, + agg_stats: np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any], + config: Optional[MetricConfig] = None, + ) -> np.ndarray[tuple[()], Any] | np.ndarray[tuple[int], Any]: + """Inner function of calc_metric_from_aggregate.""" + if agg_stats.shape[-1] != 1: + raise ValueError( + "Multiple aggregates can't be integrated without specific algorithms." + ) + + return agg_stats.squeeze(-1) def is_simple_average(self, stats: MetricStats): """Whether the eval score is a simple average of the sufficient statistics. @@ -317,6 +383,14 @@ def is_simple_average(self, stats: MetricStats): """ return True + def uses_customized_aggregate(self) -> bool: + """Whether the metric uses other aggregated stats than example-level stats. + + If this function returns True, aggregate_stats() skips to check the size of the + last dimension of the returned ndarray. + """ + return False + def calc_confidence_interval( self, stats: MetricStats, @@ -335,8 +409,13 @@ def calc_confidence_interval( Returns: A confidence interval or `None` if one cannot be calculated. """ - if confidence_alpha <= 0.0 or confidence_alpha >= 1.0: - raise ValueError(f'Bad confidence value {confidence_alpha}') + if not (0.0 < confidence_alpha < 1.0): + raise ValueError(f'Invalid confidence_alpha: {confidence_alpha}') + + if stats.is_batched(): + raise ValueError( + "Confidence interval can't be calculated for batched data." + ) stats_data = stats.get_batch_data() if stats.is_batched() else stats.get_data() num_stats = stats.num_statistics() @@ -373,6 +452,12 @@ def calc_confidence_interval( filt_stats = stats.filter(all_indices) agg_stats = self.aggregate_stats(filt_stats) samp_results = self.calc_metric_from_aggregate(agg_stats, config) + + if samp_results.ndim != 1: + raise ValueError( + f"Invalid shape of sampled metrics: {samp_results.shape}" + ) + samp_results.sort() low = int(num_iterations * confidence_alpha / 2.0) high = int(num_iterations * (1.0 - confidence_alpha / 2.0)) diff --git a/explainaboard/metrics/metric_test.py b/explainaboard/metrics/metric_test.py new file mode 100644 index 00000000..ef7adb7b --- /dev/null +++ b/explainaboard/metrics/metric_test.py @@ -0,0 +1,223 @@ +"""Tests for explainaboard.metrics.metric""" + +from __future__ import annotations + +from collections.abc import Callable +import dataclasses +from typing import Any +import unittest + +import numpy as np + +from explainaboard.metrics.metric import ( + Metric, + MetricConfig, + MetricStats, + SimpleMetricStats, +) +from explainaboard.utils.typing_utils import narrow, unwrap + + +@dataclasses.dataclass +class _DummyMetricConfig(MetricConfig): + is_simple_average: bool = True + uses_customized_aggregate: bool = False + aggregate_stats_fn: Callable[[MetricStats], np.ndarray[Any, Any]] | None = None + + def to_metric(self) -> Metric: + return _DummyMetric(self) + + +class _DummyMetric(Metric): + def is_simple_average(self, stats: MetricStats) -> bool: + return narrow(_DummyMetricConfig, self.config).is_simple_average + + def uses_customized_aggregate(self) -> bool: + return narrow(_DummyMetricConfig, self.config).uses_customized_aggregate + + def _aggregate_stats( + self, stats: MetricStats + ) -> np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any]: + user_agg_fn = narrow(_DummyMetricConfig, self.config).aggregate_stats_fn + agg_fn = user_agg_fn if user_agg_fn is not None else super()._aggregate_stats + return agg_fn(stats) + + def calc_stats_from_data( + self, + true_data: list, + pred_data: list, + config: MetricConfig | None = None, + ) -> MetricStats: + raise NotImplementedError + + +class MetricTest(unittest.TestCase): + def test_aggregate_stats_1dim(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + stats = SimpleMetricStats(np.array([1.0, 2.0, 3.0])) + aggregate = metric.aggregate_stats(stats) + self.assertTrue(np.array_equal(aggregate, np.array([2.0]))) + + def test_aggregate_stats_2dim(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + stats = SimpleMetricStats(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + aggregate = metric.aggregate_stats(stats) + self.assertTrue(np.array_equal(aggregate, np.array([3.0, 4.0]))) + + def test_aggregate_stats_2dim_empty(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + stats = SimpleMetricStats(np.zeros((0, 3))) + aggregate = metric.aggregate_stats(stats) + self.assertTrue(np.array_equal(aggregate, np.zeros((3,)))) + + def test_aggregate_stats_3dim(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + stats = SimpleMetricStats( + np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + ) + aggregate = metric.aggregate_stats(stats) + self.assertTrue(np.array_equal(aggregate, np.array([[2.0, 3.0], [6.0, 7.0]]))) + + def test_aggregate_stats_3dim_empty(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + stats = SimpleMetricStats(np.zeros((2, 0, 3))) + aggregate = metric.aggregate_stats(stats) + self.assertTrue(np.array_equal(aggregate, np.zeros((2, 3)))) + + def test_aggregate_stats_customized_nonbatch(self) -> None: + def agg_fn(stats: MetricStats) -> np.ndarray[Any, Any]: + return stats.get_data().max(axis=-2).max(axis=-1, keepdims=True) + + metric = _DummyMetric( + _DummyMetricConfig( + "test", uses_customized_aggregate=True, aggregate_stats_fn=agg_fn + ) + ) + stats = SimpleMetricStats(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + aggregate = metric.aggregate_stats(stats) + self.assertTrue(np.array_equal(aggregate, np.array([6.0]))) + + def test_aggregate_stats_customized_nonbatch_invalid(self) -> None: + def agg_fn(stats: MetricStats) -> np.ndarray[Any, Any]: + return stats.get_data().max(axis=-2).max(axis=-1, keepdims=True) + + metric = _DummyMetric(_DummyMetricConfig("test", aggregate_stats_fn=agg_fn)) + stats = SimpleMetricStats(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + with self.assertRaisesRegex( + AssertionError, r"Expected shape \(2,\), but got \(1,\)\.$" + ): + metric.aggregate_stats(stats) + + def test_aggregate_stats_customized_batch(self) -> None: + def agg_fn(stats: MetricStats) -> np.ndarray[Any, Any]: + return stats.get_batch_data().max(axis=-2).max(axis=-1, keepdims=True) + + metric = _DummyMetric( + _DummyMetricConfig( + "test", uses_customized_aggregate=True, aggregate_stats_fn=agg_fn + ) + ) + stats = SimpleMetricStats( + np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + ) + aggregate = metric.aggregate_stats(stats) + self.assertTrue(np.array_equal(aggregate, np.array([[4.0], [8.0]]))) + + def test_aggregate_stats_customized_batch_invalid(self) -> None: + def agg_fn(stats: MetricStats) -> np.ndarray[Any, Any]: + return stats.get_batch_data().max(axis=-2).max(axis=-1, keepdims=True) + + metric = _DummyMetric(_DummyMetricConfig("test", aggregate_stats_fn=agg_fn)) + stats = SimpleMetricStats( + np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + ) + with self.assertRaisesRegex( + AssertionError, r"Expected shape \(2, 2\), but got \(2, 1\)\.$" + ): + metric.aggregate_stats(stats) + + def test_calc_metric_from_aggregate_0dim(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + aggregate = np.array(3.0) + with self.assertRaisesRegex(ValueError, r"^Invalid shape size: \(\)$"): + metric.calc_metric_from_aggregate(aggregate) + + def test_calc_metric_from_aggregate_1dim(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + aggregate = np.array([3.0]) + result = metric.calc_metric_from_aggregate(aggregate) + self.assertTrue(np.array_equal(result, np.array(3.0))) + + def test_calc_metric_from_aggregate_1dim_multi(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + aggregate = np.array([3.0, 4.0]) + with self.assertRaisesRegex(ValueError, r"^Multiple aggregates"): + metric.calc_metric_from_aggregate(aggregate) + + def test_calc_metric_from_aggregate_2dim(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + aggregate = np.array([[1.0], [2.0], [3.0]]) + result = metric.calc_metric_from_aggregate(aggregate) + self.assertTrue(np.array_equal(result, np.array([1.0, 2.0, 3.0]))) + + def test_calc_metric_from_aggregate_2dim_multi(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + aggregate = np.array([[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]]) + with self.assertRaisesRegex(ValueError, r"^Multiple aggregates"): + metric.calc_metric_from_aggregate(aggregate) + + def test_calc_metric_from_aggregate_3dim(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + aggregate = np.array([[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]]) + with self.assertRaisesRegex(ValueError, r"Invalid shape size: \(2, 3, 1\)$"): + metric.calc_metric_from_aggregate(aggregate) + + def test_calc_confidence_interval_tdist(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + stats = SimpleMetricStats(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + ci = unwrap(metric.calc_confidence_interval(stats, 0.05)) + self.assertAlmostEqual(ci[0], 3.387428953673732) + self.assertAlmostEqual(ci[1], 3.612571046326268) + + def test_calc_confidence_interval_tdist_multi_agg(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + stats = SimpleMetricStats(np.array([[1.0, 2.0], [3.0, 4.0]])) + with self.assertRaisesRegex(ValueError, r"^t-test can be applied"): + metric.calc_confidence_interval(stats, 0.05) + + def test_calc_confidence_interval_bootstrap(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test", is_simple_average=False)) + stats = SimpleMetricStats(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + ci = unwrap(metric.calc_confidence_interval(stats, 0.05)) + # NOTE(odashi): + # The sampler takes only 3 samples for each bootstrap iteration, resulting in + # very wide confidence interval. This is a limitation of bootstrapping. + self.assertLess(ci[0], ci[1]) + self.assertGreaterEqual(ci[0], 1.0) + self.assertLessEqual(ci[1], 6.0) + + def test_calc_confidence_interval_bootstrap_multi_agg(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test", is_simple_average=False)) + stats = SimpleMetricStats(np.array([[0.5, 1.5], [1.5, 2.5], [2.5, 3.5]])) + with self.assertRaisesRegex(ValueError, r"^Multiple aggregates"): + metric.calc_confidence_interval(stats, 0.05) + + def test_calc_confidence_interval_invalid_alpha(self) -> None: + metric = _DummyMetric(_DummyMetricConfig("test")) + stats = SimpleMetricStats(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + with self.assertRaisesRegex(ValueError, r"^Invalid confidence_alpha: -0.125$"): + self.assertIsNone(metric.calc_confidence_interval(stats, -0.125)) + with self.assertRaisesRegex(ValueError, r"^Invalid confidence_alpha: 0.0$"): + self.assertIsNone(metric.calc_confidence_interval(stats, 0.0)) + with self.assertRaisesRegex(ValueError, r"^Invalid confidence_alpha: 1.0$"): + self.assertIsNone(metric.calc_confidence_interval(stats, 1.0)) + with self.assertRaisesRegex(ValueError, r"^Invalid confidence_alpha: 1.125$"): + self.assertIsNone(metric.calc_confidence_interval(stats, 1.125)) + + def test_calc_confidence_interval_single_example(self) -> None: + for is_single_average in (False, True): + metric = _DummyMetric( + _DummyMetricConfig("test", is_simple_average=is_single_average) + ) + stats = SimpleMetricStats(np.array([[1.0]])) + self.assertIsNone(metric.calc_confidence_interval(stats, 0.05)) diff --git a/explainaboard/metrics/nlg_meta_evaluation.py b/explainaboard/metrics/nlg_meta_evaluation.py index 8a1ee7be..7fd61128 100644 --- a/explainaboard/metrics/nlg_meta_evaluation.py +++ b/explainaboard/metrics/nlg_meta_evaluation.py @@ -45,10 +45,14 @@ def to_metric(self) -> Metric: class CorrelationMetric(Metric): """A metric that calculates correlations.""" - def is_simple_average(self, stats: MetricStats): + def is_simple_average(self, stats: MetricStats) -> bool: """See Metric.is_simple_average.""" return False + def uses_customized_aggregate(self) -> bool: + """See Metric.uses_customized_aggregate.""" + return True + def calc_stats_from_data( self, true_data: list[Union[str, list[str]]], @@ -111,19 +115,30 @@ def get_scores_from_stats( return scores - def aggregate_stats(self, stats: MetricStats) -> np.ndarray: + def _aggregate_stats(self, stats: MetricStats) -> np.ndarray: """See Metric.aggregate_stats.""" - return stats.get_batch_data() if stats.is_batched() else stats.get_data() + if stats.is_batched(): + data = stats.get_batch_data() + assert data.shape[-1] == 4 + return data.reshape((data.shape[0], data.shape[-2] * data.shape[-1])) + else: + data = stats.get_data() + assert data.shape[-1] == 4 + return data.reshape((data.shape[-2] * data.shape[-1])) - def calc_metric_from_aggregate( + def _calc_metric_from_aggregate( self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None ) -> np.ndarray: """See Metric.calc_metric_from_aggregate.""" - if len(agg_stats.shape) == 2: + if agg_stats.ndim == 1: + agg_stats = agg_stats.reshape((int(agg_stats.shape[0] / 4), 4)) val = self.calc_metric_from_aggregate_single(agg_stats, config) - return np.array([val]) + return np.array(val) else: n_samples = agg_stats.shape[0] + agg_stats = agg_stats.reshape( + (agg_stats.shape[0], int(agg_stats.shape[1] / 4), 4) + ) ret_metric = np.zeros(n_samples) for i, single_stat in enumerate(agg_stats): val = self.calc_metric_from_aggregate_single(single_stat, config)