diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 0662329..72491a7 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -356,12 +356,12 @@ def train(self, verbose: bool = True): for i, batch in enumerate(self.train_data_loader): self.model.train() self.optimizer.zero_grad() - sequences, delta_ts, labels, seq_lens = batch + sequences, delta_ts, labels, seq_lens, weights = batch real_batch_size = seq_lens.shape[0] outputs, _ = self.model(sequences) stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] retentions = power_forgetting_curve(delta_ts, stabilities) - loss = self.loss_fn(retentions, labels).sum() + loss = (self.loss_fn(retentions, labels) * weights).sum() loss.backward() if self.float_delta_t or not self.enable_short_term: for param in self.model.parameters(): @@ -400,17 +400,18 @@ def eval(self): if len(dataset) == 0: losses.append(0) continue - sequences, delta_ts, labels, seq_lens = ( + sequences, delta_ts, labels, seq_lens, weights = ( dataset.x_train, dataset.t_train, dataset.y_train, dataset.seq_len, + dataset.weights, ) real_batch_size = seq_lens.shape[0] outputs, _ = self.model(sequences.transpose(0, 1)) stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] retentions = power_forgetting_curve(delta_ts, stabilities) - loss = self.loss_fn(retentions, labels).mean() + loss = (self.loss_fn(retentions, labels) * weights).mean() losses.append(loss) self.avg_train_losses.append(losses[0]) self.avg_eval_losses.append(losses[1]) @@ -1186,6 +1187,7 @@ def train( batch_size: int = 512, verbose: bool = True, split_by_time: bool = False, + recency_weight: bool = False, ): """Step 4""" self.dataset["tensor"] = self.dataset.progress_apply( @@ -1198,9 +1200,9 @@ def train( w = [] plots = [] + self.dataset.sort_values(by=["review_time"], inplace=True) if split_by_time: tscv = TimeSeriesSplit(n_splits=5) - self.dataset.sort_values(by=["review_time"], inplace=True) for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)): if verbose: tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}") @@ -1227,6 +1229,8 @@ def train( print(metrics) plots.append(trainer.plot()) else: + if recency_weight: + self.dataset["weights"] = np.linspace(0.5, 1.5, len(self.dataset)) trainer = Trainer( self.dataset, None,