Skip to content

Commit

Permalink
more consistent plotting and figures
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Aug 20, 2024
1 parent a389ca8 commit c4f12b7
Showing 1 changed file with 111 additions and 42 deletions.
153 changes: 111 additions & 42 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from dlmbl_unet import UNet
from tqdm import tqdm
import tifffile
import mwatershed as mws

from skimage.filters import threshold_otsu

Expand All @@ -71,7 +72,7 @@
# %%
# Create a custom label color map for showing instances
np.random.seed(1)
colors = [[0,0,0]] + [list(np.random.choice(range(256), size=3)) for _ in range(254)]
colors = [[0, 0, 0]] + [list(np.random.choice(range(256), size=3)) for _ in range(254)]
label_cmap = ListedColormap(colors)

# %% [markdown]
Expand All @@ -87,11 +88,12 @@
# <br> - As an example, here, you see the SDT (right) of the target mask (middle), below.

# %% [markdown]
# ![image](static/04_instance_sdt.png)
# ![image](static/figure2/04_instance_sdt.png)
#

# %%


def compute_sdt(labels: np.ndarray, scale: int = 5):
"""Function to compute a signed distance transform."""
dims = len(labels.shape)
Expand Down Expand Up @@ -132,6 +134,7 @@ def compute_sdt(labels: np.ndarray, scale: int = 5):
distances[labels == 0] *= -1
return distances


# %% [markdown]
# <div class="alert alert-block alert-info">
# <b>Task 1.1</b>: Explain the `compute_sdt` from the cell above.
Expand Down Expand Up @@ -166,6 +169,7 @@ def compute_sdt(labels: np.ndarray, scale: int = 5):
# <br> Note that the output of the signed distance transform is not binary, a significant difference from semantic segmentation
# %%
# Visualize the signed distance transform using the function you wrote above.

root_dir = "tissuenet_data/train" # the directory with all the training samples
samples = os.listdir(root_dir)
idx = np.random.randint(len(samples) // 3) # take a random sample.
Expand All @@ -174,7 +178,7 @@ def compute_sdt(labels: np.ndarray, scale: int = 5):
os.path.join(root_dir, f"img_{idx}_cyto_masks.tif")
) # get the image
sdt = compute_sdt(label)
plot_two(img[1], sdt, label="SDT")
plot_two(img, sdt, label="SDT")

# %% [markdown]
# <div class="alert alert-block alert-info">
Expand Down Expand Up @@ -254,16 +258,18 @@ def __getitem__(self, idx):
image = self.transform(image)
torch.manual_seed(seed)
mask = self.transform(mask)

# use the compute_sdt function to get the sdt
sdt = ...
assert sdt.shape == mask.shape
if self.img_transform is not None:
image = self.img_transform(image)
if self.return_mask is True:
return image, mask.unsqueeze(0), sdt.unsqueeze(0)
else:
return image, sdt.unsqueeze(0)


# %% tags=["solution"]
class SDTDataset(Dataset):
"""A PyTorch dataset to load cell images and nuclei masks."""
Expand Down Expand Up @@ -314,6 +320,7 @@ def __getitem__(self, idx):
torch.manual_seed(seed)
mask = self.transform(mask)
sdt = self.create_sdt_target(mask)
assert sdt.shape == mask.shape
if self.img_transform is not None:
image = self.img_transform(image)
if self.return_mask is True:
Expand Down Expand Up @@ -341,7 +348,7 @@ def create_sdt_target(self, mask):
idx = np.random.randint(len(train_data)) # take a random sample
img, sdt = train_data[idx] # get the image and the nuclei masks
print(img.shape, sdt.shape)
plot_two(img[1], sdt[0], label="SDT")
plot_two(img, sdt[0], label="SDT")

# %% [markdown]
# <div class="alert alert-block alert-info">
Expand Down Expand Up @@ -379,7 +386,7 @@ def create_sdt_target(self, mask):

