Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Reading GIF files #46

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.log
*.pyc
build/*.jar
.coverage

docs/_site
docs/api
Expand All @@ -18,6 +19,9 @@ src_managed/
project/boot/
project/plugins/project/

# spark
metastore_db

# intellij
.idea/

Expand Down
84 changes: 80 additions & 4 deletions python/sparkdl/image/imageIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
StructField("nChannels", IntegerType(), False),
StructField("data", BinaryType(), False)])

gifSchema = StructType([StructField("filePath", StringType(), False),
StructField("frameNum", IntegerType(), True),
StructField("gifFrame", imageSchema, True)])


# ImageType class for holding metadata about images stored in DataFrames.
# fields:
Expand Down Expand Up @@ -199,23 +203,77 @@ def _decodeImage(imageData):
image = imageArrayToStruct(imgArray, mode.sparkMode)
return image


def _decodeGif(gifData):
"""
Decode compressed GIF data into a sequence of images.

:param gifData: (bytes, bytearray) compressed GIF data in PIL compatible format.
:return: list of tuples of zero-indexed numbers and
DataFrame Rows of image structs: (idx, struct)
"""
try:
img = Image.open(BytesIO(gifData))
except IOError:
return [(None, None)]

if img.format.lower() == "gif":
mode = pilModeLookup["RGB"]
else:
warn("Image file does not appear to be a GIF")
return [(None, None)]

frames = []
i = 0
mypalette = img.getpalette()
try:
while True:
if not img.getpalette() and mypalette:
img.putpalette(mypalette)
newImg = Image.new("RGB", img.size)
newImg.paste(img)

newImgArray = np.asarray(newImg)
newImage = imageArrayToStruct(newImgArray, mode.sparkMode)
frames.append((i, newImage))

i += 1
img.seek(img.tell() + 1)
except EOFError:
# end of sequence
pass

return frames

# Creating a UDF on import can cause SparkContext issues sometimes.
# decodeImage = udf(_decodeImage, imageSchema)

def filesToRDD(sc, path, numPartitions=None):
"""
Read files from a directory to an RDD.

:param sc: SparkContext.
:param path: str, path to files.
:param numPartitions: int, number or partitions to use for reading files.
:return: RDD tuple of: (filePath: str, fileData: BinaryType)
"""
numPartitions = numPartitions or sc.defaultParallelism
rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions)
return rdd.map(lambda x: (x[0], bytearray(x[1])))


def filesToDF(sc, path, numPartitions=None):
"""
Read files from a directory to a DataFrame.

:param sc: SparkContext.
:param path: str, path to files.
:param numPartition: int, number or partitions to use for reading files.
:param numPartitions: int, number or partitions to use for reading files.
:return: DataFrame, with columns: (filePath: str, fileData: BinaryType)
"""
numPartitions = numPartitions or sc.defaultParallelism
schema = StructType([StructField("filePath", StringType(), False),
StructField("fileData", BinaryType(), False)])
rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions)
rdd = rdd.map(lambda x: (x[0], bytearray(x[1])))
rdd = filesToRDD(sc, path, numPartitions)
return rdd.toDF(schema)


Expand All @@ -235,3 +293,21 @@ def _readImages(imageDirectory, numPartition, sc):
decodeImage = udf(_decodeImage, imageSchema)
imageData = filesToDF(sc, imageDirectory, numPartitions=numPartition)
return imageData.select("filePath", decodeImage("fileData").alias("image"))


def readGifs(gifDirectory, numPartition=None):
"""
Read a directory of GIFs (or a single GIF) into a DataFrame.

:param sc: spark context
:param gifDirectory: str, file path.
:param numPartition: int, number or partitions to use for reading files.
:return: DataFrame, with columns: (filepath: str, image: imageSchema).
"""
return _readGifs(gifDirectory, numPartition, SparkContext.getOrCreate())


def _readGifs(gifDirectory, numPartition, sc):
gifsRDD = filesToRDD(sc, gifDirectory, numPartitions=numPartition)
framesRDD = gifsRDD.flatMap(lambda x: [(x[0], i, frame) for (i, frame) in _decodeGif(x[1])])
return framesRDD.toDF(gifSchema)
96 changes: 95 additions & 1 deletion python/tests/image/test_imageIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sparkdl.image import imageIO
from ..tests import SparkDLTestCase

# Create dome fake image data to work with
# Create some fake image data to work with
def create_image_data():
# Random image-like data
array = np.random.randint(0, 256, (10, 11, 3), 'uint8')
Expand Down Expand Up @@ -173,4 +173,98 @@ def test_filesTODF(self):
self.assertEqual(type(first.fileData), bytearray)


# Create some fake GIF data to work with
def create_gif_data():
# Random GIF-like data
arrays2D = [np.random.randint(0, 256, (10, 11), 'uint8') for _ in range(3)]
arrays3D = [np.dstack((a, a, a)) for a in arrays2D]
# Create frames in P mode because Pillow always reads GIFs as P or L images
frames = [PIL.Image.fromarray(a, mode='P') for a in arrays2D]

# Compress as GIF
gifFile = BytesIO()
frames[0].save(gifFile, 'gif', save_all=True, append_images=frames[1:], optimize=False)
gifFile.seek(0)

# Get GIF data as stream
gifData = gifFile.read()
return arrays3D, gifData

gifArray, gifData = create_gif_data()
frameArray = gifArray[0]


class BinaryGifFilesMock(object):

defaultParallelism = 4

def __init__(self, sc):
self.sc = sc

def binaryFiles(self, path, minPartitions=None):
gifsData = [["file/path", gifData],
["another/file/path", gifData],
["bad/gif", b"badGifData"]
]
rdd = self.sc.parallelize(gifsData)
if minPartitions is not None:
rdd = rdd.repartition(minPartitions)
return rdd


class TestReadGifs(SparkDLTestCase):
@classmethod
def setUpClass(cls):
super(TestReadGifs, cls).setUpClass()
cls.binaryFilesMock = BinaryGifFilesMock(cls.sc)

@classmethod
def tearDownClass(cls):
super(TestReadGifs, cls).tearDownClass()
cls.binaryFilesMock = None

def test_decodeGif(self):
badFrames = imageIO._decodeGif(b"xxx")
self.assertEqual(badFrames, [(None, None)])
gifFrames = imageIO._decodeGif(gifData)
self.assertIsNotNone(gifFrames)
self.assertEqual(len(gifFrames), 3)
self.assertEqual(len(gifFrames[0][1]), len(imageIO.imageSchema.names))
for n in imageIO.imageSchema.names:
gifFrames[0][1][n]

def test_gif_round_trip(self):
# Test round trip: array -> GIF frame -> sparkImg -> array
binarySchema = StructType([StructField("data", BinaryType(), False)])
df = self.session.sparkContext.parallelize([bytearray(gifData)])

# Convert to GIF frames
rdd = df.flatMap(lambda x: [f[1] for f in imageIO._decodeGif(x)])
framesDF = rdd.toDF(imageIO.imageSchema)
row = framesDF.first()

testArray = imageIO.imageStructToArray(row)
self.assertEqual(testArray.shape, frameArray.shape)
self.assertEqual(testArray.dtype, frameArray.dtype)
self.assertTrue(np.all(frameArray == testArray))

def test_readGifs(self):
# Test that reading
gifDF = imageIO._readGifs("some/path", 2, self.binaryFilesMock)
self.assertTrue("filePath" in gifDF.schema.names)
self.assertTrue("frameNum" in gifDF.schema.names)
self.assertTrue("gifFrame" in gifDF.schema.names)

# The DF should have 6 images (2 images, 3 frames each) and 1 null.
self.assertEqual(gifDF.count(), 7)
validGifs = gifDF.filter(col("gifFrame").isNotNull())
self.assertEqual(validGifs.count(), 6)

frame = validGifs.first().gifFrame
self.assertEqual(frame.height, frameArray.shape[0])
self.assertEqual(frame.width, frameArray.shape[1])
self.assertEqual(imageIO.imageType(frame).nChannels, frameArray.shape[2])
self.assertEqual(frame.data, frameArray.tobytes())


# TODO: make unit tests for arrayToImageRow on arrays of varying shapes, channels, dtypes.