From e19ee14dec4b0cf61938d391ad81625ce4811ef5 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 12 Sep 2024 15:47:38 -0700 Subject: [PATCH 01/20] caching dataloader --- viscy/data/hcs_ram.py | 207 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 viscy/data/hcs_ram.py diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py new file mode 100644 index 00000000..f9e0fa82 --- /dev/null +++ b/viscy/data/hcs_ram.py @@ -0,0 +1,207 @@ +import logging +import math +import os +import re +import tempfile +from pathlib import Path +from typing import Callable, Literal, Sequence + +import numpy as np +import torch +import zarr +from imageio import imread +from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.data import set_track_meta +from monai.data.utils import collate_meta_tensor +from monai.transforms import ( + CenterSpatialCropd, + Compose, + MapTransform, + MultiSampleTrait, + RandAffined, +) +from torch import Tensor +from torch.utils.data import DataLoader, Dataset + +from viscy.data.typing import ChannelMap, DictTransform, HCSStackIndex, NormMeta, Sample +from viscy.data.hcs import _read_norm_meta +from tqdm import tqdm + +_logger = logging.getLogger("lightning.pytorch") + +# TODO: cache the norm metadata when caching the dataset + + +def _stack_channels( + sample_images: list[dict[str, Tensor]] | dict[str, Tensor], + channels: ChannelMap, + key: str, +) -> Tensor | list[Tensor]: + """Stack single-channel images into a multi-channel tensor.""" + if not isinstance(sample_images, list): + return torch.stack([sample_images[ch][0] for ch in channels[key]]) + # training time + return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] + + +class CachedDataset(Dataset): + """ + A dataset that caches the data in RAM. + It relies on the `__getitem__` method to load the data on the 1st epoch. + """ + + def __init__( + self, + positions: list[Position], + channels: ChannelMap, + transform: DictTransform | None = None, + ): + super().__init__() + self.positions = positions + self.channels = channels + self.transform = transform + + self.source_ch_idx = [ + positions[0].get_channel_index(c) for c in channels["source"] + ] + self.target_ch_idx = ( + [positions[0].get_channel_index(c) for c in channels["target"]] + if "target" in channels + else None + ) + self._position_mapping() + self.cache_dict = {} + + def _position_mapping(self) -> None: + self.position_keys = [] + self.norm_meta_dict = {} + + for pos in self.positions: + self.position_keys.append(pos.data.name) + self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) + + def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: + # Add the position to the cached_dict + # TODO: hardcoding to t=0 + self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( + self.positions[index] + .data.oindex[slice(t, t + 1), channel_index, :] + .astype(np.float32) + ) + + def _get_weight_map(self, position: Position) -> Tensor: + # Get the weight map from the position for the MONAI weightedcrop transform + raise NotImplementedError + + def __len__(self) -> int: + return len(self.positions) + + def __getitem__(self, index: int) -> Sample: + + ch_names = self.channels["source"].copy() + ch_idx = self.source_ch_idx.copy() + if self.target_ch_idx is not None: + ch_names.extend(self.channels["target"]) + ch_idx.extend(self.target_ch_idx) + + # Check if the sample is in the cache else add it + # Split the tensor into the channels + sample_id = self.position_keys[index] + if sample_id not in self.cache_dict: + logging.debug(f"Adding {sample_id} to cache") + self._cache_dataset(index, channel_index=ch_idx) + + # Get the sample from the cache + images = self.cache_dict[sample_id].unbind(dim=1) + norm_meta = self.norm_meta_dict[str(sample_id)] + + sample_images = {k: v for k, v in zip(ch_names, images)} + + if self.target_ch_idx is not None: + # FIXME: this uses the first target channel as weight for performance + # since adding a reference to a tensor does not copy + # maybe write a weight map in preprocessing to use more information? + sample_images["weight"] = sample_images[self.channels["target"][0]] + if norm_meta is not None: + sample_images["norm_meta"] = norm_meta + if self.transform: + sample_images = self.transform(sample_images) + if "weight" in sample_images: + del sample_images["weight"] + sample = { + "index": sample_id, + "source": _stack_channels(sample_images, self.channels, "source"), + "norm_meta": norm_meta, + } + if self.target_ch_idx is not None: + sample["target"] = _stack_channels(sample_images, self.channels, "target") + return sample + + def _load_sample(self, position: Position) -> Sample: + source, target = self.channel_map.source, self.channel_map.target + source_data = self._load_channel_data(position, source) + target_data = self._load_channel_data(position, target) + sample = {"source": source_data, "target": target_data} + return sample + + +class CachedDataloader(LightningDataModule): + def __init__( + self, + data_path: str, + source_channel: str | Sequence[str], + target_channel: str | Sequence[str], + split_ratio: float = 0.8, + batch_size: int = 16, + num_workers: int = 8, + architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "UNeXt2", + yx_patch_size: tuple[int, int] = (256, 256), + normalizations: list[MapTransform] = [], + augmentations: list[MapTransform] = [], + ): + super().__init__() + self.data_path = data_path + self.source_channel = source_channel + self.target_channel = target_channel + self.batch_size = batch_size + self.num_workers = num_workers + self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True + self.split_ratio = split_ratio + self.yx_patch_size = yx_patch_size + self.normalizations = normalizations + self.augmentations = augmentations + + @property + def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: + return { + "channels": {"source": self.source_channel}, + } + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: + dataset_settings = self._base_dataset_settings + if stage in ("fit", "validate"): + self._setup_fit(dataset_settings) + elif stage == "test": + self._setup_test(dataset_settings) + elif stage == "predict": + self._setup_predict(dataset_settings) + else: + raise NotImplementedError(f"Stage {stage} is not supported") + + def _setup_fit(self, dataset_settings: dict) -> None: + """ + Setup the train and validation datasets. + """ + train_transform, val_transform = self._fit_transform() + dataset_settings["channels"]["target"] = self.target_channel + # Load the plate + plate = open_ome_zarr(self.data_path) + + pass + + def _setup_test(self) -> None: + pass + + def _setup_val(self) -> None: + pass From d31978d928af8820f3e7b3db2a8f0aa4ac3a9fc4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 25 Sep 2024 17:27:04 -0700 Subject: [PATCH 02/20] caching data module --- viscy/data/hcs_ram.py | 88 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 10 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index f9e0fa82..5e72f6f6 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -98,7 +98,6 @@ def __len__(self) -> int: return len(self.positions) def __getitem__(self, index: int) -> Sample: - ch_names = self.channels["source"].copy() ch_idx = self.source_ch_idx.copy() if self.target_ch_idx is not None: @@ -109,7 +108,7 @@ def __getitem__(self, index: int) -> Sample: # Split the tensor into the channels sample_id = self.position_keys[index] if sample_id not in self.cache_dict: - logging.debug(f"Adding {sample_id} to cache") + logging.info(f"Adding {sample_id} to cache") self._cache_dataset(index, channel_index=ch_idx) # Get the sample from the cache @@ -146,7 +145,7 @@ def _load_sample(self, position: Position) -> Sample: return sample -class CachedDataloader(LightningDataModule): +class CachedDataModule(LightningDataModule): def __init__( self, data_path: str, @@ -159,6 +158,7 @@ def __init__( yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], + z_window_size: int = 1, ): super().__init__() self.data_path = data_path @@ -171,6 +171,7 @@ def __init__( self.yx_patch_size = yx_patch_size self.normalizations = normalizations self.augmentations = augmentations + self.z_window_size = z_window_size @property def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: @@ -183,12 +184,53 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: if stage in ("fit", "validate"): self._setup_fit(dataset_settings) elif stage == "test": - self._setup_test(dataset_settings) + raise NotImplementedError("Test stage is not supported") elif stage == "predict": - self._setup_predict(dataset_settings) + raise NotImplementedError("Predict stage is not supported") else: raise NotImplementedError(f"Stage {stage} is not supported") + def _train_transform(self) -> list[Callable]: + if self.augmentations: + for aug in self.augmentations: + if isinstance(aug, MultiSampleTrait): + num_samples = aug.cropper.num_samples + if self.batch_size % num_samples != 0: + raise ValueError( + "Batch size must be divisible by `num_samples` per stack. " + f"Got batch size {self.batch_size} and " + f"number of samples {num_samples} for " + f"transform type {type(aug)}." + ) + self.train_patches_per_stack = num_samples + return list(self.augmentations) + + def _fit_transform(self) -> tuple[Compose, Compose]: + """(normalization -> maybe augmentation -> center crop) + Deterministic center crop as the last step of training and validation.""" + # TODO: These have a fixed order for now... () + final_crop = [ + CenterSpatialCropd( + keys=self.source_channel + self.target_channel, + roi_size=( + self.z_window_size, + self.yx_patch_size[0], + self.yx_patch_size[1], + ), + ) + ] + train_transform = Compose( + self.normalizations + self._train_transform() + final_crop + ) + val_transform = Compose(self.normalizations + final_crop) + return train_transform, val_transform + + def _set_fit_global_state(self, num_positions: int) -> torch.Tensor: + # disable metadata tracking in MONAI for performance + set_track_meta(False) + # shuffle positions, randomness is handled globally + return torch.randperm(num_positions) + def _setup_fit(self, dataset_settings: dict) -> None: """ Setup the train and validation datasets. @@ -197,11 +239,37 @@ def _setup_fit(self, dataset_settings: dict) -> None: dataset_settings["channels"]["target"] = self.target_channel # Load the plate plate = open_ome_zarr(self.data_path) + # shuffle positions, randomness is handled globally + positions = [pos for _, pos in plate.positions()] + shuffled_indices = self._set_fit_global_state(len(positions)) + positions = list(positions[i] for i in shuffled_indices) + num_train_fovs = int(len(positions) * self.split_ratio) - pass + self.train_dataset = CachedDataset( + positions[:num_train_fovs], + transform=train_transform, + **dataset_settings, + ) + self.val_dataset = CachedDataset( + positions[num_train_fovs:], + transform=val_transform, + **dataset_settings, + ) - def _setup_test(self) -> None: - pass + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size // self.train_patches_per_stack, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, + ) - def _setup_val(self) -> None: - pass + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=False, + ) From 041d73837665bb3d465b48d29d63dc0621931cf4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 26 Sep 2024 21:06:49 -0700 Subject: [PATCH 03/20] black --- viscy/data/hcs_ram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 5e72f6f6..8a813dce 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -224,7 +224,7 @@ def _fit_transform(self) -> tuple[Compose, Compose]: ) val_transform = Compose(self.normalizations + final_crop) return train_transform, val_transform - + def _set_fit_global_state(self, num_positions: int) -> torch.Tensor: # disable metadata tracking in MONAI for performance set_track_meta(False) From 7f76174917513a3cee4e565008cc7d7b60527cfe Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 26 Sep 2024 21:13:25 -0700 Subject: [PATCH 04/20] ruff --- viscy/data/hcs_ram.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 8a813dce..74240deb 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -1,32 +1,22 @@ import logging -import math -import os -import re -import tempfile -from pathlib import Path from typing import Callable, Literal, Sequence import numpy as np import torch -import zarr -from imageio import imread -from iohub.ngff import ImageArray, Plate, Position, open_ome_zarr +from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta -from monai.data.utils import collate_meta_tensor from monai.transforms import ( CenterSpatialCropd, Compose, MapTransform, MultiSampleTrait, - RandAffined, ) from torch import Tensor from torch.utils.data import DataLoader, Dataset -from viscy.data.typing import ChannelMap, DictTransform, HCSStackIndex, NormMeta, Sample from viscy.data.hcs import _read_norm_meta -from tqdm import tqdm +from viscy.data.typing import ChannelMap, DictTransform, Sample _logger = logging.getLogger("lightning.pytorch") From 85ea7915a82dba31c733d01b4842bc8ff1e7f9aa Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 27 Sep 2024 18:18:20 -0700 Subject: [PATCH 05/20] Bump torch to 2.4.1 (#174) * update torch >2.4.1 * black * ruff --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dc263580..d07187fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ "iohub==0.1.0", - "torch>=2.1.2", + "torch>=2.4.1", "timm>=0.9.5", "tensorboard>=2.13.0", "lightning>=2.3.0", From 18385813b1097ddf736d09cce8bb9e4d99745ef6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Sat, 28 Sep 2024 11:05:41 -0700 Subject: [PATCH 06/20] adding timeout to ram_dataloader --- viscy/data/hcs_ram.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 74240deb..fb1e1de8 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -32,9 +32,11 @@ def _stack_channels( if not isinstance(sample_images, list): return torch.stack([sample_images[ch][0] for ch in channels[key]]) # training time + # sample_images is a list['Phase3D'].shape = (1,3,256,256) return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -149,6 +151,7 @@ def __init__( normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], z_window_size: int = 1, + timeout: int = 600, ): super().__init__() self.data_path = data_path @@ -162,6 +165,7 @@ def __init__( self.normalizations = normalizations self.augmentations = augmentations self.z_window_size = z_window_size + self.timeout = timeout @property def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: @@ -253,6 +257,7 @@ def train_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), shuffle=True, + timeout=self.timeout ) def val_dataloader(self) -> DataLoader: @@ -262,4 +267,5 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), shuffle=False, + timeout=self.timeout ) From f5c01a31ce4cff5f79f65d46b072fba9594af55d Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 4 Oct 2024 13:07:14 +0200 Subject: [PATCH 07/20] bandaid to cached dataloader --- viscy/data/hcs_ram.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index fb1e1de8..bec0b0da 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -117,7 +117,8 @@ def __getitem__(self, index: int) -> Sample: if norm_meta is not None: sample_images["norm_meta"] = norm_meta if self.transform: - sample_images = self.transform(sample_images) + # FIX ME: check why the transforms return a list? + sample_images = self.transform(sample_images)[0] if "weight" in sample_images: del sample_images["weight"] sample = { @@ -185,6 +186,11 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: raise NotImplementedError(f"Stage {stage} is not supported") def _train_transform(self) -> list[Callable]: + """ Set the train augmentations + + + """ + if self.augmentations: for aug in self.augmentations: if isinstance(aug, MultiSampleTrait): @@ -197,6 +203,10 @@ def _train_transform(self) -> list[Callable]: f"transform type {type(aug)}." ) self.train_patches_per_stack = num_samples + else: + self.augmentations=[] + + _logger.info(f'Training augmentations: {self.augmentations}') return list(self.augmentations) def _fit_transform(self) -> tuple[Compose, Compose]: From 26a06b86491d7fea807afee3b170cd04d61aa5c4 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 4 Oct 2024 09:02:20 -0700 Subject: [PATCH 08/20] fixing the dataloader using torch collate_fn --- viscy/data/hcs_ram.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index bec0b0da..2e4f4f2c 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -14,6 +14,7 @@ ) from torch import Tensor from torch.utils.data import DataLoader, Dataset +from monai.data.utils import collate_meta_tensor from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample @@ -35,6 +36,25 @@ def _stack_channels( # sample_images is a list['Phase3D'].shape = (1,3,256,256) return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] +def _collate_samples(batch: Sequence[Sample]) -> Sample: + """Collate samples into a batch sample. + + :param Sequence[Sample] batch: a sequence of dictionaries, + where each key may point to a value of a single tensor or a list of tensors, + as is the case with ``train_patches_per_stack > 1``. + :return Sample: Batch sample (dictionary of tensors) + """ + collated: Sample = {} + for key in batch[0].keys(): + data = [] + for sample in batch: + if isinstance(sample[key], Sequence): + data.extend(sample[key]) + else: + data.append(sample[key]) + collated[key] = collate_meta_tensor(data) + return collated + class CachedDataset(Dataset): @@ -118,7 +138,7 @@ def __getitem__(self, index: int) -> Sample: sample_images["norm_meta"] = norm_meta if self.transform: # FIX ME: check why the transforms return a list? - sample_images = self.transform(sample_images)[0] + sample_images = self.transform(sample_images) if "weight" in sample_images: del sample_images["weight"] sample = { @@ -206,7 +226,7 @@ def _train_transform(self) -> list[Callable]: else: self.augmentations=[] - _logger.info(f'Training augmentations: {self.augmentations}') + _logger.debug(f'Training augmentations: {self.augmentations}') return list(self.augmentations) def _fit_transform(self) -> tuple[Compose, Compose]: @@ -267,7 +287,9 @@ def train_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), shuffle=True, - timeout=self.timeout + timeout=self.timeout, + collate_fn=_collate_samples, + drop_last=True ) def val_dataloader(self) -> DataLoader: @@ -277,5 +299,6 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, persistent_workers=bool(self.num_workers), shuffle=False, - timeout=self.timeout + timeout=self.timeout, + ) From f2ff43c0ac8b1370c3f93401b881170c930e725c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 16 Oct 2024 17:52:28 -0700 Subject: [PATCH 09/20] replacing dictionary with single array --- viscy/data/hcs_ram.py | 62 ++++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 2e4f4f2c..c8f0227c 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -83,24 +83,39 @@ def __init__( else None ) self._position_mapping() - self.cache_dict = {} + + self.cache_order = [] + self.cache_record = torch.zeros(len(self.positions)) + # Caching the dataset as two separate arrays + # self._init_cache_dataset() def _position_mapping(self) -> None: self.position_keys = [] + self.position_shape_tczyx= (1,1,1,1,1) self.norm_meta_dict = {} for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) + # FIX: Use the position shape + self.position_shape_zyx = pos.data.shape[-3:] - def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: - # Add the position to the cached_dict - # TODO: hardcoding to t=0 - self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( - self.positions[index] - .data.oindex[slice(t, t + 1), channel_index, :] - .astype(np.float32) - ) + def _init_cache_dataset(self, t_idx=1, ch_idx=1) -> None: + _logger.info('Initializing cache array') + # FIXME assumes t=1 + self.cache = torch.zeros(((len(self.positions),t_idx,len(ch_idx),)+ self.position_shape_zyx)) + + + # def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: + # # Add the position to the cached_dict + # # TODO: hardcoding to t=0 + # _logger.info(f'Adding {self.position_keys[index]} to cache') + + # # self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( + # # self.positions[index] + # # .data.oindex[slice(t, t + 1), channel_index, :] + # # .astype(np.float32) + # # ) def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -117,14 +132,30 @@ def __getitem__(self, index: int) -> Sample: ch_idx.extend(self.target_ch_idx) # Check if the sample is in the cache else add it - # Split the tensor into the channels - sample_id = self.position_keys[index] - if sample_id not in self.cache_dict: - logging.info(f"Adding {sample_id} to cache") - self._cache_dataset(index, channel_index=ch_idx) + if self.cache_record[index]==0: + #if all entries of self.cache_record are zero + if self.cache_record.sum()==0: + #FIXME hardcoding t_idx=1 + self._init_cache_dataset(ch_idx=ch_idx,t_idx=1) + + # Flip the bit + self.cache_record[index]=1 + self.cache_order.append(index) + # Stack the data + _logger.info(f'Adding {self.position_keys[index]} to cache') + _logger.info(f'Cache_order: {self.cache_order}') + _logger.info(f'caching index: {index}') + #FIX ME: hardcoding t=0 and make this part of function + t=0 + # Insert the data into the cache + self.cache[index]=torch.from_numpy(self.positions[index] + .data.oindex[slice(t, t + 1), ch_idx, :] + .astype(np.float32)) # Get the sample from the cache - images = self.cache_dict[sample_id].unbind(dim=1) + # images = self.cache_dict[sample_id].unbind(dim=1) + sample_id = self.position_keys[index] + images = self.cache[index].unbind(dim=1) norm_meta = self.norm_meta_dict[str(sample_id)] sample_images = {k: v for k, v in zip(ch_names, images)} @@ -207,7 +238,6 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: def _train_transform(self) -> list[Callable]: """ Set the train augmentations - """ From 5fb96d75ace8775d8a152ebee6b900f9ed1be59a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 18 Oct 2024 10:04:34 -0700 Subject: [PATCH 10/20] loading prior to epoch 0 --- viscy/data/hcs_ram.py | 104 ++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 45 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index c8f0227c..98bfaba0 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -23,6 +23,15 @@ # TODO: cache the norm metadata when caching the dataset +# Map the NumPy dtype to the corresponding PyTorch dtype +numpy_to_torch_dtype = { + np.dtype('float32'): torch.float32, + np.dtype('float64'): torch.float64, + np.dtype('int32'): torch.int32, + np.dtype('int64'): torch.int64, + np.dtype('uint8'): torch.int8, + np.dtype('uint16'): torch.int16, +} def _stack_channels( sample_images: list[dict[str, Tensor]] | dict[str, Tensor], @@ -54,9 +63,7 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: data.append(sample[key]) collated[key] = collate_meta_tensor(data) return collated - - - + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -82,10 +89,18 @@ def __init__( if "target" in channels else None ) + # Get total num channels + self.total_ch_names = self.channels["source"].copy() + self.total_ch_idx = self.source_ch_idx.copy() + if self.target_ch_idx is not None: + self.total_ch_names.extend(self.channels["target"]) + self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() self.cache_order = [] self.cache_record = torch.zeros(len(self.positions)) + self._init_cache_dataset() + # Caching the dataset as two separate arrays # self._init_cache_dataset() @@ -97,25 +112,27 @@ def _position_mapping(self) -> None: for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) - # FIX: Use the position shape self.position_shape_zyx = pos.data.shape[-3:] + self._cache_dtype = numpy_to_torch_dtype.get(pos.data.dtype, torch.float32) # Default to torch.float32 if not found - def _init_cache_dataset(self, t_idx=1, ch_idx=1) -> None: + def _init_cache_dataset(self) -> None: _logger.info('Initializing cache array') # FIXME assumes t=1 - self.cache = torch.zeros(((len(self.positions),t_idx,len(ch_idx),)+ self.position_shape_zyx)) - - - # def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: - # # Add the position to the cached_dict - # # TODO: hardcoding to t=0 - # _logger.info(f'Adding {self.position_keys[index]} to cache') - - # # self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( - # # self.positions[index] - # # .data.oindex[slice(t, t + 1), channel_index, :] - # # .astype(np.float32) - # # ) + t_idx = 1 + self.cache = torch.zeros(((len(self.positions),t_idx,len(self.total_ch_idx))+ self.position_shape_zyx),dtype=self._cache_dtype) + _logger.info(f'Cache shape: {self.cache.shape}') + + #TODO Caching here to see if multiprocessing is faster + t=0 + + for i, pos in enumerate(self.positions): + _logger.info(f'Caching position {i}/{len(self.positions)}') + ## Insert the data into the cache + data = pos.data.oindex[slice(t, t + 1), self.total_ch_idx, :] + if data.dtype != np.float32: + data = data.astype(np.float32) + self.cache[i]= torch.from_numpy(data) + del data def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -125,35 +142,33 @@ def __len__(self) -> int: return len(self.positions) def __getitem__(self, index: int) -> Sample: - ch_names = self.channels["source"].copy() - ch_idx = self.source_ch_idx.copy() - if self.target_ch_idx is not None: - ch_names.extend(self.channels["target"]) - ch_idx.extend(self.target_ch_idx) - + #FIXME replace this after debugging + ch_idx = self.total_ch_idx + ch_names = self.total_ch_names + # Check if the sample is in the cache else add it - if self.cache_record[index]==0: - #if all entries of self.cache_record are zero - if self.cache_record.sum()==0: - #FIXME hardcoding t_idx=1 - self._init_cache_dataset(ch_idx=ch_idx,t_idx=1) - - # Flip the bit - self.cache_record[index]=1 - self.cache_order.append(index) - # Stack the data - _logger.info(f'Adding {self.position_keys[index]} to cache') - _logger.info(f'Cache_order: {self.cache_order}') - _logger.info(f'caching index: {index}') - #FIX ME: hardcoding t=0 and make this part of function - t=0 - # Insert the data into the cache - self.cache[index]=torch.from_numpy(self.positions[index] - .data.oindex[slice(t, t + 1), ch_idx, :] - .astype(np.float32)) + # if self.cache_record[index]== 0: + # # Flip the bit + # self.cache_record[index]=1 + # self.cache_order.append(index) + + # # Stack the data + # _logger.info(f'Adding {self.position_keys[index]} to cache') + # _logger.info(f'Cache_order: {self.cache_order}') + # _logger.info(f'caching index: {index}') + + # #FIX ME: hardcoding t=0 and make this part of function + # t=0 + + # # Insert the data into the cache + # data = self.positions[index].data.oindex[slice(t, t + 1), ch_idx, :] + # if data.dtype != np.float32: + # data = data.astype(np.float32) + # self.cache[index]= torch.from_numpy(data) + # del data # Get the sample from the cache - # images = self.cache_dict[sample_id].unbind(dim=1) + _logger.info(f'Getting sample {index} from cache') sample_id = self.position_keys[index] images = self.cache[index].unbind(dim=1) norm_meta = self.norm_meta_dict[str(sample_id)] @@ -168,7 +183,6 @@ def __getitem__(self, index: int) -> Sample: if norm_meta is not None: sample_images["norm_meta"] = norm_meta if self.transform: - # FIX ME: check why the transforms return a list? sample_images = self.transform(sample_images) if "weight" in sample_images: del sample_images["weight"] From 848cd63ce06b4dd5c36548bd109ef87f6b232176 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 18 Oct 2024 17:32:51 -0700 Subject: [PATCH 11/20] Revert "replacing dictionary with single array" This reverts commit 8c13f49498eb862e9f94518be727f47682d2cdcf. --- viscy/data/hcs_ram.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 98bfaba0..4ecf2f6f 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -96,6 +96,7 @@ def __init__( self.total_ch_names.extend(self.channels["target"]) self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() +<<<<<<< HEAD self.cache_order = [] self.cache_record = torch.zeros(len(self.positions)) @@ -103,15 +104,18 @@ def __init__( # Caching the dataset as two separate arrays # self._init_cache_dataset() +======= + self.cache_dict = {} +>>>>>>> parent of 8c13f49 (replacing dictionary with single array) def _position_mapping(self) -> None: self.position_keys = [] - self.position_shape_tczyx= (1,1,1,1,1) self.norm_meta_dict = {} for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) +<<<<<<< HEAD self.position_shape_zyx = pos.data.shape[-3:] self._cache_dtype = numpy_to_torch_dtype.get(pos.data.dtype, torch.float32) # Default to torch.float32 if not found @@ -133,6 +137,17 @@ def _init_cache_dataset(self) -> None: data = data.astype(np.float32) self.cache[i]= torch.from_numpy(data) del data +======= + + def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: + # Add the position to the cached_dict + # TODO: hardcoding to t=0 + self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( + self.positions[index] + .data.oindex[slice(t, t + 1), channel_index, :] + .astype(np.float32) + ) +>>>>>>> parent of 8c13f49 (replacing dictionary with single array) def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -147,6 +162,7 @@ def __getitem__(self, index: int) -> Sample: ch_names = self.total_ch_names # Check if the sample is in the cache else add it +<<<<<<< HEAD # if self.cache_record[index]== 0: # # Flip the bit # self.cache_record[index]=1 @@ -171,6 +187,16 @@ def __getitem__(self, index: int) -> Sample: _logger.info(f'Getting sample {index} from cache') sample_id = self.position_keys[index] images = self.cache[index].unbind(dim=1) +======= + # Split the tensor into the channels + sample_id = self.position_keys[index] + if sample_id not in self.cache_dict: + logging.info(f"Adding {sample_id} to cache") + self._cache_dataset(index, channel_index=ch_idx) + + # Get the sample from the cache + images = self.cache_dict[sample_id].unbind(dim=1) +>>>>>>> parent of 8c13f49 (replacing dictionary with single array) norm_meta = self.norm_meta_dict[str(sample_id)] sample_images = {k: v for k, v in zip(ch_names, images)} @@ -252,6 +278,7 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: def _train_transform(self) -> list[Callable]: """ Set the train augmentations + """ From f7e57ae03a2ceb4ffb5347bb6b1d5d9107a8a50c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 18 Oct 2024 18:11:01 -0700 Subject: [PATCH 12/20] using multiprocessing manager --- viscy/data/hcs_ram.py | 95 ++++++++++++------------------------------- 1 file changed, 27 insertions(+), 68 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 4ecf2f6f..d8b1b698 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -18,6 +18,8 @@ from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample +from multiprocessing import Manager +from datetime import datetime _logger = logging.getLogger("lightning.pytorch") @@ -96,17 +98,12 @@ def __init__( self.total_ch_names.extend(self.channels["target"]) self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() -<<<<<<< HEAD - - self.cache_order = [] - self.cache_record = torch.zeros(len(self.positions)) - self._init_cache_dataset() - - # Caching the dataset as two separate arrays - # self._init_cache_dataset() -======= + + # Cached dictionary with tensors self.cache_dict = {} ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) + manager = Manager() + self.cache_dict = manager.dict() + self._cached_pos=[] def _position_mapping(self) -> None: self.position_keys = [] @@ -115,39 +112,15 @@ def _position_mapping(self) -> None: for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) -<<<<<<< HEAD - self.position_shape_zyx = pos.data.shape[-3:] - self._cache_dtype = numpy_to_torch_dtype.get(pos.data.dtype, torch.float32) # Default to torch.float32 if not found - - def _init_cache_dataset(self) -> None: - _logger.info('Initializing cache array') - # FIXME assumes t=1 - t_idx = 1 - self.cache = torch.zeros(((len(self.positions),t_idx,len(self.total_ch_idx))+ self.position_shape_zyx),dtype=self._cache_dtype) - _logger.info(f'Cache shape: {self.cache.shape}') - - #TODO Caching here to see if multiprocessing is faster - t=0 - - for i, pos in enumerate(self.positions): - _logger.info(f'Caching position {i}/{len(self.positions)}') - ## Insert the data into the cache - data = pos.data.oindex[slice(t, t + 1), self.total_ch_idx, :] - if data.dtype != np.float32: - data = data.astype(np.float32) - self.cache[i]= torch.from_numpy(data) - del data -======= def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: # Add the position to the cached_dict # TODO: hardcoding to t=0 - self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( - self.positions[index] - .data.oindex[slice(t, t + 1), channel_index, :] - .astype(np.float32) - ) ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) + data =self.positions[index].data.oindex[slice(t, t + 1), channel_index, :] + if data.dtype != np.float32: + data = data.astype(np.float32) + self.cache_dict[str(self.position_keys[index])] = torch.from_numpy(data) + def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -162,43 +135,20 @@ def __getitem__(self, index: int) -> Sample: ch_names = self.total_ch_names # Check if the sample is in the cache else add it -<<<<<<< HEAD - # if self.cache_record[index]== 0: - # # Flip the bit - # self.cache_record[index]=1 - # self.cache_order.append(index) - - # # Stack the data - # _logger.info(f'Adding {self.position_keys[index]} to cache') - # _logger.info(f'Cache_order: {self.cache_order}') - # _logger.info(f'caching index: {index}') - - # #FIX ME: hardcoding t=0 and make this part of function - # t=0 - - # # Insert the data into the cache - # data = self.positions[index].data.oindex[slice(t, t + 1), ch_idx, :] - # if data.dtype != np.float32: - # data = data.astype(np.float32) - # self.cache[index]= torch.from_numpy(data) - # del data - - # Get the sample from the cache - _logger.info(f'Getting sample {index} from cache') - sample_id = self.position_keys[index] - images = self.cache[index].unbind(dim=1) -======= # Split the tensor into the channels sample_id = self.position_keys[index] if sample_id not in self.cache_dict: - logging.info(f"Adding {sample_id} to cache") + _logger.info(f"Adding {sample_id} to cache") + self._cached_pos.append(index) + _logger.info(f"Cached positions: {self._cached_pos}") self._cache_dataset(index, channel_index=ch_idx) # Get the sample from the cache + _logger.info('Getting sample from cache') + start_time = datetime.now() images = self.cache_dict[sample_id].unbind(dim=1) ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) norm_meta = self.norm_meta_dict[str(sample_id)] - + after_cache = datetime.now() - start_time sample_images = {k: v for k, v in zip(ch_names, images)} if self.target_ch_idx is not None: @@ -209,7 +159,9 @@ def __getitem__(self, index: int) -> Sample: if norm_meta is not None: sample_images["norm_meta"] = norm_meta if self.transform: + before_transform = datetime.now() sample_images = self.transform(sample_images) + after_transform = datetime.now() - before_transform if "weight" in sample_images: del sample_images["weight"] sample = { @@ -219,6 +171,11 @@ def __getitem__(self, index: int) -> Sample: } if self.target_ch_idx is not None: sample["target"] = _stack_channels(sample_images, self.channels, "target") + + _logger.info(f"\nTime taken to cache: {after_cache}") + _logger.info(f"Time taken to transform: {after_transform}") + _logger.info(f"Time taken to get sample: {datetime.now() - start_time}\n") + return sample def _load_sample(self, position: Position) -> Sample: @@ -357,6 +314,7 @@ def train_dataloader(self) -> DataLoader: batch_size=self.batch_size // self.train_patches_per_stack, num_workers=self.num_workers, persistent_workers=bool(self.num_workers), + pin_memory=True, shuffle=True, timeout=self.timeout, collate_fn=_collate_samples, @@ -369,6 +327,7 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=bool(self.num_workers), + pin_memory=True, shuffle=False, timeout=self.timeout, From c4797b4529dfd846127dab7bac58231fa19a1b7f Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 21 Oct 2024 14:13:44 -0700 Subject: [PATCH 13/20] add sharded distributed sampler --- viscy/data/distributed.py | 51 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 viscy/data/distributed.py diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py new file mode 100644 index 00000000..bd3ab618 --- /dev/null +++ b/viscy/data/distributed.py @@ -0,0 +1,51 @@ +"""Utilities for DDP training.""" + +import math + +import torch +from torch.utils.data.distributed import DistributedSampler + + +class ShardedDistributedSampler(DistributedSampler): + def _sharded_randperm(self, generator): + """Generate a sharded random permutation of indices.""" + indices = torch.tensor(range(len(self.dataset))) + permuted = torch.stack( + [ + torch.randperm(self.num_samples, generator=generator) + + i * self.num_samples + for i in range(self.num_replicas) + ], + dim=1, + ).reshape(-1) + return indices[permuted].tolist() + + def __iter__(self): + """Modified __iter__ method to shard data across distributed ranks.""" + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = self._sharded_randperm(g) + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) From 2c31e7d09653bb831c3cdd022ba782988bd93893 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 21 Oct 2024 14:14:05 -0700 Subject: [PATCH 14/20] add example script for ddp caching --- viscy/scripts/shared_dict.py | 121 +++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 viscy/scripts/shared_dict.py diff --git a/viscy/scripts/shared_dict.py b/viscy/scripts/shared_dict.py new file mode 100644 index 00000000..b29d4d86 --- /dev/null +++ b/viscy/scripts/shared_dict.py @@ -0,0 +1,121 @@ +from multiprocessing.managers import DictProxy + +import torch +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.utilities import rank_zero_info +from torch.distributed import get_rank +from torch.multiprocessing import Manager +from torch.utils.data import DataLoader, Dataset, Subset + +from viscy.data.distributed import ShardedDistributedSampler + + +class CachedDataset(Dataset): + def __init__(self, shared_dict: DictProxy, length: int): + self.rank = get_rank() + print(f"=== Initializing cache pool for rank {self.rank} ===") + self.shared_dict = shared_dict + self.length = length + + def __getitem__(self, index): + if index not in self.shared_dict: + print(f"* Adding {index} to cache dict on rank {self.rank}") + self.shared_dict[index] = torch.tensor(index).float()[None] + return self.shared_dict[index] + + def __len__(self): + return self.length + + +class CachedDataModule(LightningDataModule): + def __init__( + self, + length: int, + split_ratio: float, + batch_size: int, + num_workers: int, + persistent_workers: bool, + ): + super().__init__() + self.length = length + self.split_ratio = split_ratio + self.batch_size = batch_size + self.num_workers = num_workers + self.persistent_workers = persistent_workers + + def setup(self, stage): + if stage != "fit": + raise NotImplementedError("Only fit stage is supported.") + shared_dict = Manager().dict() + dataset = CachedDataset(shared_dict, self.length) + split_idx = int(self.length * self.split_ratio) + self.train_dataset = Subset(dataset, range(0, split_idx)) + self.val_dataset = Subset(dataset, range(split_idx, self.length)) + + def train_dataloader(self): + sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=self.persistent_workers, + drop_last=False, + sampler=sampler, + ) + + def val_dataloader(self): + sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=self.persistent_workers, + drop_last=False, + sampler=sampler, + ) + + +class DummyModel(LightningModule): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(1, 1) + + def forward(self, x): + return self.layer(x) + + def on_train_start(self): + rank_zero_info("=== Starting training ===") + + def on_train_epoch_start(self): + rank_zero_info(f"=== Starting training epoch {self.current_epoch} ===") + + def training_step(self, batch, batch_idx): + loss = torch.nn.functional.mse_loss(self.layer(batch), torch.zeros_like(batch)) + return loss + + def validation_step(self, batch, batch_idx): + loss = torch.nn.functional.mse_loss(self.layer(batch), torch.zeros_like(batch)) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-3) + + +trainer = Trainer( + max_epochs=5, + strategy="ddp", + accelerator="cpu", + devices=3, + use_distributed_sampler=False, + enable_progress_bar=False, + logger=False, + enable_checkpointing=False, +) + +data_module = CachedDataModule( + length=50, batch_size=2, split_ratio=0.6, num_workers=4, persistent_workers=False +) +model = DummyModel() +trainer.fit(model, data_module) From 5300b4a932c6bf334a7943182c5d65c464019985 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 21 Oct 2024 16:32:34 -0700 Subject: [PATCH 15/20] format and lint --- viscy/data/hcs_ram.py | 50 +++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index d8b1b698..e891beda 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -1,4 +1,6 @@ import logging +from datetime import datetime +from multiprocessing import Manager from typing import Callable, Literal, Sequence import numpy as np @@ -6,6 +8,7 @@ from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta +from monai.data.utils import collate_meta_tensor from monai.transforms import ( CenterSpatialCropd, Compose, @@ -14,12 +17,9 @@ ) from torch import Tensor from torch.utils.data import DataLoader, Dataset -from monai.data.utils import collate_meta_tensor from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample -from multiprocessing import Manager -from datetime import datetime _logger = logging.getLogger("lightning.pytorch") @@ -27,14 +27,15 @@ # Map the NumPy dtype to the corresponding PyTorch dtype numpy_to_torch_dtype = { - np.dtype('float32'): torch.float32, - np.dtype('float64'): torch.float64, - np.dtype('int32'): torch.int32, - np.dtype('int64'): torch.int64, - np.dtype('uint8'): torch.int8, - np.dtype('uint16'): torch.int16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("uint8"): torch.int8, + np.dtype("uint16"): torch.int16, } + def _stack_channels( sample_images: list[dict[str, Tensor]] | dict[str, Tensor], channels: ChannelMap, @@ -47,6 +48,7 @@ def _stack_channels( # sample_images is a list['Phase3D'].shape = (1,3,256,256) return [torch.stack([im[ch][0] for ch in channels[key]]) for im in sample_images] + def _collate_samples(batch: Sequence[Sample]) -> Sample: """Collate samples into a batch sample. @@ -65,7 +67,8 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: data.append(sample[key]) collated[key] = collate_meta_tensor(data) return collated - + + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -98,12 +101,12 @@ def __init__( self.total_ch_names.extend(self.channels["target"]) self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() - + # Cached dictionary with tensors self.cache_dict = {} manager = Manager() self.cache_dict = manager.dict() - self._cached_pos=[] + self._cached_pos = [] def _position_mapping(self) -> None: self.position_keys = [] @@ -116,12 +119,11 @@ def _position_mapping(self) -> None: def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: # Add the position to the cached_dict # TODO: hardcoding to t=0 - data =self.positions[index].data.oindex[slice(t, t + 1), channel_index, :] + data = self.positions[index].data.oindex[slice(t, t + 1), channel_index, :] if data.dtype != np.float32: data = data.astype(np.float32) self.cache_dict[str(self.position_keys[index])] = torch.from_numpy(data) - def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform raise NotImplementedError @@ -130,10 +132,10 @@ def __len__(self) -> int: return len(self.positions) def __getitem__(self, index: int) -> Sample: - #FIXME replace this after debugging + # FIXME replace this after debugging ch_idx = self.total_ch_idx ch_names = self.total_ch_names - + # Check if the sample is in the cache else add it # Split the tensor into the channels sample_id = self.position_keys[index] @@ -144,7 +146,7 @@ def __getitem__(self, index: int) -> Sample: self._cache_dataset(index, channel_index=ch_idx) # Get the sample from the cache - _logger.info('Getting sample from cache') + _logger.info("Getting sample from cache") start_time = datetime.now() images = self.cache_dict[sample_id].unbind(dim=1) norm_meta = self.norm_meta_dict[str(sample_id)] @@ -234,10 +236,7 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: raise NotImplementedError(f"Stage {stage} is not supported") def _train_transform(self) -> list[Callable]: - """ Set the train augmentations - - - """ + """Set the train augmentations""" if self.augmentations: for aug in self.augmentations: @@ -252,9 +251,9 @@ def _train_transform(self) -> list[Callable]: ) self.train_patches_per_stack = num_samples else: - self.augmentations=[] - - _logger.debug(f'Training augmentations: {self.augmentations}') + self.augmentations = [] + + _logger.debug(f"Training augmentations: {self.augmentations}") return list(self.augmentations) def _fit_transform(self) -> tuple[Compose, Compose]: @@ -318,7 +317,7 @@ def train_dataloader(self) -> DataLoader: shuffle=True, timeout=self.timeout, collate_fn=_collate_samples, - drop_last=True + drop_last=True, ) def val_dataloader(self) -> DataLoader: @@ -330,5 +329,4 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - ) From 8a8b4b017b42d0dc680e6c523be9138748985b50 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 22 Oct 2024 09:40:26 -0700 Subject: [PATCH 16/20] addding the custom distrb sampler to hcs_ram.py --- viscy/data/hcs_ram.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index e891beda..7e406179 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from multiprocessing import Manager +from multiprocessing.managers import DictProxy from typing import Callable, Literal, Sequence import numpy as np @@ -20,6 +21,9 @@ from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample +from viscy.data.distributed import ShardedDistributedSampler +from torch.distributed import get_rank +import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -68,7 +72,10 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated - +def is_ddp_enabled() -> bool: + """Check if distributed data parallel (DDP) is initialized.""" + return dist.is_available() and dist.is_initialized() + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -77,11 +84,17 @@ class CachedDataset(Dataset): def __init__( self, + shared_dict: DictProxy, positions: list[Position], channels: ChannelMap, transform: DictTransform | None = None, ): super().__init__() + if is_ddp_enabled(): + self.rank = dist.get_rank() + _logger.info(f"=== Initializing cache pool for rank {self.rank} ===") + + self.cache_dict = shared_dict self.positions = positions self.channels = channels self.transform = transform @@ -103,9 +116,7 @@ def __init__( self._position_mapping() # Cached dictionary with tensors - self.cache_dict = {} - manager = Manager() - self.cache_dict = manager.dict() + # TODO: Delete after testing self._cached_pos = [] def _position_mapping(self) -> None: @@ -132,18 +143,13 @@ def __len__(self) -> int: return len(self.positions) def __getitem__(self, index: int) -> Sample: - # FIXME replace this after debugging - ch_idx = self.total_ch_idx - ch_names = self.total_ch_names - # Check if the sample is in the cache else add it - # Split the tensor into the channels sample_id = self.position_keys[index] if sample_id not in self.cache_dict: _logger.info(f"Adding {sample_id} to cache") self._cached_pos.append(index) _logger.info(f"Cached positions: {self._cached_pos}") - self._cache_dataset(index, channel_index=ch_idx) + self._cache_dataset(index, channel_index=self.total_ch_idx) # Get the sample from the cache _logger.info("Getting sample from cache") @@ -151,7 +157,7 @@ def __getitem__(self, index: int) -> Sample: images = self.cache_dict[sample_id].unbind(dim=1) norm_meta = self.norm_meta_dict[str(sample_id)] after_cache = datetime.now() - start_time - sample_images = {k: v for k, v in zip(ch_names, images)} + sample_images = {k: v for k, v in zip(self.total_ch_names, images)} if self.target_ch_idx is not None: # FIXME: this uses the first target channel as weight for performance @@ -296,31 +302,36 @@ def _setup_fit(self, dataset_settings: dict) -> None: positions = list(positions[i] for i in shuffled_indices) num_train_fovs = int(len(positions) * self.split_ratio) + shared_dict = Manager().dict() self.train_dataset = CachedDataset( + shared_dict, positions[:num_train_fovs], transform=train_transform, **dataset_settings, ) self.val_dataset = CachedDataset( + shared_dict, positions[num_train_fovs:], transform=val_transform, **dataset_settings, ) def train_dataloader(self) -> DataLoader: + sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) return DataLoader( self.train_dataset, batch_size=self.batch_size // self.train_patches_per_stack, num_workers=self.num_workers, persistent_workers=bool(self.num_workers), pin_memory=True, - shuffle=True, + shuffle=False, timeout=self.timeout, collate_fn=_collate_samples, drop_last=True, ) def val_dataloader(self) -> DataLoader: + sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) return DataLoader( self.val_dataset, batch_size=self.batch_size, From 49764faa6f433d3a7bad9f70334b924a402244f9 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 22 Oct 2024 09:52:57 -0700 Subject: [PATCH 17/20] adding sampler to val train dataloader --- viscy/data/hcs_ram.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 7e406179..cedc3403 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -75,7 +75,7 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() - + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -328,6 +328,7 @@ def train_dataloader(self) -> DataLoader: timeout=self.timeout, collate_fn=_collate_samples, drop_last=True, + sampler=sampler, ) def val_dataloader(self) -> DataLoader: @@ -340,4 +341,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, + sampler=sampler ) From 1fe54913bf045cab07b1537e735b000784f2a4a8 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 22 Oct 2024 10:49:38 -0700 Subject: [PATCH 18/20] fix divisibility of the last shard --- viscy/data/distributed.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py index bd3ab618..68e6d39e 100644 --- a/viscy/data/distributed.py +++ b/viscy/data/distributed.py @@ -1,35 +1,40 @@ """Utilities for DDP training.""" +from __future__ import annotations + import math +from typing import TYPE_CHECKING import torch +import torch.distributed from torch.utils.data.distributed import DistributedSampler +if TYPE_CHECKING: + from torch import Generator + class ShardedDistributedSampler(DistributedSampler): - def _sharded_randperm(self, generator): - """Generate a sharded random permutation of indices.""" - indices = torch.tensor(range(len(self.dataset))) - permuted = torch.stack( - [ - torch.randperm(self.num_samples, generator=generator) - + i * self.num_samples - for i in range(self.num_replicas) - ], - dim=1, - ).reshape(-1) - return indices[permuted].tolist() + def _sharded_randperm(self, max_size: int, generator: Generator) -> list[int]: + """Generate a sharded random permutation of indices. + Overlap may occur in between the last two shards to maintain divisibility.""" + sharded_randperm = [ + torch.randperm(self.num_samples, generator=generator) + + min(i * self.num_samples, max_size - self.num_samples) + for i in range(self.num_replicas) + ] + indices = torch.stack(sharded_randperm, dim=1).reshape(-1) + return indices.tolist() def __iter__(self): """Modified __iter__ method to shard data across distributed ranks.""" + max_size = len(self.dataset) # type: ignore[arg-type] if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) - indices = self._sharded_randperm(g) + indices = self._sharded_randperm(max_size, g) else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] - + indices = list(range(max_size)) if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) From 0b005cfb6407b018adc7104cf184667a108a2990 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 22 Oct 2024 11:04:19 -0700 Subject: [PATCH 19/20] hcs_ram format and lint --- viscy/data/hcs_ram.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index cedc3403..aa24f28e 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -6,6 +6,7 @@ import numpy as np import torch +import torch.distributed as dist from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta @@ -19,11 +20,9 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample -from viscy.data.distributed import ShardedDistributedSampler -from torch.distributed import get_rank -import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -72,10 +71,12 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated + def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() + class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -341,5 +342,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - sampler=sampler + sampler=sampler, ) From daa686028575868d7f87f34955f7451d9fcfac19 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 24 Oct 2024 10:49:49 -0700 Subject: [PATCH 20/20] path for if not ddp --- viscy/data/hcs_ram.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index aa24f28e..a9ff25d3 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -6,7 +6,6 @@ import numpy as np import torch -import torch.distributed as dist from iohub.ngff import Position, open_ome_zarr from lightning.pytorch import LightningDataModule from monai.data import set_track_meta @@ -20,9 +19,11 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset -from viscy.data.distributed import ShardedDistributedSampler from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample +from viscy.data.distributed import ShardedDistributedSampler +from torch.distributed import get_rank +import torch.distributed as dist _logger = logging.getLogger("lightning.pytorch") @@ -71,12 +72,10 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: collated[key] = collate_meta_tensor(data) return collated - def is_ddp_enabled() -> bool: """Check if distributed data parallel (DDP) is initialized.""" return dist.is_available() and dist.is_initialized() - class CachedDataset(Dataset): """ A dataset that caches the data in RAM. @@ -318,7 +317,11 @@ def _setup_fit(self, dataset_settings: dict) -> None: ) def train_dataloader(self) -> DataLoader: - sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + if is_ddp_enabled(): + sampler = ShardedDistributedSampler(self.train_dataset, shuffle=True) + else: + sampler = None + _logger.info("Using standard sampler for non-distributed training") return DataLoader( self.train_dataset, batch_size=self.batch_size // self.train_patches_per_stack, @@ -333,7 +336,12 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + if is_ddp_enabled(): + sampler = ShardedDistributedSampler(self.val_dataset, shuffle=False) + else: + sampler = None + _logger.info("Using standard sampler for non-distributed validation") + return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -342,5 +350,5 @@ def val_dataloader(self) -> DataLoader: pin_memory=True, shuffle=False, timeout=self.timeout, - sampler=sampler, + sampler=sampler )