Skip to content

Commit

Permalink
Import clean version of my code
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgonchar committed Apr 19, 2016
0 parents commit 1a251a3
Show file tree
Hide file tree
Showing 13 changed files with 378 additions and 0 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#colornet
Neural Network to colorize grayscale images

Results
-------

|               Grayscale               |               Prediction               |            Ground Truth           |
|---|---|---|

![grayscale-pred-groundtruth](summary/3000_0.png?raw=true "grayscale-pred-groundtruth-3000")

![grayscale-pred-groundtruth](summary/7000_0.png?raw=true "grayscale-pred-groundtruth-7000")


Sources
-------
- [Automatic Colorization](http://tinyclouds.org/colorize/)
- [Hypercolumns for Object Segmentation and Fine-grained Localization](http://arxiv.org/pdf/1411.5752v2.pdf)
- [ILSVRC-2014 model (VGG team) with 16 weight layers](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md) and [Tensorflow version](https://github.com/ry/tensorflow-vgg16)
- [YUV from Wikipedia](https://en.wikipedia.org/wiki/YUV)
65 changes: 65 additions & 0 deletions batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""A helper class for managing batch normalization state.
This class is designed to simplify adding batch normalization
(http://arxiv.org/pdf/1502.03167v3.pdf) to your model by
managing the state variables associated with it.
Important use note: The function get_assigner() returns
an op that must be executed to save the updated state.
A suggested way to do this is to make execution of the
model optimizer force it, e.g., by:
update_assignments = tf.group(bn1.get_assigner(),
bn2.get_assigner())
with tf.control_dependencies([optimizer]):
optimizer = tf.group(update_assignments)
"""

import tensorflow as tf


class ConvolutionalBatchNormalizer(object):
"""Helper class that groups the normalization logic and variables.
Use:
ewma = tf.train.ExponentialMovingAverage(decay=0.99)
bn = ConvolutionalBatchNormalizer(depth, 0.001, ewma, True)
update_assignments = bn.get_assigner()
x = bn.normalize(y, train=training?)
(the output x will be batch-normalized).
"""

def __init__(self, depth, epsilon, ewma_trainer, scale_after_norm):
self.mean = tf.Variable(tf.constant(0.0, shape=[depth]),
trainable=False)
self.variance = tf.Variable(tf.constant(1.0, shape=[depth]),
trainable=False)
self.beta = tf.Variable(tf.constant(0.0, shape=[depth]))
self.gamma = tf.Variable(tf.constant(1.0, shape=[depth]))
self.ewma_trainer = ewma_trainer
self.epsilon = epsilon
self.scale_after_norm = scale_after_norm

def get_assigner(self):
"""Returns an EWMA apply op that must be invoked after optimization."""
return self.ewma_trainer.apply([self.mean, self.variance])

def normalize(self, x, train=True):
"""Returns a batch-normalized version of x."""
if train:
mean, variance = tf.nn.moments(x, [0, 1, 2])
assign_mean = self.mean.assign(mean)
assign_variance = self.variance.assign(variance)
with tf.control_dependencies([assign_mean, assign_variance]):
return tf.nn.batch_norm_with_global_normalization(
x, mean, variance, self.beta, self.gamma,
self.epsilon, self.scale_after_norm)
else:
mean = self.ewma_trainer.average(self.mean)
variance = self.ewma_trainer.average(self.variance)
local_beta = tf.identity(self.beta)
local_gamma = tf.identity(self.gamma)
return tf.nn.batch_norm_with_global_normalization(
x, mean, variance, local_beta, local_gamma,
self.epsilon, self.scale_after_norm)
Binary file added summary/1000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added summary/2000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added summary/3000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added summary/4000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added summary/5000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added summary/6000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added summary/7000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added summary/8000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added summary/9000_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
293 changes: 293 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
import tensorflow as tf
import numpy as np
import glob
import sys
from matplotlib import pyplot as plt
from batchnorm import ConvolutionalBatchNormalizer

filenames = sorted(glob.glob("../colornet/*/*.jpg"))
batch_size = 1
num_epochs = 1e+9

global_step = tf.Variable(0, name='global_step', trainable=False)
phase_train = tf.placeholder(tf.bool, name='phase_train')
uv = tf.placeholder(tf.uint8, name='uv')


def read_my_file_format(filename_queue, randomize=False):
reader = tf.WholeFileReader()
key, file = reader.read(filename_queue)
uint8image = tf.image.decode_jpeg(file, channels=3)
uint8image = tf.random_crop(uint8image, (224, 224, 3))
if randomize:
uint8image = tf.image.random_flip_left_right(uint8image)
uint8image = tf.image.random_flip_up_down(uint8image, seed=None)
float_image = tf.div(tf.cast(uint8image, tf.float32), 255)
return float_image


def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=False)
example = read_my_file_format(filename_queue, randomize=False)
min_after_dequeue = 100
capacity = min_after_dequeue + 3 * batch_size
example_batch = tf.train.shuffle_batch(
[example], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch


def batch_norm(x, depth, phase_train):
with tf.variable_scope('batchnorm'):
ewma = tf.train.ExponentialMovingAverage(decay=0.9999)
bn = ConvolutionalBatchNormalizer(depth, 0.001, ewma, True)
update_assignments = bn.get_assigner()
x = bn.normalize(x, train=phase_train)
return x


def conv2d(_X, w, sigmoid=False, bn=False):
with tf.variable_scope('conv2d'):
_X = tf.nn.conv2d(_X, w, [1, 1, 1, 1], 'SAME')
if bn:
_X = batch_norm(_X, w.get_shape()[3], phase_train)
if sigmoid:
return tf.sigmoid(_X)
else:
_X = tf.nn.relu(_X)
return tf.maximum(0.01 * _X, _X)


def colornet(_tensors):
"""
Network architecture http://tinyclouds.org/colorize/residual_encoder.png
"""
with tf.variable_scope('colornet'):
# Bx28x28x512 -> batch norm -> 1x1 conv = Bx28x28x256
conv1 = tf.nn.relu(tf.nn.conv2d(batch_norm(_tensors[
"conv4_3"], 512, phase_train),
_tensors["weights"]["wc1"], [1, 1, 1, 1], 'SAME'))
# upscale to 56x56x256
conv1 = tf.image.resize_bilinear(conv1, (56, 56))
conv1 = tf.add(conv1, batch_norm(
_tensors["conv3_3"], 256, phase_train))

# Bx56x56x256-> 3x3 conv = Bx56x56x128
conv2 = conv2d(conv1, _tensors["weights"][
'wc2'], sigmoid=False, bn=True)
# upscale to 112x112x128
conv2 = tf.image.resize_bilinear(conv2, (112, 112))
conv2 = tf.add(conv2, batch_norm(
_tensors["conv2_2"], 128, phase_train))

# Bx112x112x128 -> 3x3 conv = Bx112x112x64
conv3 = conv2d(conv2, _tensors["weights"][
'wc3'], sigmoid=False, bn=True)
# upscale to Bx224x224x64
conv3 = tf.image.resize_bilinear(conv3, (224, 224))
conv3 = tf.add(conv3, batch_norm(_tensors["conv1_2"], 64, phase_train))

# Bx224x224x64 -> 3x3 conv = Bx224x224x3
conv4 = conv2d(conv3, _tensors["weights"][
'wc4'], sigmoid=False, bn=True)
conv4 = tf.add(conv4, batch_norm(
_tensors["grayscale"], 3, phase_train))

# Bx224x224x3 -> 3x3 conv = Bx224x224x3
conv5 = conv2d(conv4, _tensors["weights"][
'wc5'], sigmoid=False, bn=True)
# Bx224x224x3 -> 3x3 conv = Bx224x224x2
conv6 = conv2d(conv5, _tensors["weights"][
'wc6'], sigmoid=True, bn=True)

return conv6


def concat_images(imga, imgb):
"""
Combines two color image ndarrays side-by-side.
"""
ha, wa = imga.shape[:2]
hb, wb = imgb.shape[:2]
max_height = np.max([ha, hb])
total_width = wa + wb
new_img = np.zeros(shape=(max_height, total_width, 3), dtype=np.float32)
new_img[:ha, :wa] = imga
new_img[:hb, wa:wa + wb] = imgb
return new_img


def rgb2yuv(rgb):
"""
Convert RGB image into YUV https://en.wikipedia.org/wiki/YUV
"""
rgb2yuv_filter = tf.constant(
[[[[0.299, -0.169, 0.499],
[0.587, -0.331, -0.418],
[0.114, 0.499, -0.0813]]]])
rgb2yuv_bias = tf.constant([0., 0.5, 0.5])

temp = tf.nn.conv2d(rgb, rgb2yuv_filter, [1, 1, 1, 1], 'SAME')
temp = tf.nn.bias_add(temp, rgb2yuv_bias)

return temp


def yuv2rgb(yuv):
"""
Convert YUV image into RGB https://en.wikipedia.org/wiki/YUV
"""
yuv = tf.mul(yuv, 255)
yuv2rgb_filter = tf.constant(
[[[[1., 1., 1.],
[0., -0.34413999, 1.77199996],
[1.40199995, -0.71414, 0.]]]])
yuv2rgb_bias = tf.constant([-179.45599365, 135.45983887, -226.81599426])
temp = tf.nn.conv2d(yuv, yuv2rgb_filter, [1, 1, 1, 1], 'SAME')
temp = tf.nn.bias_add(temp, yuv2rgb_bias)
temp = tf.maximum(temp, tf.zeros(temp.get_shape(), dtype=tf.float32))
temp = tf.minimum(temp, tf.mul(
tf.ones(temp.get_shape(), dtype=tf.float32), 255))
temp = tf.div(temp, 255)
return temp


with open("vgg/tensorflow-vgg16/vgg16-20160129.tfmodel", mode='rb') as f:
fileContent = f.read()

graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

with tf.variable_scope('colornet'):
# Store layers weight
weights = {
# 1x1 conv, 512 inputs, 256 outputs
'wc1': tf.Variable(tf.truncated_normal([1, 1, 512, 256], stddev=0.01)),
# 3x3 conv, 512 inputs, 128 outputs
'wc2': tf.Variable(tf.truncated_normal([3, 3, 256, 128], stddev=0.01)),
# 3x3 conv, 256 inputs, 64 outputs
'wc3': tf.Variable(tf.truncated_normal([3, 3, 128, 64], stddev=0.01)),
# 3x3 conv, 128 inputs, 3 outputs
'wc4': tf.Variable(tf.truncated_normal([3, 3, 64, 3], stddev=0.01)),
# 3x3 conv, 6 inputs, 3 outputs
'wc5': tf.Variable(tf.truncated_normal([3, 3, 3, 3], stddev=0.01)),
# 3x3 conv, 3 inputs, 2 outputs
'wc6': tf.Variable(tf.truncated_normal([3, 3, 3, 2], stddev=0.01)),
}

colorimage = input_pipeline(filenames, batch_size, num_epochs=num_epochs)
colorimage_yuv = rgb2yuv(colorimage)

grayscale = tf.image.rgb_to_grayscale(colorimage)
grayscale_rgb = tf.image.grayscale_to_rgb(grayscale)
grayscale_yuv = rgb2yuv(grayscale_rgb)
grayscale = tf.concat(3, [grayscale, grayscale, grayscale])

tf.import_graph_def(graph_def, input_map={"images": grayscale})

graph = tf.get_default_graph()

with tf.variable_scope('vgg'):
conv1_2 = graph.get_tensor_by_name("import/conv1_2/Relu:0")
conv2_2 = graph.get_tensor_by_name("import/conv2_2/Relu:0")
conv3_3 = graph.get_tensor_by_name("import/conv3_3/Relu:0")
conv4_3 = graph.get_tensor_by_name("import/conv4_3/Relu:0")

tensors = {
"conv1_2": conv1_2,
"conv2_2": conv2_2,
"conv3_3": conv3_3,
"conv4_3": conv4_3,
"grayscale": grayscale,
"weights": weights
}

# Construct model
pred = colornet(tensors)
pred_yuv = tf.concat(3, [tf.split(3, 3, grayscale_yuv)[0], pred])
pred_rgb = yuv2rgb(pred_yuv)

loss = tf.square(tf.sub(pred, tf.concat(
3, [tf.split(3, 3, colorimage_yuv)[1], tf.split(3, 3, colorimage_yuv)[2]])))

if uv == 1:
loss = tf.split(3, 2, loss)[0]
elif uv == 2:
loss = tf.split(3, 2, loss)[1]
else:
loss = (tf.split(3, 2, loss)[0] + tf.split(3, 2, loss)[1]) / 2

if phase_train:
optimizer = tf.train.GradientDescentOptimizer(0.0001)
opt = optimizer.minimize(
loss, global_step=global_step, gate_gradients=optimizer.GATE_NONE)

# Summaries
tf.histogram_summary("weights1", weights["wc1"])
tf.histogram_summary("weights2", weights["wc2"])
tf.histogram_summary("weights3", weights["wc3"])
tf.histogram_summary("weights4", weights["wc4"])
tf.histogram_summary("weights5", weights["wc5"])
tf.histogram_summary("weights6", weights["wc6"])
tf.histogram_summary("instant_loss", tf.reduce_mean(loss))
tf.image_summary("colorimage", colorimage, max_images=1)
tf.image_summary("pred_rgb", pred_rgb, max_images=1)
tf.image_summary("grayscale", grayscale_rgb, max_images=1)

# Saver.
saver = tf.train.Saver()

# Create the graph, etc.
init_op = tf.initialize_all_variables()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables.
sess.run(init_op)

merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("tb_log", sess.graph_def)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
while not coord.should_stop():
# Run training steps
training_opt = sess.run(opt, feed_dict={phase_train: True, uv: 1})
training_opt = sess.run(opt, feed_dict={phase_train: True, uv: 2})

step = sess.run(global_step)

if step % 1 == 0:
pred_, pred_rgb_, colorimage_, grayscale_rgb_, cost, merged_ = sess.run(
[pred, pred_rgb, colorimage, grayscale_rgb, loss, merged], feed_dict={phase_train: False, uv: 3})
print {
"step": step,
"cost": np.mean(cost)
}
if step % 1000 == 0:
summary_image = concat_images(grayscale_rgb_[0], pred_rgb_[0])
summary_image = concat_images(summary_image, colorimage_[0])
plt.imsave("summary/" + str(step) + "_0", summary_image)

sys.stdout.flush()
writer.add_summary(merged_, step)
writer.flush()
if step % 100000 == 99998:
save_path = saver.save(sess, "model.ckpt")
print("Model saved in file: %s" % save_path)
sys.stdout.flush()

except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()
Binary file not shown.

0 comments on commit 1a251a3

Please sign in to comment.