Skip to content

Commit

Permalink
remove pre_train_set & next_train_set
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed May 10, 2024
1 parent 3f5b3b2 commit 2b0752f
Showing 1 changed file with 1 addition and 13 deletions.
14 changes: 1 addition & 13 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(
self.max_seq_len = max_seq_len
self.build_dataset(train_set, test_set)
self.n_epoch = n_epoch
self.batch_nums = self.next_train_data_loader.batch_nums
self.batch_nums = self.train_data_loader.batch_nums
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=self.batch_nums * n_epoch
)
Expand All @@ -304,18 +304,6 @@ def __init__(
self.loss_fn = nn.BCELoss(reduction="none")

def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]):
pre_train_set = train_set[train_set["i"] == 2]
self.pre_train_set = BatchDataset(
pre_train_set, batch_size=self.batch_size, max_seq_len=self.max_seq_len
)
self.pre_train_data_loader = BatchLoader(self.pre_train_set)

next_train_set = train_set[train_set["i"] > 2]
self.next_train_set = BatchDataset(
next_train_set, batch_size=self.batch_size, max_seq_len=self.max_seq_len
)
self.next_train_data_loader = BatchLoader(self.next_train_set)

self.train_set = BatchDataset(
train_set, batch_size=self.batch_size, max_seq_len=self.max_seq_len
)
Expand Down

0 comments on commit 2b0752f

Please sign in to comment.