Skip to content

Commit

Permalink
Merge pull request #46 from janelia-cellmap/actions/black
Browse files Browse the repository at this point in the history
Format Python code with psf/black push
  • Loading branch information
mzouink authored Feb 9, 2024
2 parents 5d77af0 + 232047c commit 8b9d44a
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 41 deletions.
2 changes: 1 addition & 1 deletion dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def apply(run_name: str, iteration: int, dataset_name: str):
iteration,
dataset_name,
)
raise NotImplementedError("This function is not yet implemented.")
raise NotImplementedError("This function is not yet implemented.")
2 changes: 1 addition & 1 deletion dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def validate(run_name, iteration):
help="The name of the dataset to apply the run to.",
)
def apply(run_name, iteration, dataset_name):
dacapo.apply(run_name, iteration, dataset_name)
dacapo.apply(run_name, iteration, dataset_name)
10 changes: 4 additions & 6 deletions dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ def __init__(self, task_config):
"""Create a `DummyTask` from a `DummyTaskConfig`."""

self.predictor = AffinitiesPredictor(
neighborhood=task_config.neighborhood,
lsds=task_config.lsds,
num_voxels=task_config.num_voxels,
downsample_lsds=task_config.downsample_lsds,
grow_boundary_iterations=task_config.grow_boundary_iterations,
neighborhood=task_config.neighborhood, lsds=task_config.lsds
)
self.loss = AffinitiesLoss(
len(task_config.neighborhood), task_config.lsds_to_affs_weight_ratio
)
self.loss = AffinitiesLoss(len(task_config.neighborhood))
self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood)
self.evaluator = InstanceEvaluator()
18 changes: 2 additions & 16 deletions dacapo/experiments/tasks/affinities_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,9 @@ class AffinitiesTaskConfig(TaskConfig):
"It has been shown that lsds as an auxiliary task can help affinity predictions."
},
)
num_voxels: int = attr.ib(
default=20,
metadata={
"help_text": "The number of voxels to use for the gaussian sigma when computing lsds."
},
)
downsample_lsds: int = attr.ib(
lsds_to_affs_weight_ratio: float = attr.ib(
default=1,
metadata={
"help_text": "The amount to downsample the lsds. "
"This is useful for speeding up training and inference."
},
)
grow_boundary_iterations: int = attr.ib(
default=0,
metadata={
"help_text": "The number of iterations to run the grow boundaries algorithm. "
"This is useful for refining the boundaries of the affinities, and reducing merging of adjacent objects."
"help_text": "If training with lsds, set how much they should be weighted compared to affs."
},
)
5 changes: 3 additions & 2 deletions dacapo/experiments/tasks/losses/affinities_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@


class AffinitiesLoss(Loss):
def __init__(self, num_affinities: int):
def __init__(self, num_affinities: int, lsds_to_affs_weight_ratio: float):
self.num_affinities = num_affinities
self.lsds_to_affs_weight_ratio = lsds_to_affs_weight_ratio

def compute(self, prediction, target, weight):
affs, affs_target, affs_weight = (
Expand All @@ -21,7 +22,7 @@ def compute(self, prediction, target, weight):
return (
torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target)
* affs_weight
).mean() + (
).mean() + self.lsds_to_affs_weight_ratio * (
torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target)
* aux_weight
).mean()
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def enumerate_parameters(self):
"""Enumerate all possible parameters of this post-processor. Should
return instances of ``PostProcessorParameters``."""

for i, bias in enumerate([0.1, 0.5, 0.9]):
for i, bias in enumerate([0.1, 0.25, 0.5, 0.75, 0.9]):
yield WatershedPostProcessorParameters(id=i, bias=bias)

def set_prediction(self, prediction_array_identifier):
Expand All @@ -44,9 +44,9 @@ def process(self, parameters, output_array_identifier):
# if a previous segmentation is provided, it must have a "grid graph"
# in its metadata.
pred_data = self.prediction_array[self.prediction_array.roi]
affs = pred_data[: len(self.offsets)]
affs = pred_data[: len(self.offsets)].astype(np.float64)
segmentation = mws.agglom(
affs - 0.5,
affs - parameters.bias,
self.offsets,
)
# filter fragments
Expand All @@ -59,12 +59,17 @@ def process(self, parameters, output_array_identifier):
for fragment, mean in zip(
fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids)
):
if mean < 0.5:
if mean < parameters.bias:
filtered_fragments.append(fragment)

filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype)
replace = np.zeros_like(filtered_fragments)
segmentation = npi.remap(segmentation, filtered_fragments, replace)

# DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input
if filtered_fragments.size > 0:
segmentation = npi.remap(
segmentation.flatten(), filtered_fragments, replace
).reshape(segmentation.shape)

output_array[self.prediction_array.roi] = segmentation

Expand Down
27 changes: 18 additions & 9 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def __init__(self, trainer_config):
self.mask_integral_downsample_factor = 4
self.clip_raw = trainer_config.clip_raw

# Testing out if calculating multiple times and multiplying is necessary
self.add_predictor_nodes_to_dataset = (
trainer_config.add_predictor_nodes_to_dataset
)

self.scheduler = None

def create_optimizer(self, model):
Expand Down Expand Up @@ -146,13 +151,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)

# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)
if self.add_predictor_nodes_to_dataset:
# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)

dataset_sources.append(dataset_source)
pipeline = tuple(dataset_sources) + gp.RandomProvider(weights)
Expand All @@ -162,11 +168,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
task.predictor,
gt_key=gt_key,
target_key=target_key,
weights_key=datasets_weight_key,
weights_key=datasets_weight_key
if self.add_predictor_nodes_to_dataset
else weight_key,
mask_key=mask_key,
)

pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)
if self.add_predictor_nodes_to_dataset:
pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)

# Trainer attributes:
if self.num_data_fetchers > 1:
Expand Down
7 changes: 7 additions & 0 deletions dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ class GunpowderTrainerConfig(TrainerConfig):
)
min_masked: Optional[float] = attr.ib(default=0.15)
clip_raw: bool = attr.ib(default=True)

add_predictor_nodes_to_dataset: Optional[bool] = attr.ib(
default=True,
metadata={
"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"
},
)
2 changes: 1 addition & 1 deletion dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def predict(
num_cpu_workers: int = 4,
compute_context: ComputeContext = LocalTorch(),
output_roi: Optional[Roi] = None,
output_dtype: np.dtype = np.float32, # type: ignore
output_dtype: np.dtype = np.float32, # type: ignore
overwrite: bool = False,
):
# get the model's input and output size
Expand Down

0 comments on commit 8b9d44a

Please sign in to comment.