-
Notifications
You must be signed in to change notification settings - Fork 21
/
classifier.py
executable file
·31 lines (27 loc) · 1.19 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-#
#-------------------------------------------------------------------------------
# Name: classifier
# Description: This code is modified based on Jin-Hwa Kim's repository (Bilinear Attention Networks - https://github.com/jnhwkim/ban-vqa) by Xuan B. Nguyen
# Author: Boliu.Kelvin
# Date: 2020/4/7
#-------------------------------------------------------------------------------
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
class SimpleClassifier(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, args):
super(SimpleClassifier, self).__init__()
activation_dict = {'relu': nn.ReLU()}
try:
activation_func = activation_dict[args.activation]
except:
raise AssertionError(args.activation + " is not supported yet!")
layers = [
weight_norm(nn.Linear(in_dim, hid_dim), dim=None),
activation_func,
nn.Dropout(args.dropout, inplace=True),
weight_norm(nn.Linear(hid_dim, out_dim), dim=None)
]
self.main = nn.Sequential(*layers)
def forward(self, x):
logits = self.main(x)
return logits