diff --git a/petastorm/pytorch.py b/petastorm/pytorch.py index 2550b408..3fb0ce2f 100644 --- a/petastorm/pytorch.py +++ b/petastorm/pytorch.py @@ -16,7 +16,9 @@ import decimal # Must import pyarrow before torch. See: https://github.com/uber/petastorm/blob/master/docs/troubleshoot.rst import re +import threading import logging +from queue import Queue import numpy as np from six import PY2 from torch.utils.data.dataloader import default_collate @@ -100,11 +102,42 @@ def decimal_friendly_collate(batch): loader." +class BackgroundIterator(threading.Thread): + """Prefetch iterator results. A thread iterates the original iterator and + populates a queue. Iterating over this background iterator just consumes the underlying + queue until no other result is available.""" + def __init__(self, iterator, queue_size=1000): + threading.Thread.__init__(self) + self.name = "background_iterator" + self.queue = Queue(queue_size) + self.iterator = iterator + self.stop = threading.Event() + self.start() + + def run(self): + while not self.stop.isSet(): + for item in self.iterator: + self.queue.put(item) + self.queue.put(None) + self.stop.set() + return + + def __iter__(self): + return self + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + class LoaderBase(object): def __init__(self): self._in_iter = None self._error = None + self._queue_size = 1 def __iter__(self): if self._error is not None: @@ -117,8 +150,11 @@ def __iter__(self): logger.warning('Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.') self._in_iter = True + iterator = self._iter_impl() try: - for batch in self._iter_impl(): + if self._queue_size > 1: + iterator = BackgroundIterator(iterator, queue_size=self._queue_size) + for batch in iterator: yield batch except Exception as e: self._error = e @@ -126,6 +162,8 @@ def __iter__(self): raise finally: self._in_iter = False + if isinstance(iterator, BackgroundIterator): + iterator.stop.set() class DataLoader(LoaderBase): @@ -264,7 +302,8 @@ class BatchedDataLoader(LoaderBase): def __init__(self, reader, batch_size=1, transform_fn=None, - shuffling_queue_capacity=0): + shuffling_queue_capacity=0, + batch_queue_size=None): """ Initializes a data loader object. @@ -287,6 +326,8 @@ def __init__(self, reader, batch_size=1, :param transform_fn: an optional callable to convert batches from the reader to PyTorch tensors :param shuffling_queue_capacity: Queue capacity is passed to the underlying :class:`tf.RandomShuffleQueue` instance. If set to 0, no shuffling will be done. + :param batch_queue_size: an optional int indicating maximum number of batches to fetch in + parallel. This might be useful when training models in order to improve model data throughput. """ super(BatchedDataLoader, self).__init__() self.reader = reader @@ -298,6 +339,11 @@ def __init__(self, reader, batch_size=1, self.shuffling_queue_capacity = shuffling_queue_capacity self._in_iter = None + # fetch batches in parallel? + if batch_queue_size is not None: + assert batch_queue_size > 0, "if set, batch_queue_size must be greater or equal to 1" + self._queue_size = batch_queue_size + def _iter_impl(self): """ The Data Loader iterator stops the for-loop when reader runs out of samples. diff --git a/petastorm/tests/test_pytorch_dataloader.py b/petastorm/tests/test_pytorch_dataloader.py index 384a695e..3f27a8ce 100644 --- a/petastorm/tests/test_pytorch_dataloader.py +++ b/petastorm/tests/test_pytorch_dataloader.py @@ -1,5 +1,7 @@ from decimal import Decimal from packaging import version +import time +import threading import numpy as np import pyarrow as pa @@ -10,7 +12,8 @@ from petastorm import make_reader, TransformSpec, make_batch_reader from petastorm.pytorch import (_sanitize_pytorch_types, DataLoader, BatchedDataLoader, decimal_friendly_collate, - InMemBatchedDataLoader, _load_rows_into_mem) + InMemBatchedDataLoader, _load_rows_into_mem, + BackgroundIterator) from petastorm.tests.test_common import TestSchema BASIC_DATA_LOADERS = [DataLoader, BatchedDataLoader] @@ -331,3 +334,39 @@ def test_inmem_batched_dataloader_shuffle_per_epoch(synthetic_dataset, reader_fa with pytest.raises(StopIteration): next(it) + + +def test_background_iterator(): + # number of iterator elements + N = int(1e6) + + # wait some time for the queue to be filled + bit = BackgroundIterator(range(N), queue_size=1000) + time.sleep(1) + + assert not bit.queue.empty() + # ensure the thread exists and is populating the queue + assert "background_iterator" in [t.name for t in threading.enumerate()] + # ensure we process the same number of elements present in original iterator + n = 0 + for _ in bit: + n += 1 + assert n == N + # ensure the thread stopped when original iterator was completed + time.sleep(1) + assert "background_iterator" not in [t.name for t in threading.enumerate()] + + +@pytest.mark.parametrize('reader_factory', ALL_READER_FLAVOR_FACTORIES) +def test_batched_dataloader_background_iterator_handle_exception(synthetic_dataset, reader_factory): + with BatchedDataLoader(reader_factory(synthetic_dataset.url, schema_fields=TORCH_BATCHABLE_FIELDS, + transform_spec=TransformSpec(_sensor_name_to_int)), + batch_queue_size=100) as loader: + try: + for _ in loader: + assert "background_iterator" in [t.name for t in threading.enumerate()] + raise RuntimeError() + except RuntimeError: + # ensure we wait enough for the thread to be stopped + time.sleep(1) + assert "background_iterator" not in [t.name for t in threading.enumerate()]