Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRD: PID: fixups #11664

Merged
merged 1 commit into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <array>
#include <unordered_map>
#include <string>
#include <iostream>

namespace o2
{
Expand All @@ -28,46 +29,71 @@ namespace trd
enum class PIDPolicy : unsigned int {
// Classical Algorithms
LQ1D = 0, ///< 1-Dimensional Likelihood model
LQ2D, ///< 2-Dimensional Likelihood model
LQ3D, ///< 3-Dimensional Likelihood model

#ifdef TRDPID_WITH_ONNX
// ML models
XGB, ///< XGBOOST
PY, ///< Pytorch
#endif

// Do not add anything after this!
NMODELS, ///< Count of all models
Test, ///< Load object for testing
Dummy, ///< Dummy object outputting -1.f
DEFAULT = Dummy, ///< The default option
};

inline std::ostream& operator<<(std::ostream& os, const PIDPolicy& policy)
{
std::string name;
switch (policy) {
case PIDPolicy::LQ1D:
name = "LQ1D";
break;
case PIDPolicy::LQ2D:
name = "LQ2D";
break;
case PIDPolicy::LQ3D:
name = "LQ3D";
break;
#ifdef TRDPID_WITH_ONNX
case PIDPolicy::XGB:
name = "XGBoost";
break;
case PIDPolicy::PY:
name = "PyTorch";
break;
#endif
case PIDPolicy::Dummy:
name = "Dummy";
break;
default:
name = "Default";
}
os << name;
return os;
}

/// Transform PID policy from string to enum.
static const std::unordered_map<std::string, PIDPolicy> PIDPolicyString{
// Classical Algorithms
{"LQ1D", PIDPolicy::LQ1D},
{"LQ2D", PIDPolicy::LQ2D},
{"LQ3D", PIDPolicy::LQ3D},

#ifdef TRDPID_WITH_ONNX
// ML models
{"XGB", PIDPolicy::XGB},
{"PY", PIDPolicy::PY},
#endif

// General
{"TEST", PIDPolicy::Test},
{"DUMMY", PIDPolicy::Dummy},
// Default
{"default", PIDPolicy::DEFAULT},
};

/// Transform PID policy from string to enum.
static const char* PIDPolicyEnum[] = {
"LQ1D",
"LQ3D",
"XGBoost",
"NMODELS",
"Test",
"Dummy",
"default(=TODO)"};

using PIDValue = float;

} // namespace trd
} // namespace o2

Expand Down
2 changes: 2 additions & 0 deletions Detectors/TRD/pid/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ if(ONNXRuntime_FOUND)
HEADERS include/TRDPID/PIDBase.h
include/TRDPID/PIDParameters.h
include/TRDPID/ML.h
include/TRDPID/LQND.h
include/TRDPID/Dummy.h)
else()
o2_target_root_dictionary(TRDPID
HEADERS include/TRDPID/PIDBase.h
include/TRDPID/PIDParameters.h
include/TRDPID/Dummy.h
include/TRDPID/LQND.h
LINKDEF src/TRDPIDNoMLLinkDef.h)
endif()

Expand Down
64 changes: 64 additions & 0 deletions Detectors/TRD/pid/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Particle Identification with TRD
## Usage
Activate PID during tracking with the '--with-pid' flag.

o2-trd-global-tracking --with-pid --policy ML

Specify a which algorithm (called policy) should be use.
Implemented are the following:

- LQ1D
- LQ2D
- LQ3D
- ML (every model, which is exported to the ONNX format):
- XGB (XGBoost model)
- NN (Pytorch model)
- Dummy (returns only -1)
- Test (one of the above)
- Default (one of the above, gets picked if '--policy' is unspecified)

## Implementation details
### Tracking workflow
Every TRDTrack gets a PID value set (mSignal), which then gets propergated to the AO2D writer.

### Basic Interface
The base interface for PID is defined in [here](include/TRDPID/PIDBase.h).
The 'init' function is such that each policy can specify what if anything it needs from the CCDB.
For the 'process' each policy defines how a TRDTrack gets assigned a PID value.
Additionally, the base class implements how to get the _corrected charges_ from the tracklets.
_Corrected charges_ means z-row merged and calibrated charges.

### Classical Likelihood
The classical LQND policies ([here](include/TRDPID/LQND.h)) require an array of lookup tables (LUTs) from the ccdb.
$N$ stands for the dimension.
Then the electron likelihood for layer $i$ is defined as this:

