Skip to content

Commit

Permalink
Set Thread Num According to Cgroup
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Weizhi Xu <[email protected]>
  • Loading branch information
PwzXxm committed Nov 9, 2023
1 parent 25bf8a3 commit 2addb6e
Showing 1 changed file with 106 additions and 2 deletions.
108 changes: 106 additions & 2 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,122 @@
#include <omp.h>
#include <sys/resource.h>

#include <algorithm>
#include <cerrno>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <thread>
#include <utility>
#include <vector>

#include "folly/executors/CPUThreadPoolExecutor.h"
#include "folly/futures/Future.h"
#include "knowhere/log.h"

namespace knowhere {

namespace fs = std::filesystem;
class CgroupCpuReader {
private:
static auto
split(const std::string& s, char delimiter) -> std::vector<std::string> {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(s);
while (std::getline(tokenStream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}

static auto
getCgroupCpuPath() -> fs::path {
std::ifstream fin("/proc/self/cgroup");
std::string line;
while (std::getline(fin, line)) {
auto fields = split(line, ':');
if (fields.size() >= 3) {
if (auto sub_sys = split(fields[1], ',');
std::find(sub_sys.cbegin(), sub_sys.cend(), "cpu") != sub_sys.cend()) {
return fields[2];
}
}
}
throw std::runtime_error("Unable to get cgroup file");
}

static auto
getCgroupMountPath() -> std::pair<fs::path, fs::path> {
std::ifstream fin("/proc/self/mountinfo");
std::string line;
while (std::getline(fin, line)) {
auto fields = split(line, ' ');
if (auto it = std::find(fields.cbegin(), fields.cend(), "-"); it != fields.cend()) {
if (*std::next(it) == "cgroup") {
auto sub_systems = split(*std::next(it, 3), ',');
if (std::find(sub_systems.cbegin(), sub_systems.cend(), "cpu") != sub_systems.cend()) {
return {fields[3], fields[4]};
}
}
}
}
throw std::runtime_error("Unable to get mount info");
}

CgroupCpuReader() = default;
~CgroupCpuReader() = default;

public:
CgroupCpuReader(const CgroupCpuReader&) = delete;
CgroupCpuReader(CgroupCpuReader&&) noexcept = delete;
auto
operator=(const CgroupCpuReader&) -> CgroupCpuReader& = delete;
auto
operator=(CgroupCpuReader&&) noexcept -> CgroupCpuReader& = delete;

static auto
GetCpuNum() -> int {
#ifdef __linux__
try {
auto readIntFromFile = [](const fs::path& path) -> int {
std::ifstream fin(path);
int ret;
if (fin >> ret) {
return ret;
}
throw std::runtime_error("Failed to get int value from " + path.generic_string());
};
auto [root, mount_path] = getCgroupMountPath();
auto cgroup_path = (mount_path / fs::relative(getCgroupCpuPath(), root));
auto quota_file_path = cgroup_path / "cpu.cfs_quota_us";
auto period_file_path = cgroup_path / "cpu.cfs_period_us";
int quota = readIntFromFile(quota_file_path);
// if no limit on cpu quota
if (quota < 0) {
return std::thread::hardware_concurrency();
} else if (quota == 0) {
throw std::runtime_error("Cpu quota is 0");
}
int period = readIntFromFile(period_file_path);
int cpu_num = quota / period;
return std::max(cpu_num, 1);
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "Failed to get cpu num from cgroups: " << e.what()
<< ". Fallback to hardware concurrency";
return std::thread::hardware_concurrency();
}
#else
return std::thread::hardware_concurrency();
#endif
}
};

class ThreadPool {
#ifdef __linux__
private:
Expand Down Expand Up @@ -138,7 +242,7 @@ class ThreadPool {
static std::shared_ptr<ThreadPool>
GetGlobalBuildThreadPool() {
if (global_build_thread_pool_size_ == 0) {
InitThreadPool(std::thread::hardware_concurrency(), global_build_thread_pool_size_);
InitThreadPool(CgroupCpuReader::GetCpuNum(), global_build_thread_pool_size_);
LOG_KNOWHERE_WARNING_ << "Global Build ThreadPool has not been initialized yet, init it with threads num: "
<< global_build_thread_pool_size_;
}
Expand All @@ -149,7 +253,7 @@ class ThreadPool {
static std::shared_ptr<ThreadPool>
GetGlobalSearchThreadPool() {
if (global_search_thread_pool_size_ == 0) {
InitThreadPool(std::thread::hardware_concurrency(), global_search_thread_pool_size_);
InitThreadPool(CgroupCpuReader::GetCpuNum(), global_search_thread_pool_size_);
LOG_KNOWHERE_WARNING_ << "Global Search ThreadPool has not been initialized yet, init it with threads num: "
<< global_search_thread_pool_size_;
}
Expand Down

0 comments on commit 2addb6e

Please sign in to comment.