-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathvit.py
147 lines (129 loc) · 4.97 KB
/
vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.layers import (
Dense,
Dropout,
LayerNormalization,
)
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
class MultiHeadSelfAttention(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads=8):
super(MultiHeadSelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
if embed_dim % num_heads != 0:
raise ValueError(
f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}"
)
self.projection_dim = embed_dim // num_heads
self.query_dense = Dense(embed_dim)
self.key_dense = Dense(embed_dim)
self.value_dense = Dense(embed_dim)
self.combine_heads = Dense(embed_dim)
def attention(self, query, key, value):
score = tf.matmul(query, key, transpose_b=True)
dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
scaled_score = score / tf.math.sqrt(dim_key)
weights = tf.nn.softmax(scaled_score, axis=-1)
output = tf.matmul(weights, value)
return output, weights
def separate_heads(self, x, batch_size):
x = tf.reshape(
x, (batch_size, -1, self.num_heads, self.projection_dim)
)
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
query = self.query_dense(inputs)
key = self.key_dense(inputs)
value = self.value_dense(inputs)
query = self.separate_heads(query, batch_size)
key = self.separate_heads(key, batch_size)
value = self.separate_heads(value, batch_size)
attention, weights = self.attention(query, key, value)
attention = tf.transpose(attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(
attention, (batch_size, -1, self.embed_dim)
)
output = self.combine_heads(concat_attention)
return output
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super(TransformerBlock, self).__init__()
self.att = MultiHeadSelfAttention(embed_dim, num_heads)
self.ffn = tf.keras.Sequential(
[Dense(ff_dim, activation="relu"), Dense(embed_dim),]
)
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
def call(self, inputs, training):
attn_output = self.att(inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
class VisionTransformer(tf.keras.Model):
def __init__(
self,
image_size,
patch_size,
num_layers,
num_classes,
d_model,
num_heads,
mlp_dim,
channels=3,
dropout=0.1,
):
super(VisionTransformer, self).__init__()
num_patches = (image_size // patch_size) ** 2
self.patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.d_model = d_model
self.num_layers = num_layers
self.rescale = Rescaling(1./255)
self.pos_emb = self.add_weight(
"pos_emb", shape=(1, num_patches + 1, d_model)
)
self.class_emb = self.add_weight("class_emb", shape=(1, 1, d_model))
self.patch_proj = Dense(d_model)
self.enc_layers = [
TransformerBlock(d_model, num_heads, mlp_dim, dropout)
for _ in range(num_layers)
]
self.mlp_head = tf.keras.Sequential(
[
Dense(mlp_dim, activation=tfa.activations.gelu),
Dropout(dropout),
Dense(num_classes),
]
)
def extract_patches(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patches = tf.reshape(patches, [batch_size, -1, self.patch_dim])
return patches
def call(self, x, training):
batch_size = tf.shape(x)[0]
x = self.rescale(x)
patches = self.extract_patches(x)
x = self.patch_proj(patches)
class_emb = tf.broadcast_to(
self.class_emb, [batch_size, 1, self.d_model]
)
x = tf.concat([class_emb, x], axis=1)
x = x + self.pos_emb
for layer in self.enc_layers:
x = layer(x, training)
# First (class token) is used for classification
x = self.mlp_head(x[:, 0])
return x