diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index a4bc738..2aa03e2 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -1541,6 +1541,11 @@ def moving_average(data, window_size=365 // 20): def evaluate(self, save_to_file=True): my_collection = Collection(DEFAULT_PARAMETER, self.float_delta_t) + if "tensor" not in self.dataset.columns: + self.dataset["tensor"] = self.dataset.progress_apply( + lambda x: lineToTensor(list(zip([x["t_history"]], [x["r_history"]]))[0]), + axis=1, + ) stabilities, difficulties = my_collection.batch_predict(self.dataset) self.dataset["stability"] = stabilities self.dataset["difficulty"] = difficulties