From 37ece5e87d6c65f02a2ab11533c0f39015f9adca Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sat, 20 Apr 2024 10:47:47 +0800 Subject: [PATCH] Feat/not sqrt in weights for pretrain (#106) --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 77 +++++++++++++--------------- 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7954678..5c655cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "4.27.7" +version = "4.28.0" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 4faf65a..175a949 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -48,23 +48,23 @@ Relearning = 3 DEFAULT_WEIGHT = [ - 0.5701, - 1.4436, - 4.1386, - 10.9355, - 5.1443, - 1.2006, - 0.8627, - 0.0362, - 1.629, - 0.1342, - 1.0166, - 2.1174, - 0.0839, - 0.3204, - 1.4676, - 0.219, - 2.8237, + 0.4872, + 1.4003, + 3.7145, + 13.8206, + 5.1618, + 1.2298, + 0.8975, + 0.031, + 1.6474, + 0.1367, + 1.0461, + 2.1072, + 0.0793, + 0.3246, + 1.587, + 0.2272, + 2.8755, ] S_MIN = 0.01 @@ -443,6 +443,18 @@ def remove_non_continuous_rows(group): return group.loc[: first_non_continuous_index - 1] +def fit_stability(delta_t, retention, size): + def loss(stability): + y_pred = power_forgetting_curve(delta_t, stability) + loss = sum( + -(retention * np.log(y_pred) + (1 - retention) * np.log(1 - y_pred)) * size + ) + return loss + + res = minimize(loss, x0=1, bounds=[(S_MIN, 36500)]) + return res.x[0] + + class Optimizer: def __init__(self) -> None: tqdm.pandas() @@ -707,12 +719,9 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame: group["group_cnt"] = group_cnt if group["i"].values[0] > 1: group["stability"] = round( - curve_fit( - power_forgetting_curve, - group["delta_t"], - group["retention"], - sigma=1 / np.sqrt(group["total_cnt"]), - )[0][0], + fit_stability( + group["delta_t"], group["retention"], group["total_cnt"] + ), 1, ) else: @@ -820,7 +829,6 @@ def pretrain(self, dataset=None, verbose=True): group["y"]["count"] + 1 ) count = group["y"]["count"] - weight = np.sqrt(count) init_s0 = r_s0_default[first_rating] @@ -828,7 +836,7 @@ def loss(stability): y_pred = power_forgetting_curve(delta_t, stability) logloss = sum( -(recall * np.log(y_pred) + (1 - recall) * np.log(1 - y_pred)) - * weight + * count ) l1 = np.abs(stability - init_s0) / 16 return logloss + l1 @@ -837,7 +845,7 @@ def loss(stability): loss, x0=init_s0, bounds=((S_MIN, 100),), - options={"maxiter": int(sum(weight))}, + options={"maxiter": int(sum(count))}, ) params = res.x stability = params[0] @@ -1494,22 +1502,7 @@ def cal_stability(tmp): count = tmp["y_count"] total_count = sum(count) - def loss(stability): - y_pred = power_forgetting_curve(delta_t, stability) - logloss = sum( - -( - recall * np.log(y_pred) - + (1 - recall) * np.log(1 - y_pred) - ) - * np.sqrt(count) - ) - return logloss - - res = minimize(loss, 1, bounds=((S_MIN, 3650),)) - if res.success: - tmp["true_s"] = res.x[0] - else: - tmp["true_s"] = np.nan + tmp["true_s"] = fit_stability(delta_t, recall, count) tmp["predicted_s"] = np.average( tmp["stability_mean"], weights=count )