Skip to content

Commit

Permalink
Merge branch 'HotFix_multiprocess'
Browse files Browse the repository at this point in the history
  • Loading branch information
ymd-h committed Feb 13, 2022
2 parents 07f9dbc + 0f39eb5 commit cbfae94
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 34 deletions.
18 changes: 18 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ SegmentTree.cpp:
- $CXX -o test/SegmentTree test/SegmentTree.cpp
- test/SegmentTree


cpp_bench:
image: *dev_image
stage: bench_mark_test
needs:
- job: ReplayBuffer.cpp
artifacts: false
- job: SegmentTree.cpp
artifacts: false
script:
- $CXX -o test/segmenttree_bench test/segmenttree_bench.cpp
- test/segmenttree_bench
- test/segmenttree_bench
- test/segmenttree_bench
- test/segmenttree_bench
- test/segmenttree_bench


ReplayBuffer:
<<: *py_setup
script:
Expand Down
9 changes: 1 addition & 8 deletions cpprb/ReplayBuffer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ namespace ymd {
sum{PowerOf2(buffer_size),[](auto a,auto b){ return a+b; },
Priority{0},
sum_ptr,sum_anychanged,initialize},
min{PowerOf2(buffer_size),[](Priority a,Priority b){ return std::min(a,b); },
min{PowerOf2(buffer_size),[](Priority a,Priority b){ return std::min(a,b); },
std::numeric_limits<Priority>::max(),
min_ptr,min_anychanged,initialize},
g{std::random_device{}()},
Expand Down Expand Up @@ -486,13 +486,6 @@ namespace ymd {
void set_eps(Priority eps){
this->eps = eps;
}

void weak_update_changed(){
if constexpr (MultiThread) {
sum.weak_update_changed();
min.weak_update_changed();
}
}
};

template<typename Priority>
Expand Down
65 changes: 41 additions & 24 deletions cpprb/SegmentTree.hh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ namespace ymd {
return tmp != buffer[i];
}

void update_all(){
void update_init(){
for(std::size_t i = access_index(0) -1, end = -1; i != end; --i){
update_buffer(i);
}
Expand All @@ -78,6 +78,22 @@ namespace ymd {
}
}

void update_all(){
constexpr const std::size_t zero = 0;
const auto end = parent(access_index(buffer_size-1))+1;
for(auto i = parent(access_index(0)); i != end; ++i){
auto updated = update_buffer(i);
auto _i = i;
while((_i != zero) && updated){
_i = parent(_i);
updated = update_buffer(_i);
}
}
if constexpr (MultiThread){
any_changed->store(false,std::memory_order_release);
}
}

public:
SegmentTree(std::size_t n,F f, T v = T{0},
T* buffer_ptr = nullptr,
Expand Down Expand Up @@ -105,7 +121,7 @@ namespace ymd {
if(initialize){
std::fill_n(buffer+access_index(0),n,v);

update_all();
update_init();
}
}
SegmentTree(): SegmentTree{2,[](auto a,auto b){ return a+b; }} {}
Expand Down Expand Up @@ -142,8 +158,6 @@ namespace ymd {
constexpr const std::size_t zero = 0;
if(zero == max){ max = buffer_size; }

std::set<std::size_t> will_update{};

if constexpr (MultiThread){
if(N){ any_changed->store(true,std::memory_order_release); }
}
Expand All @@ -155,24 +169,17 @@ namespace ymd {
if constexpr (!MultiThread){
for(auto n = std::size_t(0); n < copy_N; ++n){
auto _i = access_index(i+n);
if(_i != 0){
will_update.insert(parent(_i));
auto updated = true;
while((_i != zero) && updated){
_i = parent(_i);
updated = update_buffer(_i);
}
}
}

N = (N > copy_N) ? N - copy_N: zero;
i = zero;
}

if constexpr (!MultiThread) {
while(!will_update.empty()){
i = *(will_update.rbegin());
auto updated = update_buffer(i);
will_update.erase(i);
if(i && updated){ will_update.insert(parent(i)); }
}
}
}

void set(std::size_t i,T v,std::size_t N,std::size_t max = std::size_t(0)){
Expand All @@ -190,7 +197,8 @@ namespace ymd {
}

auto largest_region_index(std::function<bool(T)> condition,
std::size_t n=std::size_t(0)) {
std::size_t n=std::size_t(0),
T init = T{0}) {
// max index of reduce( [0,index) ) -> true

constexpr const std::size_t zero = 0;
Expand All @@ -203,21 +211,30 @@ namespace ymd {
}
}

std::size_t min = zero;
auto max = (zero != n) ? n: buffer_size;
if(n == zero){ n = buffer_size; }
auto b = zero;

if(condition(buffer[b])){ return n-1; }

auto index = (min + max)/two;
auto min = zero;
auto max = buffer_size;
auto cond = condition;
auto red = init;

while(max - min > one){
if( condition(_reduce(zero,index,zero,zero,buffer_size)) ){
min = index;
auto b_left = child_left(b);
if(cond(buffer[b_left])){
min = (min + max) / two;
red = f(red, buffer[b_left]);
cond = [=](auto v){ return condition(f(red,v)); };
b = child_right(b);
}else{
max = index;
max = (min + max) / two;
b = b_left;
}
index = (min + max)/two;
}

return index;
return std::min(min, n-1);
}

void clear(T v = T{0}){
Expand Down
4 changes: 2 additions & 2 deletions test/PyReplayBuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def test_5_done(self):
class TestPrioritizedBase:
def test_weights(self):
self._check_ndarray(self.s['weights'],1,(self.batch_size,),"weights")
for w in self.s['weights']:
self.assertAlmostEqual(w,1.0)
np.testing.assert_allclose(self.s["weights"],
np.full_like(self.s["weights"], 1.0))

def test_indexes(self):
self._check_ndarray(self.s['indexes'],1,(self.batch_size,),"indexes")
Expand Down
105 changes: 105 additions & 0 deletions test/segmenttree_bench.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include <algorithm>
#include <iostream>
#include <iterator>
#include <chrono>
#include <vector>

#include <SegmentTree.hh>
#include <ReplayBuffer.hh>

using PER = ymd::CppPrioritizedSampler<float>;
using MPPER = ymd::CppThreadSafePrioritizedSampler<float>;

auto bench = [](auto&& F, auto n, auto fmt=""){
auto t1 = std::chrono::high_resolution_clock::now();
for(auto i=0ul; i < n; ++i){ F(); }
auto t2 = std::chrono::high_resolution_clock::now();

std::cout << fmt
<< std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count()
<< std::endl;
};


int main(int argc, char** argv){
constexpr const auto buffer_size = 1000000ul;
constexpr const auto size = ymd::PowerOf2(buffer_size);

auto sum = ymd::SegmentTree<float>(size, [](auto a, auto b){ return a+b; });
auto sum2 = ymd::SegmentTree<float,true>(size, [](auto a, auto b){ return a+b; });

bench([&,i=0, j=0]() mutable { sum.set(i++, j++); }, 10000, "sum1.set A: ");
bench([&,i=100]() mutable { sum.set(100*(i++),
[j=0]()mutable{ return j++; },
100,
buffer_size); }, 100, "sum1.set B: ");
bench([&,i=20]() mutable { sum.set(1000*(i++),
[j=0]()mutable{ return j++; },
1000,
buffer_size); }, 10, "sum1.set C: ");
bench([&]() mutable { sum.reduce(0, 30000); }, 1, "sum1.red A: ");
bench([&]() mutable { sum.reduce(0, 30000); }, 1, "sum1.red B: ");
bench([&,i=0]() mutable {
sum.largest_region_index([&](auto v){ return v <= 79.8*(i++); }, 30000);
},
10000, "sum1.lridx: ");

bench([&,i=0, j=0]() mutable { sum2.set(i++, j++); }, 10000, "sum2.set A: ");
bench([&,i=100]() mutable { sum2.set(100*(i++),
[j=0]()mutable{ return j++; },
100,
buffer_size); }, 100, "sum2.set B: ");
bench([&,i=20]() mutable { sum2.set(1000*(i++),
[j=0]()mutable{ return j++; },
1000,
buffer_size); }, 10, "sum2.set C: ");
bench([&]() mutable { sum2.reduce(0, 30000); }, 1, "sum2.red A: ");
bench([&]() mutable { sum2.reduce(0, 30000); }, 1, "sum2.red B: ");
bench([&,i=0]() mutable {
sum2.largest_region_index([&](auto v){ return v <= 79.8*(i++); }, 30000);
},
10000, "sum2.lridx: ");

std::cout << sum.get(1) << " " << sum2.get(1) << std::endl;
std::cout << sum.get(10001) << " " << sum2.get(10001) << std::endl;

//

constexpr const auto alpha = 0.5, beta = 0.4;
auto per = PER(buffer_size, alpha);
auto mpper = MPPER(buffer_size, alpha);

auto p = std::vector<float>{};
p.reserve(10000);
std::generate_n(std::back_inserter(p), 10000,
[i=0]() mutable { return 0.02*(i++ % 321); });

auto indexes = std::vector<size_t>{};
indexes.reserve(32);
auto weights = std::vector<float>{};
weights.reserve(32);

bench([&, i=0,j=0]() mutable { per.set_priorities(i++, 0.02*(j++ % 321)); },
10000, " PER.add1: ");
bench([&, i=100,j=0]() mutable {
per.set_priorities(100*(i++), p.data()+100*(j++), 100, buffer_size);
},
100, " PER.addN: ");
bench([&]() mutable { per.sample(32,beta,weights,indexes,20000); },
1, " PER.smpA: ");
bench([&]() mutable { per.sample(32,beta,weights,indexes,20000); },
1, " PER.smpB: ");

bench([&, i=0,j=0]() mutable { mpper.set_priorities(i++, 0.02*(j++ % 321)); },
10000, "MPPER.add1: ");
bench([&, i=100,j=0]() mutable {
mpper.set_priorities(100*(i++), p.data()+100*(j++), 100, buffer_size);
},
100, "MPPER.addN: ");
bench([&]() mutable { mpper.sample(32,beta,weights,indexes,20000); },
1, "MPPER.smpA: ");
bench([&]() mutable { mpper.sample(32,beta,weights,indexes,20000); },
1, "MPPER.smpB: ");

return 0;
}

0 comments on commit cbfae94

Please sign in to comment.