-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathFastText.js
99 lines (78 loc) · 2.6 KB
/
FastText.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
const Args = require('./Args');
const Dictionary = require('./Dictionary');
const Matrix = require('./Matrix');
const QMatrix = require('./QMatrix');
const Model = require('./Model');
const SortedArray = require('./SortedArray');
const Vector = require('./Vector');
const FtzReader = require('./FtzReader');
class FastText {
constructor () {
this.input = new Matrix();
this.output = new Matrix();
this.qinput = new QMatrix();
this.qoutput = new QMatrix();
this.quant = false;
}
/**
*
* @param {String} modelPath
* @param {function} callback
*/
loadModel(modelPath, callback) {
this.ftzReader = new FtzReader(modelPath);
this.ftzReader.open((err, isOpen) => {
this.args = new Args();
this.args.load(this.ftzReader);
this.dictionary = new Dictionary(this.args);
this.dictionary.load(this.ftzReader);
this.quant = !!this.ftzReader.readUInt8();
if (this.quant) {
this.qinput.load(this.ftzReader);
} else {
this.input.load(this.ftzReader);
}
this.args.qout = !!this.ftzReader.readUInt8();
if (this.quant && this.args.qout) {
this.qoutput.load(this.ftzReader);
} else {
this.output.load(this.ftzReader);
}
this.model = new Model(this.input, this.output, this.args, 0);
this.model.quant = this.quant;
this.model.setQuantizePointer(this.qinput, this.qoutput, this.args.qout);
if (this.args.model == this.args.model_name.sup) {
this.model.setTargetCounts(this.dictionary.getCounts(this.dictionary.entry_type.label));
} else {
this.model.setTargetCounts(this.dictionary.getCounts(this.dictionary.entry_type.word));
}
callback();
});
}
/**
*
* @param {String} inputText
* @param {Number} k
* @param {Number} numResults
*/
predict(inputText, k, numResults) {
this.predictions = [];
let words = [];
let labels = [];
this.predictions = [];
this.dictionary.getLine(inputText, words, labels);
if (words.length === 0) return;
let hidden = new Vector(this.args.dim);
let output = new Vector(this.dictionary.nlabels);
let modelPredictions = new SortedArray();
this.model.predict(words, k, modelPredictions, hidden, output);
modelPredictions.data.forEach((pred) => {
this.predictions.push({
probability: Math.exp(pred.first),
label: this.dictionary.getLabel(pred.second)
});
})
return this.predictions;
}
}
module.exports = FastText;