diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 24151aa..a254d08 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -166,9 +166,10 @@ def __call__(self, module): if hasattr(module, "w"): w = module.w.data w[0] = w[0].clamp(S_MIN, 100) - w[1] = w[1].clamp(w[0] * 1.05, 100) - w[2] = w[2].clamp(w[1] * 1.05, 100) - w[3] = w[3].clamp(w[2] * 1.05, 100) + # this ensures that w[n] is at least 5% greater than w[n-1], and also greater by at least 1 hour + w[1] = w[1].clamp(max(w[0] * 1.05, w[0] + 0.05), 100) + w[2] = w[2].clamp(max(w[1] * 1.05, w[1] + 0.05), 100) + w[3] = w[3].clamp(max(w[2] * 1.05, w[2] + 0.05), 100) w[4] = w[4].clamp(1, 10) w[5] = w[5].clamp(0.01, 4) w[6] = w[6].clamp(0.01, 4)