From 3b6f8a7fe9014cc46e1666ec21b24aa2abb6ce79 Mon Sep 17 00:00:00 2001 From: Felix Schlepper Date: Tue, 28 Feb 2023 11:41:28 +0100 Subject: [PATCH] TRD: PID: fixups This commit renames the factory function to better reflect its purpose. Additionally, z-row merging and charge correction have been added to the codebase to improve functionality. Pytorch policy and LQND policy have been added as new policies. The README has also been added to provide additional explanation of the code. As part of this update, the pid policy map and pidvalue alias have been removed. A print overload has been added for the policy enum to improve readability. Various minor fixes have also been made to improve overall code quality. Also included are various changes to the class layout and the ccdb object for LUTs. Signed-off-by: Felix Schlepper --- .../TRD/include/DataFormatsTRD/PID.h | 49 ++++-- Detectors/TRD/pid/CMakeLists.txt | 2 + Detectors/TRD/pid/README.md | 63 +++++++ Detectors/TRD/pid/include/TRDPID/Dummy.h | 4 +- Detectors/TRD/pid/include/TRDPID/LQND.h | 159 ++++++++++++++++++ Detectors/TRD/pid/include/TRDPID/ML.h | 52 ++++-- Detectors/TRD/pid/include/TRDPID/PIDBase.h | 28 ++- .../TRD/pid/include/TRDPID/PIDParameters.h | 2 + Detectors/TRD/pid/macros/CMakeLists.txt | 6 +- Detectors/TRD/pid/macros/ccdbModelUpload.C | 82 --------- Detectors/TRD/pid/macros/ccdbPIDUpload.C | 112 ++++++++++++ Detectors/TRD/pid/macros/makeTestLUTs.C | 40 +++++ Detectors/TRD/pid/src/ML.cxx | 124 +++++--------- Detectors/TRD/pid/src/PIDBase.cxx | 57 +++++-- Detectors/TRD/pid/src/TRDPIDLinkDef.h | 5 + Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h | 5 + .../workflow/src/TRDGlobalTrackingSpec.cxx | 21 ++- 17 files changed, 598 insertions(+), 213 deletions(-) create mode 100644 Detectors/TRD/pid/README.md create mode 100644 Detectors/TRD/pid/include/TRDPID/LQND.h delete mode 100644 Detectors/TRD/pid/macros/ccdbModelUpload.C create mode 100644 Detectors/TRD/pid/macros/ccdbPIDUpload.C create mode 100644 Detectors/TRD/pid/macros/makeTestLUTs.C diff --git a/DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h b/DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h index cbb62d8f1e164..5687b58c6ca1f 100644 --- a/DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h +++ b/DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace o2 { @@ -30,44 +31,64 @@ enum class PIDPolicy : unsigned int { LQ1D = 0, ///< 1-Dimensional Likelihood model LQ3D, ///< 3-Dimensional Likelihood model +#ifdef TRDPID_WITH_ONNX // ML models XGB, ///< XGBOOST + PY, ///< Pytorch +#endif // Do not add anything after this! NMODELS, ///< Count of all models - Test, ///< Load object for testing Dummy, ///< Dummy object outputting -1.f DEFAULT = Dummy, ///< The default option }; +inline std::ostream& operator<<(std::ostream& os, const PIDPolicy& policy) +{ + std::string name; + switch (policy) { + case PIDPolicy::LQ1D: + name = "LQ1D"; + break; + case PIDPolicy::LQ3D: + name = "LQ3D"; + break; +#ifdef TRDPID_WITH_ONNX + case PIDPolicy::XGB: + name = "XGBoost"; + break; + case PIDPolicy::PY: + name = "PyTorch"; + break; +#endif + case PIDPolicy::Dummy: + name = "Dummy"; + break; + default: + name = "Default"; + } + os << name; + return os; +} + /// Transform PID policy from string to enum. static const std::unordered_map PIDPolicyString{ // Classical Algorithms {"LQ1D", PIDPolicy::LQ1D}, {"LQ3D", PIDPolicy::LQ3D}, +#ifdef TRDPID_WITH_ONNX // ML models {"XGB", PIDPolicy::XGB}, + {"PY", PIDPolicy::PY}, +#endif // General - {"TEST", PIDPolicy::Test}, {"DUMMY", PIDPolicy::Dummy}, // Default {"default", PIDPolicy::DEFAULT}, }; -/// Transform PID policy from string to enum. -static const char* PIDPolicyEnum[] = { - "LQ1D", - "LQ3D", - "XGBoost", - "NMODELS", - "Test", - "Dummy", - "default(=TODO)"}; - -using PIDValue = float; - } // namespace trd } // namespace o2 diff --git a/Detectors/TRD/pid/CMakeLists.txt b/Detectors/TRD/pid/CMakeLists.txt index a3d5093c320fd..158103cf007c2 100644 --- a/Detectors/TRD/pid/CMakeLists.txt +++ b/Detectors/TRD/pid/CMakeLists.txt @@ -31,12 +31,14 @@ if(ONNXRuntime_FOUND) HEADERS include/TRDPID/PIDBase.h include/TRDPID/PIDParameters.h include/TRDPID/ML.h + include/TRDPID/LQND.h include/TRDPID/Dummy.h) else() o2_target_root_dictionary(TRDPID HEADERS include/TRDPID/PIDBase.h include/TRDPID/PIDParameters.h include/TRDPID/Dummy.h + include/TRDPID/LQND.h LINKDEF src/TRDPIDNoMLLinkDef.h) endif() diff --git a/Detectors/TRD/pid/README.md b/Detectors/TRD/pid/README.md new file mode 100644 index 0000000000000..a38add5220b69 --- /dev/null +++ b/Detectors/TRD/pid/README.md @@ -0,0 +1,63 @@ +# Particle Identification with TRD +## Usage +Activate PID during tracking with the '--with-pid' flag. + + o2-trd-global-tracking --with-pid --policy ML + +Specify a which algorithm (called policy) should be use. +Implemented are the following: + + - LQ1D + - LQ3D + - ML (every model, which is exported to the ONNX format): + - XGB (XGBoost model) + - NN (Pytorch model) + - Dummy (returns only -1) + - Test (one of the above) + - Default (one of the above, gets picked if '--policy' is unspecified) + +## Implementation details +### Tracking workflow +Every TRDTrack gets a PID value set (mSignal), which then gets propergated to the AO2D writer. + +### Basic Interface +The base interface for PID is defined in [here](include/TRDPID/PIDBase.h). +The 'init' function is such that each policy can specify what if anything it needs from the CCDB. +For the 'process' each policy defines how a TRDTrack gets assigned a PID value. +Additionally, the base class implements how to get the _corrected charges_ from the tracklets. +_Corrected charges_ means z-row merged and calibrated charges. + +### Classical Likelihood +The classical LQND policies ([here](include/TRDPID/LQND.h)) require an array of lookup tables (LUTs) from the ccdb. +$N$ stands for the dimension. +Then the electron likelihood for layer $i$ is defined as this: + +$$L_i(e|Q_i)=\frac{P(Q_i|e)}{P(Q_i|e)+P(Q_i|\pi)}$$ + +From the charge $Q_i$ the LUTs give the corresponding $L_i$. +The _combined electron likelihood_ is obtained by this formula: + +$$L(e|Q)=\frac{\prod_i L_i(e|Q_i)}{\prod_i L_i(e|Q_i) + \prod_i L_i(\pi|Q_i)}$$ + +where $L_i(\pi|Q_i)=1-L_i(e|Q_i)$. + + +Extension to higher dimensions is easy each tracklet has charges $Q_j$ which cover the integral of the pulse height spectrum in different slice ($j\in [0,1,2]$). +In our case $Q0$ covers the pulse height peak, $Q1$ the Transition Radiation peak and $Q2$ the plateau. +For each charge $j$ a LUT is available which gives the likelihood $L^e_j$. +For each layer $i$ the likelihood is then: + +$$L_i(e|Q)=\frac{\prod_j L_{i,j}(e|Q_j)}{\prod_j L_{i,j}(e|Q_j) + \prod_j L_{i,j}(\pi|Q_j)}$$ + +The combined electron likelihood is then: + +$$L(e|Q)=\frac{\prod_{i,j} L_{i,j}(e|Q_j)}{\prod_{i,j} L_{i,j}(e|Q_j) + \prod_{i,j} L_{i,j}(\pi|Q_j)}$$ + + +### Machine Learning +The ML policies ([here](include/TRDPID/ML.h)) are uploaded to the CCDB in the ONNX file format (most python machine learning libraries support this standardized format). +In O2 we leverage the ONNXRuntime to use these formats and calculate a PID value. +The models can thus be trained in python which is very convenient. +The code should take care of most of the annoying stuff. +Policies just have to specify how to get the electron likelihood from the ONNXRuntime output (each python library varies in that somewhat). +The 'prepareModelInput' prepares the TRDTracks as input. diff --git a/Detectors/TRD/pid/include/TRDPID/Dummy.h b/Detectors/TRD/pid/include/TRDPID/Dummy.h index 3158ad715fc61..4203f4c56953a 100644 --- a/Detectors/TRD/pid/include/TRDPID/Dummy.h +++ b/Detectors/TRD/pid/include/TRDPID/Dummy.h @@ -34,13 +34,13 @@ class Dummy final : public PIDBase using PIDBase::PIDBase; public: - ~Dummy() final = default; + ~Dummy() = default; /// Do absolutely nothing. void init(o2::framework::ProcessingContext& pc) final{}; /// Everything below 0.f indicates nothing available. - PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) final + float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final { return -1.f; }; diff --git a/Detectors/TRD/pid/include/TRDPID/LQND.h b/Detectors/TRD/pid/include/TRDPID/LQND.h new file mode 100644 index 0000000000000..e4be48f93d1f6 --- /dev/null +++ b/Detectors/TRD/pid/include/TRDPID/LQND.h @@ -0,0 +1,159 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +/// \file LQND.h +/// \brief This file provides the interface for loglikehood policies +/// \author Felix Schlepper + +#ifndef O2_TRD_LQND_H +#define O2_TRD_LQND_H + +#include "TGraph.h" +#include "TRDPID/PIDBase.h" +#include "DataFormatsTRD/PID.h" +#include "DataFormatsTRD/Constants.h" +#include "Framework/ProcessingContext.h" +#include "Framework/InputRecord.h" +#include "DataFormatsTRD/CalibratedTracklet.h" +#include "DetectorsBase/Propagator.h" +#include "Framework/Logger.h" + +#include +#include +#include +#include +#include + +namespace o2 +{ +namespace trd +{ +namespace detail +{ +/// Lookup Table class for ccdb upload +template +class LUT +{ + public: + LUT() = default; + LUT(std::vector p, std::vector l) : mIntervalsP(p), mLUTs(l) {} + + // + const TGraph& get(float p, bool isNegative, int iDim = 0) const + { + auto upper = std::upper_bound(mIntervalsP.begin(), mIntervalsP.end(), p); + if (upper == mIntervalsP.end()) { + // outside of momentum intervals, should not happen + return mLUTs[0]; + } + auto index = std::distance(mIntervalsP.begin(), upper); + index += (isNegative) ? 0 : mIntervalsP.size() * nDim; + if constexpr (nDim == 1) { + return mLUTs[index]; + } else { + if (iDim == 0) { + return mLUTs[index + 0]; + } else if (iDim == 1) { + return mLUTs[index + 1]; + } else { + return mLUTs[index + 2]; + } + } + } + + private: + std::vector mIntervalsP; ///< half-open interval upper bounds starting at 0, e.g., for {1.0,2.0,...} is (-inf,1.0], (1.0,2.0], (2.0, ...) + std::vector mLUTs; ///< corresponding likelihood lookup tables + + ClassDefNV(LUT, 1); +}; +} // namespace detail + +/// This is the ML Base class which defines the interface all machine learning +/// models. +template +class LQND : public PIDBase +{ + static_assert(nDim == 1 || nDim == 3, "Likelihood only for 1/3 dimension"); + using PIDBase::PIDBase; + + public: + ~LQND() = default; + + void init(o2::framework::ProcessingContext& pc) final + { + // retrieve lookup tables (LUTs) from ccdb + mLUTs = *(pc.inputs().get*>(Form("lq%ddlut", nDim))); + } + + float process(const TrackTRD& trkIn, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final + { + const auto& trkSeed = isTPCTRD ? input.getTPCTracks()[trkIn.getRefGlobalTrackId()].getParamOut() : input.getTPCITSTracks()[trkIn.getRefGlobalTrackId()].getParamOut(); // seeding track + auto trk = trkSeed; + + const auto isNegative = std::signbit(trkSeed.getSign()); // positive and negative charged particles are treated differently since ExB effects the charge distributions + const auto& trackletsRaw = input.getTRDTracklets(); + float lei0{1.f}, lei1{1.f}, lei2{1.f}, lpi0{1.f}, lpi1{1.f}, lpi2{1.f}; // likelihood per layer + for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) { + int trkltId = trkIn.getTrackletIndex(iLayer); + if (trkltId < 0) { // no tracklet attached + continue; + } + const auto xCalib = input.getTRDCalibratedTracklets()[trkIn.getTrackletIndex(iLayer)].getX(); + if (!o2::base::Propagator::Instance()->PropagateToXBxByBz(trk, xCalib, o2::base::Propagator::MAX_SIN_PHI, o2::base::Propagator::MAX_STEP, o2::base::Propagator::MatCorrType::USEMatCorrNONE)) { + LOGF(debug, "Track propagation failed in layer %i (pt=%f, xTrk=%f, xToGo=%f)", iLayer, trk.getPt(), trk.getX(), xCalib); + continue; + } + const auto snp = trk.getSnp(); + const auto tgl = trk.getTgl(); + const auto& trklt = trackletsRaw[trkltId]; + const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkIn, input, snp, tgl); // correct charges + if constexpr (nDim == 1) { + auto lut = mLUTs.get(trk.getP(), isNegative); + auto ll1{1.f}; + ll1 = lut.Eval(q0 + q1 + q2); + lei0 *= ll1; + lpi0 *= (1.f - ll1); + } else { + auto lut1 = mLUTs.get(trk.getP(), isNegative, 0); + auto lut2 = mLUTs.get(trk.getP(), isNegative, 1); + auto lut3 = mLUTs.get(trk.getP(), isNegative, 2); + auto ll1{1.f}; + auto ll2{1.f}; + auto ll3{1.f}; + ll1 = lut1.Eval(q0); + ll2 = lut2.Eval(q1); + ll3 = lut3.Eval(q2); + lei0 *= ll1; + lei1 *= ll2; + lei2 *= ll3; + lpi0 *= (1.f - ll1); + lpi1 *= (1.f - ll2); + lpi2 *= (1.f - ll3); + } + } + + return (lei0 * lei1 * lei2) / (lei0 * lei1 * lei2 + lpi0 * lpi1 * lpi2); // combined likelihood + } + + private: + detail::LUT mLUTs; ///< likelihood lookup tables + + ClassDefNV(LQND, 1); +}; + +using LQ1D = LQND<1>; +using LQ3D = LQND<3>; + +} // namespace trd +} // namespace o2 + +#endif diff --git a/Detectors/TRD/pid/include/TRDPID/ML.h b/Detectors/TRD/pid/include/TRDPID/ML.h index 59d902397dd4f..0b8d8ea8ea97b 100644 --- a/Detectors/TRD/pid/include/TRDPID/ML.h +++ b/Detectors/TRD/pid/include/TRDPID/ML.h @@ -40,31 +40,31 @@ class ML : public PIDBase public: void init(o2::framework::ProcessingContext& pc) final; - PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) final; + float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final; private: /// Return the electron likelihood. /// Different models have different ways to return the probability. - virtual PIDValue getELikelihood(const std::vector& tensorData) const noexcept = 0; + virtual inline float getELikelihood(const std::vector& tensorData) const noexcept = 0; /// Fetch a ML model from the ccdb via its binding - std::string fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const; - - /// Calculate pid value - template - PIDValue calculate(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks); + std::string fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const noexcept; /// Prepare model input /// Collect track properties in vector as flat array - template - std::vector prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks); + std::vector prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) const noexcept; /// Pretty print model shape std::string printShape(const std::vector& v) const noexcept; + /// Get DPL name + virtual inline std::string getName() const noexcept = 0; + // ONNX runtime Ort::Env mEnv{ORT_LOGGING_LEVEL_WARNING, "TRD-PID", - // integrate ORT logging into Fairlogger + // Integrate ORT logging into Fairlogger this way we can have + // all the nice logging while taking advantage of ORT telling us + // what to do. [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) { LOG(warn) << "Ort " << severity << ": [" << logid << "|" << category << "|" << code_location << "]: " << message << ((intptr_t)param == 3 ? " [valid]" : " [error]"); }, @@ -79,7 +79,7 @@ class ML : public PIDBase std::vector mOutputNames; ///< model output names std::vector> mOutputShapes; ///< output shape - ClassDefNV(ML, 1); + ClassDefOverride(ML, 1); }; /// XGBoost Model @@ -88,14 +88,40 @@ class XGB final : public ML using ML::ML; public: - ~XGB() final = default; + ~XGB() = default; private: - PIDValue getELikelihood(const std::vector& tensorData) const noexcept final; + /// XGBoost export is like this: + /// (label|eprob, 1-eprob). + inline float getELikelihood(const std::vector& tensorData) const noexcept + { + return tensorData[1].GetTensorData()[1]; + } + + inline std::string getName() const noexcept { return "xgb"; } ClassDefNV(XGB, 1); }; +/// PyTorch Model +class PY final : public ML +{ + using ML::ML; + + public: + ~PY() = default; + + private: + inline float getELikelihood(const std::vector& tensorData) const noexcept + { + return tensorData[0].GetTensorData()[0]; + } + + inline std::string getName() const noexcept { return "py"; } + + ClassDefNV(PY, 1); +}; + } // namespace trd } // namespace o2 diff --git a/Detectors/TRD/pid/include/TRDPID/PIDBase.h b/Detectors/TRD/pid/include/TRDPID/PIDBase.h index 38b486a02e699..a760a7df1496d 100644 --- a/Detectors/TRD/pid/include/TRDPID/PIDBase.h +++ b/Detectors/TRD/pid/include/TRDPID/PIDBase.h @@ -19,6 +19,9 @@ #include "Rtypes.h" #include "DataFormatsTRD/PID.h" #include "DataFormatsTRD/TrackTRD.h" +#include "DataFormatsTRD/Tracklet64.h" +#include "DataFormatsTRD/Constants.h" +#include "TRDBase/PadCalibrationsAliases.h" #include "TRDPID/PIDParameters.h" #include "Framework/ProcessingContext.h" #include "DataFormatsGlobalTracking/RecoContainer.h" @@ -34,7 +37,7 @@ namespace trd /// This is the PID Base class which defines the interface all other models /// must provide. /// -/// A 'policy' describes how a PID value (PIDValue) should be +/// A 'policy' describes how a PID value (float) should be /// calculated. For the classical algorithms there is no /// initialization needed since these work off LUTs. However, for ML /// models some initialization is needed, e.g. creating the @@ -52,18 +55,33 @@ class PIDBase virtual void init(o2::framework::ProcessingContext& pc) = 0; /// Calculate a PID for a given track. - virtual PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) = 0; + virtual float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const = 0; + + /// Set krypton calibration + void setLocalGainFactors(const LocalGainFactor* localGain) { mLocalGain = localGain; } protected: + /// Getter for pid information, applies Z-Row merging of tracklets and gain correction. + /// Some tracklets due to their inclination cross over two pads in z-row, where MCMs do not share ADC lanes. + /// This can be recovered in software, by taking the attached tracklets and looking for nearby tracklets. + /// Only modifies the tracklet if the flag is set. + std::array getCharges(const Tracklet64& tracklet, const int layer, const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, float snp, float tgl) const noexcept; + const TRDPIDParams& mParams{TRDPIDParams::Instance()}; ///< parameters - PIDPolicy mPolicy; ///< policy + const PIDPolicy mPolicy; ///< policy private: - ClassDefNV(PIDBase, 1); + /// Correct the charges of the tracklet + std::array correctCharges(const Tracklet64& trklt, float snp, float tgl) const noexcept; + + // correction factors + const LocalGainFactor* mLocalGain; ///< local gain factors from krypton calibration + + ClassDef(PIDBase, 1); }; /// Factory function to create a PID policy. -std::unique_ptr getTRDPIDBase(PIDPolicy policy); +std::unique_ptr getTRDPIDPolicy(PIDPolicy policy); } // namespace trd } // namespace o2 diff --git a/Detectors/TRD/pid/include/TRDPID/PIDParameters.h b/Detectors/TRD/pid/include/TRDPID/PIDParameters.h index 7014463962b73..e14cbdd599d16 100644 --- a/Detectors/TRD/pid/include/TRDPID/PIDParameters.h +++ b/Detectors/TRD/pid/include/TRDPID/PIDParameters.h @@ -24,9 +24,11 @@ namespace trd /// PID parameters. struct TRDPIDParams : public o2::conf::ConfigurableParamHelper { +#ifdef TRDPID_WITH_ONNX unsigned int numOrtThreads = 1; ///< ONNX Session threads unsigned int graphOptimizationLevel = 99; ///< ONNX GraphOptimization Level /// 0=Disable All, 1=Enable Basic, 2=Enable Extended, 99=Enable ALL +#endif // boilerplate O2ParamDef(TRDPIDParams, "TRDPIDParams"); diff --git a/Detectors/TRD/pid/macros/CMakeLists.txt b/Detectors/TRD/pid/macros/CMakeLists.txt index ba75551016e67..7c0fde086a4c7 100644 --- a/Detectors/TRD/pid/macros/CMakeLists.txt +++ b/Detectors/TRD/pid/macros/CMakeLists.txt @@ -9,6 +9,10 @@ # granted to it by virtue of its status as an Intergovernmental Organization # or submit itself to any jurisdiction. -o2_add_test_root_macro(ccdbModelUpload.C +o2_add_test_root_macro(ccdbPIDUpload.C + PUBLIC_LINK_LIBRARIES O2::TRDPID + LABELS trd) + +o2_add_test_root_macro(makeTestLUTs.C PUBLIC_LINK_LIBRARIES O2::TRDPID LABELS trd) diff --git a/Detectors/TRD/pid/macros/ccdbModelUpload.C b/Detectors/TRD/pid/macros/ccdbModelUpload.C deleted file mode 100644 index 921272b9bad90..0000000000000 --- a/Detectors/TRD/pid/macros/ccdbModelUpload.C +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019-2020 CERN and copyright holders of ALICE O2. -// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. -// All rights not expressly granted are reserved. -// -// This software is distributed under the terms of the GNU General Public -// License v3 (GPL Version 3), copied verbatim in the file "COPYING". -// -// In applying this license CERN does not waive the privileges and immunities -// granted to it by virtue of its status as an Intergovernmental Organization -// or submit itself to any jurisdiction. - -#if !defined(__CLING__) || defined(__ROOTCLING__) -// ROOT header -#include -// O2 header -#include "CCDB/CcdbApi.h" -#include "CCDB/BasicCCDBManager.h" -#include "CCDB/CcdbObjectInfo.h" - -#include -#include -#include -#include -#include -#include -#include - -#endif - -/// Upload an ONNX model to the ccdb. -/// This reads the file as a binary file and stores it as such. -void ccdbModelUpload(std::string inFileName, std::string ccdbPath = "TRD_test/PID/xgb") -{ - - o2::ccdb::CcdbApi ccdb; - // ccdb.init("http://alice-ccdb.cern.ch"); - // ccdb.init("http://localhost:8080"); - ccdb.init("http://ccdb-test.cern.ch:8080"); - // ccdb.init("http://o2-ccdb.internal"); - std::map metadata; - metadata["UploadedBy"] = "Felix Schlepper"; - metadata["EMail"] = "felix.schlepper@cern.ch"; - metadata["Description"] = "ONNX model for TRD PID"; - metadata["default"] = "false"; // tag default objects - metadata["Created"] = "1"; // tag default objects - - std::vector input; - std::ifstream inFile(inFileName, std::ios::binary); - if (!inFile.is_open()) { - std::cout << "Could not load file!" << std::endl; - return; - } - inFile.seekg(0, std::ios_base::end); - auto length = inFile.tellg(); - inFile.seekg(0, std::ios_base::beg); - input.resize(static_cast(length)); - inFile.read(reinterpret_cast(input.data()), length); - auto success = !inFile.fail() && length == inFile.gcount(); - if (!success) { - std::cout << "File loading went wrong!" << std::endl; - return; - } - inFile.close(); - - // for default objects: - // long timeStampStart = 0; - uint64_t timeStampStart = 1577833200000UL; // 1.1.2020 - uint64_t timeStampEnd = 2208985200000UL; // 1.1.2040 - - auto res = ccdb.storeAsBinaryFile(reinterpret_cast(input.data()), length, inFileName, "ONNX Model", ccdbPath, metadata, timeStampStart, timeStampEnd); - - if (res == 0) { - std::cout << "OK" << std::endl; - } else if (res == -1) { - std::cout << "ERROR: object bigger than maxSize" << std::endl; - } else if (res == -2) { - std::cout << "ERROR: curl initialization error" << std::endl; - } else { - std::cout << "ERROR: see curl error codes: " << res << std::endl; - } - return; -} diff --git a/Detectors/TRD/pid/macros/ccdbPIDUpload.C b/Detectors/TRD/pid/macros/ccdbPIDUpload.C new file mode 100644 index 0000000000000..96bc109e7475d --- /dev/null +++ b/Detectors/TRD/pid/macros/ccdbPIDUpload.C @@ -0,0 +1,112 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#if !defined(__CLING__) || defined(__ROOTCLING__) +// STL headers +#include +#include +#include +#include +#include + +// O2 header +#include "CCDB/CcdbApi.h" +#include "CCDB/BasicCCDBManager.h" +#include "CCDB/CcdbObjectInfo.h" +#include "Framework/Logger.h" +#include "TRDPID/LQND.h" + +// ROOT header +#include +#include + +#endif + +constexpr int fileError{-42}; + +o2::ccdb::CcdbApi ccdb; +std::map metadata{{"UploadedBy", "Felix Schlepper"}, {"EMail", "felix.schlepper@cern.ch"}, {"default", "false"}, {"Created", "1"}}; + +/// Upload an ONNX model to the ccdb. +/// This reads the file as a binary file and stores it as such. +int ccdbONNXUpload(std::string inFileName, std::string ccdbPath, uint64_t timeStampStart, uint64_t timeStampEnd) +{ + metadata["Description"] = "ONNX model for TRD PID"; + + std::vector input; + std::ifstream inFile(inFileName, std::ios::binary); + if (!inFile.is_open()) { + LOG(error) << "Could not open file (" << inFileName << "!)"; + return fileError; + } + inFile.seekg(0, std::ios_base::end); + auto length = inFile.tellg(); + inFile.seekg(0, std::ios_base::beg); + input.resize(static_cast(length)); + inFile.read(reinterpret_cast(input.data()), length); + auto success = !inFile.fail() && length == inFile.gcount(); + if (!success) { + LOG(error) << "Could not read file (" << inFileName << "!)"; + return fileError; + } + inFile.close(); + + return ccdb.storeAsBinaryFile(reinterpret_cast(input.data()), length, inFileName, "ONNX Model | file read as binary string", ccdbPath, metadata, timeStampStart, timeStampEnd); +} + +/// Upload LQND LUTs as std::vector +template +int ccdbLQNDUpload(std::string inFileName, std::string ccdbPath, uint64_t timeStampStart, uint64_t timeStampEnd) +{ + metadata["Description"] = Form("LQ%dD model for TRD PID", dim); + + std::unique_ptr inFile(TFile::Open(inFileName.c_str())); + if (!inFile || inFile->IsZombie()) { + LOG(error) << "Could not open file (" << inFileName << "!)"; + return fileError; + } + // copy vector from file + auto luts = *(inFile->Get>("luts")); + + return ccdb.storeAsTFileAny(&luts, ccdbPath, metadata, timeStampStart, timeStampEnd); +} + +void ccdbPIDUpload(std::string inFileName, std::string ccdbPath, bool testCCDB = true, bool ml = false, int dim = 1, uint64_t timeStampStart = 1577833200000UL /* 1.1.2020 */, uint64_t timeStampEnd = 2208985200000UL /* 1.1.2040 */) +{ + if (testCCDB) { + ccdb.init("http://ccdb-test.cern.ch:8080"); + } else { + ccdb.init("http://alice-ccdb.cern.ch"); + } + + int res{0}; + if (ml) { + res = ccdbONNXUpload(inFileName, ccdbPath, timeStampStart, timeStampEnd); + } else { + if (dim == 1) { + res = ccdbLQNDUpload<1>(inFileName, ccdbPath, timeStampStart, timeStampEnd); + } else { + res = ccdbLQNDUpload<3>(inFileName, ccdbPath, timeStampStart, timeStampEnd); + } + } + + if (res == 0) { + LOG(info) << "Upload: OKAY"; + } else if (res == -1) { + LOG(error) << "object bigger than maxSize"; + } else if (res == -2) { + LOG(error) << "curl initialization error"; + } else if (res == fileError) { + LOG(error) << "File reading error"; + } else { + LOG(error) << "see curl error codes: " << res; + } +} diff --git a/Detectors/TRD/pid/macros/makeTestLUTs.C b/Detectors/TRD/pid/macros/makeTestLUTs.C new file mode 100644 index 0000000000000..95068ea15233a --- /dev/null +++ b/Detectors/TRD/pid/macros/makeTestLUTs.C @@ -0,0 +1,40 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#if !defined(__CLING__) || defined(__ROOTCLING__) +// ROOT header +#include +#include + +#include "TRDPID/LQND.h" + +#include +#include +#endif + +constexpr int dim = 1; + +/// Generate very simple luts for testing +void makeTestLUTs() +{ + std::vector p{1.0, 2.0, 3.0, 100.0}; + std::vector g; + double x[4] = {0.0, 10, 70, 317}; + double y[4] = {0.0, 0.0, 1.0, 1.0}; + for (auto i = 0; i < dim * 2 * p.size(); ++i) { + g.emplace_back(4, x, y); + } + + o2::trd::detail::LUT luts(p, g); + + std::unique_ptr outFile(TFile::Open("LQND_LUTS.root", "RECREATE")); + outFile->WriteObject(&luts, "luts"); +} diff --git a/Detectors/TRD/pid/src/ML.cxx b/Detectors/TRD/pid/src/ML.cxx index c9e7a4b2b8693..5cbf47bc6b5fd 100644 --- a/Detectors/TRD/pid/src/ML.cxx +++ b/Detectors/TRD/pid/src/ML.cxx @@ -22,6 +22,8 @@ #include "Framework/ProcessingContext.h" #include "Framework/InputRecord.h" #include "Framework/Logger.h" +#include "DataFormatsTRD/CalibratedTracklet.h" +#include "DetectorsBase/Propagator.h" #include #include @@ -42,17 +44,10 @@ namespace trd void ML::init(o2::framework::ProcessingContext& pc) { - LOG(info) << "Finializing model initialization"; + LOG(info) << "Initializating pid policy"; // fetch the onnx model from the ccdb - std::string model_data; - switch (mPolicy) { - case PIDPolicy::Test: - model_data = fetchModelCCDB(pc, "mlTest"); - break; - default: - throw std::runtime_error("Could not load ML model from ccdb!"); - } + std::string model_data{fetchModelCCDB(pc, getName().c_str())}; // disable telemtry events mEnv.DisableTelemetryEvents(); @@ -89,37 +84,10 @@ void ML::init(o2::framework::ProcessingContext& pc) LOG(info) << "Finalization done"; } -PIDValue ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) -{ - if (isTPC) { - return calculate(trk, input); - } else { - return calculate(trk, input); - } -} - -std::string ML::fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const -{ - auto policyInt = static_cast(mPolicy); - // sanity checks - auto ref = pc.inputs().get(binding); - if (!ref.spec || !ref.payload) { - throw std::runtime_error(fmt::format("A ML model({}) with '{}' as binding does not exist!", PIDPolicyEnum[policyInt], binding)); - } - - // the model is in binary string format - auto model_data = pc.inputs().get(binding); - if (model_data.empty()) { - throw std::runtime_error(fmt::format("Did not get any data for {} model({}) from ccdb!", binding, PIDPolicyEnum[policyInt])); - } - return model_data; -} - -template -PIDValue ML::calculate(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) +float ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& inputTracks, bool isTPCTRD) const { try { - auto input = prepareModelInput(trkTRD, inputTracks); + auto input = prepareModelInput(trk, inputTracks); // create memory mapping to vector above auto inputTensor = Ort::Experimental::Value::CreateTensor(input.data(), input.size(), {static_cast(input.size()) / mInputShapes[0][1], mInputShapes[0][1]}); @@ -135,52 +103,57 @@ PIDValue ML::calculate(const TrackTRD& trkTRD, const o2::globaltracking::RecoCon } } -template -std::vector ML::prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) +std::string ML::fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const noexcept { - // input is [slope0, slope1, ..., slope5, charge0.0, charge0.1, charge0.2, charge1.0, ..., charge5.2, p] + // sanity checks + auto ref = pc.inputs().get(binding); + if (!ref.spec || !ref.payload) { + throw std::runtime_error(fmt::format("A ML model with '{}' as binding does not exist!", binding)); + } + + // the model is in binary string format + auto model_data = pc.inputs().get(binding); + if (model_data.empty()) { + throw std::runtime_error(fmt::format("Did not get any data for {} model from ccdb!", binding)); + } + return model_data; +} + +std::vector ML::prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) const noexcept +{ + // input is [charge0.0, charge0.1, charge0.2, charge1.0, ..., charge5.2, p0, ..., p5] std::vector in(mInputShapes[0][1]); const auto& trackletsRaw = inputTracks.getTRDTracklets(); - // std::fill(in.begin(), in.end(), 1.f); - auto id = trkTRD.getRefGlobalTrackId(); - in.back() = trkTRD.getP(); - // const auto& trkSeed = [&]() { - // if constexpr (isTPCTRD) { - // return mTracksInTPCTRD[id].getParamOut(); - // } else { - // return mTracksInITSTPCTRD[id].getParamOut(); - // } - // }; - - for (int iLayer = 0; iLayer < NLAYER; ++iLayer) { + auto trk = trdTRD; + for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) { int trkltId = trkTRD.getTrackletIndex(iLayer); if (trkltId < 0) { - /// easy fill with default values e.g. charge=-1., slope=0. - in[iLayer] = 0.f; - in[NLAYER + iLayer * NCHARGES + 0] = -1.f; - in[NLAYER + iLayer * NCHARGES + 1] = -1.f; - in[NLAYER + iLayer * NCHARGES + 2] = -1.f; + // no tracklet attached, fill with default values e.g. charge=-1., + in[iLayer * NCHARGES + 0] = -1.f; + in[iLayer * NCHARGES + 1] = -1.f; + in[iLayer * NCHARGES + 2] = -1.f; + in[18 + iLayer] = -1.f; continue; + } else { + const auto xCalib = input.getTRDCalibratedTracklets()[trkTRD.getTrackletIndex(iLayer)].getX(); + if (!o2::base::Propagator::Instance()->PropagateToXBxByBz(trk, xCalib, o2::base::Propagator::MAX_SIN_PHI, o2::base::Propagator::MAX_STEP, o2::base::Propagator::MatCorrType::USEMatCorrNONE)) { + LOGF(debug, "Track propagation failed in layer %i (pt=%f, xTrk=%f, xToGo=%f)", iLayer, trk.getPt(), trk.getX(), xCalib); + continue; + } + const auto snp = trk.getSnp(); + const auto tgl = trk.getTgl(); + const auto& trklt = trackletsRaw[trkltId]; + const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkTRD, input, snp, tgl); // correct charges + in[iLayer * NCHARGES + 0] = q0; + in[iLayer * NCHARGES + 1] = q1; + in[iLayer * NCHARGES + 2] = q2; + in[18 + iLayer] = trk.getP(); } - - auto trklt = trackletsRaw[trkltId]; - auto slope = trackletsRaw[trkltId].getSlopeBinSigned(); - auto q0 = trackletsRaw[trkltId].getQ0(); - auto q1 = trackletsRaw[trkltId].getQ1(); - auto q2 = trackletsRaw[trkltId].getQ2(); - - // TODO handel padrow crossing e.g. z-row merging - - in[iLayer] = slope; - in[NLAYER + iLayer * 3 + 0] = q0; - in[NLAYER + iLayer * 3 + 1] = q1; - in[NLAYER + iLayer * 3 + 2] = q2; } return in; } -// pretty prints a shape dimension vector std::string ML::printShape(const std::vector& v) const noexcept { std::stringstream ss(""); @@ -191,12 +164,5 @@ std::string ML::printShape(const std::vector& v) const noexcept return ss.str(); } -/// XGBoost export is like this: -/// (label|eprob, 1-eprob). -PIDValue XGB::getELikelihood(const std::vector& tensorData) const noexcept -{ - return tensorData[1].GetTensorData()[1]; -} - } // namespace trd } // namespace o2 diff --git a/Detectors/TRD/pid/src/PIDBase.cxx b/Detectors/TRD/pid/src/PIDBase.cxx index caf6912d28fc5..7ca8c51f9aaa4 100644 --- a/Detectors/TRD/pid/src/PIDBase.cxx +++ b/Detectors/TRD/pid/src/PIDBase.cxx @@ -14,31 +14,68 @@ #include "TRDPID/PIDBase.h" #include "DataFormatsTRD/PID.h" +#include "Framework/Logger.h" + #ifdef TRDPID_WITH_ONNX #include "TRDPID/ML.h" #endif +#include "TRDPID/LQND.h" #include "TRDPID/Dummy.h" -#include "Framework/Logger.h" -#include "fmt/format.h" namespace o2 { namespace trd { -std::unique_ptr getTRDPIDBase(PIDPolicy policy) +std::array PIDBase::getCharges(const Tracklet64& tracklet, const int layer, const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, float snp, float tgl) const noexcept { - auto policyInt = static_cast(policy); - LOG(info) << "Creating PID policy. Loading model " << PIDPolicyEnum[policyInt]; + // Check z-row merging needs to be performed to recover full charge information + if (trk.getIsCrossingNeighbor(layer) && trk.getHasNeighbor()) { // tracklet needs correction + for (const auto& trklt : input.getTRDTracklets()) { // search for nearby tracklet + if (std::abs(tracklet.getPadCol() - trklt.getPadCol()) <= 1 && std::abs(tracklet.getPadRow() - trklt.getPadRow()) == 1) { + if (tracklet.getTrackletWord() == trklt.getTrackletWord()) { // skip original tracklet + continue; + } + + // Add charge information + const auto [aQ0, aQ1, aQ2] = correctCharges(tracklet, snp, tgl); + const auto [bQ0, bQ1, bQ2] = correctCharges(tracklet, snp, tgl); + return {aQ0 + bQ0, aQ1 + bQ1, aQ2 + bQ2}; + } + } + } + + return correctCharges(tracklet, snp, tgl); +} + +std::array PIDBase::correctCharges(const Tracklet64& trklt, float snp, float tgl) const noexcept +{ + auto tphi = snp / std::sqrt((1.f - snp) + (1.f + snp)); + auto trackletLength = std::sqrt(1.f + tphi * tphi + tgl * tgl); + const float correction = mLocalGain->getValue(trklt.getHCID() / 2, trklt.getPadCol(), trklt.getPadRow()) * trackletLength; + return { + trklt.getQ0() / correction, + trklt.getQ1() / correction, + trklt.getQ2() / correction, + }; +} + +std::unique_ptr getTRDPIDPolicy(PIDPolicy policy) +{ + LOG(info) << "Creating PID policy. Loading model " << policy; switch (policy) { -#ifdef TRDPID_WITH_ONNX - case PIDPolicy::Test: - return std::make_unique(PIDPolicy::Test); + case PIDPolicy::LQ1D: + return std::make_unique(PIDPolicy::LQ1D); + case PIDPolicy::LQ3D: + return std::make_unique(PIDPolicy::LQ3D); +#ifdef TRDPID_WITH_ONNX // Add all policies that use ONNX in this ifdef + case PIDPolicy::XGB: + return std::make_unique(PIDPolicy::XGB); + case PIDPolicy::PY: + return std::make_unique(PIDPolicy::PY); #endif case PIDPolicy::Dummy: return std::make_unique(PIDPolicy::Dummy); - default: - throw std::invalid_argument(fmt::format("Cannot create this PID policy {}({})", PIDPolicyEnum[policyInt], policyInt)); } return nullptr; // cannot be reached } diff --git a/Detectors/TRD/pid/src/TRDPIDLinkDef.h b/Detectors/TRD/pid/src/TRDPIDLinkDef.h index 06996c711f3a3..f0054f2e823bf 100644 --- a/Detectors/TRD/pid/src/TRDPIDLinkDef.h +++ b/Detectors/TRD/pid/src/TRDPIDLinkDef.h @@ -18,7 +18,12 @@ #pragma link C++ class o2::trd::PIDBase + ; #pragma link C++ class o2::trd::ML + ; #pragma link C++ class o2::trd::Dummy + ; +#pragma link C++ class o2::trd::LQND < 1> + ; +#pragma link C++ class o2::trd::LQND < 3> + ; +#pragma link C++ class o2::trd::detail::LUT < 1> + ; +#pragma link C++ class o2::trd::detail::LUT < 3> + ; #pragma link C++ class o2::trd::TRDPIDParams + ; #pragma link C++ class o2::conf::ConfigurableParamHelper < o2::trd::TRDPIDParams> + ; +#pragma link C++ class std::vector < TGraph> + ; #endif diff --git a/Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h b/Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h index 25a0ccf6dce35..761e9bac99c6f 100644 --- a/Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h +++ b/Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h @@ -17,7 +17,12 @@ #pragma link C++ class o2::trd::PIDBase + ; #pragma link C++ class o2::trd::Dummy + ; +#pragma link C++ class o2::trd::LQND < 1> + ; +#pragma link C++ class o2::trd::LQND < 3> + ; +#pragma link C++ class o2::trd::detail::LUT < 1> + ; +#pragma link C++ class o2::trd::detail::LUT < 3> + ; #pragma link C++ class o2::trd::TRDPIDParams + ; #pragma link C++ class o2::conf::ConfigurableParamHelper < o2::trd::TRDPIDParams> + ; +#pragma link C++ class std::vector < TGraph> + ; #endif diff --git a/Detectors/TRD/workflow/src/TRDGlobalTrackingSpec.cxx b/Detectors/TRD/workflow/src/TRDGlobalTrackingSpec.cxx index 20cdd4c12ac0c..5b0e0fdb46586 100644 --- a/Detectors/TRD/workflow/src/TRDGlobalTrackingSpec.cxx +++ b/Detectors/TRD/workflow/src/TRDGlobalTrackingSpec.cxx @@ -121,8 +121,9 @@ void TRDGlobalTracking::updateTimeDependentParams(ProcessingContext& pc) /// Get the PID model if requested if (mWithPID) { - mBase = getTRDPIDBase(mPolicy); + mBase = getTRDPIDPolicy(mPolicy); mBase->init(pc); + mBase->setLocalGainFactors(pc.inputs().get("localgainfactors").get()); } } @@ -837,21 +838,27 @@ DataProcessorSpec getTRDGlobalTrackingSpec(bool useMC, GTrackID::mask_t src, boo // Request PID policy data if (withPID) { + // request policy switch (policy) { case PIDPolicy::LQ1D: - // inputs.emplace_back("LQ1D", "TRD", "LQ1D", 0, Lifetime::Condition, ccdbParamSpec("TRD/ppPID/LQ1D")); + inputs.emplace_back("lq1dlut", "TRD", "LQ1D", 0, Lifetime::Condition, ccdbParamSpec("TRD_test/PID/LQ1D")); break; case PIDPolicy::LQ3D: - // inputs.emplace_back("LQ3D", "TRD", "LQ3D", 0, Lifetime::Condition, ccdbParamSpec("TRD/ppPID/LQ3D")); + inputs.emplace_back("lq3dlut", "TRD", "LQ3D", 0, Lifetime::Condition, ccdbParamSpec("TRD_test/PID/LQ3D")); break; - case PIDPolicy::Test: - inputs.emplace_back("mlTest", "TRD", "MLTEST", 0, Lifetime::Condition, ccdbParamSpec("TRD_test/pid/xgb1")); +#ifdef TRDPID_WITH_ONNX + case PIDPolicy::XGB: + inputs.emplace_back("xgb", "TRD", "XGB", 0, Lifetime::Condition, ccdbParamSpec("TRD_test/PID_new/xgb")); break; + case PIDPolicy::PY: + inputs.emplace_back("py", "TRD", "py", 0, Lifetime::Condition, ccdbParamSpec("TRD_test/PID_new/py")); + break; +#endif case PIDPolicy::Dummy: break; - default: - throw std::runtime_error("Unable to load requested PID policy data!"); } + // request calibration data + inputs.emplace_back("localgainfactors", "TRD", "LOCALGAINFACTORS", 0, Lifetime::Condition, ccdbParamSpec("TRD/Calib/LocalGainFactor")); } if (GTrackID::includesSource(GTrackID::Source::ITSTPC, src)) {