Skip to content

Commit

Permalink
Merge pull request #23 from janelia-cellmap/array_writer
Browse files Browse the repository at this point in the history
Array writer
  • Loading branch information
rhoadesScholar authored Nov 4, 2024
2 parents ab544be + 86db200 commit 59dde48
Show file tree
Hide file tree
Showing 12 changed files with 1,036 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.10', '3.11']
python-version: ['3.11', '3.12']
platform: [ubuntu-latest, macos-latest, windows-latest]

steps:
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
name: Python mypy

on: [push, pull_request]
on:
push:
branches: [ "main" ]
pull_request:

jobs:
static-analysis:
Expand All @@ -9,6 +12,8 @@ jobs:
steps:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.11
- name: Setup checkout
uses: actions/checkout@v4
- name: mypy
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ name: Test

on:
push:
branches: [ "main" ]
pull_request:

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.11", "3.12"]
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "hatchling.build"
name = "cellmap-data"
description = "Utility for loading CellMap data for machine learning training, utilizing PyTorch, Xarray, TensorStore, and PyDantic."
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
license = { text = "BSD 3-Clause License" }
authors = [
{ email = "[email protected]", name = "Jeff Rhoades" },
Expand All @@ -17,9 +17,8 @@ classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: BSD License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Typing :: Typed",
]
dynamic = ["version"]
Expand All @@ -32,8 +31,9 @@ dependencies = [
"pydantic_ome_ngff",
"xarray_ome_ngff",
"tensorstore",
"xarray-tensorstore @ git+https://github.com/rhoadesScholar/xarray-tensorstore.git",
# "xarray-tensorstore @ git+https://github.com/google/xarray-tensorstore.git",
"xarray-tensorstore",
# "xarray-tensorstore",
"universal_pathlib>=0.2.0",
"fsspec[s3,http]",
"cellpose",
Expand Down
3 changes: 2 additions & 1 deletion src/cellmap_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from .dataloader import CellMapDataLoader
from .datasplit import CellMapDataSplit
from .dataset import CellMapDataset
from .image import CellMapImage
from .dataset_writer import CellMapDatasetWriter
from .image import CellMapImage, EmptyImage, ImageWriter
from .subdataset import CellMapSubset
from . import transforms
from . import utils
8 changes: 4 additions & 4 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from torch.utils.data import DataLoader, Sampler
from torch.utils.data import DataLoader, Sampler, Subset
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
from .subdataset import CellMapSubset
from .dataset_writer import CellMapDatasetWriter
from typing import Callable, Iterable, Optional, Sequence


Expand Down Expand Up @@ -30,7 +30,7 @@ class CellMapDataLoader:

def __init__(
self,
dataset: CellMapMultiDataset | CellMapDataset | CellMapSubset,
dataset: CellMapMultiDataset | CellMapDataset | Subset | CellMapDatasetWriter,
classes: Iterable[str],
batch_size: int = 1,
num_workers: int = 0,
Expand Down Expand Up @@ -113,7 +113,7 @@ def refresh(self):

def collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]:
"""Combine a list of dictionaries from different sources into a single dictionary for output."""
outputs: dict[str, torch.Tensor] = {}
outputs = {}
for b in batch:
for key, value in b.items():
if key not in outputs:
Expand Down
4 changes: 3 additions & 1 deletion src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,16 +403,18 @@ def __len__(self) -> int:

def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
"""Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index."""
idx = np.array(idx)
idx[idx < 0] = len(self) + idx[idx < 0]
try:
center = np.unravel_index(
idx, [self.sampling_box_shape[c] for c in self.axis_order]
)
except ValueError:
# TODO: This is a hacky temprorary fix. Need to figure out why this is happening
logger.error(
f"Index {idx} out of bounds for dataset {self} of length {len(self)}"
)
logger.warning(f"Returning closest index in bounds")
# TODO: This is a hacky temprorary fix. Need to figure out why this is happening
center = [self.sampling_box_shape[c] - 1 for c in self.axis_order]
center = {
c: center[i] * self.largest_voxel_sizes[c] + self.sampling_box[c][0]
Expand Down
Loading

0 comments on commit 59dde48

Please sign in to comment.