From f8e35c4684fdaf9816f95221d99b53446b2d55b4 Mon Sep 17 00:00:00 2001
From: Piero Toffanin <pt@masseranolabs.com>
Date: Wed, 12 Apr 2023 16:02:55 -0400
Subject: [PATCH] Added extended training statistics and exporting (2) (#18)

* Fixed indentation + grouped headers in generated solution

* Added statistics calculation

* Improved statistics printing

* Added stats.json generation

* Cleanup

* Fixed gcc compilation

* Fixed average calculation

* Reverted indentation

* Reverted linter advice

* Extracted logic and added new metrics

* Reverted linter advice

* Refactoring

* Sensitivity --> recall

* Simplify

---------

Co-authored-by: Luca Di Leo <admin@lucadileo.it>
---
 CMakeLists.txt   |   3 +-
 classifier.hpp   |  36 ++++----
 gbm.cpp          |   6 +-
 gbm.hpp          |   5 +-
 labels.hpp       |   2 +-
 pcclassify.cpp   |  19 ++++-
 pctrain.cpp      |  21 +++--
 randomforest.cpp |   6 +-
 randomforest.hpp |   5 +-
 statistics.hpp   | 215 +++++++++++++++++++++++++++++++++++++++++++++++
 10 files changed, 284 insertions(+), 34 deletions(-)
 create mode 100644 statistics.hpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index ead4c86..18f8d8a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -85,6 +85,7 @@ include_directories(${EIGEN3_INCLUDE_DIR})
 include_directories(vendor/ethz)
 
 set(SOURCES classifier.cpp scale.cpp point_io.cpp randomforest.cpp features.cpp color.cpp labels.cpp)
+set(HEADERS classifier.hpp scale.hpp point_io.hpp randomforest.hpp features.hpp color.hpp labels.hpp statistics.hpp)
 if (WITH_GBT)
     list(APPEND SOURCES gbm.cpp)
     list(APPEND HEADERS gbm.hpp)
@@ -96,7 +97,7 @@ if (WITH_PDAL)
     set(PDAL_LIB ${PDAL_LIBRARIES})
 endif()
 
-add_library(libopc OBJECT ${SOURCES})
+add_library(libopc OBJECT ${SOURCES} ${HEADERS})
 
 if (WITH_GBT)
     add_dependencies(libopc lightgbm)
diff --git a/classifier.hpp b/classifier.hpp
index 6a28624..4b84686 100644
--- a/classifier.hpp
+++ b/classifier.hpp
@@ -3,11 +3,13 @@
 
 #include <vector>
 #include <random>
+#include <cmath>
 
 #include "features.hpp"
 #include "labels.hpp"
 #include "constants.hpp"
 #include "point_io.hpp"
+#include "statistics.hpp"
 
 enum Regularization { None, LocalSmooth };
 Regularization parseRegularization(const std::string &regularization);
@@ -15,12 +17,13 @@ Regularization parseRegularization(const std::string &regularization);
 enum ClassifierType { RandomForest, GradientBoostedTrees };
 ClassifierType fingerprint(const std::string &modelFile);
 
+
 template <typename F, typename I>
 void getTrainingData(const std::vector<std::string> &filenames,
     double *startResolution,
-    int numScales,
-    double radius,
-    int maxSamples,
+    const int numScales,
+    const double radius,
+    const int maxSamples,
     const std::vector<int> &asprsClasses,
     F storeFeatures,
     I init) {
@@ -112,12 +115,14 @@ void classifyData(PointSet &pointSet,
     F evaluateFunc,
     const std::vector<Feature *> &features,
     const std::vector<Label> &labels,
-    Regularization regularization,
-    double regRadius,
-    bool useColors,
-    bool unclassifiedOnly,
-    bool evaluate,
-    const std::vector<int> &skip) {
+    const Regularization regularization,
+    const double regRadius,
+    const bool useColors,
+    const bool unclassifiedOnly,
+    const bool evaluate,
+    const std::vector<int> &skip,
+    const std::string &statsFile) {
+
     std::cout << "Classifying..." << std::endl;
     pointSet.base->labels.resize(pointSet.base->count());
 
@@ -214,7 +219,6 @@ void classifyData(PointSet &pointSet,
         throw std::runtime_error("Invalid regularization");
     }
 
-    std::size_t correct = 0;
     if (!useColors && !pointSet.hasLabels()) pointSet.labels.resize(pointSet.count());
     std::vector<bool> skipMap(255, false);
     for (size_t i = 0; i < skip.size(); i++) {
@@ -224,6 +228,8 @@ void classifyData(PointSet &pointSet,
 
     auto train2asprsCodes = getTrain2AsprsCodes();
 
+    Statistics stats(labels);
+
     #pragma omp parallel for
     for (long long int i = 0; i < pointSet.count(); i++) {
         const size_t idx = pointSet.pointMap[i];
@@ -232,10 +238,7 @@ void classifyData(PointSet &pointSet,
         auto label = labels[bestClass];
 
         if (evaluate) {
-            if (pointSet.labels[i] == bestClass) {
-                #pragma omp atomic
-                correct++;
-            }
+            stats.record(bestClass, pointSet.labels[i]);
         }
 
         bool update = true;
@@ -266,8 +269,9 @@ void classifyData(PointSet &pointSet,
     }
 
     if (evaluate) {
-        float modelErr = (1.f - static_cast<float>(correct) / static_cast<float>(pointSet.count()));
-        std::cout << "Model error: " << std::setprecision(4) << (modelErr * 100.f) << "%" << std::endl;
+        stats.finalize();
+        stats.print();
+        if (!statsFile.empty()) stats.writeToFile(statsFile);
     }
 }
 
diff --git a/gbm.cpp b/gbm.cpp
index f0ac9cc..508113e 100644
--- a/gbm.cpp
+++ b/gbm.cpp
@@ -170,7 +170,9 @@ void classify(PointSet &pointSet,
     const bool useColors,
     const bool unclassifiedOnly,
     const bool evaluate,
-    const std::vector<int> &skip) {
+    const std::vector<int> &skip,
+    const std::string &statsFile
+) {
 
     LightGBM::PredictionEarlyStopConfig early_stop_config;
     auto earlyStop = LightGBM::CreatePredictionEarlyStopInstance("none", early_stop_config);
@@ -179,7 +181,7 @@ void classify(PointSet &pointSet,
         [&booster, &earlyStop](const double *ft, double *probs) {
             booster->Predict(ft, probs, &earlyStop);
         },
-        features, labels, regularization, regRadius, useColors, unclassifiedOnly, evaluate, skip);
+        features, labels, regularization, regRadius, useColors, unclassifiedOnly, evaluate, skip, statsFile);
 }
 
 }
diff --git a/gbm.hpp b/gbm.hpp
index 3948067..5abcbbf 100644
--- a/gbm.hpp
+++ b/gbm.hpp
@@ -47,11 +47,12 @@ void classify(PointSet &pointSet,
     const std::vector<Feature *> &features,
     const std::vector<Label> &labels,
     Regularization regularization = Regularization::None,
-    double regRadius = 2.5f,
+    double regRadius = 2.5,
     bool useColors = false,
     bool unclassifiedOnly = false,
     bool evaluate = false,
-    const std::vector<int> &skip = {});
+    const std::vector<int> &skip = {},
+    const std::string &statsFile = "");
 
 }
 
diff --git a/labels.hpp b/labels.hpp
index e5886de..3f82e9e 100644
--- a/labels.hpp
+++ b/labels.hpp
@@ -17,7 +17,7 @@ class Label {
     Color color;
 public:
     Label(const std::string &name, int trainingCode, int asprsCode, Color color = Color())
-        : name(name), trainingCode(trainingCode), asprsCode(asprsCode), color(color) {};
+        : name(name), trainingCode(trainingCode), asprsCode(asprsCode), color(color) {}
 
     std::string getName() const { return name; }
     int getTrainingCode() const { return trainingCode; }
diff --git a/pcclassify.cpp b/pcclassify.cpp
index ba0b8b2..954e4c8 100644
--- a/pcclassify.cpp
+++ b/pcclassify.cpp
@@ -21,6 +21,8 @@ int main(int argc, char **argv) {
         ("u,unclassified", "Only classify points that are labeled as unclassified and leave the others untouched", cxxopts::value<bool>()->default_value("false"))
         ("s,skip", "Do not apply these classification labels (comma separated) and leave them as-is", cxxopts::value<std::vector<int>>())
         ("e,eval", "If the input point cloud is labeled, enable accuracy evaluation", cxxopts::value<bool>()->default_value("false"))
+        ("eval-result", "Write evaluation results cloud to ply file", cxxopts::value<std::string>()->default_value(""))
+        ("stats-file", "Write evaluation statistics to json file", cxxopts::value<std::string>()->default_value(""))
         ("h,help", "Print usage")
         ;
     options.parse_positional({ "input", "output", "model" });
@@ -98,18 +100,29 @@ int main(int argc, char **argv) {
         const auto features = getFeatures(computeScales(numScales, pointSet, startResolution, radius));
         std::cout << "Features: " << features.size() << std::endl;
 
+        const auto eval = result["eval"].as<bool>();
+        const auto evalResult = result["eval-result"].as<std::string>();
+        const auto statsFile = result["stats-file"].as<std::string>();
+        const auto regRadius = result["reg-radius"].as<double>();
+        const auto color = result["color"].as<bool>();
+        const auto unclassified = result["unclassified"].as<bool>();
+
         if (ctype == RandomForest) {
             rf::classify(*pointSet, rtrees, features, labels, regularization,
-                result["reg-radius"].as<double>(), result["color"].as<bool>(), result["unclassified"].as<bool>(), result["eval"].as<bool>(), skip);
+                regRadius, color, unclassified, eval, skip, statsFile);
         }
         #ifdef WITH_GBT
         else {
             gbm::classify(*pointSet, booster, features, labels, regularization,
-                result["reg-radius"].as<double>(), result["color"].as<bool>(), result["unclassified"].as<bool>(), result["eval"].as<bool>(), skip);
+                regRadius, color, unclassified, eval, skip, statsFile);
         }
         #endif
 
-        savePointSet(*pointSet, outputFile);
+        if (eval && !evalResult.empty())
+        {
+            savePointSet(*pointSet, evalResult);
+        }
+        
     }
     catch (std::exception &e) {
         std::cerr << "Error: " << e.what() << std::endl;
diff --git a/pctrain.cpp b/pctrain.cpp
index 7ab7f4c..8f38c31 100644
--- a/pctrain.cpp
+++ b/pctrain.cpp
@@ -22,6 +22,8 @@ int main(int argc, char **argv) {
         ("m,max-samples", "Approximate maximum number of samples for each input point cloud", cxxopts::value<int>()->default_value("100000"))
         ("radius", "Radius size to use for neighbor search (meters)", cxxopts::value<double>()->default_value(MKSTR(RADIUS)))
         ("e,eval", "Labeled point cloud to use for model accuracy evaluation", cxxopts::value<std::string>()->default_value(""))
+        ("eval-result", "Path where to store evaluation results (PLY)", cxxopts::value<std::string>()->default_value(""))
+        ("stats", "Path where to store evaluation statistics (JSON)", cxxopts::value<std::string>()->default_value(""))
         ("c,classifier", "Which classifier type to use (rf = Random Forest, gbt = Gradient Boosted Trees)", cxxopts::value<std::string>()->default_value("rf"))
         ("classes", "Train only these classification classes (comma separated IDs)", cxxopts::value<std::vector<int>>())
         ("h,help", "Print usage")
@@ -54,6 +56,10 @@ int main(int argc, char **argv) {
         const auto radius = result["radius"].as<double>();
         const auto maxSamples = result["max-samples"].as<int>();
         const auto classifier = result["classifier"].as<std::string>();
+        const auto evalResult = result["eval-result"].as<std::string>();
+        const auto statsFile = result["stats"].as<std::string>();
+        const auto evalFilename = result["eval"].as<std::string>();
+
         std::vector<int> classes = {};
         if (result.count("classes")) classes = result["classes"].as<std::vector<int>>();
 
@@ -84,8 +90,8 @@ int main(int argc, char **argv) {
         }
         #endif
 
-        if (result["eval"].count()) {
-            const std::string evalFilename = result["eval"].as<std::string>();
+        if (!evalFilename.empty()) {
+            
             std::cout << "Evaluating on " << evalFilename << " ..." << std::endl;
 
             const ClassifierType ctype = fingerprint(modelFilename);
@@ -113,16 +119,21 @@ int main(int argc, char **argv) {
             std::cout << "Features: " << evalFeatures.size() << std::endl;
 
             if (ctype == RandomForest) {
-                rf::classify(*evalPointSet, rtrees, evalFeatures, labels, Regularization::None, 2.5, true, false, true);
+                rf::classify(*evalPointSet, rtrees, evalFeatures, labels, Regularization::None, 2.5, 
+                    true, false, true, {}, statsFile);
             }
 
             #ifdef WITH_GBT
             else {
-                gbm::classify(*evalPointSet, booster, evalFeatures, labels, Regularization::None, 2.5, true, false, true);
+                gbm::classify(*evalPointSet, booster, evalFeatures, labels, Regularization::None, 2.5, 
+                    true, false, true, {}, statsFile);
             }
             #endif
 
-            savePointSet(*evalPointSet, "evaluation_results.ply");
+            if (!evalResult.empty()) {
+                savePointSet(*evalPointSet, evalResult);
+            }
+
         }
     }
     catch (std::exception &e) {
diff --git a/randomforest.cpp b/randomforest.cpp
index cc3d9f0..b190f38 100644
--- a/randomforest.cpp
+++ b/randomforest.cpp
@@ -10,6 +10,7 @@ RandomForest *train(const std::vector<std::string> &filenames,
     const double radius,
     const int maxSamples,
     const std::vector<int> &classes) {
+
     ForestParams params;
     params.n_trees = numTrees;
     params.max_depth = treeDepth;
@@ -69,12 +70,13 @@ void classify(PointSet &pointSet,
     const bool useColors,
     const bool unclassifiedOnly,
     const bool evaluate,
-    const std::vector<int> &skip) {
+    const std::vector<int> &skip,
+    const std::string &statsFile) {
     classifyData<float>(pointSet,
         [&rtrees](const float *ft, float *probs) {
             rtrees->evaluate(ft, probs);
         },
-        features, labels, regularization, regRadius, useColors, unclassifiedOnly, evaluate, skip);
+        features, labels, regularization, regRadius, useColors, unclassifiedOnly, evaluate, skip, statsFile);
 }
 
 }
diff --git a/randomforest.hpp b/randomforest.hpp
index f2c9728..565246d 100644
--- a/randomforest.hpp
+++ b/randomforest.hpp
@@ -34,11 +34,12 @@ void classify(PointSet &pointSet,
     const std::vector<Feature *> &features,
     const std::vector<Label> &labels,
     Regularization regularization = Regularization::None,
-    double regRadius = 2.5f,
+    double regRadius = 2.5,
     bool useColors = false,
     bool unclassifiedOnly = false,
     bool evaluate = false,
-    const std::vector<int> &skip = {});
+    const std::vector<int> &skip = {},
+    const std::string &statsFile = "");
 
 }
 #endif
diff --git a/statistics.hpp b/statistics.hpp
new file mode 100644
index 0000000..27ed158
--- /dev/null
+++ b/statistics.hpp
@@ -0,0 +1,215 @@
+
+#ifndef STATISTICS_H
+#define STATISTICS_H
+
+#include <vector>
+#include <map>
+
+class Statistics{
+    struct Counts{
+        size_t tp = 0; // True positives
+        size_t fp = 0; // False positives
+        size_t fn = 0; // False negatives
+    };
+
+    struct LabelStat{
+        std::string name;
+        double accuracy;
+        double recall;
+        double precision;
+        double f1;
+        Counts counts;
+
+        LabelStat(const std::string &name, const double accuracy, const double recall, const double precision, const double f1, const Counts &counts) : 
+            name(name), accuracy(accuracy), recall(recall), precision(precision), f1(f1), counts(counts) {}
+    };
+
+    std::map<int, Counts> stats;
+    const std::vector<Label> &labels;
+    size_t totalSamples = 0;
+
+    double accuracy;
+    double avgAccuracy;
+    double avgRecall;
+    double avgPrecision;
+    double avgF1;
+
+    std::vector<LabelStat> labelStats;
+public:
+    Statistics(const std::vector<Label> &labels) : labels(labels){
+        for (auto &label : labels){
+            stats[label.getTrainingCode()] = Counts();
+        }
+    }
+
+    inline void record(const int predicted, const int truth){
+        if (predicted == truth){
+            #pragma omp atomic
+            stats[predicted].tp++;
+        }else{
+            #pragma omp atomic
+            stats[predicted].fp++;
+
+            #pragma omp atomic
+            stats[truth].fn++;
+        }
+
+        #pragma omp atomic
+        totalSamples++;
+    }
+
+    void finalize(){
+        const auto cnt = labels.size();
+
+        auto accuracyCount = 0;
+        auto recallCount = 0;
+        auto precisionCount = 0;
+        auto f1Count = 0;
+
+        size_t sumTp = 0;
+
+        double sumAccuracy = 0.0;
+        double sumF1 = 0.0;
+        double sumRecall = 0.0;
+        double sumPrecision = 0.0;
+
+        for (size_t i = 0; i < cnt; ++i){
+            auto label = labels[i];
+            const Counts &cnts = stats[label.getTrainingCode()];
+            sumTp += cnts.tp;
+
+            const auto tp = static_cast<double>(cnts.tp);
+
+            const double accuracy = tp / (cnts.tp + cnts.fn + cnts.fp);
+
+            if (!std::isnan(accuracy)) {
+                sumAccuracy += accuracy;
+                accuracyCount++;
+            }
+
+            const double precision = tp / (cnts.tp + cnts.fp);
+
+            if (!std::isnan(precision)) {
+                sumPrecision += precision;
+                precisionCount++;
+            }
+
+            const double recall = tp / (cnts.tp + cnts.fn);
+
+            if (!std::isnan(recall)) {
+                sumRecall += recall;
+                recallCount++;
+            }
+
+            const double f1 = 2 * (precision * recall) / (precision + recall);
+
+            if (!std::isnan(f1)) {
+                sumF1 += f1;
+                f1Count++;
+            }
+
+            labelStats.emplace_back(label.getName(), accuracy, recall, precision, f1, cnts);
+        }
+
+        accuracy = static_cast<double>(sumTp) / totalSamples;
+
+        avgAccuracy = sumAccuracy / accuracyCount;
+        avgF1 = sumF1 / f1Count;
+        avgPrecision = sumPrecision / precisionCount;
+        avgRecall = sumRecall / recallCount;
+    }
+
+    void print() const{
+        std::cout << "Statistics:" << std::endl;
+        std::cout << "  Accuracy: " << std::fixed << std::setprecision(2) << accuracy * 100 << "%" << std::endl << std::endl;
+
+        std::cout << "  " << std::setw(24) << "Label " << " | " << std::setw(10) << "Accuracy" << " | " << std::setw(11) << "Recall" << " | " << std::setw(10) << "Precision" << " | " << std::setw(10) << "F1" << " | "  << std::endl;
+        std::cout << "  " << std::setw(24) << std::string(24, '-') << " | ";
+        std::cout << std::setw(10) << std::string(10, '-') << " | ";
+        std::cout << std::setw(10) << std::string(11, '-') << " | ";
+        std::cout << std::setw(10) << std::string(10, '-') << " | ";
+        std::cout << std::setw(10) << std::string(10, '-') << " | "  << std::endl;
+
+        for (const auto &label : labelStats){
+
+            if (std::isnan(label.accuracy) && std::isnan(label.f1)) continue;
+
+            std::cout << "  " << std::setw(24) << label.name << " | ";
+
+            if (!std::isnan(label.accuracy))
+                std::cout << std::setw(9) << std::fixed << std::setprecision(2) << label.accuracy * 100 << "% | ";
+            else
+                std::cout << std::setw(9) << "N/A" << " | ";
+
+            if (!std::isnan(label.recall))
+                std::cout << std::setw(10) << std::fixed << std::setprecision(2) << label.recall * 100 << "% | ";
+            else
+                std::cout << std::setw(10) << "N/A" << " | ";
+
+            if (!std::isnan(label.precision))
+                std::cout << std::setw(9) << std::fixed << std::setprecision(2) << label.precision * 100 << "% | ";
+            else
+                std::cout << std::setw(9) << "N/A" << " | ";
+            
+            if (!std::isnan(label.f1))
+                std::cout << std::setw(10) << std::fixed << std::setprecision(2) << label.f1 << " | ";
+            else
+                std::cout << std::setw(10) << "N/A" << " | ";
+
+            std::cout << std::endl;
+        }
+
+        std::cout << "  " << std::setw(24) << "(Average)" << " | ";
+        std::cout << std::setw(9) << std::fixed << std::setprecision(2) << avgAccuracy * 100 << "% | ";
+        std::cout << std::setw(10) << std::fixed << std::setprecision(2) << avgRecall * 100 << "% | ";
+        std::cout << std::setw(9) << std::fixed << std::setprecision(2) << avgPrecision * 100 << "% | ";
+        std::cout << std::setw(10) << std::fixed << std::setprecision(2) << avgF1 << " | " << std::endl;
+
+        std::cout << std::endl;
+    }
+
+
+    void writeToFile(const std::string &jsonFile){
+        std::ofstream o(jsonFile);
+        if (!o.is_open()){
+            std::cerr << "Unable to create stats file" << std::endl;
+            return;
+        }
+
+        json j = json{
+            {"accuracy", accuracy}
+        };
+
+        for (const auto &label : labelStats){
+            if (std::isnan(label.accuracy) && std::isnan(label.f1)) continue;
+            const auto name = label.name;
+
+            if (!std::isnan(label.accuracy))
+                j["labels"][name]["accuracy"] = label.accuracy;
+            else
+                j["labels"][name]["accuracy"] = nullptr;
+
+            if (!std::isnan(label.precision))
+                j["labels"][name]["precision"] = label.precision;
+            else
+                j["labels"][name]["precision"] = nullptr;
+
+            if (!std::isnan(label.recall))
+                j["labels"][name]["recall"] = label.recall;
+            else
+                j["labels"][name]["recall"] = nullptr;
+
+            if (!std::isnan(label.f1))
+                j["labels"][name]["f1"] = label.f1;
+            else
+                j["labels"][name]["f1"] = nullptr;
+        }
+
+        o << j.dump(4);
+        o.close();
+        std::cout << "Statistics saved to " << jsonFile << std::endl;
+    }
+};
+
+
+#endif