Skip to content

Commit

Permalink
Feat/add ICI and related metrics (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jan 15, 2024
1 parent eb96867 commit 9c25848
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.20.8"
version = "4.21.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
16 changes: 15 additions & 1 deletion src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List, Optional
from datetime import timedelta, datetime
import statsmodels.api as sm
from statsmodels.nonparametric.smoothers_lowess import lowess
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch
Expand Down Expand Up @@ -1735,6 +1736,14 @@ def get_bin(x, bins=bins):


def plot_brier(predictions, real, bins=20, ax=None, title=None):
y, p = zip(*sorted(zip(real, predictions), key=lambda x: x[1]))
observation = lowess(
y, p, it=0, delta=0.01 * (max(p) - min(p)), is_sorted=True, return_sorted=False
)
ici = np.mean(np.abs(observation - p))
e_50 = np.median(np.abs(observation - p))
e_90 = np.quantile(np.abs(observation - p), 0.9)
e_max = np.max(np.abs(observation - p))
brier = load_brier(predictions, real, bins=bins)
bin_prediction_means = brier["detail"]["bin_prediction_means"]
bin_correct_means = brier["detail"]["bin_correct_means"]
Expand All @@ -1759,6 +1768,10 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
tqdm.write(f"R-squared: {r2:.4f}")
tqdm.write(f"RMSE: {rmse:.4f}")
tqdm.write(f"MAE: {mae:.4f}")
tqdm.write(f"ICI: {ici:.4f}")
tqdm.write(f"E50: {e_50:.4f}")
tqdm.write(f"E90: {e_90:.4f}")
tqdm.write(f"EMax: {e_max:.4f}")
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.grid(True)
Expand All @@ -1785,6 +1798,7 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
color="#1f77b4",
marker="*",
)
ax.plot(p, observation, label="Lowess Smoothing", color="red")
ax.plot((0, 1), (0, 1), label="Perfect Calibration", color="#ff7f0e")
bin_count = brier["detail"]["bin_count"]
counts = np.array(bin_counts)
Expand All @@ -1809,7 +1823,7 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
ax2.legend(loc="lower center")
if title:
ax.set_title(title)
metrics = {"R-squared": r2, "RMSE": rmse, "MAE": mae}
metrics = {"R-squared": r2, "RMSE": rmse, "MAE": mae, "ICI": ici}
return metrics


Expand Down

0 comments on commit 9c25848

Please sign in to comment.