diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6d4e7928f..ea7ed0c5e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -79,7 +79,8 @@ set(ROOT_REQUIRED_LIBRARIES
Gdml
Minuit
Spectrum
- XMLIO)
+ XMLIO
+ TMVA)
# Auto schema evolution for ROOT
if (NOT DEFINED REST_SE)
diff --git a/examples/tmva.rml b/examples/tmva.rml
new file mode 100644
index 000000000..fe965ce6c
--- /dev/null
+++ b/examples/tmva.rml
@@ -0,0 +1,66 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/source/framework/analysis/inc/TRestDataSetTMVA.h b/source/framework/analysis/inc/TRestDataSetTMVA.h
new file mode 100644
index 000000000..e650a16e1
--- /dev/null
+++ b/source/framework/analysis/inc/TRestDataSetTMVA.h
@@ -0,0 +1,91 @@
+/*************************************************************************
+ * This file is part of the REST software framework. *
+ * *
+ * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) *
+ * For more information see https://gifna.unizar.es/trex *
+ * *
+ * REST is free software: you can redistribute it and/or modify *
+ * it under the terms of the GNU General Public License as published by *
+ * the Free Software Foundation, either version 3 of the License, or *
+ * (at your option) any later version. *
+ * *
+ * REST is distributed in the hope that it will be useful, *
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of *
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
+ * GNU General Public License for more details. *
+ * *
+ * You should have a copy of the GNU General Public License along with *
+ * REST in $REST_PATH/LICENSE. *
+ * If not, see https://www.gnu.org/licenses/. *
+ * For the list of contributors see $REST_PATH/CREDITS. *
+ *************************************************************************/
+
+#ifndef REST_TRestDataSetTMVA
+#define REST_TRestDataSetTMVA
+
+#include "TH1F.h"
+#include "TMVA/Types.h"
+#include "TRestCut.h"
+#include "TRestMetadata.h"
+
+/// This class is meant to evaluate several TMVA methods in datasets
+class TRestDataSetTMVA : public TRestMetadata {
+ private:
+ /// Name of the output file
+ std::string fOutputFileName = ""; //<
+
+ /// Name of the signal dataSet
+ std::string fDataSetSignal = ""; //<
+
+ /// Name of the background dataset
+ std::string fDataSetBackground = ""; //<
+
+ /// Name of the output path for the xml files
+ std::string fOutputPath = ""; //<
+
+ /// Vector containing different obserbable names
+ std::vector fObsName; //<
+
+ /// Add method to compute TMVA, https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf for more details
+ std::vector > fMethod; //<
+
+ /// Cuts over background dataset for PDF selection
+ TRestCut* fBackgroundCut = nullptr; //<
+
+ /// Cuts over signal dataset for PDF selection
+ TRestCut* fSignalCut = nullptr; //<
+
+ /// If true display ROC curve after evaluating all methods
+ bool fDrawROCCurve = true; //<
+
+ /// Map with supported TMVA methods, please add more if something is missing
+ const std::map fMethodMap = {
+ //<
+ {"Likelihood", TMVA::Types::kLikelihood}, // Likelihood ("naive Bayes estimator")
+ {"LikelihoodKDE",
+ TMVA::Types::kLikelihood}, // Use a kernel density estimator to approximate the PDFs
+ {"Fisher", TMVA::Types::kFisher}, // Fisher discriminant (same as LD)
+ {"BDT", TMVA::Types::kBDT}, // Boosted Decision Trees
+ {"MLP", TMVA::Types::kMLP} // Multi-Layer Perceptron (Neural Network)
+ };
+
+ void Initialize() override;
+ void InitFromConfigFile() override;
+
+ public:
+ void PrintMetadata() override;
+
+ void ComputeTMVA();
+
+ inline void SetDataSetSignal(const std::string& dSName) { fDataSetSignal = dSName; }
+ inline void SetDataSetBackground(const std::string& dSName) { fDataSetBackground = dSName; }
+ inline void SetOutputFileName(const std::string& outName) { fOutputFileName = outName; }
+ inline void SetOutputPath(const std::string& outPath) { fOutputPath = outPath; }
+
+ TRestDataSetTMVA();
+ TRestDataSetTMVA(const char* configFilename, std::string name = "");
+ ~TRestDataSetTMVA();
+
+ ClassDefOverride(TRestDataSetTMVA, 1);
+};
+#endif
diff --git a/source/framework/analysis/inc/TRestDataSetTMVAClassification.h b/source/framework/analysis/inc/TRestDataSetTMVAClassification.h
new file mode 100644
index 000000000..9ca4ab43d
--- /dev/null
+++ b/source/framework/analysis/inc/TRestDataSetTMVAClassification.h
@@ -0,0 +1,70 @@
+/*************************************************************************
+ * This file is part of the REST software framework. *
+ * *
+ * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) *
+ * For more information see https://gifna.unizar.es/trex *
+ * *
+ * REST is free software: you can redistribute it and/or modify *
+ * it under the terms of the GNU General Public License as published by *
+ * the Free Software Foundation, either version 3 of the License, or *
+ * (at your option) any later version. *
+ * *
+ * REST is distributed in the hope that it will be useful, *
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of *
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
+ * GNU General Public License for more details. *
+ * *
+ * You should have a copy of the GNU General Public License along with *
+ * REST in $REST_PATH/LICENSE. *
+ * If not, see https://www.gnu.org/licenses/. *
+ * For the list of contributors see $REST_PATH/CREDITS. *
+ *************************************************************************/
+
+#ifndef REST_TRestDataSetTMVAClassification
+#define REST_TRestDataSetTMVAClassification
+
+#include "TH1F.h"
+#include "TRestCut.h"
+#include "TRestMetadata.h"
+
+/// This class is meant to classify a given dataset with a particular TMVA method
+class TRestDataSetTMVAClassification : public TRestMetadata {
+ private:
+ /// Name of the output file
+ std::string fOutputFileName = ""; //<
+
+ /// Name of the dataSet to classify
+ std::string fDataSetName = ""; //<
+
+ /// Name of the TMVA method
+ std::string fTmvaMethod = ""; //<
+
+ /// Name of the TMVA weights file
+ std::string fTmvaFile = ""; //<
+
+ /// Vector containing different obserbable names
+ std::vector fObsName; //<
+
+ /// Cuts over the dataset for PDF selection
+ TRestCut* fCut = nullptr; //<
+
+ void Initialize() override;
+ void InitFromConfigFile() override;
+
+ public:
+ void PrintMetadata() override;
+
+ void ClassifyTMVA();
+
+ inline void SetDataSet(const std::string& dSName) { fDataSetName = dSName; }
+ inline void SetTMVAMethod(const std::string& method) { fTmvaMethod = method; }
+ inline void SetTMVAFile(const std::string& file) { fTmvaFile = file; }
+ inline void SetOutputFileName(const std::string& outName) { fOutputFileName = outName; }
+
+ TRestDataSetTMVAClassification();
+ TRestDataSetTMVAClassification(const char* configFilename, std::string name = "");
+ ~TRestDataSetTMVAClassification();
+
+ ClassDefOverride(TRestDataSetTMVAClassification, 1);
+};
+#endif
diff --git a/source/framework/analysis/src/TRestDataSetTMVA.cxx b/source/framework/analysis/src/TRestDataSetTMVA.cxx
new file mode 100644
index 000000000..1d957efde
--- /dev/null
+++ b/source/framework/analysis/src/TRestDataSetTMVA.cxx
@@ -0,0 +1,371 @@
+/*************************************************************************
+ * This file is part of the REST software framework. *
+ * *
+ * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) *
+ * For more information see https://gifna.unizar.es/trex *
+ * *
+ * REST is free software: you can redistribute it and/or modify *
+ * it under the terms of the GNU General Public License as published by *
+ * the Free Software Foundation, either version 3 of the License, or *
+ * (at your option) any later version. *
+ * *
+ * REST is distributed in the hope that it will be useful, *
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of *
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
+ * GNU General Public License for more details. *
+ * *
+ * You should have a copy of the GNU General Public License along with *
+ * REST in $REST_PATH/LICENSE. *
+ * If not, see https://www.gnu.org/licenses/. *
+ * For the list of contributors see $REST_PATH/CREDITS. *
+ *************************************************************************/
+
+/////////////////////////////////////////////////////////////////////////
+/// TRestDataSetTMVA is meant to evaluate different TMVA methods in datasets.
+/// For more information about TMVA, check https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf
+/// So far, only Likelihood, LikelihoodKDE, Fisher, BDT and MLP methods are
+/// supported. TMVA requires a signal and a background dataset from which the
+/// different TMVA methods are computed. The different methods are evaluated
+/// in a set of observables that are provided in the RML file. Different cuts
+/// can be performed in either the signal or the background datasets prior to
+/// the TMVA evaluation. The output of this class is a root file which contains
+/// a signal and a background tree with the cuts applied and the different observables
+/// that are generated with the TMVA analysis. In addition, a folder is created
+/// with different xml files that contain the output of the TMVA evaluation that
+/// can be used to compute the TMVA classification via TRestDataSetTMVAClassification.
+///
+/// A summary of the basic parameters is described below:
+/// * **outputFileName**: Name of the output file
+/// * **dataSetSignal**: Name of the dataset file containing the signal
+/// * **dataSetBackground**: Name of the dataset file containing the background
+/// * **outputPath**: Name of the output path with the evaluation results
+/// * **drawROCCurve**: If true display the ROC curve for the evaluation of all methods
+///
+/// The different observables for the TMVA analysis can be added with the following key:
+/// \code
+///
+/// \endcode
+///
+/// * **name**: Name of the observable be computed
+///
+/// The different signal and background cuts can be added awith the following key:
+/// \code
+///
+///
+///
+/// \endcode
+///
+/// Where the cut name (e.g. ParamCut or Energy cut from above) have to be defined inside
+/// the RML file, e.g.:
+/// \code
+///
+///
+///
+///
+///
+///
+///
+///
+/// \endcode
+///
+/// Please, check TRestCut class for more info.
+///
+/// The different TMVA methods can be added wit the following key:
+/// \code
+///
+/// \endcode
+/// The different parameters for adding TMVA methods are described below:
+/// * **name**: Name of the TMVA method, only Likelihood, LikelihoodKDE, Fisher, BDT and MLP
+/// are supported so far.
+/// * **parameters**: String parameters to be used in the TMVA method, for more information
+/// check https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf
+///
+/// ### Examples
+/// Example of RML config file:
+/// \code
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+/// \endcode
+///
+/// Example to perform TMVA evaluation using restRoot:
+/// \code
+/// [0] TRestDataSetTMVA tmva("tmva.rml");
+/// [1] tmva.SetDataSetSignal("DataSetSignal.root");
+/// [2] tmva.SetDataSetBackground("DataSetBackground.root");
+/// [3] tmva.SetOutputFileName("MyDataSetEvaluation.root");
+/// [4] tmva.SetOutputPath("MyDataSetFiles");
+/// [5] tmva.ComputeTMVA();
+/// \endcode
+///
+/// In addition it is possible to display TMVA results after evaluating all methods,
+/// using root or restRoot;
+/// \code
+/// [0] TMVA::TMVAGui("MyDataSetEvaluation.root")
+/// \endcode
+///
+///----------------------------------------------------------------------
+///
+/// REST-for-Physics - Software for Rare Event Searches Toolkit
+///
+/// History of developments:
+///
+/// 2023-05: First implementation of TRestDataSetTMVA
+/// JuanAn Garcia
+///
+/// \class TRestDataSetTMVA
+/// \author: JuanAn Garcia e-mail: juanangp@unizar.es
+///
+///
+///
+
+#include "TRestDataSetTMVA.h"
+
+#include "ROOT/RDFHelpers.hxx"
+#include "TMVA/CrossValidation.h"
+#include "TMVA/DataLoader.h"
+#include "TMVA/Factory.h"
+#include "TMVA/TMVAGui.h"
+#include "TMVA/Tools.h"
+#include "TRestDataSet.h"
+
+ClassImp(TRestDataSetTMVA);
+
+///////////////////////////////////////////////
+/// \brief Default constructor
+///
+TRestDataSetTMVA::TRestDataSetTMVA() { Initialize(); }
+
+/////////////////////////////////////////////
+/// \brief Constructor loading data from a config file
+///
+/// If no configuration path is defined using TRestMetadata::SetConfigFilePath
+/// the path to the config file must be specified using full path, absolute or
+/// relative.
+///
+/// The default behaviour is that the config file must be specified with
+/// full path, absolute or relative.
+///
+/// \param configFilename A const char* that defines the RML filename.
+/// \param name The name of the metadata section. It will be used to find the
+/// corresponding TRestDataSetTMVA section inside the RML.
+///
+TRestDataSetTMVA::TRestDataSetTMVA(const char* configFilename, std::string name)
+ : TRestMetadata(configFilename) {
+ LoadConfigFromFile(fConfigFileName, name);
+ Initialize();
+
+ if (GetVerboseLevel() >= TRestStringOutput::REST_Verbose_Level::REST_Info) PrintMetadata();
+}
+
+///////////////////////////////////////////////
+/// \brief Default destructor
+///
+TRestDataSetTMVA::~TRestDataSetTMVA() {}
+
+///////////////////////////////////////////////
+/// \brief Function to initialize input/output event members and define
+/// the section name
+///
+void TRestDataSetTMVA::Initialize() { SetSectionName(this->ClassName()); }
+
+///////////////////////////////////////////////
+/// \brief Function to initialize some variables from
+/// configfile
+///
+void TRestDataSetTMVA::InitFromConfigFile() {
+ Initialize();
+ TRestMetadata::InitFromConfigFile();
+
+ TiXmlElement* obsDefinition = GetElement("observable");
+ while (obsDefinition != nullptr) {
+ std::string obsName = GetFieldValue("name", obsDefinition);
+ if (obsName.empty() || obsName == "Not defined") {
+ RESTError << "< observable variable key does not contain a name!" << RESTendl;
+ exit(1);
+ } else {
+ fObsName.push_back(obsName);
+ }
+
+ obsDefinition = GetNextElement(obsDefinition);
+ }
+
+ TiXmlElement* cutele = GetElement("addBackgroundCut");
+ while (cutele != nullptr) {
+ std::string cutName = GetParameter("name", cutele, "");
+ if (!cutName.empty()) {
+ if (fBackgroundCut == nullptr) {
+ fBackgroundCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName);
+ } else {
+ fBackgroundCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName));
+ }
+ }
+ cutele = GetNextElement(cutele);
+ }
+
+ cutele = GetElement("addSignalCut");
+ while (cutele != nullptr) {
+ std::string cutName = GetParameter("name", cutele, "");
+ if (!cutName.empty()) {
+ if (fSignalCut == nullptr) {
+ fSignalCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName);
+ } else {
+ fSignalCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName));
+ }
+ }
+ cutele = GetNextElement(cutele);
+ }
+
+ TiXmlElement* method = GetElement("addMethod");
+ while (method != nullptr) {
+ std::string name = GetParameter("name", method, "");
+ std::string params = GetParameter("parameters", method, "");
+ if (name.empty() || params.empty()) {
+ RESTWarning << "Empty method" << RESTendl;
+ } else {
+ fMethod.push_back(std::make_pair(name, params));
+ }
+ method = GetNextElement(method);
+ }
+
+ if (fObsName.empty()) {
+ RESTError << "No observables provided, exiting..." << RESTendl;
+ exit(1);
+ }
+
+ if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", "");
+}
+
+/////////////////////////////////////////////
+/// \brief This function computes the TMVA using
+/// the different methods provided via config file
+/// and the signal and background dataSets. The results
+/// are stored in an output root file and a folder. Note
+/// that it doesn't provide any usable dataset, just standard
+/// root files.
+///
+void TRestDataSetTMVA::ComputeTMVA() {
+ if (fOutputFileName.empty() || fOutputPath.empty() || fDataSetSignal.empty() ||
+ fDataSetBackground.empty()) {
+ RESTError << "Empty output file name, path, signal or background files " << RESTendl;
+ PrintMetadata();
+ exit(1);
+ }
+
+ if (fMethod.empty()) {
+ RESTError << "No TMVA methods have been added " << RESTendl;
+ PrintMetadata();
+ exit(1);
+ }
+
+ // Add signal dataset
+ TRestDataSet signal;
+ signal.Import(fDataSetSignal);
+ auto dfSignal = signal.MakeCut(fSignalCut);
+ dfSignal.Snapshot("Signal", fOutputFileName);
+
+ // Add background dataset
+ TRestDataSet bck;
+ bck.Import(fDataSetBackground);
+ auto dfBackground = bck.MakeCut(fBackgroundCut);
+ ROOT::RDF::RSnapshotOptions opt;
+ opt.fMode = "update";
+ dfBackground.Snapshot("Background", fOutputFileName, "", opt);
+
+ auto outputFile = TFile::Open(fOutputFileName.c_str(), "UPDATE");
+
+ auto signalTree = outputFile->Get("Signal");
+ auto bckTree = outputFile->Get("Background");
+
+ TMVA::Factory factory("TMVA_Classification", outputFile,
+ "!V:ROC:!Silent:Color:AnalysisType=Classification");
+
+ TMVA::DataLoader loader(fOutputPath);
+
+ // Add observables for the evaluation
+ for (const auto& obs : fObsName) loader.AddVariable(obs);
+
+ loader.AddSignalTree(signalTree, 1.0);
+ loader.AddBackgroundTree(bckTree, 1.0);
+ loader.PrepareTrainingAndTestTree("", "",
+ ":SplitMode=Random"
+ ":NormMode=NumEvents"
+ ":!V");
+
+ // Add different TMVA methods
+ for (const auto& [name, params] : fMethod) {
+ auto it = fMethodMap.find(name);
+ if (it == fMethodMap.end()) {
+ RESTWarning << "Method " << name << " not supported " << RESTendl;
+ RESTWarning << "Currently supported methods: ";
+ for (const auto& [method, val] : fMethodMap) RESTWarning << method << ", ";
+ RESTWarning << RESTendl;
+ continue;
+ }
+ std::cout << "Added method " << name << " " << it->second << " " << params << std::endl;
+ factory.BookMethod(&loader, it->second, name.c_str(), params.c_str());
+ }
+
+ // Train, test and evaluate all methods
+ factory.TrainAllMethods();
+ factory.TestAllMethods();
+ factory.EvaluateAllMethods();
+
+ // Draw ROC curve
+ if (fDrawROCCurve && gApplication != nullptr && gApplication->IsRunning()) {
+ auto c1 = factory.GetROCCurve(&loader);
+ c1->Draw();
+ }
+
+ outputFile->Close();
+}
+
+/////////////////////////////////////////////
+/// \brief Prints on screen the information about the metadata members of TRestDataSetTMVA
+///
+void TRestDataSetTMVA::PrintMetadata() {
+ TRestMetadata::PrintMetadata();
+
+ RESTMetadata << " Observables to compute: " << RESTendl;
+ for (const auto& obs : fObsName) {
+ RESTMetadata << obs << RESTendl;
+ }
+ RESTMetadata << " TMVA Methods " << RESTendl;
+ for (const auto& [name, params] : fMethod) {
+ RESTMetadata << name << " " << params << RESTendl;
+ }
+ RESTMetadata << "----" << RESTendl;
+}
diff --git a/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx b/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx
new file mode 100644
index 000000000..2f500f059
--- /dev/null
+++ b/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx
@@ -0,0 +1,295 @@
+/*************************************************************************
+ * This file is part of the REST software framework. *
+ * *
+ * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) *
+ * For more information see https://gifna.unizar.es/trex *
+ * *
+ * REST is free software: you can redistribute it and/or modify *
+ * it under the terms of the GNU General Public License as published by *
+ * the Free Software Foundation, either version 3 of the License, or *
+ * (at your option) any later version. *
+ * *
+ * REST is distributed in the hope that it will be useful, *
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of *
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
+ * GNU General Public License for more details. *
+ * *
+ * You should have a copy of the GNU General Public License along with *
+ * REST in $REST_PATH/LICENSE. *
+ * If not, see https://www.gnu.org/licenses/. *
+ * For the list of contributors see $REST_PATH/CREDITS. *
+ *************************************************************************/
+
+/////////////////////////////////////////////////////////////////////////
+/// TRestDataSetTMVAClassification performs the classification of a given
+/// dataSet using as input the results of the TMVA evaluation methods
+/// generated using TRestDataSetTMVA. Note that the observables used on
+/// TRestDataSetTMVA and TRestDataSetTMVA needs to match. This class generates
+/// as output a dataset with a new observable which is defined using the name of
+/// the TMVA method that has been used to classify the dataset. Only one TMVA
+/// method is classified in ClassifyTMVA. An output dataset is generated by
+/// definining a new observable with the TMVA method e.g. BDT_score
+///
+/// A summary of the basic parameters is described below:
+/// * **dataSetName**: Name of the dataSet to be classified
+/// * **tmvaFile**: Name of the xml input file with the tmva weigths
+/// * **tmvaMethod**: Name of the TMVA method used to classify
+/// * **outputFileName**: Name of the output dataset
+///
+///
+/// The different observables for the TMVA classification can be added with the following key:
+/// \code
+///
+/// \endcode
+/// * **name**: Name of the observable be computed
+///
+/// Note that the observable names has to match the ones using for the evaluation of
+/// a particular TMVA method
+///
+/// Different cuts over the dataset can be added with the following key:
+/// \code
+///
+/// \endcode
+///
+/// ### Examples
+/// Example of RML config file:
+/// \code
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+///
+/// \endcode
+///
+/// Example of TRestDataSetTMVAClassification using restRoot:
+/// \code
+/// TRestDataSetTMVAClassification tmva("tmva.rml");
+/// tmva.SetDataSet("MyDataSet.root");
+/// tmva.SetOutputFileName("MyClassifiedDataSet.root");
+/// tmva.ClassifyTMVA();
+/// \endcode
+///
+///----------------------------------------------------------------------
+///
+/// REST-for-Physics - Software for Rare Event Searches Toolkit
+///
+/// History of developments:
+///
+/// 2023-03: First implementation of TRestDataSetTMVAClassification
+/// JuanAn Garcia
+///
+/// \class TRestDataSetTMVAClassification
+/// \author: JuanAn Garcia e-mail: juanangp@unizar.es
+///
+///
+///
+
+#include "TRestDataSetTMVAClassification.h"
+
+#include "ROOT/RDFHelpers.hxx"
+#include "TMVA/CrossValidation.h"
+#include "TMVA/DataLoader.h"
+#include "TMVA/Factory.h"
+#include "TMVA/RInferenceUtils.hxx"
+#include "TMVA/RReader.hxx"
+#include "TMVA/RTensorUtils.hxx"
+#include "TMVA/Tools.h"
+#include "TRestDataSet.h"
+
+ClassImp(TRestDataSetTMVAClassification);
+
+///////////////////////////////////////////////
+/// \brief Default constructor
+///
+TRestDataSetTMVAClassification::TRestDataSetTMVAClassification() { Initialize(); }
+
+/////////////////////////////////////////////
+/// \brief Constructor loading data from a config file
+///
+/// If no configuration path is defined using TRestMetadata::SetConfigFilePath
+/// the path to the config file must be specified using full path, absolute or
+/// relative.
+///
+/// The default behaviour is that the config file must be specified with
+/// full path, absolute or relative.
+///
+/// \param configFilename A const char* that defines the RML filename.
+/// \param name The name of the metadata section. It will be used to find the
+/// corresponding TRestDataSetTMVAClassification section inside the RML.
+///
+TRestDataSetTMVAClassification::TRestDataSetTMVAClassification(const char* configFilename, std::string name)
+ : TRestMetadata(configFilename) {
+ LoadConfigFromFile(fConfigFileName, name);
+ Initialize();
+
+ if (GetVerboseLevel() >= TRestStringOutput::REST_Verbose_Level::REST_Info) PrintMetadata();
+}
+
+///////////////////////////////////////////////
+/// \brief Default destructor
+///
+TRestDataSetTMVAClassification::~TRestDataSetTMVAClassification() {}
+
+///////////////////////////////////////////////
+/// \brief Function to initialize input/output event members and define
+/// the section name
+///
+void TRestDataSetTMVAClassification::Initialize() { SetSectionName(this->ClassName()); }
+
+///////////////////////////////////////////////
+/// \brief Function to initialize some variables from
+/// configfile
+///
+void TRestDataSetTMVAClassification::InitFromConfigFile() {
+ Initialize();
+ TRestMetadata::InitFromConfigFile();
+
+ TiXmlElement* obsDefinition = GetElement("observable");
+ while (obsDefinition != nullptr) {
+ std::string obsName = GetFieldValue("name", obsDefinition);
+ if (obsName.empty() || obsName == "Not defined") {
+ RESTError << "< observable variable key does not contain a name!" << RESTendl;
+ exit(1);
+ } else {
+ fObsName.push_back(obsName);
+ }
+
+ obsDefinition = GetNextElement(obsDefinition);
+ }
+
+ TiXmlElement* cutele = GetElement("addCut");
+ while (cutele != nullptr) {
+ std::string cutName = GetParameter("name", cutele, "");
+ if (!cutName.empty()) {
+ if (fCut == nullptr) {
+ fCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName);
+ } else {
+ fCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName));
+ }
+ }
+ cutele = GetNextElement(cutele);
+ }
+
+ if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", "");
+}
+
+/////////////////////////////////////////////
+/// \brief This function computes the TMVA classification
+/// for a given dataSet. It requires a xml file with weigths
+/// from the output of TRestDataSetTMVA to perform the
+/// clasification for a given set of observables. This function
+/// defines a new observable with the score of the TMVA method
+/// provided in the input file that can be used for further
+/// signal and background discrimination.
+///
+void TRestDataSetTMVAClassification::ClassifyTMVA() {
+ PrintMetadata();
+
+ if (fObsName.empty()) {
+ RESTError << "No observables provided, exiting..." << RESTendl;
+ exit(1);
+ }
+
+ TMVA::Reader reader("!Color:!Silent");
+ std::vector var(fObsName.size());
+
+ // Add variables to the reader
+ for (unsigned int i = 0; i < fObsName.size(); i++) {
+ reader.AddVariable(fObsName[i].c_str(), &var[i]);
+ }
+
+ // Book TMVA method
+ reader.BookMVA(fTmvaMethod.c_str(), fTmvaFile.c_str());
+
+ // Lambda for evaluation of the method
+ auto eval = [&reader = reader, &tmvaMethod = fTmvaMethod](const std::vector& val) {
+ return reader.EvaluateMVA(val, tmvaMethod.c_str());
+ };
+
+ TRestDataSet dataSet;
+ dataSet.Import(fDataSetName);
+
+ auto df = dataSet.MakeCut(fCut);
+
+ std::string obsName = fTmvaMethod + "_score";
+
+ // Ugly but cannot pass vector size to ROOT::RDF::PassAsVec
+ switch (fObsName.size()) {
+ case 1:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<1, double>(eval), fObsName);
+ break;
+ case 2:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<2, double>(eval), fObsName);
+ break;
+ case 3:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<3, double>(eval), fObsName);
+ break;
+ case 4:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<4, double>(eval), fObsName);
+ break;
+ case 5:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<5, double>(eval), fObsName);
+ break;
+ case 6:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<6, double>(eval), fObsName);
+ break;
+ case 7:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<7, double>(eval), fObsName);
+ break;
+ case 8:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<8, double>(eval), fObsName);
+ break;
+ case 9:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<9, double>(eval), fObsName);
+ break;
+ case 10:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<10, double>(eval), fObsName);
+ break;
+ case 11:
+ df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<11, double>(eval), fObsName);
+ break;
+ default:
+ RESTError << "Number of observables " << fObsName.size() << " is not supported" << RESTendl;
+ exit(1);
+ }
+
+ dataSet.SetDataFrame(df);
+
+ if (!fOutputFileName.empty()) {
+ if (TRestTools::GetFileNameExtension(fOutputFileName) == "root") {
+ dataSet.Export(fOutputFileName);
+ TFile* f = TFile::Open(fOutputFileName.c_str(), "UPDATE");
+ this->Write();
+ f->Close();
+ }
+ }
+}
+
+/////////////////////////////////////////////
+/// \brief Prints on screen the information about the metadata members of TRestDataSetTMVAClassification
+///
+void TRestDataSetTMVAClassification::PrintMetadata() {
+ TRestMetadata::PrintMetadata();
+
+ RESTMetadata << " Observables to compute: " << RESTendl;
+ for (size_t i = 0; i < fObsName.size(); i++) {
+ RESTMetadata << fObsName[i] << RESTendl;
+ }
+ RESTMetadata << "----" << RESTendl;
+}
diff --git a/source/framework/core/inc/TRestDataSet.h b/source/framework/core/inc/TRestDataSet.h
index 94f0a490d..28d185f7a 100644
--- a/source/framework/core/inc/TRestDataSet.h
+++ b/source/framework/core/inc/TRestDataSet.h
@@ -169,8 +169,8 @@ class TRestDataSet : public TRestMetadata {
inline void SetQuantity(const std::map& quantity) { fQuantity = quantity; }
TRestDataSet& operator=(TRestDataSet& dS);
- void Import(const std::string& fileName);
void Import(std::vector fileNames);
+ void Import(const std::string& fileName, bool enableMT = true);
void Export(const std::string& filename);
ROOT::RDF::RNode MakeCut(const TRestCut* cut);
diff --git a/source/framework/core/inc/TRestRun.h b/source/framework/core/inc/TRestRun.h
index ffaf56b41..52cfe2b3d 100644
--- a/source/framework/core/inc/TRestRun.h
+++ b/source/framework/core/inc/TRestRun.h
@@ -92,6 +92,7 @@ class TRestRun : public TRestMetadata {
TFile* MergeToOutputFile(std::vector filefullnames, std::string outputfilename = "");
TFile* FormOutputFile();
TFile* UpdateOutputFile();
+ TFile* OpenAndUpdateOutputFile();
void PassOutputFile() {
fOutputFile = fInputFile;
diff --git a/source/framework/core/src/TRestDataSet.cxx b/source/framework/core/src/TRestDataSet.cxx
index 430aba8aa..566b279c3 100644
--- a/source/framework/core/src/TRestDataSet.cxx
+++ b/source/framework/core/src/TRestDataSet.cxx
@@ -890,7 +890,7 @@ TRestDataSet& TRestDataSet::operator=(TRestDataSet& dS) {
/// it import metadata info from the previous dataSet
/// while it opens the analysis tree
///
-void TRestDataSet::Import(const std::string& fileName) {
+void TRestDataSet::Import(const std::string& fileName, bool enableMT) {
if (TRestTools::GetFileNameExtension(fileName) != "root") {
RESTError << "Datasets can only be imported from root files" << RESTendl;
return;
@@ -918,7 +918,7 @@ void TRestDataSet::Import(const std::string& fileName) {
return;
}
- ROOT::EnableImplicitMT();
+ if (enableMT) ROOT::EnableImplicitMT();
fDataSet = ROOT::RDataFrame("AnalysisTree", fileName);
diff --git a/source/framework/core/src/TRestRun.cxx b/source/framework/core/src/TRestRun.cxx
index 3f533206b..6ae1e0ea7 100644
--- a/source/framework/core/src/TRestRun.cxx
+++ b/source/framework/core/src/TRestRun.cxx
@@ -986,24 +986,25 @@ TString TRestRun::FormFormat(const TString& FilenameFormat) {
TFile* TRestRun::MergeToOutputFile(vector filenames, string outputfilename) {
RESTDebug << "TRestRun::FormOutputFile. target : " << outputfilename << RESTendl;
string filename;
- TFileMerger* m = new TFileMerger(false);
+
+ TFileMerger m(false);
if (outputfilename.empty()) {
filename = fOutputFileName;
RESTInfo << "Creating file : " << filename << RESTendl;
- m->OutputFile(filename.c_str(), "RECREATE");
+ m.OutputFile(filename.c_str(), "RECREATE");
} else {
filename = outputfilename;
- RESTInfo << "Updating file : " << filename << RESTendl;
- m->OutputFile(filename.c_str(), "UPDATE");
+ RESTInfo << "Creating file : " << filename << RESTendl;
+ m.OutputFile(filename.c_str(), "UPDATE");
}
RESTDebug << "TRestRun::FormOutputFile. Starting to add files" << RESTendl;
for (unsigned int i = 0; i < filenames.size(); i++) {
- m->AddFile(filenames[i].c_str(), false);
+ m.AddFile(filenames[i].c_str(), false);
}
- if (m->Merge()) {
+ if (m.Merge()) {
for (unsigned int i = 0; i < filenames.size(); i++) {
remove(filenames[i].c_str());
}
@@ -1013,8 +1014,6 @@ TFile* TRestRun::MergeToOutputFile(vector filenames, string outputfilena
exit(1);
}
- delete m;
-
// we rename the created output file
fOutputFileName = FormFormat(filename);
rename(filename.c_str(), fOutputFileName);
@@ -1081,6 +1080,17 @@ TFile* TRestRun::UpdateOutputFile() {
return nullptr;
}
+///////////////////////////////////////////////
+/// \brief Open and update output file in case is closed
+///
+TFile* TRestRun::OpenAndUpdateOutputFile() {
+ if (fOutputFile == nullptr) {
+ fOutputFile = TFile::Open(fOutputFileName, "UPDATE");
+ }
+
+ return UpdateOutputFile();
+}
+
///////////////////////////////////////////////
/// \brief Write this object into TFile and add a new entry in database
///