Skip to content

Commit

Permalink
Added extended training statistics and exporting (2) (#18)
Browse files Browse the repository at this point in the history
* Fixed indentation + grouped headers in generated solution

* Added statistics calculation

* Improved statistics printing

* Added stats.json generation

* Cleanup

* Fixed gcc compilation

* Fixed average calculation

* Reverted indentation

* Reverted linter advice

* Extracted logic and added new metrics

* Reverted linter advice

* Refactoring

* Sensitivity --> recall

* Simplify

---------

Co-authored-by: Luca Di Leo <[email protected]>
  • Loading branch information
pierotofy and HeDo88TH authored Apr 12, 2023
1 parent 0083cd4 commit f8e35c4
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 34 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ include_directories(${EIGEN3_INCLUDE_DIR})
include_directories(vendor/ethz)

set(SOURCES classifier.cpp scale.cpp point_io.cpp randomforest.cpp features.cpp color.cpp labels.cpp)
set(HEADERS classifier.hpp scale.hpp point_io.hpp randomforest.hpp features.hpp color.hpp labels.hpp statistics.hpp)
if (WITH_GBT)
list(APPEND SOURCES gbm.cpp)
list(APPEND HEADERS gbm.hpp)
Expand All @@ -96,7 +97,7 @@ if (WITH_PDAL)
set(PDAL_LIB ${PDAL_LIBRARIES})
endif()

add_library(libopc OBJECT ${SOURCES})
add_library(libopc OBJECT ${SOURCES} ${HEADERS})

if (WITH_GBT)
add_dependencies(libopc lightgbm)
Expand Down
36 changes: 20 additions & 16 deletions classifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,27 @@

#include <vector>
#include <random>
#include <cmath>

#include "features.hpp"
#include "labels.hpp"
#include "constants.hpp"
#include "point_io.hpp"
#include "statistics.hpp"

enum Regularization { None, LocalSmooth };
Regularization parseRegularization(const std::string &regularization);

enum ClassifierType { RandomForest, GradientBoostedTrees };
ClassifierType fingerprint(const std::string &modelFile);


template <typename F, typename I>
void getTrainingData(const std::vector<std::string> &filenames,
double *startResolution,
int numScales,
double radius,
int maxSamples,
const int numScales,
const double radius,
const int maxSamples,
const std::vector<int> &asprsClasses,
F storeFeatures,
I init) {
Expand Down Expand Up @@ -112,12 +115,14 @@ void classifyData(PointSet &pointSet,
F evaluateFunc,
const std::vector<Feature *> &features,
const std::vector<Label> &labels,
Regularization regularization,
double regRadius,
bool useColors,
bool unclassifiedOnly,
bool evaluate,
const std::vector<int> &skip) {
const Regularization regularization,
const double regRadius,
const bool useColors,
const bool unclassifiedOnly,
const bool evaluate,
const std::vector<int> &skip,
const std::string &statsFile) {

std::cout << "Classifying..." << std::endl;
pointSet.base->labels.resize(pointSet.base->count());

Expand Down Expand Up @@ -214,7 +219,6 @@ void classifyData(PointSet &pointSet,
throw std::runtime_error("Invalid regularization");
}

std::size_t correct = 0;
if (!useColors && !pointSet.hasLabels()) pointSet.labels.resize(pointSet.count());
std::vector<bool> skipMap(255, false);
for (size_t i = 0; i < skip.size(); i++) {
Expand All @@ -224,6 +228,8 @@ void classifyData(PointSet &pointSet,

auto train2asprsCodes = getTrain2AsprsCodes();

Statistics stats(labels);

#pragma omp parallel for
for (long long int i = 0; i < pointSet.count(); i++) {
const size_t idx = pointSet.pointMap[i];
Expand All @@ -232,10 +238,7 @@ void classifyData(PointSet &pointSet,
auto label = labels[bestClass];

if (evaluate) {
if (pointSet.labels[i] == bestClass) {
#pragma omp atomic
correct++;
}
stats.record(bestClass, pointSet.labels[i]);
}

bool update = true;
Expand Down Expand Up @@ -266,8 +269,9 @@ void classifyData(PointSet &pointSet,
}

if (evaluate) {
float modelErr = (1.f - static_cast<float>(correct) / static_cast<float>(pointSet.count()));
std::cout << "Model error: " << std::setprecision(4) << (modelErr * 100.f) << "%" << std::endl;
stats.finalize();
stats.print();
if (!statsFile.empty()) stats.writeToFile(statsFile);
}
}

Expand Down
6 changes: 4 additions & 2 deletions gbm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ void classify(PointSet &pointSet,
const bool useColors,
const bool unclassifiedOnly,
const bool evaluate,
const std::vector<int> &skip) {
const std::vector<int> &skip,
const std::string &statsFile
) {

LightGBM::PredictionEarlyStopConfig early_stop_config;
auto earlyStop = LightGBM::CreatePredictionEarlyStopInstance("none", early_stop_config);
Expand All @@ -179,7 +181,7 @@ void classify(PointSet &pointSet,
[&booster, &earlyStop](const double *ft, double *probs) {
booster->Predict(ft, probs, &earlyStop);
},
features, labels, regularization, regRadius, useColors, unclassifiedOnly, evaluate, skip);
features, labels, regularization, regRadius, useColors, unclassifiedOnly, evaluate, skip, statsFile);
}

}
Expand Down
5 changes: 3 additions & 2 deletions gbm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ void classify(PointSet &pointSet,
const std::vector<Feature *> &features,
const std::vector<Label> &labels,
Regularization regularization = Regularization::None,
double regRadius = 2.5f,
double regRadius = 2.5,
bool useColors = false,
bool unclassifiedOnly = false,
bool evaluate = false,
const std::vector<int> &skip = {});
const std::vector<int> &skip = {},
const std::string &statsFile = "");

}

Expand Down
2 changes: 1 addition & 1 deletion labels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Label {
Color color;
public:
Label(const std::string &name, int trainingCode, int asprsCode, Color color = Color())
: name(name), trainingCode(trainingCode), asprsCode(asprsCode), color(color) {};
: name(name), trainingCode(trainingCode), asprsCode(asprsCode), color(color) {}

std::string getName() const { return name; }
int getTrainingCode() const { return trainingCode; }
Expand Down
19 changes: 16 additions & 3 deletions pcclassify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ int main(int argc, char **argv) {
("u,unclassified", "Only classify points that are labeled as unclassified and leave the others untouched", cxxopts::value<bool>()->default_value("false"))
("s,skip", "Do not apply these classification labels (comma separated) and leave them as-is", cxxopts::value<std::vector<int>>())
("e,eval", "If the input point cloud is labeled, enable accuracy evaluation", cxxopts::value<bool>()->default_value("false"))
("eval-result", "Write evaluation results cloud to ply file", cxxopts::value<std::string>()->default_value(""))
("stats-file", "Write evaluation statistics to json file", cxxopts::value<std::string>()->default_value(""))
("h,help", "Print usage")
;
options.parse_positional({ "input", "output", "model" });
Expand Down Expand Up @@ -98,18 +100,29 @@ int main(int argc, char **argv) {
const auto features = getFeatures(computeScales(numScales, pointSet, startResolution, radius));
std::cout << "Features: " << features.size() << std::endl;

const auto eval = result["eval"].as<bool>();
const auto evalResult = result["eval-result"].as<std::string>();
const auto statsFile = result["stats-file"].as<std::string>();
const auto regRadius = result["reg-radius"].as<double>();
const auto color = result["color"].as<bool>();
const auto unclassified = result["unclassified"].as<bool>();

if (ctype == RandomForest) {
rf::classify(*pointSet, rtrees, features, labels, regularization,
result["reg-radius"].as<double>(), result["color"].as<bool>(), result["unclassified"].as<bool>(), result["eval"].as<bool>(), skip);
regRadius, color, unclassified, eval, skip, statsFile);
}
#ifdef WITH_GBT
else {
gbm::classify(*pointSet, booster, features, labels, regularization,
result["reg-radius"].as<double>(), result["color"].as<bool>(), result["unclassified"].as<bool>(), result["eval"].as<bool>(), skip);
regRadius, color, unclassified, eval, skip, statsFile);
}
#endif

savePointSet(*pointSet, outputFile);
if (eval && !evalResult.empty())
{
savePointSet(*pointSet, evalResult);
}

}
catch (std::exception &e) {
std::cerr << "Error: " << e.what() << std::endl;
Expand Down
21 changes: 16 additions & 5 deletions pctrain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ int main(int argc, char **argv) {
("m,max-samples", "Approximate maximum number of samples for each input point cloud", cxxopts::value<int>()->default_value("100000"))
("radius", "Radius size to use for neighbor search (meters)", cxxopts::value<double>()->default_value(MKSTR(RADIUS)))
("e,eval", "Labeled point cloud to use for model accuracy evaluation", cxxopts::value<std::string>()->default_value(""))
("eval-result", "Path where to store evaluation results (PLY)", cxxopts::value<std::string>()->default_value(""))
("stats", "Path where to store evaluation statistics (JSON)", cxxopts::value<std::string>()->default_value(""))
("c,classifier", "Which classifier type to use (rf = Random Forest, gbt = Gradient Boosted Trees)", cxxopts::value<std::string>()->default_value("rf"))
("classes", "Train only these classification classes (comma separated IDs)", cxxopts::value<std::vector<int>>())
("h,help", "Print usage")
Expand Down Expand Up @@ -54,6 +56,10 @@ int main(int argc, char **argv) {
const auto radius = result["radius"].as<double>();
const auto maxSamples = result["max-samples"].as<int>();
const auto classifier = result["classifier"].as<std::string>();
const auto evalResult = result["eval-result"].as<std::string>();
const auto statsFile = result["stats"].as<std::string>();
const auto evalFilename = result["eval"].as<std::string>();

std::vector<int> classes = {};
if (result.count("classes")) classes = result["classes"].as<std::vector<int>>();

Expand Down Expand Up @@ -84,8 +90,8 @@ int main(int argc, char **argv) {
}
#endif

if (result["eval"].count()) {
const std::string evalFilename = result["eval"].as<std::string>();
if (!evalFilename.empty()) {

std::cout << "Evaluating on " << evalFilename << " ..." << std::endl;

const ClassifierType ctype = fingerprint(modelFilename);
Expand Down Expand Up @@ -113,16 +119,21 @@ int main(int argc, char **argv) {
std::cout << "Features: " << evalFeatures.size() << std::endl;

if (ctype == RandomForest) {
rf::classify(*evalPointSet, rtrees, evalFeatures, labels, Regularization::None, 2.5, true, false, true);
rf::classify(*evalPointSet, rtrees, evalFeatures, labels, Regularization::None, 2.5,
true, false, true, {}, statsFile);
}

#ifdef WITH_GBT
else {
gbm::classify(*evalPointSet, booster, evalFeatures, labels, Regularization::None, 2.5, true, false, true);
gbm::classify(*evalPointSet, booster, evalFeatures, labels, Regularization::None, 2.5,
true, false, true, {}, statsFile);
}
#endif

savePointSet(*evalPointSet, "evaluation_results.ply");
if (!evalResult.empty()) {
savePointSet(*evalPointSet, evalResult);
}

}
}
catch (std::exception &e) {
Expand Down
6 changes: 4 additions & 2 deletions randomforest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ RandomForest *train(const std::vector<std::string> &filenames,
const double radius,
const int maxSamples,
const std::vector<int> &classes) {

ForestParams params;
params.n_trees = numTrees;
params.max_depth = treeDepth;
Expand Down Expand Up @@ -69,12 +70,13 @@ void classify(PointSet &pointSet,
const bool useColors,
const bool unclassifiedOnly,
const bool evaluate,
const std::vector<int> &skip) {
const std::vector<int> &skip,
const std::string &statsFile) {
classifyData<float>(pointSet,
[&rtrees](const float *ft, float *probs) {
rtrees->evaluate(ft, probs);
},
features, labels, regularization, regRadius, useColors, unclassifiedOnly, evaluate, skip);
features, labels, regularization, regRadius, useColors, unclassifiedOnly, evaluate, skip, statsFile);
}

}
5 changes: 3 additions & 2 deletions randomforest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ void classify(PointSet &pointSet,
const std::vector<Feature *> &features,
const std::vector<Label> &labels,
Regularization regularization = Regularization::None,
double regRadius = 2.5f,
double regRadius = 2.5,
bool useColors = false,
bool unclassifiedOnly = false,
bool evaluate = false,
const std::vector<int> &skip = {});
const std::vector<int> &skip = {},
const std::string &statsFile = "");

}
#endif
Loading

0 comments on commit f8e35c4

Please sign in to comment.