Skip to content

Commit

Permalink
optimize multi gpu support (#146)
Browse files Browse the repository at this point in the history
Signed-off-by: Yusheng.Ma <[email protected]>
  • Loading branch information
Presburger authored Oct 18, 2023
1 parent 1dc6e21 commit 84ab6ab
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/common/raft/raft_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ gpu_device_manager::choose_with_load(size_t load) {
return std::distance(memory_load_.begin(), it);
}

void
gpu_device_manager::release_load(int device_id, size_t load) {
if (size_t(device_id) < memory_load_.size()) {
std::lock_guard<std::mutex> lock(mtx_);
memory_load_[device_id] -= load;
} else {
LOG_KNOWHERE_WARNING_ << "please check device id " << device_id;
}
}

gpu_device_manager::gpu_device_manager() {
int device_counts;
try {
Expand Down
6 changes: 6 additions & 0 deletions src/common/raft/raft_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class gpu_device_manager {
random_choose() const;
int
choose_with_load(size_t load);
void
release_load(int device_id, size_t load);

private:
gpu_device_manager();
Expand Down Expand Up @@ -192,3 +194,7 @@ set_mem_pool_size(size_t init_size, size_t max_size) {
do { \
x = raft_utils::gpu_device_manager::instance().choose_with_load(load); \
} while (0)
#define RELEASE_DEVICE(x, load) \
do { \
raft_utils::gpu_device_manager::instance().release_load(x, load); \
} while (0)
7 changes: 7 additions & 0 deletions src/index/ivf_raft/ivf_raft.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ class RaftIvfIndexNode : public IndexNode {
// status
is.read((char*)(&this->device_id_), sizeof(this->device_id_));
MIN_LOAD_CHOOSE_DEVICE_WITH_ASSIGN(this->device_id_, binary->size);
load_ = binary->size;
raft_utils::device_setter with_this_device{this->device_id_};

raft_utils::init_gpu_resources();
Expand Down Expand Up @@ -574,11 +575,17 @@ class RaftIvfIndexNode : public IndexNode {
return knowhere::IndexEnum::INDEX_RAFT_IVFPQ;
}
}
virtual ~RaftIvfIndexNode() {
if (device_id_ >= 0) {
RELEASE_DEVICE(this->device_id_, this->load_);
}
}

private:
int device_id_ = -1;
int64_t dim_ = 0;
int64_t counts_ = 0;
size_t load_ = 0;
std::optional<T> gpu_index_;

template <typename raft_search_params_t>
Expand Down
1 change: 0 additions & 1 deletion tests/ut/test_get_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "knowhere/comp/knowhere_config.h"
#include "knowhere/factory.h"
#include "utils.h"

TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
using Catch::Approx;

Expand Down

0 comments on commit 84ab6ab

Please sign in to comment.