Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThreadPool: Spend less time busy waiting. (2nd Attempt) #23278

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 69 additions & 13 deletions include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ class ThreadPoolProfiler {
void LogStart() {};
void LogEnd(ThreadPoolEvent){};
void LogEndAndStart(ThreadPoolEvent){};
void LogStartAndCoreAndBlock(std::ptrdiff_t){};
void LogCoreAndBlock(std::ptrdiff_t){};
void LogStartAndCoreAndBlock(std::ptrdiff_t) {}
void LogCoreAndBlock(std::ptrdiff_t) {}
void LogThreadId(int) {};
void LogRun(int) {};
std::string DumpChildThreadStat() { return {}; }
Expand Down Expand Up @@ -766,13 +766,45 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
typedef std::function<void()> Task;
typedef RunQueue<Task, Tag, 1024> Queue;

// Class for waiting w/ exponential backoff.
// Template argument is maximum number of spins in backoff loop.
class ThreadPoolWaiter {
// Current number if spins in backoff loop
unsigned pause_time_{0};
const unsigned max_backoff_;

public:
explicit ThreadPoolWaiter(unsigned max_backoff)
: max_backoff_(max_backoff) {
}

void wait() {
switch (max_backoff_) {
case 1:
onnxruntime::concurrency::SpinPause();
[[fallthrough]];
case 0:
// If kMaxBackoff is zero don't do any pausing.
return;
default:
// Exponential backoff
unsigned pause_time = pause_time_ + 1U;
for (unsigned i = 0; i < pause_time; ++i) {
onnxruntime::concurrency::SpinPause();
}
pause_time_ = (pause_time * 2U) % max_backoff_;
}
}
};

ThreadPoolTempl(const CHAR_TYPE* name, int num_threads, bool allow_spinning, Environment& env,
const ThreadOptions& thread_options)
const ThreadOptions& thread_options, bool is_hybrid)
: profiler_(num_threads, name),
env_(env),
num_threads_(num_threads),
allow_spinning_(allow_spinning),
set_denormal_as_zero_(thread_options.set_denormal_as_zero),
is_hybrid_{is_hybrid},
worker_data_(num_threads),
all_coprimes_(num_threads),
blocked_(0),
Expand Down Expand Up @@ -907,8 +939,9 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
// finish dispatch work. This avoids new tasks being started
// concurrently with us attempting to end the parallel section.
if (ps.dispatch_q_idx != -1) {
ThreadPoolWaiter waiter{4};
while (!ps.dispatch_done.load(std::memory_order_acquire)) {
onnxruntime::concurrency::SpinPause();
waiter.wait();
}
}

Expand All @@ -928,17 +961,19 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
}
profiler_.LogEnd(ThreadPoolProfiler::WAIT_REVOKE);

ThreadPoolWaiter waiter{is_hybrid_ ? 0U : 1U};

// Wait for the dispatch task's own work...
if (ps.dispatch_q_idx > -1) {
while (!ps.work_done.load(std::memory_order_acquire)) {
onnxruntime::concurrency::SpinPause();
waiter.wait();
}
}

// ...and wait for any other tasks not revoked to finish their work
auto tasks_to_wait_for = tasks_started - ps.tasks_revoked;
while (ps.tasks_finished < tasks_to_wait_for) {
onnxruntime::concurrency::SpinPause();
waiter.wait();
}

// Clear status to allow the ThreadPoolParallelSection to be
Expand Down Expand Up @@ -1255,10 +1290,12 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

