Skip to content

Commit

Permalink
Feat/add fuzz option to simulator (#160)
Browse files Browse the repository at this point in the history
* Feat/add fuzz option to simulator

* bump version
  • Loading branch information
L-M-Sherlock authored Jan 10, 2025
1 parent a2bbc6b commit aba321d
Showing 1 changed file with 49 additions and 4 deletions.
53 changes: 49 additions & 4 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,54 @@ def power_forgetting_curve(t, s):
return (1 + FACTOR * t / s) ** DECAY


def next_interval(s, r, float_ivl: bool = False):
def next_interval(s, r, float_ivl: bool = False, fuzz: bool = False):
ivl = s / FACTOR * (r ** (1 / DECAY) - 1)
return (
np.maximum(1, np.round(ivl).astype(int)) if not float_ivl else np.round(ivl, 6)
)
if float_ivl:
ivl = np.round(ivl, 6)
else:
ivl = np.maximum(1, np.round(ivl).astype(int))
if fuzz:
fuzz_mask = ivl >= 3
ivl[fuzz_mask] = fuzz_interval(ivl[fuzz_mask])
return ivl


FUZZ_RANGES = [
{
"start": 2.5,
"end": 7.0,
"factor": 0.15,
},
{
"start": 7.0,
"end": 20.0,
"factor": 0.1,
},
{
"start": 20.0,
"end": math.inf,
"factor": 0.05,
},
]


def get_fuzz_range(interval):
delta = np.ones_like(interval, dtype=float)
for range in FUZZ_RANGES:
delta += range["factor"] * np.maximum(
np.minimum(interval, range["end"]) - range["start"], 0.0
)
min_ivl = np.round(interval - delta).astype(int)
max_ivl = np.round(interval + delta).astype(int)
min_ivl = np.maximum(2, min_ivl)
min_ivl = np.minimum(min_ivl, max_ivl)
return min_ivl, max_ivl


def fuzz_interval(interval):
min_ivl, max_ivl = get_fuzz_range(interval)
# max_ivl + 1 because randint upper bound is exclusive
return np.random.randint(min_ivl, max_ivl + 1, size=min_ivl.shape)


columns = [
Expand Down Expand Up @@ -64,6 +107,7 @@ def simulate(
forget_session_len=DEFAULT_FORGET_SESSION_LEN,
loss_aversion=2.5,
seed=42,
fuzz=False,
):
np.random.seed(seed)
card_table = np.zeros((len(columns), deck_size))
Expand Down Expand Up @@ -224,6 +268,7 @@ def mean_reversion(init, current):
next_interval(
card_table[col["stability"]][true_review | true_learn],
request_retention,
fuzz=fuzz,
),
1,
max_ivl,
Expand Down

0 comments on commit aba321d

Please sign in to comment.