Skip to content

Commit

Permalink
TRD: PID: fixups
Browse files Browse the repository at this point in the history
This commit renames the factory function to better reflect its purpose.
Additionally, z-row merging and charge correction have been added to the codebase to improve functionality.
Pytorch policy and LQND policy have been added as new policies.
The README has also been added to provide additional explanation of the code.
As part of this update, the pid policy map and pidvalue alias have been
removed.
A print overload has been added for the policy enum to improve readability.
Various minor fixes have also been made to improve overall code quality.
Also included are various changes to the class layout and the ccdb object for LUTs.

Signed-off-by: Felix Schlepper <[email protected]>
  • Loading branch information
f3sch committed Jul 18, 2023
1 parent d077610 commit 3b6f8a7
Show file tree
Hide file tree
Showing 17 changed files with 598 additions and 213 deletions.
49 changes: 35 additions & 14 deletions DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <array>
#include <unordered_map>
#include <string>
#include <iostream>

namespace o2
{
Expand All @@ -30,44 +31,64 @@ enum class PIDPolicy : unsigned int {
LQ1D = 0, ///< 1-Dimensional Likelihood model
LQ3D, ///< 3-Dimensional Likelihood model

#ifdef TRDPID_WITH_ONNX
// ML models
XGB, ///< XGBOOST
PY, ///< Pytorch
#endif

// Do not add anything after this!
NMODELS, ///< Count of all models
Test, ///< Load object for testing
Dummy, ///< Dummy object outputting -1.f
DEFAULT = Dummy, ///< The default option
};

inline std::ostream& operator<<(std::ostream& os, const PIDPolicy& policy)
{
std::string name;
switch (policy) {
case PIDPolicy::LQ1D:
name = "LQ1D";
break;
case PIDPolicy::LQ3D:
name = "LQ3D";
break;
#ifdef TRDPID_WITH_ONNX
case PIDPolicy::XGB:
name = "XGBoost";
break;
case PIDPolicy::PY:
name = "PyTorch";
break;
#endif
case PIDPolicy::Dummy:
name = "Dummy";
break;
default:
name = "Default";
}
os << name;
return os;
}

/// Transform PID policy from string to enum.
static const std::unordered_map<std::string, PIDPolicy> PIDPolicyString{
// Classical Algorithms
{"LQ1D", PIDPolicy::LQ1D},
{"LQ3D", PIDPolicy::LQ3D},

#ifdef TRDPID_WITH_ONNX
// ML models
{"XGB", PIDPolicy::XGB},
{"PY", PIDPolicy::PY},
#endif

// General
{"TEST", PIDPolicy::Test},
{"DUMMY", PIDPolicy::Dummy},
// Default
{"default", PIDPolicy::DEFAULT},
};

/// Transform PID policy from string to enum.
static const char* PIDPolicyEnum[] = {
"LQ1D",
"LQ3D",
"XGBoost",
"NMODELS",
"Test",
"Dummy",
"default(=TODO)"};

using PIDValue = float;

} // namespace trd
} // namespace o2

Expand Down
2 changes: 2 additions & 0 deletions Detectors/TRD/pid/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ if(ONNXRuntime_FOUND)
HEADERS include/TRDPID/PIDBase.h
include/TRDPID/PIDParameters.h
include/TRDPID/ML.h
include/TRDPID/LQND.h
include/TRDPID/Dummy.h)
else()
o2_target_root_dictionary(TRDPID
HEADERS include/TRDPID/PIDBase.h
include/TRDPID/PIDParameters.h
include/TRDPID/Dummy.h
include/TRDPID/LQND.h
LINKDEF src/TRDPIDNoMLLinkDef.h)
endif()

Expand Down
63 changes: 63 additions & 0 deletions Detectors/TRD/pid/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Particle Identification with TRD
## Usage
Activate PID during tracking with the '--with-pid' flag.

o2-trd-global-tracking --with-pid --policy ML

Specify a which algorithm (called policy) should be use.
Implemented are the following:

- LQ1D
- LQ3D
- ML (every model, which is exported to the ONNX format):
- XGB (XGBoost model)
- NN (Pytorch model)
- Dummy (returns only -1)
- Test (one of the above)
- Default (one of the above, gets picked if '--policy' is unspecified)

## Implementation details
### Tracking workflow
Every TRDTrack gets a PID value set (mSignal), which then gets propergated to the AO2D writer.

