-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Uncolored clouds and PointNet baseline #132
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,60 @@ | ||
<div align="center"> | ||
|
||
# Myria3D: Aerial Lidar HD Semantic Segmentation with Deep Learning | ||
# Fork adapted to train/infer on non-colorized data | ||
</div> | ||
|
||
Myria3D is a deep learning library designed with a focused scope: the multiclass semantic segmentation of large scale, high density aerial Lidar points cloud. | ||
|
||
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a> | ||
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a> | ||
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a> | ||
This fork includes option to train on Lidar HD data without RGB attributes. It also implements the PointNet baseline. Two pretrained models are available, RandLaNet and PointNet, in the form of .ckpt files with the best version of the model. The models were trained on the same twelve Lidar HD tiles, list of the tiles trained upon is in trained_model_assets/lidarhd_dataset_split.csv. Training metrics can be observed on Comet: https://shorturl.at/QHZVd | ||
|
||
[![](https://shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=303030)](https://github.com/ashleve/lightning-hydra-template) | ||
The training was performed on a laptop with 3070Ti GPU (8GB VRAM), 32 GB RAM and i7-12700H. Batch sizes were adapted to the specifications. | ||
PointNet implementation subsamples the 50x50m tiles to 4096 points, then upsamples with 1-nn. RandLaNet remains unchanged. | ||
|
||
[![CICD](https://github.com/IGNF/myria3d/actions/workflows/cicd.yaml/badge.svg)](https://github.com/IGNF/myria3d/actions/workflows/cicd.yaml) | ||
[![Documentation Build](https://github.com/IGNF/myria3d/actions/workflows/gh-pages.yml/badge.svg)](https://github.com/IGNF/myria3d/actions/workflows/gh-pages.yml) | ||
</div> | ||
<br><br> | ||
To train PointNet: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
`python run.py experiment=PointNet_baseline` | ||
|
||
Myria3D is a deep learning library designed with a focused scope: the multiclass semantic segmentation of large scale, high density aerial Lidar points cloud. | ||
|
||
The library implements the training of 3D Segmentation neural networks, with optimized data-processing and evaluation logics at fit time. Inference on unseen, large scale point cloud is also supported. | ||
It allows for the evaluation of single-class IoU on the full point cloud, which results in reliable model evaluation. | ||
To train RandLaNet: | ||
|
||
`python run.py experiment=RandLaNet_base_run_FR` | ||
|
||
Myria3D is built upon [PyTorch](https://pytorch.org/). It keeps the standard data format | ||
from [Pytorch-Geometric](https://pytorch-geometric.readthedocs.io/). | ||
Its structure was bootstraped from [this code template](https://github.com/ashleve/lightning-hydra-template), | ||
which heavily relies on [Hydra](https://hydra.cc/) and [Pytorch-Lightning](https://github.com/PyTorchLightning/pytorch-lightning) to enable flexible and rapid iterations of deep learning experiments. | ||
|
||
Although the library can be extended with new neural network architectures or new data signatures, it makes some opiniated choices in terms of neural network architecture, data processing logics, and inference logic. Indeed, it is initially built with the [French "Lidar HD" project](https://geoservices.ign.fr/lidarhd) in mind, with the ambition to map France in 3D with 10 pulse/m² aerial Lidar by 2025. The data will be openly available, including a semantic segmentation with a minimal number of classes: ground, vegetation, buildings, vehicles, bridges, others. | ||
To infer PointNet model: | ||
|
||
> → For installation and usage, please refer to [**Documentation**](https://ignf.github.io/myria3d/). | ||
`python run.py task.task_name='predict' predict.src_las='/path/to/las' predict.output_dir='/path/to/output' datamodule.epsg=2154 predict.ckpt_path='${hydra:runtime.cwd}/trained_model_assets/randlanet_norgb_epoch_028.ckpt' trainer.accelerator=gpu predict.gpus=[0]` | ||
|
||
> → A stable, production-ready version of Myria3D is tracked by a [Production Release](https://github.com/IGNF/myria3d/releases/tag/prod-release-tag). In the release's assets are a trained multiclass segmentation model as well as the necessary configuration file to perform inference on French "Lidar HD" data. Those assets are provided for convenience, and are subject to change in time to reflect latest model training. | ||
(to achieve better results add `predict.subtile_overlap=25`) | ||
|
||
|
||
To infer RandLaNet model: | ||
|
||
`python run.py task.task_name='predict' predict.src_las='/path/to/las' predict.output_dir='/path/to/output' datamodule.epsg=2154 predict.ckpt_path='${hydra:runtime.cwd}/trained_model_assets/pointnet_norgb_epoch_020.ckpt' trainer.accelerator=gpu predict.gpus=[0]` | ||
___ | ||
|
||
# Comparisons | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that your comparisons can have a place in the documentation too, provided that you update the links so that the images are stored in our repo as well (in case yours is modified) |
||
<p float="left"> | ||
----------------Ground Truth------------------------------- | ||
PointNet---------------------------------- | ||
RandLaNet--------------- | ||
</p> | ||
<p float="left"> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex1/gt.png?raw=true" width="250" /> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex1/pointnet.png?raw=true" width="250" /> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex1/randlanet.png?raw=true" width="250" /> | ||
</p> | ||
<p float="left"> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex2/gt.png?raw=true" width="250" /> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex2/pointnet.png?raw=true" width="250" /> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex2/randlanet.png?raw=true" width="250" /> | ||
</p> | ||
<p float="left"> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex3/gt.png?raw=true" width="250" /> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex3/pointnet.png?raw=true" width="250" /> | ||
<img src="https://github.com/Vynikal/myria3d/blob/lidar_hd/im/ex3/randlanet.png?raw=true" width="250" /> | ||
</p> | ||
|
||
Please cite Myria3D if it helped your own research. Here is an example BibTex entry: | ||
``` | ||
@misc{gaydon2022myria3d, | ||
|
@@ -41,4 +64,4 @@ Please cite Myria3D if it helped your own research. Here is an example BibTex en | |
year={2022}, | ||
note={IGN (French Mapping Agency)}, | ||
} | ||
``` | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
_convert_: all # For omegaconf struct to be converted to python dictionnaries | ||
# classification_preprocessing_dict = {source_class_code_int: target_class_code_int}, | ||
# 3: medium vegetation -> vegetation | ||
# 4: high vegetation -> vegetation | ||
# 0: no processing --> unclassified | ||
# 66: synthetic points --> noise (synthetic points are useful for specific modelling task on already classified data). | ||
# We set them to noise so that they are ignored during training. | ||
# Codes that should not have been in the data: 100, 101. (note: 200 and 201 may have been reported too, leaving that for now) | ||
classification_preprocessing_dict: {3: 5, 4: 5, 0: 1, 64: 1, 66: 65, 67: 65, 100: 1, 101: 1} | ||
|
||
# classification_dict = {code_int: name_str, ...} and MUST be sorted (increasing order). | ||
classification_dict: {1: "unclassified", 2: "ground", 5: vegetation, 6: "building", 9: water, 17: bridge} | ||
|
||
# class_weights for the CrossEntropyLoss with format "[[w1,w2,w3...,wk]]" with w_i a float e.g. 1.0 | ||
# Balanced CE: arbitrary weights based on heuristic. | ||
# class_weights: normalized so they sum to num of classes to preserve scale of CELoss | ||
class_weights: [0.2,0.002,0.001,0.03,0.75,5.0] | ||
|
||
# Input and output dims of neural net are dataset dependant: | ||
d_in: 3 | ||
num_classes: 6 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# @package _global_ | ||
defaults: | ||
- override /datamodule/transforms/augmentations: light.yaml | ||
- override /model: pointnet_model.yaml | ||
|
||
logger: | ||
comet: | ||
experiment_name: "PointNet-(BatchSize16xBudget(300pts-40000pts))" | ||
|
||
|
||
# Smaller BS : 16 x 4096 | ||
datamodule: | ||
batch_size: 16 | ||
|
||
trainer: | ||
accelerator: gpu | ||
num_sanity_val_steps: 2 | ||
min_epochs: 20 | ||
max_epochs: 150 | ||
accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size | ||
# gpus: [1] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# @package _global_ | ||
defaults: | ||
- override /datamodule/transforms/augmentations: light.yaml | ||
- override /model: pointnet_model.yaml | ||
- override /dataset_description: 20230601_lidarhd_norgb_pacasam_dataset.yaml | ||
|
||
logger: | ||
comet: | ||
experiment_name: "PointNet-(BatchSize16xBudget(300pts-40000pts))" | ||
|
||
|
||
# Smaller BS : 16 x 4096 | ||
datamodule: | ||
batch_size: 16 | ||
points_pre_transform: | ||
_args_: | ||
- "${get_method:myria3d.pctl.points_pre_transform.lidar_hd_norgb.lidar_hd_norgb_pre_transform}" | ||
|
||
trainer: | ||
accelerator: gpu | ||
num_sanity_val_steps: 2 | ||
min_epochs: 20 | ||
max_epochs: 150 | ||
accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size | ||
# gpus: [1] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# @package _global_ | ||
defaults: | ||
- override /datamodule/transforms/augmentations: light.yaml | ||
- override /dataset_description: 20230601_lidarhd_norgb_pacasam_dataset.yaml | ||
|
||
logger: | ||
comet: | ||
experiment_name: "RandLaNet_base_run_FR-(BatchSize10xBudget(300pts-40000pts))" | ||
|
||
# Smaller BS : 10 x 40 000 (max) == 400 000 pts i.e. previous budget of 32 x 12 500pts. | ||
datamodule: | ||
batch_size: 10 | ||
points_pre_transform: | ||
_args_: | ||
- "${get_method:myria3d.pctl.points_pre_transform.lidar_hd_norgb.lidar_hd_norgb_pre_transform}" | ||
|
||
trainer: | ||
num_sanity_val_steps: 2 | ||
min_epochs: 100 | ||
max_epochs: 150 | ||
accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size | ||
# gpus: [1] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
defaults: | ||
- default.yaml | ||
|
||
lr: 0.003933709606504788 # 200-long LR-test-range betw 10^-4 and 3.0 | ||
|
||
neural_net_class_name: "PointNet" | ||
neural_net_hparams: | ||
num_features: ${model.d_in} # 3 (xyz) + num of features | ||
num_classes: ${model.num_classes} | ||
subsample: 4096 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import torch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs testing, tests can be added in https://github.com/IGNF/myria3d/blob/main/tests/myria3d/models/modules/ (see https://github.com/IGNF/myria3d/blob/main/tests/myria3d/models/modules/test_randla_nets.py for an example) |
||
import torch.nn as nn | ||
|
||
import numpy as np | ||
from sklearn.neighbors import NearestNeighbors | ||
from myria3d.utils import utils # noqa | ||
|
||
class PointNet(nn.Module): | ||
""" | ||
PointNet network for semantic segmentation | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_classes: int, | ||
num_features: int, | ||
subsample: int = 512, | ||
MLP_1: list = [32,64], | ||
MLP_2: list = [64,128,256], | ||
MLP_3: list = [128,64,32], | ||
): | ||
""" | ||
initialization function | ||
num_classes : int = the number of class | ||
num_features : int = number of input feature | ||
MLP_1, MLP_2 and MLP_3 : int list = width of the layers of | ||
multi-layer perceptrons. For example MLP_1 = [32, 64] or [16, 64, 128] | ||
cuda : int = if 0 run on CPU (slow but easy to debug), if 1 on GPU | ||
""" | ||
|
||
super(PointNet, self).__init__() #necessary for all classes extending the module class | ||
self.subsample = subsample | ||
self.num_features = num_features + 3 #pos | ||
|
||
#since we don't know the number of layers in the MLPs, we need to use loops | ||
#to create the correct number of layers | ||
|
||
m1 = MLP_1[-1] #size of the first embeding F1 | ||
m2 = MLP_2[-1] #size of the second embeding F2 | ||
|
||
#MLP_1: input [num_features x n] -> f1 [m1 x n] | ||
modules = [] | ||
for i in range(len(MLP_1)): #loop over the layer of MLP1 | ||
#note: for the first layer, the first in_channels is feature_size | ||
modules.append( | ||
nn.Conv1d(in_channels=MLP_1[i-1] if i>0 else self.num_features, #to handle i=0 | ||
out_channels=MLP_1[i], kernel_size=1)) | ||
modules.append(nn.BatchNorm1d(MLP_1[i])) | ||
modules.append(nn.ReLU(True)) | ||
#this transform the list of layers into a callable module | ||
self.MLP_1 = nn.Sequential(*modules) | ||
|
||
#MLP_2: f1 [m1 x n] -> f2 [m2 x n] | ||
modules = [] | ||
for i in range(len(MLP_2)): | ||
modules.append(nn.Conv1d(in_channels=MLP_2[i-1] if i>0 else m1, | ||
out_channels=MLP_2[i], kernel_size=1)) | ||
modules.append(nn.BatchNorm1d(MLP_2[i])) | ||
modules.append(nn.ReLU(True)) | ||
self.MLP_2 = nn.Sequential(*modules) | ||
|
||
#MLP_3: f1 [(m1 + m2) x n] -> output [k x n] | ||
modules = [] | ||
for i in range(len(MLP_3)): | ||
modules.append(nn.Conv1d(in_channels=MLP_3[i-1] if i>0 else m1+m2, | ||
out_channels=MLP_3[i], kernel_size=1)) | ||
modules.append(nn.BatchNorm1d(MLP_3[i])) | ||
modules.append(nn.ReLU(True)) | ||
|
||
#note: the last layer do not have normalization nor activation | ||
modules.append(nn.Conv1d(MLP_3[-1], num_classes, 1)) | ||
|
||
self.MLP_3 = nn.Sequential(*modules) | ||
|
||
def forward(self, x, pos, batch, ptr): | ||
""" | ||
the forward function producing the embeddings for each point of 'input' | ||
input : [n_batch, num_features, n_points] float array = input features | ||
output : [n_batch, num_classes, n_points] float array = point class logits | ||
""" | ||
n_batch = ptr.size(0) - 1 | ||
n_points = len(pos) | ||
input_all=torch.cat((pos,x), axis=1) | ||
input=torch.Tensor(n_batch, self.num_features, self.subsample) | ||
out=torch.Tensor(n_batch, self.num_features+1, n_points) | ||
|
||
for i_batch in range(n_batch): | ||
b_idx = np.where(batch.cpu()==i_batch) | ||
full_cloud = input_all[b_idx] | ||
n_full = full_cloud.shape[0] | ||
selected_points = np.random.choice(n_full, self.subsample) | ||
input_batch = full_cloud[selected_points] | ||
input[i_batch,:,:] = input_batch.T | ||
|
||
input = input.to(ptr.device) | ||
|
||
#embed points, equation (1) | ||
b1_out = self.MLP_1(input) | ||
|
||
#second point embeddings equation (2) | ||
b2_out = self.MLP_2(b1_out) | ||
|
||
#maxpool, equation 3 | ||
G = torch.max(b2_out,2,keepdim=True)[0] | ||
|
||
#concatenate f1 and G | ||
Gf1 = torch.cat((G.repeat(1,1,self.subsample), b1_out),1) | ||
|
||
#equation(4) | ||
pred = self.MLP_3(Gf1) | ||
pred = pred.permute(0,2,1).flatten(0,1) | ||
|
||
pos_sampled = input[:,:3,:].permute(0,2,1).flatten(0,1) | ||
|
||
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit( \ | ||
pos_sampled.cpu()) | ||
|
||
_, closest_point = knn.kneighbors(pos.cpu()) | ||
|
||
#remove uneeded dimension (we only took one neighbor) | ||
closest_point = closest_point.squeeze() | ||
|
||
out = pred[closest_point,:] | ||
|
||
return out | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# function to turn points loaded via pdal into a pyg Data object, with additional channels | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this needs tests as well, that can be added in https://github.com/IGNF/myria3d/tree/main/tests/myria3d/pctl/transforms |
||
import numpy as np | ||
from torch_geometric.data import Data | ||
|
||
RETURN_NUMBER_NORMALIZATION_MAX_VALUE = 10.0 | ||
|
||
|
||
def lidar_hd_norgb_pre_transform(points): | ||
"""Turn pdal points into torch-geometric Data object. | ||
|
||
Args: | ||
las_filepath (str): path to the LAS file. | ||
|
||
Returns: | ||
Data: the point cloud formatted for later deep learning training. | ||
|
||
""" | ||
# Positions and base features | ||
pos = np.asarray([points["X"], points["Y"], points["Z"]], dtype=np.float32).transpose() | ||
|
||
# normalization | ||
points["ReturnNumber"] = (points["ReturnNumber"]) / (RETURN_NUMBER_NORMALIZATION_MAX_VALUE) | ||
points["NumberOfReturns"] = (points["NumberOfReturns"]) / ( | ||
RETURN_NUMBER_NORMALIZATION_MAX_VALUE | ||
) | ||
|
||
# todo | ||
x = np.array([ | ||
points[name] | ||
for name in [ | ||
"Intensity", | ||
"ReturnNumber", | ||
"NumberOfReturns" | ||
] | ||
]).transpose() | ||
x_features_names = [ | ||
"Intensity", | ||
"ReturnNumber", | ||
"NumberOfReturns" | ||
] | ||
y = points["Classification"] | ||
|
||
data = Data(pos=pos, x=x, y=y, x_features_names=x_features_names) | ||
|
||
return data | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove the parts that are relate to your fork from this MR please, as it won't be relevant once merged