diff --git a/viscy/data/hcs_ram.py b/viscy/data/hcs_ram.py index 4ecf2f6f..d8b1b698 100644 --- a/viscy/data/hcs_ram.py +++ b/viscy/data/hcs_ram.py @@ -18,6 +18,8 @@ from viscy.data.hcs import _read_norm_meta from viscy.data.typing import ChannelMap, DictTransform, Sample +from multiprocessing import Manager +from datetime import datetime _logger = logging.getLogger("lightning.pytorch") @@ -96,17 +98,12 @@ def __init__( self.total_ch_names.extend(self.channels["target"]) self.total_ch_idx.extend(self.target_ch_idx) self._position_mapping() -<<<<<<< HEAD - - self.cache_order = [] - self.cache_record = torch.zeros(len(self.positions)) - self._init_cache_dataset() - - # Caching the dataset as two separate arrays - # self._init_cache_dataset() -======= + + # Cached dictionary with tensors self.cache_dict = {} ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) + manager = Manager() + self.cache_dict = manager.dict() + self._cached_pos=[] def _position_mapping(self) -> None: self.position_keys = [] @@ -115,39 +112,15 @@ def _position_mapping(self) -> None: for pos in self.positions: self.position_keys.append(pos.data.name) self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos) -<<<<<<< HEAD - self.position_shape_zyx = pos.data.shape[-3:] - self._cache_dtype = numpy_to_torch_dtype.get(pos.data.dtype, torch.float32) # Default to torch.float32 if not found - - def _init_cache_dataset(self) -> None: - _logger.info('Initializing cache array') - # FIXME assumes t=1 - t_idx = 1 - self.cache = torch.zeros(((len(self.positions),t_idx,len(self.total_ch_idx))+ self.position_shape_zyx),dtype=self._cache_dtype) - _logger.info(f'Cache shape: {self.cache.shape}') - - #TODO Caching here to see if multiprocessing is faster - t=0 - - for i, pos in enumerate(self.positions): - _logger.info(f'Caching position {i}/{len(self.positions)}') - ## Insert the data into the cache - data = pos.data.oindex[slice(t, t + 1), self.total_ch_idx, :] - if data.dtype != np.float32: - data = data.astype(np.float32) - self.cache[i]= torch.from_numpy(data) - del data -======= def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None: # Add the position to the cached_dict # TODO: hardcoding to t=0 - self.cache_dict[str(self.position_keys[index])] = torch.from_numpy( - self.positions[index] - .data.oindex[slice(t, t + 1), channel_index, :] - .astype(np.float32) - ) ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) + data =self.positions[index].data.oindex[slice(t, t + 1), channel_index, :] + if data.dtype != np.float32: + data = data.astype(np.float32) + self.cache_dict[str(self.position_keys[index])] = torch.from_numpy(data) + def _get_weight_map(self, position: Position) -> Tensor: # Get the weight map from the position for the MONAI weightedcrop transform @@ -162,43 +135,20 @@ def __getitem__(self, index: int) -> Sample: ch_names = self.total_ch_names # Check if the sample is in the cache else add it -<<<<<<< HEAD - # if self.cache_record[index]== 0: - # # Flip the bit - # self.cache_record[index]=1 - # self.cache_order.append(index) - - # # Stack the data - # _logger.info(f'Adding {self.position_keys[index]} to cache') - # _logger.info(f'Cache_order: {self.cache_order}') - # _logger.info(f'caching index: {index}') - - # #FIX ME: hardcoding t=0 and make this part of function - # t=0 - - # # Insert the data into the cache - # data = self.positions[index].data.oindex[slice(t, t + 1), ch_idx, :] - # if data.dtype != np.float32: - # data = data.astype(np.float32) - # self.cache[index]= torch.from_numpy(data) - # del data - - # Get the sample from the cache - _logger.info(f'Getting sample {index} from cache') - sample_id = self.position_keys[index] - images = self.cache[index].unbind(dim=1) -======= # Split the tensor into the channels sample_id = self.position_keys[index] if sample_id not in self.cache_dict: - logging.info(f"Adding {sample_id} to cache") + _logger.info(f"Adding {sample_id} to cache") + self._cached_pos.append(index) + _logger.info(f"Cached positions: {self._cached_pos}") self._cache_dataset(index, channel_index=ch_idx) # Get the sample from the cache + _logger.info('Getting sample from cache') + start_time = datetime.now() images = self.cache_dict[sample_id].unbind(dim=1) ->>>>>>> parent of 8c13f49 (replacing dictionary with single array) norm_meta = self.norm_meta_dict[str(sample_id)] - + after_cache = datetime.now() - start_time sample_images = {k: v for k, v in zip(ch_names, images)} if self.target_ch_idx is not None: @@ -209,7 +159,9 @@ def __getitem__(self, index: int) -> Sample: if norm_meta is not None: sample_images["norm_meta"] = norm_meta if self.transform: + before_transform = datetime.now() sample_images = self.transform(sample_images) + after_transform = datetime.now() - before_transform if "weight" in sample_images: del sample_images["weight"] sample = { @@ -219,6 +171,11 @@ def __getitem__(self, index: int) -> Sample: } if self.target_ch_idx is not None: sample["target"] = _stack_channels(sample_images, self.channels, "target") + + _logger.info(f"\nTime taken to cache: {after_cache}") + _logger.info(f"Time taken to transform: {after_transform}") + _logger.info(f"Time taken to get sample: {datetime.now() - start_time}\n") + return sample def _load_sample(self, position: Position) -> Sample: @@ -357,6 +314,7 @@ def train_dataloader(self) -> DataLoader: batch_size=self.batch_size // self.train_patches_per_stack, num_workers=self.num_workers, persistent_workers=bool(self.num_workers), + pin_memory=True, shuffle=True, timeout=self.timeout, collate_fn=_collate_samples, @@ -369,6 +327,7 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=bool(self.num_workers), + pin_memory=True, shuffle=False, timeout=self.timeout,