diff --git a/README.md b/README.md
index 13f67f6..8285863 100644
--- a/README.md
+++ b/README.md
@@ -40,7 +40,6 @@ The code provided is compatible with [nuScenes](https://www.nuscenes.org/lidar-s
[PV-RCNN finetuned on KITTI](https://github.com/valeoai/SLidR/releases/download/v1.0/pvrcnn_slidr.pt)
-
## Reproducing the results
### Pre-computing the superpixels (required)
@@ -131,6 +130,13 @@ SLidR |81.9 |51.6 |68.5 |**
*As reimplemented in [ONCE](https://arxiv.org/abs/2106.11037)
+## Visualizations
+
+For visualization you need a pre-training containing both 2D & 3D models. We provide the raw [SR-UNet & ResNet50 pre-trained on nuScenes](https://github.com/valeoai/SLidR/releases/download/v1.1/minkunet_slidr_1gpu_raw.pt).
+The image part of the pre-trained weights are identical for almost all layers to those of [MoCov2](https://github.com/facebookresearch/moco) (He et al.)
+
+The [visualization code](utils/visualization.ipynb) allows to assess the similarities between points and pixels, as shown in the article.
+
## Acknowledgment
diff --git a/model/__init__.py b/model/__init__.py
index f3f7deb..9be3f0d 100644
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -5,5 +5,5 @@
MinkUNet = None
try:
from model.spconv_backbone import VoxelNet
-except ImportError:
+except (ImportError, AttributeError):
VoxelNet = None
diff --git a/model/modules/common.py b/model/modules/common.py
index c86aa78..c7b0f92 100644
--- a/model/modules/common.py
+++ b/model/modules/common.py
@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import collections
+import collections.abc as collections
from enum import Enum
import MinkowskiEngine as ME
diff --git a/utils/visualization.ipynb b/utils/visualization.ipynb
new file mode 100644
index 0000000..ae2f2ff
--- /dev/null
+++ b/utils/visualization.ipynb
@@ -0,0 +1,694 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "8573e6d5",
+ "metadata": {},
+ "source": [
+ "# Visualization code - SLidR"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "07c21201",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.chdir('../')\n",
+ "import torch\n",
+ "import plotly.express as px\n",
+ "import plotly.graph_objects as go\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import MinkowskiEngine as ME\n",
+ "from datetime import datetime as dt\n",
+ "from torch.utils.data import DataLoader\n",
+ "from pretrain.model_builder import make_model\n",
+ "from pretrain.dataloader_nuscenes import NuScenesMatchDataset, minkunet_collate_pair_fn\n",
+ "from utils.transforms import make_transforms_asymmetrical_val\n",
+ "\n",
+ "\n",
+ "np.random.seed(0)\n",
+ "\n",
+ "\n",
+ "def generate_config():\n",
+ " dataset = \"nuscenes\"\n",
+ " cylindrical_coordinates = True\n",
+ " voxel_size = 0.1\n",
+ " use_intensity = True\n",
+ " kernel_size = 3\n",
+ " model_n_out = 64\n",
+ " bn_momentum = 0.05\n",
+ " model_points = \"minkunet\"\n",
+ " image_weights = \"moco_v2\"\n",
+ " images_encoder = \"resnet50\"\n",
+ " decoder = \"dilation\"\n",
+ " training = \"validate\"\n",
+ " transforms_clouds = [\"Rotation\", \"FlipAxis\"]\n",
+ " transforms_mixed = [\"DropCuboids\", \"ResizedCrop\", \"FlipHorizontal\"]\n",
+ " losses = [\"loss_superpixels_average\"]\n",
+ " superpixels_type = \"slic\"\n",
+ " dataset_skip_step = 1\n",
+ " resume_path = \"weights/minkunet_slidr_1gpu_raw.pt\"\n",
+ "\n",
+ " # WARNING: DO NOT CHANGE THE FOLLOWING PARAMETERS\n",
+ " # ===============================================\n",
+ " if dataset.lower() == \"nuscenes\":\n",
+ " dataset_root = \"/datasets/nuscenes/\"\n",
+ " crop_size = (224, 416)\n",
+ " crop_ratio = (14.0 / 9.0, 17.0 / 9.0)\n",
+ " elif dataset.lower() == \"kitti\":\n",
+ " dataset_root = \"/datasets/semantic_kitti/\"\n",
+ " crop_size = (192, 672)\n",
+ " crop_ratio = (3., 4.)\n",
+ " else:\n",
+ " raise Exception(f\"Dataset Unknown: {dataset}\")\n",
+ "\n",
+ " datetime = dt.today().strftime(\"%d%m%y-%H%M\")\n",
+ " \n",
+ " normalize_features = True\n",
+ "\n",
+ " config = locals().copy()\n",
+ " return config\n",
+ "\n",
+ "config = generate_config()\n",
+ "\n",
+ "mixed_transforms_val = make_transforms_asymmetrical_val(config)\n",
+ "dataset = NuScenesMatchDataset(\n",
+ " phase=\"val\",\n",
+ " shuffle=False,\n",
+ " cloud_transforms=None,\n",
+ " mixed_transforms=mixed_transforms_val,\n",
+ " config=config,\n",
+ ")\n",
+ "\n",
+ "dataloader = DataLoader(\n",
+ " dataset,\n",
+ " batch_size=1,\n",
+ " shuffle=True,\n",
+ " num_workers=0,\n",
+ " collate_fn=minkunet_collate_pair_fn,\n",
+ " pin_memory=True,\n",
+ " drop_last=False,\n",
+ " worker_init_fn=lambda id:0\n",
+ ")\n",
+ "dl = iter(dataloader)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dec060aa",
+ "metadata": {},
+ "source": [
+ "## Load the 2D & 3D NN"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ec4887b5",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "model_points, model_images = make_model(config)\n",
+ "\n",
+ "checkpoint = torch.load(config[\"resume_path\"], map_location='cpu')\n",
+ "try:\n",
+ " model_points.load_state_dict(checkpoint[\"model_points\"])\n",
+ "except KeyError:\n",
+ " weights = {\n",
+ " k.replace(\"model_points.\", \"\"): v\n",
+ " for k, v in checkpoint[\"state_dict\"].items()\n",
+ " if k.startswith(\"model_points.\")\n",
+ " }\n",
+ " model_points.load_state_dict(weights)\n",
+ "\n",
+ "try:\n",
+ " model_images.load_state_dict(checkpoint[\"model_images\"])\n",
+ "except KeyError:\n",
+ " weights = {\n",
+ " k.replace(\"model_images.\", \"\"): v\n",
+ " for k, v in checkpoint[\"state_dict\"].items()\n",
+ " if k.startswith(\"model_images.\")\n",
+ " }\n",
+ " model_images.load_state_dict(weights)\n",
+ "model_points = model_points.cuda().eval()\n",
+ "model_images = model_images.cuda().eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c2cd3e81",
+ "metadata": {},
+ "source": [
+ "## Plotly code for dynamic figures"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "94fc4845",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def dynamic_heatmap(points, dist, image, save_path=None):\n",
+ " dist -= dist.min()\n",
+ " dist = dist / dist.max()\n",
+ " fig = go.FigureWidget(\n",
+ " data=[\n",
+ " dict(\n",
+ " type='image',\n",
+ " z=image,\n",
+ " hoverinfo='skip',\n",
+ " opacity=1.\n",
+ " ),\n",
+ " dict(\n",
+ " type='scattergl',\n",
+ " x=points[:, 0],\n",
+ " y=points[:, 1],\n",
+ " mode='markers',\n",
+ " marker={'color': '#0000ff'},\n",
+ " marker_size=10,\n",
+ " marker_line_width=1,\n",
+ " hovertemplate=''\n",
+ " ),\n",
+ " ] +\n",
+ " [dict(type='heatmap', z=dist[:,:,i], zmin=0., zmax=1., showscale=False, visible=False, hoverinfo='skip', opacity=.5) for i in range(len(points))],\n",
+ " )\n",
+ " fig.layout.hovermode = 'closest'\n",
+ " fig.layout.xaxis.visible = False\n",
+ " fig.layout.yaxis.visible = False\n",
+ " fig.layout.showlegend = False\n",
+ " fig.layout.width = 416\n",
+ " fig.layout.height = 224\n",
+ " fig.layout.plot_bgcolor=\"rgba(0,0,0,0)\"\n",
+ " fig.layout.margin=go.layout.Margin(\n",
+ " l=0, #left margin\n",
+ " r=0, #right margin\n",
+ " b=0, #bottom margin\n",
+ " t=0, #top margin\n",
+ " )\n",
+ " scatter = fig.data[1]\n",
+ "\n",
+ " def click_fn(trace, points, selector):\n",
+ " ind = points.point_inds[0]\n",
+ " c = ['#0000ff'] * dist.shape[2]\n",
+ " opacity = [0.] * dist.shape[2]\n",
+ " c[ind] = '#ff0000'\n",
+ " opacity[ind] = 1.\n",
+ " if fig.data[ind + 2].visible is False:\n",
+ " with fig.batch_update():\n",
+ " scatter.marker.color = c\n",
+ " scatter.marker.opacity = opacity\n",
+ " for i in range(dist.shape[2]):\n",
+ " fig.data[i + 2].visible = False\n",
+ " fig.data[ind + 2].visible = True\n",
+ " fig.update_xaxes(range=[0., 415.])\n",
+ " fig.update_yaxes(range=[223, 0.])\n",
+ " scatter.on_click(click_fn)\n",
+ " return fig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8ceafa38",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def dynamic_heatmap_3d(query, points, dist_3d, image, save_path=None):\n",
+ " dist_3d -= dist_3d.min()\n",
+ " dist_3d = dist_3d / dist_3d.max()\n",
+ " fig = go.FigureWidget(\n",
+ " data=[\n",
+ " dict(\n",
+ " type='image',\n",
+ " z=image,\n",
+ " hoverinfo='skip'\n",
+ " ),\n",
+ " dict(\n",
+ " type='scattergl',\n",
+ " x=query[:, 0],\n",
+ " y=query[:, 1],\n",
+ " mode='markers',\n",
+ " marker={'color': '#0000ff'},\n",
+ " marker_size=10,\n",
+ " marker_line_width=1,\n",
+ " hovertemplate=''\n",
+ " ),\n",
+ " ] +\n",
+ " [dict(type='scatter', mode=\"markers\", x=points[:, 0], y=points[:, 1], marker_color=dist_3d[i], marker_size=10, visible=False, hoverinfo='skip', opacity=0.5) for i in range(len(query))],\n",
+ " )\n",
+ " fig.layout.hovermode = 'closest'\n",
+ " fig.layout.xaxis.visible = False\n",
+ " fig.layout.yaxis.visible = False\n",
+ " fig.layout.showlegend = False\n",
+ " fig.layout.width = 416\n",
+ " fig.layout.height = 224\n",
+ " fig.layout.plot_bgcolor=\"rgba(0,0,0,0)\"\n",
+ " fig.layout.margin=go.layout.Margin(\n",
+ " l=0, #left margin\n",
+ " r=0, #right margin\n",
+ " b=0, #bottom margin\n",
+ " t=0, #top margin\n",
+ " )\n",
+ " scatter = fig.data[1]\n",
+ "\n",
+ " def click_fn(trace, points, selector):\n",
+ " ind = points.point_inds[0]\n",
+ " c = ['#0000ff'] * dist_3d.shape[0]\n",
+ " opacity = [0.] * dist_3d.shape[0]\n",
+ " c[ind] = '#ff0000'\n",
+ " opacity[ind] = 1.\n",
+ " if fig.data[ind + 2].visible is False:\n",
+ " with fig.batch_update():\n",
+ " scatter.marker.color = c\n",
+ " scatter.marker.opacity = opacity\n",
+ " for i in range(dist_3d.shape[0]):\n",
+ " fig.data[i + 2].visible = False\n",
+ " fig.data[ind + 2].visible = True\n",
+ " fig.update_xaxes(range=[0., 415.])\n",
+ " fig.update_yaxes(range=[223, 0.])\n",
+ " scatter.on_click(click_fn)\n",
+ " return fig"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ef2d1083",
+ "metadata": {},
+ "source": [
+ "## Process one batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0550e215",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " image_id = 0\n",
+ " batch = next(dl)\n",
+ " sparse_input = ME.SparseTensor(batch[\"sinput_F\"].cuda(), batch[\"sinput_C\"].cuda())\n",
+ " output_points = model_points(sparse_input).F\n",
+ " output_images = model_images(batch[\"input_I\"].cuda())\n",
+ " image = batch[\"input_I\"][image_id].permute(1,2,0) * 255\n",
+ " mask = batch[\"pairing_images\"][:,0] == image_id\n",
+ " superpixels = batch[\"superpixels\"][image_id]\n",
+ " points = np.flip(batch[\"pairing_images\"][mask, 1:].numpy(), axis=1)\n",
+ " points_features = output_points[batch[\"pairing_points\"][mask]]\n",
+ " image_features = output_images[image_id].permute(1,2,0)\n",
+ " pairing_images = batch[\"pairing_images\"][mask, 1:]\n",
+ " pairing_points = batch[\"pairing_points\"][mask]\n",
+ " dist_2d_3d = (1+torch.matmul(image_features, points_features.T))/2\n",
+ " dist_2d_3d = dist_2d_3d.cpu().numpy()\n",
+ " dist_3d_3d = (1+torch.matmul(points_features, points_features.T).cpu().numpy())/2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8a52151",
+ "metadata": {},
+ "source": [
+ "## Show the front camera for this batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2a3c9943",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plt.figure(figsize=(8.32,4.48))\n",
+ "ax = fig.add_axes([0, 0, 1, 1])\n",
+ "plt.axis('off')\n",
+ "ax.imshow(image/255)\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f19eae50",
+ "metadata": {},
+ "source": [
+ "## Show the associated projected 3D points"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "180ed015",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plt.figure(figsize=(8.32,4.48))\n",
+ "ax = fig.add_axes([0, 0, 1, 1])\n",
+ "plt.axis('off')\n",
+ "ax.scatter(points[:, 0], points[:, 1], color='black', s=15)\n",
+ "ax.imshow(np.zeros((224,416,4)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6959bde7",
+ "metadata": {},
+ "source": [
+ "## Dynamic 2D features\n",
+ "Clicking on a projected 3D point (in blue) will show a similarity map for the 2D features in the image relative to this point"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4d8871a5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "candidates_ind = np.random.choice(points.shape[0], 10, replace=False)\n",
+ "dynamic_heatmap(points[candidates_ind], dist_2d_3d[:,:,candidates_ind], image)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2e943738",
+ "metadata": {},
+ "source": [
+ "## Dynamic 3D features\n",
+ "Clicking on a projected 3D point (in blue) will show a similarity map for other 3D points' features, relative to this point"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a7fdce3d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "candidates_ind = np.random.choice(points.shape[0], 25, replace=False)\n",
+ "dynamic_heatmap_3d(points[candidates_ind], points, (dist_3d_3d[candidates_ind]), image, save_path=None)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2a8f32ee",
+ "metadata": {},
+ "source": [
+ "## PCA coloring"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "74cc0f77",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.decomposition import PCA\n",
+ "pca = PCA(3)\n",
+ "y = pca.fit_transform(points_features.cpu().numpy())\n",
+ "y = y - y.min(0)\n",
+ "y = y / y.max(0)\n",
+ "x = pca.transform(image_features.view(-1, 64).cpu().numpy())\n",
+ "x = x - x.min(0)\n",
+ "x = x / x.max(0)\n",
+ "fmap = x.reshape((224,416,3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d5448b37",
+ "metadata": {},
+ "source": [
+ "The following figures show a PCA coloring (in RGB) for the 2D (first figure) or 3D (second figure) features. The same PCA is used for both, so the colors are corresponding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8581302d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fn = lambda x: f\"rgb({x[0]}, {x[1]}, {x[2]})\"\n",
+ "cmap = list(map(fn, (y*255).astype(np.int32)))\n",
+ "fig = go.FigureWidget(\n",
+ " data=[\n",
+ " dict(\n",
+ " type='image',\n",
+ " z=image,\n",
+ " hoverinfo='skip'\n",
+ " ),\n",
+ " dict(\n",
+ " type='image',\n",
+ " z=fmap*255,\n",
+ " hoverinfo='skip',\n",
+ " opacity=0.5\n",
+ " )\n",
+ " ]\n",
+ ")\n",
+ "fig.layout.xaxis.visible = False\n",
+ "fig.layout.yaxis.visible = False\n",
+ "fig.layout.showlegend = False\n",
+ "fig.layout.width = 416\n",
+ "fig.layout.height = 224\n",
+ "fig.layout.plot_bgcolor=\"rgba(0,0,0,0)\"\n",
+ "fig.layout.margin=go.layout.Margin(\n",
+ " l=0, #left margin\n",
+ " r=0, #right margin\n",
+ " b=0, #bottom margin\n",
+ " t=0, #top margin\n",
+ ")\n",
+ "fig.update_xaxes(range=[0., 415.])\n",
+ "fig.update_yaxes(range=[223, 0.])\n",
+ "fig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ac6aabf0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fn = lambda x: f\"rgb({x[0]}, {x[1]}, {x[2]})\"\n",
+ "cmap = list(map(fn, (y*255).astype(np.int32)))\n",
+ "fig = go.FigureWidget(\n",
+ " data=[\n",
+ " dict(\n",
+ " type='image',\n",
+ " z=image,\n",
+ " hoverinfo='skip'\n",
+ " ),\n",
+ " dict(type='scatter', mode=\"markers\", x=points[:, 0], y=points[:, 1], marker_color=cmap, marker_size=10, visible=True, hoverinfo='skip', opacity=0.5)\n",
+ " ]\n",
+ ")\n",
+ "fig.layout.xaxis.visible = False\n",
+ "fig.layout.yaxis.visible = False\n",
+ "fig.layout.showlegend = False\n",
+ "fig.layout.width = 416\n",
+ "fig.layout.height = 224\n",
+ "fig.layout.plot_bgcolor=\"rgba(0,0,0,0)\"\n",
+ "fig.layout.margin=go.layout.Margin(\n",
+ " l=0, #left margin\n",
+ " r=0, #right margin\n",
+ " b=0, #bottom margin\n",
+ " t=0, #top margin\n",
+ ")\n",
+ "fig.update_xaxes(range=[0., 415.])\n",
+ "fig.update_yaxes(range=[223, 0.])\n",
+ "fig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "80180bc5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plt.figure(figsize=(8.32,4.48))\n",
+ "ax = fig.add_axes([0, 0, 1, 1])\n",
+ "plt.axis('off')\n",
+ "ax.scatter(points[:, 0], points[:, 1], color=y, s=15)\n",
+ "ax.imshow(np.zeros((224,416,4)))\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "96a0583c",
+ "metadata": {},
+ "source": [
+ "## Pooling the PCA coloring by superpixels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7351fd84",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "m = tuple(pairing_images.cpu().T.long())\n",
+ "\n",
+ "superpixels_I = superpixels.flatten()\n",
+ "idx_P = torch.arange(pairing_points.shape[0], device=superpixels.device)\n",
+ "total_pixels = superpixels_I.shape[0]\n",
+ "idx_I = torch.arange(total_pixels, device=superpixels.device)\n",
+ "\n",
+ "one_hot_P = torch.sparse_coo_tensor(\n",
+ " torch.stack((\n",
+ " superpixels[m], idx_P\n",
+ " ), dim=0),\n",
+ " torch.ones(pairing_points.shape[0], device=superpixels.device),\n",
+ " (superpixels.max() + 1, pairing_points.shape[0])\n",
+ ")\n",
+ "\n",
+ "one_hot_I = torch.sparse_coo_tensor(\n",
+ " torch.stack((\n",
+ " superpixels_I, idx_I\n",
+ " ), dim=0),\n",
+ " torch.ones(total_pixels, device=superpixels.device),\n",
+ " (superpixels.max() + 1, total_pixels)\n",
+ ")\n",
+ "\n",
+ "k = one_hot_P @ points_features.cpu()\n",
+ "k = k / (torch.sparse.sum(one_hot_P, 1).to_dense()[:, None] + 1e-6)\n",
+ "k = pca.transform(k.cpu().numpy())\n",
+ "k = k - k.min(0)\n",
+ "k = k / k.max(0)\n",
+ "q = one_hot_I @ image_features.cpu().flatten(0, 1)\n",
+ "q = q / (torch.sparse.sum(one_hot_I, 1).to_dense()[:, None] + 1e-6)\n",
+ "q = pca.transform(q.cpu().numpy())\n",
+ "q = q - q.min(0)\n",
+ "q = q / q.max(0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3dc371c5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plt.figure(figsize=(8.32,4.48))\n",
+ "ax = fig.add_axes([0, 0, 1, 1])\n",
+ "plt.axis('off')\n",
+ "ax.imshow(q[superpixels.numpy()])\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ea04373c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plt.figure(figsize=(8.32,4.48))\n",
+ "ax = fig.add_axes([0, 0, 1, 1])\n",
+ "plt.axis('off')\n",
+ "ax.scatter(points[:, 0], points[:, 1], color=k[superpixels[m]], s=15)\n",
+ "ax.imshow(np.zeros((224,416,4)))\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b4858c35",
+ "metadata": {},
+ "source": [
+ "## Showing superpixels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "08508c05",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "scene_index = np.random.randint(850)\n",
+ "current_sample_token = dataloader.dataset.nusc.scene[scene_index]['first_sample_token']\n",
+ "data = dataloader.dataset.nusc.get(\"sample\", current_sample_token)['data']\n",
+ "cam_info = dataloader.dataset.nusc.get(\"sample_data\", data['CAM_FRONT_RIGHT'])\n",
+ "token = cam_info['token']\n",
+ "filename = cam_info['filename']\n",
+ "im = plt.imread(f\"/datasets_master/nuscenes/{filename}\")\n",
+ "fig = plt.figure(figsize=(8,4.5))\n",
+ "ax = fig.add_axes([0, 0, 1, 1])\n",
+ "plt.axis('off')\n",
+ "plt.imshow(im)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6e7e524d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from PIL import Image\n",
+ "sp = np.array(Image.open(f\"/datasets_master/nuscenes_corentin/superpixels_slic/{token}.png\"))\n",
+ "from skimage.segmentation import mark_boundaries\n",
+ "compound_image = np.zeros((900,1600,3))\n",
+ "for i in range(sp.max()):\n",
+ " ma = sp==i\n",
+ " compound_image[ma] = np.average(im[ma], 0) / 255\n",
+ "compound_image = mark_boundaries(compound_image, sp, color=(1., 1., 1.))\n",
+ "fig = plt.figure(figsize=(8,4.5))\n",
+ "ax = fig.add_axes([0, 0, 1, 1])\n",
+ "plt.axis('off')\n",
+ "plt.imshow(compound_image)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "19fbdc94",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plt.figure(figsize=(8.32,4.48))\n",
+ "ax = fig.add_axes([0, 0, 1, 1])\n",
+ "plt.axis('off')\n",
+ "ax.scatter(points[:, 0], points[:, 1], color=np.array(image[list(np.flip(points, 1).T)] / 255), s=15)\n",
+ "ax.imshow(np.zeros((224,416,4)))\n",
+ "fig.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.10"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}