-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable batch fetching in parallel #748
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,9 @@ | |
import decimal | ||
# Must import pyarrow before torch. See: https://github.com/uber/petastorm/blob/master/docs/troubleshoot.rst | ||
import re | ||
import threading | ||
import logging | ||
from queue import Queue | ||
import numpy as np | ||
from six import PY2 | ||
from torch.utils.data.dataloader import default_collate | ||
|
@@ -100,11 +102,38 @@ def decimal_friendly_collate(batch): | |
loader." | ||
|
||
|
||
class BackgroundIterator(threading.Thread): | ||
"""Prefetch iterator results. A thread iterates the original iterator and | ||
populates a queue. Iterating over this background iterator just consumes the underlying | ||
queue until no other result is available.""" | ||
def __init__(self, iterator, prefetch=1000): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having default value of 1000 batches for the queue size maybe a bit too much, given a batch is a row-group, and a rowgroup of couple of hundreds MBs are common. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Loader's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are absolutely right. |
||
threading.Thread.__init__(self) | ||
self.queue = Queue(prefetch) | ||
self.iterator = iterator | ||
self.daemon = True | ||
self.start() | ||
|
||
def run(self): | ||
for item in self.iterator: | ||
self.queue.put(item) | ||
self.queue.put(None) | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
next_item = self.queue.get() | ||
if next_item is None: | ||
raise StopIteration | ||
return next_item | ||
|
||
|
||
class LoaderBase(object): | ||
|
||
def __init__(self): | ||
self._in_iter = None | ||
self._error = None | ||
self._max_prefetch = 1 | ||
|
||
def __iter__(self): | ||
if self._error is not None: | ||
|
@@ -118,7 +147,10 @@ def __iter__(self): | |
self._in_iter = True | ||
|
||
try: | ||
for batch in self._iter_impl(): | ||
iterator = self._iter_impl() | ||
if self._max_prefetch > 1: | ||
iterator = BackgroundIterator(iterator, prefetch=self.max_prefetch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to check this, thank you for the heads up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know it could be tricky, but would be good to also test these aspects in a unit test or two. |
||
for batch in iterator: | ||
yield batch | ||
except Exception as e: | ||
self._error = e | ||
|
@@ -264,7 +296,8 @@ class BatchedDataLoader(LoaderBase): | |
|
||
def __init__(self, reader, batch_size=1, | ||
transform_fn=None, | ||
shuffling_queue_capacity=0): | ||
shuffling_queue_capacity=0, | ||
batch_max_prefetch=None): | ||
""" | ||
Initializes a data loader object. | ||
|
||
|
@@ -287,6 +320,9 @@ def __init__(self, reader, batch_size=1, | |
:param transform_fn: an optional callable to convert batches from the reader to PyTorch tensors | ||
:param shuffling_queue_capacity: Queue capacity is passed to the underlying :class:`tf.RandomShuffleQueue` | ||
instance. If set to 0, no shuffling will be done. | ||
:param batch_max_prefetch: an optional int indicating maximum number of batches to fetch in | ||
advance. This is specially useful when training models in order to improve model data | ||
throughput. | ||
""" | ||
super(BatchedDataLoader, self).__init__() | ||
self.reader = reader | ||
|
@@ -298,6 +334,11 @@ def __init__(self, reader, batch_size=1, | |
self.shuffling_queue_capacity = shuffling_queue_capacity | ||
self._in_iter = None | ||
|
||
# fetch batches in advance? | ||
if batch_max_prefetch is not None: | ||
assert batch_max_prefetch > 0, "if set, batch_max_prefetch must be greater or equal to 1" | ||
self._max_prefetch = batch_max_prefetch | ||
|
||
def _iter_impl(self): | ||
""" | ||
The Data Loader iterator stops the for-loop when reader runs out of samples. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really prefetching? Setting queue size does not guarantee that we will prefetch the data before user starts consuming it, does it? Perhaps we call the argument 'queue_size'?
Frankly, I am not sure how prefetching helps steady-state throughput. Wouldn't it just eliminate some hiccups when the training starts at the expense of training starting a bit later? Isn't steady state throughput the only important characteristic here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, probably prefetching is not the right word. As discussed in #740, the main motivation of this PR is to enable parallel batch building while training a model (otherwise the model will always have to wait for a batch to be available and this may take some time, specially if the dataset has a big number of columns). I have observed a ~3x speedup in data throughput with this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure - I see how this can speed up the training. This is a good change.
In my understanding we are not really doing prefetching here (depending on the timing, the consumer might try to fetch the first batch before the thread has populated it, i.e. nothing was prefetched).
If you are ok with just changing the name from prefetching to queue size, everything will fall in place then.