diff --git a/fsrs.go b/fsrs.go index dc8f3d9..4cbbcb8 100644 --- a/fsrs.go +++ b/fsrs.go @@ -34,10 +34,10 @@ func (p *Parameters) Repeat(card Card, now time.Time) map[Rating]SchedulingInfo s.schedule(now, hardInterval, goodInterval, easyInterval) case Review: - interval := float64(card.ElapsedDays) + elapsedDays := float64(card.ElapsedDays) lastD := card.Difficulty lastS := card.Stability - retrievability := math.Pow(1+interval/(9*lastS), -1) + retrievability := p.forgettingCurve(elapsedDays, lastS) p.nextDS(s, lastD, lastS, retrievability) hardInterval := p.nextInterval(s.Hard.Stability) @@ -121,6 +121,10 @@ func (s *schedulingCards) recordLog(card Card, now time.Time) map[Rating]Schedul return m } +func (p *Parameters) forgettingCurve(elapsedDays float64, stability float64) float64 { + return math.Pow(1+p.Factor*elapsedDays/stability, p.Decay) +} + func (p *Parameters) initDS(s *schedulingCards) { s.Again.Difficulty = p.initDifficulty(Again) s.Again.Stability = p.initStability(Again) @@ -155,7 +159,7 @@ func constrainDifficulty(d float64) float64 { } func (p *Parameters) nextInterval(s float64) float64 { - newInterval := s * 9 * (1/p.RequestRetention - 1) + newInterval := s / p.Factor * (math.Pow(p.RequestRetention, 1/p.Decay) - 1) return math.Max(math.Min(math.Round(newInterval), p.MaximumInterval), 1) } diff --git a/fsrs_test.go b/fsrs_test.go index f3358d2..57a1d33 100644 --- a/fsrs_test.go +++ b/fsrs_test.go @@ -3,14 +3,21 @@ package fsrs import ( "encoding/json" "fmt" + "math" "reflect" "testing" "time" ) +func roundFloat(val float64, precision uint) float64 { + ratio := math.Pow(10, float64(precision)) + return math.Round(val*ratio) / ratio +} + func TestExample(t *testing.T) { p := DefaultParam() - p.W = Weights{1.14, 1.01, 5.44, 14.67, 5.3024, 1.5662, 1.2503, 0.0028, 1.5489, 0.1763, 0.9953, 2.7473, 0.0179, 0.3105, 0.3976, 0.0, 2.0902} + p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321, + 2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244} card := NewCard() now := time.Date(2022, 11, 29, 12, 30, 0, 0, time.UTC) var ivlList []uint64 @@ -38,12 +45,58 @@ func TestExample(t *testing.T) { fmt.Println(ivlList) fmt.Println(stateList) - wantIvlList := []uint64{0, 5, 16, 43, 106, 236, 0, 0, 12, 25, 47, 85, 147} + wantIvlList := []uint64{0, 4, 15, 49, 143, 379, 0, 0, 15, 37, 85, 184, 376} if !reflect.DeepEqual(ivlList, wantIvlList) { - t.Errorf("excepted:%v, got:%v", ivlList, wantIvlList) + t.Errorf("excepted:%v, got:%v", wantIvlList, ivlList) } wantStateList := []State{New, Learning, Review, Review, Review, Review, Review, Relearning, Relearning, Review, Review, Review, Review} if !reflect.DeepEqual(stateList, wantStateList) { - t.Errorf("excepted:%v, got:%v", stateList, wantStateList) + t.Errorf("excepted:%v, got:%v", wantStateList, stateList) + } +} + +func TestMemoState(t *testing.T) { + p := DefaultParam() + p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321, + 2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244} + card := NewCard() + now := time.Date(2022, 11, 29, 12, 30, 0, 0, time.UTC) + + schedulingCards := p.Repeat(card, now) + var ratings = []Rating{Again, Good, Good, Good, Good, Good} + var ivlList = []uint64{0, 0, 1, 3, 8, 21} + var rating Rating + for i := 0; i < len(ratings); i++ { + rating = ratings[i] + card = schedulingCards[rating].Card + now = now.Add(time.Duration(ivlList[i]) * 24 * time.Hour) + schedulingCards = p.Repeat(card, now) + } + wantStability := 43.0554 + cardStability := roundFloat(schedulingCards[Good].Card.Stability, 4) + wantDifficulty := 7.7609 + cardDifficulty := roundFloat(schedulingCards[Good].Card.Difficulty, 4) + + if !reflect.DeepEqual(wantStability, cardStability) { + t.Errorf("excepted:%v, got:%v", wantStability, cardStability) + } + + if !reflect.DeepEqual(wantDifficulty, cardDifficulty) { + t.Errorf("excepted:%v, got:%v", wantDifficulty, cardDifficulty) + } +} + +func TestNextInterval(t *testing.T) { + p := DefaultParam() + p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321, + 2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244} + var ivlList []float64 + for i := 1; i <= 10; i++ { + p.RequestRetention = float64(i) / 10 + ivlList = append(ivlList, p.nextInterval(1)) + } + wantIvlList := []float64{422, 102, 43, 22, 13, 8, 4, 2, 1, 1} + if !reflect.DeepEqual(ivlList, wantIvlList) { + t.Errorf("excepted:%v, got:%v", wantIvlList, ivlList) } } diff --git a/params.go b/params.go index 7cc916d..b4a0754 100644 --- a/params.go +++ b/params.go @@ -1,21 +1,30 @@ package fsrs +import "math" + type Weights [17]float64 type Parameters struct { RequestRetention float64 `json:"RequestRetention"` MaximumInterval float64 `json:"MaximumInterval"` W Weights `json:"Weights"` + Decay float64 `json:"Decay"` + Factor float64 `json:"Factor"` } func DefaultParam() Parameters { + var Decay = -0.5 + var Factor = math.Pow(0.9, 1/Decay) - 1 return Parameters{ RequestRetention: 0.9, MaximumInterval: 36500, W: DefaultWeights(), + Decay: Decay, + Factor: Factor, } } func DefaultWeights() Weights { - return Weights{0.4, 0.6, 2.4, 5.8, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29, 2.61} + return Weights{0.5701, 1.4436, 4.1386, 10.9355, 5.1443, 1.2006, 0.8627, 0.0362, 1.629, 0.1342, 1.0166, 2.1174, + 0.0839, 0.3204, 1.4676, 0.219, 2.8237} }