This repository has been archived by the owner on Jan 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
example.py
127 lines (114 loc) · 4.15 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from functools import partial
from multiprocessing import Pool
from os import cpu_count
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.functional as F
from numpy.core.fromnumeric import product
from skimage.segmentation import find_boundaries, mark_boundaries, slic
from skimage.segmentation.boundaries import find_boundaries
from torchvision.io import read_image
from torchvision.models.segmentation import fcn_resnet50
from torchvision.transforms.functional import convert_image_dtype
from torchvision.utils import draw_segmentation_masks, make_grid
from pytorch_superpixels.runtime import superpixelise
def show(imgs):
if not isinstance(imgs, list):
imgs = [imgs]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.detach()
img = F.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
plt.tight_layout()
plt.show()
if __name__ == "__main__":
sem_classes = [
"__background__",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
image_dims = [420, 640]
images = [read_image(str(img)) for img in Path("data").glob("*.jpg")]
images = [F.center_crop(image, image_dims) for image in images]
image_size = product(image_dims)
batch_int = torch.stack(images)
batch = convert_image_dtype(batch_int, dtype=torch.float)
# permute because slic expects the last dimension to be channel
with Pool(processes=cpu_count() - 1) as pool:
# re-order axes for skimage
args = [x.permute(1, 2, 0) for x in batch]
# 100 segments
kwargs = {"n_segments": 100, "start_label": 0, "slic_zero": True}
func = partial(slic, **kwargs)
masks_100sp = pool.map(func, args)
# 1000 segments
kwargs["n_segments"] = 1000
func = partial(slic, **kwargs)
masks_1000sp = pool.map(func, args)
model = fcn_resnet50(pretrained=True, progress=False)
model = model.eval()
normalized_batch = F.normalize(
batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
)
outputs = model(batch)["out"]
normalized_masks = torch.nn.functional.softmax(outputs, dim=1)
num_classes = normalized_masks.shape[1]
def generate_all_class_masks(outputs, masks):
masks = np.stack(masks)
masks = torch.from_numpy(masks)
outputs_sp = superpixelise(outputs, masks)
normalized_masks_sp = torch.nn.functional.softmax(outputs_sp, dim=1)
return (
normalized_masks_sp[i].argmax(0) == torch.arange(num_classes)[:, None, None]
)
to_show = []
for i, image in enumerate(images):
# before
all_classes_masks = (
normalized_masks[i].argmax(0) == torch.arange(num_classes)[:, None, None]
)
to_show.append(
draw_segmentation_masks(image, masks=all_classes_masks, alpha=0.6)
)
# after 100
all_classes_masks_sp = generate_all_class_masks(outputs, masks_100sp)
to_show.append(
draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=0.6)
)
# show superpixel boundaries
boundaries = find_boundaries(masks_100sp[i])
to_show[-1][0:2, boundaries] = 255
to_show[-1][2, boundaries] = 0
# after 1000
all_classes_masks_sp = generate_all_class_masks(outputs, masks_1000sp)
to_show.append(
draw_segmentation_masks(image, masks=all_classes_masks_sp, alpha=0.6)
)
# show superpixel boundaries
boundaries = find_boundaries(masks_1000sp[i])
to_show[-1][0:2, boundaries] = 255
to_show[-1][2, boundaries] = 0
show(make_grid(to_show, nrow=6))