Skip to content

Commit

Permalink
Converted all std::async to thread pool dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
hosseinmoein committed Dec 5, 2023
1 parent a09754b commit a10ebfd
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 193 deletions.
74 changes: 38 additions & 36 deletions include/DataFrame/DataFrameFinancialVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,28 +187,29 @@ struct DoubleCrossOver {
operator() (const K &idx_begin, const K &idx_end,
const H &prices_begin, const H &prices_end) {

const size_type thread_level =
ThreadGranularity::get_thread_level();
size_type re_count1 = 0;
size_type re_count2 = 0;
const auto thread_level = ThreadGranularity::get_thread_level();
size_type re_count1 = 0;
size_type re_count2 = 0;

if (thread_level >= 2) {
if (thread_level > 0) {
std::future<size_type> fut1 =
std::async(std::launch::async,
&DoubleCrossOver::run_short_roller_<K, H>,
this,
std::cref(idx_begin),
std::cref(idx_end),
std::cref(prices_begin),
std::cref(prices_end));
ThreadGranularity::thr_pool_.dispatch(
false,
&DoubleCrossOver::run_short_roller_<K, H>,
this,
std::cref(idx_begin),
std::cref(idx_end),
std::cref(prices_begin),
std::cref(prices_end));
std::future<size_type> fut2 =
std::async(std::launch::async,
&DoubleCrossOver::run_long_roller_<K, H>,
this,
std::cref(idx_begin),
std::cref(idx_end),
std::cref(prices_begin),
std::cref(prices_end));
ThreadGranularity::thr_pool_.dispatch(
false,
&DoubleCrossOver::run_long_roller_<K, H>,
this,
std::cref(idx_begin),
std::cref(idx_end),
std::cref(prices_begin),
std::cref(prices_end));

re_count1 = fut1.get();
re_count2 = fut2.get();
Expand Down Expand Up @@ -318,26 +319,27 @@ struct BollingerBand {
operator() (const K &idx_begin, const K &idx_end,
const H &prices_begin, const H &prices_end) {

const size_type thread_level =
ThreadGranularity::get_thread_level();
const auto thread_level = ThreadGranularity::get_thread_level();

if (thread_level >= 2) {
if (thread_level > 0) {
std::future<void> fut1 =
std::async(std::launch::async,
&BollingerBand::run_mean_roller_<K, H>,
this,
std::cref(idx_begin),
std::cref(idx_end),
std::cref(prices_begin),
std::cref(prices_end));
ThreadGranularity::thr_pool_.dispatch(
false,
&BollingerBand::run_mean_roller_<K, H>,
this,
std::cref(idx_begin),
std::cref(idx_end),
std::cref(prices_begin),
std::cref(prices_end));
std::future<void> fut2 =
std::async(std::launch::async,
&BollingerBand::run_std_roller_<K, H>,
this,
std::cref(idx_begin),
std::cref(idx_end),
std::cref(prices_begin),
std::cref(prices_end));
ThreadGranularity::thr_pool_.dispatch(
false,
&BollingerBand::run_std_roller_<K, H>,
this,
std::cref(idx_begin),
std::cref(idx_end),
std::cref(prices_begin),
std::cref(prices_end));

fut1.get();
fut2.get();
Expand Down
35 changes: 6 additions & 29 deletions include/DataFrame/DataFrameStatsVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -1794,37 +1794,16 @@ struct AutoCorrVisitor {

vec_type<value_type> tmp_result(col_s - 4);
size_type lag = 1;
const size_type thread_level =
ThreadGranularity::get_thread_level();
vec_type<std::future<CorrResult>> futures(thread_level);
size_type thread_count = 0;

tmp_result[0] = 1.0;
while (lag < col_s - 4) {
if (thread_count >= thread_level) {
const auto result = get_auto_corr_(col_s, lag, column_begin);

tmp_result[result.first] = result.second;
}
else {
futures[thread_count] =
std::async(std::launch::async,
&AutoCorrVisitor::get_auto_corr_<H>,
this,
col_s,
lag,
std::cref(column_begin));
thread_count += 1;
}
lag += 1;
}

for (size_type i = 0; i < thread_count; ++i) {
const auto &result = futures[i].get();
while (lag < col_s - 4) {
const auto result = get_auto_corr_(col_s, lag, column_begin);

tmp_result[result.first] = result.second;
lag += 1;
}
tmp_result.swap(result_);
result_.swap(tmp_result);
}

DEFINE_PRE_POST
Expand All @@ -1839,10 +1818,8 @@ struct AutoCorrVisitor {
using CorrResult = std::pair<size_type, value_type>;

template<typename H>
inline CorrResult
get_auto_corr_(size_type col_s,
size_type lag,
const H &column_begin) const {
inline static CorrResult
get_auto_corr_(size_type col_s, size_type lag, const H &column_begin) {

CorrVisitor<value_type, index_type> corr { };
constexpr I dummy = I();
Expand Down
129 changes: 57 additions & 72 deletions include/DataFrame/Internals/DataFrame.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -377,91 +377,76 @@ fill_missing(const StlVecType<const char *> &col_names,
const StlVecType<T> &values,
int limit) {

const size_type count = col_names.size();
StlVecType<std::future<void>> futures(get_thread_level());
ThreadGranularity::size_type thread_count = 0;
if (fp == fill_policy::linear_extrapolate) {
char buffer [512];

snprintf(buffer, sizeof(buffer) - 1,
"DataFrame::fill_missing(): fill_policy %d not implemented",
static_cast<int>(fp));
throw NotImplemented(buffer);
}

const size_type count = col_names.size();
const ThreadGranularity::size_type thread_count = get_thread_level();
StlVecType<std::future<void>> futures;

if (thread_count > 0)
futures.reserve(count);
for (size_type i = 0; i < count; ++i) {
ColumnVecType<T> &vec = get_column<T>(col_names[i]);

if (fp == fill_policy::value) {
if (thread_count >= get_thread_level())
if (thread_count == 0) {
if (fp == fill_policy::value)
fill_missing_value_(vec, values[i], limit, indices_.size());
else {
futures[thread_count] =
std::async(std::launch::async,
&DataFrame::fill_missing_value_<T>,
std::ref(vec),
std::cref(values[i]),
limit,
indices_.size());
thread_count += 1;
}
}
else if (fp == fill_policy::fill_forward) {
if (thread_count >= get_thread_level())
else if (fp == fill_policy::fill_forward)
fill_missing_ffill_<T>(vec, limit, indices_.size());
else {
futures[thread_count] =
std::async(std::launch::async,
&DataFrame::fill_missing_ffill_<T>,
std::ref(vec),
limit,
indices_.size());
thread_count += 1;
}
}
else if (fp == fill_policy::fill_backward) {
if (thread_count >= get_thread_level())
else if (fp == fill_policy::fill_backward)
fill_missing_bfill_<T>(vec, limit);
else {
futures[thread_count] =
std::async(std::launch::async,
&DataFrame::fill_missing_bfill_<T>,
std::ref(vec),
limit);
thread_count += 1;
}
}
else if (fp == fill_policy::linear_interpolate) {
if (thread_count >= get_thread_level())
else if (fp == fill_policy::linear_interpolate)
fill_missing_linter_<T>(vec, indices_, limit);
else {
futures[thread_count] =
std::async(std::launch::async,
&DataFrame::fill_missing_linter_<T>,
std::ref(vec),
std::cref(indices_),
limit);
thread_count += 1;
}
}
else if (fp == fill_policy::mid_point) {
if (thread_count >= get_thread_level())
else if (fp == fill_policy::mid_point)
fill_missing_midpoint_<T>(vec, limit, indices_.size());
else {
futures[thread_count] =
std::async(std::launch::async,
&DataFrame::fill_missing_midpoint_<T>,
std::ref(vec),
limit,
indices_.size());
thread_count += 1;
}
}
else if (fp == fill_policy::linear_extrapolate) {
char buffer [512];

snprintf (
buffer, sizeof(buffer) - 1,
"DataFrame::fill_missing(): fill_policy %d is not implemented",
static_cast<int>(fp));
throw NotImplemented(buffer);
else {
if (fp == fill_policy::value)
futures.emplace_back(
thr_pool_.dispatch(false,
&DataFrame::fill_missing_value_<T>,
std::ref(vec),
std::cref(values[i]),
limit,
indices_.size()));
else if (fp == fill_policy::fill_forward)
futures.emplace_back(
thr_pool_.dispatch(false,
&DataFrame::fill_missing_ffill_<T>,
std::ref(vec),
limit,
indices_.size()));
else if (fp == fill_policy::fill_backward)
futures.emplace_back(
thr_pool_.dispatch(false,
&DataFrame::fill_missing_bfill_<T>,
std::ref(vec),
limit));
else if (fp == fill_policy::linear_interpolate)
futures.emplace_back(
thr_pool_.dispatch(false,
&DataFrame::fill_missing_linter_<T>,
std::ref(vec),
std::cref(indices_),
limit));
else if (fp == fill_policy::mid_point)
futures.emplace_back(
thr_pool_.dispatch(false,
&DataFrame::fill_missing_midpoint_<T>,
std::ref(vec),
limit,
indices_.size()));
}
}
for (auto &fut : futures) fut.get();

for (ThreadGranularity::size_type idx = 0; idx < thread_count; ++idx)
futures[idx].get();
return;
}

Expand Down
55 changes: 7 additions & 48 deletions include/DataFrame/Internals/DataFrame_shift.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -47,35 +47,14 @@ void DataFrame<I, H>::self_shift(size_type periods, shift_policy sp) {
if (periods > 0) [[likely]] {
if (sp == shift_policy::down || sp == shift_policy::up) [[likely]] {
vertical_shift_functor_<Ts ...> functor(periods, sp);
StlVecType<std::future<void>> futures(get_thread_level());
ThreadGranularity::size_type thread_count = 0;
const size_type data_size = data_.size();
const size_type num_cols = data_.size();

{
const SpinGuard guard(lock_);

for (size_type idx = 0; idx < data_size; ++idx) [[likely]] {
if (thread_count >= get_thread_level())
data_[idx].change(functor);
else {
auto to_be_called =
static_cast
<void(DataVec::*)(vertical_shift_functor_<Ts ...> &&)>
(&DataVec::template
change<vertical_shift_functor_<Ts ...>>);

futures[thread_count] =
std::async(std::launch::async,
to_be_called,
&(data_[idx]),
std::move(functor));
thread_count += 1;
}
}
for (size_type idx = 0; idx < num_cols; ++idx) [[likely]]
data_[idx].change(functor);
}
for (ThreadGranularity::size_type idx = 0;
idx < thread_count; ++idx)
futures[idx].get();
}
else if (sp == shift_policy::left) {
while (periods-- > 0)
Expand Down Expand Up @@ -132,35 +111,15 @@ void DataFrame<I, H>::self_rotate(size_type periods, shift_policy sp) {

if (periods > 0) {
if (sp == shift_policy::down || sp == shift_policy::up) [[likely]] {
rotate_functor_<Ts ...> functor(periods, sp);
StlVecType<std::future<void>> futures(get_thread_level());
ThreadGranularity::size_type thread_count = 0;
const size_type data_size = data_.size();
rotate_functor_<Ts ...> functor(periods, sp);
const size_type num_cols = data_.size();

{
const SpinGuard guard(lock_);

for (size_type idx = 0; idx < data_size; ++idx) [[likely]] {
if (thread_count >= get_thread_level())
data_[idx].change(functor);
else {
auto to_be_called =
static_cast
<void(H::*)(rotate_functor_<Ts ...> &&)>
(&H::template change<rotate_functor_<Ts ...>>);

futures[thread_count] =
std::async(std::launch::async,
to_be_called,
&(data_[idx]),
std::move(functor));
thread_count += 1;
}
}
for (size_type idx = 0; idx < num_cols; ++idx) [[likely]]
data_[idx].change(functor);
}
for (ThreadGranularity::size_type idx = 0;
idx < thread_count; ++idx)
futures[idx].get();
}
else if (sp == shift_policy::left) {
std::rotate(column_list_.begin(),
Expand Down
6 changes: 5 additions & 1 deletion include/DataFrame/Utils/Threads/SharedQueue.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <DataFrame/Utils/Threads/SharedQueue.h>

#include <chrono>

using namespace std::chrono_literals;

// ----------------------------------------------------------------------------

namespace hmdf
Expand Down Expand Up @@ -73,7 +77,7 @@ SharedQueue<T>::pop_front(bool wait_on_front) noexcept {
std::unique_lock<std::mutex> ul { mutex_ };

if (queue_.empty() && wait_on_front) {
while (queue_.empty()) cvx_.wait(ul);
while (queue_.empty()) cvx_.wait_for(ul, 2s);
}

if (! queue_.empty()) {
Expand Down
Loading

0 comments on commit a10ebfd

Please sign in to comment.