From 96cea5fa0d52583cf8068b9e9599d97231f48b54 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 2 Jan 2025 12:23:53 +0800 Subject: [PATCH] Fix/S0 dataset initialization fallback --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index eafb07d..8837ecf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.6.2" +version = "5.6.3" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index c9fe57a..662da30 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -519,6 +519,7 @@ def loss(stability): class Optimizer: float_delta_t: bool = False enable_short_term: bool = True + S0_dataset_group = None def __init__( self, float_delta_t: bool = False, enable_short_term: bool = True @@ -1007,6 +1008,16 @@ def pretrain(self, dataset=None, verbose=True): ] if self.dataset.empty: raise ValueError("Training data is inadequate.") + if self.S0_dataset_group is None: + self.dataset["first_rating"] = self.dataset["r_history"].map( + lambda x: x[0] if len(x) > 0 else "" + ) + self.S0_dataset_group = ( + self.dataset[self.dataset["i"] == 2] + .groupby(by=["first_rating", "delta_t"], group_keys=False) + .agg({"y": ["mean", "count"]}) + .reset_index() + ) rating_stability = {} rating_count = {} average_recall = self.dataset["y"].mean()