Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable batch fetching in parallel #748

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions petastorm/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import decimal
# Must import pyarrow before torch. See: https://github.com/uber/petastorm/blob/master/docs/troubleshoot.rst
import re
import threading
import logging
from queue import Queue
import numpy as np
from six import PY2
from torch.utils.data.dataloader import default_collate
Expand Down Expand Up @@ -100,11 +102,38 @@ def decimal_friendly_collate(batch):
loader."


class BackgroundIterator(threading.Thread):
"""Prefetch iterator results. A thread iterates the original iterator and
populates a queue. Iterating over this background iterator just consumes the underlying
queue until no other result is available."""
def __init__(self, iterator, prefetch=1000):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really prefetching? Setting queue size does not guarantee that we will prefetch the data before user starts consuming it, does it? Perhaps we call the argument 'queue_size'?
Frankly, I am not sure how prefetching helps steady-state throughput. Wouldn't it just eliminate some hiccups when the training starts at the expense of training starting a bit later? Isn't steady state throughput the only important characteristic here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, probably prefetching is not the right word. As discussed in #740, the main motivation of this PR is to enable parallel batch building while training a model (otherwise the model will always have to wait for a batch to be available and this may take some time, specially if the dataset has a big number of columns). I have observed a ~3x speedup in data throughput with this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - I see how this can speed up the training. This is a good change.

In my understanding we are not really doing prefetching here (depending on the timing, the consumer might try to fetch the first batch before the thread has populated it, i.e. nothing was prefetched).

If you are ok with just changing the name from prefetching to queue size, everything will fall in place then.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having default value of 1000 batches for the queue size maybe a bit too much, given a batch is a row-group, and a rowgroup of couple of hundreds MBs are common.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loader's _iter_impl yields batches and not row groups, right? This is what is enqueued. It is true that depending on the queue size more or less row groups will be processed, but I expect this to be controlled via the queue size and the batch size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are absolutely right.

threading.Thread.__init__(self)
self.queue = Queue(prefetch)
self.iterator = iterator
self.daemon = True
self.start()

def run(self):
for item in self.iterator:
self.queue.put(item)
self.queue.put(None)

def __iter__(self):
return self

def __next__(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item


class LoaderBase(object):

def __init__(self):
self._in_iter = None
self._error = None
self._max_prefetch = 1

def __iter__(self):
if self._error is not None:
Expand All @@ -118,7 +147,10 @@ def __iter__(self):
self._in_iter = True

try:
for batch in self._iter_impl():
iterator = self._iter_impl()
if self._max_prefetch > 1:
iterator = BackgroundIterator(iterator, prefetch=self.max_prefetch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the BackgroundIterator be properly destroyed and the thread joined when we exit the function (either nominal exit or with an exception?).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check this, thank you for the heads up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it could be tricky, but would be good to also test these aspects in a unit test or two.

for batch in iterator:
yield batch
except Exception as e:
self._error = e
Expand Down Expand Up @@ -264,7 +296,8 @@ class BatchedDataLoader(LoaderBase):

def __init__(self, reader, batch_size=1,
transform_fn=None,
shuffling_queue_capacity=0):
shuffling_queue_capacity=0,
batch_max_prefetch=None):
"""
Initializes a data loader object.

Expand All @@ -287,6 +320,9 @@ def __init__(self, reader, batch_size=1,
:param transform_fn: an optional callable to convert batches from the reader to PyTorch tensors
:param shuffling_queue_capacity: Queue capacity is passed to the underlying :class:`tf.RandomShuffleQueue`
instance. If set to 0, no shuffling will be done.
:param batch_max_prefetch: an optional int indicating maximum number of batches to fetch in
advance. This is specially useful when training models in order to improve model data
throughput.
"""
super(BatchedDataLoader, self).__init__()
self.reader = reader
Expand All @@ -298,6 +334,11 @@ def __init__(self, reader, batch_size=1,
self.shuffling_queue_capacity = shuffling_queue_capacity
self._in_iter = None

# fetch batches in advance?
if batch_max_prefetch is not None:
assert batch_max_prefetch > 0, "if set, batch_max_prefetch must be greater or equal to 1"
self._max_prefetch = batch_max_prefetch

def _iter_impl(self):
"""
The Data Loader iterator stops the for-loop when reader runs out of samples.
Expand Down