### Basic Interface
The base interface for PID is defined in [here](include/TRDPID/PIDBase.h).
The 'init' function is such that each policy can specify what if anything it needs from the CCDB.
For the 'process' each policy defines how a TRDTrack gets assigned a PID value.
Additionally, the base class implements how to get the _corrected charges_ from the tracklets.
_Corrected charges_ means z-row merged and calibrated charges.

### Classical Likelihood
The classical LQND policies ([here](include/TRDPID/LQND.h)) require an array of lookup tables (LUTs) from the ccdb.
$N$ stands for the dimension.
Then the electron likelihood for layer $i$ is defined as this:

$$L_i(e|Q_i)=\frac{P(Q_i|e)}{P(Q_i|e)+P(Q_i|\pi)}$$

From the charge $Q_i$ the LUTs give the corresponding $L_i$.
The _combined electron likelihood_ is obtained by this formula:

$$L(e|Q)=\frac{\prod_i L_i(e|Q_i)}{\prod_i L_i(e|Q_i) + \prod_i L_i(\pi|Q_i)}$$

where $L_i(\pi|Q_i)=1-L_i(e|Q_i)$.


Extension to higher dimensions is easy each tracklet has charges $Q_j$ which cover the integral of the pulse height spectrum in different slice ($j\in [0,1,2]$).
In our case $Q0$ covers the pulse height peak, $Q1$ the Transition Radiation peak and $Q2$ the plateau.
For each charge $j$ a LUT is available which gives the likelihood $L^e_j$.
For each layer $i$ the likelihood is then:

$$L_i(e|Q)=\frac{\prod_j L_{i,j}(e|Q_j)}{\prod_j L_{i,j}(e|Q_j) + \prod_j L_{i,j}(\pi|Q_j)}$$

The combined electron likelihood is then:

$$L(e|Q)=\frac{\prod_{i,j} L_{i,j}(e|Q_j)}{\prod_{i,j} L_{i,j}(e|Q_j) + \prod_{i,j} L_{i,j}(\pi|Q_j)}$$


### Machine Learning
The ML policies ([here](include/TRDPID/ML.h)) are uploaded to the CCDB in the ONNX file format (most python machine learning libraries support this standardized format).
In O2 we leverage the ONNXRuntime to use these formats and calculate a PID value.
The models can thus be trained in python which is very convenient.
The code should take care of most of the annoying stuff.
Policies just have to specify how to get the electron likelihood from the ONNXRuntime output (each python library varies in that somewhat).
The 'prepareModelInput' prepares the TRDTracks as input.
4 changes: 2 additions & 2 deletions Detectors/TRD/pid/include/TRDPID/Dummy.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ class Dummy final : public PIDBase
using PIDBase::PIDBase;

public:
~Dummy() final = default;
~Dummy() = default;

/// Do absolutely nothing.
void init(o2::framework::ProcessingContext& pc) final{};

/// Everything below 0.f indicates nothing available.
PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) final
float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final
{
return -1.f;
};
Expand Down
159 changes: 159 additions & 0 deletions Detectors/TRD/pid/include/TRDPID/LQND.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
// All rights not expressly granted are reserved.
//
// This software is distributed under the terms of the GNU General Public
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
//
// In applying this license CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

/// \file LQND.h
/// \brief This file provides the interface for loglikehood policies
/// \author Felix Schlepper

#ifndef O2_TRD_LQND_H
#define O2_TRD_LQND_H

#include "TGraph.h"
#include "TRDPID/PIDBase.h"
#include "DataFormatsTRD/PID.h"
#include "DataFormatsTRD/Constants.h"
#include "Framework/ProcessingContext.h"
#include "Framework/InputRecord.h"
#include "DataFormatsTRD/CalibratedTracklet.h"
#include "DetectorsBase/Propagator.h"
#include "Framework/Logger.h"

#include <memory>
#include <vector>
#include <array>
#include <string>
#include <numeric>

