Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cellpose - don't merge #146

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f7233be
cellpose unet in progress
mzouink Feb 16, 2024
b98cf2d
cellpose target, task and post processor
mzouink Feb 16, 2024
e0ac0f3
Merge branch 'dev/main' into cellpose
rhoadesScholar Feb 28, 2024
e05cb03
fix: 🚀 Add cellpose requirement.
rhoadesScholar Feb 28, 2024
70aabd8
cellpose unet
mzouink Apr 3, 2024
d40fc43
black format
mzouink Apr 3, 2024
b639c3c
Merge branch 'dev/main' into cellpose
mzouink Apr 3, 2024
408d1ba
Merge branch 'cellpose' of github.com:janelia-cellmap/dacapo into cel…
rhoadesScholar Apr 3, 2024
c8029a0
cellpose
mzouink Apr 4, 2024
fbfff32
unit test
mzouink Apr 4, 2024
4e8c405
cellpose unet
mzouink Apr 4, 2024
dc62c66
upgrade tests to pytest8
mzouink Apr 4, 2024
0f18edf
:art: Format Python code with psf/black
mzouink Apr 4, 2024
c6ed780
Update pyproject.toml
mzouink Apr 4, 2024
68c6b1c
Format Python code with psf/black push (#233)
mzouink Apr 4, 2024
0069031
Merge branch 'cellpose_unet' into pytest_upgrade
mzouink Apr 4, 2024
76b306b
upgrade tests to pytest8 (#234)
mzouink Apr 4, 2024
c4c1c3a
Merge branch 'dev/main' into cellpose_unet
mzouink Apr 5, 2024
530e060
Merge branch 'dev/main' into cellpose_unet
mzouink Apr 5, 2024
f68c1b7
Merge branch 'main' into cellpose_unet
mzouink May 9, 2024
880946c
Merge branch 'main' into cellpose
mzouink May 9, 2024
3185f04
:art: Format Python code with psf/black
mzouink May 9, 2024
6fd984c
Merge branch 'cellpose_unet' into actions/black
mzouink May 9, 2024
d417485
Format Python code with psf/black push (#255)
mzouink May 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dacapo/experiments/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DummyArchitecture,
) # noqa
from .cnnectome_unet_config import CNNectomeUNetConfig, CNNectomeUNet # noqa
from .cellpose_unet_config import CellposUNetConfig, CellposeUnet # noqa
75 changes: 75 additions & 0 deletions dacapo/experiments/architectures/cellpose_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from cellpose.resnet_torch import CPnet
from .architecture import Architecture
from funlib.geometry import Coordinate


# example
# nout = 4
# sz = 3
# self.net = CPnet(
# nbase, nout, sz, mkldnn=False, conv_3D=True, max_pool=True, diam_mean=30.0
# )
# currently the input channels are embedded in nbdase, but they should be passed as a separate parameternbase = [in_chan, 32, 64, 128, 256]
class CellposeUnet(Architecture):
def __init__(self, architecture_config):
super().__init__()
self._input_shape = Coordinate(architecture_config.input_shape)
self._nbase = architecture_config.nbase
self._sz = self._input_shape.dims
self._eval_shape_increase = Coordinate((0,) * self._sz)
self._nout = architecture_config.nout
print("conv_3D:", architecture_config.conv_3D)
self.unet = CPnet(
architecture_config.nbase,
architecture_config.nout,
self._sz,
architecture_config.mkldnn,
architecture_config.conv_3D,
architecture_config.max_pool,
architecture_config.diam_mean,
)
print(self.unet)

def forward(self, data):
"""
Forward pass of the CPnet model.

Args:
data (torch.Tensor): Input data.

Returns:
tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
"""
if self.unet.mkldnn:
data = data.to_mkldnn()
T0 = self.unet.downsample(data)
if self.unet.mkldnn:
style = self.unet.make_style(T0[-1].to_dense())
else:
style = self.unet.make_style(T0[-1])
# style0 = style
if not self.unet.style_on:
style = style * 0
T1 = self.unet.upsample(style, T0, self.unet.mkldnn)
# head layer
# T1 = self.unet.output(T1)
if self.unet.mkldnn:
T0 = [t0.to_dense() for t0 in T0]
T1 = T1.to_dense()
return T1

@property
def input_shape(self):
return self._input_shape

@property
def num_in_channels(self) -> int:
return self._nbase[0]

@property
def num_out_channels(self) -> int:
return self._nout

@property
def eval_shape_increase(self):
return self._eval_shape_increase
41 changes: 41 additions & 0 deletions dacapo/experiments/architectures/cellpose_unet_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import attr

from .architecture_config import ArchitectureConfig
from .cellpose_unet import CellposeUnet

from funlib.geometry import Coordinate

from typing import List, Optional


@attr.s
class CellposUNetConfig(ArchitectureConfig):
"""This class configures the CellPose based on
https://github.com/MouseLand/cellpose/blob/main/cellpose/resnet_torch.py
"""

architecture_type = CellposeUnet

input_shape: Coordinate = attr.ib(
metadata={
"help_text": "The shape of the data passed into the network during training."
}
)
nbase: List[int] = attr.ib(
metadata={
"help_text": "List of integers representing the number of channels in each layer of the downsample path."
}
)
nout: int = attr.ib(metadata={"help_text": "Number of output channels."})
mkldnn: Optional[bool] = attr.ib(
default=False, metadata={"help_text": "Whether to use MKL-DNN acceleration."}
)
conv_3D: bool = attr.ib(
default=False, metadata={"help_text": "Whether to use 3D convolution."}
)
max_pool: Optional[bool] = attr.ib(
default=True, metadata={"help_text": "Whether to use max pooling."}
)
diam_mean: Optional[float] = attr.ib(
default=30.0, metadata={"help_text": "Mean diameter of the cells."}
)
23 changes: 23 additions & 0 deletions dacapo/experiments/tasks/cellpose_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .evaluators import BinarySegmentationEvaluator
from .losses import CellposeLoss
from .post_processors import ThresholdPostProcessor
from .predictors import CellposePredictor
from .task import Task


class CellposeTask(Task):
def __init__(self, task_config):
self.predictor = CellposePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
clipmin=task_config.clipmin,
clipmax=task_config.clipmax,
)
self.loss = CellposeLoss()
self.post_processor = ThresholdPostProcessor()
self.evaluator = BinarySegmentationEvaluator(
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .loss import Loss # noqa
from .affinities_loss import AffinitiesLoss # noqa
from .hot_distance_loss import HotDistanceLoss # noqa
from .cellpose_loss import CellposeLoss # noqa
18 changes: 18 additions & 0 deletions dacapo/experiments/tasks/losses/cellpose_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .loss import Loss
import torch
from torch import nn

# TODO check support weights


class CellposeLoss(Loss):
def compute(self, prediction, target, weights=None):
"""loss function between true labels target and prediction prediction"""
criterion = nn.MSELoss(reduction="mean")
criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
veci = 5.0 * target[:, 1:]
loss = criterion(prediction[:, :-1], veci)
loss /= 2.0
loss2 = criterion2(prediction[:, -1], (target[:, 0] > 0.5).float())
loss = loss + loss2
return loss
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from .cellpose_post_processor_parameters import CellposePostProcessorParameters
from .post_processor import PostProcessor
from dacapo.store import LocalArrayIdentifier, ZarrArray
import numpy as np
import zarr

from typing import Iterable

from cellpose.dynamics import compute_masks

# https://github.com/MouseLand/cellpose/blob/54b14fe567d885db293280b9b8d68dc50703d219/cellpose/models.py#L608


class CellposePostProcessor(PostProcessor):
def __init__(self, detection_threshold: float):
self.detection_threshold = detection_threshold

def enumerate_parameters(self) -> Iterable[CellposePostProcessorParameters]:
"""Enumerate all possible parameters of this post-processor. Should
return instances of ``PostProcessorParameters``."""

for i, min_size in enumerate(range(1, 11)):
yield CellposePostProcessorParameters(id=i, min_size=min_size)

def set_prediction(self, prediction_array_identifier: LocalArrayIdentifier):
self.prediction_array = ZarrArray.open_from_identifier(
prediction_array_identifier
)

def process(self, parameters, output_array_identifier):
# store some dummy data
f = zarr.open(str(output_array_identifier.container), "a")
f[output_array_identifier.dataset] = compute_masks(
self.prediction_array.data[:-1] / 5.0,
self.prediction_array.data[-1],
niter=200,
cellprob_threshold=self.detection_threshold,
do_3D=True,
min_size=parameters.min_size,
)[0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .post_processor_parameters import PostProcessorParameters
import attr


# TODO
@attr.s(frozen=True)
class CellposePostProcessorParameters(PostProcessorParameters):
min_size: int = attr.ib()
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .affinities_predictor import AffinitiesPredictor # noqa
from .inner_distance_predictor import InnerDistancePredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
from .cellpose_predictor import CellposePredictor # noqa
161 changes: 161 additions & 0 deletions dacapo/experiments/tasks/predictors/cellpose_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from .predictor import Predictor
from dacapo.experiments import Model
from dacapo.experiments.arraytypes import DistanceArray
from dacapo.experiments.datasplits.datasets.arrays import NumpyArray
from dacapo.utils.balance_weights import balance_weights

from funlib.geometry import Coordinate

from scipy.ndimage.morphology import distance_transform_edt
import numpy as np
import torch

import logging
from typing import List
from cellpose.dynamics import masks_to_flows_gpu_3d

logger = logging.getLogger(__name__)


# TODO currently CPnet have nout which is the head of the network, check how to change it in the predictor
class CellposePredictor(Predictor):
"""
Predict signed distances for a binary segmentation task.
Distances deep within background are pushed to -inf, distances deep within
the foreground object are pushed to inf. After distances have been
calculated they are passed through a tanh so that distances saturate at +-1.
Multiple classes can be predicted via multiple distance channels. The names
of each class that is being segmented can be passed in as a list of strings
in the channels argument.
"""

def __init__(
self,
channels: List[str],
scale_factor: float,
mask_distances: bool,
clipmin: float = 0.05,
clipmax: float = 0.95,
):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
self.mask_distances = mask_distances

self.max_distance = 1 * scale_factor
self.epsilon = 5e-2
self.threshold = 0.8
self.clipmin = clipmin
self.clipmax = clipmax

@property
def embedding_dims(self):
return len(self.channels)

def create_model(self, architecture):
if isinstance(architecture, CellposeUnet):
head = torch.nn.Identity()

return Model(architecture, torch.nn.Identity())

def create_target(self, gt):
flows, _ = masks_to_flows_gpu_3d(gt)
# difussion = self.process(
# gt.data, gt.voxel_size, self.norm, self.dt_scale_factor
# )
return NumpyArray.from_np_array(
flows,
gt.roi,
gt.voxel_size,
gt.axes,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
# balance weights independently for each channel

weights, moving_class_counts = balance_weights(
gt[target.roi],
2,
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[mask[target.roi]],
moving_counts=moving_class_counts,
clipmin=self.clipmin,
clipmax=self.clipmax,
)
return (
NumpyArray.from_np_array(
weights,
gt.roi,
gt.voxel_size,
gt.axes,
),
moving_class_counts,
)

@property
def output_array_type(self):
return DistanceArray(self.embedding_dims)

def process(
self,
labels: np.ndarray,
voxel_size: Coordinate,
normalize=None,
normalize_args=None,
):
all_distances = np.zeros(labels.shape, dtype=np.float32) - 1
for ii, channel in enumerate(labels):
boundaries = self.__find_boundaries(channel)

# mark boundaries with 0 (not 1)
boundaries = 1.0 - boundaries

if np.sum(boundaries == 0) == 0:
max_distance = min(
dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size)
)
if np.sum(channel) == 0:
distances = -np.ones(channel.shape, dtype=np.float32) * max_distance
else:
distances = np.ones(channel.shape, dtype=np.float32) * max_distance
else:
# get distances (voxel_size/2 because image is doubled)
distances = distance_transform_edt(
boundaries, sampling=tuple(float(v) / 2 for v in voxel_size)
)
distances = distances.astype(np.float32)

# restore original shape
downsample = (slice(None, None, 2),) * len(voxel_size)
distances = distances[downsample]

# todo: inverted distance
distances[channel == 0] = -distances[channel == 0]

if normalize is not None:
distances = self.__normalize(distances, normalize, normalize_args)

all_distances[ii] = distances

return all_distances

def __normalize(self, distances, norm, normalize_args):
if norm == "tanh":
scale = normalize_args
return np.tanh(distances / scale)
else:
raise ValueError("Only tanh is supported for normalization")

def gt_region_for_roi(self, target_spec):
if self.mask_distances:
gt_spec = target_spec.copy()
gt_spec.roi = gt_spec.roi.grow(
Coordinate((self.max_distance,) * gt_spec.voxel_size.dims),
Coordinate((self.max_distance,) * gt_spec.voxel_size.dims),
).snap_to_grid(gt_spec.voxel_size, mode="shrink")
else:
gt_spec = target_spec.copy()
return gt_spec

def padding(self, gt_voxel_size: Coordinate) -> Coordinate:
return Coordinate((self.max_distance,) * gt_voxel_size.dims)
Loading
Loading