Skip to content

Commit

Permalink
Feat/train with .csv
Browse files Browse the repository at this point in the history
  • Loading branch information
ishiko732 committed Mar 1, 2024
1 parent 9e334db commit 5091d56
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 5 deletions.
2 changes: 2 additions & 0 deletions sandbox/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"@surma/rollup-plugin-off-main-thread": "^2.2.3",
"@types/sql.js": "^1.4.9",
"@types/surma__rollup-plugin-off-main-thread": "^2.2.3",
"@types/papaparse": "^5.3.14",
"prettier": "^3.2.5",
"sass": "^1.71.1",
"solid-devtools": "^0.29.3",
Expand All @@ -23,6 +24,7 @@
},
"dependencies": {
"@popperjs/core": "^2.11.8",
"papaparse": "^5.4.1",
"bootstrap": "^5.3.3",
"fsrs-browser": "link:../pkg",
"solid-js": "^1.8.15",
Expand Down
17 changes: 16 additions & 1 deletion sandbox/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sandbox/postinstall.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ cp ./node_modules/sql.js/dist/sql-wasm.wasm ./src/assets/sql-wasm.wasm
mkdir -p ./public/
wget -nc -O ./public/collection.anki21.zip https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip
unzip -n ./public/collection.anki21.zip -d ./public
wget -nc -O ./public/revlog.csv https://github.com/open-spaced-repetition/fsrs4anki/files/12515294/revlog.csv
9 changes: 7 additions & 2 deletions sandbox/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ const App: Component = () => {
</button>
<label>
Train with custom file
<input type='file' onChange={customFile} accept='.anki21' />
<input type='file' onChange={customFile} accept='.anki21, .csv'/>
</label>
<button onclick={testSerialization}>
<div>Test Serialization</div>
Expand All @@ -62,7 +62,12 @@ async function customFile(
// My mental static analysis says to use `currentTarget`, but it seems to randomly be null, hence `target`. I'm confused but whatever.
event.target.files?.item(0) ?? throwExp('There should be a file selected')
let ab = await file.arrayBuffer()
train({ data: ab })
if (file.name.endsWith('.csv')) {
train({data: file})
} else {
train({data: ab})
}

}

export default App
Expand Down
65 changes: 63 additions & 2 deletions sandbox/src/train.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import wasm, { initThreadPool, Fsrs } from 'fsrs-browser/fsrs_browser'
import initSqlJs, { type Database } from 'sql.js'
import sqliteWasmUrl from './assets/sql-wasm.wasm?url'
import * as papa from "papaparse";

// @ts-ignore https://github.com/rustwasm/console_error_panic_hook#errorstacktracelimit
Error.stackTraceLimit = 30
Expand All @@ -9,14 +10,16 @@ const sqlJs = initSqlJs({
locateFile: () => sqliteWasmUrl,
})

export const train = async (event: { data: 'autotrain' | ArrayBuffer }) => {
export const train = async (event: { data: 'autotrain' | ArrayBuffer | File }) => {
if (event.data === 'autotrain') {
let db = await fetch('/collection.anki21')
let ab = await db.arrayBuffer()
loadSqliteAndRun(ab)
} else if (event.data instanceof ArrayBuffer) {
loadSqliteAndRun(event.data)
}
} else if (event.data instanceof File) {
csvTrain(event.data)
}
}

async function loadSqliteAndRun(ab: ArrayBuffer) {
Expand Down Expand Up @@ -68,6 +71,64 @@ async function loadSqliteAndRun(ab: ArrayBuffer) {
}
}


// use the csv file to train the model
// dataset: https://github.com/open-spaced-repetition/fsrs4anki/issues/450
interface ParseData {
review_time: string,
card_id: string,
review_rating: string,
review_duration: string,
review_state: string
}

interface csvTrainDataItem {
card_id: bigint,
review_time: bigint,
review_state: number,
review_rating: number
}
async function csvTrain(csv: File) {
await wasm()
await initThreadPool(navigator.hardwareConcurrency)
await sleep(1000) // the workers need time to spin up. TODO, post an init message and await a response. Also maybe move worker construction to Javascript.
console.time('full training time')
const result: csvTrainDataItem[] = [];
const fsrs = new Fsrs()
papa.parse<ParseData>(csv, {
header: true,
delimiter: ",",
step: function (row) {
const data = row.data;
if (data.card_id === undefined) return;
result.push({
card_id: BigInt(data.card_id),
review_time: BigInt(data.review_time),
review_state: Number(data.review_state),
review_rating: Number(data.review_rating),
});
},
complete: function (_) {
const cids: BigInt64Array = new BigInt64Array([
...result.map((r) => r.card_id),
]);
const eases: Uint8Array = new Uint8Array([
...result.map((r) => r.review_rating),
]);
const ids: BigInt64Array = new BigInt64Array([
...result.map((r) => r.review_time),
]);
const types: Uint8Array = new Uint8Array([
...result.map((r) => r.review_state),
]);
const weights = fsrs.computeWeightsAnki(cids, eases, ids, types)
console.timeEnd('full training time')
console.log('trained weights are', weights)
console.log('revlog count', result.length)
},
});
}

async function getDb(ab: ArrayBuffer): Promise<Database> {
const sql = await sqlJs
return new sql.Database(new Uint8Array(ab))
Expand Down

0 comments on commit 5091d56

Please sign in to comment.