Skip to content

Commit

Permalink
Cagra support & update raft version (#207)
Browse files Browse the repository at this point in the history
* Add CAGRA support with latest RAFT

Update to RAFT 23.12
Update CAGRA integration to improve performance
Avoid post-filtering using RAFT's new filtering feature
Use RAFT's new device_resources_manager to simplify and optimize
resource initialization
Update build infratructure to build for all supported CUDA architectures
Refactor RAFT integration code to more cleanly separate RAFT code from
Knowhere code
Avoid exposing RAFT symbols in any Knowhere header

Signed-off-by: William Hicks <[email protected]>

* Fix todo items discovered during self-review

Signed-off-by: William Hicks <[email protected]>

* Update to mainline RAFT and run linters

Signed-off-by: William Hicks <[email protected]>

* Add workaround for refinement issue

Signed-off-by: William Hicks <[email protected]>

* support cagra

Signed-off-by: yusheng.ma <[email protected]>

* support cagra

Signed-off-by: yusheng.ma <[email protected]>

* support cagra

Signed-off-by: yusheng.ma <[email protected]>

---------

Signed-off-by: William Hicks <[email protected]>
Signed-off-by: yusheng.ma <[email protected]>
Co-authored-by: William Hicks <[email protected]>
  • Loading branch information
Presburger and wphicks authored Nov 24, 2023
1 parent 439a0dd commit 7d9e553
Show file tree
Hide file tree
Showing 34 changed files with 2,435 additions and 1,418 deletions.
43 changes: 16 additions & 27 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@
# License for the specific language governing permissions and limitations under
# the License

cmake_minimum_required(VERSION 3.23.0 FATAL_ERROR)
cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR)
project(knowhere CXX C)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/")
include(GNUInstallDirs)
include(ExternalProject)
include(cmake/utils/utils.cmake)

knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF)
if (WITH_RAFT)
set(CMAKE_CUDA_ARCHITECTURES RAPIDS)
include(cmake/libs/librapids.cmake)
project(knowhere CXX C CUDA)
include(cmake/libs/libraft.cmake)
endif()

knowhere_option(WITH_UT "Build with UT test" OFF)
knowhere_option(WITH_ASAN "Build with ASAN" OFF)
knowhere_option(WITH_DISKANN "Build with diskann index" OFF)
knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF)
knowhere_option(WITH_BENCHMARK "Build with benchmark" OFF)
knowhere_option(WITH_COVERAGE "Build with coverage" OFF)
knowhere_option(WITH_CCACHE "Build with ccache" ON)
Expand Down Expand Up @@ -64,18 +70,6 @@ endif()

list( APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}/)

