diff --git a/src/inference.rs b/src/inference.rs index 96b68b0b..40a93d1c 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -17,7 +17,7 @@ use burn::tensor::ElementConversion; pub type Weights = [f32]; pub static DEFAULT_WEIGHTS: [f32; 17] = [ - 0.4, 0.6, 2.4, 5.8, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29, + 0.4, 0.9, 2.3, 10.9, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29, 2.61, ]; @@ -387,7 +387,7 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.20944944, 0.042762663]), 5); + .assert_approx_eq(&Data::from([0.20753297, 0.041122540]), 5); let fsrs = FSRS::new(Some(WEIGHTS))?; let metrics = fsrs.evaluate(items, |_| true).unwrap(); diff --git a/src/model.rs b/src/model.rs index 2d0b1fc5..f94eafa1 100644 --- a/src/model.rs +++ b/src/model.rs @@ -282,7 +282,7 @@ mod tests { let stability = model.init_stability(rating); assert_eq!( stability.to_data(), - Data::from([0.4, 0.6, 2.4, 5.8, 0.4, 0.6]) + Data::from([0.4, 0.9, 2.3, 10.9, 0.4, 0.9]) ) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index abbc2e24..59e4dea6 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -436,7 +436,7 @@ mod tests { 0.9, None, ); - assert_eq!(memorization, 2542.50223082592) + assert_eq!(memorization, 2635.689850107157) } #[test] @@ -444,7 +444,7 @@ mod tests { let config = SimulatorConfig::default(); let fsrs = FSRS::new(None)?; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8687319006249048); + assert_eq!(optimal_retention, 0.8633668648071942); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) } diff --git a/src/pre_training.rs b/src/pre_training.rs index 4007dbec..fea6e72b 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -4,7 +4,7 @@ use itertools::Itertools; use ndarray::Array1; use std::collections::HashMap; -static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[(1, 0.4), (2, 0.6), (3, 2.4), (4, 5.8)]; +static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[(1, 0.4), (2, 0.9), (3, 2.3), (4, 10.9)]; pub fn pretrain(fsrs_items: Vec, average_recall: f32) -> Result<[f32; 4]> { let pretrainset = create_pretrain_data(fsrs_items); @@ -340,7 +340,7 @@ mod tests { ], )]); let actual = search_parameters(pretrainset, 0.9); - let expected = [(4, 1.2502396)].into_iter().collect(); + let expected = [(4, 1.4098487)].into_iter().collect(); assert_eq!(actual, expected); } @@ -352,15 +352,15 @@ mod tests { let pretrainset = split_data(items, 1).0; assert_eq!( pretrain(pretrainset, average_recall).unwrap(), - [0.948_268_3, 1.696_434_9, 4.059_253_7, 9.001_528,], + [0.948_268_3, 1.695_154, 4.051_595_7, 9.332_188,], ) } #[test] fn test_smooth_and_fill() { - let mut rating_stability = HashMap::from([(1, 0.4), (3, 2.4), (4, 5.8)]); + let mut rating_stability = HashMap::from([(1, 0.4), (3, 2.3), (4, 10.9)]); let rating_count = HashMap::from([(1, 1), (2, 1), (3, 1), (4, 1)]); let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap(); - assert_eq!(actual, [0.4, 0.81906897, 2.4, 5.8,]); + assert_eq!(actual, [0.4, 0.8052433, 2.3, 10.9,]); } }