Skip to content

Commit

Permalink
fix shape
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Nov 27, 2024
1 parent b236f73 commit 9539e13
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def train(self, verbose: bool = True):
retentions = power_forgetting_curve(delta_ts, stabilities)
pls_flag = sequences[seq_lens-1, torch.arange(real_batch_size), 1] == 1
penalty = torch.ones_like(retentions, requires_grad=False)
penalty[pls_flag] *= 10
penalty[pls_flag] *= 2
loss = (self.loss_fn(retentions, labels) * penalty).sum()
loss.backward()
if self.float_delta_t:
Expand Down Expand Up @@ -398,7 +398,7 @@ def eval(self):
stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0]
retentions = power_forgetting_curve(delta_ts, stabilities)
penalty = torch.ones_like(retentions, requires_grad=False)
pls_flag = sequences[seq_lens-1, torch.arange(real_batch_size), 1] == 1
pls_flag = sequences[torch.arange(real_batch_size), seq_lens-1, 1] == 1
penalty[pls_flag] *= 2
loss = (self.loss_fn(retentions, labels) * penalty).mean()
losses.append(loss)
Expand Down

0 comments on commit 9539e13

Please sign in to comment.