Skip to content

Commit

Permalink
Fix/PerformanceWarning (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Mar 21, 2024
1 parent 08eb5d1 commit 02842aa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.26.5"
version = "4.26.6"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
32 changes: 15 additions & 17 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,9 +856,7 @@ def loss(stability):
rating_stability[int(first_rating)] = stability
rating_count[int(first_rating)] = sum(count)
predict_recall = power_forgetting_curve(delta_t, *params)
rmse = root_mean_squared_error(
recall, predict_recall, sample_weight=count
)
rmse = root_mean_squared_error(recall, predict_recall, sample_weight=count)

if verbose:
fig = plt.figure()
Expand Down Expand Up @@ -1506,7 +1504,8 @@ def formula_analysis(self):
analysis_df[analysis_df["r_history"].str.endswith(last_rating)]
.groupby(
by=["last_s_bin", "last_d_bin", "last_r_bin", "delta_t"],
group_keys=False,
group_keys=True,
as_index=False,
)
.agg(
{
Expand All @@ -1517,12 +1516,15 @@ def formula_analysis(self):
}
)
)
analysis_group.reset_index(inplace=True)
analysis_group.columns = [
"_".join(col_name).rstrip("_")
for col_name in analysis_group.columns
]

def cal_stability(tmp):
delta_t = tmp["delta_t"]
recall = tmp["y"]["mean"]
count = tmp["y"]["count"]
recall = tmp["y_mean"]
count = tmp["y_count"]
total_count = sum(count)

def loss(stability):
Expand All @@ -1542,18 +1544,16 @@ def loss(stability):
else:
tmp["true_s"] = np.nan
tmp["predicted_s"] = np.average(
tmp["stability"]["mean"], weights=count
tmp["stability_mean"], weights=count
)
tmp["total_count"] = total_count
return tmp

analysis_group = (
analysis_group.groupby(by=[group_key], group_keys=False)
.apply(cal_stability)
.reset_index(drop=True)
)
analysis_group = analysis_group.groupby(
by=[group_key], group_keys=False
).apply(cal_stability)
analysis_group.dropna(inplace=True)
analysis_group.drop_duplicates(subset=[(group_key, "")], inplace=True)
analysis_group.drop_duplicates(subset=[group_key], inplace=True)
analysis_group.sort_values(by=[group_key], inplace=True)
rmse = root_mean_squared_error(
analysis_group["true_s"],
Expand Down Expand Up @@ -1877,9 +1877,7 @@ def rmse_matrix(df):
.agg({"y": "mean", "p": "mean", "card_id": "count"})
.reset_index()
)
return root_mean_squared_error(
tmp["y"], tmp["p"], sample_weight=tmp["card_id"]
)
return root_mean_squared_error(tmp["y"], tmp["p"], sample_weight=tmp["card_id"])


if __name__ == "__main__":
Expand Down

0 comments on commit 02842aa

Please sign in to comment.