Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable batch fetching in parallel #748

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions petastorm/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import decimal
# Must import pyarrow before torch. See: https://github.com/uber/petastorm/blob/master/docs/troubleshoot.rst
import re
import threading
import logging
from queue import Queue
import numpy as np
from six import PY2
from torch.utils.data.dataloader import default_collate
Expand Down Expand Up @@ -100,11 +102,42 @@ def decimal_friendly_collate(batch):
loader."


class BackgroundIterator(threading.Thread):
"""Prefetch iterator results. A thread iterates the original iterator and
populates a queue. Iterating over this background iterator just consumes the underlying
queue until no other result is available."""
def __init__(self, iterator, queue_size=1000):
threading.Thread.__init__(self)
self.name = "background_iterator"
self.queue = Queue(queue_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mark all data members that are not intended to be exposed to BackgroundIterator users as private (_ prefix).

self.iterator = iterator
self.stop = threading.Event()
self.start()

def run(self):
while not self.stop.isSet():
for item in self.iterator:
self.queue.put(item)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can use blocking puts and gets and end up with a solution that is robust to deadlocks. Let's see if this works:

  • Producer ended up filling the queue and waits on a blocking put.
  • Consumer fails. Calls iterator.stop.set(), however since .put is blocking within an iterator queue, the event is never checked and the thread is not shut down.

Another scenario:

  • The queue is empty, hence consumer waits on a blocking .get.
  • However, producer raises an exception. The thread dies and the consumer is stuck forever on a .get.

I think a robust implementation for a BackgroundIterator could get pretty tricky. All these edge cases need to be carefully tested as these kind of failures would be hard to catch in production.

self.queue.put(None)
self.stop.set()
return

def __iter__(self):
return self

def __next__(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item


class LoaderBase(object):

def __init__(self):
self._in_iter = None
self._error = None
self._queue_size = 1

def __iter__(self):
if self._error is not None:
Expand All @@ -117,15 +150,20 @@ def __iter__(self):
logger.warning('Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.')
self._in_iter = True

iterator = self._iter_impl()
try:
for batch in self._iter_impl():
if self._queue_size > 1:
iterator = BackgroundIterator(iterator, queue_size=self._queue_size)
for batch in iterator:
yield batch
except Exception as e:
self._error = e
logger.error('Iteration on Petastorm DataLoader raise error: %s', repr(e))
raise
finally:
self._in_iter = False
if isinstance(iterator, BackgroundIterator):
iterator.stop.set()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make stop a private member (i.e. _stop) and add an API to the BackgroundIterator that performs stop (encapsulation principle).



class DataLoader(LoaderBase):
Expand Down Expand Up @@ -264,7 +302,8 @@ class BatchedDataLoader(LoaderBase):

def __init__(self, reader, batch_size=1,
transform_fn=None,
shuffling_queue_capacity=0):
shuffling_queue_capacity=0,
batch_queue_size=None):
"""
Initializes a data loader object.

Expand All @@ -287,6 +326,8 @@ def __init__(self, reader, batch_size=1,
:param transform_fn: an optional callable to convert batches from the reader to PyTorch tensors
:param shuffling_queue_capacity: Queue capacity is passed to the underlying :class:`tf.RandomShuffleQueue`
instance. If set to 0, no shuffling will be done.
:param batch_queue_size: an optional int indicating maximum number of batches to fetch in
parallel. This might be useful when training models in order to improve model data throughput.
"""
super(BatchedDataLoader, self).__init__()
self.reader = reader
Expand All @@ -298,6 +339,11 @@ def __init__(self, reader, batch_size=1,
self.shuffling_queue_capacity = shuffling_queue_capacity
self._in_iter = None

# fetch batches in parallel?
if batch_queue_size is not None:
assert batch_queue_size > 0, "if set, batch_queue_size must be greater or equal to 1"
self._queue_size = batch_queue_size

def _iter_impl(self):
"""
The Data Loader iterator stops the for-loop when reader runs out of samples.
Expand Down
41 changes: 40 additions & 1 deletion petastorm/tests/test_pytorch_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from decimal import Decimal
from packaging import version
import time
import threading

import numpy as np
import pyarrow as pa
Expand All @@ -10,7 +12,8 @@
from petastorm import make_reader, TransformSpec, make_batch_reader
from petastorm.pytorch import (_sanitize_pytorch_types, DataLoader,
BatchedDataLoader, decimal_friendly_collate,
InMemBatchedDataLoader, _load_rows_into_mem)
InMemBatchedDataLoader, _load_rows_into_mem,
BackgroundIterator)
from petastorm.tests.test_common import TestSchema

BASIC_DATA_LOADERS = [DataLoader, BatchedDataLoader]
Expand Down Expand Up @@ -331,3 +334,39 @@ def test_inmem_batched_dataloader_shuffle_per_epoch(synthetic_dataset, reader_fa

with pytest.raises(StopIteration):
next(it)


def test_background_iterator():
# number of iterator elements
N = int(1e6)

# wait some time for the queue to be filled
bit = BackgroundIterator(range(N), queue_size=1000)
time.sleep(1)

assert not bit.queue.empty()
# ensure the thread exists and is populating the queue
assert "background_iterator" in [t.name for t in threading.enumerate()]
# ensure we process the same number of elements present in original iterator
n = 0
for _ in bit:
n += 1
assert n == N
# ensure the thread stopped when original iterator was completed
time.sleep(1)
assert "background_iterator" not in [t.name for t in threading.enumerate()]


@pytest.mark.parametrize('reader_factory', ALL_READER_FLAVOR_FACTORIES)
def test_batched_dataloader_background_iterator_handle_exception(synthetic_dataset, reader_factory):
with BatchedDataLoader(reader_factory(synthetic_dataset.url, schema_fields=TORCH_BATCHABLE_FIELDS,
transform_spec=TransformSpec(_sensor_name_to_int)),
batch_queue_size=100) as loader:
try:
for _ in loader:
assert "background_iterator" in [t.name for t in threading.enumerate()]
raise RuntimeError()
except RuntimeError:
# ensure we wait enough for the thread to be stopped
time.sleep(1)
assert "background_iterator" not in [t.name for t in threading.enumerate()]