Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing root TMVA for high level analysis #448

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ set(ROOT_REQUIRED_LIBRARIES
Gdml
Minuit
Spectrum
XMLIO)
XMLIO
TMVA)

# Auto schema evolution for ROOT
if (NOT DEFINED REST_SE)
Expand Down
66 changes: 66 additions & 0 deletions examples/tmva.rml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
<TRestManager>

<TRestDataSetTMVA name="TMVA" verboseLevel="info">
<observable name="tckAna_MaxTrack_XYZ_SigmaZ2"/>
<observable name="tckAna_MaxTrackEnergyBalanceXY"/>
<observable name="tckAna_MaxTrackEnergyRatio"/>
<observable name="tckAna_MaxTrack_XZ_SigmaX" />
<observable name="tckAna_MaxTrack_YZ_SigmaY" />
<observable name="tckAna_MaxTrackxySigmaBalance"/>

<addBackgroundCut name="ParamCut"/>
<addSignalCut name="ParamCut"/>
<addSignalCut name="EnergyCut"/>

<TRestCut name="EnergyCut" verboseLevel="info">
<cut name="c1" variable="calib_Energy" condition=">4" />
<cut name="c2" variable="calib_Energy" condition="<8" />
</TRestCut>

<TRestCut name="ParamCut" verboseLevel="info">
<cut name="c3" variable="tckAna_nTracks_X" condition=">0" />
<cut name="c4" variable="tckAna_nTracks_Y" condition=">0" />
<cut name="c5" variable="tckAna_MaxTrackEnergyRatio" condition="<0.1" />
<cut name="c8" variable="tckAna_MaxTrack_XYZ_SigmaZ2" condition="<20." />
<cut name="c9" variable="tckAna_MaxTrackEnergyBalanceXY" condition="<5"/>
<cut name="c10" variable="tckAna_MaxTrackEnergyBalanceXY" condition=">-5"/>
<cut name="c11" variable="tckAna_MaxTrackxySigmaBalance" condition=">-1"/>
<cut name="c12" variable="tckAna_MaxTrackxySigmaBalance" condition="<1"/>
</TRestCut>

<addMethod name="Likelihood" parameters="H:!V:TransformOutput:PDFInterpol=Spline2:NSmoothSig[0]=20:NSmoothBkg[0]=10:NSmoothBkg[1]=5:NSmooth=1:NAvEvtPerBin=10"/>
<addMethod name="LikelihoodKDE" parameters="!H:!V:!TransformOutput:PDFInterpol=KDE:KDEtype=Gauss:KDEiter=Adaptive:KDEFineFactor=0.3:KDEborder=None:NAvEvtPerBin=10"/>
<addMethod name="Fisher" parameters="H:!V:Fisher:VarTransform=None:CreateMVAPdfs:PDFInterpolMVAPdf=Spline2:NbinsMVAPdf=50:NsmoothMVAPdf=10"/>
<addMethod name="BDT" parameters="!V:NTrees=200:MinNodeSize=2.5%:MaxDepth=2:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20"/>
<addMethod name="MLP" parameters="!H:!V:NeuronType=tanh:VarTransform=N:NCycles=100:HiddenLayers=N+5:TestRate=5:!UseRegulator"/>

</TRestDataSetTMVA>


<TRestDataSetTMVAClassification name="TMVA" verboseLevel="info">
<observable name="tckAna_MaxTrack_XYZ_SigmaZ2"/>
<observable name="tckAna_MaxTrackEnergyBalanceXY"/>
<observable name="tckAna_MaxTrackEnergyRatio"/>
<observable name="tckAna_MaxTrack_XZ_SigmaX" />
<observable name="tckAna_MaxTrack_YZ_SigmaY" />
<observable name="tckAna_MaxTrackxySigmaBalance"/>

<parameter name="tmvaFile" value="/storage/iaxo/iaxo-lsc/analysis/ana/test/test/weights/TMVA_Classification_BDT.weights.xml" />
<parameter name="tmvaMethod" value="BDT" />
<addCut name="ParamCut"/>

<TRestCut name="ParamCut" verboseLevel="info">
<cut name="c3" variable="tckAna_nTracks_X" condition=">0" />
<cut name="c4" variable="tckAna_nTracks_Y" condition=">0" />
<cut name="c5" variable="tckAna_MaxTrackEnergyRatio" condition="<0.1" />
<cut name="c8" variable="tckAna_MaxTrack_XYZ_SigmaZ2" condition="<20." />
<cut name="c9" variable="tckAna_MaxTrackEnergyBalanceXY" condition="<5"/>
<cut name="c10" variable="tckAna_MaxTrackEnergyBalanceXY" condition=">-5"/>
<cut name="c11" variable="tckAna_MaxTrackxySigmaBalance" condition=">-1"/>
<cut name="c12" variable="tckAna_MaxTrackxySigmaBalance" condition="<1"/>

</TRestCut>

</TRestDataSetTMVAClassification>