$$L_i(e|Q_i)=\frac{P(Q_i|e)}{P(Q_i|e)+P(Q_i|\pi)}$$

From the charge $Q_i$ the LUTs give the corresponding $L_i$.
The _combined electron likelihood_ is obtained by this formula:

$$L(e|Q)=\frac{\prod_i L_i(e|Q_i)}{\prod_i L_i(e|Q_i) + \prod_i L_i(\pi|Q_i)}$$

where $L_i(\pi|Q_i)=1-L_i(e|Q_i)$.


Extension to higher dimensions is easy each tracklet has charges $Q_j$ which cover the integral of the pulse height spectrum in different slice ($j\in [0,1,2]$).
In our case $Q0$ covers the pulse height peak, $Q1$ the Transition Radiation peak and $Q2$ the plateau.
For each charge $j$ a LUT is available which gives the likelihood $L^e_j$.
For each layer $i$ the likelihood is then:

$$L_i(e|Q)=\frac{\prod_j L_{i,j}(e|Q_j)}{\prod_j L_{i,j}(e|Q_j) + \prod_j L_{i,j}(\pi|Q_j)}$$

The combined electron likelihood is then:

$$L(e|Q)=\frac{\prod_{i,j} L_{i,j}(e|Q_j)}{\prod_{i,j} L_{i,j}(e|Q_j) + \prod_{i,j} L_{i,j}(\pi|Q_j)}$$


### Machine Learning
The ML policies ([here](include/TRDPID/ML.h)) are uploaded to the CCDB in the ONNX file format (most python machine learning libraries support this standardized format).
In O2 we leverage the ONNXRuntime to use these formats and calculate a PID value.
The models can thus be trained in python which is very convenient.
The code should take care of most of the annoying stuff.
Policies just have to specify how to get the electron likelihood from the ONNXRuntime output (each python library varies in that somewhat).
The 'prepareModelInput' prepares the TRDTracks as input.
4 changes: 2 additions & 2 deletions Detectors/TRD/pid/include/TRDPID/Dummy.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ class Dummy final : public PIDBase
using PIDBase::PIDBase;

public:
~Dummy() final = default;
~Dummy() = default;

/// Do absolutely nothing.
void init(o2::framework::ProcessingContext& pc) final{};

/// Everything below 0.f indicates nothing available.
PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) final
float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final
{
return -1.f;
};
Expand Down
160 changes: 160 additions & 0 deletions Detectors/TRD/pid/include/TRDPID/LQND.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
// All rights not expressly granted are reserved.
//
// This software is distributed under the terms of the GNU General Public
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
//
// In applying this license CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

/// \file LQND.h
/// \brief This file provides the interface for loglikehood policies
/// \author Felix Schlepper

#ifndef O2_TRD_LQND_H
#define O2_TRD_LQND_H

#include "TGraph.h"
#include "TRDPID/PIDBase.h"
#include "DataFormatsTRD/PID.h"
#include "DataFormatsTRD/Constants.h"
#include "DataFormatsTRD/HelperMethods.h"
#include "Framework/ProcessingContext.h"
#include "Framework/InputRecord.h"
#include "DataFormatsTRD/CalibratedTracklet.h"
#include "DetectorsBase/Propagator.h"
#include "Framework/Logger.h"
#include "ReconstructionDataFormats/TrackParametrization.h"

#include <memory>
#include <vector>
#include <array>
#include <string>
#include <numeric>

