diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py new file mode 100644 index 00000000..ef0d0322 --- /dev/null +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -0,0 +1,25 @@ +from .evaluators import BinarySegmentationEvaluator +from .losses import HotDistanceLoss +from .post_processors import ThresholdPostProcessor +from .predictors import HotDistancePredictor +from .task import Task + + +class HotDistanceTask(Task): + """This is just a Hot Distance Task that combine Binary and distance prediction.""" + + def __init__(self, task_config): + """Create a `HotDistanceTask` from a `HotDistanceTaskConfig`.""" + + self.predictor = HotDistancePredictor( + channels=task_config.channels, + scale_factor=task_config.scale_factor, + mask_distances=task_config.mask_distances, + ) + self.loss = HotDistanceLoss() + self.post_processor = ThresholdPostProcessor() + self.evaluator = BinarySegmentationEvaluator( + clip_distance=task_config.clip_distance, + tol_distance=task_config.tol_distance, + channels=task_config.channels, + ) diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py new file mode 100644 index 00000000..95122647 --- /dev/null +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -0,0 +1,47 @@ +import attr + +from .hot_distance_task import HotDistanceTask +from .task_config import TaskConfig + +from typing import List + + +class HotDistanceTaskConfig(TaskConfig): + """This is a Hot Distance task config used for generating and + evaluating signed distance transforms as a way of generating + segmentations. + + The advantage of generating distance transforms over regular + affinities is you can get a denser signal, i.e. 1 misclassified + pixel in an affinity prediction could merge 2 otherwise very + distinct objects, this cannot happen with distances. + """ + + task_type = HotDistanceTask + + channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."}) + clip_distance: float = attr.ib( + metadata={ + "help_text": "Maximum distance to consider for false positive/negatives." + }, + ) + tol_distance: float = attr.ib( + metadata={ + "help_text": "Tolerance distance for counting false positives/negatives" + }, + ) + scale_factor: float = attr.ib( + default=1, + metadata={ + "help_text": "The amount by which to scale distances before applying " + "a tanh normalization." + }, + ) + mask_distances: bool = attr.ib( + default=False, + metadata={ + "help_text": "Whether or not to mask out regions where the true distance to " + "object boundary cannot be known. This is anywhere that the distance to crop boundary " + "is less than the distance to object boundary." + }, + ) diff --git a/dacapo/experiments/tasks/losses/__init__.py b/dacapo/experiments/tasks/losses/__init__.py index b675faa9..f1db3586 100644 --- a/dacapo/experiments/tasks/losses/__init__.py +++ b/dacapo/experiments/tasks/losses/__init__.py @@ -2,3 +2,4 @@ from .mse_loss import MSELoss # noqa from .loss import Loss # noqa from .affinities_loss import AffinitiesLoss # noqa +from .hot_distance_loss import HotDistanceLoss # noqa diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py new file mode 100644 index 00000000..2e99ab5e --- /dev/null +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -0,0 +1,29 @@ +from .loss import Loss +import torch + + +# HotDistance is used for predicting hot and distance maps at the same time. +# The first half of the channels are the hot maps, the second half are the distance maps. +# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. +# Model should predict twice the number of channels as the target. +class HotDistanceLoss(Loss): + def compute(self, prediction, target, weight): + target_hot, target_distance = self.split(target) + prediction_hot, prediction_distance = self.split(prediction) + weight_hot, weight_distance = self.split(weight) + return self.hot_loss( + prediction_hot, target_hot, weight_hot + ) + self.distance_loss(prediction_distance, target_distance, weight_distance) + + def hot_loss(self, prediction, target, weight): + return torch.nn.BCELoss().forward(prediction * weight, target * weight) + + def distance_loss(self, prediction, target, weight): + return torch.nn.MSELoss().forward(prediction * weight, target * weight) + + def split(self, x): + assert ( + x.shape[0] % 2 == 0 + ), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance." + mid = x.shape[0] // 2 + return x[:mid], x[-mid:] diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py new file mode 100644 index 00000000..fc73cb0e --- /dev/null +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -0,0 +1,280 @@ +from dacapo.experiments.arraytypes.probabilities import ProbabilityArray +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import DistanceArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray +from dacapo.utils.balance_weights import balance_weights + +from funlib.geometry import Coordinate + +from scipy.ndimage.morphology import distance_transform_edt +import numpy as np +import torch + +import logging +from typing import List + +logger = logging.getLogger(__name__) + + +class HotDistancePredictor(Predictor): + """ + Predict signed distances and one hot embedding (as a proxy task) for a binary segmentation task. + Distances deep within background are pushed to -inf, distances deep within + the foreground object are pushed to inf. After distances have been + calculated they are passed through a tanh so that distances saturate at +-1. + Multiple classes can be predicted via multiple distance channels. The names + of each class that is being segmented can be passed in as a list of strings + in the channels argument. + """ + + def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + self.channels = ( + channels * 2 + ) # one hot + distance (TODO: add hot/distance to channel names) + self.norm = "tanh" + self.dt_scale_factor = scale_factor + self.mask_distances = mask_distances + + self.max_distance = 1 * scale_factor + self.epsilon = 5e-2 # TODO: should be a config parameter + self.threshold = 0.8 # TODO: should be a config parameter + + @property + def embedding_dims(self): + return len(self.channels) + + @property + def classes(self): + return len(self.channels) // 2 + + def create_model(self, architecture): + if architecture.dims == 2: + head = torch.nn.Conv2d( + architecture.num_out_channels, self.embedding_dims, kernel_size=3 + ) + elif architecture.dims == 3: + head = torch.nn.Conv3d( + architecture.num_out_channels, self.embedding_dims, kernel_size=3 + ) + + return Model(architecture, head) + + def create_target(self, gt): + target = self.process(gt.data, gt.voxel_size, self.norm, self.dt_scale_factor) + return NumpyArray.from_np_array( + target, + gt.roi, + gt.voxel_size, + gt.axes, + ) + + def create_weight(self, gt, target, mask, moving_class_counts=None): + # balance weights independently for each channel + one_hot_weights, one_hot_moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi]], + moving_counts=moving_class_counts[: self.classes], + ) + + if self.mask_distances: + distance_mask = self.create_distance_mask( + target[target.roi][-self.classes :], + mask[target.roi], + target.voxel_size, + self.norm, + self.dt_scale_factor, + ) + else: + distance_mask = np.ones_like(target.data) + + distance_weights, distance_moving_class_counts = balance_weights( + gt[target.roi], + 2, + slab=tuple(1 if c == "c" else -1 for c in gt.axes), + masks=[mask[target.roi], distance_mask], + moving_counts=moving_class_counts[-self.classes :], + ) + + weights = np.concatenate((one_hot_weights, distance_weights)) + moving_class_counts = np.concatenate( + (one_hot_moving_class_counts, distance_moving_class_counts) + ) + return ( + NumpyArray.from_np_array( + weights, + gt.roi, + gt.voxel_size, + gt.axes, + ), + moving_class_counts, + ) + + @property + def output_array_type(self): + # technically this is a probability array + distance array, but it is only ever referenced for interpolatability (which is true for both) (TODO) + return ProbabilityArray(self.embedding_dims) + + def create_distance_mask( + self, + distances: np.ndarray, + mask: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + mask_output = mask.copy() + for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)): + tmp = np.zeros( + np.array(channel_mask.shape) + np.array((2,) * channel_mask.ndim), + dtype=channel_mask.dtype, + ) + slices = tmp.ndim * (slice(1, -1),) + tmp[slices] = channel_mask + boundary_distance = distance_transform_edt( + tmp, + sampling=voxel_size, + ) + if self.epsilon is None: + add = 0 + else: + add = self.epsilon + boundary_distance = self.__normalize( + boundary_distance[slices], normalize, normalize_args + ) + + channel_mask_output = mask_output[i] + logging.debug( + "Total number of masked in voxels before distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + channel_mask_output[ + np.logical_and( + np.clip(abs(channel_distance) + add, 0, self.threshold) + >= boundary_distance, + channel_distance >= 0, + ) + ] = 0 + logging.debug( + "Total number of masked in voxels after postive distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + channel_mask_output[ + np.logical_and( + np.clip(abs(channel_distance) + add, 0, self.threshold) + >= boundary_distance, + channel_distance <= 0, + ) + ] = 0 + logging.debug( + "Total number of masked in voxels after negative distance masking {0:}".format( + np.sum(channel_mask_output) + ) + ) + return mask_output + + def process( + self, + labels: np.ndarray, + voxel_size: Coordinate, + normalize=None, + normalize_args=None, + ): + all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 + for ii, channel in enumerate(labels): + boundaries = self.__find_boundaries(channel) + + # mark boundaries with 0 (not 1) + boundaries = 1.0 - boundaries + + if np.sum(boundaries == 0) == 0: + max_distance = min( + dim * vs / 2 for dim, vs in zip(channel.shape, voxel_size) + ) + if np.sum(channel) == 0: + distances = -np.ones(channel.shape, dtype=np.float32) * max_distance + else: + distances = np.ones(channel.shape, dtype=np.float32) * max_distance + else: + # get distances (voxel_size/2 because image is doubled) + distances = distance_transform_edt( + boundaries, sampling=tuple(float(v) / 2 for v in voxel_size) + ) + distances = distances.astype(np.float32) + + # restore original shape + downsample = (slice(None, None, 2),) * len(voxel_size) + distances = distances[downsample] + + # todo: inverted distance + distances[channel == 0] = -distances[channel == 0] + + if normalize is not None: + distances = self.__normalize(distances, normalize, normalize_args) + + all_distances[ii] = distances + + return np.concatenate((labels, all_distances)) + + def __find_boundaries(self, labels): + # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n + # shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 + # diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 + # bound.: 00000001000100000001000 2n - 1 + + logger.debug("computing boundaries for %s", labels.shape) + + dims = len(labels.shape) + in_shape = labels.shape + out_shape = tuple(2 * s - 1 for s in in_shape) + + boundaries = np.zeros(out_shape, dtype=bool) + + logger.debug("boundaries shape is %s", boundaries.shape) + + for d in range(dims): + logger.debug("processing dimension %d", d) + + shift_p = [slice(None)] * dims + shift_p[d] = slice(1, in_shape[d]) + + shift_n = [slice(None)] * dims + shift_n[d] = slice(0, in_shape[d] - 1) + + diff = (labels[tuple(shift_p)] - labels[tuple(shift_n)]) != 0 + + logger.debug("diff shape is %s", diff.shape) + + target = [slice(None, None, 2)] * dims + target[d] = slice(1, out_shape[d], 2) + + logger.debug("target slices are %s", target) + + boundaries[tuple(target)] = diff + + return boundaries + + def __normalize(self, distances, norm, normalize_args): + if norm == "tanh": + scale = normalize_args + return np.tanh(distances / scale) + else: + raise ValueError("Only tanh is supported for normalization") + + def gt_region_for_roi(self, target_spec): + if self.mask_distances: + gt_spec = target_spec.copy() + gt_spec.roi = gt_spec.roi.grow( + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + Coordinate((self.max_distance,) * gt_spec.voxel_size.dims), + ).snap_to_grid(gt_spec.voxel_size, mode="shrink") + else: + gt_spec = target_spec.copy() + return gt_spec + + def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + return Coordinate((self.max_distance,) * gt_voxel_size.dims)