-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
360 lines (281 loc) · 16.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import torch
import torch.nn as nn
import torchviz
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import init
import math
class DownSample(nn.Module):
def __init__(self, output_dim, method=None, input_dim=None):
super(DownSample, self).__init__()
if method is not None:
self.pool1 = nn.AdaptiveMaxPool1d(output_dim)
self.linear = nn.Linear(input_dim, output_dim)
self.method = method
else:
pass
def forward(self, x):
if self.method == 'MaxPool':
reduced = self.pool1(x)
elif self.method == 'Linear':
# x shape: (batch, 7, 512)
batch_size, seq_len, input_dim = x.size()
# Reshape x to (-1, 512) to apply the linear layer
x = x.view(-1, input_dim)
x = self.linear(x)
# Reshape x back to (batch, 7, 23)
reduced = x.view(batch_size, seq_len, -1)
else:
raise ValueError("Invalid method for down sampling. Choose 'MaxPool' or 'Linear'.")
return reduced
# Define a class that applies the transformer L times
class MultiLayerTransformer(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads, num_layers, output_dim=None, downsmaple_method=None):
super(MultiLayerTransformer, self).__init__()
# Define a single transformer encoder layer with batch_first=True
self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, dim_feedforward=hidden_dim,
batch_first=True)
# Stack num_layers of these layers to form the complete transformer encoder
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.down_sample = DownSample(output_dim, downsmaple_method, input_dim)
self.downsample_method = downsmaple_method
def forward(self, z):
z = self.transformer_encoder(z)
# Reduce the output dimension if specified
if self.downsample_method is not None:
z = self.down_sample(z)
return z
# Define a class that applies the MultiLayerModalityTransformer to two modalities
# and concatenates the results with T additional tokens in between
class ModalitySpecificTransformer(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads, num_layers, T, downsmaple_method, fusion_dim):
super(ModalitySpecificTransformer, self).__init__()
self.modality1_transformer = MultiLayerTransformer(input_dim[0], hidden_dim[0], num_heads[0], num_layers[0],
fusion_dim, downsmaple_method)
self.modality2_transformer = MultiLayerTransformer(input_dim[1], hidden_dim[1], num_heads[1], num_layers[1],
fusion_dim, downsmaple_method)
if len(input_dim) == 3:
self.modality3_transformer = MultiLayerTransformer(input_dim[2], hidden_dim[2], num_heads[2], num_layers[2],
fusion_dim, downsmaple_method)
def forward(self, z1, z2, z3):
z1_final = self.modality1_transformer(z1)
z2_final = self.modality2_transformer(z2)
if z3 is not None:
z3_final = self.modality3_transformer(z3)
return z1_final, z2_final, z3_final
return z1_final, z2_final, None
class FusionTransformers(nn.Module):
def __init__(self, input_dim, num_heads, hidden_dim, Lf, B, fusion_dim):
super(FusionTransformers, self).__init__()
self.Lf = Lf
self.T = B
# Adjusting the bottleneck tokens shape for batch_first=True
self.bottleneck_tokens = nn.Parameter(torch.empty(1, B, fusion_dim), requires_grad=True)
init.normal_(self.bottleneck_tokens, mean=0, std=0.02)
self.layers_modality1 = self._get_layers(fusion_dim, num_heads[-1], hidden_dim[-1], Lf)
self.layers_modality2 = self._get_layers(fusion_dim, num_heads[-1], hidden_dim[-1], Lf)
self.layers_modality3 = self._get_layers(fusion_dim, num_heads[-1], hidden_dim[-1], Lf)
@staticmethod
def _get_layers(input_dim, num_heads, hidden_dim, Lf):
encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, dim_feedforward=hidden_dim,
batch_first=True)
return nn.ModuleList([encoder_layer for _ in range(Lf)])
def forward(self, z1, z2, z3=None):
# Adjusting concatenation for batch_first=True
# Repeat Bottleneck tokens for the batch size
final_tokens = self.bottleneck_tokens.repeat(z1.size(0), 1, 1)
for i in range(self.Lf):
z1 = torch.cat((z1, final_tokens), dim=1)
z2 = torch.cat((z2, final_tokens), dim=1)
if z3 is not None:
z3 = torch.cat((z3, final_tokens), dim=1)
z1 = self.layers_modality1[i](z1)
z2 = self.layers_modality2[i](z2)
if z3 is not None:
z3 = self.layers_modality3[i](z3)
# Separate the output into z1, temp_tokens1, z2, and temp_tokens2
z1, temp_tokens1 = z1[:, :-self.T, :], z1[:, -self.T:, :]
z2, temp_tokens2 = z2[:, :-self.T, :], z2[:, -self.T:, :]
if z3 is not None:
z3, temp_tokens3 = z3[:, :-self.T, :], z3[:, -self.T:, :]
final_tokens = (temp_tokens1 + temp_tokens2 + temp_tokens3) / 3
continue
# Average the two temporary tokens
final_tokens = (temp_tokens1 + temp_tokens2) / 2
return z1, z2, z3
def positional_encoding(sequence_length, d_model, device):
"""Compute the sinusoidal positional encoding for a batch of sequences."""
# Initialize a matrix to store the positional encodings
pos_enc = torch.zeros(sequence_length, d_model)
# Compute the positional encodings
for pos in range(sequence_length):
for i in range(0, d_model, 2):
div_term = torch.exp(torch.tensor(-math.log(10000.0) * (i // 2) / d_model))
pos_enc[pos, i] = torch.sin(pos * div_term)
pos_enc[pos, i + 1] = torch.cos(pos * div_term)
# Add an extra dimension to match the batch size in input
pos_enc = pos_enc.unsqueeze(0).to(device)
return pos_enc
class ClassificationHead(nn.Module):
def __init__(self, embedding_dim, seq_len, dropout_rate, head_layer_sizes, n_classes: int = 5):
super().__init__()
self.norm = nn.LayerNorm(embedding_dim)
self.seq = nn.Sequential(nn.Flatten(), nn.Linear(embedding_dim * seq_len, head_layer_sizes[0]), nn.ReLU(),
nn.Dropout(dropout_rate), nn.Linear(head_layer_sizes[0], head_layer_sizes[1]),
nn.ReLU(),
nn.Dropout(dropout_rate), nn.Linear(head_layer_sizes[1], head_layer_sizes[2]),
nn.ReLU(),
nn.Dropout(dropout_rate), nn.Linear(head_layer_sizes[2], n_classes))
def forward(self, x):
x = self.norm(x)
x = self.seq(x)
return x
class ClassificationProcessor(nn.Module):
def __init__(self, input_dim, dropout_rate, mode,
classification_head, max_seq_length, head_layer_sizes, num_classes, modalities, fusion_dim):
super(ClassificationProcessor, self).__init__()
self.dropout1 = nn.Dropout(dropout_rate)
self.dropout2 = nn.Dropout(dropout_rate)
self.dropout3 = nn.Dropout(dropout_rate) if len(modalities) == 3 else None
# Classification heads or layers
if classification_head:
self.combined_classifier = ClassificationHead(fusion_dim, max_seq_length * len(modalities)
, dropout_rate, head_layer_sizes)
self.classifier1 = ClassificationHead(fusion_dim, max_seq_length, dropout_rate, head_layer_sizes)
self.classifier2 = ClassificationHead(fusion_dim, max_seq_length, dropout_rate, head_layer_sizes)
self.classifier3 = ClassificationHead(fusion_dim, max_seq_length, dropout_rate, head_layer_sizes)
elif not classification_head:
self.combined_classifier = nn.Linear(len(modalities) * fusion_dim, num_classes) # Combined classifier
self.classifier1 = nn.Linear(fusion_dim, num_classes) # Separate classifier for modality 1
self.classifier2 = nn.Linear(fusion_dim, num_classes) # Separate classifier for modality 2
self.classifier3 = nn.Linear(fusion_dim, num_classes) # Separate classifier for modality 3
self.mode = mode
self.classification_head = classification_head
def forward(self, z1_out, z2_out, z3_out=None):
if self.classification_head:
if self.mode == 'concat':
representations = [z1_out, z2_out]
if z3_out is not None:
representations.append(z3_out)
combined_cls = torch.cat(representations, dim=1)
final_output = self.combined_classifier(combined_cls)
return final_output
elif self.mode == 'separate':
logits_output_1 = self.classifier1(z1_out)
logits_output_2 = self.classifier2(z2_out)
logits = [logits_output_1, logits_output_2]
if z3_out is not None and self.classifier3:
logits_output_3 = self.classifier3(z3_out)
logits.append(logits_output_3)
final_output = sum(logits) / len(logits)
return final_output
else:
raise ValueError("Invalid mode. Choose 'concat' or 'separate'.")
elif not self.classification_head:
cls_representation1 = self.dropout1(z1_out[:, 0, :])
cls_representation2 = self.dropout2(z2_out[:, 0, :])
cls_representation3 = self.dropout3(z3_out[:, 0, :]) if z3_out is not None else None
if self.mode == 'concat':
representations = [cls_representation1, cls_representation2]
if cls_representation3 is not None:
representations.append(cls_representation3)
combined_cls = torch.cat(representations, dim=1)
final_output = self.combined_classifier(combined_cls)
elif self.mode == 'separate':
logits_output_1 = self.classifier1(cls_representation1)
logits_output_2 = self.classifier2(cls_representation2)
logits = [logits_output_1, logits_output_2]
if cls_representation3 is not None:
logits.append(self.classifier3(cls_representation3))
final_output = sum(logits) / len(logits)
else:
raise ValueError("Invalid mode. Choose 'concat' or 'separate'.")
else:
raise ValueError("Invalid classification head. Choose True or False.")
return final_output
class AttentionBottleneckFusion(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads, num_layers, Lf, T, num_classes, device, max_seq_length,
mode, dropout_rate, downsmaple_method, classification_head, head_layer_sizes, modalities,
fusion_dim):
super(AttentionBottleneckFusion, self).__init__()
# CLS tokens for each modality
self.cls_token1 = nn.Parameter(2 * torch.rand(1, 1, input_dim[0]) - 1, requires_grad=True)
self.cls_token2 = nn.Parameter(2 * torch.rand(1, 1, input_dim[1]) - 1, requires_grad=True)
self.cls_token3 = nn.Parameter(2 * torch.rand(1, 1, input_dim[2]) - 1, requires_grad=True) if len(modalities) == 3 else None
# Positional encodings
self.positional_encodings1 = positional_encoding(100, input_dim[0], device)
self.positional_encodings2 = positional_encoding(100, input_dim[1], device)
self.positional_encodings3 = positional_encoding(100, input_dim[2], device) if len(modalities) == 3 else None
# Initialize ModalitySpecificTransformer
self.modality_specific_transformer = ModalitySpecificTransformer(input_dim, hidden_dim, num_heads, num_layers,
T, downsmaple_method, fusion_dim)
# Initialize FusionTransformers
self.fusion_transformer = FusionTransformers(input_dim, num_heads, hidden_dim, Lf, T, fusion_dim)
self.classification_processor = ClassificationProcessor(input_dim, dropout_rate, mode, classification_head,
max_seq_length, head_layer_sizes, num_classes,
modalities, fusion_dim)
self.classification_head = classification_head
def forward(self, z1, z2, z3=None):
if not self.classification_head:
# Concat the CLS tokens for each modality
cls_token1_embed = self.cls_token1.repeat(z1.size(0), 1, 1)
cls_token2_embed = self.cls_token2.repeat(z2.size(0), 1, 1)
cls_token3_embed = self.cls_token3.repeat(z3.size(0), 1, 1) if z3 is not None else None
z1 = torch.cat([cls_token1_embed, z1], dim=1)
z2 = torch.cat([cls_token2_embed, z2], dim=1)
z3 = torch.cat([cls_token3_embed, z3], dim=1) if z3 is not None else None
# Add positional encodings
z1 = z1 + self.positional_encodings1[:, z1.size(1), :]
z2 = z2 + self.positional_encodings2[:, z2.size(1), :]
z3 = z3 + self.positional_encodings3[:, z3.size(1), :] if z3 is not None else None
# Get the outputs from the modality-specific transformers
z1, z2, z3 = self.modality_specific_transformer(z1, z2, z3)
# Feed the outputs to the FusionTransformers
z1_out, z2_out, z3_out = self.fusion_transformer(z1, z2, z3)
# Classification using the classification head
final_output = self.classification_processor(z1_out, z2_out, z3_out)
return final_output
class SingleModalityTransformer(nn.Module):
def __init__(self, input_dim, hidden_dim, num_heads, num_layers, Lf, T, num_classes, device, max_seq_length,
mode, dropout_rate, downsmaple_method, classification_head, head_layer_sizes, output_dim):
super(SingleModalityTransformer, self).__init__()
"""
- output_dim: The output dimension of the transformer will be downsampled to this dimension.
- downsmaple_method: The method used to downsample the output of the transformer. Choose 'MaxPool' or 'Linear'
or None for no downsampling.
"""
# CLS tokens for each modality
self.cls_token1 = nn.Parameter(2 * torch.rand(1, 1, input_dim) - 1, requires_grad=True)
# Positional encodings
self.positional_encodings1 = positional_encoding(100, input_dim, device)
# Initialize ModalitySpecificTransformer
self.multi_layer_transformer = MultiLayerTransformer(input_dim, hidden_dim, num_heads, num_layers,
output_dim, downsmaple_method)
# Dropout layers
self.dropout1 = nn.Dropout(dropout_rate)
# Classification heads or layers
if classification_head:
self.classifier1 = ClassificationHead(output_dim, max_seq_length, dropout_rate, head_layer_sizes)
elif not classification_head:
self.classifier1 = nn.Linear(input_dim, num_classes) # Separate classifier for modality 1
self.classification_head = classification_head
def forward(self, z1):
# Concat the CLS tokens for each modality
cls_token1_embed = self.cls_token1.repeat(z1.size(0), 1, 1)
z1 = torch.cat([cls_token1_embed, z1], dim=1)
# Add positional encodings
z1 = z1 + self.positional_encodings1[:, z1.size(1), :]
# Get the outputs from the modality-specific transformers
z1_out = self.multi_layer_transformer(z1)
# Classification using the classification head
if self.classification_head:
final_output = self.classifier1(z1_out)
return final_output
# Classification without classification head
elif not self.classification_head:
# Extracting the CLS token's representation post transformation
cls_representation1 = self.dropout1(z1_out[:, 0, :])
final_output = self.classifier1(cls_representation1)
else:
raise ValueError("Invalid classification head. Choose True or False.")
return final_output