Skip to content

Commit

Permalink
using multiprocessing manager
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Oct 19, 2024
1 parent 4f18e50 commit da94915
Showing 1 changed file with 27 additions and 68 deletions.
95 changes: 27 additions & 68 deletions viscy/data/hcs_ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from viscy.data.hcs import _read_norm_meta
from viscy.data.typing import ChannelMap, DictTransform, Sample
from multiprocessing import Manager
from datetime import datetime

_logger = logging.getLogger("lightning.pytorch")

Expand Down Expand Up @@ -96,17 +98,12 @@ def __init__(
self.total_ch_names.extend(self.channels["target"])
self.total_ch_idx.extend(self.target_ch_idx)
self._position_mapping()
<<<<<<< HEAD

self.cache_order = []
self.cache_record = torch.zeros(len(self.positions))
self._init_cache_dataset()

# Caching the dataset as two separate arrays
# self._init_cache_dataset()
=======

# Cached dictionary with tensors
self.cache_dict = {}
>>>>>>> parent of 8c13f49 (replacing dictionary with single array)
manager = Manager()
self.cache_dict = manager.dict()
self._cached_pos=[]

def _position_mapping(self) -> None:
self.position_keys = []
Expand All @@ -115,39 +112,15 @@ def _position_mapping(self) -> None:
for pos in self.positions:
self.position_keys.append(pos.data.name)
self.norm_meta_dict[str(pos.data.name)] = _read_norm_meta(pos)
<<<<<<< HEAD
self.position_shape_zyx = pos.data.shape[-3:]
self._cache_dtype = numpy_to_torch_dtype.get(pos.data.dtype, torch.float32) # Default to torch.float32 if not found

def _init_cache_dataset(self) -> None:
_logger.info('Initializing cache array')
# FIXME assumes t=1
t_idx = 1
self.cache = torch.zeros(((len(self.positions),t_idx,len(self.total_ch_idx))+ self.position_shape_zyx),dtype=self._cache_dtype)
_logger.info(f'Cache shape: {self.cache.shape}')

#TODO Caching here to see if multiprocessing is faster
t=0

for i, pos in enumerate(self.positions):
_logger.info(f'Caching position {i}/{len(self.positions)}')
## Insert the data into the cache
data = pos.data.oindex[slice(t, t + 1), self.total_ch_idx, :]
if data.dtype != np.float32:
data = data.astype(np.float32)
self.cache[i]= torch.from_numpy(data)
del data
=======

def _cache_dataset(self, index: int, channel_index: list[int], t: int = 0) -> None:
# Add the position to the cached_dict
# TODO: hardcoding to t=0
self.cache_dict[str(self.position_keys[index])] = torch.from_numpy(
self.positions[index]
.data.oindex[slice(t, t + 1), channel_index, :]
.astype(np.float32)
)
>>>>>>> parent of 8c13f49 (replacing dictionary with single array)
data =self.positions[index].data.oindex[slice(t, t + 1), channel_index, :]
if data.dtype != np.float32:
data = data.astype(np.float32)
self.cache_dict[str(self.position_keys[index])] = torch.from_numpy(data)


def _get_weight_map(self, position: Position) -> Tensor:
# Get the weight map from the position for the MONAI weightedcrop transform
Expand All @@ -162,43 +135,20 @@ def __getitem__(self, index: int) -> Sample:
ch_names = self.total_ch_names

# Check if the sample is in the cache else add it
<<<<<<< HEAD
# if self.cache_record[index]== 0:
# # Flip the bit
# self.cache_record[index]=1
# self.cache_order.append(index)

# # Stack the data
# _logger.info(f'Adding {self.position_keys[index]} to cache')
# _logger.info(f'Cache_order: {self.cache_order}')
# _logger.info(f'caching index: {index}')

# #FIX ME: hardcoding t=0 and make this part of function
# t=0

# # Insert the data into the cache
# data = self.positions[index].data.oindex[slice(t, t + 1), ch_idx, :]
# if data.dtype != np.float32:
# data = data.astype(np.float32)
# self.cache[index]= torch.from_numpy(data)
# del data

# Get the sample from the cache
_logger.info(f'Getting sample {index} from cache')
sample_id = self.position_keys[index]
images = self.cache[index].unbind(dim=1)
=======
# Split the tensor into the channels
sample_id = self.position_keys[index]
if sample_id not in self.cache_dict:
logging.info(f"Adding {sample_id} to cache")
_logger.info(f"Adding {sample_id} to cache")
self._cached_pos.append(index)
_logger.info(f"Cached positions: {self._cached_pos}")
self._cache_dataset(index, channel_index=ch_idx)

# Get the sample from the cache
_logger.info('Getting sample from cache')
start_time = datetime.now()
images = self.cache_dict[sample_id].unbind(dim=1)
>>>>>>> parent of 8c13f49 (replacing dictionary with single array)
norm_meta = self.norm_meta_dict[str(sample_id)]

after_cache = datetime.now() - start_time
sample_images = {k: v for k, v in zip(ch_names, images)}

if self.target_ch_idx is not None:
Expand All @@ -209,7 +159,9 @@ def __getitem__(self, index: int) -> Sample:
if norm_meta is not None:
sample_images["norm_meta"] = norm_meta
if self.transform:
before_transform = datetime.now()
sample_images = self.transform(sample_images)
after_transform = datetime.now() - before_transform
if "weight" in sample_images:
del sample_images["weight"]
sample = {
Expand All @@ -219,6 +171,11 @@ def __getitem__(self, index: int) -> Sample:
}
if self.target_ch_idx is not None:
sample["target"] = _stack_channels(sample_images, self.channels, "target")

_logger.info(f"\nTime taken to cache: {after_cache}")
_logger.info(f"Time taken to transform: {after_transform}")
_logger.info(f"Time taken to get sample: {datetime.now() - start_time}\n")

return sample

def _load_sample(self, position: Position) -> Sample:
Expand Down Expand Up @@ -357,6 +314,7 @@ def train_dataloader(self) -> DataLoader:
batch_size=self.batch_size // self.train_patches_per_stack,
num_workers=self.num_workers,
persistent_workers=bool(self.num_workers),
pin_memory=True,
shuffle=True,
timeout=self.timeout,
collate_fn=_collate_samples,
Expand All @@ -369,6 +327,7 @@ def val_dataloader(self) -> DataLoader:
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=bool(self.num_workers),
pin_memory=True,
shuffle=False,
timeout=self.timeout,

Expand Down

0 comments on commit da94915

Please sign in to comment.