From e1a5ba491678fb1725ed50a4bd8f962d899c51bd Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Fri, 10 Jan 2025 12:08:11 +0800 Subject: [PATCH] add gamma to other.py && fix some bugs --- other.py | 124 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 88 insertions(+), 36 deletions(-) diff --git a/other.py b/other.py index cf3c0931868..4a82e8e04d8 100644 --- a/other.py +++ b/other.py @@ -275,6 +275,7 @@ def loss(stability): self.w.data[0:4] = Tensor( list(map(lambda x: max(min(INIT_S_MAX, x), S_MIN), init_s0)) ) + self.init_w_tensor = self.w.data.clone() class FSRS1ParameterClipper: @@ -962,12 +963,55 @@ class FSRS5(FSRS): ] clipper = FSRS5ParameterClipper() lr: float = 4e-2 + gamma: float = 2 wd: float = 1e-5 n_epoch: int = 5 + default_params_stddev_tensor = torch.tensor( + [ + 6.61, + 9.52, + 17.69, + 27.74, + 0.55, + 0.28, + 0.67, + 0.12, + 0.4, + 0.18, + 0.34, + 0.27, + 0.08, + 0.14, + 0.57, + 0.25, + 1.03, + 0.27, + 0.39, + ] + ) def __init__(self, w: List[float] = init_w): super(FSRS5, self).__init__() self.w = nn.Parameter(torch.tensor(w, dtype=torch.float32)) + self.init_w_tensor = self.w.data.clone() + + def iter( + self, + sequences: Tensor, + delta_ts: Tensor, + seq_lens: Tensor, + real_batch_size: int, + ) -> dict[str, Tensor]: + output = super().iter(sequences, delta_ts, seq_lens, real_batch_size) + output["penalty"] = ( + torch.sum( + torch.square(self.w - self.init_w_tensor) + / torch.square(self.default_params_stddev_tensor) + ) + * real_batch_size + * self.gamma + ) + return output def forgetting_curve(self, t, s): return (1 + FACTOR * t / s) ** DECAY @@ -2070,23 +2114,24 @@ def __init__( self.loss_fn = nn.BCELoss(reduction="none") def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]): - pre_train_set = train_set[train_set["i"] == 2] - self.pre_train_set = BatchDataset( - pre_train_set.copy(), - self.batch_size, - max_seq_len=self.max_seq_len, - device=DEVICE, - ) - self.pre_train_data_loader = BatchLoader(self.pre_train_set) + if isinstance(self.model, (FSRS4, FSRS4dot5)): + pre_train_set = train_set[train_set["i"] == 2] + self.pre_train_set = BatchDataset( + pre_train_set.copy(), + self.batch_size, + max_seq_len=self.max_seq_len, + device=DEVICE, + ) + self.pre_train_data_loader = BatchLoader(self.pre_train_set) - next_train_set = train_set[train_set["i"] > 2] - self.next_train_set = BatchDataset( - next_train_set.copy(), - self.batch_size, - max_seq_len=self.max_seq_len, - device=DEVICE, - ) - self.next_train_data_loader = BatchLoader(self.next_train_set) + next_train_set = train_set[train_set["i"] > 2] + self.next_train_set = BatchDataset( + next_train_set.copy(), + self.batch_size, + max_seq_len=self.max_seq_len, + device=DEVICE, + ) + self.next_train_data_loader = BatchLoader(self.next_train_set) self.train_set = BatchDataset( train_set.copy(), @@ -2112,6 +2157,7 @@ def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame def train(self): best_loss = np.inf + epoch_len = len(self.train_set.y_train) for k in range(self.n_epoch): weighted_loss, w = self.eval() if weighted_loss < best_loss: @@ -2129,6 +2175,8 @@ def train(self): self.loss_fn(result["retentions"], result["labels"]) * result["weights"] ).sum() + if "penalty" in result: + loss += result["penalty"] / epoch_len loss.backward() if isinstance( self.model, (FSRS4, FSRS4dot5) @@ -2156,6 +2204,7 @@ def eval(self): continue loss = 0 total = 0 + epoch_len = len(data_loader.dataset.y_train) for batch in data_loader: result = iter(self.model, batch) loss += ( @@ -2167,6 +2216,8 @@ def eval(self): .detach() .item() ) + if "penalty" in result: + loss += (result["penalty"] / epoch_len).detach().item() total += batch[3].shape[0] losses.append(loss / total) self.train_data_loader.shuffle = True @@ -2475,39 +2526,39 @@ def process(user_id): ) df_revlogs.drop(columns=["user_id"], inplace=True) if MODEL_NAME in ("RNN", "LSTM", "GRU"): - model = RNN + Model = RNN elif MODEL_NAME == "GRU-P": - model = GRU_P + Model = GRU_P elif MODEL_NAME == "FSRSv1": - model = FSRS1 + Model = FSRS1 elif MODEL_NAME == "FSRSv2": - model = FSRS2 + Model = FSRS2 elif MODEL_NAME == "FSRSv3": - model = FSRS3 + Model = FSRS3 elif MODEL_NAME == "FSRSv4": - model = FSRS4 + Model = FSRS4 elif MODEL_NAME == "FSRS-4.5": - model = FSRS4dot5 + Model = FSRS4dot5 elif MODEL_NAME == "FSRS-5": global SHORT_TERM SHORT_TERM = True - model = FSRS5 + Model = FSRS5 elif MODEL_NAME == "HLR": - model = HLR + Model = HLR elif MODEL_NAME == "Transformer": - model = Transformer + Model = Transformer elif MODEL_NAME == "ACT-R": - model = ACT_R + Model = ACT_R elif MODEL_NAME in ("DASH", "DASH[MCM]"): - model = DASH + Model = DASH elif MODEL_NAME == "DASH[ACT-R]": - model = DASH_ACTR + Model = DASH_ACTR elif MODEL_NAME == "NN-17": - model = NN_17 + Model = NN_17 elif MODEL_NAME == "SM2-trainable": - model = SM2 + Model = SM2 elif MODEL_NAME == "Anki": - model = Anki + Model = Anki dataset = create_features(df_revlogs, MODEL_NAME) if dataset.shape[0] < 6: @@ -2544,14 +2595,15 @@ def process(user_id): for partition in train_set["partition"].unique(): try: train_partition = train_set[train_set["partition"] == partition].copy() + model = Model() if RECENCY: x = np.linspace(0, 1, len(train_partition)) train_partition["weights"] = 0.25 + 0.75 * np.power(x, 3) if DRY_RUN: - partition_weights[partition] = model().state_dict() + partition_weights[partition] = model.state_dict() continue trainer = Trainer( - model(), + model, train_partition, None, n_epoch=model.n_epoch, @@ -2567,7 +2619,7 @@ def process(user_id): else: tb = sys.exc_info()[2] print("User:", user_id, "Error:", e.with_traceback(tb)) - partition_weights[partition] = model().state_dict() + partition_weights[partition] = Model().state_dict() w_list.append(partition_weights) p = [] @@ -2578,7 +2630,7 @@ def process(user_id): for partition in testset["partition"].unique(): partition_testset = testset[testset["partition"] == partition].copy() weights = w.get(partition, None) - my_collection = Collection(model(weights) if weights else model()) + my_collection = Collection(Model(weights) if weights else Model()) retentions, stabilities = my_collection.batch_predict(partition_testset) partition_testset["p"] = retentions if stabilities: