Skip to content

Commit

Permalink
improve preview in float delta_t mode
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Oct 9, 2024
1 parent 76733c4 commit 07a506f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
56 changes: 40 additions & 16 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,10 +1214,10 @@ def train(
tqdm.write("\nTraining finished!")
return plots

def preview(self, requestRetention: float, verbose=False):
def preview(self, requestRetention: float, verbose=False, n_steps=3):
my_collection = Collection(self.w, self.float_delta_t)
preview_text = "1:again, 2:hard, 3:good, 4:easy\n"
n_learning_steps = 3
n_learning_steps = n_steps if not self.float_delta_t else 0
for first_rating in (1, 2, 3, 4):
preview_text += f"\nfirst rating: {first_rating}\n"
t_history = "0"
Expand All @@ -1242,8 +1242,14 @@ def preview(self, requestRetention: float, verbose=False):
)
)
left -= 1
next_t = next_interval(states[0], requestRetention) if left <= 0 else 0
t_history += f",{int(next_t)}"
next_t = (
next_interval(
states[0].detach().numpy(), requestRetention, self.float_delta_t
)
if left <= 0
else 0
)
t_history += f",{next_t}"
d_history += f",{difficulty}"
s_history += f",{stability}"
r_history += f",3"
Expand All @@ -1254,13 +1260,22 @@ def preview(self, requestRetention: float, verbose=False):
+ ",".join(
[
(
f"{ivl}d"
if ivl < 30
f"{ivl:.4f}d"
if ivl < 1 and ivl > 0
else (
f"{ivl / 30:.1f}m" if ivl < 365 else f"{ivl / 365:.1f}y"
f"{ivl:.1f}d"
if ivl < 30
else (
f"{ivl / 30:.1f}m"
if ivl < 365
else f"{ivl / 365:.1f}y"
)
)
)
for ivl in map(int, t_history.split(","))
for ivl in map(
int if not self.float_delta_t else float,
t_history.split(","),
)
]
)
+ "\n"
Expand Down Expand Up @@ -1295,8 +1310,10 @@ def preview_sequence(self, test_rating_sequence: str, requestRetention: float):
for i in range(len(test_rating_sequence.split(","))):
r_history = test_rating_sequence[: 2 * i + 1]
states = my_collection.predict(t_history, r_history)
next_t = next_interval(states[0], requestRetention)
t_history += f",{int(next_t)}"
next_t = next_interval(
states[0].detach().numpy(), requestRetention, self.float_delta_t
)
t_history += f",{next_t}"
difficulty = round(float(states[1]), 1)
d_history += f",{difficulty}"
preview_text = f"rating history: {test_rating_sequence}\n"
Expand All @@ -1305,13 +1322,20 @@ def preview_sequence(self, test_rating_sequence: str, requestRetention: float):
+ ",".join(
[
(
f"{ivl}d"
if ivl < 30
else f"{ivl / 30:.1f}m"
if ivl < 365
else f"{ivl / 365:.1f}y"
f"{ivl:.4f}d"
if ivl < 1 and ivl > 0
else (
f"{ivl:.1f}d"
if ivl < 30
else (
f"{ivl / 30:.1f}m" if ivl < 365 else f"{ivl / 365:.1f}y"
)
)
)
for ivl in map(
int if not self.float_delta_t else float,
t_history.split(","),
)
for ivl in map(int, t_history.split(","))
]
)
+ "\n"
Expand Down
6 changes: 4 additions & 2 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ def power_forgetting_curve(t, s):
return (1 + FACTOR * t / s) ** DECAY


def next_interval(s, r):
def next_interval(s, r, float_ivl: bool = False):
ivl = s / FACTOR * (r ** (1 / DECAY) - 1)
return np.maximum(1, np.round(ivl))
return (
np.maximum(1, np.round(ivl).astype(int)) if not float_ivl else np.round(ivl, 6)
)


columns = [
Expand Down

0 comments on commit 07a506f

Please sign in to comment.