Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/github_actions/codecov/codecov-ac…
Browse files Browse the repository at this point in the history
…tion-5
  • Loading branch information
rhoadesScholar authored Dec 6, 2024
2 parents 6da4f2c + 8813526 commit 27bd935
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Build package
run: hatch build
- name: Publish package
uses: pypa/gh-action-pypi-publish@fb13cb306901256ace3dab689990e13a5550ffaa
uses: pypa/gh-action-pypi-publish@15c56dba361d8335944d31a2ecd17d700fc7bcbc
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ dependencies = [
"pydantic_ome_ngff",
"xarray_ome_ngff",
"tensorstore",
"xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git",
# "xarray-tensorstore",
# "xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git",
"xarray-tensorstore",
"universal_pathlib>=0.2.0",
"fsspec[s3,http]",
"cellpose",
Expand Down
32 changes: 19 additions & 13 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,25 @@ def __getitem__(self, indices: Sequence[int]) -> dict:

def refresh(self):
"""If the sampler is a Callable, refresh the DataLoader with the current sampler."""
if isinstance(self.sampler, Callable):
kwargs = self.default_kwargs.copy()
kwargs.update(
{
"dataset": self.dataset,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"collate_fn": self.collate_fn,
"shuffle": False,
}
)
kwargs["sampler"] = self.sampler()
self.loader = DataLoader(**kwargs)
kwargs = self.default_kwargs.copy()
kwargs.update(
{
"dataset": self.dataset,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"collate_fn": self.collate_fn,
}
)
if self.sampler is not None:
if isinstance(self.sampler, Callable):
kwargs["sampler"] = self.sampler()
else:
kwargs["sampler"] = self.sampler
elif self.is_train:
kwargs["shuffle"] = True
else:
kwargs["shuffle"] = False
self.loader = DataLoader(**kwargs)

def collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]:
"""Combine a list of dictionaries from different sources into a single dictionary for output."""
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def sampling_box_shape(self) -> dict[str, int]:
for c, size in self._sampling_box_shape.items():
if size <= 0:
logger.warning(
f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1"
f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1 and padding."
)
self._sampling_box_shape[c] = 1
return self._sampling_box_shape
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/dataset_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def sampling_box_shape(self) -> dict[str, int]:
for c, size in self._sampling_box_shape.items():
if size <= 0:
logger.warning(
f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1"
f"Sampling box shape is <= 0 for axis {c} with size {size}. Setting to 1 and padding"
)
self._sampling_box_shape[c] = 1
return self._sampling_box_shape
Expand Down
35 changes: 23 additions & 12 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,19 +352,24 @@ def class_counts(self) -> float:
except AttributeError:
# Get from cellmap-schemas metadata, then normalize by resolution
try:
# TODO: Make work with HDF5 files
bg_count = self.group[self.scale_level].attrs["cellmap"]["annotation"][
bg_count = self.group["s0"].attrs["cellmap"]["annotation"][
"complement_counts"
]["absent"]
for scale in self.group.attrs["multiscales"][0]["datasets"]:
if scale["path"] == "s0":
for transform in scale["coordinateTransformations"]:
if "scale" in transform:
s0_scale = transform["scale"]
break
break
self._class_counts = (
np.prod(self.group[self.scale_level].shape) - bg_count
) * np.prod(list(self.scale.values()))
self._bg_count = bg_count * np.prod(list(self.scale.values()))
np.prod(self.group["s0"].shape) - bg_count
) * np.prod(s0_scale)
self._bg_count = bg_count * np.prod(s0_scale)
except Exception as e:
# print(f"Error: {e}")
print(
"Unable to get class counts from metadata, falling back to giving foreground 1 pixel, and the rest to background."
)
print(f"Error: {e}")
print(f"Unable to get class counts for {self.path}")
# print("from metadata, falling back to giving foreground 1 pixel, and the rest to background.")
self._class_counts = np.prod(list(self.scale.values()))
self._bg_count = (
np.prod(self.group[self.scale_level].shape) - 1
Expand Down Expand Up @@ -778,6 +783,10 @@ def array(self) -> xarray.DataArray:
try:
array = array_future.result()
except ValueError as e:
if "ALREADY_EXISTS" in str(e):
raise FileExistsError(
f"Image already exists at {self.path}. Set overwrite=True to overwrite the image."
)
Warning(e)
UserWarning("Falling back to zarr3 driver")
spec["driver"] = "zarr3"
Expand Down Expand Up @@ -929,9 +938,11 @@ def __setitem__(
try:
self.array.loc[coords] = data
except ValueError as e:
print(
f"Writing to center {center} in image {self.path} failed. Coordinates: are not all within the image's bounds. Will drop out of bounds data."
)
# print(e)
# print(data.shape)
# print(
# f"Writing to center {center} in image {self.path} failed. Coordinates: are not all within the image's bounds. Will drop out of bounds data."
# )
# Crop data to match the number of coordinates matched in the image
slices = []
for coord in coords.values():
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def validation_indices(self) -> Sequence[int]:
if i == 0:
offset = 0
else:
offset = self.cummulative_sizes[i - 1]
offset = self.cumulative_sizes[i - 1]
sample_indices = np.array(dataset.validation_indices) + offset
indices.extend(list(sample_indices))
except AttributeError:
Expand Down
18 changes: 14 additions & 4 deletions src/cellmap_data/transforms/targets/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,28 @@ class DistanceTransform(torch.nn.Module):
Attributes:
use_cuda (bool): Use CUDA.
clip (list): Clip the output to the specified range.
Methods:
_transform: Transform the input.
forward: Forward pass.
"""

def __init__(self, use_cuda: bool = False) -> None:
def __init__(self, use_cuda: bool = False, clip=[-torch.inf, torch.inf]) -> None:
"""
Initialize the distance transform.
Args:
use_cuda (bool, optional): Use CUDA. Defaults to False.
clip (list, optional): Clip the output to the specified range. Defaults to [-torch.inf, torch.inf].
Raises:
NotImplementedError: CUDA is not supported yet.
"""
UserWarning("This is still in development and may not work as expected")
super().__init__()
self.use_cuda = use_cuda
self.clip = clip
if self.use_cuda:
raise NotImplementedError(
"CUDA is not supported yet because testing did not return expected results."
Expand All @@ -46,7 +49,7 @@ def _transform(self, x: torch.Tensor) -> torch.Tensor:
)
# return transform_cuda(x)
else:
return transform(x)
return transform(x).clip(self.clip[0], self.clip[1])

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
Expand All @@ -64,25 +67,28 @@ class SignedDistanceTransform(torch.nn.Module):
Attributes:
use_cuda (bool): Use CUDA.
clip (list): Clip the output to the specified range.
Methods:
_transform: Transform the input.
forward: Forward pass.
"""

def __init__(self, use_cuda: bool = False) -> None:
def __init__(self, use_cuda: bool = False, clip=[-torch.inf, torch.inf]) -> None:
"""
Initialize the signed distance transform.
Args:
use_cuda (bool, optional): Use CUDA. Defaults to False.
clip (list, optional): Clip the output to the specified range. Defaults to [-torch.inf, torch.inf].
Raises:
NotImplementedError: CUDA is not supported yet.
"""
UserWarning("This is still in development and may not work as expected")
super().__init__()
self.use_cuda = use_cuda
self.clip = clip
if self.use_cuda:
raise NotImplementedError(
"CUDA is not supported yet because testing did not return expected results."
Expand All @@ -96,7 +102,11 @@ def _transform(self, x: torch.Tensor) -> torch.Tensor:
)
# return transform_cuda(x) - transform_cuda(x.logical_not())
else:
return transform(x) - transform(x.logical_not())
# TODO: Fix this to be correct

return transform(x).clip(self.clip[0], self.clip[1]) - transform(
x.logical_not()
).clip(self.clip[0], self.clip[1])

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
Expand Down

0 comments on commit 27bd935

Please sign in to comment.