Skip to content

Commit

Permalink
[binding] refactor binding code
Browse files Browse the repository at this point in the history
Refactor the binding code into multiple files, and get rid of llambdas.
It's a lot cleaner/clearer now, I hope.
  • Loading branch information
milobanks committed Jul 7, 2024
1 parent d70f74f commit d94b7fa
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 428 deletions.
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ find_package(pybind11 CONFIG REQUIRED)
# Include the include/epiworld library
include_directories(include/epiworld)

python_add_library(_core MODULE src/main.cpp WITH_SOABI)
python_add_library(_core MODULE
src/models/model.cpp
src/models/seirconn.cpp
src/database.cpp
src/main.cpp
WITH_SOABI)
target_link_libraries(_core PRIVATE pybind11::headers)
target_compile_definitions(_core PRIVATE VERSION_INFO=${PROJECT_VERSION})

Expand Down
153 changes: 153 additions & 0 deletions src/database.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#include "database.hpp"

#include <pybind11/numpy.h>
#include <pybind11/stl.h>

using namespace epiworld;
using namespace epiworldpy;
namespace py = pybind11;

static py::dict get_hist_total(DataBase<int> &self) {
/* Lo, one of the times in modern C++ where the 'new' keyword
* isn't out of place. */
std::vector<std::string> states;
std::vector<int> *dates = new std::vector<int>();
std::vector<int> *counts = new std::vector<int>();

self.get_hist_total(dates, &states, counts);

/* Return to Python. */
py::capsule pyc_dates(
dates, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });
py::capsule pyc_counts(counts, [](void *x) {
delete reinterpret_cast<std::vector<int> *>(x);
});

/* TODO: Find a way to do a no-copy of a string vector. */
py::array py_states = py::array(py::cast(states));
py::array_t<int> py_dates(dates->size(), dates->data(), pyc_dates);
py::array_t<int> py_counts(counts->size(), counts->data(), pyc_counts);

py::dict ret("dates"_a = py_dates, "states"_a = py_states,
"counts"_a = py_counts);

return ret;
}

static std::vector<std::vector<std::map<int, int>>>
get_reproductive_number(DataBase<int> &self) {
MapVec_type<int, int> raw_rt = self.reproductive_number();
// viruses | dates | pairs
// C | C | V
std::vector<std::vector<std::map<int, int>>> viruses;

/* Reserve our spaces for our elements so we don't have to
* worry about it later. */
for (int virus_id = 0; virus_id < self.get_n_viruses(); virus_id++) {
std::vector<std::map<int, int>> dates(self.get_model()->today() + 1);
viruses.push_back(dates);
}

/* Load into pre-return. */
for (const auto &keyValue : raw_rt) {
const std::vector<int> &key = keyValue.first;
const int virus_id = key.at(0);
const int source = key.at(1);
const int exposure_date = key.at(2);
const int effective_rn = keyValue.second;

viruses[virus_id][exposure_date].insert({source, effective_rn});
}

/* TODO: There's lots room for optimization here, namely
* returning an array instead of a bunch of lists. */
return viruses;
}

static py::dict get_transmissions(DataBase<int> &self) {
/* Lo, one of the times in modern C++ where the 'new' keyword
* isn't out of place. */
std::vector<int> *dates = new std::vector<int>();
std::vector<int> *sources = new std::vector<int>();
std::vector<int> *targets = new std::vector<int>();
std::vector<int> *viruses = new std::vector<int>();
std::vector<int> *source_exposure_dates = new std::vector<int>();

self.get_transmissions(*dates, *sources, *targets, *viruses,
*source_exposure_dates);

/* Return to Python. */
py::capsule pyc_dates(
dates, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });
py::capsule pyc_sources(sources, [](void *x) {
delete reinterpret_cast<std::vector<int> *>(x);
});
py::capsule pyc_targets(targets, [](void *x) {
delete reinterpret_cast<std::vector<int> *>(x);
});
py::capsule pyc_viruses(viruses, [](void *x) {
delete reinterpret_cast<std::vector<int> *>(x);
});
py::capsule pyc_source_exposure_dates(source_exposure_dates, [](void *x) {
delete reinterpret_cast<std::vector<int> *>(x);
});

