From 4970509203ba603d2056a14bb721bebf967dcff9 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 23 Dec 2024 18:46:45 +0800 Subject: [PATCH] Improve recency weighting (#154) * improve recency weighting * bump version * apply weights on evaluation --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d07146f..ebd70b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.6.0" +version = "5.6.1" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 72491a7..e812eff 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -1230,7 +1230,8 @@ def train( plots.append(trainer.plot()) else: if recency_weight: - self.dataset["weights"] = np.linspace(0.5, 1.5, len(self.dataset)) + x = np.linspace(0, 1, len(self.dataset)) + self.dataset["weights"] = 0.25 + 0.75 * np.power(x, 3) trainer = Trainer( self.dataset, None, @@ -1541,6 +1542,11 @@ def evaluate(self, save_to_file=True): lambda row: -np.log(row["p"]) if row["y"] == 1 else -np.log(1 - row["p"]), axis=1, ) + self.dataset["log_loss"] = ( + self.dataset["log_loss"] + * self.dataset["weights"] + / self.dataset["weights"].mean() + ) loss_before = self.dataset["log_loss"].mean() my_collection = Collection(self.w, self.float_delta_t) @@ -1554,6 +1560,11 @@ def evaluate(self, save_to_file=True): lambda row: -np.log(row["p"]) if row["y"] == 1 else -np.log(1 - row["p"]), axis=1, ) + self.dataset["log_loss"] = ( + self.dataset["log_loss"] + * self.dataset["weights"] + / self.dataset["weights"].mean() + ) loss_after = self.dataset["log_loss"].mean() if save_to_file: tmp = self.dataset.copy() @@ -2100,10 +2111,10 @@ def count_lapse(r_history, t_history): ) tmp = ( tmp.groupby(["delta_t", "i", "lapse"]) - .agg({"y": "mean", "p": "mean", "card_id": "count"}) + .agg({"y": "mean", "p": "mean", "weights": "sum"}) .reset_index() ) - return root_mean_squared_error(tmp["y"], tmp["p"], sample_weight=tmp["card_id"]) + return root_mean_squared_error(tmp["y"], tmp["p"], sample_weight=tmp["weights"]) def wrap_short_term_ratings(r_history, t_history):