</TRestManager>
91 changes: 91 additions & 0 deletions source/framework/analysis/inc/TRestDataSetTMVA.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*************************************************************************
* This file is part of the REST software framework. *
* *
* Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) *
* For more information see https://gifna.unizar.es/trex *
* *
* REST is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* REST is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have a copy of the GNU General Public License along with *
* REST in $REST_PATH/LICENSE. *
* If not, see https://www.gnu.org/licenses/. *
* For the list of contributors see $REST_PATH/CREDITS. *
*************************************************************************/

#ifndef REST_TRestDataSetTMVA
#define REST_TRestDataSetTMVA

#include "TH1F.h"
#include "TMVA/Types.h"
#include "TRestCut.h"
#include "TRestMetadata.h"

/// This class is meant to evaluate several TMVA methods in datasets
class TRestDataSetTMVA : public TRestMetadata {
private:
/// Name of the output file
std::string fOutputFileName = ""; //<

/// Name of the signal dataSet
std::string fDataSetSignal = ""; //<

/// Name of the background dataset
std::string fDataSetBackground = ""; //<

/// Name of the output path for the xml files
std::string fOutputPath = ""; //<

/// Vector containing different obserbable names
std::vector<std::string> fObsName; //<

/// Add method to compute TMVA, https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf for more details
std::vector<std::pair<std::string, std::string> > fMethod; //<

/// Cuts over background dataset for PDF selection
TRestCut* fBackgroundCut = nullptr; //<

/// Cuts over signal dataset for PDF selection
TRestCut* fSignalCut = nullptr; //<

/// If true display ROC curve after evaluating all methods
bool fDrawROCCurve = true; //<

/// Map with supported TMVA methods, please add more if something is missing
const std::map<std::string, TMVA::Types::EMVA> fMethodMap = {
//<
{"Likelihood", TMVA::Types::kLikelihood}, // Likelihood ("naive Bayes estimator")
{"LikelihoodKDE",
TMVA::Types::kLikelihood}, // Use a kernel density estimator to approximate the PDFs
{"Fisher", TMVA::Types::kFisher}, // Fisher discriminant (same as LD)
{"BDT", TMVA::Types::kBDT}, // Boosted Decision Trees
{"MLP", TMVA::Types::kMLP} // Multi-Layer Perceptron (Neural Network)
};

void Initialize() override;
void InitFromConfigFile() override;

public:
void PrintMetadata() override;

void ComputeTMVA();

inline void SetDataSetSignal(const std::string& dSName) { fDataSetSignal = dSName; }
inline void SetDataSetBackground(const std::string& dSName) { fDataSetBackground = dSName; }
inline void SetOutputFileName(const std::string& outName) { fOutputFileName = outName; }
inline void SetOutputPath(const std::string& outPath) { fOutputPath = outPath; }

TRestDataSetTMVA();
TRestDataSetTMVA(const char* configFilename, std::string name = "");
~TRestDataSetTMVA();

ClassDefOverride(TRestDataSetTMVA, 1);
};
#endif
70 changes: 70 additions & 0 deletions source/framework/analysis/inc/TRestDataSetTMVAClassification.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*************************************************************************
* This file is part of the REST software framework. *
* *
* Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) *
* For more information see https://gifna.unizar.es/trex *
* *
* REST is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* REST is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have a copy of the GNU General Public License along with *
* REST in $REST_PATH/LICENSE. *
* If not, see https://www.gnu.org/licenses/. *
* For the list of contributors see $REST_PATH/CREDITS. *
*************************************************************************/

#ifndef REST_TRestDataSetTMVAClassification
#define REST_TRestDataSetTMVAClassification

#include "TH1F.h"
#include "TRestCut.h"
#include "TRestMetadata.h"

/// This class is meant to classify a given dataset with a particular TMVA method
class TRestDataSetTMVAClassification : public TRestMetadata {
private:
/// Name of the output file
std::string fOutputFileName = ""; //<

/// Name of the dataSet to classify
std::string fDataSetName = ""; //<

/// Name of the TMVA method
std::string fTmvaMethod = ""; //<

/// Name of the TMVA weights file
std::string fTmvaFile = ""; //<

/// Vector containing different obserbable names
std::vector<std::string> fObsName; //<

/// Cuts over the dataset for PDF selection
TRestCut* fCut = nullptr; //<

void Initialize() override;
void InitFromConfigFile() override;

public:
void PrintMetadata() override;

void ClassifyTMVA();

inline void SetDataSet(const std::string& dSName) { fDataSetName = dSName; }
inline void SetTMVAMethod(const std::string& method) { fTmvaMethod = method; }
inline void SetTMVAFile(const std::string& file) { fTmvaFile = file; }
inline void SetOutputFileName(const std::string& outName) { fOutputFileName = outName; }

TRestDataSetTMVAClassification();
TRestDataSetTMVAClassification(const char* configFilename, std::string name = "");
~TRestDataSetTMVAClassification();

ClassDefOverride(TRestDataSetTMVAClassification, 1);
};
#endif
Loading