# %% tags=["solution"]
unet = UNet(
depth=2,
depth=3,
in_channels=2,
out_channels=1,
final_activation=torch.nn.Tanh(),
Expand Down Expand Up @@ -418,7 +425,7 @@ def create_sdt_target(self, mask):
image = np.squeeze(image.cpu())
sdt = np.squeeze(sdt.cpu().numpy())
pred = np.squeeze(pred.cpu().detach().numpy())
plot_three(image[1], sdt, pred)
plot_three(image, sdt, pred)


# %% [markdown]
Expand Down Expand Up @@ -449,13 +456,13 @@ def create_sdt_target(self, mask):


def find_local_maxima(distance_transform, min_dist_between_points):

# Hint: Use `maximum_filter` to perform a maximum filter convolution on the distance_transform

seeds, number_of_seeds = ...

return seeds, number_of_seeds


# %% tags=["solution"]
from scipy.ndimage import label, maximum_filter

Expand Down Expand Up @@ -656,14 +663,11 @@ def get_inner_mask(pred, threshold):
pred = np.squeeze(pred.cpu().detach().numpy())

# feel free to try different thresholds
thresh = threshold_otsu(pred)
thresh = ...

# get boundary mask
inner_mask = get_inner_mask(pred, threshold=thresh)

pred_labels = watershed_from_boundary_distance(
pred, inner_mask, id_offset=0, min_seed_distance=20
)
inner_mask = ...
pred_labels = ...
precision, recall, accuracy = evaluate(gt_labels, pred_labels)
precision_list.append(precision)
recall_list.append(recall)
Expand Down Expand Up @@ -701,11 +705,14 @@ def get_inner_mask(pred, threshold):
pred = np.squeeze(pred.cpu().detach().numpy())

# feel free to try different thresholds
thresh = ...
thresh = threshold_otsu(pred)

# get boundary mask
inner_mask = ...
pred_labels = ...
inner_mask = get_inner_mask(pred, threshold=thresh)

pred_labels = watershed_from_boundary_distance(
pred, inner_mask, id_offset=0, min_seed_distance=20
)
precision, recall, accuracy = evaluate(gt_labels, pred_labels)
precision_list.append(precision)
recall_list.append(recall)
Expand All @@ -715,6 +722,7 @@ def get_inner_mask(pred, threshold):
print(f"Mean Recall is {np.mean(recall_list):.3f}")
print(f"Mean Accuracy is {np.mean(accuracy_list):.3f}")


# %% [markdown]
# <hr style="height:2px;">
#
Expand All @@ -728,7 +736,7 @@ def get_inner_mask(pred, threshold):
# Here, we show the (affinity in x + affinity in y) in the bottom right image.

# %% [markdown]
# ![image](static/05_instance_affinity.png)
# ![image](static/figure3/instance_affinity.png)

# %% [markdown]
# Similar to the pipeline used for SDTs, we first need to modify the dataset to produce affinities.
Expand All @@ -741,7 +749,15 @@ def get_inner_mask(pred, threshold):
class AffinityDataset(Dataset):
"""A PyTorch dataset to load cell images and nuclei masks"""

def __init__(self, root_dir, transform=None, img_transform=None, return_mask=False):
def __init__(
self,
root_dir,
transform=None,
img_transform=None,
return_mask=False,
weights: bool = False,
):
self.weights = weights
self.root_dir = root_dir # the directory with all the training samples
self.num_samples = len(os.listdir(self.root_dir)) // 3 # list the samples
self.return_mask = return_mask
Expand Down Expand Up @@ -788,13 +804,35 @@ def __getitem__(self, idx):
aff_mask = self.create_aff_target(mask)
if self.img_transform is not None:
image = self.img_transform(image)
if self.return_mask is True:
return image, mask, aff_mask

if self.weights:
weight = torch.zeros_like(aff_mask)
for channel in range(weight.shape[0]):
weight[channel][aff_mask[channel] == 0] = np.clip(
weight[channel].numel()
/ 2
/ (weight[channel].numel() - weight[channel].sum()),
0.1,
10.0,
)
weight[channel][aff_mask[channel] == 1] = np.clip(
weight[channel].numel() / 2 / weight[channel].sum(), 0.1, 10.0
)

if self.return_mask is True:
return image, mask, aff_mask, weight
else:
return image, aff_mask, weight
else:
return image, aff_mask
if self.return_mask is True:
return image, mask, aff_mask
else:
return image, aff_mask

def create_aff_target(self, mask):
aff_target_array = compute_affinities(np.asarray(mask), [[0, 1], [1, 0]])
aff_target_array = compute_affinities(
np.asarray(mask), [[0, 1], [1, 0], [0, 5], [5, 0]]
)
aff_target = torch.from_numpy(aff_target_array)
return aff_target.float()

Expand All @@ -804,13 +842,14 @@ def create_aff_target(self, mask):
# %%
# Initialize the datasets

train_data = AffinityDataset("tissuenet_data/train", v2.RandomCrop(256))
train_data = AffinityDataset("tissuenet_data/train", v2.RandomCrop(256), weights=True)
train_loader = DataLoader(
train_data, batch_size=5, shuffle=True, num_workers=NUM_THREADS
)
idx = np.random.randint(len(train_data)) # take a random sample
img, affinity = train_data[idx] # get the image and the nuclei masks
plot_two(img[1], affinity[0+2] + affinity[1+2], label="AFFINITY")
img, affinity, weight = train_data[idx] # get the image and the nuclei masks
plot_two(img, affinity, label="AFFINITY")


# %% [markdown]
# <div class="alert alert-block alert-info">
Expand All @@ -833,11 +872,11 @@ def create_aff_target(self, mask):
# %% tags=["solution"]

unet = UNet(
depth=2,
depth=4,
in_channels=2,
out_channels=2,
out_channels=4,
final_activation=torch.nn.Sigmoid(),
num_fmaps=4,
num_fmaps=16,
fmap_inc_factor=3,
downsample_factor=2,
padding="same",
Expand All @@ -846,11 +885,14 @@ def create_aff_target(self, mask):
learning_rate = 1e-4

# choose a loss function
loss = torch.nn.MSELoss()
loss = torch.nn.MSELoss(reduce=False)

optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate)
plot_three(image[1], mask[0] + mask[1], pred[0 + 2] + pred[1 + 2], label="Affinity")

val_data = AffinityDataset("tissuenet_data/test", v2.RandomCrop(256))
val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=8)

