From 02842aa8200a22bf6436d2538c4cb9fd08be1c00 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 21 Mar 2024 16:41:47 +0800 Subject: [PATCH] Fix/PerformanceWarning (#94) --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 32 +++++++++++++--------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1f24e96..43c1b5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "4.26.5" +version = "4.26.6" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 31bf922..b7d4f2d 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -856,9 +856,7 @@ def loss(stability): rating_stability[int(first_rating)] = stability rating_count[int(first_rating)] = sum(count) predict_recall = power_forgetting_curve(delta_t, *params) - rmse = root_mean_squared_error( - recall, predict_recall, sample_weight=count - ) + rmse = root_mean_squared_error(recall, predict_recall, sample_weight=count) if verbose: fig = plt.figure() @@ -1506,7 +1504,8 @@ def formula_analysis(self): analysis_df[analysis_df["r_history"].str.endswith(last_rating)] .groupby( by=["last_s_bin", "last_d_bin", "last_r_bin", "delta_t"], - group_keys=False, + group_keys=True, + as_index=False, ) .agg( { @@ -1517,12 +1516,15 @@ def formula_analysis(self): } ) ) - analysis_group.reset_index(inplace=True) + analysis_group.columns = [ + "_".join(col_name).rstrip("_") + for col_name in analysis_group.columns + ] def cal_stability(tmp): delta_t = tmp["delta_t"] - recall = tmp["y"]["mean"] - count = tmp["y"]["count"] + recall = tmp["y_mean"] + count = tmp["y_count"] total_count = sum(count) def loss(stability): @@ -1542,18 +1544,16 @@ def loss(stability): else: tmp["true_s"] = np.nan tmp["predicted_s"] = np.average( - tmp["stability"]["mean"], weights=count + tmp["stability_mean"], weights=count ) tmp["total_count"] = total_count return tmp - analysis_group = ( - analysis_group.groupby(by=[group_key], group_keys=False) - .apply(cal_stability) - .reset_index(drop=True) - ) + analysis_group = analysis_group.groupby( + by=[group_key], group_keys=False + ).apply(cal_stability) analysis_group.dropna(inplace=True) - analysis_group.drop_duplicates(subset=[(group_key, "")], inplace=True) + analysis_group.drop_duplicates(subset=[group_key], inplace=True) analysis_group.sort_values(by=[group_key], inplace=True) rmse = root_mean_squared_error( analysis_group["true_s"], @@ -1877,9 +1877,7 @@ def rmse_matrix(df): .agg({"y": "mean", "p": "mean", "card_id": "count"}) .reset_index() ) - return root_mean_squared_error( - tmp["y"], tmp["p"], sample_weight=tmp["card_id"] - ) + return root_mean_squared_error(tmp["y"], tmp["p"], sample_weight=tmp["card_id"]) if __name__ == "__main__":