Skip to content

Commit

Permalink
Feat/update initial values of S0
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Oct 14, 2023
1 parent 0e71824 commit c06f616
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use burn::tensor::ElementConversion;
pub type Weights = [f32];

pub static DEFAULT_WEIGHTS: [f32; 17] = [
0.4, 0.6, 2.4, 5.8, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29,
0.4, 0.9, 2.3, 10.9, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29,
2.61,
];

Expand Down Expand Up @@ -387,7 +387,7 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.20944944, 0.042762663]), 5);
.assert_approx_eq(&Data::from([0.20753297, 0.041122540]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ mod tests {
let stability = model.init_stability(rating);
assert_eq!(
stability.to_data(),
Data::from([0.4, 0.6, 2.4, 5.8, 0.4, 0.6])
Data::from([0.4, 0.9, 2.3, 10.9, 0.4, 0.9])
)
}

Expand Down
4 changes: 2 additions & 2 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,15 +436,15 @@ mod tests {
0.9,
None,
);
assert_eq!(memorization, 2542.50223082592)
assert_eq!(memorization, 2635.689850107157)
}

#[test]
fn optimal_retention() -> Result<()> {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8687319006249048);
assert_eq!(optimal_retention, 0.8633668648071942);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down
10 changes: 5 additions & 5 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use itertools::Itertools;
use ndarray::Array1;
use std::collections::HashMap;

static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[(1, 0.4), (2, 0.6), (3, 2.4), (4, 5.8)];
static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[(1, 0.4), (2, 0.9), (3, 2.3), (4, 10.9)];

pub fn pretrain(fsrs_items: Vec<FSRSItem>, average_recall: f32) -> Result<[f32; 4]> {
let pretrainset = create_pretrain_data(fsrs_items);
Expand Down Expand Up @@ -340,7 +340,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
let expected = [(4, 1.2502396)].into_iter().collect();
let expected = [(4, 1.4098487)].into_iter().collect();
assert_eq!(actual, expected);
}

Expand All @@ -352,15 +352,15 @@ mod tests {
let pretrainset = split_data(items, 1).0;
assert_eq!(
pretrain(pretrainset, average_recall).unwrap(),
[0.948_268_3, 1.696_434_9, 4.059_253_7, 9.001_528,],
[0.948_268_3, 1.695_154, 4.051_595_7, 9.332_188,],
)
}

#[test]
fn test_smooth_and_fill() {
let mut rating_stability = HashMap::from([(1, 0.4), (3, 2.4), (4, 5.8)]);
let mut rating_stability = HashMap::from([(1, 0.4), (3, 2.3), (4, 10.9)]);
let rating_count = HashMap::from([(1, 1), (2, 1), (3, 1), (4, 1)]);
let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap();
assert_eq!(actual, [0.4, 0.81906897, 2.4, 5.8,]);
assert_eq!(actual, [0.4, 0.8052433, 2.3, 10.9,]);
}
}

0 comments on commit c06f616

Please sign in to comment.