Skip to content

Commit

Permalink
TMVAHelper: Fixing nightlies build (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
kjvbrt authored Jul 18, 2024
1 parent 1f03903 commit ccc302a
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 87 deletions.
11 changes: 7 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ option(USE_EXTERNAL_CATCH2 "Link against an external Catch2 v3 static library, o

option(FCCANALYSES_DOCUMENTATION "Whether or not to create doxygen doc target." ON)

#--- Export compile commands --------------------------------------------------
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

#--- Set a better default for installation directory---------------------------
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/install" CACHE PATH "default install path" FORCE)
Expand All @@ -48,17 +51,17 @@ set(INSTALL_INCLUDE_DIR include CACHE PATH

#--- Declare C++ Standard -----------------------------------------------------

set(CMAKE_CXX_STANDARD 17 CACHE STRING "")
set(CMAKE_CXX_STANDARD 20 CACHE STRING "")
if(NOT CMAKE_CXX_STANDARD MATCHES "17|20")
message(FATAL_ERROR "Unsupported C++ standard: ${CMAKE_CXX_STANDARD}")
endif()
message (STATUS "C++ standard: ${CMAKE_CXX_STANDARD}")

#--- Dependencies -------------------------------------------------------------

find_package(ROOT COMPONENTS ROOTVecOps ROOTDataFrame REQUIRED)
find_package(EDM4HEP REQUIRED)
find_package(podio)
find_package(ROOT REQUIRED COMPONENTS ROOTVecOps ROOTDataFrame TMVA TMVAUtils)
find_package(EDM4HEP REQUIRED) # will find also podio
find_package(TBB REQUIRED COMPONENTS tbb)

# need to use our own FindFastJet.cmake
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
Expand Down
38 changes: 15 additions & 23 deletions addons/TMVAHelper/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,37 +1,29 @@


find_package(TBB REQUIRED tbb)
find_package(ROOT)
find_package(ROOT COMPONENTS ROOTVecOps)
find_package(ROOT COMPONENTS TMVA)


message(STATUS "includes-------------------------- TEST: ${TBB_INCLUDE_DIRS}")
find_package(TBB REQUIRED COMPONENTS tbb)
find_package(ROOT REQUIRED COMPONENTS TMVA TMVAUtils ROOTVecOps)

file(GLOB sources src/*.cc)
file(GLOB headers *.h)


fccanalyses_addon_build(TMVAHelper
SOURCES ${headers} ${sources}
# EXT_LIBS ROOT::ROOTVecOps ROOT::TMVA TBB::tbb
INSTALL_COMPONENT tmvahelper)

add_custom_command(TARGET TMVAHelper POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_SOURCE_DIR}/python/*
${CMAKE_CURRENT_BINARY_DIR}
EXT_LIBS ROOT::ROOTVecOps ROOT::TMVA ROOT::TMVAUtils
TBB::tbb
INSTALL_COMPONENT tmvahelper
)

target_link_libraries(TMVAHelper PRIVATE TBB::tbb)
target_link_libraries(TMVAHelper PRIVATE ROOT::ROOTVecOps)
target_link_libraries(TMVAHelper PRIVATE ROOT::TMVA)
target_compile_features(TMVAHelper PRIVATE cxx_std_11)
add_custom_command(
TARGET TMVAHelper
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/python/*
${CMAKE_CURRENT_BINARY_DIR}
)

install(FILES
${CMAKE_CURRENT_LIST_DIR}/TMVAHelper.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/TMVAHelper
)
)

file(GLOB _addon_python_files python/*.py)
install(FILES ${_addon_python_files} DESTINATION ${CMAKE_INSTALL_PREFIX}/python/addons/TMVAHelper)
install(FILES ${_addon_python_files}
DESTINATION ${CMAKE_INSTALL_PREFIX}/python/addons/TMVAHelper
)
28 changes: 17 additions & 11 deletions addons/TMVAHelper/TMVAHelper.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
#ifndef TMVAHelper_TMVAHelper_h
#define TMVAHelper_TMVAHelper_h

#include <tbb/task_arena.h>
// ROOT
#include "ROOT/RVec.hxx"
#include "TMVA/RBDT.hxx"

// TBB
#include <tbb/task_arena.h>

// std
#include <string>
#include <vector>

class tmva_helper_xgb {
public:
explicit tmva_helper_xgb(const std::string &filename, const std::string &name, const unsigned &nvars, const unsigned int nslots = 1);
virtual ~tmva_helper_xgb();
ROOT::VecOps::RVec<float> operator()(const ROOT::VecOps::RVec<float> vars);
public:
tmva_helper_xgb(const std::string &filename, const std::string &name,
const unsigned int nslots = 1);
~tmva_helper_xgb() {};
ROOT::VecOps::RVec<float> operator()(const ROOT::VecOps::RVec<float> vars);

private:
unsigned int nvars_;
TMVA::Experimental::RBDT<> model_;
std::vector<TMVA::Experimental::RBDT<>> interpreters_;
private:
// Default backend (template parameter) is:
// TMVA::Experimental::BranchlessJittedForest<float>
std::vector<TMVA::Experimental::RBDT<>> m_interpreters;
};


#endif
#endif
25 changes: 12 additions & 13 deletions addons/TMVAHelper/python/TMVAHelper.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@

import ROOT
import pathlib

ROOT.gInterpreter.ProcessLine('#include "TMVAHelper/TMVAHelper.h"')
ROOT.gSystem.Load("libTMVAHelper")

class TMVAHelperXGB():

class TMVAHelperXGB():
def __init__(self, model_input, model_name, variables=[]):

if len(variables) == 0: # try to get the variables from the model file (saved as a TList)
# try to get the variables from the model file (saved as a TList)
if len(variables) == 0:
fIn = ROOT.TFile(model_input)
variables_ = fIn.Get("variables")
self.variables = [str(var.GetString()) for var in variables_]
fIn.Close()
else:
else:
self.variables = variables
self.nvars = len(self.variables)
self.model_input = model_input
self.model_name = model_name
self.nthreads = ROOT.GetThreadPoolSize()
self.nthreads = ROOT.GetThreadPoolSize()

self.tmva_helper = ROOT.tmva_helper_xgb(self.model_input, self.model_name, self.nvars, self.nthreads)
self.tmva_helper = ROOT.tmva_helper_xgb(self.model_input,
self.model_name,
self.nthreads)
self.var_col = f"tmva_vars_{self.model_name}"

def run_inference(self, df, col_name = "mva_score"):
def run_inference(self, df, col_name="mva_score"):

# check if columns exist in the dataframe
cols = df.GetColumnNames()
for var in self.variables:
if not var in cols:
if var not in cols:
raise Exception(f"Variable {var} not defined in dataframe.")

vars_str = ', (float)'.join(self.variables)
df = df.Define(self.var_col, f"ROOT::VecOps::RVec<float>{{{vars_str}}}")
df = df.Define(self.var_col,
f"ROOT::VecOps::RVec<float>{{{vars_str}}}")
df = df.Define(col_name, self.tmva_helper, [self.var_col])
return df

35 changes: 18 additions & 17 deletions addons/TMVAHelper/src/TMVAHelper.cc
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
#include "TMVAHelper/TMVAHelper.h"

tmva_helper_xgb::tmva_helper_xgb(const std::string &filename,
const std::string &name,
const unsigned int nslots) {

tmva_helper_xgb::tmva_helper_xgb(const std::string &filename, const std::string &name, const unsigned &nvars, const unsigned int nslots) :
model_(name, filename), nvars_(nvars) {

const unsigned int nslots_actual = std::max(nslots, 1U);
interpreters_.reserve(nslots_actual);
for (unsigned int islot = 0; islot < nslots_actual; ++islot) {
interpreters_.emplace_back(model_);
}
const unsigned int nslots_actual = std::max(nslots, 1U);
m_interpreters.reserve(nslots_actual);
for (unsigned int islot = 0; islot < nslots_actual; ++islot) {
m_interpreters.emplace_back(TMVA::Experimental::RBDT<>(name, filename));
}
}

ROOT::VecOps::RVec<float> tmva_helper_xgb::operator()(const ROOT::VecOps::RVec<float> vars) {
auto const tbb_slot = std::max(tbb::this_task_arena::current_thread_index(), 0);
if (tbb_slot >= interpreters_.size()) {
throw std::runtime_error("Not enough interpreters allocated for number of tbb threads");
}
auto &interpreter_data = interpreters_[tbb_slot];
return interpreter_data.Compute(vars);
ROOT::VecOps::RVec<float>
tmva_helper_xgb::operator()(const ROOT::VecOps::RVec<float> vars) {
auto const tbb_slot =
std::max(tbb::this_task_arena::current_thread_index(), 0);
if (tbb_slot >= m_interpreters.size()) {
throw std::runtime_error(
"Not enough interpreters allocated for number of tbb threads");
}
auto &interpreter_data = m_interpreters[tbb_slot];
return interpreter_data.Compute(vars);
}

tmva_helper_xgb::~tmva_helper_xgb() {}
31 changes: 12 additions & 19 deletions analyzers/dataframe/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@


# workaround for ROOT not properly exporting the VDT includes
find_package(Vdt)


message(STATUS "includes-------------------------- dataframe edm4hep: ${EDM4HEP_INCLUDE_DIRS}")
message(STATUS "includes-------------------------- dataframe podio : ${podio_INCLUDE_DIR}")
message(STATUS "includes-------------------------- dataframe delphes: ${DELPHES_INCLUDE_DIR}")
message(STATUS "includes-------------------------- dataframe delphes EXt TrkCov: ${DELPHES_EXTERNALS_TKCOV_INCLUDE_DIR}")
message(STATUS "includes-------------------------- dataframe delphes EXt: ${DELPHES_EXTERNALS_INCLUDE_DIR}")

include_directories(${DELPHES_INCLUDE_DIR}
${DELPHES_EXTERNALS_INCLUDE_DIR}
${DELPHES_EXTERNALS_TKCOV_INCLUDE_DIR}
)


# sources and headers
file(GLOB sources src/*.cc)
file(GLOB headers RELATIVE ${CMAKE_CURRENT_LIST_DIR} FCCAnalyses/*.h)

file(GLOB headers RELATIVE ${CMAKE_CURRENT_LIST_DIR} FCCAnalyses/*.h)
list(FILTER headers EXCLUDE REGEX "LinkDef.h")

if(NOT WITH_DD4HEP)
list(FILTER headers EXCLUDE REGEX "CaloNtupleizer.h")
list(FILTER sources EXCLUDE REGEX "CaloNtupleizer.cc")
endif()

if(NOT WITH_ONNX)
list(FILTER headers EXCLUDE REGEX "JetFlavourUtils.h")
list(FILTER sources EXCLUDE REGEX "JetFlavourUtils.cc")
Expand All @@ -38,14 +32,14 @@ if(NOT WITH_ACTS)
list(FILTER sources EXCLUDE REGEX "VertexFinderActs.cc")
endif()

message(STATUS "includes headers ${headers}")
message(STATUS "includes sources ${sources}")
message(STATUS "FCCAnalyses headers:\n ${headers}")
message(STATUS "FCCAnalyses sources:\n ${sources}")

message(STATUS "CMAKE_CURRENT_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}")
message(STATUS "CMAKE_INSTALL_INCLUDEDIR ${CMAKE_INSTALL_INCLUDEDIR}")

add_library(FCCAnalyses SHARED ${sources} ${headers} )
target_include_directories(FCCAnalyses PUBLIC
add_library(FCCAnalyses SHARED ${sources} ${headers})
target_include_directories(FCCAnalyses PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/addons>
$<INSTALL_INTERFACE:include>
Expand All @@ -58,7 +52,6 @@ target_include_directories(FCCAnalyses PUBLIC
message(STATUS " ====== DELPHES LIBRARY = " ${DELPHES_LIBRARY} )
message(STATUS " ====== DELPHES_EXTERNALS_TKCOV_INCLUDE_DIR = " ${DELPHES_EXTERNALS_TKCOV_INCLUDE_DIR} )


target_link_libraries(FCCAnalyses PUBLIC
ROOT::Physics
ROOT::MathCore
Expand All @@ -67,10 +60,10 @@ target_link_libraries(FCCAnalyses PUBLIC
EDM4HEP::edm4hep
EDM4HEP::edm4hepDict
podio::podio
${DELPHES_LIBRARY}
${ADDONS_LIBRARIES}
${DELPHES_LIBRARY}
gfortran # todo: why necessary?
)
)

if(WITH_DD4HEP)
target_link_libraries(FCCAnalyses PUBLIC DD4hep::DDCore)
Expand All @@ -88,15 +81,15 @@ ROOT_GENERATE_DICTIONARY(G__FCCAnalyses
${headers}
MODULE FCCAnalyses
LINKDEF FCCAnalyses/LinkDef.h
)
)

install(TARGETS FCCAnalyses
EXPORT FCCAnalysesTargets
RUNTIME DESTINATION "${INSTALL_BIN_DIR}" COMPONENT bin
LIBRARY DESTINATION "${INSTALL_LIB_DIR}" COMPONENT shlib
PUBLIC_HEADER DESTINATION "${INSTALL_INCLUDE_DIR}/FCCAnalyses"
COMPONENT dev
)
)

install(FILES
"${PROJECT_BINARY_DIR}/analyzers/dataframe/libFCCAnalyses.rootmap"
Expand Down

0 comments on commit ccc302a

Please sign in to comment.