forked from pavelgonchar/colornet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1a251a3
Showing
13 changed files
with
378 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#colornet | ||
Neural Network to colorize grayscale images | ||
|
||
Results | ||
------- | ||
|
||
| Grayscale | Prediction | Ground Truth | | ||
|---|---|---| | ||
|
||
![grayscale-pred-groundtruth](summary/3000_0.png?raw=true "grayscale-pred-groundtruth-3000") | ||
|
||
![grayscale-pred-groundtruth](summary/7000_0.png?raw=true "grayscale-pred-groundtruth-7000") | ||
|
||
|
||
Sources | ||
------- | ||
- [Automatic Colorization](http://tinyclouds.org/colorize/) | ||
- [Hypercolumns for Object Segmentation and Fine-grained Localization](http://arxiv.org/pdf/1411.5752v2.pdf) | ||
- [ILSVRC-2014 model (VGG team) with 16 weight layers](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md) and [Tensorflow version](https://github.com/ry/tensorflow-vgg16) | ||
- [YUV from Wikipedia](https://en.wikipedia.org/wiki/YUV) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""A helper class for managing batch normalization state. | ||
This class is designed to simplify adding batch normalization | ||
(http://arxiv.org/pdf/1502.03167v3.pdf) to your model by | ||
managing the state variables associated with it. | ||
Important use note: The function get_assigner() returns | ||
an op that must be executed to save the updated state. | ||
A suggested way to do this is to make execution of the | ||
model optimizer force it, e.g., by: | ||
update_assignments = tf.group(bn1.get_assigner(), | ||
bn2.get_assigner()) | ||
with tf.control_dependencies([optimizer]): | ||
optimizer = tf.group(update_assignments) | ||
""" | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class ConvolutionalBatchNormalizer(object): | ||
"""Helper class that groups the normalization logic and variables. | ||
Use: | ||
ewma = tf.train.ExponentialMovingAverage(decay=0.99) | ||
bn = ConvolutionalBatchNormalizer(depth, 0.001, ewma, True) | ||
update_assignments = bn.get_assigner() | ||
x = bn.normalize(y, train=training?) | ||
(the output x will be batch-normalized). | ||
""" | ||
|
||
def __init__(self, depth, epsilon, ewma_trainer, scale_after_norm): | ||
self.mean = tf.Variable(tf.constant(0.0, shape=[depth]), | ||
trainable=False) | ||
self.variance = tf.Variable(tf.constant(1.0, shape=[depth]), | ||
trainable=False) | ||
self.beta = tf.Variable(tf.constant(0.0, shape=[depth])) | ||
self.gamma = tf.Variable(tf.constant(1.0, shape=[depth])) | ||
self.ewma_trainer = ewma_trainer | ||
self.epsilon = epsilon | ||
self.scale_after_norm = scale_after_norm | ||
|
||
def get_assigner(self): | ||
"""Returns an EWMA apply op that must be invoked after optimization.""" | ||
return self.ewma_trainer.apply([self.mean, self.variance]) | ||
|
||
def normalize(self, x, train=True): | ||
"""Returns a batch-normalized version of x.""" | ||
if train: | ||
mean, variance = tf.nn.moments(x, [0, 1, 2]) | ||
assign_mean = self.mean.assign(mean) | ||
assign_variance = self.variance.assign(variance) | ||
with tf.control_dependencies([assign_mean, assign_variance]): | ||
return tf.nn.batch_norm_with_global_normalization( | ||
x, mean, variance, self.beta, self.gamma, | ||
self.epsilon, self.scale_after_norm) | ||
else: | ||
mean = self.ewma_trainer.average(self.mean) | ||
variance = self.ewma_trainer.average(self.variance) | ||
local_beta = tf.identity(self.beta) | ||
local_gamma = tf.identity(self.gamma) | ||
return tf.nn.batch_norm_with_global_normalization( | ||
x, mean, variance, local_beta, local_gamma, | ||
self.epsilon, self.scale_after_norm) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,293 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
import glob | ||
import sys | ||
from matplotlib import pyplot as plt | ||
from batchnorm import ConvolutionalBatchNormalizer | ||
|
||
filenames = sorted(glob.glob("../colornet/*/*.jpg")) | ||
batch_size = 1 | ||
num_epochs = 1e+9 | ||
|
||
global_step = tf.Variable(0, name='global_step', trainable=False) | ||
phase_train = tf.placeholder(tf.bool, name='phase_train') | ||
uv = tf.placeholder(tf.uint8, name='uv') | ||
|
||
|
||
def read_my_file_format(filename_queue, randomize=False): | ||
reader = tf.WholeFileReader() | ||
key, file = reader.read(filename_queue) | ||
uint8image = tf.image.decode_jpeg(file, channels=3) | ||
uint8image = tf.random_crop(uint8image, (224, 224, 3)) | ||
if randomize: | ||
uint8image = tf.image.random_flip_left_right(uint8image) | ||
uint8image = tf.image.random_flip_up_down(uint8image, seed=None) | ||
float_image = tf.div(tf.cast(uint8image, tf.float32), 255) | ||
return float_image | ||
|
||
|
||
def input_pipeline(filenames, batch_size, num_epochs=None): | ||
filename_queue = tf.train.string_input_producer( | ||
filenames, num_epochs=num_epochs, shuffle=False) | ||
example = read_my_file_format(filename_queue, randomize=False) | ||
min_after_dequeue = 100 | ||
capacity = min_after_dequeue + 3 * batch_size | ||
example_batch = tf.train.shuffle_batch( | ||
[example], batch_size=batch_size, capacity=capacity, | ||
min_after_dequeue=min_after_dequeue) | ||
return example_batch | ||
|
||
|
||
def batch_norm(x, depth, phase_train): | ||
with tf.variable_scope('batchnorm'): | ||
ewma = tf.train.ExponentialMovingAverage(decay=0.9999) | ||
bn = ConvolutionalBatchNormalizer(depth, 0.001, ewma, True) | ||
update_assignments = bn.get_assigner() | ||
x = bn.normalize(x, train=phase_train) | ||
return x | ||
|
||
|
||
def conv2d(_X, w, sigmoid=False, bn=False): | ||
with tf.variable_scope('conv2d'): | ||
_X = tf.nn.conv2d(_X, w, [1, 1, 1, 1], 'SAME') | ||
if bn: | ||
_X = batch_norm(_X, w.get_shape()[3], phase_train) | ||
if sigmoid: | ||
return tf.sigmoid(_X) | ||
else: | ||
_X = tf.nn.relu(_X) | ||
return tf.maximum(0.01 * _X, _X) | ||
|
||
|
||
def colornet(_tensors): | ||
""" | ||
Network architecture http://tinyclouds.org/colorize/residual_encoder.png | ||
""" | ||
with tf.variable_scope('colornet'): | ||
# Bx28x28x512 -> batch norm -> 1x1 conv = Bx28x28x256 | ||
conv1 = tf.nn.relu(tf.nn.conv2d(batch_norm(_tensors[ | ||
"conv4_3"], 512, phase_train), | ||
_tensors["weights"]["wc1"], [1, 1, 1, 1], 'SAME')) | ||
# upscale to 56x56x256 | ||
conv1 = tf.image.resize_bilinear(conv1, (56, 56)) | ||
conv1 = tf.add(conv1, batch_norm( | ||
_tensors["conv3_3"], 256, phase_train)) | ||
|
||
# Bx56x56x256-> 3x3 conv = Bx56x56x128 | ||
conv2 = conv2d(conv1, _tensors["weights"][ | ||
'wc2'], sigmoid=False, bn=True) | ||
# upscale to 112x112x128 | ||
conv2 = tf.image.resize_bilinear(conv2, (112, 112)) | ||
conv2 = tf.add(conv2, batch_norm( | ||
_tensors["conv2_2"], 128, phase_train)) | ||
|
||
# Bx112x112x128 -> 3x3 conv = Bx112x112x64 | ||
conv3 = conv2d(conv2, _tensors["weights"][ | ||
'wc3'], sigmoid=False, bn=True) | ||
# upscale to Bx224x224x64 | ||
conv3 = tf.image.resize_bilinear(conv3, (224, 224)) | ||
conv3 = tf.add(conv3, batch_norm(_tensors["conv1_2"], 64, phase_train)) | ||
|
||
# Bx224x224x64 -> 3x3 conv = Bx224x224x3 | ||
conv4 = conv2d(conv3, _tensors["weights"][ | ||
'wc4'], sigmoid=False, bn=True) | ||
conv4 = tf.add(conv4, batch_norm( | ||
_tensors["grayscale"], 3, phase_train)) | ||
|
||
# Bx224x224x3 -> 3x3 conv = Bx224x224x3 | ||
conv5 = conv2d(conv4, _tensors["weights"][ | ||
'wc5'], sigmoid=False, bn=True) | ||
# Bx224x224x3 -> 3x3 conv = Bx224x224x2 | ||
conv6 = conv2d(conv5, _tensors["weights"][ | ||
'wc6'], sigmoid=True, bn=True) | ||
|
||
return conv6 | ||
|
||
|
||
def concat_images(imga, imgb): | ||
""" | ||
Combines two color image ndarrays side-by-side. | ||
""" | ||
ha, wa = imga.shape[:2] | ||
hb, wb = imgb.shape[:2] | ||
max_height = np.max([ha, hb]) | ||
total_width = wa + wb | ||
new_img = np.zeros(shape=(max_height, total_width, 3), dtype=np.float32) | ||
new_img[:ha, :wa] = imga | ||
new_img[:hb, wa:wa + wb] = imgb | ||
return new_img | ||
|
||
|
||
def rgb2yuv(rgb): | ||
""" | ||
Convert RGB image into YUV https://en.wikipedia.org/wiki/YUV | ||
""" | ||
rgb2yuv_filter = tf.constant( | ||
[[[[0.299, -0.169, 0.499], | ||
[0.587, -0.331, -0.418], | ||
[0.114, 0.499, -0.0813]]]]) | ||
rgb2yuv_bias = tf.constant([0., 0.5, 0.5]) | ||
|
||
temp = tf.nn.conv2d(rgb, rgb2yuv_filter, [1, 1, 1, 1], 'SAME') | ||
temp = tf.nn.bias_add(temp, rgb2yuv_bias) | ||
|
||
return temp | ||
|
||
|
||
def yuv2rgb(yuv): | ||
""" | ||
Convert YUV image into RGB https://en.wikipedia.org/wiki/YUV | ||
""" | ||
yuv = tf.mul(yuv, 255) | ||
yuv2rgb_filter = tf.constant( | ||
[[[[1., 1., 1.], | ||
[0., -0.34413999, 1.77199996], | ||
[1.40199995, -0.71414, 0.]]]]) | ||
yuv2rgb_bias = tf.constant([-179.45599365, 135.45983887, -226.81599426]) | ||
temp = tf.nn.conv2d(yuv, yuv2rgb_filter, [1, 1, 1, 1], 'SAME') | ||
temp = tf.nn.bias_add(temp, yuv2rgb_bias) | ||
temp = tf.maximum(temp, tf.zeros(temp.get_shape(), dtype=tf.float32)) | ||
temp = tf.minimum(temp, tf.mul( | ||
tf.ones(temp.get_shape(), dtype=tf.float32), 255)) | ||
temp = tf.div(temp, 255) | ||
return temp | ||
|
||
|
||
with open("vgg/tensorflow-vgg16/vgg16-20160129.tfmodel", mode='rb') as f: | ||
fileContent = f.read() | ||
|
||
graph_def = tf.GraphDef() | ||
graph_def.ParseFromString(fileContent) | ||
|
||
with tf.variable_scope('colornet'): | ||
# Store layers weight | ||
weights = { | ||
# 1x1 conv, 512 inputs, 256 outputs | ||
'wc1': tf.Variable(tf.truncated_normal([1, 1, 512, 256], stddev=0.01)), | ||
# 3x3 conv, 512 inputs, 128 outputs | ||
'wc2': tf.Variable(tf.truncated_normal([3, 3, 256, 128], stddev=0.01)), | ||
# 3x3 conv, 256 inputs, 64 outputs | ||
'wc3': tf.Variable(tf.truncated_normal([3, 3, 128, 64], stddev=0.01)), | ||
# 3x3 conv, 128 inputs, 3 outputs | ||
'wc4': tf.Variable(tf.truncated_normal([3, 3, 64, 3], stddev=0.01)), | ||
# 3x3 conv, 6 inputs, 3 outputs | ||
'wc5': tf.Variable(tf.truncated_normal([3, 3, 3, 3], stddev=0.01)), | ||
# 3x3 conv, 3 inputs, 2 outputs | ||
'wc6': tf.Variable(tf.truncated_normal([3, 3, 3, 2], stddev=0.01)), | ||
} | ||
|
||
colorimage = input_pipeline(filenames, batch_size, num_epochs=num_epochs) | ||
colorimage_yuv = rgb2yuv(colorimage) | ||
|
||
grayscale = tf.image.rgb_to_grayscale(colorimage) | ||
grayscale_rgb = tf.image.grayscale_to_rgb(grayscale) | ||
grayscale_yuv = rgb2yuv(grayscale_rgb) | ||
grayscale = tf.concat(3, [grayscale, grayscale, grayscale]) | ||
|
||
tf.import_graph_def(graph_def, input_map={"images": grayscale}) | ||
|
||
graph = tf.get_default_graph() | ||
|
||
with tf.variable_scope('vgg'): | ||
conv1_2 = graph.get_tensor_by_name("import/conv1_2/Relu:0") | ||
conv2_2 = graph.get_tensor_by_name("import/conv2_2/Relu:0") | ||
conv3_3 = graph.get_tensor_by_name("import/conv3_3/Relu:0") | ||
conv4_3 = graph.get_tensor_by_name("import/conv4_3/Relu:0") | ||
|
||
tensors = { | ||
"conv1_2": conv1_2, | ||
"conv2_2": conv2_2, | ||
"conv3_3": conv3_3, | ||
"conv4_3": conv4_3, | ||
"grayscale": grayscale, | ||
"weights": weights | ||
} | ||
|
||
# Construct model | ||
pred = colornet(tensors) | ||
pred_yuv = tf.concat(3, [tf.split(3, 3, grayscale_yuv)[0], pred]) | ||
pred_rgb = yuv2rgb(pred_yuv) | ||
|
||
loss = tf.square(tf.sub(pred, tf.concat( | ||
3, [tf.split(3, 3, colorimage_yuv)[1], tf.split(3, 3, colorimage_yuv)[2]]))) | ||
|
||
if uv == 1: | ||
loss = tf.split(3, 2, loss)[0] | ||
elif uv == 2: | ||
loss = tf.split(3, 2, loss)[1] | ||
else: | ||
loss = (tf.split(3, 2, loss)[0] + tf.split(3, 2, loss)[1]) / 2 | ||
|
||
if phase_train: | ||
optimizer = tf.train.GradientDescentOptimizer(0.0001) | ||
opt = optimizer.minimize( | ||
loss, global_step=global_step, gate_gradients=optimizer.GATE_NONE) | ||
|
||
# Summaries | ||
tf.histogram_summary("weights1", weights["wc1"]) | ||
tf.histogram_summary("weights2", weights["wc2"]) | ||
tf.histogram_summary("weights3", weights["wc3"]) | ||
tf.histogram_summary("weights4", weights["wc4"]) | ||
tf.histogram_summary("weights5", weights["wc5"]) | ||
tf.histogram_summary("weights6", weights["wc6"]) | ||
tf.histogram_summary("instant_loss", tf.reduce_mean(loss)) | ||
tf.image_summary("colorimage", colorimage, max_images=1) | ||
tf.image_summary("pred_rgb", pred_rgb, max_images=1) | ||
tf.image_summary("grayscale", grayscale_rgb, max_images=1) | ||
|
||
# Saver. | ||
saver = tf.train.Saver() | ||
|
||
# Create the graph, etc. | ||
init_op = tf.initialize_all_variables() | ||
|
||
# Create a session for running operations in the Graph. | ||
sess = tf.Session() | ||
|
||
# Initialize the variables. | ||
sess.run(init_op) | ||
|
||
merged = tf.merge_all_summaries() | ||
writer = tf.train.SummaryWriter("tb_log", sess.graph_def) | ||
|
||
# Start input enqueue threads. | ||
coord = tf.train.Coordinator() | ||
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
|
||
try: | ||
while not coord.should_stop(): | ||
# Run training steps | ||
training_opt = sess.run(opt, feed_dict={phase_train: True, uv: 1}) | ||
training_opt = sess.run(opt, feed_dict={phase_train: True, uv: 2}) | ||
|
||
step = sess.run(global_step) | ||
|
||
if step % 1 == 0: | ||
pred_, pred_rgb_, colorimage_, grayscale_rgb_, cost, merged_ = sess.run( | ||
[pred, pred_rgb, colorimage, grayscale_rgb, loss, merged], feed_dict={phase_train: False, uv: 3}) | ||
print { | ||
"step": step, | ||
"cost": np.mean(cost) | ||
} | ||
if step % 1000 == 0: | ||
summary_image = concat_images(grayscale_rgb_[0], pred_rgb_[0]) | ||
summary_image = concat_images(summary_image, colorimage_[0]) | ||
plt.imsave("summary/" + str(step) + "_0", summary_image) | ||
|
||
sys.stdout.flush() | ||
writer.add_summary(merged_, step) | ||
writer.flush() | ||
if step % 100000 == 99998: | ||
save_path = saver.save(sess, "model.ckpt") | ||
print("Model saved in file: %s" % save_path) | ||
sys.stdout.flush() | ||
|
||
except tf.errors.OutOfRangeError: | ||
print('Done training -- epoch limit reached') | ||
finally: | ||
# When done, ask the threads to stop. | ||
coord.request_stop() | ||
|
||
# Wait for threads to finish. | ||
coord.join(threads) | ||
sess.close() |
Binary file not shown.