diff --git a/include/random_walks/uniform_accelerated_billiard_walk.hpp b/include/random_walks/uniform_accelerated_billiard_walk.hpp index 960871017..e3a63fe06 100644 --- a/include/random_walks/uniform_accelerated_billiard_walk.hpp +++ b/include/random_walks/uniform_accelerated_billiard_walk.hpp @@ -16,6 +16,8 @@ #include #include +const double eps = 0.000000001; + template class Heap { public: @@ -24,29 +26,31 @@ class Heap { std::vector> vec; private: - void siftDown(int index) { + int siftDown(int index) { while((index << 1) + 1 < heap_size) { int child = (index << 1) + 1; - if(child + 1 < heap_size && heap[child + 1].first < heap[child].first) { + if(child + 1 < heap_size && heap[child + 1].first < heap[child].first - eps) { child += 1; } - if(heap[child].first < heap[index].first) + if(heap[child].first < heap[index].first - eps) { std::swap(heap[child], heap[index]); std::swap(vec[heap[child].second].second, vec[heap[index].second].second); index = child; } else { - return; + return index; } } + return index; } - void siftUp(int index) { - while(index > 0 && heap[(index - 1) >> 1].first > heap[index].first) { + int siftUp(int index) { + while(index > 0 && heap[(index - 1) >> 1].first - eps > heap[index].first) { std::swap(heap[(index - 1) >> 1], heap[index]); std::swap(vec[heap[(index - 1) >> 1].second].second, vec[heap[index].second].second); index = (index - 1) >> 1; } + return index; } public: @@ -79,7 +83,7 @@ class Heap { return heap[0]; } - void remove (const int index) { // takes the index from the heap + void remove (int index) { // takes the index from the heap if(index == -1) { return; } @@ -87,7 +91,8 @@ class Heap { std::swap(vec[heap[heap_size - 1].second].second, vec[heap[index].second].second); vec[heap[heap_size - 1].second].second = -1; heap_size -= 1; - siftDown(index); + index = siftDown(index); + siftUp(index); } void insert (const std::pair val) {