# %%
for epoch in range(NUM_EPOCHS):
train(
unet,
Expand All @@ -871,20 +913,35 @@ def create_aff_target(self, mask):

unet.eval()
idx = np.random.randint(len(val_data)) # take a random sample
image, mask = val_data[idx] # get the image and the nuclei masks
image, affs = val_data[idx] # get the image and the nuclei masks
image = image.to(device)
pred = torch.squeeze(unet(torch.unsqueeze(image, dim=0)))

image = image.cpu()
mask = mask.cpu().numpy()
affs = affs.cpu().numpy()
pred = pred.cpu().detach().numpy()

plot_three(image[1], mask[0] + mask[1], pred[0] + pred[1], label="Affinity")
bias_short = -0.9
bias_long = -0.95

pred_labels = mws.agglom(
np.array(
[
pred[0] + bias_short,
pred[1] + bias_short,
pred[2] + bias_long,
pred[3] + bias_long,
]
).astype(np.float64),
[[0, 1], [1, 0], [0, 5], [5, 0]],
)

plot_four(image, affs, pred, pred_labels, label="Affinity")

# %% [markdown]
# Let's also evaluate the model performance.

# %%

val_dataset = AffinityDataset("tissuenet_data/test", return_mask=True)
val_loader = DataLoader(
val_dataset, batch_size=1, shuffle=False, num_workers=NUM_THREADS
Expand All @@ -911,17 +968,29 @@ def create_aff_target(self, mask):

pred = np.squeeze(pred.cpu().detach().numpy())

# feel free to try different thresholds
thresh = threshold_otsu(pred)
# # feel free to try different thresholds
# thresh = threshold_otsu(pred)

# get boundary mask
inner_mask = 0.5 * (pred[0] + pred[1]) > thresh
# # get boundary mask
# inner_mask = 0.5 * (pred[0] + pred[1]) > thresh

boundary_distances = distance_transform_edt(inner_mask)
# boundary_distances = distance_transform_edt(inner_mask)

pred_labels = watershed_from_boundary_distance(
boundary_distances, inner_mask, id_offset=0, min_seed_distance=20
# pred_labels = watershed_from_boundary_distance(
# boundary_distances, inner_mask, id_offset=0, min_seed_distance=20
# )
pred_labels = mws.agglom(
np.array(
[
pred[0] - bias_short,
pred[1] - bias_short,
pred[2] - bias_long,
pred[3] - bias_long,
]
).astype(np.float64),
[[0, 1], [1, 0], [0, 5], [5, 0]],
)

precision, recall, accuracy = evaluate(gt_labels, pred_labels)
precision_list.append(precision)
recall_list.append(recall)
Expand Down

0 comments on commit c4f12b7

Please sign in to comment.