diff --git a/pyproject.toml b/pyproject.toml index 3833775..911356b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.0.7" +version = "5.0.8" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_simulator.py b/src/fsrs_optimizer/fsrs_simulator.py index 5eb845d..611e58d 100644 --- a/src/fsrs_optimizer/fsrs_simulator.py +++ b/src/fsrs_optimizer/fsrs_simulator.py @@ -263,12 +263,19 @@ def sample( workload_only=False, ): results = [] - if learn_span < 100: - SAMPLE_SIZE = 16 - elif learn_span < 365: - SAMPLE_SIZE = 8 - else: - SAMPLE_SIZE = 4 + + def best_sample_size(days_to_simulate): + if days_to_simulate <= 30: + return 45 + elif days_to_simulate >= 365: + return 4 + else: + a1, a2, a3 = 8.20e-07, 2.41e-03, 1.30e-02 + factor = a1 * np.power(days_to_simulate, 2) + a2 * days_to_simulate + a3 + default_sample_size = 4 + return int(default_sample_size/factor) + + SAMPLE_SIZE = best_sample_size(learn_span) for i in range(SAMPLE_SIZE): _, _, _, memorized_cnt_per_day, cost_per_day = simulate( @@ -422,6 +429,7 @@ def workload_graph(default_params, sampling_size=30): default_params["deck_size"] / default_params["learn_span"] ) default_params["review_limit_perday"] = math.inf + default_params["loss_aversion"] = 1 workload = [sample(r=r, workload_only=True, **default_params) for r in R] # this is for testing @@ -513,10 +521,14 @@ def workload_graph(default_params, sampling_size=30): ax.xaxis.set_tick_params(labelsize=14) ax.set_xlim(0.7, 0.99) - if max_w >= 4.5 * min_w: - lim = 4.5 * min_w - elif max_w >= 3.5 * min_w: + if max_w >= 3.5 * min_w: lim = 3.5 * min_w + elif max_w >= 3 * min_w: + lim = 3 * min_w + elif max_w >= 2.5 * min_w: + lim = 2.5 * min_w + elif max_w >= 2 * min_w: + lim = 2 * min_w else: lim = 1.1 * max_w @@ -527,13 +539,13 @@ def workload_graph(default_params, sampling_size=30): ax.text( 0.701, min_w, - "min. workload", + "minimum workload", ha="left", va="bottom", color="black", fontsize=12, ) - if max_w >= 1.8 * min_w: + if lim >= 1.8 * min_w: ax.axhline(y=1.5 * min_w, color="black", alpha=0.75, ls="--") ax.text( 0.701, @@ -544,7 +556,7 @@ def workload_graph(default_params, sampling_size=30): color="black", fontsize=12, ) - if max_w >= 2.3 * min_w: + if lim >= 2.3 * min_w: ax.axhline(y=2 * min_w, color="black", alpha=0.75, ls="--") ax.text( 0.701, @@ -555,7 +567,7 @@ def workload_graph(default_params, sampling_size=30): color="black", fontsize=12, ) - if max_w >= 2.8 * min_w: + if lim >= 2.8 * min_w: ax.axhline(y=2.5 * min_w, color="black", alpha=0.75, ls="--") ax.text( 0.701, @@ -566,7 +578,7 @@ def workload_graph(default_params, sampling_size=30): color="black", fontsize=12, ) - if max_w >= 3.3 * min_w: + if lim >= 3.3 * min_w: ax.axhline(y=3 * min_w, color="black", alpha=0.75, ls="--") ax.text( 0.701,