-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetalearner.py
162 lines (125 loc) · 5.72 KB
/
metalearner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import division, print_function, absolute_import
import pdb
import math
import torch
import torch.nn as nn
# 2023/10/21
class MetaLSTMCell(nn.Module):
"""C_t = f_t * C_{t-1} + i_t * \tilde{C_t}"""
def __init__(self, input_size, hidden_size, n_learner_params):
super(MetaLSTMCell, self).__init__()
"""Args:
input_size (int): cell input size, default = 20
hidden_size (int): should be 1
n_learner_params (int): number of learner's parameters
"""
self.input_size = input_size
self.hidden_size = hidden_size
self.n_learner_params = n_learner_params
self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))
self.bI = nn.Parameter(torch.Tensor(1, hidden_size))
self.bF = nn.Parameter(torch.Tensor(1, hidden_size))
self.reset_parameters()
def reset_parameters(self):
for weight in self.parameters():
nn.init.uniform_(weight, -0.01, 0.01)
# want initial forget value to be high and input value to be low so that
# model starts with gradient descent
nn.init.uniform_(self.bF, 4, 6)
nn.init.uniform_(self.bI, -5, -4)
def init_cI(self, flat_params):
self.cI.data.copy_(flat_params.unsqueeze(1))
def forward(self, inputs, hx=None):
"""Args:
inputs = [x_all, grad]:
x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM
grad (torch.Tensor of size [n_learner_params]): gradients from learner
hx = [f_prev, i_prev, c_prev]:
f (torch.Tensor of size [n_learner_params, 1]): forget gate
i (torch.Tensor of size [n_learner_params, 1]): input gate
c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters
"""
x_all, grad = inputs
# print(x_all.shape,grad.shape)
batch, _ = x_all.size()
if hx is None:
f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)
i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)
c_prev = self.cI
hx = [f_prev, i_prev, c_prev]
f_prev, i_prev, c_prev = hx
# f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f)
f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)
# i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i)
i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)
# next cell/params
c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)
return c_next, [f_next, i_next, c_next]
def extra_repr(self):
s = '{input_size}, {hidden_size}, {n_learner_params}'
return s.format(**self.__dict__)
class Value(torch.nn.Module):
def __init__(self, dim_input, dim_val):
super(Value, self).__init__()
self.dim_val = dim_val
self.fc1 = nn.Linear(dim_input, dim_val, bias = True)
def forward(self, x):
x = self.fc1(x)
return x
class Key(torch.nn.Module):
def __init__(self, dim_input, dim_attn):
super(Key, self).__init__()
self.dim_attn = dim_attn
self.fc1 = nn.Linear(dim_input, dim_attn, bias = True)
def forward(self, x):
x = self.fc1(x)
return x
class Query(torch.nn.Module):
def __init__(self, dim_input, dim_attn):
super(Query, self).__init__()
self.dim_attn = dim_attn
self.fc1 = nn.Linear(dim_input, dim_attn, bias = True)
def forward(self, x):
x = self.fc1(x)
return x
class MetaLearner(nn.Module):
def __init__(self, input_size, hidden_size, n_learner_params):
super(MetaLearner, self).__init__()
"""Args:
input_size (int): for the first LSTM layer, default = 4
hidden_size (int): for the first LSTM layer, default = 20
n_learner_params (int): number of learner's parameters
"""
self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params)
self.value = Value(948, 948)
self.key = Key(948, 20)
self.query = Query(948, 20)
self.attend = nn.Softmax(dim = -1)
self.linear = nn.Linear(in_features=4, out_features=20)
def forward(self, inputs, hs=None):
"""Args:
inputs = [loss, grad_prep, grad]
loss (torch.Tensor of size [1, 2])
grad_prep (torch.Tensor of size [n_learner_params, 2])
grad (torch.Tensor of size [n_learner_params])
hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]]
"""
loss, grad_prep, grad = inputs
loss = loss.expand_as(grad_prep)
inputs = torch.cat((loss, grad_prep), 1)
if hs is None:
hs = [None, None]
lstmhx, lstmcx = self.lstm(inputs, hs[0])
hx = torch.reshape(lstmhx, (1217,948))
v_h = self.value(hx)
k_h = self.key(hx)
q_h = self.query(hx)
dots = torch.matmul(q_h, k_h.transpose(-1, -2)) * 948 ** -0.5
out = torch.matmul(dots, v_h)
attention_h = hx + out
att = torch.reshape(attention_h, (288429,4))
flat_learner_unsqzd, metalstm_hs = self.metalstm([att, grad], hs[1])
return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs]