Skip to content

Commit

Permalink
Tiny fixes for the Cache & DatasetOptimizer (#18817)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <[email protected]>
  • Loading branch information
3 people authored Oct 19, 2023
1 parent c68ff64 commit e7afe04
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 87 deletions.
18 changes: 7 additions & 11 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,19 @@ def __init__(
if not _TORCH_GREATER_EQUAL_2_1_0:
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")

cache_dir = cache_dir if cache_dir else _try_create_cache_dir(name)
self._cache_dir = cache_dir = str(cache_dir) if cache_dir else _try_create_cache_dir(name)
if not remote_dir:
remote_dir, has_index_file = _find_remote_dir(name, version)

# When the index exists, we don't care about the chunk_size anymore.
if has_index_file and (chunk_size is None and chunk_bytes is None):
chunk_size = 2
self._writer = BinaryWriter(
str(cache_dir), chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression
)
self._reader = BinaryReader(
str(cache_dir),
remote_dir=remote_dir,
compression=compression,
item_loader=item_loader,
)
self._cache_dir = str(cache_dir)

if cache_dir:
os.makedirs(cache_dir, exist_ok=True)

self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression, item_loader=item_loader)
self._is_done = False
self._distributed_env = _DistributedEnv.detect()

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __len__(self) -> int:
return len(self.cache)

def __getitem__(self, idx: int) -> Any:
return self.getitem(self.cache[idx])
return self.cache[idx]

def getitem(self, obj: Any) -> Any:
"""Override the getitem with your own logic to transform the cache object."""
Expand Down
32 changes: 30 additions & 2 deletions src/lightning/data/streaming/dataset_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
from urllib import parse

from tqdm import tqdm
from tqdm.auto import tqdm

from lightning import seed_everything
from lightning.data.streaming import Cache
Expand Down Expand Up @@ -174,6 +174,7 @@ def __init__(
items: List[Any],
progress_queue: Queue,
error_queue: Queue,
stop_queue: Queue,
num_downloaders: int,
remove: bool,
chunk_size: Optional[int] = None,
Expand All @@ -200,6 +201,7 @@ def __init__(
self.remover: Optional[Process] = None
self.downloaders: List[Process] = []
self.to_download_queues: List[Queue] = []
self.stop_queue = stop_queue
self.ready_to_process_queue: Queue = Queue()
self.remove_queue: Queue = Queue()
self.upload_queue: Queue = Queue()
Expand All @@ -208,6 +210,7 @@ def __init__(
self.uploader: Optional[Process] = None
self._collected_items = 0
self._counter = 0
self._last_time = time()
self._index_counter = 0

def run(self) -> None:
Expand Down Expand Up @@ -249,6 +252,14 @@ def _loop(self) -> None:
assert self.uploader
self.upload_queue.put(None)
self.uploader.join()

if self.remove:
assert self.remover
self.remove_queue.put(None)
self.remover.join()

if self.progress_queue:
self.progress_queue.put((self.worker_index, self._counter))
return
continue

Expand All @@ -263,12 +274,20 @@ def _loop(self) -> None:
self._try_upload(chunk_filepath)

self._counter += 1
if self.progress_queue:

if self.progress_queue and (time() - self._last_time) > 1:
self.progress_queue.put((self.worker_index, self._counter))
self._last_time = time()

if self.remove:
self.remove_queue.put(self.paths[index])

try:
self.stop_queue.get(timeout=0.0001)
return
except Empty:
pass

def _set_environ_variables(self) -> None:
# set the optimizer global rank and world_size
os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(_get_node_rank() * self.num_workers + self.worker_index)
Expand Down Expand Up @@ -517,6 +536,7 @@ def __init__(
self.workers_tracker: Dict[int, int] = {}
self.progress_queue: Optional[Queue] = None
self.error_queue: Queue = Queue()
self.stop_queues: List[Queue] = []
self.remote_src_dir = (
str(remote_src_dir)
if remote_src_dir is not None
Expand Down Expand Up @@ -619,6 +639,7 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
total = sum([len(w) for w in workers_user_items])
with tqdm(total=total, smoothing=0) as pbar:
for worker_idx, worker_user_items in enumerate(workers_user_items):
self.stop_queues.append(Queue())
new_total = sum([w.collected_items for w in self.workers])
pbar.update(new_total - current_total)
current_total = new_total
Expand All @@ -635,6 +656,7 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
worker_user_items,
None,
self.error_queue,
self.stop_queues[-1],
self.num_downloaders,
self.delete_cached_files,
(self.chunk_size if self.chunk_size else 2)
Expand All @@ -657,7 +679,9 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
def _create_process_workers(self, begins: List[int], workers_user_items: List[List[Any]]) -> None:
self.progress_queue = Queue()
workers: List[DataWorkerProcess] = []
stop_queues: List[Queue] = []
for worker_idx, worker_user_items in enumerate(workers_user_items):
stop_queues.append(Queue())
worker = DataWorkerProcess(
worker_idx,
self.num_workers,
Expand All @@ -671,6 +695,7 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li
worker_user_items,
self.progress_queue,
self.error_queue,
stop_queues[-1],
self.num_downloaders,
self.delete_cached_files,
(self.chunk_size if self.chunk_size else 2)
Expand All @@ -684,6 +709,7 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li

# Note: Don't store within the loop as weakref aren't serializable
self.workers = workers
self.stop_queues = stop_queues

def _associated_items_to_workers(self, user_items: List[Any]) -> Tuple[List[int], List[List[Any]]]:
# Associate the items to the workers based on world_size and node_rank
Expand Down Expand Up @@ -738,6 +764,8 @@ def _cached_list_filepaths(self) -> List[str]:

def _signal_handler(self, signal: Any, frame: Any) -> None:
"""On temrination, we stop all the processes to avoid leaking RAM."""
for stop_queue in self.stop_queues:
stop_queue.put(None)
for w in self.workers:
w.join(0)
os._exit(0)
Expand Down
29 changes: 23 additions & 6 deletions src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import os
import warnings
from threading import Lock, Thread
from time import sleep
from typing import Any, Dict, List, Optional, Tuple
Expand All @@ -23,6 +24,8 @@
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import _SERIALIZERS, Serializer

warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*")

if _TORCH_GREATER_EQUAL_2_1_0:
pass

Expand All @@ -40,7 +43,9 @@ def __init__(self, config: ChunksConfig) -> None:
def add(self, chunk_indices: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
self._chunks_index_to_be_processed.extend(chunk_indices)
for chunk_indice in chunk_indices:
if chunk_indice not in self._chunks_index_to_be_processed:
self._chunks_index_to_be_processed.append(chunk_indice)

def run(self) -> None:
while True:
Expand Down Expand Up @@ -75,6 +80,8 @@ def __init__(
"""
super().__init__()
warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*")

self._cache_dir = cache_dir
self._remote_dir = remote_dir

Expand All @@ -89,6 +96,7 @@ def __init__(
self._rank: Optional[int] = None
self._config: Optional[ChunksConfig] = None
self._prepare_thread: Optional[PrepareChunksThread] = None
self._chunks_index_to_be_processed: List[int] = []
self._item_loader = item_loader or PyTreeLoader()

def _get_chunk_index_from_index(self, index: int) -> int:
Expand Down Expand Up @@ -132,11 +140,20 @@ def read(self, index: ChunkedIndex) -> Any:
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")

# Create and start the prepare chunks thread
if index.chunk_indexes is not None and self._prepare_thread is None and self._config:
self._prepare_thread = PrepareChunksThread(self._config)
self._prepare_thread.start()
self._prepare_thread.add(index.chunk_indexes)
if self._config and self._config._remote_dir:
# Create and start the prepare chunks thread
if self._prepare_thread is None and self._config:
self._prepare_thread = PrepareChunksThread(self._config)
self._prepare_thread.start()
if index.chunk_indexes:
self._chunks_index_to_be_processed.extend(index.chunk_indexes)
self._prepare_thread.add(index.chunk_indexes)

# If the chunk_index isn't already in the download queue, add it.
if index.chunk_index not in self._chunks_index_to_be_processed:
assert self._prepare_thread
self._prepare_thread.add([index.chunk_index])
self._chunks_index_to_be_processed.append(index.chunk_index)

# Fetch the element
chunk_filepath, begin, _ = self.config[index]
Expand Down
17 changes: 12 additions & 5 deletions src/lightning/data/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pickle
from abc import ABC, abstractmethod
from collections import OrderedDict
from io import BytesIO
from typing import Any, Optional, Tuple, Union

import numpy as np
Expand All @@ -36,6 +35,7 @@

if _TORCH_VISION_AVAILABLE:
from torchvision.io import decode_jpeg
from torchvision.transforms.functional import pil_to_tensor


class Serializer(ABC):
Expand Down Expand Up @@ -71,7 +71,8 @@ def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]:
ints = np.array([width, height, len(mode)], np.uint32)
return ints.tobytes() + mode + raw, None

def deserialize(self, data: bytes) -> Any:
@classmethod
def deserialize(cls, data: bytes) -> Any:
idx = 3 * 4
width, height, mode_size = np.frombuffer(data[:idx], np.uint32)
idx2 = idx + mode_size
Expand Down Expand Up @@ -113,10 +114,16 @@ def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]:
def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]:
if _TORCH_VISION_AVAILABLE:
array = torch.frombuffer(data, dtype=torch.uint8)
return decode_jpeg(array)
try:
return decode_jpeg(array)
except RuntimeError:
# Note: Some datasets like Imagenet contains some PNG images with JPEG extension, so we fallback to PIL
pass

inp = BytesIO(data)
return Image.open(inp)
img = PILSerializer.deserialize(data)
if _TORCH_VISION_AVAILABLE:
img = pil_to_tensor(img)
return img

def can_serialize(self, item: Any) -> bool:
return isinstance(item, JpegImageFile)
Expand Down
17 changes: 14 additions & 3 deletions src/lightning/data/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
"""
self._cache_dir = cache_dir

if not os.path.exists(self._cache_dir):
if (isinstance(self._cache_dir, str) and not os.path.exists(self._cache_dir)) or self._cache_dir is None:
raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.")

if (chunk_size is None and chunk_bytes is None) or (chunk_size and chunk_bytes):
Expand Down Expand Up @@ -157,8 +157,8 @@ def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]:

if self._data_format is None:
self._data_format = data_format
elif self._data_format != data_format:
raise Exception(
elif self._data_format != data_format and self._should_raise(data_format, self._data_format):
raise ValueError(
f"The data format changed between items. Found {data_format} instead of {self._data_format}."
)

Expand Down Expand Up @@ -406,3 +406,14 @@ def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
else:
with open(os.path.join(self._cache_dir, f"{node_rank}-{_INDEX_FILENAME}"), "w") as f:
json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True)

def _should_raise(self, data_format_1: List[str], data_format_2: List[str]) -> bool:
if len(data_format_1) != len(data_format_2):
return True

def is_non_valid(f1: str, f2: str) -> bool:
if f1 in ["pil", "jpeg"] and f2 in ["pil", "jpeg"]:
return False
return f1 != f2

return any(is_non_valid(f1, f2) for f1, f2 in zip(data_format_1, data_format_2))
Loading

0 comments on commit e7afe04

Please sign in to comment.