diff --git a/DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h b/DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h index cbb62d8f1e164..d99a5be14db4a 100644 --- a/DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h +++ b/DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace o2 { @@ -28,46 +29,71 @@ namespace trd enum class PIDPolicy : unsigned int { // Classical Algorithms LQ1D = 0, ///< 1-Dimensional Likelihood model + LQ2D, ///< 2-Dimensional Likelihood model LQ3D, ///< 3-Dimensional Likelihood model +#ifdef TRDPID_WITH_ONNX // ML models XGB, ///< XGBOOST + PY, ///< Pytorch +#endif // Do not add anything after this! NMODELS, ///< Count of all models - Test, ///< Load object for testing Dummy, ///< Dummy object outputting -1.f DEFAULT = Dummy, ///< The default option }; +inline std::ostream& operator<<(std::ostream& os, const PIDPolicy& policy) +{ + std::string name; + switch (policy) { + case PIDPolicy::LQ1D: + name = "LQ1D"; + break; + case PIDPolicy::LQ2D: + name = "LQ2D"; + break; + case PIDPolicy::LQ3D: + name = "LQ3D"; + break; +#ifdef TRDPID_WITH_ONNX + case PIDPolicy::XGB: + name = "XGBoost"; + break; + case PIDPolicy::PY: + name = "PyTorch"; + break; +#endif + case PIDPolicy::Dummy: + name = "Dummy"; + break; + default: + name = "Default"; + } + os << name; + return os; +} + /// Transform PID policy from string to enum. static const std::unordered_map PIDPolicyString{ // Classical Algorithms {"LQ1D", PIDPolicy::LQ1D}, + {"LQ2D", PIDPolicy::LQ2D}, {"LQ3D", PIDPolicy::LQ3D}, +#ifdef TRDPID_WITH_ONNX // ML models {"XGB", PIDPolicy::XGB}, + {"PY", PIDPolicy::PY}, +#endif // General - {"TEST", PIDPolicy::Test}, {"DUMMY", PIDPolicy::Dummy}, // Default {"default", PIDPolicy::DEFAULT}, }; -/// Transform PID policy from string to enum. -static const char* PIDPolicyEnum[] = { - "LQ1D", - "LQ3D", - "XGBoost", - "NMODELS", - "Test", - "Dummy", - "default(=TODO)"}; - -using PIDValue = float; - } // namespace trd } // namespace o2 diff --git a/Detectors/TRD/pid/CMakeLists.txt b/Detectors/TRD/pid/CMakeLists.txt index a3d5093c320fd..158103cf007c2 100644 --- a/Detectors/TRD/pid/CMakeLists.txt +++ b/Detectors/TRD/pid/CMakeLists.txt @@ -31,12 +31,14 @@ if(ONNXRuntime_FOUND) HEADERS include/TRDPID/PIDBase.h include/TRDPID/PIDParameters.h include/TRDPID/ML.h + include/TRDPID/LQND.h include/TRDPID/Dummy.h) else() o2_target_root_dictionary(TRDPID HEADERS include/TRDPID/PIDBase.h include/TRDPID/PIDParameters.h include/TRDPID/Dummy.h + include/TRDPID/LQND.h LINKDEF src/TRDPIDNoMLLinkDef.h) endif() diff --git a/Detectors/TRD/pid/README.md b/Detectors/TRD/pid/README.md new file mode 100644 index 0000000000000..27101dae09eb1 --- /dev/null +++ b/Detectors/TRD/pid/README.md @@ -0,0 +1,64 @@ +# Particle Identification with TRD +## Usage +Activate PID during tracking with the '--with-pid' flag. + + o2-trd-global-tracking --with-pid --policy ML + +Specify a which algorithm (called policy) should be use. +Implemented are the following: + + - LQ1D + - LQ2D + - LQ3D + - ML (every model, which is exported to the ONNX format): + - XGB (XGBoost model) + - NN (Pytorch model) + - Dummy (returns only -1) + - Test (one of the above) + - Default (one of the above, gets picked if '--policy' is unspecified) + +## Implementation details +### Tracking workflow +Every TRDTrack gets a PID value set (mSignal), which then gets propergated to the AO2D writer. + +### Basic Interface +The base interface for PID is defined in [here](include/TRDPID/PIDBase.h). +The 'init' function is such that each policy can specify what if anything it needs from the CCDB. +For the 'process' each policy defines how a TRDTrack gets assigned a PID value. +Additionally, the base class implements how to get the _corrected charges_ from the tracklets. +_Corrected charges_ means z-row merged and calibrated charges. + +### Classical Likelihood +The classical LQND policies ([here](include/TRDPID/LQND.h)) require an array of lookup tables (LUTs) from the ccdb. +$N$ stands for the dimension. +Then the electron likelihood for layer $i$ is defined as this: + +$$L_i(e|Q_i)=\frac{P(Q_i|e)}{P(Q_i|e)+P(Q_i|\pi)}$$ + +From the charge $Q_i$ the LUTs give the corresponding $L_i$. +The _combined electron likelihood_ is obtained by this formula: + +$$L(e|Q)=\frac{\prod_i L_i(e|Q_i)}{\prod_i L_i(e|Q_i) + \prod_i L_i(\pi|Q_i)}$$ + +where $L_i(\pi|Q_i)=1-L_i(e|Q_i)$. + + +Extension to higher dimensions is easy each tracklet has charges $Q_j$ which cover the integral of the pulse height spectrum in different slice ($j\in [0,1,2]$). +In our case $Q0$ covers the pulse height peak, $Q1$ the Transition Radiation peak and $Q2$ the plateau. +For each charge $j$ a LUT is available which gives the likelihood $L^e_j$. +For each layer $i$ the likelihood is then: + +$$L_i(e|Q)=\frac{\prod_j L_{i,j}(e|Q_j)}{\prod_j L_{i,j}(e|Q_j) + \prod_j L_{i,j}(\pi|Q_j)}$$ + +The combined electron likelihood is then: + +$$L(e|Q)=\frac{\prod_{i,j} L_{i,j}(e|Q_j)}{\prod_{i,j} L_{i,j}(e|Q_j) + \prod_{i,j} L_{i,j}(\pi|Q_j)}$$ + + +### Machine Learning +The ML policies ([here](include/TRDPID/ML.h)) are uploaded to the CCDB in the ONNX file format (most python machine learning libraries support this standardized format). +In O2 we leverage the ONNXRuntime to use these formats and calculate a PID value. +The models can thus be trained in python which is very convenient. +The code should take care of most of the annoying stuff. +Policies just have to specify how to get the electron likelihood from the ONNXRuntime output (each python library varies in that somewhat). +The 'prepareModelInput' prepares the TRDTracks as input. diff --git a/Detectors/TRD/pid/include/TRDPID/Dummy.h b/Detectors/TRD/pid/include/TRDPID/Dummy.h index 3158ad715fc61..4203f4c56953a 100644 --- a/Detectors/TRD/pid/include/TRDPID/Dummy.h +++ b/Detectors/TRD/pid/include/TRDPID/Dummy.h @@ -34,13 +34,13 @@ class Dummy final : public PIDBase using PIDBase::PIDBase; public: - ~Dummy() final = default; + ~Dummy() = default; /// Do absolutely nothing. void init(o2::framework::ProcessingContext& pc) final{}; /// Everything below 0.f indicates nothing available. - PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) final + float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final { return -1.f; }; diff --git a/Detectors/TRD/pid/include/TRDPID/LQND.h b/Detectors/TRD/pid/include/TRDPID/LQND.h new file mode 100644 index 0000000000000..7a898b25829c5 --- /dev/null +++ b/Detectors/TRD/pid/include/TRDPID/LQND.h @@ -0,0 +1,160 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +/// \file LQND.h +/// \brief This file provides the interface for loglikehood policies +/// \author Felix Schlepper + +#ifndef O2_TRD_LQND_H +#define O2_TRD_LQND_H + +#include "TGraph.h" +#include "TRDPID/PIDBase.h" +#include "DataFormatsTRD/PID.h" +#include "DataFormatsTRD/Constants.h" +#include "DataFormatsTRD/HelperMethods.h" +#include "Framework/ProcessingContext.h" +#include "Framework/InputRecord.h" +#include "DataFormatsTRD/CalibratedTracklet.h" +#include "DetectorsBase/Propagator.h" +#include "Framework/Logger.h" +#include "ReconstructionDataFormats/TrackParametrization.h" + +#include +#include +#include +#include +#include + +namespace o2 +{ +namespace trd +{ +namespace detail +{ +/// Lookup Table class for ccdb upload +template +class LUT +{ + public: + LUT() = default; + LUT(std::vector p, std::vector l) : mIntervalsP(p), mLUTs(l) {} + + // + const TGraph& get(float p, bool isNegative, int iDim = 0) const + { + auto upper = std::upper_bound(mIntervalsP.begin(), mIntervalsP.end(), p); + if (upper == mIntervalsP.end()) { + // outside of momentum intervals, should not happen + return mLUTs[0]; + } + auto index = std::distance(mIntervalsP.begin(), upper); + index += (isNegative) ? 0 : mIntervalsP.size() * nDim; + return mLUTs[index + iDim]; + } + + private: + std::vector mIntervalsP; ///< half-open interval upper bounds starting at 0, e.g., for {1.0,2.0,...} is (-inf,1.0], (1.0,2.0], (2.0, ...) + std::vector mLUTs; ///< corresponding likelihood lookup tables + + ClassDefNV(LUT, 1); +}; +} // namespace detail + +/// This is the ML Base class which defines the interface all machine learning +/// models. +template +class LQND : public PIDBase +{ + static_assert(nDim == 1 || nDim == 2 || nDim == 3, "Likelihood only for 1/2/3 dimension"); + using PIDBase::PIDBase; + + public: + ~LQND() = default; + + void init(o2::framework::ProcessingContext& pc) final + { + // retrieve lookup tables (LUTs) from ccdb + mLUTs = *(pc.inputs().get*>(Form("lq%ddlut", nDim))); + } + + float process(const TrackTRD& trkIn, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final + { + const auto& trkSeed = isTPCTRD ? input.getTPCTracks()[trkIn.getRefGlobalTrackId()].getParamOut() : input.getTPCITSTracks()[trkIn.getRefGlobalTrackId()].getParamOut(); // seeding track + auto trk = trkSeed; + + const auto isNegative = std::signbit(trkSeed.getSign()); // positive and negative charged particles are treated differently since ExB effects the charge distributions + const auto& trackletsRaw = input.getTRDTracklets(); + float lei0{1.f}, lei1{1.f}, lei2{1.f}, lpi0{1.f}, lpi1{1.f}, lpi2{1.f}; // likelihood per layer + for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) { + int trkltId = trkIn.getTrackletIndex(iLayer); + if (trkltId < 0) { // no tracklet attached + continue; + } + const auto xCalib = input.getTRDCalibratedTracklets()[trkIn.getTrackletIndex(iLayer)].getX(); + auto bz = o2::base::Propagator::Instance()->getNominalBz(); + const auto tgl = trk.getTgl(); + const auto snp = trk.getSnpAt(o2::math_utils::sector2Angle(HelperMethods::getSector(input.getTRDTracklets()[trkIn.getTrackletIndex(iLayer)].getDetector())), xCalib, bz); + const auto& trklt = trackletsRaw[trkltId]; + const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkIn, input, snp, tgl); // correct charges + if constexpr (nDim == 1) { + auto lut = mLUTs.get(trk.getP(), isNegative); + auto ll1{1.f}; + ll1 = lut.Eval(q0 + q1 + q2); + lei0 *= ll1; + lpi0 *= (1.f - ll1); + } else if (nDim == 2) { + auto lut1 = mLUTs.get(trk.getP(), isNegative, 0); + auto lut2 = mLUTs.get(trk.getP(), isNegative, 1); + auto ll1{1.f}; + auto ll2{1.f}; + ll1 = lut1.Eval(q0 + q2); + ll2 = lut2.Eval(q1); + lei0 *= ll1; + lei1 *= ll2; + lpi0 *= (1.f - ll1); + lpi1 *= (1.f - ll2); + } else { + auto lut1 = mLUTs.get(trk.getP(), isNegative, 0); + auto lut2 = mLUTs.get(trk.getP(), isNegative, 1); + auto lut3 = mLUTs.get(trk.getP(), isNegative, 2); + auto ll1{1.f}; + auto ll2{1.f}; + auto ll3{1.f}; + ll1 = lut1.Eval(q0); + ll2 = lut2.Eval(q1); + ll3 = lut3.Eval(q2); + lei0 *= ll1; + lei1 *= ll2; + lei2 *= ll3; + lpi0 *= (1.f - ll1); + lpi1 *= (1.f - ll2); + lpi2 *= (1.f - ll3); + } + } + + return (lei0 * lei1 * lei2) / (lei0 * lei1 * lei2 + lpi0 * lpi1 * lpi2); // combined likelihood + } + + private: + detail::LUT mLUTs; ///< likelihood lookup tables + + ClassDefNV(LQND, 1); +}; + +using LQ1D = LQND<1>; +using LQ2D = LQND<2>; +using LQ3D = LQND<3>; + +} // namespace trd +} // namespace o2 + +#endif diff --git a/Detectors/TRD/pid/include/TRDPID/ML.h b/Detectors/TRD/pid/include/TRDPID/ML.h index 59d902397dd4f..0b8d8ea8ea97b 100644 --- a/Detectors/TRD/pid/include/TRDPID/ML.h +++ b/Detectors/TRD/pid/include/TRDPID/ML.h @@ -40,31 +40,31 @@ class ML : public PIDBase public: void init(o2::framework::ProcessingContext& pc) final; - PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) final; + float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final; private: /// Return the electron likelihood. /// Different models have different ways to return the probability. - virtual PIDValue getELikelihood(const std::vector& tensorData) const noexcept = 0; + virtual inline float getELikelihood(const std::vector& tensorData) const noexcept = 0; /// Fetch a ML model from the ccdb via its binding - std::string fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const; - - /// Calculate pid value - template - PIDValue calculate(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks); + std::string fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const noexcept; /// Prepare model input /// Collect track properties in vector as flat array - template - std::vector prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks); + std::vector prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) const noexcept; /// Pretty print model shape std::string printShape(const std::vector& v) const noexcept; + /// Get DPL name + virtual inline std::string getName() const noexcept = 0; + // ONNX runtime Ort::Env mEnv{ORT_LOGGING_LEVEL_WARNING, "TRD-PID", - // integrate ORT logging into Fairlogger + // Integrate ORT logging into Fairlogger this way we can have + // all the nice logging while taking advantage of ORT telling us + // what to do. [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) { LOG(warn) << "Ort " << severity << ": [" << logid << "|" << category << "|" << code_location << "]: " << message << ((intptr_t)param == 3 ? " [valid]" : " [error]"); }, @@ -79,7 +79,7 @@ class ML : public PIDBase std::vector mOutputNames; ///< model output names std::vector> mOutputShapes; ///< output shape - ClassDefNV(ML, 1); + ClassDefOverride(ML, 1); }; /// XGBoost Model @@ -88,14 +88,40 @@ class XGB final : public ML using ML::ML; public: - ~XGB() final = default; + ~XGB() = default; private: - PIDValue getELikelihood(const std::vector& tensorData) const noexcept final; + /// XGBoost export is like this: + /// (label|eprob, 1-eprob). + inline float getELikelihood(const std::vector& tensorData) const noexcept + { + return tensorData[1].GetTensorData()[1]; + } + + inline std::string getName() const noexcept { return "xgb"; } ClassDefNV(XGB, 1); }; +/// PyTorch Model +class PY final : public ML +{ + using ML::ML; + + public: + ~PY() = default; + + private: + inline float getELikelihood(const std::vector& tensorData) const noexcept + { + return tensorData[0].GetTensorData()[0]; + } + + inline std::string getName() const noexcept { return "py"; } + + ClassDefNV(PY, 1); +}; + } // namespace trd } // namespace o2 diff --git a/Detectors/TRD/pid/include/TRDPID/PIDBase.h b/Detectors/TRD/pid/include/TRDPID/PIDBase.h index 38b486a02e699..a760a7df1496d 100644 --- a/Detectors/TRD/pid/include/TRDPID/PIDBase.h +++ b/Detectors/TRD/pid/include/TRDPID/PIDBase.h @@ -19,6 +19,9 @@ #include "Rtypes.h" #include "DataFormatsTRD/PID.h" #include "DataFormatsTRD/TrackTRD.h" +#include "DataFormatsTRD/Tracklet64.h" +#include "DataFormatsTRD/Constants.h" +#include "TRDBase/PadCalibrationsAliases.h" #include "TRDPID/PIDParameters.h" #include "Framework/ProcessingContext.h" #include "DataFormatsGlobalTracking/RecoContainer.h" @@ -34,7 +37,7 @@ namespace trd /// This is the PID Base class which defines the interface all other models /// must provide. /// -/// A 'policy' describes how a PID value (PIDValue) should be +/// A 'policy' describes how a PID value (float) should be /// calculated. For the classical algorithms there is no /// initialization needed since these work off LUTs. However, for ML /// models some initialization is needed, e.g. creating the @@ -52,18 +55,33 @@ class PIDBase virtual void init(o2::framework::ProcessingContext& pc) = 0; /// Calculate a PID for a given track. - virtual PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) = 0; + virtual float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const = 0; + + /// Set krypton calibration + void setLocalGainFactors(const LocalGainFactor* localGain) { mLocalGain = localGain; } protected: + /// Getter for pid information, applies Z-Row merging of tracklets and gain correction. + /// Some tracklets due to their inclination cross over two pads in z-row, where MCMs do not share ADC lanes. + /// This can be recovered in software, by taking the attached tracklets and looking for nearby tracklets. + /// Only modifies the tracklet if the flag is set. + std::array getCharges(const Tracklet64& tracklet, const int layer, const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, float snp, float tgl) const noexcept; + const TRDPIDParams& mParams{TRDPIDParams::Instance()}; ///< parameters - PIDPolicy mPolicy; ///< policy + const PIDPolicy mPolicy; ///< policy private: - ClassDefNV(PIDBase, 1); + /// Correct the charges of the tracklet + std::array correctCharges(const Tracklet64& trklt, float snp, float tgl) const noexcept; + + // correction factors + const LocalGainFactor* mLocalGain; ///< local gain factors from krypton calibration + + ClassDef(PIDBase, 1); }; /// Factory function to create a PID policy. -std::unique_ptr getTRDPIDBase(PIDPolicy policy); +std::unique_ptr getTRDPIDPolicy(PIDPolicy policy); } // namespace trd } // namespace o2 diff --git a/Detectors/TRD/pid/include/TRDPID/PIDParameters.h b/Detectors/TRD/pid/include/TRDPID/PIDParameters.h index 7014463962b73..e14cbdd599d16 100644 --- a/Detectors/TRD/pid/include/TRDPID/PIDParameters.h +++ b/Detectors/TRD/pid/include/TRDPID/PIDParameters.h @@ -24,9 +24,11 @@ namespace trd /// PID parameters. struct TRDPIDParams : public o2::conf::ConfigurableParamHelper { +#ifdef TRDPID_WITH_ONNX unsigned int numOrtThreads = 1; ///< ONNX Session threads unsigned int graphOptimizationLevel = 99; ///< ONNX GraphOptimization Level /// 0=Disable All, 1=Enable Basic, 2=Enable Extended, 99=Enable ALL +#endif // boilerplate O2ParamDef(TRDPIDParams, "TRDPIDParams"); diff --git a/Detectors/TRD/pid/macros/CMakeLists.txt b/Detectors/TRD/pid/macros/CMakeLists.txt index ba75551016e67..7c0fde086a4c7 100644 --- a/Detectors/TRD/pid/macros/CMakeLists.txt +++ b/Detectors/TRD/pid/macros/CMakeLists.txt @@ -9,6 +9,10 @@ # granted to it by virtue of its status as an Intergovernmental Organization # or submit itself to any jurisdiction. -o2_add_test_root_macro(ccdbModelUpload.C +o2_add_test_root_macro(ccdbPIDUpload.C + PUBLIC_LINK_LIBRARIES O2::TRDPID + LABELS trd) + +o2_add_test_root_macro(makeTestLUTs.C PUBLIC_LINK_LIBRARIES O2::TRDPID LABELS trd) diff --git a/Detectors/TRD/pid/macros/ccdbModelUpload.C b/Detectors/TRD/pid/macros/ccdbModelUpload.C deleted file mode 100644 index 921272b9bad90..0000000000000 --- a/Detectors/TRD/pid/macros/ccdbModelUpload.C +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019-2020 CERN and copyright holders of ALICE O2. -// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. -// All rights not expressly granted are reserved. -// -// This software is distributed under the terms of the GNU General Public -// License v3 (GPL Version 3), copied verbatim in the file "COPYING". -// -// In applying this license CERN does not waive the privileges and immunities -// granted to it by virtue of its status as an Intergovernmental Organization -// or submit itself to any jurisdiction. - -#if !defined(__CLING__) || defined(__ROOTCLING__) -// ROOT header -#include -// O2 header -#include "CCDB/CcdbApi.h" -#include "CCDB/BasicCCDBManager.h" -#include "CCDB/CcdbObjectInfo.h" - -#include -#include -#include -#include -#include -#include -#include - -#endif - -/// Upload an ONNX model to the ccdb. -/// This reads the file as a binary file and stores it as such. -void ccdbModelUpload(std::string inFileName, std::string ccdbPath = "TRD_test/PID/xgb") -{ - - o2::ccdb::CcdbApi ccdb; - // ccdb.init("http://alice-ccdb.cern.ch"); - // ccdb.init("http://localhost:8080"); - ccdb.init("http://ccdb-test.cern.ch:8080"); - // ccdb.init("http://o2-ccdb.internal"); - std::map metadata; - metadata["UploadedBy"] = "Felix Schlepper"; - metadata["EMail"] = "felix.schlepper@cern.ch"; - metadata["Description"] = "ONNX model for TRD PID"; - metadata["default"] = "false"; // tag default objects - metadata["Created"] = "1"; // tag default objects - - std::vector input; - std::ifstream inFile(inFileName, std::ios::binary); - if (!inFile.is_open()) { - std::cout << "Could not load file!" << std::endl; - return; - } - inFile.seekg(0, std::ios_base::end); - auto length = inFile.tellg(); - inFile.seekg(0, std::ios_base::beg); - input.resize(static_cast(length)); - inFile.read(reinterpret_cast(input.data()), length); - auto success = !inFile.fail() && length == inFile.gcount(); - if (!success) { - std::cout << "File loading went wrong!" << std::endl; - return; - } - inFile.close(); - - // for default objects: - // long timeStampStart = 0; - uint64_t timeStampStart = 1577833200000UL; // 1.1.2020 - uint64_t timeStampEnd = 2208985200000UL; // 1.1.2040 - - auto res = ccdb.storeAsBinaryFile(reinterpret_cast(input.data()), length, inFileName, "ONNX Model", ccdbPath, metadata, timeStampStart, timeStampEnd); - - if (res == 0) { - std::cout << "OK" << std::endl; - } else if (res == -1) { - std::cout << "ERROR: object bigger than maxSize" << std::endl; - } else if (res == -2) { - std::cout << "ERROR: curl initialization error" << std::endl; - } else { - std::cout << "ERROR: see curl error codes: " << res << std::endl; - } - return; -} diff --git a/Detectors/TRD/pid/macros/ccdbPIDUpload.C b/Detectors/TRD/pid/macros/ccdbPIDUpload.C new file mode 100644 index 0000000000000..96bc109e7475d --- /dev/null +++ b/Detectors/TRD/pid/macros/ccdbPIDUpload.C @@ -0,0 +1,112 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#if !defined(__CLING__) || defined(__ROOTCLING__) +// STL headers +#include +#include +#include +#include +#include + +// O2 header +#include "CCDB/CcdbApi.h" +#include "CCDB/BasicCCDBManager.h" +#include "CCDB/CcdbObjectInfo.h" +#include "Framework/Logger.h" +#include "TRDPID/LQND.h" + +// ROOT header +#include +#include + +#endif + +constexpr int fileError{-42}; + +o2::ccdb::CcdbApi ccdb; +std::map metadata{{"UploadedBy", "Felix Schlepper"}, {"EMail", "felix.schlepper@cern.ch"}, {"default", "false"}, {"Created", "1"}}; + +/// Upload an ONNX model to the ccdb. +/// This reads the file as a binary file and stores it as such. +int ccdbONNXUpload(std::string inFileName, std::string ccdbPath, uint64_t timeStampStart, uint64_t timeStampEnd) +{ + metadata["Description"] = "ONNX model for TRD PID"; + + std::vector input; + std::ifstream inFile(inFileName, std::ios::binary); + if (!inFile.is_open()) { + LOG(error) << "Could not open file (" << inFileName << "!)"; + return fileError; + } + inFile.seekg(0, std::ios_base::end); + auto length = inFile.tellg(); + inFile.seekg(0, std::ios_base::beg); + input.resize(static_cast(length)); + inFile.read(reinterpret_cast(input.data()), length); + auto success = !inFile.fail() && length == inFile.gcount(); + if (!success) { + LOG(error) << "Could not read file (" << inFileName << "!)"; + return fileError; + } + inFile.close(); + + return ccdb.storeAsBinaryFile(reinterpret_cast(input.data()), length, inFileName, "ONNX Model | file read as binary string", ccdbPath, metadata, timeStampStart, timeStampEnd); +} + +/// Upload LQND LUTs as std::vector +template +int ccdbLQNDUpload(std::string inFileName, std::string ccdbPath, uint64_t timeStampStart, uint64_t timeStampEnd) +{ + metadata["Description"] = Form("LQ%dD model for TRD PID", dim); + + std::unique_ptr inFile(TFile::Open(inFileName.c_str())); + if (!inFile || inFile->IsZombie()) { + LOG(error) << "Could not open file (" << inFileName << "!)"; + return fileError; + } + // copy vector from file + auto luts = *(inFile->Get>("luts")); + + return ccdb.storeAsTFileAny(&luts, ccdbPath, metadata, timeStampStart, timeStampEnd); +} + +void ccdbPIDUpload(std::string inFileName, std::string ccdbPath, bool testCCDB = true, bool ml = false, int dim = 1, uint64_t timeStampStart = 1577833200000UL /* 1.1.2020 */, uint64_t timeStampEnd = 2208985200000UL /* 1.1.2040 */) +{ + if (testCCDB) { + ccdb.init("http://ccdb-test.cern.ch:8080"); + } else { + ccdb.init("http://alice-ccdb.cern.ch"); + } + + int res{0}; + if (ml) { + res = ccdbONNXUpload(inFileName, ccdbPath, timeStampStart, timeStampEnd); + } else { + if (dim == 1) { + res = ccdbLQNDUpload<1>(inFileName, ccdbPath, timeStampStart, timeStampEnd); + } else { + res = ccdbLQNDUpload<3>(inFileName, ccdbPath, timeStampStart, timeStampEnd); + } + } + + if (res == 0) { + LOG(info) << "Upload: OKAY"; + } else if (res == -1) { + LOG(error) << "object bigger than maxSize"; + } else if (res == -2) { + LOG(error) << "curl initialization error"; + } else if (res == fileError) { + LOG(error) << "File reading error"; + } else { + LOG(error) << "see curl error codes: " << res; + } +} diff --git a/Detectors/TRD/pid/macros/makeTestLUTs.C b/Detectors/TRD/pid/macros/makeTestLUTs.C new file mode 100644 index 0000000000000..95068ea15233a --- /dev/null +++ b/Detectors/TRD/pid/macros/makeTestLUTs.C @@ -0,0 +1,40 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#if !defined(__CLING__) || defined(__ROOTCLING__) +// ROOT header +#include +#include + +#include "TRDPID/LQND.h" + +#include +#include +#endif + +constexpr int dim = 1; + +/// Generate very simple luts for testing +void makeTestLUTs() +{ + std::vector p{1.0, 2.0, 3.0, 100.0}; + std::vector g; + double x[4] = {0.0, 10, 70, 317}; + double y[4] = {0.0, 0.0, 1.0, 1.0}; + for (auto i = 0; i < dim * 2 * p.size(); ++i) { + g.emplace_back(4, x, y); + } + + o2::trd::detail::LUT luts(p, g); + + std::unique_ptr outFile(TFile::Open("LQND_LUTS.root", "RECREATE")); + outFile->WriteObject(&luts, "luts"); +} diff --git a/Detectors/TRD/pid/src/ML.cxx b/Detectors/TRD/pid/src/ML.cxx index c9e7a4b2b8693..2ce3b6876a893 100644 --- a/Detectors/TRD/pid/src/ML.cxx +++ b/Detectors/TRD/pid/src/ML.cxx @@ -15,13 +15,17 @@ #include "TRDPID/ML.h" #include "DataFormatsTRD/Constants.h" #include "DataFormatsTRD/Tracklet64.h" +#include "DataFormatsTRD/HelperMethods.h" #include "ReconstructionDataFormats/TrackTPCITS.h" #include "ReconstructionDataFormats/GlobalTrackID.h" +#include "ReconstructionDataFormats/TrackParametrization.h" #include "ReconstructionDataFormats/TrackParametrizationWithError.h" #include "DataFormatsTPC/TrackTPC.h" #include "Framework/ProcessingContext.h" #include "Framework/InputRecord.h" #include "Framework/Logger.h" +#include "DataFormatsTRD/CalibratedTracklet.h" +#include "DetectorsBase/Propagator.h" #include #include @@ -42,17 +46,10 @@ namespace trd void ML::init(o2::framework::ProcessingContext& pc) { - LOG(info) << "Finializing model initialization"; + LOG(info) << "Initializating pid policy"; // fetch the onnx model from the ccdb - std::string model_data; - switch (mPolicy) { - case PIDPolicy::Test: - model_data = fetchModelCCDB(pc, "mlTest"); - break; - default: - throw std::runtime_error("Could not load ML model from ccdb!"); - } + std::string model_data{fetchModelCCDB(pc, getName().c_str())}; // disable telemtry events mEnv.DisableTelemetryEvents(); @@ -89,37 +86,10 @@ void ML::init(o2::framework::ProcessingContext& pc) LOG(info) << "Finalization done"; } -PIDValue ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) -{ - if (isTPC) { - return calculate(trk, input); - } else { - return calculate(trk, input); - } -} - -std::string ML::fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const -{ - auto policyInt = static_cast(mPolicy); - // sanity checks - auto ref = pc.inputs().get(binding); - if (!ref.spec || !ref.payload) { - throw std::runtime_error(fmt::format("A ML model({}) with '{}' as binding does not exist!", PIDPolicyEnum[policyInt], binding)); - } - - // the model is in binary string format - auto model_data = pc.inputs().get(binding); - if (model_data.empty()) { - throw std::runtime_error(fmt::format("Did not get any data for {} model({}) from ccdb!", binding, PIDPolicyEnum[policyInt])); - } - return model_data; -} - -template -PIDValue ML::calculate(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) +float ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& inputTracks, bool isTPCTRD) const { try { - auto input = prepareModelInput(trkTRD, inputTracks); + auto input = prepareModelInput(trk, inputTracks); // create memory mapping to vector above auto inputTensor = Ort::Experimental::Value::CreateTensor(input.data(), input.size(), {static_cast(input.size()) / mInputShapes[0][1], mInputShapes[0][1]}); @@ -135,52 +105,54 @@ PIDValue ML::calculate(const TrackTRD& trkTRD, const o2::globaltracking::RecoCon } } -template -std::vector ML::prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) +std::string ML::fetchModelCCDB(o2::framework::ProcessingContext& pc, const char* binding) const noexcept { - // input is [slope0, slope1, ..., slope5, charge0.0, charge0.1, charge0.2, charge1.0, ..., charge5.2, p] + // sanity checks + auto ref = pc.inputs().get(binding); + if (!ref.spec || !ref.payload) { + throw std::runtime_error(fmt::format("A ML model with '{}' as binding does not exist!", binding)); + } + + // the model is in binary string format + auto model_data = pc.inputs().get(binding); + if (model_data.empty()) { + throw std::runtime_error(fmt::format("Did not get any data for {} model from ccdb!", binding)); + } + return model_data; +} + +std::vector ML::prepareModelInput(const TrackTRD& trkTRD, const o2::globaltracking::RecoContainer& inputTracks) const noexcept +{ + // input is [charge0.0, charge0.1, charge0.2, charge1.0, ..., charge5.2, p0, ..., p5] std::vector in(mInputShapes[0][1]); const auto& trackletsRaw = inputTracks.getTRDTracklets(); - // std::fill(in.begin(), in.end(), 1.f); - auto id = trkTRD.getRefGlobalTrackId(); - in.back() = trkTRD.getP(); - // const auto& trkSeed = [&]() { - // if constexpr (isTPCTRD) { - // return mTracksInTPCTRD[id].getParamOut(); - // } else { - // return mTracksInITSTPCTRD[id].getParamOut(); - // } - // }; - - for (int iLayer = 0; iLayer < NLAYER; ++iLayer) { + auto trk = trdTRD; + for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) { int trkltId = trkTRD.getTrackletIndex(iLayer); if (trkltId < 0) { - /// easy fill with default values e.g. charge=-1., slope=0. - in[iLayer] = 0.f; - in[NLAYER + iLayer * NCHARGES + 0] = -1.f; - in[NLAYER + iLayer * NCHARGES + 1] = -1.f; - in[NLAYER + iLayer * NCHARGES + 2] = -1.f; + // no tracklet attached, fill with default values e.g. charge=-1., + in[iLayer * NCHARGES + 0] = -1.f; + in[iLayer * NCHARGES + 1] = -1.f; + in[iLayer * NCHARGES + 2] = -1.f; + in[18 + iLayer] = -1.f; continue; + } else { + const auto xCalib = input.getTRDCalibratedTracklets()[trkTRD.getTrackletIndex(iLayer)].getX(); + auto bz = o2::base::Propagator::Instance()->getNominalBz(); + const auto tgl = trk.getTgl(); + const auto snp = trk.getSnpAt(o2::math_utils::sector2Angle(HelperMethods::getSector(input.getTRDTracklets()[trkIn.getTrackletIndex(iLayer)].getDetector())), xCalib, bz); + const auto& trklt = trackletsRaw[trkltId]; + const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkTRD, input, snp, tgl); // correct charges + in[iLayer * NCHARGES + 0] = q0; + in[iLayer * NCHARGES + 1] = q1; + in[iLayer * NCHARGES + 2] = q2; + in[18 + iLayer] = trk.getP(); } - - auto trklt = trackletsRaw[trkltId]; - auto slope = trackletsRaw[trkltId].getSlopeBinSigned(); - auto q0 = trackletsRaw[trkltId].getQ0(); - auto q1 = trackletsRaw[trkltId].getQ1(); - auto q2 = trackletsRaw[trkltId].getQ2(); - - // TODO handel padrow crossing e.g. z-row merging - - in[iLayer] = slope; - in[NLAYER + iLayer * 3 + 0] = q0; - in[NLAYER + iLayer * 3 + 1] = q1; - in[NLAYER + iLayer * 3 + 2] = q2; } return in; } -// pretty prints a shape dimension vector std::string ML::printShape(const std::vector& v) const noexcept { std::stringstream ss(""); @@ -191,12 +163,5 @@ std::string ML::printShape(const std::vector& v) const noexcept return ss.str(); } -/// XGBoost export is like this: -/// (label|eprob, 1-eprob). -PIDValue XGB::getELikelihood(const std::vector& tensorData) const noexcept -{ - return tensorData[1].GetTensorData()[1]; -} - } // namespace trd } // namespace o2 diff --git a/Detectors/TRD/pid/src/PIDBase.cxx b/Detectors/TRD/pid/src/PIDBase.cxx index caf6912d28fc5..f0ccd502352c7 100644 --- a/Detectors/TRD/pid/src/PIDBase.cxx +++ b/Detectors/TRD/pid/src/PIDBase.cxx @@ -14,31 +14,72 @@ #include "TRDPID/PIDBase.h" #include "DataFormatsTRD/PID.h" +#include "Framework/Logger.h" + #ifdef TRDPID_WITH_ONNX #include "TRDPID/ML.h" #endif +#include "TRDPID/LQND.h" #include "TRDPID/Dummy.h" -#include "Framework/Logger.h" -#include "fmt/format.h" namespace o2 { namespace trd { -std::unique_ptr getTRDPIDBase(PIDPolicy policy) +std::array PIDBase::getCharges(const Tracklet64& tracklet, const int layer, const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, float snp, float tgl) const noexcept { - auto policyInt = static_cast(policy); - LOG(info) << "Creating PID policy. Loading model " << PIDPolicyEnum[policyInt]; + // Check z-row merging needs to be performed to recover full charge information + if (trk.getIsCrossingNeighbor(layer) && trk.getHasNeighbor()) { // tracklet needs correction + for (const auto& trklt : input.getTRDTracklets()) { // search for nearby tracklet + if (std::abs(tracklet.getPadCol() - trklt.getPadCol()) <= 1 && std::abs(tracklet.getPadRow() - trklt.getPadRow()) == 1) { + if (tracklet.getTrackletWord() == trklt.getTrackletWord()) { // skip original tracklet + continue; + } + + // Add charge information + const auto [aQ0, aQ1, aQ2] = correctCharges(tracklet, snp, tgl); + const auto [bQ0, bQ1, bQ2] = correctCharges(tracklet, snp, tgl); + return {aQ0 + bQ0, aQ1 + bQ1, aQ2 + bQ2}; + } + } + } + + return correctCharges(tracklet, snp, tgl); +} + +std::array PIDBase::correctCharges(const Tracklet64& trklt, float snp, float tgl) const noexcept +{ + auto tphi = snp / std::sqrt((1.f - snp) + (1.f + snp)); + auto trackletLength = std::sqrt(1.f + tphi * tphi + tgl * tgl); + const float correction = mLocalGain->getValue(trklt.getHCID() / 2, trklt.getPadCol(), trklt.getPadRow()) * trackletLength; + return { + trklt.getQ0() / correction, + trklt.getQ1() / correction, + trklt.getQ2() / correction, + }; +} + +std::unique_ptr getTRDPIDPolicy(PIDPolicy policy) +{ + LOG(info) << "Creating PID policy. Loading model " << policy; switch (policy) { -#ifdef TRDPID_WITH_ONNX - case PIDPolicy::Test: - return std::make_unique(PIDPolicy::Test); + case PIDPolicy::LQ1D: + return std::make_unique(PIDPolicy::LQ1D); + case PIDPolicy::LQ2D: + return std::make_unique(PIDPolicy::LQ2D); + case PIDPolicy::LQ3D: + return std::make_unique(PIDPolicy::LQ3D); +#ifdef TRDPID_WITH_ONNX // Add all policies that use ONNX in this ifdef + case PIDPolicy::XGB: + return std::make_unique(PIDPolicy::XGB); + case PIDPolicy::PY: + return std::make_unique(PIDPolicy::PY); #endif case PIDPolicy::Dummy: return std::make_unique(PIDPolicy::Dummy); default: - throw std::invalid_argument(fmt::format("Cannot create this PID policy {}({})", PIDPolicyEnum[policyInt], policyInt)); + return nullptr; } return nullptr; // cannot be reached } diff --git a/Detectors/TRD/pid/src/TRDPIDLinkDef.h b/Detectors/TRD/pid/src/TRDPIDLinkDef.h index 06996c711f3a3..8003e92d1c676 100644 --- a/Detectors/TRD/pid/src/TRDPIDLinkDef.h +++ b/Detectors/TRD/pid/src/TRDPIDLinkDef.h @@ -18,7 +18,14 @@ #pragma link C++ class o2::trd::PIDBase + ; #pragma link C++ class o2::trd::ML + ; #pragma link C++ class o2::trd::Dummy + ; +#pragma link C++ class o2::trd::LQND < 1> + ; +#pragma link C++ class o2::trd::LQND < 2> + ; +#pragma link C++ class o2::trd::LQND < 3> + ; +#pragma link C++ class o2::trd::detail::LUT < 1> + ; +#pragma link C++ class o2::trd::detail::LUT < 2> + ; +#pragma link C++ class o2::trd::detail::LUT < 3> + ; #pragma link C++ class o2::trd::TRDPIDParams + ; #pragma link C++ class o2::conf::ConfigurableParamHelper < o2::trd::TRDPIDParams> + ; +#pragma link C++ class std::vector < TGraph> + ; #endif diff --git a/Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h b/Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h index 25a0ccf6dce35..9ae4419a2e498 100644 --- a/Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h +++ b/Detectors/TRD/pid/src/TRDPIDNoMLLinkDef.h @@ -17,7 +17,14 @@ #pragma link C++ class o2::trd::PIDBase + ; #pragma link C++ class o2::trd::Dummy + ; +#pragma link C++ class o2::trd::LQND < 1> + ; +#pragma link C++ class o2::trd::LQND < 2> + ; +#pragma link C++ class o2::trd::LQND < 3> + ; +#pragma link C++ class o2::trd::detail::LUT < 1> + ; +#pragma link C++ class o2::trd::detail::LUT < 2> + ; +#pragma link C++ class o2::trd::detail::LUT < 3> + ; #pragma link C++ class o2::trd::TRDPIDParams + ; #pragma link C++ class o2::conf::ConfigurableParamHelper < o2::trd::TRDPIDParams> + ; +#pragma link C++ class std::vector < TGraph> + ; #endif diff --git a/Detectors/TRD/workflow/src/TRDGlobalTrackingSpec.cxx b/Detectors/TRD/workflow/src/TRDGlobalTrackingSpec.cxx index 20cdd4c12ac0c..87c7037475236 100644 --- a/Detectors/TRD/workflow/src/TRDGlobalTrackingSpec.cxx +++ b/Detectors/TRD/workflow/src/TRDGlobalTrackingSpec.cxx @@ -121,8 +121,9 @@ void TRDGlobalTracking::updateTimeDependentParams(ProcessingContext& pc) /// Get the PID model if requested if (mWithPID) { - mBase = getTRDPIDBase(mPolicy); + mBase = getTRDPIDPolicy(mPolicy); mBase->init(pc); + mBase->setLocalGainFactors(pc.inputs().get("localgainfactors").get()); } } @@ -837,21 +838,32 @@ DataProcessorSpec getTRDGlobalTrackingSpec(bool useMC, GTrackID::mask_t src, boo // Request PID policy data if (withPID) { + // request policy switch (policy) { case PIDPolicy::LQ1D: - // inputs.emplace_back("LQ1D", "TRD", "LQ1D", 0, Lifetime::Condition, ccdbParamSpec("TRD/ppPID/LQ1D")); + inputs.emplace_back("lq1dlut", "TRD", "LQ1D", 0, Lifetime::Condition, ccdbParamSpec("TRD/PID/LQ1D")); + break; + case PIDPolicy::LQ2D: + inputs.emplace_back("lq2dlut", "TRD", "LQ2D", 0, Lifetime::Condition, ccdbParamSpec("TRD/PID/LQ2D")); break; case PIDPolicy::LQ3D: - // inputs.emplace_back("LQ3D", "TRD", "LQ3D", 0, Lifetime::Condition, ccdbParamSpec("TRD/ppPID/LQ3D")); + inputs.emplace_back("lq3dlut", "TRD", "LQ3D", 0, Lifetime::Condition, ccdbParamSpec("TRD/PID/LQ3D")); + break; +#ifdef TRDPID_WITH_ONNX + case PIDPolicy::XGB: + inputs.emplace_back("xgb", "TRD", "XGB", 0, Lifetime::Condition, ccdbParamSpec("TRD_test/PID_new/xgb")); break; - case PIDPolicy::Test: - inputs.emplace_back("mlTest", "TRD", "MLTEST", 0, Lifetime::Condition, ccdbParamSpec("TRD_test/pid/xgb1")); + case PIDPolicy::PY: + inputs.emplace_back("py", "TRD", "py", 0, Lifetime::Condition, ccdbParamSpec("TRD_test/PID_new/py")); break; +#endif case PIDPolicy::Dummy: break; default: throw std::runtime_error("Unable to load requested PID policy data!"); } + // request calibration data + inputs.emplace_back("localgainfactors", "TRD", "LOCALGAINFACTORS", 0, Lifetime::Condition, ccdbParamSpec("TRD/Calib/LocalGainFactor")); } if (GTrackID::includesSource(GTrackID::Source::ITSTPC, src)) {