From 9539e13ca8f8dd3c10f3410e9ba77c256b3c8ed0 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Wed, 27 Nov 2024 09:57:30 +0800 Subject: [PATCH] fix shape --- src/fsrs_optimizer/fsrs_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 9d55a10..c7cda30 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -350,7 +350,7 @@ def train(self, verbose: bool = True): retentions = power_forgetting_curve(delta_ts, stabilities) pls_flag = sequences[seq_lens-1, torch.arange(real_batch_size), 1] == 1 penalty = torch.ones_like(retentions, requires_grad=False) - penalty[pls_flag] *= 10 + penalty[pls_flag] *= 2 loss = (self.loss_fn(retentions, labels) * penalty).sum() loss.backward() if self.float_delta_t: @@ -398,7 +398,7 @@ def eval(self): stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] retentions = power_forgetting_curve(delta_ts, stabilities) penalty = torch.ones_like(retentions, requires_grad=False) - pls_flag = sequences[seq_lens-1, torch.arange(real_batch_size), 1] == 1 + pls_flag = sequences[torch.arange(real_batch_size), seq_lens-1, 1] == 1 penalty[pls_flag] *= 2 loss = (self.loss_fn(retentions, labels) * penalty).mean() losses.append(loss)