diff --git a/README.md b/README.md index 5f20fc9f..51a134b6 100755 --- a/README.md +++ b/README.md @@ -1,37 +1,60 @@
+ ----------------Ground Truth------------------------------- + PointNet---------------------------------- + RandLaNet--------------- +
++ + + +
++ + + +
++ + + +
+ Please cite Myria3D if it helped your own research. Here is an example BibTex entry: ``` @misc{gaydon2022myria3d, @@ -41,4 +64,4 @@ Please cite Myria3D if it helped your own research. Here is an example BibTex en year={2022}, note={IGN (French Mapping Agency)}, } -``` \ No newline at end of file +``` diff --git a/configs/dataset_description/20230601_lidarhd_norgb_pacasam_dataset.yaml b/configs/dataset_description/20230601_lidarhd_norgb_pacasam_dataset.yaml new file mode 100644 index 00000000..b1485ef2 --- /dev/null +++ b/configs/dataset_description/20230601_lidarhd_norgb_pacasam_dataset.yaml @@ -0,0 +1,21 @@ +_convert_: all # For omegaconf struct to be converted to python dictionnaries +# classification_preprocessing_dict = {source_class_code_int: target_class_code_int}, +# 3: medium vegetation -> vegetation +# 4: high vegetation -> vegetation +# 0: no processing --> unclassified +# 66: synthetic points --> noise (synthetic points are useful for specific modelling task on already classified data). +# We set them to noise so that they are ignored during training. +# Codes that should not have been in the data: 100, 101. (note: 200 and 201 may have been reported too, leaving that for now) +classification_preprocessing_dict: {3: 5, 4: 5, 0: 1, 64: 1, 66: 65, 67: 65, 100: 1, 101: 1} + +# classification_dict = {code_int: name_str, ...} and MUST be sorted (increasing order). +classification_dict: {1: "unclassified", 2: "ground", 5: vegetation, 6: "building", 9: water, 17: bridge} + +# class_weights for the CrossEntropyLoss with format "[[w1,w2,w3...,wk]]" with w_i a float e.g. 1.0 +# Balanced CE: arbitrary weights based on heuristic. +# class_weights: normalized so they sum to num of classes to preserve scale of CELoss +class_weights: [0.2,0.002,0.001,0.03,0.75,5.0] + +# Input and output dims of neural net are dataset dependant: +d_in: 3 +num_classes: 6 diff --git a/configs/experiment/PointNet_baseline.yaml b/configs/experiment/PointNet_baseline.yaml new file mode 100644 index 00000000..bcbe0867 --- /dev/null +++ b/configs/experiment/PointNet_baseline.yaml @@ -0,0 +1,21 @@ +# @package _global_ +defaults: + - override /datamodule/transforms/augmentations: light.yaml + - override /model: pointnet_model.yaml + +logger: + comet: + experiment_name: "PointNet-(BatchSize16xBudget(300pts-40000pts))" + + +# Smaller BS : 16 x 4096 +datamodule: + batch_size: 16 + +trainer: + accelerator: gpu + num_sanity_val_steps: 2 + min_epochs: 20 + max_epochs: 150 + accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size + # gpus: [1] diff --git a/configs/experiment/PointNet_norgb_baseline.yaml b/configs/experiment/PointNet_norgb_baseline.yaml new file mode 100644 index 00000000..ad08aa00 --- /dev/null +++ b/configs/experiment/PointNet_norgb_baseline.yaml @@ -0,0 +1,25 @@ +# @package _global_ +defaults: + - override /datamodule/transforms/augmentations: light.yaml + - override /model: pointnet_model.yaml + - override /dataset_description: 20230601_lidarhd_norgb_pacasam_dataset.yaml + +logger: + comet: + experiment_name: "PointNet-(BatchSize16xBudget(300pts-40000pts))" + + +# Smaller BS : 16 x 4096 +datamodule: + batch_size: 16 + points_pre_transform: + _args_: + - "${get_method:myria3d.pctl.points_pre_transform.lidar_hd_norgb.lidar_hd_norgb_pre_transform}" + +trainer: + accelerator: gpu + num_sanity_val_steps: 2 + min_epochs: 20 + max_epochs: 150 + accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size + # gpus: [1] diff --git a/configs/experiment/RandLaNet_base_norgb_run_FR.yaml b/configs/experiment/RandLaNet_base_norgb_run_FR.yaml new file mode 100644 index 00000000..51f30dc8 --- /dev/null +++ b/configs/experiment/RandLaNet_base_norgb_run_FR.yaml @@ -0,0 +1,22 @@ +# @package _global_ +defaults: + - override /datamodule/transforms/augmentations: light.yaml + - override /dataset_description: 20230601_lidarhd_norgb_pacasam_dataset.yaml + +logger: + comet: + experiment_name: "RandLaNet_base_run_FR-(BatchSize10xBudget(300pts-40000pts))" + +# Smaller BS : 10 x 40 000 (max) == 400 000 pts i.e. previous budget of 32 x 12 500pts. +datamodule: + batch_size: 10 + points_pre_transform: + _args_: + - "${get_method:myria3d.pctl.points_pre_transform.lidar_hd_norgb.lidar_hd_norgb_pre_transform}" + +trainer: + num_sanity_val_steps: 2 + min_epochs: 100 + max_epochs: 150 + accumulate_grad_batches: 3 # b/c larger clouds will not fit in memory with original Batch Size + # gpus: [1] \ No newline at end of file diff --git a/configs/model/pointnet_model.yaml b/configs/model/pointnet_model.yaml new file mode 100644 index 00000000..b1fb774b --- /dev/null +++ b/configs/model/pointnet_model.yaml @@ -0,0 +1,10 @@ +defaults: + - default.yaml + +lr: 0.003933709606504788 # 200-long LR-test-range betw 10^-4 and 3.0 + +neural_net_class_name: "PointNet" +neural_net_hparams: + num_features: ${model.d_in} # 3 (xyz) + num of features + num_classes: ${model.num_classes} + subsample: 4096 diff --git a/myria3d/models/model.py b/myria3d/models/model.py index f1d842d7..358dd653 100755 --- a/myria3d/models/model.py +++ b/myria3d/models/model.py @@ -5,11 +5,12 @@ from torch_geometric.nn import knn_interpolate from myria3d.models.modules.pyg_randla_net import PyGRandLANet +from myria3d.models.modules.pointnet import PointNet from myria3d.utils import utils log = utils.get_logger(__name__) -MODEL_ZOO = [PyGRandLANet] +MODEL_ZOO = [PyGRandLANet, PointNet] def get_neural_net_class(class_name: str) -> nn.Module: diff --git a/myria3d/models/modules/pointnet.py b/myria3d/models/modules/pointnet.py new file mode 100644 index 00000000..fcd0d197 --- /dev/null +++ b/myria3d/models/modules/pointnet.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn + +import numpy as np +from sklearn.neighbors import NearestNeighbors +from myria3d.utils import utils # noqa + +class PointNet(nn.Module): + """ + PointNet network for semantic segmentation + """ + + def __init__( + self, + num_classes: int, + num_features: int, + subsample: int = 512, + MLP_1: list = [32,64], + MLP_2: list = [64,128,256], + MLP_3: list = [128,64,32], + ): + """ + initialization function + num_classes : int = the number of class + num_features : int = number of input feature + MLP_1, MLP_2 and MLP_3 : int list = width of the layers of + multi-layer perceptrons. For example MLP_1 = [32, 64] or [16, 64, 128] + cuda : int = if 0 run on CPU (slow but easy to debug), if 1 on GPU + """ + + super(PointNet, self).__init__() #necessary for all classes extending the module class + self.subsample = subsample + self.num_features = num_features + 3 #pos + + #since we don't know the number of layers in the MLPs, we need to use loops + #to create the correct number of layers + + m1 = MLP_1[-1] #size of the first embeding F1 + m2 = MLP_2[-1] #size of the second embeding F2 + + #MLP_1: input [num_features x n] -> f1 [m1 x n] + modules = [] + for i in range(len(MLP_1)): #loop over the layer of MLP1 + #note: for the first layer, the first in_channels is feature_size + modules.append( + nn.Conv1d(in_channels=MLP_1[i-1] if i>0 else self.num_features, #to handle i=0 + out_channels=MLP_1[i], kernel_size=1)) + modules.append(nn.BatchNorm1d(MLP_1[i])) + modules.append(nn.ReLU(True)) + #this transform the list of layers into a callable module + self.MLP_1 = nn.Sequential(*modules) + + #MLP_2: f1 [m1 x n] -> f2 [m2 x n] + modules = [] + for i in range(len(MLP_2)): + modules.append(nn.Conv1d(in_channels=MLP_2[i-1] if i>0 else m1, + out_channels=MLP_2[i], kernel_size=1)) + modules.append(nn.BatchNorm1d(MLP_2[i])) + modules.append(nn.ReLU(True)) + self.MLP_2 = nn.Sequential(*modules) + + #MLP_3: f1 [(m1 + m2) x n] -> output [k x n] + modules = [] + for i in range(len(MLP_3)): + modules.append(nn.Conv1d(in_channels=MLP_3[i-1] if i>0 else m1+m2, + out_channels=MLP_3[i], kernel_size=1)) + modules.append(nn.BatchNorm1d(MLP_3[i])) + modules.append(nn.ReLU(True)) + + #note: the last layer do not have normalization nor activation + modules.append(nn.Conv1d(MLP_3[-1], num_classes, 1)) + + self.MLP_3 = nn.Sequential(*modules) + + def forward(self, x, pos, batch, ptr): + """ + the forward function producing the embeddings for each point of 'input' + input : [n_batch, num_features, n_points] float array = input features + output : [n_batch, num_classes, n_points] float array = point class logits + """ + n_batch = ptr.size(0) - 1 + n_points = len(pos) + input_all=torch.cat((pos,x), axis=1) + input=torch.Tensor(n_batch, self.num_features, self.subsample) + out=torch.Tensor(n_batch, self.num_features+1, n_points) + + for i_batch in range(n_batch): + b_idx = np.where(batch.cpu()==i_batch) + full_cloud = input_all[b_idx] + n_full = full_cloud.shape[0] + selected_points = np.random.choice(n_full, self.subsample) + input_batch = full_cloud[selected_points] + input[i_batch,:,:] = input_batch.T + + input = input.to(ptr.device) + + #embed points, equation (1) + b1_out = self.MLP_1(input) + + #second point embeddings equation (2) + b2_out = self.MLP_2(b1_out) + + #maxpool, equation 3 + G = torch.max(b2_out,2,keepdim=True)[0] + + #concatenate f1 and G + Gf1 = torch.cat((G.repeat(1,1,self.subsample), b1_out),1) + + #equation(4) + pred = self.MLP_3(Gf1) + pred = pred.permute(0,2,1).flatten(0,1) + + pos_sampled = input[:,:3,:].permute(0,2,1).flatten(0,1) + + knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit( \ + pos_sampled.cpu()) + + _, closest_point = knn.kneighbors(pos.cpu()) + + #remove uneeded dimension (we only took one neighbor) + closest_point = closest_point.squeeze() + + out = pred[closest_point,:] + + return out + diff --git a/myria3d/pctl/dataset/hdf5.py b/myria3d/pctl/dataset/hdf5.py index b41438b5..328c09dd 100644 --- a/myria3d/pctl/dataset/hdf5.py +++ b/myria3d/pctl/dataset/hdf5.py @@ -17,6 +17,7 @@ split_cloud_into_samples, ) from myria3d.pctl.points_pre_transform.lidar_hd import lidar_hd_pre_transform +from myria3d.pctl.points_pre_transform.lidar_hd_norgb import lidar_hd_norgb_pre_transform from myria3d.utils import utils log = utils.get_logger(__name__) diff --git a/myria3d/pctl/dataset/iterable.py b/myria3d/pctl/dataset/iterable.py index 3d44a881..323004a2 100644 --- a/myria3d/pctl/dataset/iterable.py +++ b/myria3d/pctl/dataset/iterable.py @@ -11,6 +11,7 @@ split_cloud_into_samples, ) from myria3d.pctl.points_pre_transform.lidar_hd import lidar_hd_pre_transform +from myria3d.pctl.points_pre_transform.lidar_hd_norgb import lidar_hd_norgb_pre_transform class InferenceDataset(IterableDataset): diff --git a/myria3d/pctl/points_pre_transform/lidar_hd_norgb.py b/myria3d/pctl/points_pre_transform/lidar_hd_norgb.py new file mode 100644 index 00000000..bc5badc0 --- /dev/null +++ b/myria3d/pctl/points_pre_transform/lidar_hd_norgb.py @@ -0,0 +1,46 @@ +# function to turn points loaded via pdal into a pyg Data object, with additional channels +import numpy as np +from torch_geometric.data import Data + +RETURN_NUMBER_NORMALIZATION_MAX_VALUE = 10.0 + + +def lidar_hd_norgb_pre_transform(points): + """Turn pdal points into torch-geometric Data object. + + Args: + las_filepath (str): path to the LAS file. + + Returns: + Data: the point cloud formatted for later deep learning training. + + """ + # Positions and base features + pos = np.asarray([points["X"], points["Y"], points["Z"]], dtype=np.float32).transpose() + + # normalization + points["ReturnNumber"] = (points["ReturnNumber"]) / (RETURN_NUMBER_NORMALIZATION_MAX_VALUE) + points["NumberOfReturns"] = (points["NumberOfReturns"]) / ( + RETURN_NUMBER_NORMALIZATION_MAX_VALUE + ) + + # todo + x = np.array([ + points[name] + for name in [ + "Intensity", + "ReturnNumber", + "NumberOfReturns" + ] + ]).transpose() + x_features_names = [ + "Intensity", + "ReturnNumber", + "NumberOfReturns" + ] + y = points["Classification"] + + data = Data(pos=pos, x=x, y=y, x_features_names=x_features_names) + + return data + diff --git a/myria3d/pctl/transforms/transforms.py b/myria3d/pctl/transforms/transforms.py index da314cad..9159fbd6 100755 --- a/myria3d/pctl/transforms/transforms.py +++ b/myria3d/pctl/transforms/transforms.py @@ -122,8 +122,9 @@ def __call__(self, data: Data): # Log transform to be less sensitive to large outliers - info is in lower values data.x[:, idx] = torch.log(data.x[:, idx] + 1) data.x[:, idx] = self.standardize_channel(data.x[:, idx]) - idx = data.x_features_names.index("rgb_avg") - data.x[:, idx] = self.standardize_channel(data.x[:, idx]) + if "rgb_avg" in data.x_features_names: + idx = data.x_features_names.index("rgb_avg") + data.x[:, idx] = self.standardize_channel(data.x[:, idx]) return data def standardize_channel(self, channel_data: torch.Tensor, clamp_sigma: int = 3):