diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 93aff68..a4bc738 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -1551,6 +1551,8 @@ def evaluate(self, save_to_file=True): lambda row: -np.log(row["p"]) if row["y"] == 1 else -np.log(1 - row["p"]), axis=1, ) + if "weights" not in self.dataset.columns: + self.dataset["weights"] = 1 self.dataset["log_loss"] = ( self.dataset["log_loss"] * self.dataset["weights"]