Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uncolored clouds and PointNet baseline #132

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,37 +1,60 @@
<div align="center">

# Myria3D: Aerial Lidar HD Semantic Segmentation with Deep Learning
# Fork adapted to train/infer on non-colorized data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove the parts that are relate to your fork from this MR please, as it won't be relevant once merged

</div>

Myria3D is a deep learning library designed with a focused scope: the multiclass semantic segmentation of large scale, high density aerial Lidar points cloud.

<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a>
This fork includes option to train on Lidar HD data without RGB attributes. It also implements the PointNet baseline. Two pretrained models are available, RandLaNet and PointNet, in the form of .ckpt files with the best version of the model. The models were trained on the same twelve Lidar HD tiles, list of the tiles trained upon is in trained_model_assets/lidarhd_dataset_split.csv. Training metrics can be observed on Comet: https://shorturl.at/QHZVd

[![](https://shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=303030)](https://github.com/ashleve/lightning-hydra-template)
The training was performed on a laptop with 3070Ti GPU (8GB VRAM), 32 GB RAM and i7-12700H. Batch sizes were adapted to the specifications.
PointNet implementation subsamples the 50x50m tiles to 4096 points, then upsamples with 1-nn. RandLaNet remains unchanged.

[![CICD](https://github.com/IGNF/myria3d/actions/workflows/cicd.yaml/badge.svg)](https://github.com/IGNF/myria3d/actions/workflows/cicd.yaml)
[![Documentation Build](https://github.com/IGNF/myria3d/actions/workflows/gh-pages.yml/badge.svg)](https://github.com/IGNF/myria3d/actions/workflows/gh-pages.yml)
</div>
<br><br>
To train PointNet:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Infos on how to train are in the documentation, can you move your update to this part? eg. here

Same for the inference, see here


`python run.py experiment=PointNet_baseline`

Myria3D is a deep learning library designed with a focused scope: the multiclass semantic segmentation of large scale, high density aerial Lidar points cloud.

The library implements the training of 3D Segmentation neural networks, with optimized data-processing and evaluation logics at fit time. Inference on unseen, large scale point cloud is also supported.
It allows for the evaluation of single-class IoU on the full point cloud, which results in reliable model evaluation.
To train RandLaNet:

`python run.py experiment=RandLaNet_base_run_FR`

Myria3D is built upon [PyTorch](https://pytorch.org/). It keeps the standard data format
from [Pytorch-Geometric](https://pytorch-geometric.readthedocs.io/).
Its structure was bootstraped from [this code template](https://github.com/ashleve/lightning-hydra-template),
which heavily relies on [Hydra](https://hydra.cc/) and [Pytorch-Lightning](https://github.com/PyTorchLightning/pytorch-lightning) to enable flexible and rapid iterations of deep learning experiments.

Although the library can be extended with new neural network architectures or new data signatures, it makes some opiniated choices in terms of neural network architecture, data processing logics, and inference logic. Indeed, it is initially built with the [French "Lidar HD" project](https://geoservices.ign.fr/lidarhd) in mind, with the ambition to map France in 3D with 10 pulse/m² aerial Lidar by 2025. The data will be openly available, including a semantic segmentation with a minimal number of classes: ground, vegetation, buildings, vehicles, bridges, others.
To infer PointNet model:

> &rarr; For installation and usage, please refer to [**Documentation**](https://ignf.github.io/myria3d/).
`python run.py task.task_name='predict' predict.src_las='/path/to/las' predict.output_dir='/path/to/output' datamodule.epsg=2154 predict.ckpt_path='${hydra:runtime.cwd}/trained_model_assets/randlanet_norgb_epoch_028.ckpt' trainer.accelerator=gpu predict.gpus=[0]`

> &rarr; A stable, production-ready version of Myria3D is tracked by a [Production Release](https://github.com/IGNF/myria3d/releases/tag/prod-release-tag). In the release's assets are a trained multiclass segmentation model as well as the necessary configuration file to perform inference on French "Lidar HD" data. Those assets are provided for convenience, and are subject to change in time to reflect latest model training.
(to achieve better results add `predict.subtile_overlap=25`)


To infer RandLaNet model:

`python run.py task.task_name='predict' predict.src_las='/path/to/las' predict.output_dir='/path/to/output' datamodule.epsg=2154 predict.ckpt_path='${hydra:runtime.cwd}/trained_model_assets/pointnet_norgb_epoch_020.ckpt' trainer.accelerator=gpu predict.gpus=[0]`
___

# Comparisons
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that your comparisons can have a place in the documentation too, provided that you update the links so that the images are stored in our repo as well (in case yours is modified)

<p float="left">
----------------Ground Truth-------------------------------
PointNet----------------------------------
RandLaNet---------------
</p>
<p float="left">
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex1/gt.png?raw=true" width="250" />
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex1/pointnet.png?raw=true" width="250" />
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex1/randlanet.png?raw=true" width="250" />
</p>
<p float="left">
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex2/gt.png?raw=true" width="250" />
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex2/pointnet.png?raw=true" width="250" />
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex2/randlanet.png?raw=true" width="250" />
</p>
<p float="left">
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex3/gt.png?raw=true" width="250" />
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex3/pointnet.png?raw=true" width="250" />
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex3/randlanet.png?raw=true" width="250" />
</p>

Please cite Myria3D if it helped your own research. Here is an example BibTex entry:
```
@misc{gaydon2022myria3d,
Expand All @@ -41,4 +64,4 @@ Please cite Myria3D if it helped your own research. Here is an example BibTex en
year={2022},
note={IGN (French Mapping Agency)},
}
```
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_convert_: all # For omegaconf struct to be converted to python dictionnaries
# classification_preprocessing_dict = {source_class_code_int: target_class_code_int},
# 3: medium vegetation -> vegetation
# 4: high vegetation -> vegetation
# 0: no processing --> unclassified
# 66: synthetic points --> noise (synthetic points are useful for specific modelling task on already classified data).
# We set them to noise so that they are ignored during training.
# Codes that should not have been in the data: 100, 101. (note: 200 and 201 may have been reported too, leaving that for now)
classification_preprocessing_dict: {3: 5, 4: 5, 0: 1, 64: 1, 66: 65, 67: 65, 100: 1, 101: 1}

# classification_dict = {code_int: name_str, ...} and MUST be sorted (increasing order).
classification_dict: {1: "unclassified", 2: "ground", 5: vegetation, 6: "building", 9: water, 17: bridge}

# class_weights for the CrossEntropyLoss with format "[[w1,w2,w3...,wk]]" with w_i a float e.g. 1.0
# Balanced CE: arbitrary weights based on heuristic.
# class_weights: normalized so they sum to num of classes to preserve scale of CELoss
class_weights: [0.2,0.002,0.001,0.03,0.75,5.0]

# Input and output dims of neural net are dataset dependant:
d_in: 3
num_classes: 6
21 changes: 21 additions & 0 deletions configs/experiment/PointNet_baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# @package _global_
defaults:
- override /datamodule/transforms/augmentations: light.yaml
- override /model: pointnet_model.yaml

logger:
comet:
experiment_name: "PointNet-(BatchSize16xBudget(300pts-40000pts))"


# Smaller BS : 16 x 4096
datamodule:
batch_size: 16

trainer:
accelerator: gpu
num_sanity_val_steps: 2
min_epochs: 20
max_epochs: 150
accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size
# gpus: [1]
25 changes: 25 additions & 0 deletions configs/experiment/PointNet_norgb_baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# @package _global_
defaults:
- override /datamodule/transforms/augmentations: light.yaml
- override /model: pointnet_model.yaml
- override /dataset_description: 20230601_lidarhd_norgb_pacasam_dataset.yaml

logger:
comet:
experiment_name: "PointNet-(BatchSize16xBudget(300pts-40000pts))"


# Smaller BS : 16 x 4096
datamodule:
batch_size: 16
points_pre_transform:
_args_:
- "${get_method:myria3d.pctl.points_pre_transform.lidar_hd_norgb.lidar_hd_norgb_pre_transform}"

trainer:
accelerator: gpu
num_sanity_val_steps: 2
min_epochs: 20
max_epochs: 150
accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size
# gpus: [1]
22 changes: 22 additions & 0 deletions configs/experiment/RandLaNet_base_norgb_run_FR.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# @package _global_
defaults:
- override /datamodule/transforms/augmentations: light.yaml
- override /dataset_description: 20230601_lidarhd_norgb_pacasam_dataset.yaml

logger:
comet:
experiment_name: "RandLaNet_base_run_FR-(BatchSize10xBudget(300pts-40000pts))"

# Smaller BS : 10 x 40 000 (max) == 400 000 pts i.e. previous budget of 32 x 12 500pts.
datamodule:
batch_size: 10
points_pre_transform:
_args_:
- "${get_method:myria3d.pctl.points_pre_transform.lidar_hd_norgb.lidar_hd_norgb_pre_transform}"

trainer:
num_sanity_val_steps: 2
min_epochs: 100
max_epochs: 150
accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size
# gpus: [1]
10 changes: 10 additions & 0 deletions configs/model/pointnet_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defaults:
- default.yaml

lr: 0.003933709606504788 # 200-long LR-test-range betw 10^-4 and 3.0

neural_net_class_name: "PointNet"
neural_net_hparams:
num_features: ${model.d_in} # 3 (xyz) + num of features
num_classes: ${model.num_classes}
subsample: 4096
3 changes: 2 additions & 1 deletion myria3d/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from torch_geometric.nn import knn_interpolate

from myria3d.models.modules.pyg_randla_net import PyGRandLANet
from myria3d.models.modules.pointnet import PointNet
from myria3d.utils import utils

log = utils.get_logger(__name__)

MODEL_ZOO = [PyGRandLANet]
MODEL_ZOO = [PyGRandLANet, PointNet]


def get_neural_net_class(class_name: str) -> nn.Module:
Expand Down
126 changes: 126 additions & 0 deletions myria3d/models/modules/pointnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch.nn as nn

import numpy as np
from sklearn.neighbors import NearestNeighbors
from myria3d.utils import utils # noqa

class PointNet(nn.Module):
"""
PointNet network for semantic segmentation
"""

def __init__(
self,
num_classes: int,
num_features: int,
subsample: int = 512,
MLP_1: list = [32,64],
MLP_2: list = [64,128,256],
MLP_3: list = [128,64,32],
):
"""
initialization function
num_classes : int = the number of class
num_features : int = number of input feature
MLP_1, MLP_2 and MLP_3 : int list = width of the layers of
multi-layer perceptrons. For example MLP_1 = [32, 64] or [16, 64, 128]
cuda : int = if 0 run on CPU (slow but easy to debug), if 1 on GPU
"""

super(PointNet, self).__init__() #necessary for all classes extending the module class
self.subsample = subsample
self.num_features = num_features + 3 #pos

#since we don't know the number of layers in the MLPs, we need to use loops
#to create the correct number of layers

m1 = MLP_1[-1] #size of the first embeding F1
m2 = MLP_2[-1] #size of the second embeding F2

#MLP_1: input [num_features x n] -> f1 [m1 x n]
modules = []
for i in range(len(MLP_1)): #loop over the layer of MLP1
#note: for the first layer, the first in_channels is feature_size
modules.append(
nn.Conv1d(in_channels=MLP_1[i-1] if i>0 else self.num_features, #to handle i=0
out_channels=MLP_1[i], kernel_size=1))
modules.append(nn.BatchNorm1d(MLP_1[i]))
modules.append(nn.ReLU(True))
#this transform the list of layers into a callable module
self.MLP_1 = nn.Sequential(*modules)

#MLP_2: f1 [m1 x n] -> f2 [m2 x n]
modules = []
for i in range(len(MLP_2)):
modules.append(nn.Conv1d(in_channels=MLP_2[i-1] if i>0 else m1,
out_channels=MLP_2[i], kernel_size=1))
modules.append(nn.BatchNorm1d(MLP_2[i]))
modules.append(nn.ReLU(True))
self.MLP_2 = nn.Sequential(*modules)

#MLP_3: f1 [(m1 + m2) x n] -> output [k x n]
modules = []
for i in range(len(MLP_3)):
modules.append(nn.Conv1d(in_channels=MLP_3[i-1] if i>0 else m1+m2,
out_channels=MLP_3[i], kernel_size=1))
modules.append(nn.BatchNorm1d(MLP_3[i]))
modules.append(nn.ReLU(True))

#note: the last layer do not have normalization nor activation
modules.append(nn.Conv1d(MLP_3[-1], num_classes, 1))

self.MLP_3 = nn.Sequential(*modules)

def forward(self, x, pos, batch, ptr):
"""
the forward function producing the embeddings for each point of 'input'
input : [n_batch, num_features, n_points] float array = input features
output : [n_batch, num_classes, n_points] float array = point class logits
"""
n_batch = ptr.size(0) - 1
n_points = len(pos)
input_all=torch.cat((pos,x), axis=1)
input=torch.Tensor(n_batch, self.num_features, self.subsample)
out=torch.Tensor(n_batch, self.num_features+1, n_points)

for i_batch in range(n_batch):
b_idx = np.where(batch.cpu()==i_batch)
full_cloud = input_all[b_idx]
n_full = full_cloud.shape[0]
selected_points = np.random.choice(n_full, self.subsample)
input_batch = full_cloud[selected_points]
input[i_batch,:,:] = input_batch.T

input = input.to(ptr.device)

#embed points, equation (1)
b1_out = self.MLP_1(input)

#second point embeddings equation (2)
b2_out = self.MLP_2(b1_out)

#maxpool, equation 3
G = torch.max(b2_out,2,keepdim=True)[0]

#concatenate f1 and G
Gf1 = torch.cat((G.repeat(1,1,self.subsample), b1_out),1)

#equation(4)
pred = self.MLP_3(Gf1)
pred = pred.permute(0,2,1).flatten(0,1)

pos_sampled = input[:,:3,:].permute(0,2,1).flatten(0,1)

knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit( \
pos_sampled.cpu())

_, closest_point = knn.kneighbors(pos.cpu())

#remove uneeded dimension (we only took one neighbor)
closest_point = closest_point.squeeze()

out = pred[closest_point,:]

return out

1 change: 1 addition & 0 deletions myria3d/pctl/dataset/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
split_cloud_into_samples,
)
from myria3d.pctl.points_pre_transform.lidar_hd import lidar_hd_pre_transform
from myria3d.pctl.points_pre_transform.lidar_hd_norgb import lidar_hd_norgb_pre_transform
from myria3d.utils import utils

log = utils.get_logger(__name__)
Expand Down
1 change: 1 addition & 0 deletions myria3d/pctl/dataset/iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
split_cloud_into_samples,
)
from myria3d.pctl.points_pre_transform.lidar_hd import lidar_hd_pre_transform
from myria3d.pctl.points_pre_transform.lidar_hd_norgb import lidar_hd_norgb_pre_transform


class InferenceDataset(IterableDataset):
Expand Down
46 changes: 46 additions & 0 deletions myria3d/pctl/points_pre_transform/lidar_hd_norgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# function to turn points loaded via pdal into a pyg Data object, with additional channels
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs tests as well, that can be added in https://github.com/IGNF/myria3d/tree/main/tests/myria3d/pctl/transforms

import numpy as np
from torch_geometric.data import Data

RETURN_NUMBER_NORMALIZATION_MAX_VALUE = 10.0


def lidar_hd_norgb_pre_transform(points):
"""Turn pdal points into torch-geometric Data object.

Args:
las_filepath (str): path to the LAS file.

Returns:
Data: the point cloud formatted for later deep learning training.

"""
# Positions and base features
pos = np.asarray([points["X"], points["Y"], points["Z"]], dtype=np.float32).transpose()

# normalization
points["ReturnNumber"] = (points["ReturnNumber"]) / (RETURN_NUMBER_NORMALIZATION_MAX_VALUE)
points["NumberOfReturns"] = (points["NumberOfReturns"]) / (
RETURN_NUMBER_NORMALIZATION_MAX_VALUE
)

# todo
x = np.array([
points[name]
for name in [
"Intensity",
"ReturnNumber",
"NumberOfReturns"
]
]).transpose()
x_features_names = [
"Intensity",
"ReturnNumber",
"NumberOfReturns"
]
y = points["Classification"]

data = Data(pos=pos, x=x, y=y, x_features_names=x_features_names)

return data

Loading