-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
lstm.ts
145 lines (133 loc) · 4.09 KB
/
lstm.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import { Matrix } from './matrix';
import { Equation } from './matrix/equation';
import { RandomMatrix } from './matrix/random-matrix';
import { IRNNHiddenLayer, RNN } from './rnn';
export interface ILSTMHiddenLayer extends IRNNHiddenLayer {
inputMatrix: Matrix;
inputHidden: Matrix;
inputBias: Matrix;
forgetMatrix: Matrix;
forgetHidden: Matrix;
forgetBias: Matrix;
outputMatrix: Matrix;
outputHidden: Matrix;
outputBias: Matrix;
cellActivationMatrix: Matrix;
cellActivationHidden: Matrix;
cellActivationBias: Matrix;
}
export class LSTM extends RNN {
getHiddenLayer(hiddenSize: number, prevSize: number): IRNNHiddenLayer {
return getHiddenLSTMLayer(hiddenSize, prevSize);
}
getEquation(
equation: Equation,
inputMatrix: Matrix,
previousResult: Matrix,
hiddenLayer: IRNNHiddenLayer
): Matrix {
return getLSTMEquation(
equation,
inputMatrix,
previousResult,
hiddenLayer as ILSTMHiddenLayer
);
}
}
export function getHiddenLSTMLayer(
hiddenSize: number,
prevSize: number
): ILSTMHiddenLayer {
return {
// gates parameters
// wix
inputMatrix: new RandomMatrix(hiddenSize, prevSize, 0.08), // wih
inputHidden: new RandomMatrix(hiddenSize, hiddenSize, 0.08), // bi
inputBias: new Matrix(hiddenSize, 1),
// wfx
forgetMatrix: new RandomMatrix(hiddenSize, prevSize, 0.08), // wfh
forgetHidden: new RandomMatrix(hiddenSize, hiddenSize, 0.08), // bf
forgetBias: new Matrix(hiddenSize, 1),
// wox
outputMatrix: new RandomMatrix(hiddenSize, prevSize, 0.08), // woh
outputHidden: new RandomMatrix(hiddenSize, hiddenSize, 0.08), // bo
outputBias: new Matrix(hiddenSize, 1),
// cell write params
// wcx
cellActivationMatrix: new RandomMatrix(hiddenSize, prevSize, 0.08), // wch
cellActivationHidden: new RandomMatrix(hiddenSize, hiddenSize, 0.08), // bc
cellActivationBias: new Matrix(hiddenSize, 1),
};
}
export function getLSTMEquation(
equation: Equation,
inputMatrix: Matrix,
previousResult: Matrix,
hiddenLayer: ILSTMHiddenLayer
): Matrix {
if (
!hiddenLayer.inputMatrix ||
!hiddenLayer.inputHidden ||
!hiddenLayer.inputBias ||
!hiddenLayer.forgetMatrix ||
!hiddenLayer.forgetHidden ||
!hiddenLayer.forgetBias ||
!hiddenLayer.outputMatrix ||
!hiddenLayer.outputHidden ||
!hiddenLayer.outputBias ||
!hiddenLayer.cellActivationMatrix ||
!hiddenLayer.cellActivationHidden ||
!hiddenLayer.cellActivationBias
) {
throw new Error('hiddenLayer does not have expected properties');
}
const sigmoid = equation.sigmoid.bind(equation);
const add = equation.add.bind(equation);
const multiply = equation.multiply.bind(equation);
const multiplyElement = equation.multiplyElement.bind(equation);
const tanh = equation.tanh.bind(equation);
const inputGate = sigmoid(
add(
add(
multiply(hiddenLayer.inputMatrix, inputMatrix),
multiply(hiddenLayer.inputHidden, previousResult)
),
hiddenLayer.inputBias
)
);
const forgetGate = sigmoid(
add(
add(
multiply(hiddenLayer.forgetMatrix, inputMatrix),
multiply(hiddenLayer.forgetHidden, previousResult)
),
hiddenLayer.forgetBias
)
);
// output gate
const outputGate = sigmoid(
add(
add(
multiply(hiddenLayer.outputMatrix, inputMatrix),
multiply(hiddenLayer.outputHidden, previousResult)
),
hiddenLayer.outputBias
)
);
// write operation on cells
const cellWrite = tanh(
add(
add(
multiply(hiddenLayer.cellActivationMatrix, inputMatrix),
multiply(hiddenLayer.cellActivationHidden, previousResult)
),
hiddenLayer.cellActivationBias
)
);
// compute new cell activation
const retainCell = multiplyElement(forgetGate, previousResult); // what do we keep from cell
const writeCell = multiplyElement(inputGate, cellWrite); // what do we write to cell
const cell = add(retainCell, writeCell); // new cell contents
// compute hidden state as gated, saturated cell activations
return multiplyElement(outputGate, tanh(cell));
}