From 2b0752fbfb80b9b1cb1a00b21ecd6da966e7ba8a Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Fri, 10 May 2024 18:58:10 +0800 Subject: [PATCH] remove pre_train_set & next_train_set --- src/fsrs_optimizer/fsrs_optimizer.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index a768179..0ea5d7f 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -295,7 +295,7 @@ def __init__( self.max_seq_len = max_seq_len self.build_dataset(train_set, test_set) self.n_epoch = n_epoch - self.batch_nums = self.next_train_data_loader.batch_nums + self.batch_nums = self.train_data_loader.batch_nums self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=self.batch_nums * n_epoch ) @@ -304,18 +304,6 @@ def __init__( self.loss_fn = nn.BCELoss(reduction="none") def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]): - pre_train_set = train_set[train_set["i"] == 2] - self.pre_train_set = BatchDataset( - pre_train_set, batch_size=self.batch_size, max_seq_len=self.max_seq_len - ) - self.pre_train_data_loader = BatchLoader(self.pre_train_set) - - next_train_set = train_set[train_set["i"] > 2] - self.next_train_set = BatchDataset( - next_train_set, batch_size=self.batch_size, max_seq_len=self.max_seq_len - ) - self.next_train_data_loader = BatchLoader(self.next_train_set) - self.train_set = BatchDataset( train_set, batch_size=self.batch_size, max_seq_len=self.max_seq_len )