-
Notifications
You must be signed in to change notification settings - Fork 441
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit renames the factory function to better reflect its purpose. Additionally, z-row merging and charge correction have been added to the codebase to improve functionality. Pytorch policy and LQND policy have been added as new policies. The README has also been added to provide additional explanation of the code. As part of this update, the pid policy map and pidvalue alias have been removed. A print overload has been added for the policy enum to improve readability. Various minor fixes have also been made to improve overall code quality. Also included are various changes to the class layout and the ccdb object for LUTs. Signed-off-by: Felix Schlepper <[email protected]>
- Loading branch information
Showing
17 changed files
with
599 additions
and
212 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Particle Identification with TRD | ||
## Usage | ||
Activate PID during tracking with the '--with-pid' flag. | ||
|
||
o2-trd-global-tracking --with-pid --policy ML | ||
|
||
Specify a which algorithm (called policy) should be use. | ||
Implemented are the following: | ||
|
||
- LQ1D | ||
- LQ3D | ||
- ML (every model, which is exported to the ONNX format): | ||
- XGB (XGBoost model) | ||
- NN (Pytorch model) | ||
- Dummy (returns only -1) | ||
- Test (one of the above) | ||
- Default (one of the above, gets picked if '--policy' is unspecified) | ||
|
||
## Implementation details | ||
### Tracking workflow | ||
Every TRDTrack gets a PID value set (mSignal), which then gets propergated to the AO2D writer. | ||
|
||
### Basic Interface | ||
The base interface for PID is defined in [here](include/TRDPID/PIDBase.h). | ||
The 'init' function is such that each policy can specify what if anything it needs from the CCDB. | ||
For the 'process' each policy defines how a TRDTrack gets assigned a PID value. | ||
Additionally, the base class implements how to get the _corrected charges_ from the tracklets. | ||
_Corrected charges_ means z-row merged and calibrated charges. | ||
|
||
### Classical Likelihood | ||
The classical LQND policies ([here](include/TRDPID/LQND.h)) require an array of lookup tables (LUTs) from the ccdb. | ||
$N$ stands for the dimension. | ||
Then the electron likelihood for layer $i$ is defined as this: | ||
|
||
$$L_i(e|Q_i)=\frac{P(Q_i|e)}{P(Q_i|e)+P(Q_i|\pi)}$$ | ||
|
||
From the charge $Q_i$ the LUTs give the corresponding $L_i$. | ||
The _combined electron likelihood_ is obtained by this formula: | ||
|
||
$$L(e|Q)=\frac{\prod_i L_i(e|Q_i)}{\prod_i L_i(e|Q_i) + \prod_i L_i(\pi|Q_i)}$$ | ||
|
||
where $L_i(\pi|Q_i)=1-L_i(e|Q_i)$. | ||
|
||
|
||
Extension to higher dimensions is easy each tracklet has charges $Q_j$ which cover the integral of the pulse height spectrum in different slice ($j\in [0,1,2]$). | ||
In our case $Q0$ covers the pulse height peak, $Q1$ the Transition Radiation peak and $Q2$ the plateau. | ||
For each charge $j$ a LUT is available which gives the likelihood $L^e_j$. | ||
For each layer $i$ the likelihood is then: | ||
|
||
$$L_i(e|Q)=\frac{\prod_j L_{i,j}(e|Q_j)}{\prod_j L_{i,j}(e|Q_j) + \prod_j L_{i,j}(\pi|Q_j)}$$ | ||
|
||
The combined electron likelihood is then: | ||
|
||
$$L(e|Q)=\frac{\prod_{i,j} L_{i,j}(e|Q_j)}{\prod_{i,j} L_{i,j}(e|Q_j) + \prod_{i,j} L_{i,j}(\pi|Q_j)}$$ | ||
|
||
|
||
### Machine Learning | ||
The ML policies ([here](include/TRDPID/ML.h)) are uploaded to the CCDB in the ONNX file format (most python machine learning libraries support this standardized format). | ||
In O2 we leverage the ONNXRuntime to use these formats and calculate a PID value. | ||
The models can thus be trained in python which is very convenient. | ||
The code should take care of most of the annoying stuff. | ||
Policies just have to specify how to get the electron likelihood from the ONNXRuntime output (each python library varies in that somewhat). | ||
The 'prepareModelInput' prepares the TRDTracks as input. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
// Copyright 2019-2020 CERN and copyright holders of ALICE O2. | ||
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. | ||
// All rights not expressly granted are reserved. | ||
// | ||
// This software is distributed under the terms of the GNU General Public | ||
// License v3 (GPL Version 3), copied verbatim in the file "COPYING". | ||
// | ||
// In applying this license CERN does not waive the privileges and immunities | ||
// granted to it by virtue of its status as an Intergovernmental Organization | ||
// or submit itself to any jurisdiction. | ||
|
||
/// \file LQND.h | ||
/// \brief This file provides the interface for loglikehood policies | ||
/// \author Felix Schlepper | ||
|
||
#ifndef O2_TRD_LQND_H | ||
#define O2_TRD_LQND_H | ||
|
||
#include "TGraph.h" | ||
#include "TRDPID/PIDBase.h" | ||
#include "DataFormatsTRD/PID.h" | ||
#include "DataFormatsTRD/Constants.h" | ||
#include "Framework/ProcessingContext.h" | ||
#include "Framework/InputRecord.h" | ||
#include "DataFormatsTRD/CalibratedTracklet.h" | ||
#include "DetectorsBase/Propagator.h" | ||
#include "Framework/Logger.h" | ||
|
||
#include <memory> | ||
#include <vector> | ||
#include <array> | ||
#include <string> | ||
#include <numeric> | ||
|
||
namespace o2 | ||
{ | ||
namespace trd | ||
{ | ||
namespace detail | ||
{ | ||
/// Lookup Table class for ccdb upload | ||
template <int nDim> | ||
class LUT | ||
{ | ||
public: | ||
LUT() = default; | ||
LUT(std::vector<float> p, std::vector<TGraph> l) : mIntervalsP(p), mLUTs(l) {} | ||
|
||
// | ||
const TGraph& get(float p, bool isNegative, int iDim = 0) const | ||
{ | ||
auto upper = std::upper_bound(mIntervalsP.begin(), mIntervalsP.end(), p); | ||
if (upper == mIntervalsP.end()) { | ||
// outside of momentum intervals, should not happen | ||
return mLUTs[0]; | ||
} | ||
auto index = std::distance(mIntervalsP.begin(), upper); | ||
index += (isNegative) ? 0 : mIntervalsP.size() * nDim; | ||
if constexpr (nDim == 1) { | ||
return mLUTs[index]; | ||
} else { | ||
if (iDim == 0) { | ||
return mLUTs[index + 0]; | ||
} else if (iDim == 1) { | ||
return mLUTs[index + 1]; | ||
} else { | ||
return mLUTs[index + 2]; | ||
} | ||
} | ||
} | ||
|
||
private: | ||
std::vector<float> mIntervalsP; ///< half-open interval upper bounds starting at 0, e.g., for {1.0,2.0,...} is (-inf,1.0], (1.0,2.0], (2.0, ...) | ||
std::vector<TGraph> mLUTs; ///< corresponding likelihood lookup tables | ||
|
||
ClassDefNV(LUT, 1); | ||
}; | ||
} // namespace detail | ||
|
||
/// This is the ML Base class which defines the interface all machine learning | ||
/// models. | ||
template <int nDim> | ||
class LQND : public PIDBase | ||
{ | ||
static_assert(nDim == 1 || nDim == 3, "Likelihood only for 1/3 dimension"); | ||
using PIDBase::PIDBase; | ||
|
||
public: | ||
~LQND() = default; | ||
|
||
void init(o2::framework::ProcessingContext& pc) final | ||
{ | ||
// retrieve lookup tables (LUTs) from ccdb | ||
mLUTs = *(pc.inputs().get<detail::LUT<nDim>*>(Form("lq%ddlut", nDim))); | ||
} | ||
|
||
float process(const TrackTRD& trkIn, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final | ||
{ | ||
const auto& trkSeed = isTPCTRD ? input.getTPCTracks()[trkIn.getRefGlobalTrackId()].getParamOut() : input.getTPCITSTracks()[trkIn.getRefGlobalTrackId()].getParamOut(); // seeding track | ||
auto trk = trkSeed; | ||
|
||
const auto isNegative = std::signbit(trkSeed.getSign()); // positive and negative charged particles are treated differently since ExB effects the charge distributions | ||
const auto& trackletsRaw = input.getTRDTracklets(); | ||
float lei0{1.f}, lei1{1.f}, lei2{1.f}, lpi0{1.f}, lpi1{1.f}, lpi2{1.f}; // likelihood per layer | ||
for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) { | ||
int trkltId = trkIn.getTrackletIndex(iLayer); | ||
if (trkltId < 0) { // no tracklet attached | ||
continue; | ||
} | ||
const auto xCalib = input.getTRDCalibratedTracklets()[trkIn.getTrackletIndex(iLayer)].getX(); | ||
if (!o2::base::Propagator::Instance()->PropagateToXBxByBz(trk, xCalib, o2::base::Propagator::MAX_SIN_PHI, o2::base::Propagator::MAX_STEP, o2::base::Propagator::MatCorrType::USEMatCorrNONE)) { | ||
LOGF(debug, "Track propagation failed in layer %i (pt=%f, xTrk=%f, xToGo=%f)", iLayer, trk.getPt(), trk.getX(), xCalib); | ||
continue; | ||
} | ||
const auto snp = trk.getSnp(); | ||
const auto tgl = trk.getTgl(); | ||
const auto& trklt = trackletsRaw[trkltId]; | ||
const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkIn, input, snp, tgl); // correct charges | ||
if constexpr (nDim == 1) { | ||
auto lut = mLUTs.get(trk.getP(), isNegative); | ||
auto ll1{1.f}; | ||
ll1 = lut.Eval(q0 + q1 + q2); | ||
lei0 *= ll1; | ||
lpi0 *= (1.f - ll1); | ||
} else { | ||
auto lut1 = mLUTs.get(trk.getP(), isNegative, 0); | ||
auto lut2 = mLUTs.get(trk.getP(), isNegative, 1); | ||
auto lut3 = mLUTs.get(trk.getP(), isNegative, 2); | ||
auto ll1{1.f}; | ||
auto ll2{1.f}; | ||
auto ll3{1.f}; | ||
ll1 = lut1.Eval(q0); | ||
ll2 = lut2.Eval(q1); | ||
ll3 = lut3.Eval(q2); | ||
lei0 *= ll1; | ||
lei1 *= ll2; | ||
lei2 *= ll3; | ||
lpi0 *= (1.f - ll1); | ||
lpi1 *= (1.f - ll2); | ||
lpi2 *= (1.f - ll3); | ||
} | ||
} | ||
|
||
return (lei0 * lei1 * lei2) / (lei0 * lei1 * lei2 + lpi0 * lpi1 * lpi2); // combined likelihood | ||
} | ||
|
||
private: | ||
detail::LUT<nDim> mLUTs; ///< likelihood lookup tables | ||
|
||
ClassDefNV(LQND, 1); | ||
}; | ||
|
||
using LQ1D = LQND<1>; | ||
using LQ3D = LQND<3>; | ||
|
||
} // namespace trd | ||
} // namespace o2 | ||
|
||
#endif |
Oops, something went wrong.