-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathimage_history_buffer.py
55 lines (47 loc) · 2.3 KB
/
image_history_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Module implementing the image history buffer described in `2.3. Updating Discriminator using a History of
Refined Images` of https://arxiv.org/pdf/1612.07828v1.pdf.
"""
import numpy as np
class ImageHistoryBuffer(object):
def __init__(self, shape, max_size, batch_size):
"""
Initialize the class's state.
:param shape: Shape of the data to be stored in the image history buffer
(i.e. (0, img_height, img_width, img_channels)).
:param max_size: Maximum number of images that can be stored in the image history buffer.
:param batch_size: Batch size used to train GAN.
"""
self.image_history_buffer = np.zeros(shape=shape)
self.max_size = max_size
self.batch_size = batch_size
def add_to_image_history_buffer(self, images, nb_to_add=None):
"""
To be called during training of GAN. By default add batch_size // 2 images to the image history buffer each
time the generator generates a new batch of images.
:param images: Array of images (usually a batch) to be added to the image history buffer.
:param nb_to_add: The number of images from `images` to add to the image history buffer
(batch_size / 2 by default).
"""
if not nb_to_add:
nb_to_add = self.batch_size // 2
if len(self.image_history_buffer) < self.max_size:
np.append(self.image_history_buffer, images[:nb_to_add], axis=0)
elif len(self.image_history_buffer) == self.max_size:
self.image_history_buffer[:nb_to_add] = images[:nb_to_add]
else:
assert False
np.random.shuffle(self.image_history_buffer)
def get_from_image_history_buffer(self, nb_to_get=None):
"""
Get a random sample of images from the history buffer.
:param nb_to_get: Number of images to get from the image history buffer (batch_size / 2 by default).
:return: A random sample of `nb_to_get` images from the image history buffer, or an empty np array if the image
history buffer is empty.
"""
if not nb_to_get:
nb_to_get = self.batch_size // 2
try:
return self.image_history_buffer[:nb_to_get]
except IndexError:
return np.zeros(shape=0)