namespace o2
{
namespace trd
{
namespace detail
{
/// Lookup Table class for ccdb upload
template <int nDim>
class LUT
{
public:
LUT() = default;
LUT(std::vector<float> p, std::vector<TGraph> l) : mIntervalsP(p), mLUTs(l) {}

//
const TGraph& get(float p, bool isNegative, int iDim = 0) const
{
auto upper = std::upper_bound(mIntervalsP.begin(), mIntervalsP.end(), p);
if (upper == mIntervalsP.end()) {
// outside of momentum intervals, should not happen
return mLUTs[0];
}
auto index = std::distance(mIntervalsP.begin(), upper);
index += (isNegative) ? 0 : mIntervalsP.size() * nDim;
if constexpr (nDim == 1) {
return mLUTs[index];
} else {
if (iDim == 0) {
return mLUTs[index + 0];
} else if (iDim == 1) {
return mLUTs[index + 1];
} else {
return mLUTs[index + 2];
}
}
}

private:
std::vector<float> mIntervalsP; ///< half-open interval upper bounds starting at 0, e.g., for {1.0,2.0,...} is (-inf,1.0], (1.0,2.0], (2.0, ...)
std::vector<TGraph> mLUTs; ///< corresponding likelihood lookup tables

ClassDefNV(LUT, 1);
};
} // namespace detail

/// This is the ML Base class which defines the interface all machine learning
/// models.
template <int nDim>
class LQND : public PIDBase
{
static_assert(nDim == 1 || nDim == 3, "Likelihood only for 1/3 dimension");
using PIDBase::PIDBase;

public:
~LQND() = default;

void init(o2::framework::ProcessingContext& pc) final
{
// retrieve lookup tables (LUTs) from ccdb
mLUTs = *(pc.inputs().get<detail::LUT<nDim>*>(Form("lq%ddlut", nDim)));
}

float process(const TrackTRD& trkIn, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final
{
const auto& trkSeed = isTPCTRD ? input.getTPCTracks()[trkIn.getRefGlobalTrackId()].getParamOut() : input.getTPCITSTracks()[trkIn.getRefGlobalTrackId()].getParamOut(); // seeding track
auto trk = trkSeed;

const auto isNegative = std::signbit(trkSeed.getSign()); // positive and negative charged particles are treated differently since ExB effects the charge distributions
const auto& trackletsRaw = input.getTRDTracklets();
float lei0{1.f}, lei1{1.f}, lei2{1.f}, lpi0{1.f}, lpi1{1.f}, lpi2{1.f}; // likelihood per layer
for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) {
int trkltId = trkIn.getTrackletIndex(iLayer);
if (trkltId < 0) { // no tracklet attached
continue;
}
const auto xCalib = input.getTRDCalibratedTracklets()[trkIn.getTrackletIndex(iLayer)].getX();
if (!o2::base::Propagator::Instance()->PropagateToXBxByBz(trk, xCalib, o2::base::Propagator::MAX_SIN_PHI, o2::base::Propagator::MAX_STEP, o2::base::Propagator::MatCorrType::USEMatCorrNONE)) {
LOGF(debug, "Track propagation failed in layer %i (pt=%f, xTrk=%f, xToGo=%f)", iLayer, trk.getPt(), trk.getX(), xCalib);
continue;
}
const auto snp = trk.getSnp();
const auto tgl = trk.getTgl();
const auto& trklt = trackletsRaw[trkltId];
const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkIn, input, snp, tgl); // correct charges
if constexpr (nDim == 1) {
auto lut = mLUTs.get(trk.getP(), isNegative);
auto ll1{1.f};
ll1 = lut.Eval(q0 + q1 + q2);
lei0 *= ll1;
lpi0 *= (1.f - ll1);
} else {
auto lut1 = mLUTs.get(trk.getP(), isNegative, 0);
auto lut2 = mLUTs.get(trk.getP(), isNegative, 1);
auto lut3 = mLUTs.get(trk.getP(), isNegative, 2);
auto ll1{1.f};
auto ll2{1.f};
auto ll3{1.f};
ll1 = lut1.Eval(q0);
ll2 = lut2.Eval(q1);
ll3 = lut3.Eval(q2);
lei0 *= ll1;
lei1 *= ll2;
lei2 *= ll3;
lpi0 *= (1.f - ll1);
lpi1 *= (1.f - ll2);
lpi2 *= (1.f - ll3);
}
}

return (lei0 * lei1 * lei2) / (lei0 * lei1 * lei2 + lpi0 * lpi1 * lpi2); // combined likelihood
}

private:
detail::LUT<nDim> mLUTs; ///< likelihood lookup tables

ClassDefNV(LQND, 1);
};

using LQ1D = LQND<1>;
using LQ3D = LQND<3>;

} // namespace trd
} // namespace o2

#endif
Loading

0 comments on commit 3b6f8a7

Please sign in to comment.