namespace o2
{
namespace trd
{
namespace detail
{
/// Lookup Table class for ccdb upload
template <int nDim>
class LUT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the way to have the LUT as a class with a configurable number of dimensions. Would it be possible to make it accept not only 1 and 3D, but also things like 2 or 7D for example? I think we had such LUTs in the past in Run 2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7D seems nonsensical since we only get 3 windows from the FEE anyways, this was different in Run 2 were we would get 7 windows, I think.
How would you combine the windows in the 2D case (Q0+Q2 and Q1, seems the most resonable to me?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok 7D is probably indeed not needed, but 2D I could imagine (although no idea which windows would be best at the moment). Could you not get rid of the whole if constexpr block of the class? For 1D it will anyhow be called with iDim = 0 only and so you could return mLUTS[index + iDim] always, no?

{
public:
LUT() = default;
LUT(std::vector<float> p, std::vector<TGraph> l) : mIntervalsP(p), mLUTs(l) {}

//
const TGraph& get(float p, bool isNegative, int iDim = 0) const
{
auto upper = std::upper_bound(mIntervalsP.begin(), mIntervalsP.end(), p);
if (upper == mIntervalsP.end()) {
// outside of momentum intervals, should not happen
return mLUTs[0];
}
auto index = std::distance(mIntervalsP.begin(), upper);
index += (isNegative) ? 0 : mIntervalsP.size() * nDim;
return mLUTs[index + iDim];
}

private:
std::vector<float> mIntervalsP; ///< half-open interval upper bounds starting at 0, e.g., for {1.0,2.0,...} is (-inf,1.0], (1.0,2.0], (2.0, ...)
std::vector<TGraph> mLUTs; ///< corresponding likelihood lookup tables

ClassDefNV(LUT, 1);
};
} // namespace detail

/// This is the ML Base class which defines the interface all machine learning
/// models.
template <int nDim>
class LQND : public PIDBase
{
static_assert(nDim == 1 || nDim == 2 || nDim == 3, "Likelihood only for 1/2/3 dimension");
using PIDBase::PIDBase;

public:
~LQND() = default;

void init(o2::framework::ProcessingContext& pc) final
{
// retrieve lookup tables (LUTs) from ccdb
mLUTs = *(pc.inputs().get<detail::LUT<nDim>*>(Form("lq%ddlut", nDim)));
}

float process(const TrackTRD& trkIn, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final
{
const auto& trkSeed = isTPCTRD ? input.getTPCTracks()[trkIn.getRefGlobalTrackId()].getParamOut() : input.getTPCITSTracks()[trkIn.getRefGlobalTrackId()].getParamOut(); // seeding track
auto trk = trkSeed;

const auto isNegative = std::signbit(trkSeed.getSign()); // positive and negative charged particles are treated differently since ExB effects the charge distributions
const auto& trackletsRaw = input.getTRDTracklets();
float lei0{1.f}, lei1{1.f}, lei2{1.f}, lpi0{1.f}, lpi1{1.f}, lpi2{1.f}; // likelihood per layer
for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) {
int trkltId = trkIn.getTrackletIndex(iLayer);
if (trkltId < 0) { // no tracklet attached
continue;
}
const auto xCalib = input.getTRDCalibratedTracklets()[trkIn.getTrackletIndex(iLayer)].getX();
auto bz = o2::base::Propagator::Instance()->getNominalBz();
const auto tgl = trk.getTgl();
const auto snp = trk.getSnpAt(o2::math_utils::sector2Angle(HelperMethods::getSector(input.getTRDTracklets()[trkIn.getTrackletIndex(iLayer)].getDetector())), xCalib, bz);
const auto& trklt = trackletsRaw[trkltId];
const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkIn, input, snp, tgl); // correct charges
if constexpr (nDim == 1) {
auto lut = mLUTs.get(trk.getP(), isNegative);
auto ll1{1.f};
ll1 = lut.Eval(q0 + q1 + q2);
lei0 *= ll1;
lpi0 *= (1.f - ll1);
} else if (nDim == 2) {
auto lut1 = mLUTs.get(trk.getP(), isNegative, 0);
auto lut2 = mLUTs.get(trk.getP(), isNegative, 1);
auto ll1{1.f};
auto ll2{1.f};
ll1 = lut1.Eval(q0 + q2);
ll2 = lut2.Eval(q1);
lei0 *= ll1;
lei1 *= ll2;
lpi0 *= (1.f - ll1);
lpi1 *= (1.f - ll2);
} else {
auto lut1 = mLUTs.get(trk.getP(), isNegative, 0);
auto lut2 = mLUTs.get(trk.getP(), isNegative, 1);
auto lut3 = mLUTs.get(trk.getP(), isNegative, 2);
auto ll1{1.f};
auto ll2{1.f};
auto ll3{1.f};
ll1 = lut1.Eval(q0);
ll2 = lut2.Eval(q1);
ll3 = lut3.Eval(q2);
lei0 *= ll1;
lei1 *= ll2;
lei2 *= ll3;
lpi0 *= (1.f - ll1);
lpi1 *= (1.f - ll2);
lpi2 *= (1.f - ll3);
}
}

return (lei0 * lei1 * lei2) / (lei0 * lei1 * lei2 + lpi0 * lpi1 * lpi2); // combined likelihood
}

private:
detail::LUT<nDim> mLUTs; ///< likelihood lookup tables

ClassDefNV(LQND, 1);
};

using LQ1D = LQND<1>;
using LQ2D = LQND<2>;
using LQ3D = LQND<3>;

} // namespace trd
} // namespace o2

#endif
Loading