Skip to content

Commit

Permalink
add gamma to other.py && fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jan 10, 2025
1 parent 3e09651 commit e1a5ba4
Showing 1 changed file with 88 additions and 36 deletions.
124 changes: 88 additions & 36 deletions other.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def loss(stability):
self.w.data[0:4] = Tensor(
list(map(lambda x: max(min(INIT_S_MAX, x), S_MIN), init_s0))
)
self.init_w_tensor = self.w.data.clone()


class FSRS1ParameterClipper:
Expand Down Expand Up @@ -962,12 +963,55 @@ class FSRS5(FSRS):
]
clipper = FSRS5ParameterClipper()
lr: float = 4e-2
gamma: float = 2
wd: float = 1e-5
n_epoch: int = 5
default_params_stddev_tensor = torch.tensor(
[
6.61,
9.52,
17.69,
27.74,
0.55,
0.28,
0.67,
0.12,
0.4,
0.18,
0.34,
0.27,
0.08,
0.14,
0.57,
0.25,
1.03,
0.27,
0.39,
]
)

def __init__(self, w: List[float] = init_w):
super(FSRS5, self).__init__()
self.w = nn.Parameter(torch.tensor(w, dtype=torch.float32))
self.init_w_tensor = self.w.data.clone()

def iter(
self,
sequences: Tensor,
delta_ts: Tensor,
seq_lens: Tensor,
real_batch_size: int,
) -> dict[str, Tensor]:
output = super().iter(sequences, delta_ts, seq_lens, real_batch_size)
output["penalty"] = (
torch.sum(
torch.square(self.w - self.init_w_tensor)
/ torch.square(self.default_params_stddev_tensor)
)
* real_batch_size
* self.gamma
)
return output

def forgetting_curve(self, t, s):
return (1 + FACTOR * t / s) ** DECAY
Expand Down Expand Up @@ -2070,23 +2114,24 @@ def __init__(
self.loss_fn = nn.BCELoss(reduction="none")

def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]):
pre_train_set = train_set[train_set["i"] == 2]
self.pre_train_set = BatchDataset(
pre_train_set.copy(),
self.batch_size,
max_seq_len=self.max_seq_len,
device=DEVICE,
)
self.pre_train_data_loader = BatchLoader(self.pre_train_set)
if isinstance(self.model, (FSRS4, FSRS4dot5)):
pre_train_set = train_set[train_set["i"] == 2]
self.pre_train_set = BatchDataset(
pre_train_set.copy(),
self.batch_size,
max_seq_len=self.max_seq_len,
device=DEVICE,
)
self.pre_train_data_loader = BatchLoader(self.pre_train_set)

next_train_set = train_set[train_set["i"] > 2]
self.next_train_set = BatchDataset(
next_train_set.copy(),
self.batch_size,
max_seq_len=self.max_seq_len,
device=DEVICE,
)
self.next_train_data_loader = BatchLoader(self.next_train_set)
next_train_set = train_set[train_set["i"] > 2]
self.next_train_set = BatchDataset(
next_train_set.copy(),
self.batch_size,
max_seq_len=self.max_seq_len,
device=DEVICE,
)
self.next_train_data_loader = BatchLoader(self.next_train_set)

self.train_set = BatchDataset(
train_set.copy(),
Expand All @@ -2112,6 +2157,7 @@ def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame

def train(self):
best_loss = np.inf
epoch_len = len(self.train_set.y_train)
for k in range(self.n_epoch):
weighted_loss, w = self.eval()
if weighted_loss < best_loss:
Expand All @@ -2129,6 +2175,8 @@ def train(self):
self.loss_fn(result["retentions"], result["labels"])
* result["weights"]
).sum()
if "penalty" in result:
loss += result["penalty"] / epoch_len
loss.backward()
if isinstance(
self.model, (FSRS4, FSRS4dot5)
Expand Down Expand Up @@ -2156,6 +2204,7 @@ def eval(self):
continue
loss = 0
total = 0
epoch_len = len(data_loader.dataset.y_train)
for batch in data_loader:
result = iter(self.model, batch)
loss += (
Expand All @@ -2167,6 +2216,8 @@ def eval(self):
.detach()
.item()
)
if "penalty" in result:
loss += (result["penalty"] / epoch_len).detach().item()
total += batch[3].shape[0]
losses.append(loss / total)
self.train_data_loader.shuffle = True
Expand Down Expand Up @@ -2475,39 +2526,39 @@ def process(user_id):
)
df_revlogs.drop(columns=["user_id"], inplace=True)
if MODEL_NAME in ("RNN", "LSTM", "GRU"):
model = RNN
Model = RNN
elif MODEL_NAME == "GRU-P":
model = GRU_P
Model = GRU_P
elif MODEL_NAME == "FSRSv1":
model = FSRS1
Model = FSRS1
elif MODEL_NAME == "FSRSv2":
model = FSRS2
Model = FSRS2
elif MODEL_NAME == "FSRSv3":
model = FSRS3
Model = FSRS3
elif MODEL_NAME == "FSRSv4":
model = FSRS4
Model = FSRS4
elif MODEL_NAME == "FSRS-4.5":
model = FSRS4dot5
Model = FSRS4dot5
elif MODEL_NAME == "FSRS-5":
global SHORT_TERM
SHORT_TERM = True
model = FSRS5
Model = FSRS5
elif MODEL_NAME == "HLR":
model = HLR
Model = HLR
elif MODEL_NAME == "Transformer":
model = Transformer
Model = Transformer
elif MODEL_NAME == "ACT-R":
model = ACT_R
Model = ACT_R
elif MODEL_NAME in ("DASH", "DASH[MCM]"):
model = DASH
Model = DASH
elif MODEL_NAME == "DASH[ACT-R]":
model = DASH_ACTR
Model = DASH_ACTR
elif MODEL_NAME == "NN-17":
model = NN_17
Model = NN_17
elif MODEL_NAME == "SM2-trainable":
model = SM2
Model = SM2
elif MODEL_NAME == "Anki":
model = Anki
Model = Anki

dataset = create_features(df_revlogs, MODEL_NAME)
if dataset.shape[0] < 6:
Expand Down Expand Up @@ -2544,14 +2595,15 @@ def process(user_id):
for partition in train_set["partition"].unique():
try:
train_partition = train_set[train_set["partition"] == partition].copy()
model = Model()
if RECENCY:
x = np.linspace(0, 1, len(train_partition))
train_partition["weights"] = 0.25 + 0.75 * np.power(x, 3)
if DRY_RUN:
partition_weights[partition] = model().state_dict()
partition_weights[partition] = model.state_dict()
continue
trainer = Trainer(
model(),
model,
train_partition,
None,
n_epoch=model.n_epoch,
Expand All @@ -2567,7 +2619,7 @@ def process(user_id):
else:
tb = sys.exc_info()[2]
print("User:", user_id, "Error:", e.with_traceback(tb))
partition_weights[partition] = model().state_dict()
partition_weights[partition] = Model().state_dict()
w_list.append(partition_weights)

p = []
Expand All @@ -2578,7 +2630,7 @@ def process(user_id):
for partition in testset["partition"].unique():
partition_testset = testset[testset["partition"] == partition].copy()
weights = w.get(partition, None)
my_collection = Collection(model(weights) if weights else model())
my_collection = Collection(Model(weights) if weights else Model())
retentions, stabilities = my_collection.batch_predict(partition_testset)
partition_testset["p"] = retentions
if stabilities:
Expand Down

0 comments on commit e1a5ba4

Please sign in to comment.