if(WITH_RAFT)
if("${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "")
set(CMAKE_CUDA_ARCHITECTURES 86;80;75;70;61)
endif()
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
if(${CUDAToolkit_VERSION_MAJOR} GREATER 10)
# cuda11 support --threads for compile some large .cu more efficient
add_compile_options($<$<COMPILE_LANGUAGE:CUDA>:--threads=4>)
endif()
endif()

add_definitions(-DNOT_COMPILE_FOR_SWIG)

include(cmake/utils/compile_flags.cmake)
Expand All @@ -88,12 +82,8 @@ include_directories(thirdparty/faiss)
find_package(OpenMP REQUIRED)

find_package(folly REQUIRED)
set(FOLLY_LIBRARIES Folly::folly)
include_directories(${folly_INCLUDE_DIRS})

if(WITH_RAFT)
include(cmake/libs/libraft.cmake)
endif()

find_package(nlohmann_json REQUIRED)
find_package(glog REQUIRED)
Expand All @@ -113,8 +103,7 @@ if(WITH_COVERAGE)
endif()

knowhere_file_glob(GLOB_RECURSE KNOWHERE_SRCS src/common/*.cc src/index/*.cc
src/io/*.cc src/index/*.cu src/common/raft/*.cu
src/common/raft/*.cc)
src/io/*.cc src/common/*.cu src/index/*.cu src/io/*.cu)

set(KNOWHERE_LINKER_LIBS "")

Expand All @@ -127,13 +116,13 @@ else()
endif()

knowhere_file_glob(GLOB_RECURSE KNOWHERE_GPU_SRCS src/index/gpu/flat_gpu/*.cc
src/index/gpu/ivf_gpu/*.cc src/index/cagra/*.cu)
src/index/gpu/ivf_gpu/*.cc)
list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_GPU_SRCS})

if(NOT WITH_RAFT)
knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS src/index/ivf_raft/*.cc
src/index/ivf_raft/*.cu src/index/cagra/*.cu
src/common/raft/*.cu src/common/raft/*.cc)
knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS
src/common/raft/*.cu src/common/raft/*.cc
src/index/gpu_raft/*.cc)
list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_RAFT_SRCS})
endif()

Expand All @@ -145,12 +134,12 @@ list(APPEND KNOWHERE_LINKER_LIBS glog::glog)
list(APPEND KNOWHERE_LINKER_LIBS nlohmann_json::nlohmann_json)
list(APPEND KNOWHERE_LINKER_LIBS prometheus-cpp::core prometheus-cpp::push)
list(APPEND KNOWHERE_LINKER_LIBS fmt::fmt-header-only)
list(APPEND KNOWHERE_LINKER_LIBS ${FOLLY_LIBRARIES})
list(APPEND KNOWHERE_LINKER_LIBS Folly::folly)

add_library(knowhere SHARED ${KNOWHERE_SRCS})
add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS})
if(WITH_RAFT)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft CUDA::cublas CUDA::cusparse CUDA::cusolver)
endif()
target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS})
target_include_directories(
Expand Down
72 changes: 72 additions & 0 deletions benchmark/hdf5/benchmark_float_qps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,60 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test {
}
}

void
test_cagra(const knowhere::Json& cfg) {
auto conf = cfg;

auto find_smallest_max_iters = [&](float expected_recall) -> int32_t {
auto ds_ptr = knowhere::GenDataSet(nq_, dim_, xq_);
auto left = 32;
auto right = 256;
auto max_iterations = left;

float recall;
while (left <= right) {
max_iterations = left + (right - left) / 2;
conf[knowhere::indexparam::MAX_ITERATIONS] = max_iterations;

auto result = index_.Search(*ds_ptr, conf, nullptr);
recall = CalcRecall(result.value()->GetIds(), nq_, topk_);
printf(
"[%0.3f s] iterate CAGRA param for recall %.4f: max_iterations=%d, k=%d, "
"R@=%.4f\n",
get_time_diff(), expected_recall, max_iterations, topk_, recall);
std::fflush(stdout);
if (std::abs(recall - expected_recall) <= 0.0001) {
return max_iterations;
}
if (recall < expected_recall) {
left = max_iterations + 1;
} else {
right = max_iterations - 1;
}
}
return left;
};

for (auto expected_recall : EXPECTED_RECALLs_) {
conf[knowhere::indexparam::ITOPK_SIZE] = ((int{topk_} + 32 - 1) / 32) * 32;
conf[knowhere::meta::TOPK] = topk_;
conf[knowhere::indexparam::MAX_ITERATIONS] = find_smallest_max_iters(expected_recall);

printf(
"\n[%0.3f s] %s | %s | k=%d, "
"R@=%.4f\n",
get_time_diff(), ann_test_name_.c_str(), index_type_.c_str(), topk_, expected_recall);
printf("================================================================================\n");
for (auto thread_num : THREAD_NUMs_) {
CALC_TIME_SPAN(task(conf, thread_num, nq_));
printf(" thread_num = %2d, elapse = %6.3fs, VPS = %.3f\n", thread_num, t_diff, nq_ / t_diff);
std::fflush(stdout);
}
printf("================================================================================\n");
printf("[%.3f s] Test '%s/%s' done\n\n", get_time_diff(), ann_test_name_.c_str(), index_type_.c_str());
}
}

void
test_hnsw(const knowhere::Json& cfg) {
auto conf = cfg;
Expand Down Expand Up @@ -219,6 +273,9 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test {
#ifdef KNOWHERE_WITH_GPU
knowhere::KnowhereConfig::InitGPUResource(GPU_DEVICE_ID, 2);
cfg_[knowhere::meta::DEVICE_ID] = GPU_DEVICE_ID;
#endif
#ifdef KNOWHERE_WITH_RAFT
knowhere::KnowhereConfig::SetRaftMemPool();
#endif
}

Expand Down Expand Up @@ -249,6 +306,9 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test {
// SCANN index params
const std::vector<int32_t> SCANN_REORDER_K = {256, 512, 768, 1024};
const std::vector<bool> SCANN_WITH_RAW_DATA = {true};

// CAGRA index params
const std::vector<int32_t> GRAPH_DEGREE_ = {32, 64};
};

TEST_F(Benchmark_float_qps, TEST_IVF_FLAT) {
Expand Down Expand Up @@ -342,3 +402,15 @@ TEST_F(Benchmark_float_qps, TEST_SCANN) {
}
}
}
TEST_F(Benchmark_float_qps, TEST_CAGRA) {
index_type_ = knowhere::IndexEnum::INDEX_RAFT_CAGRA;
knowhere::Json conf = cfg_;
for (auto gd : GRAPH_DEGREE_) {
conf[knowhere::indexparam::GRAPH_DEGREE] = gd;
conf[knowhere::indexparam::INTERMEDIATE_GRAPH_DEGREE] = gd;
conf[knowhere::indexparam::MAX_ITERATIONS] = 64;
std::string index_file_name = get_index_name({gd});
create_index(index_file_name, conf);
test_cagra(conf);
}
}
1 change: 0 additions & 1 deletion benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class Benchmark_knowhere : public Benchmark_hdf5 {
bin->size = dim_ * nb_ * sizeof(float);
binary_set.Append("RAW_DATA", bin);
}

index.Deserialize(binary_set, conf);
}

Expand Down
21 changes: 7 additions & 14 deletions cmake/libs/libraft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,15 @@
# the License.

add_definitions(-DKNOWHERE_WITH_RAFT)
include(cmake/utils/fetch_rapids.cmake)
include(rapids-cmake)
include(rapids-cpm)
include(rapids-cuda)
include(rapids-export)
include(rapids-find)

rapids_cpm_init()

set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")

set(RAPIDS_VERSION 23.04)
set(RAFT_VERSION "${RAPIDS_VERSION}")
set(RAFT_FORK "rapidsai")
set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}")
set(RAFT_PINNED_TAG "branch-23.12")


rapids_find_package(CUDAToolkit REQUIRED
BUILD_EXPORT_SET knowhere-exports
INSTALL_EXPORT_SET knowhere-exports
)

function(find_and_configure_raft)
set(oneValueArgs VERSION FORK PINNED_TAG)
Expand Down
14 changes: 13 additions & 1 deletion cmake/utils/fetch_rapids.cmake → cmake/libs/librapids.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# License for the specific language governing permissions and limitations under
# the License.

set(RAPIDS_VERSION "23.04")
set(RAPIDS_VERSION 23.12)

if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
file(
Expand All @@ -22,3 +22,15 @@ if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)
endif()
include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake)

include(rapids-cpm) # Dependency tracking
include(rapids-find) # Wrappers for finding packages
include(rapids-cuda) # Common CMake CUDA logic

rapids_cuda_init_architectures(knowhere)
message(STATUS "INIT: ${CMAKE_CUDA_ARCHITECTURES}")

rapids_cpm_init()

set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")
29 changes: 29 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,35 @@ constexpr const char* M = "m"; // PQ param for IVFPQ
constexpr const char* SSIZE = "ssize";
constexpr const char* REORDER_K = "reorder_k";
constexpr const char* WITH_RAW_DATA = "with_raw_data";
// RAFT Params
constexpr const char* REFINE_RATIO = "refine_ratio";
// RAFT-specific IVF Params
constexpr const char* KMEANS_N_ITERS = "kmeans_n_iters";
constexpr const char* KMEANS_TRAINSET_FRACTION = "kmeans_trainset_fraction";
constexpr const char* ADAPTIVE_CENTERS = "adaptive_centers"; // IVF FLAT
constexpr const char* CODEBOOK_KIND = "codebook_kind"; // IVF PQ
constexpr const char* FORCE_RANDOM_ROTATION = "force_random_rotation"; // IVF PQ
constexpr const char* CONSERVATIVE_MEMORY_ALLOCATION = "conservative_memory_allocation"; // IVF PQ
constexpr const char* LUT_DTYPE = "lut_dtype"; // IVF PQ
constexpr const char* INTERNAL_DISTANCE_DTYPE = "internal_distance_dtype"; // IVF PQ
constexpr const char* PREFERRED_SHMEM_CARVEOUT = "preferred_shmem_carveout"; // IVF PQ

// CAGRA Params
constexpr const char* INTERMEDIATE_GRAPH_DEGREE = "intermediate_graph_degree";
constexpr const char* GRAPH_DEGREE = "graph_degree";
constexpr const char* ITOPK_SIZE = "itopk_size";
constexpr const char* MAX_QUERIES = "max_queries";
constexpr const char* BUILD_ALGO = "build_algo";
constexpr const char* SEARCH_ALGO = "search_algo";
constexpr const char* TEAM_SIZE = "team_size";
constexpr const char* SEARCH_WIDTH = "search_width";
constexpr const char* MIN_ITERATIONS = "min_iterations";
constexpr const char* MAX_ITERATIONS = "max_iterations";
constexpr const char* THREAD_BLOCK_SIZE = "thread_block_size";
constexpr const char* HASHMAP_MODE = "hashmap_mode";
constexpr const char* HASHMAP_MIN_BITLEN = "hashmap_min_bitlen";
constexpr const char* HASHMAP_MAX_FILL_RATE = "hashmap_max_fill_rate";
constexpr const char* NN_DESCENT_NITER = "nn_descent_niter";

// HNSW Params
constexpr const char* EFCONSTRUCTION = "efConstruction";
Expand Down
6 changes: 6 additions & 0 deletions include/knowhere/comp/knowhere_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class KnowhereConfig {
*/
static void
SetRaftMemPool(size_t init_size, size_t max_size);

/**
* Initialize RAFT with defaults
*/
static void
SetRaftMemPool();
};

} // namespace knowhere
Expand Down
Loading

0 comments on commit 7d9e553

Please sign in to comment.