From f464c06e6387592ebe989878c5b6302fbe28bece Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 2 Jan 2025 14:54:51 +0800 Subject: [PATCH] reset S0_dataset_group if dataset is not None --- src/fsrs_optimizer/fsrs_optimizer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 662da30..ec13d7a 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -1003,12 +1003,6 @@ def pretrain(self, dataset=None, verbose=True): ) else: self.dataset = dataset - self.dataset = self.dataset[ - (self.dataset["i"] > 1) & (self.dataset["delta_t"] > 0) - ] - if self.dataset.empty: - raise ValueError("Training data is inadequate.") - if self.S0_dataset_group is None: self.dataset["first_rating"] = self.dataset["r_history"].map( lambda x: x[0] if len(x) > 0 else "" ) @@ -1018,6 +1012,11 @@ def pretrain(self, dataset=None, verbose=True): .agg({"y": ["mean", "count"]}) .reset_index() ) + self.dataset = self.dataset[ + (self.dataset["i"] > 1) & (self.dataset["delta_t"] > 0) + ] + if self.dataset.empty: + raise ValueError("Training data is inadequate.") rating_stability = {} rating_count = {} average_recall = self.dataset["y"].mean()