-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TMVAHelper: Fixing nightlies build (#388)
- Loading branch information
Showing
6 changed files
with
81 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,29 @@ | ||
|
||
|
||
find_package(TBB REQUIRED tbb) | ||
find_package(ROOT) | ||
find_package(ROOT COMPONENTS ROOTVecOps) | ||
find_package(ROOT COMPONENTS TMVA) | ||
|
||
|
||
message(STATUS "includes-------------------------- TEST: ${TBB_INCLUDE_DIRS}") | ||
find_package(TBB REQUIRED COMPONENTS tbb) | ||
find_package(ROOT REQUIRED COMPONENTS TMVA TMVAUtils ROOTVecOps) | ||
|
||
file(GLOB sources src/*.cc) | ||
file(GLOB headers *.h) | ||
|
||
|
||
fccanalyses_addon_build(TMVAHelper | ||
SOURCES ${headers} ${sources} | ||
# EXT_LIBS ROOT::ROOTVecOps ROOT::TMVA TBB::tbb | ||
INSTALL_COMPONENT tmvahelper) | ||
|
||
add_custom_command(TARGET TMVAHelper POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy | ||
${CMAKE_CURRENT_SOURCE_DIR}/python/* | ||
${CMAKE_CURRENT_BINARY_DIR} | ||
EXT_LIBS ROOT::ROOTVecOps ROOT::TMVA ROOT::TMVAUtils | ||
TBB::tbb | ||
INSTALL_COMPONENT tmvahelper | ||
) | ||
|
||
target_link_libraries(TMVAHelper PRIVATE TBB::tbb) | ||
target_link_libraries(TMVAHelper PRIVATE ROOT::ROOTVecOps) | ||
target_link_libraries(TMVAHelper PRIVATE ROOT::TMVA) | ||
target_compile_features(TMVAHelper PRIVATE cxx_std_11) | ||
add_custom_command( | ||
TARGET TMVAHelper | ||
POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/* | ||
${CMAKE_CURRENT_BINARY_DIR} | ||
) | ||
|
||
install(FILES | ||
${CMAKE_CURRENT_LIST_DIR}/TMVAHelper.h | ||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/TMVAHelper | ||
) | ||
) | ||
|
||
file(GLOB _addon_python_files python/*.py) | ||
install(FILES ${_addon_python_files} DESTINATION ${CMAKE_INSTALL_PREFIX}/python/addons/TMVAHelper) | ||
install(FILES ${_addon_python_files} | ||
DESTINATION ${CMAKE_INSTALL_PREFIX}/python/addons/TMVAHelper | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,28 @@ | ||
#ifndef TMVAHelper_TMVAHelper_h | ||
#define TMVAHelper_TMVAHelper_h | ||
|
||
#include <tbb/task_arena.h> | ||
// ROOT | ||
#include "ROOT/RVec.hxx" | ||
#include "TMVA/RBDT.hxx" | ||
|
||
// TBB | ||
#include <tbb/task_arena.h> | ||
|
||
// std | ||
#include <string> | ||
#include <vector> | ||
|
||
class tmva_helper_xgb { | ||
public: | ||
explicit tmva_helper_xgb(const std::string &filename, const std::string &name, const unsigned &nvars, const unsigned int nslots = 1); | ||
virtual ~tmva_helper_xgb(); | ||
ROOT::VecOps::RVec<float> operator()(const ROOT::VecOps::RVec<float> vars); | ||
public: | ||
tmva_helper_xgb(const std::string &filename, const std::string &name, | ||
const unsigned int nslots = 1); | ||
~tmva_helper_xgb() {}; | ||
ROOT::VecOps::RVec<float> operator()(const ROOT::VecOps::RVec<float> vars); | ||
|
||
private: | ||
unsigned int nvars_; | ||
TMVA::Experimental::RBDT<> model_; | ||
std::vector<TMVA::Experimental::RBDT<>> interpreters_; | ||
private: | ||
// Default backend (template parameter) is: | ||
// TMVA::Experimental::BranchlessJittedForest<float> | ||
std::vector<TMVA::Experimental::RBDT<>> m_interpreters; | ||
}; | ||
|
||
|
||
#endif | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,38 @@ | ||
|
||
import ROOT | ||
import pathlib | ||
|
||
ROOT.gInterpreter.ProcessLine('#include "TMVAHelper/TMVAHelper.h"') | ||
ROOT.gSystem.Load("libTMVAHelper") | ||
|
||
class TMVAHelperXGB(): | ||
|
||
class TMVAHelperXGB(): | ||
def __init__(self, model_input, model_name, variables=[]): | ||
|
||
if len(variables) == 0: # try to get the variables from the model file (saved as a TList) | ||
# try to get the variables from the model file (saved as a TList) | ||
if len(variables) == 0: | ||
fIn = ROOT.TFile(model_input) | ||
variables_ = fIn.Get("variables") | ||
self.variables = [str(var.GetString()) for var in variables_] | ||
fIn.Close() | ||
else: | ||
else: | ||
self.variables = variables | ||
self.nvars = len(self.variables) | ||
self.model_input = model_input | ||
self.model_name = model_name | ||
self.nthreads = ROOT.GetThreadPoolSize() | ||
self.nthreads = ROOT.GetThreadPoolSize() | ||
|
||
self.tmva_helper = ROOT.tmva_helper_xgb(self.model_input, self.model_name, self.nvars, self.nthreads) | ||
self.tmva_helper = ROOT.tmva_helper_xgb(self.model_input, | ||
self.model_name, | ||
self.nthreads) | ||
self.var_col = f"tmva_vars_{self.model_name}" | ||
|
||
def run_inference(self, df, col_name = "mva_score"): | ||
def run_inference(self, df, col_name="mva_score"): | ||
|
||
# check if columns exist in the dataframe | ||
cols = df.GetColumnNames() | ||
for var in self.variables: | ||
if not var in cols: | ||
if var not in cols: | ||
raise Exception(f"Variable {var} not defined in dataframe.") | ||
|
||
vars_str = ', (float)'.join(self.variables) | ||
df = df.Define(self.var_col, f"ROOT::VecOps::RVec<float>{{{vars_str}}}") | ||
df = df.Define(self.var_col, | ||
f"ROOT::VecOps::RVec<float>{{{vars_str}}}") | ||
df = df.Define(col_name, self.tmva_helper, [self.var_col]) | ||
return df | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,24 @@ | ||
#include "TMVAHelper/TMVAHelper.h" | ||
|
||
tmva_helper_xgb::tmva_helper_xgb(const std::string &filename, | ||
const std::string &name, | ||
const unsigned int nslots) { | ||
|
||
tmva_helper_xgb::tmva_helper_xgb(const std::string &filename, const std::string &name, const unsigned &nvars, const unsigned int nslots) : | ||
model_(name, filename), nvars_(nvars) { | ||
|
||
const unsigned int nslots_actual = std::max(nslots, 1U); | ||
interpreters_.reserve(nslots_actual); | ||
for (unsigned int islot = 0; islot < nslots_actual; ++islot) { | ||
interpreters_.emplace_back(model_); | ||
} | ||
const unsigned int nslots_actual = std::max(nslots, 1U); | ||
m_interpreters.reserve(nslots_actual); | ||
for (unsigned int islot = 0; islot < nslots_actual; ++islot) { | ||
m_interpreters.emplace_back(TMVA::Experimental::RBDT<>(name, filename)); | ||
} | ||
} | ||
|
||
ROOT::VecOps::RVec<float> tmva_helper_xgb::operator()(const ROOT::VecOps::RVec<float> vars) { | ||
auto const tbb_slot = std::max(tbb::this_task_arena::current_thread_index(), 0); | ||
if (tbb_slot >= interpreters_.size()) { | ||
throw std::runtime_error("Not enough interpreters allocated for number of tbb threads"); | ||
} | ||
auto &interpreter_data = interpreters_[tbb_slot]; | ||
return interpreter_data.Compute(vars); | ||
ROOT::VecOps::RVec<float> | ||
tmva_helper_xgb::operator()(const ROOT::VecOps::RVec<float> vars) { | ||
auto const tbb_slot = | ||
std::max(tbb::this_task_arena::current_thread_index(), 0); | ||
if (tbb_slot >= m_interpreters.size()) { | ||
throw std::runtime_error( | ||
"Not enough interpreters allocated for number of tbb threads"); | ||
} | ||
auto &interpreter_data = m_interpreters[tbb_slot]; | ||
return interpreter_data.Compute(vars); | ||
} | ||
|
||
tmva_helper_xgb::~tmva_helper_xgb() {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters