Skip to content

Commit

Permalink
Fix/consider short-term params when clipping PLS
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Dec 9, 2024
1 parent 7477d2b commit a924412
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ mod tests {
let fsrs = FSRS::new(Some(&[]))?;
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.216326, 0.038727]);
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.216286, 0.038692]);

let fsrs = FSRS::new(Some(PARAMETERS))?;
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
Expand All @@ -510,7 +510,7 @@ mod tests {
.universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true)
.unwrap();

assert_approx_eq([self_by_other, other_by_self], [0.016236, 0.031085]);
assert_approx_eq([self_by_other, other_by_self], [0.016570, 0.031037]);

Ok(())
}
Expand Down
8 changes: 6 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ impl<B: Backend> Model<B> {
* last_d.pow(-self.w.get(12))
* ((last_s.clone() + 1).pow(self.w.get(13)) - 1)
* ((-r + 1) * self.w.get(14)).exp();
let new_s_min = last_s / (self.w.get(17) * self.w.get(18)).exp();
new_s
.clone()
.mask_where(last_s.clone().lower(new_s), last_s)
.mask_where(new_s_min.clone().lower(new_s), new_s_min)
}

fn stability_short_term(&self, last_s: Tensor<B, 1>, rating: Tensor<B, 1>) -> Tensor<B, 1> {
Expand Down Expand Up @@ -380,7 +381,10 @@ mod tests {
&device,
);
let state = model.forward(delta_ts, ratings, None);
dbg!(&state);
let stability = state.stability.to_data();
let difficulty = state.difficulty.to_data();
stability.assert_approx_eq(&Data::from([0.2619, 1.7074, 5.8691, 25.0124, 0.2859, 2.1482]), 4);
difficulty.assert_approx_eq(&Data::from([8.0827, 7.0405, 5.2729, 2.1301, 8.0827, 7.0405]), 4);
}

#[test]
Expand Down
8 changes: 4 additions & 4 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn stability_after_success(w: &[f32], s: f32, r: f32, d: f32, rating: usize) ->

fn stability_after_failure(w: &[f32], s: f32, r: f32, d: f32) -> f32 {
(w[11] * d.powf(-w[12]) * ((s + 1.0).powf(w[13]) - 1.0) * f32::exp((1.0 - r) * w[14]))
.clamp(S_MIN, s)
.clamp(S_MIN, s / (w[17] * w[18]).exp())
}

fn stability_short_term(w: &[f32], s: f32, rating_offset: f32, session_len: f32) -> f32 {
Expand Down Expand Up @@ -903,7 +903,7 @@ mod tests {
simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?;
assert_eq!(
memorized_cnt_per_day[memorized_cnt_per_day.len() - 1],
6919.944
6911.91
);
Ok(())
}
Expand Down Expand Up @@ -1023,7 +1023,7 @@ mod tests {
..Default::default()
};
let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?;
assert_eq!(results.0[results.0.len() - 1], 6591.4854);
assert_eq!(results.0[results.0.len() - 1], 6559.517);
Ok(())
}

Expand Down Expand Up @@ -1076,7 +1076,7 @@ mod tests {
..Default::default()
};
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.84499365);
assert_eq!(optimal_retention, 0.84458643);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down

0 comments on commit a924412

Please sign in to comment.