Skip to content

Commit

Permalink
Strict shape checking of aggregate_stats and `calc_metric_from_aggr…
Browse files Browse the repository at this point in the history
…egate` (#499)

Co-authored-by: Graham Neubig <[email protected]>
  • Loading branch information
Yusuke Oda and neubig authored Sep 23, 2022
1 parent f66f9ad commit 3714558
Show file tree
Hide file tree
Showing 9 changed files with 392 additions and 61 deletions.
2 changes: 1 addition & 1 deletion explainaboard/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def calc_stats_from_data(
)
)

def aggregate_stats(self, stats: MetricStats) -> np.ndarray:
def _aggregate_stats(self, stats: MetricStats) -> np.ndarray:
"""See Metric.aggregate_stats."""
data = stats.get_batch_data() if stats.is_batched() else stats.get_data()
if data.size == 0:
Expand Down
6 changes: 4 additions & 2 deletions explainaboard/metrics/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ def is_simple_average(self, stats: MetricStats):
"""See Metric.is_simple_average."""
return False

def calc_metric_from_aggregate(
def _calc_metric_from_aggregate(
self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None
) -> np.ndarray:
"""See Metric.calc_metric_from_aggregate."""
return np.sqrt(agg_stats)
if agg_stats.shape[-1] != 1:
raise ValueError("Invalid shape for aggregate stats {agg_stats.shape}")
return np.sqrt(np.squeeze(agg_stats, axis=-1))


@dataclass
Expand Down
18 changes: 10 additions & 8 deletions explainaboard/metrics/eaas.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ class EaaSMetric(Metric):

_NOT_SIMPLE_METRICS = {'bleu', 'chrf', 'length_ratio', 'length'}

def calc_metric_from_aggregate(
def _calc_metric_from_aggregate(
self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None
) -> np.ndarray:
"""See Metric.calc_metric_from_aggregate."""
if agg_stats.ndim == 1:
is_batched = agg_stats.ndim != 1
if not is_batched:
agg_stats = agg_stats.reshape((1, agg_stats.shape[0]))
n_samples = agg_stats.shape[0]
if self.config.name in {'bleu', 'chrf'}:
Expand All @@ -113,19 +114,20 @@ def calc_metric_from_aggregate(
metric_class._compute_score_from_stats(list(single_stat)).score
/ 100.0
)
return ret_metric
calc_result = ret_metric
elif self.config.name == 'length_ratio':
return agg_stats[:, 0] / agg_stats[:, 1]
elif self.config.name == 'length':
return agg_stats[:, 0]
calc_result = agg_stats[:, 0] / agg_stats[:, 1]
else:
return agg_stats
calc_result = agg_stats[:, 0]
if not is_batched:
calc_result = calc_result[0]
return calc_result

def is_simple_average(self, stats: MetricStats):
"""See Metric.is_simple_average."""
return self.config.name not in self._NOT_SIMPLE_METRICS

def aggregate_stats(self, stats: MetricStats) -> np.ndarray:
def _aggregate_stats(self, stats: MetricStats) -> np.ndarray:
"""See: Metric.aggregate_stats."""
data = stats.get_batch_data() if stats.is_batched() else stats.get_data()
if self.config.name in {'bleu', 'chrf'}:
Expand Down
2 changes: 1 addition & 1 deletion explainaboard/metrics/external_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def calc_agreement(self, stats: MetricStats) -> float:
# self.config.agreement = fleiss_kappa(mat_kappa)
return fleiss_kappa(mat_kappa)

def aggregate_stats(self, stats: MetricStats) -> np.ndarray:
def _aggregate_stats(self, stats: MetricStats) -> np.ndarray:
"""See Metric.aggregate_stats."""
data = stats.get_batch_data() if stats.is_batched() else stats.get_data()

Expand Down
41 changes: 19 additions & 22 deletions explainaboard/metrics/f1_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,12 @@ def calc_stats_from_data(
stats[i, tid * stat_mult + 3] += 1
return SimpleMetricStats(stats)

def calc_metric_from_aggregate(
def _calc_metric_from_aggregate(
self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None
) -> np.ndarray:
"""See Metric.calc_metric_from_aggregate."""
if agg_stats.size == 1:
return agg_stats

if agg_stats.ndim == 1:
is_batched = agg_stats.ndim != 1
if not is_batched:
agg_stats = agg_stats.reshape((1, agg_stats.shape[0]))

config = cast(F1ScoreConfig, unwrap_or(config, self.config))
Expand Down Expand Up @@ -135,6 +133,9 @@ def calc_metric_from_aggregate(
if config.average == 'macro':
f1 = np.mean(f1, axis=1)

if not is_batched:
f1 = f1[0]

return f1


Expand Down Expand Up @@ -185,23 +186,19 @@ def calc_stats_from_data(
)
return SimpleMetricStats(np.array(stats))

def aggregate_stats(self, stats: MetricStats) -> np.ndarray:
"""See Metric.aggregate_stats."""
data = stats.get_batch_data() if stats.is_batched() else stats.get_data()
if data.size == 0:
return np.array(0.0)
else:
# when data.ndim == 3, e.g.,
# * 1000 * 100 * 3 -> 1000 * 3
data_sum = np.sum(data, axis=(-2))
total_gold = data_sum[0] if data.ndim == 2 else data_sum[:, 0]
total_pred = data_sum[1] if data.ndim == 2 else data_sum[:, 1]
correct_num = data_sum[2] if data.ndim == 2 else data_sum[:, 2]

precision = correct_num * 1.0 / total_pred
recall = correct_num * 1.0 / total_gold
fscore = 2.0 * precision * recall / (precision + recall)
return np.array(fscore)
def _calc_metric_from_aggregate(
self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None
) -> np.ndarray:
"""See Metric._calc_metric_from_aggregate."""
is_batched = agg_stats.ndim == 2
if not is_batched:
agg_stats = agg_stats.reshape((1, -1))
precision = agg_stats[:, 2] * 1.0 / agg_stats[:, 1]
recall = agg_stats[:, 2] * 1.0 / agg_stats[:, 0]
fscore = 2.0 * precision * recall / (precision + recall)
if not is_batched:
fscore = fscore[0]
return fscore


@dataclass
Expand Down
13 changes: 10 additions & 3 deletions explainaboard/metrics/log_prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,21 @@ def calc_stats_from_data(
t = type(pred_data[0])
raise ValueError(f'Invalid type of pred_data for calc_stats_from_data {t}')

def calc_metric_from_aggregate(
def _calc_metric_from_aggregate(
self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None
) -> np.ndarray:
"""See Metric.calc_metric_from_aggregate."""
if agg_stats.ndim == 1:
is_batched = agg_stats.ndim != 1
if not is_batched:
agg_stats = agg_stats.reshape((1, agg_stats.shape[0]))
config = cast(LogProbConfig, unwrap_or(config, self.config))
val = agg_stats if agg_stats.size == 1 else agg_stats[:, 0] / agg_stats[:, 1]
val = (
agg_stats[:, 0]
if agg_stats.size == 1
else agg_stats[:, 0] / agg_stats[:, 1]
)
if config.ppl:
val = np.exp(-val)
if not is_batched:
val = val[0]
return val
121 changes: 103 additions & 18 deletions explainaboard/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,44 +269,110 @@ def calc_stats_from_data(
evaluation metric can be calculated later. In the simplest form, this is just
the evaluation metric value for each example.
:param true_data: gold-standard data
:param pred_data: predicted data
:param config: a configuration to over-ride the default for this object
:return: a numpy array of shape [len(true_data), X] where X=1 in the simplest
case of decomposable eval metrics
Args:
true_data: gold-standard data
pred_data: predicted data
config: a configuration to over-ride the default for this object
Returns:
A numpy array of shape [len(true_data), X] where X=1 in the simplest case of
decomposable eval metrics
"""
...

def aggregate_stats(self, stats: MetricStats) -> np.ndarray:
@final
def aggregate_stats(
self, stats: MetricStats
) -> np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any]:
"""Aggregate sufficient statistics from multiple examples into a single example.
Args:
stats: stats for every example
Returns:
Aggregated stats
Aggregated stats. Shape must be:
- Non-batched data: [num_aggregate_stats]
- Batched data: [num_batches, num_aggregate_stats]
"""
result = self._aggregate_stats(stats)

num_stats = (
result.shape[-1]
if self.uses_customized_aggregate()
else stats.num_statistics()
)
result_shape = (
(stats.get_batch_data().shape[0], num_stats)
if stats.is_batched()
else (num_stats,)
)

assert result.shape == result_shape, (
"BUG: invalid operation: "
f"{type(self).__name__}._aggregate_stats(): "
f"Expected shape {result_shape}, but got {result.shape}."
)

return result

def _aggregate_stats(
self, stats: MetricStats
) -> np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any]:
"""Inner function of aggregate_stats."""
data = stats.get_batch_data() if stats.is_batched() else stats.get_data()
if data.size == 0:
return np.array(0.0)
if data.shape[-2] == 0:
return np.zeros(
shape=data.shape[:-2] + (data.shape[-1],),
dtype=np.float32,
)
else:
return np.mean(data, axis=-2)

@final
def calc_metric_from_aggregate(
self, agg_stats: np.ndarray, config: Optional[MetricConfig] = None
) -> np.ndarray:
self,
agg_stats: np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any],
config: Optional[MetricConfig] = None,
) -> np.ndarray[tuple[()], Any] | np.ndarray[tuple[int], Any]:
"""From aggregated sufficient statistics, calculate the metric value.
Args:
agg_stats: aggregated statistics, either:
one-dimensional [metric_size]
two-dimensional [batch_size, metric_size]
agg_stats: aggregated statistics. Shape must be:
- Non-batched data: [num_aggregate_stats]
- Batched data: [num_batches, num_aggregate_stats]
config: a configuration to over-ride the default for this object
Returns:
calculated metric of size 1, or metrics of size [batch_size]
Calculated metrics. Shape must be:
- Non-batched data: []
- Batched data: [num_batches]
"""
return agg_stats
if agg_stats.ndim not in (1, 2):
raise ValueError(f"Invalid shape size: {agg_stats.shape}")

result = self._calc_metric_from_aggregate(agg_stats, config)
result_shape = () if agg_stats.ndim == 1 else (agg_stats.shape[0],)

assert result.shape == result_shape, (
"BUG: invalid operation: "
f"{type(self).__name__}._calc_metric_from_aggregate(): "
f"Expected shape {result_shape}, but got {result.shape}."
)

return result

def _calc_metric_from_aggregate(
self,
agg_stats: np.ndarray[tuple[int], Any] | np.ndarray[tuple[int, int], Any],
config: Optional[MetricConfig] = None,
) -> np.ndarray[tuple[()], Any] | np.ndarray[tuple[int], Any]:
"""Inner function of calc_metric_from_aggregate."""
if agg_stats.shape[-1] != 1:
raise ValueError(
"Multiple aggregates can't be integrated without specific algorithms."
)

return agg_stats.squeeze(-1)

def is_simple_average(self, stats: MetricStats):
"""Whether the eval score is a simple average of the sufficient statistics.
Expand All @@ -317,6 +383,14 @@ def is_simple_average(self, stats: MetricStats):
"""
return True

def uses_customized_aggregate(self) -> bool:
"""Whether the metric uses other aggregated stats than example-level stats.
If this function returns True, aggregate_stats() skips to check the size of the
last dimension of the returned ndarray.
"""
return False

def calc_confidence_interval(
self,
stats: MetricStats,
Expand All @@ -335,8 +409,13 @@ def calc_confidence_interval(
Returns:
A confidence interval or `None` if one cannot be calculated.
"""
if confidence_alpha <= 0.0 or confidence_alpha >= 1.0:
raise ValueError(f'Bad confidence value {confidence_alpha}')
if not (0.0 < confidence_alpha < 1.0):
raise ValueError(f'Invalid confidence_alpha: {confidence_alpha}')

if stats.is_batched():
raise ValueError(
"Confidence interval can't be calculated for batched data."
)

stats_data = stats.get_batch_data() if stats.is_batched() else stats.get_data()
num_stats = stats.num_statistics()
Expand Down Expand Up @@ -373,6 +452,12 @@ def calc_confidence_interval(
filt_stats = stats.filter(all_indices)
agg_stats = self.aggregate_stats(filt_stats)
samp_results = self.calc_metric_from_aggregate(agg_stats, config)

if samp_results.ndim != 1:
raise ValueError(
f"Invalid shape of sampled metrics: {samp_results.shape}"
)

samp_results.sort()
low = int(num_iterations * confidence_alpha / 2.0)
high = int(num_iterations * (1.0 - confidence_alpha / 2.0))
Expand Down
Loading

0 comments on commit 3714558

Please sign in to comment.