diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d4e7928f..ea7ed0c5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,7 +79,8 @@ set(ROOT_REQUIRED_LIBRARIES Gdml Minuit Spectrum - XMLIO) + XMLIO + TMVA) # Auto schema evolution for ROOT if (NOT DEFINED REST_SE) diff --git a/examples/tmva.rml b/examples/tmva.rml new file mode 100644 index 000000000..fe965ce6c --- /dev/null +++ b/examples/tmva.rml @@ -0,0 +1,66 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/source/framework/analysis/inc/TRestDataSetTMVA.h b/source/framework/analysis/inc/TRestDataSetTMVA.h new file mode 100644 index 000000000..e650a16e1 --- /dev/null +++ b/source/framework/analysis/inc/TRestDataSetTMVA.h @@ -0,0 +1,91 @@ +/************************************************************************* + * This file is part of the REST software framework. * + * * + * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) * + * For more information see https://gifna.unizar.es/trex * + * * + * REST is free software: you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation, either version 3 of the License, or * + * (at your option) any later version. * + * * + * REST is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU General Public License for more details. * + * * + * You should have a copy of the GNU General Public License along with * + * REST in $REST_PATH/LICENSE. * + * If not, see https://www.gnu.org/licenses/. * + * For the list of contributors see $REST_PATH/CREDITS. * + *************************************************************************/ + +#ifndef REST_TRestDataSetTMVA +#define REST_TRestDataSetTMVA + +#include "TH1F.h" +#include "TMVA/Types.h" +#include "TRestCut.h" +#include "TRestMetadata.h" + +/// This class is meant to evaluate several TMVA methods in datasets +class TRestDataSetTMVA : public TRestMetadata { + private: + /// Name of the output file + std::string fOutputFileName = ""; //< + + /// Name of the signal dataSet + std::string fDataSetSignal = ""; //< + + /// Name of the background dataset + std::string fDataSetBackground = ""; //< + + /// Name of the output path for the xml files + std::string fOutputPath = ""; //< + + /// Vector containing different obserbable names + std::vector fObsName; //< + + /// Add method to compute TMVA, https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf for more details + std::vector > fMethod; //< + + /// Cuts over background dataset for PDF selection + TRestCut* fBackgroundCut = nullptr; //< + + /// Cuts over signal dataset for PDF selection + TRestCut* fSignalCut = nullptr; //< + + /// If true display ROC curve after evaluating all methods + bool fDrawROCCurve = true; //< + + /// Map with supported TMVA methods, please add more if something is missing + const std::map fMethodMap = { + //< + {"Likelihood", TMVA::Types::kLikelihood}, // Likelihood ("naive Bayes estimator") + {"LikelihoodKDE", + TMVA::Types::kLikelihood}, // Use a kernel density estimator to approximate the PDFs + {"Fisher", TMVA::Types::kFisher}, // Fisher discriminant (same as LD) + {"BDT", TMVA::Types::kBDT}, // Boosted Decision Trees + {"MLP", TMVA::Types::kMLP} // Multi-Layer Perceptron (Neural Network) + }; + + void Initialize() override; + void InitFromConfigFile() override; + + public: + void PrintMetadata() override; + + void ComputeTMVA(); + + inline void SetDataSetSignal(const std::string& dSName) { fDataSetSignal = dSName; } + inline void SetDataSetBackground(const std::string& dSName) { fDataSetBackground = dSName; } + inline void SetOutputFileName(const std::string& outName) { fOutputFileName = outName; } + inline void SetOutputPath(const std::string& outPath) { fOutputPath = outPath; } + + TRestDataSetTMVA(); + TRestDataSetTMVA(const char* configFilename, std::string name = ""); + ~TRestDataSetTMVA(); + + ClassDefOverride(TRestDataSetTMVA, 1); +}; +#endif diff --git a/source/framework/analysis/inc/TRestDataSetTMVAClassification.h b/source/framework/analysis/inc/TRestDataSetTMVAClassification.h new file mode 100644 index 000000000..9ca4ab43d --- /dev/null +++ b/source/framework/analysis/inc/TRestDataSetTMVAClassification.h @@ -0,0 +1,70 @@ +/************************************************************************* + * This file is part of the REST software framework. * + * * + * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) * + * For more information see https://gifna.unizar.es/trex * + * * + * REST is free software: you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation, either version 3 of the License, or * + * (at your option) any later version. * + * * + * REST is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU General Public License for more details. * + * * + * You should have a copy of the GNU General Public License along with * + * REST in $REST_PATH/LICENSE. * + * If not, see https://www.gnu.org/licenses/. * + * For the list of contributors see $REST_PATH/CREDITS. * + *************************************************************************/ + +#ifndef REST_TRestDataSetTMVAClassification +#define REST_TRestDataSetTMVAClassification + +#include "TH1F.h" +#include "TRestCut.h" +#include "TRestMetadata.h" + +/// This class is meant to classify a given dataset with a particular TMVA method +class TRestDataSetTMVAClassification : public TRestMetadata { + private: + /// Name of the output file + std::string fOutputFileName = ""; //< + + /// Name of the dataSet to classify + std::string fDataSetName = ""; //< + + /// Name of the TMVA method + std::string fTmvaMethod = ""; //< + + /// Name of the TMVA weights file + std::string fTmvaFile = ""; //< + + /// Vector containing different obserbable names + std::vector fObsName; //< + + /// Cuts over the dataset for PDF selection + TRestCut* fCut = nullptr; //< + + void Initialize() override; + void InitFromConfigFile() override; + + public: + void PrintMetadata() override; + + void ClassifyTMVA(); + + inline void SetDataSet(const std::string& dSName) { fDataSetName = dSName; } + inline void SetTMVAMethod(const std::string& method) { fTmvaMethod = method; } + inline void SetTMVAFile(const std::string& file) { fTmvaFile = file; } + inline void SetOutputFileName(const std::string& outName) { fOutputFileName = outName; } + + TRestDataSetTMVAClassification(); + TRestDataSetTMVAClassification(const char* configFilename, std::string name = ""); + ~TRestDataSetTMVAClassification(); + + ClassDefOverride(TRestDataSetTMVAClassification, 1); +}; +#endif diff --git a/source/framework/analysis/src/TRestDataSetTMVA.cxx b/source/framework/analysis/src/TRestDataSetTMVA.cxx new file mode 100644 index 000000000..1d957efde --- /dev/null +++ b/source/framework/analysis/src/TRestDataSetTMVA.cxx @@ -0,0 +1,371 @@ +/************************************************************************* + * This file is part of the REST software framework. * + * * + * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) * + * For more information see https://gifna.unizar.es/trex * + * * + * REST is free software: you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation, either version 3 of the License, or * + * (at your option) any later version. * + * * + * REST is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU General Public License for more details. * + * * + * You should have a copy of the GNU General Public License along with * + * REST in $REST_PATH/LICENSE. * + * If not, see https://www.gnu.org/licenses/. * + * For the list of contributors see $REST_PATH/CREDITS. * + *************************************************************************/ + +///////////////////////////////////////////////////////////////////////// +/// TRestDataSetTMVA is meant to evaluate different TMVA methods in datasets. +/// For more information about TMVA, check https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf +/// So far, only Likelihood, LikelihoodKDE, Fisher, BDT and MLP methods are +/// supported. TMVA requires a signal and a background dataset from which the +/// different TMVA methods are computed. The different methods are evaluated +/// in a set of observables that are provided in the RML file. Different cuts +/// can be performed in either the signal or the background datasets prior to +/// the TMVA evaluation. The output of this class is a root file which contains +/// a signal and a background tree with the cuts applied and the different observables +/// that are generated with the TMVA analysis. In addition, a folder is created +/// with different xml files that contain the output of the TMVA evaluation that +/// can be used to compute the TMVA classification via TRestDataSetTMVAClassification. +/// +/// A summary of the basic parameters is described below: +/// * **outputFileName**: Name of the output file +/// * **dataSetSignal**: Name of the dataset file containing the signal +/// * **dataSetBackground**: Name of the dataset file containing the background +/// * **outputPath**: Name of the output path with the evaluation results +/// * **drawROCCurve**: If true display the ROC curve for the evaluation of all methods +/// +/// The different observables for the TMVA analysis can be added with the following key: +/// \code +/// +/// \endcode +/// +/// * **name**: Name of the observable be computed +/// +/// The different signal and background cuts can be added awith the following key: +/// \code +/// +/// +/// +/// \endcode +/// +/// Where the cut name (e.g. ParamCut or Energy cut from above) have to be defined inside +/// the RML file, e.g.: +/// \code +/// +/// +/// +/// +/// +/// +/// +/// +/// \endcode +/// +/// Please, check TRestCut class for more info. +/// +/// The different TMVA methods can be added wit the following key: +/// \code +/// +/// \endcode +/// The different parameters for adding TMVA methods are described below: +/// * **name**: Name of the TMVA method, only Likelihood, LikelihoodKDE, Fisher, BDT and MLP +/// are supported so far. +/// * **parameters**: String parameters to be used in the TMVA method, for more information +/// check https://root.cern.ch/download/doc/tmva/TMVAUsersGuide.pdf +/// +/// ### Examples +/// Example of RML config file: +/// \code +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// \endcode +/// +/// Example to perform TMVA evaluation using restRoot: +/// \code +/// [0] TRestDataSetTMVA tmva("tmva.rml"); +/// [1] tmva.SetDataSetSignal("DataSetSignal.root"); +/// [2] tmva.SetDataSetBackground("DataSetBackground.root"); +/// [3] tmva.SetOutputFileName("MyDataSetEvaluation.root"); +/// [4] tmva.SetOutputPath("MyDataSetFiles"); +/// [5] tmva.ComputeTMVA(); +/// \endcode +/// +/// In addition it is possible to display TMVA results after evaluating all methods, +/// using root or restRoot; +/// \code +/// [0] TMVA::TMVAGui("MyDataSetEvaluation.root") +/// \endcode +/// +///---------------------------------------------------------------------- +/// +/// REST-for-Physics - Software for Rare Event Searches Toolkit +/// +/// History of developments: +/// +/// 2023-05: First implementation of TRestDataSetTMVA +/// JuanAn Garcia +/// +/// \class TRestDataSetTMVA +/// \author: JuanAn Garcia e-mail: juanangp@unizar.es +/// +///
+/// + +#include "TRestDataSetTMVA.h" + +#include "ROOT/RDFHelpers.hxx" +#include "TMVA/CrossValidation.h" +#include "TMVA/DataLoader.h" +#include "TMVA/Factory.h" +#include "TMVA/TMVAGui.h" +#include "TMVA/Tools.h" +#include "TRestDataSet.h" + +ClassImp(TRestDataSetTMVA); + +/////////////////////////////////////////////// +/// \brief Default constructor +/// +TRestDataSetTMVA::TRestDataSetTMVA() { Initialize(); } + +///////////////////////////////////////////// +/// \brief Constructor loading data from a config file +/// +/// If no configuration path is defined using TRestMetadata::SetConfigFilePath +/// the path to the config file must be specified using full path, absolute or +/// relative. +/// +/// The default behaviour is that the config file must be specified with +/// full path, absolute or relative. +/// +/// \param configFilename A const char* that defines the RML filename. +/// \param name The name of the metadata section. It will be used to find the +/// corresponding TRestDataSetTMVA section inside the RML. +/// +TRestDataSetTMVA::TRestDataSetTMVA(const char* configFilename, std::string name) + : TRestMetadata(configFilename) { + LoadConfigFromFile(fConfigFileName, name); + Initialize(); + + if (GetVerboseLevel() >= TRestStringOutput::REST_Verbose_Level::REST_Info) PrintMetadata(); +} + +/////////////////////////////////////////////// +/// \brief Default destructor +/// +TRestDataSetTMVA::~TRestDataSetTMVA() {} + +/////////////////////////////////////////////// +/// \brief Function to initialize input/output event members and define +/// the section name +/// +void TRestDataSetTMVA::Initialize() { SetSectionName(this->ClassName()); } + +/////////////////////////////////////////////// +/// \brief Function to initialize some variables from +/// configfile +/// +void TRestDataSetTMVA::InitFromConfigFile() { + Initialize(); + TRestMetadata::InitFromConfigFile(); + + TiXmlElement* obsDefinition = GetElement("observable"); + while (obsDefinition != nullptr) { + std::string obsName = GetFieldValue("name", obsDefinition); + if (obsName.empty() || obsName == "Not defined") { + RESTError << "< observable variable key does not contain a name!" << RESTendl; + exit(1); + } else { + fObsName.push_back(obsName); + } + + obsDefinition = GetNextElement(obsDefinition); + } + + TiXmlElement* cutele = GetElement("addBackgroundCut"); + while (cutele != nullptr) { + std::string cutName = GetParameter("name", cutele, ""); + if (!cutName.empty()) { + if (fBackgroundCut == nullptr) { + fBackgroundCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName); + } else { + fBackgroundCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName)); + } + } + cutele = GetNextElement(cutele); + } + + cutele = GetElement("addSignalCut"); + while (cutele != nullptr) { + std::string cutName = GetParameter("name", cutele, ""); + if (!cutName.empty()) { + if (fSignalCut == nullptr) { + fSignalCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName); + } else { + fSignalCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName)); + } + } + cutele = GetNextElement(cutele); + } + + TiXmlElement* method = GetElement("addMethod"); + while (method != nullptr) { + std::string name = GetParameter("name", method, ""); + std::string params = GetParameter("parameters", method, ""); + if (name.empty() || params.empty()) { + RESTWarning << "Empty method" << RESTendl; + } else { + fMethod.push_back(std::make_pair(name, params)); + } + method = GetNextElement(method); + } + + if (fObsName.empty()) { + RESTError << "No observables provided, exiting..." << RESTendl; + exit(1); + } + + if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", ""); +} + +///////////////////////////////////////////// +/// \brief This function computes the TMVA using +/// the different methods provided via config file +/// and the signal and background dataSets. The results +/// are stored in an output root file and a folder. Note +/// that it doesn't provide any usable dataset, just standard +/// root files. +/// +void TRestDataSetTMVA::ComputeTMVA() { + if (fOutputFileName.empty() || fOutputPath.empty() || fDataSetSignal.empty() || + fDataSetBackground.empty()) { + RESTError << "Empty output file name, path, signal or background files " << RESTendl; + PrintMetadata(); + exit(1); + } + + if (fMethod.empty()) { + RESTError << "No TMVA methods have been added " << RESTendl; + PrintMetadata(); + exit(1); + } + + // Add signal dataset + TRestDataSet signal; + signal.Import(fDataSetSignal); + auto dfSignal = signal.MakeCut(fSignalCut); + dfSignal.Snapshot("Signal", fOutputFileName); + + // Add background dataset + TRestDataSet bck; + bck.Import(fDataSetBackground); + auto dfBackground = bck.MakeCut(fBackgroundCut); + ROOT::RDF::RSnapshotOptions opt; + opt.fMode = "update"; + dfBackground.Snapshot("Background", fOutputFileName, "", opt); + + auto outputFile = TFile::Open(fOutputFileName.c_str(), "UPDATE"); + + auto signalTree = outputFile->Get("Signal"); + auto bckTree = outputFile->Get("Background"); + + TMVA::Factory factory("TMVA_Classification", outputFile, + "!V:ROC:!Silent:Color:AnalysisType=Classification"); + + TMVA::DataLoader loader(fOutputPath); + + // Add observables for the evaluation + for (const auto& obs : fObsName) loader.AddVariable(obs); + + loader.AddSignalTree(signalTree, 1.0); + loader.AddBackgroundTree(bckTree, 1.0); + loader.PrepareTrainingAndTestTree("", "", + ":SplitMode=Random" + ":NormMode=NumEvents" + ":!V"); + + // Add different TMVA methods + for (const auto& [name, params] : fMethod) { + auto it = fMethodMap.find(name); + if (it == fMethodMap.end()) { + RESTWarning << "Method " << name << " not supported " << RESTendl; + RESTWarning << "Currently supported methods: "; + for (const auto& [method, val] : fMethodMap) RESTWarning << method << ", "; + RESTWarning << RESTendl; + continue; + } + std::cout << "Added method " << name << " " << it->second << " " << params << std::endl; + factory.BookMethod(&loader, it->second, name.c_str(), params.c_str()); + } + + // Train, test and evaluate all methods + factory.TrainAllMethods(); + factory.TestAllMethods(); + factory.EvaluateAllMethods(); + + // Draw ROC curve + if (fDrawROCCurve && gApplication != nullptr && gApplication->IsRunning()) { + auto c1 = factory.GetROCCurve(&loader); + c1->Draw(); + } + + outputFile->Close(); +} + +///////////////////////////////////////////// +/// \brief Prints on screen the information about the metadata members of TRestDataSetTMVA +/// +void TRestDataSetTMVA::PrintMetadata() { + TRestMetadata::PrintMetadata(); + + RESTMetadata << " Observables to compute: " << RESTendl; + for (const auto& obs : fObsName) { + RESTMetadata << obs << RESTendl; + } + RESTMetadata << " TMVA Methods " << RESTendl; + for (const auto& [name, params] : fMethod) { + RESTMetadata << name << " " << params << RESTendl; + } + RESTMetadata << "----" << RESTendl; +} diff --git a/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx b/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx new file mode 100644 index 000000000..2f500f059 --- /dev/null +++ b/source/framework/analysis/src/TRestDataSetTMVAClassification.cxx @@ -0,0 +1,295 @@ +/************************************************************************* + * This file is part of the REST software framework. * + * * + * Copyright (C) 2016 GIFNA/TREX (University of Zaragoza) * + * For more information see https://gifna.unizar.es/trex * + * * + * REST is free software: you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation, either version 3 of the License, or * + * (at your option) any later version. * + * * + * REST is distributed in the hope that it will be useful, * + * but WITHOUT ANY WARRANTY; without even the implied warranty of * + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * + * GNU General Public License for more details. * + * * + * You should have a copy of the GNU General Public License along with * + * REST in $REST_PATH/LICENSE. * + * If not, see https://www.gnu.org/licenses/. * + * For the list of contributors see $REST_PATH/CREDITS. * + *************************************************************************/ + +///////////////////////////////////////////////////////////////////////// +/// TRestDataSetTMVAClassification performs the classification of a given +/// dataSet using as input the results of the TMVA evaluation methods +/// generated using TRestDataSetTMVA. Note that the observables used on +/// TRestDataSetTMVA and TRestDataSetTMVA needs to match. This class generates +/// as output a dataset with a new observable which is defined using the name of +/// the TMVA method that has been used to classify the dataset. Only one TMVA +/// method is classified in ClassifyTMVA. An output dataset is generated by +/// definining a new observable with the TMVA method e.g. BDT_score +/// +/// A summary of the basic parameters is described below: +/// * **dataSetName**: Name of the dataSet to be classified +/// * **tmvaFile**: Name of the xml input file with the tmva weigths +/// * **tmvaMethod**: Name of the TMVA method used to classify +/// * **outputFileName**: Name of the output dataset +/// +/// +/// The different observables for the TMVA classification can be added with the following key: +/// \code +/// +/// \endcode +/// * **name**: Name of the observable be computed +/// +/// Note that the observable names has to match the ones using for the evaluation of +/// a particular TMVA method +/// +/// Different cuts over the dataset can be added with the following key: +/// \code +/// +/// \endcode +/// +/// ### Examples +/// Example of RML config file: +/// \code +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// +/// \endcode +/// +/// Example of TRestDataSetTMVAClassification using restRoot: +/// \code +/// TRestDataSetTMVAClassification tmva("tmva.rml"); +/// tmva.SetDataSet("MyDataSet.root"); +/// tmva.SetOutputFileName("MyClassifiedDataSet.root"); +/// tmva.ClassifyTMVA(); +/// \endcode +/// +///---------------------------------------------------------------------- +/// +/// REST-for-Physics - Software for Rare Event Searches Toolkit +/// +/// History of developments: +/// +/// 2023-03: First implementation of TRestDataSetTMVAClassification +/// JuanAn Garcia +/// +/// \class TRestDataSetTMVAClassification +/// \author: JuanAn Garcia e-mail: juanangp@unizar.es +/// +///
+/// + +#include "TRestDataSetTMVAClassification.h" + +#include "ROOT/RDFHelpers.hxx" +#include "TMVA/CrossValidation.h" +#include "TMVA/DataLoader.h" +#include "TMVA/Factory.h" +#include "TMVA/RInferenceUtils.hxx" +#include "TMVA/RReader.hxx" +#include "TMVA/RTensorUtils.hxx" +#include "TMVA/Tools.h" +#include "TRestDataSet.h" + +ClassImp(TRestDataSetTMVAClassification); + +/////////////////////////////////////////////// +/// \brief Default constructor +/// +TRestDataSetTMVAClassification::TRestDataSetTMVAClassification() { Initialize(); } + +///////////////////////////////////////////// +/// \brief Constructor loading data from a config file +/// +/// If no configuration path is defined using TRestMetadata::SetConfigFilePath +/// the path to the config file must be specified using full path, absolute or +/// relative. +/// +/// The default behaviour is that the config file must be specified with +/// full path, absolute or relative. +/// +/// \param configFilename A const char* that defines the RML filename. +/// \param name The name of the metadata section. It will be used to find the +/// corresponding TRestDataSetTMVAClassification section inside the RML. +/// +TRestDataSetTMVAClassification::TRestDataSetTMVAClassification(const char* configFilename, std::string name) + : TRestMetadata(configFilename) { + LoadConfigFromFile(fConfigFileName, name); + Initialize(); + + if (GetVerboseLevel() >= TRestStringOutput::REST_Verbose_Level::REST_Info) PrintMetadata(); +} + +/////////////////////////////////////////////// +/// \brief Default destructor +/// +TRestDataSetTMVAClassification::~TRestDataSetTMVAClassification() {} + +/////////////////////////////////////////////// +/// \brief Function to initialize input/output event members and define +/// the section name +/// +void TRestDataSetTMVAClassification::Initialize() { SetSectionName(this->ClassName()); } + +/////////////////////////////////////////////// +/// \brief Function to initialize some variables from +/// configfile +/// +void TRestDataSetTMVAClassification::InitFromConfigFile() { + Initialize(); + TRestMetadata::InitFromConfigFile(); + + TiXmlElement* obsDefinition = GetElement("observable"); + while (obsDefinition != nullptr) { + std::string obsName = GetFieldValue("name", obsDefinition); + if (obsName.empty() || obsName == "Not defined") { + RESTError << "< observable variable key does not contain a name!" << RESTendl; + exit(1); + } else { + fObsName.push_back(obsName); + } + + obsDefinition = GetNextElement(obsDefinition); + } + + TiXmlElement* cutele = GetElement("addCut"); + while (cutele != nullptr) { + std::string cutName = GetParameter("name", cutele, ""); + if (!cutName.empty()) { + if (fCut == nullptr) { + fCut = (TRestCut*)InstantiateChildMetadata("TRestCut", cutName); + } else { + fCut->AddCut((TRestCut*)InstantiateChildMetadata("TRestCut", cutName)); + } + } + cutele = GetNextElement(cutele); + } + + if (fOutputFileName == "") fOutputFileName = GetParameter("outputFileName", ""); +} + +///////////////////////////////////////////// +/// \brief This function computes the TMVA classification +/// for a given dataSet. It requires a xml file with weigths +/// from the output of TRestDataSetTMVA to perform the +/// clasification for a given set of observables. This function +/// defines a new observable with the score of the TMVA method +/// provided in the input file that can be used for further +/// signal and background discrimination. +/// +void TRestDataSetTMVAClassification::ClassifyTMVA() { + PrintMetadata(); + + if (fObsName.empty()) { + RESTError << "No observables provided, exiting..." << RESTendl; + exit(1); + } + + TMVA::Reader reader("!Color:!Silent"); + std::vector var(fObsName.size()); + + // Add variables to the reader + for (unsigned int i = 0; i < fObsName.size(); i++) { + reader.AddVariable(fObsName[i].c_str(), &var[i]); + } + + // Book TMVA method + reader.BookMVA(fTmvaMethod.c_str(), fTmvaFile.c_str()); + + // Lambda for evaluation of the method + auto eval = [&reader = reader, &tmvaMethod = fTmvaMethod](const std::vector& val) { + return reader.EvaluateMVA(val, tmvaMethod.c_str()); + }; + + TRestDataSet dataSet; + dataSet.Import(fDataSetName); + + auto df = dataSet.MakeCut(fCut); + + std::string obsName = fTmvaMethod + "_score"; + + // Ugly but cannot pass vector size to ROOT::RDF::PassAsVec + switch (fObsName.size()) { + case 1: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<1, double>(eval), fObsName); + break; + case 2: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<2, double>(eval), fObsName); + break; + case 3: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<3, double>(eval), fObsName); + break; + case 4: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<4, double>(eval), fObsName); + break; + case 5: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<5, double>(eval), fObsName); + break; + case 6: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<6, double>(eval), fObsName); + break; + case 7: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<7, double>(eval), fObsName); + break; + case 8: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<8, double>(eval), fObsName); + break; + case 9: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<9, double>(eval), fObsName); + break; + case 10: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<10, double>(eval), fObsName); + break; + case 11: + df = df.Define(obsName.c_str(), ROOT::RDF::PassAsVec<11, double>(eval), fObsName); + break; + default: + RESTError << "Number of observables " << fObsName.size() << " is not supported" << RESTendl; + exit(1); + } + + dataSet.SetDataFrame(df); + + if (!fOutputFileName.empty()) { + if (TRestTools::GetFileNameExtension(fOutputFileName) == "root") { + dataSet.Export(fOutputFileName); + TFile* f = TFile::Open(fOutputFileName.c_str(), "UPDATE"); + this->Write(); + f->Close(); + } + } +} + +///////////////////////////////////////////// +/// \brief Prints on screen the information about the metadata members of TRestDataSetTMVAClassification +/// +void TRestDataSetTMVAClassification::PrintMetadata() { + TRestMetadata::PrintMetadata(); + + RESTMetadata << " Observables to compute: " << RESTendl; + for (size_t i = 0; i < fObsName.size(); i++) { + RESTMetadata << fObsName[i] << RESTendl; + } + RESTMetadata << "----" << RESTendl; +} diff --git a/source/framework/core/inc/TRestDataSet.h b/source/framework/core/inc/TRestDataSet.h index 94f0a490d..28d185f7a 100644 --- a/source/framework/core/inc/TRestDataSet.h +++ b/source/framework/core/inc/TRestDataSet.h @@ -169,8 +169,8 @@ class TRestDataSet : public TRestMetadata { inline void SetQuantity(const std::map& quantity) { fQuantity = quantity; } TRestDataSet& operator=(TRestDataSet& dS); - void Import(const std::string& fileName); void Import(std::vector fileNames); + void Import(const std::string& fileName, bool enableMT = true); void Export(const std::string& filename); ROOT::RDF::RNode MakeCut(const TRestCut* cut); diff --git a/source/framework/core/inc/TRestRun.h b/source/framework/core/inc/TRestRun.h index ffaf56b41..52cfe2b3d 100644 --- a/source/framework/core/inc/TRestRun.h +++ b/source/framework/core/inc/TRestRun.h @@ -92,6 +92,7 @@ class TRestRun : public TRestMetadata { TFile* MergeToOutputFile(std::vector filefullnames, std::string outputfilename = ""); TFile* FormOutputFile(); TFile* UpdateOutputFile(); + TFile* OpenAndUpdateOutputFile(); void PassOutputFile() { fOutputFile = fInputFile; diff --git a/source/framework/core/src/TRestDataSet.cxx b/source/framework/core/src/TRestDataSet.cxx index 430aba8aa..566b279c3 100644 --- a/source/framework/core/src/TRestDataSet.cxx +++ b/source/framework/core/src/TRestDataSet.cxx @@ -890,7 +890,7 @@ TRestDataSet& TRestDataSet::operator=(TRestDataSet& dS) { /// it import metadata info from the previous dataSet /// while it opens the analysis tree /// -void TRestDataSet::Import(const std::string& fileName) { +void TRestDataSet::Import(const std::string& fileName, bool enableMT) { if (TRestTools::GetFileNameExtension(fileName) != "root") { RESTError << "Datasets can only be imported from root files" << RESTendl; return; @@ -918,7 +918,7 @@ void TRestDataSet::Import(const std::string& fileName) { return; } - ROOT::EnableImplicitMT(); + if (enableMT) ROOT::EnableImplicitMT(); fDataSet = ROOT::RDataFrame("AnalysisTree", fileName); diff --git a/source/framework/core/src/TRestRun.cxx b/source/framework/core/src/TRestRun.cxx index 3f533206b..6ae1e0ea7 100644 --- a/source/framework/core/src/TRestRun.cxx +++ b/source/framework/core/src/TRestRun.cxx @@ -986,24 +986,25 @@ TString TRestRun::FormFormat(const TString& FilenameFormat) { TFile* TRestRun::MergeToOutputFile(vector filenames, string outputfilename) { RESTDebug << "TRestRun::FormOutputFile. target : " << outputfilename << RESTendl; string filename; - TFileMerger* m = new TFileMerger(false); + + TFileMerger m(false); if (outputfilename.empty()) { filename = fOutputFileName; RESTInfo << "Creating file : " << filename << RESTendl; - m->OutputFile(filename.c_str(), "RECREATE"); + m.OutputFile(filename.c_str(), "RECREATE"); } else { filename = outputfilename; - RESTInfo << "Updating file : " << filename << RESTendl; - m->OutputFile(filename.c_str(), "UPDATE"); + RESTInfo << "Creating file : " << filename << RESTendl; + m.OutputFile(filename.c_str(), "UPDATE"); } RESTDebug << "TRestRun::FormOutputFile. Starting to add files" << RESTendl; for (unsigned int i = 0; i < filenames.size(); i++) { - m->AddFile(filenames[i].c_str(), false); + m.AddFile(filenames[i].c_str(), false); } - if (m->Merge()) { + if (m.Merge()) { for (unsigned int i = 0; i < filenames.size(); i++) { remove(filenames[i].c_str()); } @@ -1013,8 +1014,6 @@ TFile* TRestRun::MergeToOutputFile(vector filenames, string outputfilena exit(1); } - delete m; - // we rename the created output file fOutputFileName = FormFormat(filename); rename(filename.c_str(), fOutputFileName); @@ -1081,6 +1080,17 @@ TFile* TRestRun::UpdateOutputFile() { return nullptr; } +/////////////////////////////////////////////// +/// \brief Open and update output file in case is closed +/// +TFile* TRestRun::OpenAndUpdateOutputFile() { + if (fOutputFile == nullptr) { + fOutputFile = TFile::Open(fOutputFileName, "UPDATE"); + } + + return UpdateOutputFile(); +} + /////////////////////////////////////////////// /// \brief Write this object into TFile and add a new entry in database ///