From 42baaaf4e578ce0afc1863af9825a7b24f79e54c Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Wed, 27 Nov 2024 10:31:31 +0800 Subject: [PATCH] ignore same-day reviews --- src/fsrs_optimizer/fsrs_optimizer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index c7cda30..f6bf23e 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -233,6 +233,7 @@ def __init__( ) self.t_train = torch.tensor(dataframe["delta_t"].values, dtype=torch.float) self.y_train = torch.tensor(dataframe["y"].values, dtype=torch.float) + self.last_rating = torch.tensor(dataframe["last_rating"].values, dtype=torch.long) self.seq_len = torch.tensor( dataframe["tensor"].map(len).values, dtype=torch.long ) @@ -252,6 +253,7 @@ def __init__( sequences_truncated.transpose(0, 1).to(device), self.t_train[start_index:end_index].to(device), self.y_train[start_index:end_index].to(device), + self.last_rating[start_index:end_index].to(device), seq_lens.to(device), ) @@ -312,6 +314,7 @@ def __init__( self.avg_eval_losses = [] self.loss_fn = nn.BCELoss(reduction="none") self.float_delta_t = float_delta_t + self.pls_penalty = 4 def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]): self.train_set = BatchDataset( @@ -343,14 +346,13 @@ def train(self, verbose: bool = True): for i, batch in enumerate(self.train_data_loader): self.model.train() self.optimizer.zero_grad() - sequences, delta_ts, labels, seq_lens = batch + sequences, delta_ts, labels, last_ratings, seq_lens = batch real_batch_size = seq_lens.shape[0] outputs, _ = self.model(sequences) stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] retentions = power_forgetting_curve(delta_ts, stabilities) - pls_flag = sequences[seq_lens-1, torch.arange(real_batch_size), 1] == 1 penalty = torch.ones_like(retentions, requires_grad=False) - penalty[pls_flag] *= 2 + penalty[last_ratings == 1] *= self.pls_penalty loss = (self.loss_fn(retentions, labels) * penalty).sum() loss.backward() if self.float_delta_t: @@ -387,10 +389,11 @@ def eval(self): if len(dataset) == 0: losses.append(0) continue - sequences, delta_ts, labels, seq_lens = ( + sequences, delta_ts, labels, last_ratings, seq_lens = ( dataset.x_train, dataset.t_train, dataset.y_train, + dataset.last_rating, dataset.seq_len, ) real_batch_size = seq_lens.shape[0] @@ -398,8 +401,7 @@ def eval(self): stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] retentions = power_forgetting_curve(delta_ts, stabilities) penalty = torch.ones_like(retentions, requires_grad=False) - pls_flag = sequences[torch.arange(real_batch_size), seq_lens-1, 1] == 1 - penalty[pls_flag] *= 2 + penalty[last_ratings == 1] *= self.pls_penalty loss = (self.loss_fn(retentions, labels) * penalty).mean() losses.append(loss) self.avg_train_losses.append(losses[0]) @@ -883,7 +885,6 @@ def cum_concat(x): "real_days", "review_rating", "t_history", - "last_rating", "y", ], inplace=True, @@ -1178,7 +1179,6 @@ def train( lambda x: lineToTensor(list(zip([x["t_history"]], [x["r_history"]]))[0]), axis=1, ) - self.dataset["group"] = self.dataset["r_history"] + self.dataset["t_history"] if verbose: tqdm.write("Tensorized!")