diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 42957cd..d555758 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -33,7 +33,7 @@ jobs: - name: Build package run: hatch build - name: Publish package - uses: pypa/gh-action-pypi-publish@fb13cb306901256ace3dab689990e13a5550ffaa + uses: pypa/gh-action-pypi-publish@15c56dba361d8335944d31a2ecd17d700fc7bcbc with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index 7e57327..a02e588 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,8 +31,8 @@ dependencies = [ "pydantic_ome_ngff", "xarray_ome_ngff", "tensorstore", - "xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git", - # "xarray-tensorstore", + # "xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git", + "xarray-tensorstore", "universal_pathlib>=0.2.0", "fsspec[s3,http]", "cellpose", diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index 900d52a..810addc 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -97,19 +97,25 @@ def __getitem__(self, indices: Sequence[int]) -> dict: def refresh(self): """If the sampler is a Callable, refresh the DataLoader with the current sampler.""" - if isinstance(self.sampler, Callable): - kwargs = self.default_kwargs.copy() - kwargs.update( - { - "dataset": self.dataset, - "batch_size": self.batch_size, - "num_workers": self.num_workers, - "collate_fn": self.collate_fn, - "shuffle": False, - } - ) - kwargs["sampler"] = self.sampler() - self.loader = DataLoader(**kwargs) + kwargs = self.default_kwargs.copy() + kwargs.update( + { + "dataset": self.dataset, + "batch_size": self.batch_size, + "num_workers": self.num_workers, + "collate_fn": self.collate_fn, + } + ) + if self.sampler is not None: + if isinstance(self.sampler, Callable): + kwargs["sampler"] = self.sampler() + else: + kwargs["sampler"] = self.sampler + elif self.is_train: + kwargs["shuffle"] = True + else: + kwargs["shuffle"] = False + self.loader = DataLoader(**kwargs) def collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]: """Combine a list of dictionaries from different sources into a single dictionary for output.""" diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 58c86cb..3820b89 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -307,7 +307,7 @@ def sampling_box_shape(self) -> dict[str, int]: for c, size in self._sampling_box_shape.items(): if size <= 0: logger.warning( - f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1" + f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1 and padding." ) self._sampling_box_shape[c] = 1 return self._sampling_box_shape diff --git a/src/cellmap_data/dataset_writer.py b/src/cellmap_data/dataset_writer.py index 1a45304..1a523fe 100644 --- a/src/cellmap_data/dataset_writer.py +++ b/src/cellmap_data/dataset_writer.py @@ -222,7 +222,7 @@ def sampling_box_shape(self) -> dict[str, int]: for c, size in self._sampling_box_shape.items(): if size <= 0: logger.warning( - f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1" + f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1 and padding" ) self._sampling_box_shape[c] = 1 return self._sampling_box_shape diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index fdf30ee..48a7048 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -352,19 +352,24 @@ def class_counts(self) -> float: except AttributeError: # Get from cellmap-schemas metadata, then normalize by resolution try: - # TODO: Make work with HDF5 files - bg_count = self.group[self.scale_level].attrs["cellmap"]["annotation"][ + bg_count = self.group["s0"].attrs["cellmap"]["annotation"][ "complement_counts" ]["absent"] + for scale in self.group.attrs["multiscales"][0]["datasets"]: + if scale["path"] == "s0": + for transform in scale["coordinateTransformations"]: + if "scale" in transform: + s0_scale = transform["scale"] + break + break self._class_counts = ( - np.prod(self.group[self.scale_level].shape) - bg_count - ) * np.prod(list(self.scale.values())) - self._bg_count = bg_count * np.prod(list(self.scale.values())) + np.prod(self.group["s0"].shape) - bg_count + ) * np.prod(s0_scale) + self._bg_count = bg_count * np.prod(s0_scale) except Exception as e: - # print(f"Error: {e}") - print( - "Unable to get class counts from metadata, falling back to giving foreground 1 pixel, and the rest to background." - ) + print(f"Error: {e}") + print(f"Unable to get class counts for {self.path}") + # print("from metadata, falling back to giving foreground 1 pixel, and the rest to background.") self._class_counts = np.prod(list(self.scale.values())) self._bg_count = ( np.prod(self.group[self.scale_level].shape) - 1 @@ -778,6 +783,10 @@ def array(self) -> xarray.DataArray: try: array = array_future.result() except ValueError as e: + if "ALREADY_EXISTS" in str(e): + raise FileExistsError( + f"Image already exists at {self.path}. Set overwrite=True to overwrite the image." + ) Warning(e) UserWarning("Falling back to zarr3 driver") spec["driver"] = "zarr3" @@ -929,9 +938,11 @@ def __setitem__( try: self.array.loc[coords] = data except ValueError as e: - print( - f"Writing to center {center} in image {self.path} failed. Coordinates: are not all within the image's bounds. Will drop out of bounds data." - ) + # print(e) + # print(data.shape) + # print( + # f"Writing to center {center} in image {self.path} failed. Coordinates: are not all within the image's bounds. Will drop out of bounds data." + # ) # Crop data to match the number of coordinates matched in the image slices = [] for coord in coords.values(): diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index 96c5058..5bc0f19 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -157,7 +157,7 @@ def validation_indices(self) -> Sequence[int]: if i == 0: offset = 0 else: - offset = self.cummulative_sizes[i - 1] + offset = self.cumulative_sizes[i - 1] sample_indices = np.array(dataset.validation_indices) + offset indices.extend(list(sample_indices)) except AttributeError: diff --git a/src/cellmap_data/transforms/targets/distance.py b/src/cellmap_data/transforms/targets/distance.py index 16b9851..6025af0 100644 --- a/src/cellmap_data/transforms/targets/distance.py +++ b/src/cellmap_data/transforms/targets/distance.py @@ -14,18 +14,20 @@ class DistanceTransform(torch.nn.Module): Attributes: use_cuda (bool): Use CUDA. + clip (list): Clip the output to the specified range. Methods: _transform: Transform the input. forward: Forward pass. """ - def __init__(self, use_cuda: bool = False) -> None: + def __init__(self, use_cuda: bool = False, clip=[-torch.inf, torch.inf]) -> None: """ Initialize the distance transform. Args: use_cuda (bool, optional): Use CUDA. Defaults to False. + clip (list, optional): Clip the output to the specified range. Defaults to [-torch.inf, torch.inf]. Raises: NotImplementedError: CUDA is not supported yet. @@ -33,6 +35,7 @@ def __init__(self, use_cuda: bool = False) -> None: UserWarning("This is still in development and may not work as expected") super().__init__() self.use_cuda = use_cuda + self.clip = clip if self.use_cuda: raise NotImplementedError( "CUDA is not supported yet because testing did not return expected results." @@ -46,7 +49,7 @@ def _transform(self, x: torch.Tensor) -> torch.Tensor: ) # return transform_cuda(x) else: - return transform(x) + return transform(x).clip(self.clip[0], self.clip[1]) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" @@ -64,18 +67,20 @@ class SignedDistanceTransform(torch.nn.Module): Attributes: use_cuda (bool): Use CUDA. + clip (list): Clip the output to the specified range. Methods: _transform: Transform the input. forward: Forward pass. """ - def __init__(self, use_cuda: bool = False) -> None: + def __init__(self, use_cuda: bool = False, clip=[-torch.inf, torch.inf]) -> None: """ Initialize the signed distance transform. Args: use_cuda (bool, optional): Use CUDA. Defaults to False. + clip (list, optional): Clip the output to the specified range. Defaults to [-torch.inf, torch.inf]. Raises: NotImplementedError: CUDA is not supported yet. @@ -83,6 +88,7 @@ def __init__(self, use_cuda: bool = False) -> None: UserWarning("This is still in development and may not work as expected") super().__init__() self.use_cuda = use_cuda + self.clip = clip if self.use_cuda: raise NotImplementedError( "CUDA is not supported yet because testing did not return expected results." @@ -96,7 +102,11 @@ def _transform(self, x: torch.Tensor) -> torch.Tensor: ) # return transform_cuda(x) - transform_cuda(x.logical_not()) else: - return transform(x) - transform(x.logical_not()) + # TODO: Fix this to be correct + + return transform(x).clip(self.clip[0], self.clip[1]) - transform( + x.logical_not() + ).clip(self.clip[0], self.clip[1]) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass."""