py::array py_dates(dates->size(), dates->data(), pyc_dates);
py::array py_sources(sources->size(), sources->data(), pyc_sources);
py::array py_targets(targets->size(), targets->data(), pyc_targets);
py::array py_viruses(viruses->size(), viruses->data(), pyc_viruses);
py::array py_source_exposure_dates(source_exposure_dates->size(),
source_exposure_dates->data(),
pyc_source_exposure_dates);

py::dict ret("dates"_a = py_dates, "sources"_a = py_sources,
"targets"_a = py_targets, "viruses"_a = py_viruses,
"source_exposure_dates"_a = py_source_exposure_dates);

return ret;
}

static py::dict get_generation_time(DataBase<int> &self) {
std::vector<int> *agents = new std::vector<int>();
std::vector<int> *viruses = new std::vector<int>();
std::vector<int> *times = new std::vector<int>();
std::vector<int> *gentimes = new std::vector<int>();

self.generation_time(*agents, *viruses, *times, *gentimes);

/* Return to Python. */
py::capsule pyc_agents(agents, [](void *x) {
delete reinterpret_cast<std::vector<int> *>(x);
});
py::capsule pyc_viruses(viruses, [](void *x) {
delete reinterpret_cast<std::vector<int> *>(x);
});
py::capsule pyc_times(
times, [](void *x) { delete reinterpret_cast<std::vector<int> *>(x); });
py::capsule pyc_gentimes(gentimes, [](void *x) {
delete reinterpret_cast<std::vector<int> *>(x);
});

py::array py_agents(agents->size(), agents->data(), pyc_agents);
py::array py_viruses(viruses->size(), viruses->data(), pyc_viruses);
py::array py_times(times->size(), times->data(), pyc_times);
py::array py_gentimes(gentimes->size(), gentimes->data(), pyc_gentimes);

py::dict ret("agents"_a = py_agents, "viruses"_a = py_viruses,
"times"_a = py_times, "gentimes"_a = py_gentimes);

return ret;
}

void epiworldpy::export_database(
py::class_<epiworld::DataBase<int>,
std::shared_ptr<epiworld::DataBase<int>>> &c) {
c.def("get_hist_total", &get_hist_total,
"Get historical totals for this model run.")
.def("get_reproductive_number", &get_reproductive_number,
"Get reproductive numbers over time for every virus in the model.")
.def("get_transmissions", &get_transmissions,
"Get transmission data over time for every virus in the model.")
.def("get_generation_time", &get_generation_time,
"Get generation times over time for every virus in the model.");
}
13 changes: 13 additions & 0 deletions src/database.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef EPIWORLDPY_DATABASE_HPP
#define EPIWORLDPY_DATABASE_HPP

#include "interface.hpp"
#include <pybind11/pybind11.h>

namespace epiworldpy {
void export_database(
pybind11::class_<epiworld::DataBase<int>,
std::shared_ptr<epiworld::DataBase<int>>> &c);
}

#endif /* EPIWORLDPY_DATABASE_HPP */
87 changes: 0 additions & 87 deletions src/epiworld-common.hpp

This file was deleted.

24 changes: 24 additions & 0 deletions src/interface.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef EPIWORLDPY_INTERFACE_HPP
#define EPIWORLDPY_INTERFACE_HPP

#include <pybind11/pybind11.h>

using namespace pybind11::literals;

inline void pyprinter(const char *fmt, ...) {
char buffer[1024];

va_list args;
va_start(args, fmt);
vsnprintf(&buffer[0], sizeof(buffer), fmt, args);
va_end(args);

pybind11::print(std::string(buffer), pybind11::arg("end") = "");
}

#define printf_epiworld pyprinter

/* Keep me at the bottom! */
#include "epiworld.hpp"

#endif /* EPIWORLPY_INTERFACE_HPP */
Loading

0 comments on commit d94b7fa

Please sign in to comment.