-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable batch fetching in parallel #748
base: master
Are you sure you want to change the base?
Conversation
Jordi Aranda seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we are in the right direction! Left some questions.
petastorm/pytorch.py
Outdated
for batch in self._iter_impl(): | ||
iterator = self._iter_impl() | ||
if self._max_prefetch > 1: | ||
iterator = BackgroundIterator(iterator, prefetch=self.max_prefetch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the BackgroundIterator
be properly destroyed and the thread joined when we exit the function (either nominal exit or with an exception?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check this, thank you for the heads up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it could be tricky, but would be good to also test these aspects in a unit test or two.
petastorm/pytorch.py
Outdated
"""Prefetch iterator results. A thread iterates the original iterator and | ||
populates a queue. Iterating over this background iterator just consumes the underlying | ||
queue until no other result is available.""" | ||
def __init__(self, iterator, prefetch=1000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really prefetching? Setting queue size does not guarantee that we will prefetch the data before user starts consuming it, does it? Perhaps we call the argument 'queue_size'?
Frankly, I am not sure how prefetching helps steady-state throughput. Wouldn't it just eliminate some hiccups when the training starts at the expense of training starting a bit later? Isn't steady state throughput the only important characteristic here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, probably prefetching is not the right word. As discussed in #740, the main motivation of this PR is to enable parallel batch building while training a model (otherwise the model will always have to wait for a batch to be available and this may take some time, specially if the dataset has a big number of columns). I have observed a ~3x speedup in data throughput with this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure - I see how this can speed up the training. This is a good change.
In my understanding we are not really doing prefetching here (depending on the timing, the consumer might try to fetch the first batch before the thread has populated it, i.e. nothing was prefetched).
If you are ok with just changing the name from prefetching to queue size, everything will fall in place then.
petastorm/pytorch.py
Outdated
"""Prefetch iterator results. A thread iterates the original iterator and | ||
populates a queue. Iterating over this background iterator just consumes the underlying | ||
queue until no other result is available.""" | ||
def __init__(self, iterator, prefetch=1000): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having default value of 1000 batches for the queue size maybe a bit too much, given a batch is a row-group, and a rowgroup of couple of hundreds MBs are common.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loader's _iter_impl
yields batches and not row groups, right? This is what is enqueued. It is true that depending on the queue size more or less row groups will be processed, but I expect this to be controlled via the queue size and the batch size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are absolutely right.
Codecov Report
@@ Coverage Diff @@
## master #748 +/- ##
==========================================
- Coverage 86.27% 85.99% -0.28%
==========================================
Files 85 85
Lines 5084 5111 +27
Branches 787 791 +4
==========================================
+ Hits 4386 4395 +9
- Misses 559 575 +16
- Partials 139 141 +2
Continue to review full report at Codecov.
|
@selitvin could you please have a look? Thank you! |
def run(self): | ||
while not self.stop.isSet(): | ||
for item in self.iterator: | ||
self.queue.put(item) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can use blocking puts and gets and end up with a solution that is robust to deadlocks. Let's see if this works:
- Producer ended up filling the queue and waits on a blocking put.
- Consumer fails. Calls
iterator.stop.set()
, however since .put is blocking within an iterator queue, the event is never checked and the thread is not shut down.
Another scenario:
- The queue is empty, hence consumer waits on a blocking
.get
. - However, producer raises an exception. The thread dies and the consumer is stuck forever on a .get.
I think a robust implementation for a BackgroundIterator
could get pretty tricky. All these edge cases need to be carefully tested as these kind of failures would be hard to catch in production.
yield batch | ||
except Exception as e: | ||
self._error = e | ||
logger.error('Iteration on Petastorm DataLoader raise error: %s', repr(e)) | ||
raise | ||
finally: | ||
self._in_iter = False | ||
if isinstance(iterator, BackgroundIterator): | ||
iterator.stop.set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make stop
a private member (i.e. _stop
) and add an API to the BackgroundIterator that performs stop (encapsulation principle).
def __init__(self, iterator, queue_size=1000): | ||
threading.Thread.__init__(self) | ||
self.name = "background_iterator" | ||
self.queue = Queue(queue_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's mark all data members that are not intended to be exposed to BackgroundIterator users as private (_
prefix).
Just noticed this nice work. @jarandaf Thanks! |
This is WIP. Happy to hear your thoughts on this @selitvin.