// Increase the worker count if needed. Each worker will pick up
// loops to execute from the current parallel section.
std::function<void(unsigned)> worker_fn = [&ps](unsigned par_idx) {
const auto is_hybrid = is_hybrid_;
std::function<void(unsigned)> worker_fn = [&ps, is_hybrid](unsigned par_idx) {
ThreadPoolWaiter waiter{is_hybrid ? 4U : 0U};
while (ps.active) {
if (ps.current_loop.load() == nullptr) {
onnxruntime::concurrency::SpinPause();
waiter.wait();
} else {
ps.workers_in_loop++;
ThreadPoolLoop* work_item = ps.current_loop;
Expand All @@ -1279,8 +1316,9 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

// Wait for workers to exit the loop
ps.current_loop = 0;
ThreadPoolWaiter waiter{is_hybrid_ ? 1U : 4U};
while (ps.workers_in_loop) {
onnxruntime::concurrency::SpinPause();
waiter.wait();
}
profiler_.LogEnd(ThreadPoolProfiler::WAIT);
}
Expand Down Expand Up @@ -1496,6 +1534,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
const unsigned num_threads_;
const bool allow_spinning_;
const bool set_denormal_as_zero_;
const bool is_hybrid_;
Eigen::MaxSizeVector<WorkerData> worker_data_;
Eigen::MaxSizeVector<Eigen::MaxSizeVector<unsigned>> all_coprimes_;
std::atomic<unsigned> blocked_; // Count of blocked workers, used as a termination condition
Expand Down Expand Up @@ -1535,13 +1574,30 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

assert(td.GetStatus() == WorkerData::ThreadStatus::Spinning);

constexpr int log2_spin = 20;
const int spin_count = allow_spinning_ ? (1ull << log2_spin) : 0;
const int steal_count = spin_count / 100;
// The exact value of spin_count and steal_count are arbitrary and
// were experimentally determined. These numbers yielded the best
// performance across a range of workloads and
// machines. Generally, the goal of tuning spin_count is to make
// the number as small as possible while ensuring there is enough
// slack so that if each core is doing the same amount of work it
// won't sleep before they have all finished. The idea here is
// that in pipelined workloads, it won't sleep during each stage
// if it's done a bit faster than its neighbors, but that if there
// are non-equal sizes of work distributed, it won't take too long
// to reach sleep giving power (and thus frequency/performance) to
// its neighbors. Since hybrid has P/E cores, a lower value is
// chosen. On hybrid systems, even with equal sized workloads
// distributed the compute time won't stay synced. Typically in
// the hybrid case the P cores finish first (and are thus waiting)
// which is essentially a priority inversion.
const int pref_spin_count = is_hybrid_ ? 5000 : 10000;
const int spin_count = allow_spinning_ ? pref_spin_count : 0;
const int steal_count = pref_spin_count / (is_hybrid_ ? 25 : 100);

SetDenormalAsZero(set_denormal_as_zero_);
profiler_.LogThreadId(thread_id);

ThreadPoolWaiter waiter{is_hybrid_ ? 1U : 8U};
while (!should_exit) {
Task t = q.PopFront();
if (!t) {
Expand All @@ -1557,7 +1613,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
if (spin_loop_status_.load(std::memory_order_relaxed) == SpinLoopStatus::kIdle) {
break;
}
onnxruntime::concurrency::SpinPause();
waiter.wait();
}

// Attempt to block
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/platform/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class ThreadPool {
ExtendedThreadPoolInterface* underlying_threadpool_ = nullptr;

// If used, underlying_threadpool_ is instantiated and owned by the ThreadPool.
std::unique_ptr<ThreadPoolTempl<Env> > extended_eigen_threadpool_;
std::unique_ptr<ThreadPoolTempl<Env>> extended_eigen_threadpool_;

// Force the thread pool to run in hybrid mode on a normal cpu.
bool force_hybrid_ = false;
Expand Down
20 changes: 8 additions & 12 deletions onnxruntime/core/common/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,12 @@ ThreadPool::ThreadPool(Env* env,
assert(thread_options_.affinities.size() >= size_t(threads_to_create));
}

extended_eigen_threadpool_ =
std::make_unique<ThreadPoolTempl<Env> >(name,
threads_to_create,
low_latency_hint,
*env,
thread_options_);
extended_eigen_threadpool_ = std::make_unique<ThreadPoolTempl<Env>>(name,
threads_to_create,
low_latency_hint,
*env,
thread_options_,
force_hybrid_);
underlying_threadpool_ = extended_eigen_threadpool_.get();
}
}
Expand Down Expand Up @@ -665,15 +665,11 @@ std::string ThreadPool::StopProfiling(concurrency::ThreadPool* tp) {
}

void ThreadPool::EnableSpinning() {
if (extended_eigen_threadpool_) {
extended_eigen_threadpool_->EnableSpinning();
}
extended_eigen_threadpool_->EnableSpinning();
}

void ThreadPool::DisableSpinning() {
if (extended_eigen_threadpool_) {
extended_eigen_threadpool_->DisableSpinning();
}
extended_eigen_threadpool_->DisableSpinning();
}

// Return the number of threads created by the pool.
Expand Down
Loading