From edcc6c718f3f88f98257a8a61f960b6be1b57ba4 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Fri, 8 Dec 2023 04:22:30 -0500 Subject: [PATCH] Minimize rebuilds of RAFT code due to knowhere changes (#264) * Minimize rebuilds of RAFT code due to knowhere changes Signed-off-by: William Hicks * Revert accidental reversion of config defaults Signed-off-by: William Hicks --------- Signed-off-by: William Hicks --- .pre-commit-config.yaml | 3 +- CMakeLists.txt | 6 +- cmake/libs/libraft.cmake | 8 +- src/common/raft/integration/cagra_index.cu | 11 +- .../raft/integration/cagra_instantiations.cu | 27 + src/common/raft/integration/ivf_flat_index.cu | 11 +- .../integration/ivf_flat_instantiations.cu | 27 + src/common/raft/integration/ivf_pq_index.cu | 11 +- .../raft/integration/ivf_pq_instantiations.cu | 27 + .../raft/integration/raft_knowhere_index.cuh | 5 +- .../raft/integration/raft_knowhere_index.hpp | 43 +- src/common/raft/integration/type_mappers.hpp | 67 +++ .../proto/filtered_search_instantiation.cuh | 54 ++ src/common/raft/proto/raft_index.cuh | 513 +++++++----------- 14 files changed, 428 insertions(+), 385 deletions(-) create mode 100644 src/common/raft/integration/cagra_instantiations.cu create mode 100644 src/common/raft/integration/ivf_flat_instantiations.cu create mode 100644 src/common/raft/integration/ivf_pq_instantiations.cu create mode 100644 src/common/raft/integration/type_mappers.hpp create mode 100644 src/common/raft/proto/filtered_search_instantiation.cuh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84bd4df87..6c3c4925f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ # default_language_version: - python: python3.8 + python: python3 exclude: '^thirdparty' fail_fast: True repos: @@ -32,4 +32,5 @@ repos: rev: v1.3.5 hooks: - id: clang-format + types_or: [c, c++, cuda] args: [-style=file] diff --git a/CMakeLists.txt b/CMakeLists.txt index df6d9bf97..67456d5dd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,9 @@ include(cmake/utils/utils.cmake) knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF) if (WITH_RAFT) - set(CMAKE_CUDA_ARCHITECTURES RAPIDS) + if("${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "") + set(CMAKE_CUDA_ARCHITECTURES RAPIDS) + endif() include(cmake/libs/librapids.cmake) project(knowhere CXX C CUDA) include(cmake/libs/libraft.cmake) @@ -139,7 +141,7 @@ list(APPEND KNOWHERE_LINKER_LIBS Folly::folly) add_library(knowhere SHARED ${KNOWHERE_SRCS}) add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS}) if(WITH_RAFT) - list(APPEND KNOWHERE_LINKER_LIBS raft::raft CUDA::cublas CUDA::cusparse CUDA::cusolver) + list(APPEND KNOWHERE_LINKER_LIBS raft::raft raft::compiled_static CUDA::cublas CUDA::cusparse CUDA::cusolver) endif() target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS}) target_include_directories( diff --git a/cmake/libs/libraft.cmake b/cmake/libs/libraft.cmake index 514793933..0b804d997 100644 --- a/cmake/libs/libraft.cmake +++ b/cmake/libs/libraft.cmake @@ -14,9 +14,10 @@ # the License. add_definitions(-DKNOWHERE_WITH_RAFT) +add_definitions(-DRAFT_EXPLICIT_INSTANTIATE_ONLY) set(RAFT_VERSION "${RAPIDS_VERSION}") -set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "branch-23.12") +set(RAFT_FORK "wphicks") +set(RAFT_PINNED_TAG "knowhere-2.4") rapids_find_package(CUDAToolkit REQUIRED @@ -38,7 +39,7 @@ function(find_and_configure_raft) GLOBAL_TARGETS raft::raft COMPONENTS - ${RAFT_COMPONENTS} + compiled_static CPM_ARGS GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git @@ -47,6 +48,7 @@ function(find_and_configure_raft) SOURCE_SUBDIR cpp OPTIONS + "RAFT_COMPILE_LIBRARY ON" "BUILD_TESTS OFF" "BUILD_BENCH OFF" "RAFT_USE_FAISS_STATIC OFF") # Turn this on to build FAISS into your binary diff --git a/src/common/raft/integration/cagra_index.cu b/src/common/raft/integration/cagra_index.cu index 4879f5c39..d43900b50 100644 --- a/src/common/raft/integration/cagra_index.cu +++ b/src/common/raft/integration/cagra_index.cu @@ -14,8 +14,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "common/raft/proto/raft_index_kind.hpp" +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "common/raft/integration/raft_knowhere_index.cuh" +#include "common/raft/proto/filtered_search_instantiation.cuh" +#include "common/raft/proto/raft_index_kind.hpp" + +RAFT_FILTERED_SEARCH_EXTERN(cagra, raft_knowhere::raft_data_t, + raft_knowhere::raft_indexing_t, + raft_knowhere::raft_input_indexing_t, + raft_knowhere::raft_data_t, + raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type) + namespace raft_knowhere { template struct raft_knowhere_index; } // namespace raft_knowhere diff --git a/src/common/raft/integration/cagra_instantiations.cu b/src/common/raft/integration/cagra_instantiations.cu new file mode 100644 index 000000000..651fc893b --- /dev/null +++ b/src/common/raft/integration/cagra_instantiations.cu @@ -0,0 +1,27 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "common/raft/integration/type_mappers.hpp" +#include "common/raft/proto/filtered_search_instantiation.cuh" +#include "common/raft/proto/raft_index_kind.hpp" + +RAFT_FILTERED_SEARCH_INSTANTIATION(cagra, raft_knowhere::raft_data_t, + raft_knowhere::raft_indexing_t, + raft_knowhere::raft_input_indexing_t, + raft_knowhere::raft_data_t, + raft_knowhere::knowhere_bitset_data_type, + raft_knowhere::knowhere_bitset_indexing_type) diff --git a/src/common/raft/integration/ivf_flat_index.cu b/src/common/raft/integration/ivf_flat_index.cu index fc759075d..8d546048e 100644 --- a/src/common/raft/integration/ivf_flat_index.cu +++ b/src/common/raft/integration/ivf_flat_index.cu @@ -14,8 +14,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "common/raft/proto/raft_index_kind.hpp" +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "common/raft/integration/raft_knowhere_index.cuh" +#include "common/raft/proto/filtered_search_instantiation.cuh" +#include "common/raft/proto/raft_index_kind.hpp" + +RAFT_FILTERED_SEARCH_EXTERN(ivf_flat, raft_knowhere::raft_data_t, + raft_knowhere::raft_indexing_t, + raft_knowhere::raft_input_indexing_t, + raft_knowhere::raft_data_t, + raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type) + namespace raft_knowhere { template struct raft_knowhere_index; } // namespace raft_knowhere diff --git a/src/common/raft/integration/ivf_flat_instantiations.cu b/src/common/raft/integration/ivf_flat_instantiations.cu new file mode 100644 index 000000000..a5402083c --- /dev/null +++ b/src/common/raft/integration/ivf_flat_instantiations.cu @@ -0,0 +1,27 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "common/raft/integration/type_mappers.hpp" +#include "common/raft/proto/filtered_search_instantiation.cuh" +#include "common/raft/proto/raft_index_kind.hpp" + +RAFT_FILTERED_SEARCH_INSTANTIATION(ivf_flat, raft_knowhere::raft_data_t, + raft_knowhere::raft_indexing_t, + raft_knowhere::raft_input_indexing_t, + raft_knowhere::raft_data_t, + raft_knowhere::knowhere_bitset_data_type, + raft_knowhere::knowhere_bitset_indexing_type) diff --git a/src/common/raft/integration/ivf_pq_index.cu b/src/common/raft/integration/ivf_pq_index.cu index 9284a0930..9f46b3fc2 100644 --- a/src/common/raft/integration/ivf_pq_index.cu +++ b/src/common/raft/integration/ivf_pq_index.cu @@ -14,8 +14,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "common/raft/proto/raft_index_kind.hpp" +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "common/raft/integration/raft_knowhere_index.cuh" +#include "common/raft/proto/filtered_search_instantiation.cuh" +#include "common/raft/proto/raft_index_kind.hpp" + +RAFT_FILTERED_SEARCH_EXTERN(ivf_pq, raft_knowhere::raft_data_t, + raft_knowhere::raft_indexing_t, + raft_knowhere::raft_input_indexing_t, + raft_knowhere::raft_data_t, + raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type) + namespace raft_knowhere { template struct raft_knowhere_index; } // namespace raft_knowhere diff --git a/src/common/raft/integration/ivf_pq_instantiations.cu b/src/common/raft/integration/ivf_pq_instantiations.cu new file mode 100644 index 000000000..40936fda3 --- /dev/null +++ b/src/common/raft/integration/ivf_pq_instantiations.cu @@ -0,0 +1,27 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "common/raft/integration/type_mappers.hpp" +#include "common/raft/proto/filtered_search_instantiation.cuh" +#include "common/raft/proto/raft_index_kind.hpp" + +RAFT_FILTERED_SEARCH_INSTANTIATION(ivf_pq, raft_knowhere::raft_data_t, + raft_knowhere::raft_indexing_t, + raft_knowhere::raft_input_indexing_t, + raft_knowhere::raft_data_t, + raft_knowhere::knowhere_bitset_data_type, + raft_knowhere::knowhere_bitset_indexing_type) diff --git a/src/common/raft/integration/raft_knowhere_index.cuh b/src/common/raft/integration/raft_knowhere_index.cuh index 5a351c47b..ac0fec00f 100644 --- a/src/common/raft/integration/raft_knowhere_index.cuh +++ b/src/common/raft/integration/raft_knowhere_index.cuh @@ -324,14 +324,11 @@ struct raft_knowhere_index::impl { using data_type = raft_data_t; using indexing_type = raft_indexing_t; using input_indexing_type = raft_input_indexing_t; + using raft_index_type = raft_index_t; impl() { } - private: - using raft_index_type = raft_index_t; - - public: auto is_trained() const { return index_.has_value(); diff --git a/src/common/raft/integration/raft_knowhere_index.hpp b/src/common/raft/integration/raft_knowhere_index.hpp index f26a7f0c7..cd2dc8b43 100644 --- a/src/common/raft/integration/raft_knowhere_index.hpp +++ b/src/common/raft/integration/raft_knowhere_index.hpp @@ -18,52 +18,11 @@ #include #include "common/raft/integration/raft_knowhere_config.hpp" +#include "common/raft/integration/type_mappers.hpp" #include "common/raft/proto/raft_index_kind.hpp" namespace raft_knowhere { -using knowhere_data_type = float; -using knowhere_indexing_type = std::int64_t; -using knowhere_bitset_data_type = std::uint8_t; -using knowhere_bitset_indexing_type = std::uint32_t; - -namespace detail { - -template -struct raft_io_type_mapper : std::false_type {}; - -template <> -struct raft_io_type_mapper : std::true_type { - using data_type = float; - using indexing_type = std::int64_t; - using input_indexing_type = std::int64_t; -}; - -template <> -struct raft_io_type_mapper : std::true_type { - using data_type = float; - using indexing_type = std::int64_t; - using input_indexing_type = std::int64_t; -}; - -template <> -struct raft_io_type_mapper : std::true_type { - using data_type = float; - using indexing_type = std::uint32_t; - using input_indexing_type = std::int64_t; -}; - -} // namespace detail - -template -using raft_data_t = typename detail::raft_io_type_mapper::data_type; - -template -using raft_indexing_t = typename detail::raft_io_type_mapper::indexing_type; - -template -using raft_input_indexing_t = typename detail::raft_io_type_mapper::input_indexing_type; - template struct raft_knowhere_index { auto static constexpr index_kind = IndexKind; diff --git a/src/common/raft/integration/type_mappers.hpp b/src/common/raft/integration/type_mappers.hpp new file mode 100644 index 000000000..814f85382 --- /dev/null +++ b/src/common/raft/integration/type_mappers.hpp @@ -0,0 +1,67 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +#include "common/raft/proto/raft_index_kind.hpp" + +namespace raft_knowhere { + +using knowhere_data_type = float; +using knowhere_indexing_type = std::int64_t; +using knowhere_bitset_data_type = std::uint8_t; +using knowhere_bitset_indexing_type = std::uint32_t; + +namespace detail { + +template +struct raft_io_type_mapper : std::false_type {}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::int64_t; + using input_indexing_type = std::int64_t; +}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::int64_t; + using input_indexing_type = std::uint32_t; +}; + +template <> +struct raft_io_type_mapper : std::true_type { + using data_type = float; + using indexing_type = std::uint32_t; + using input_indexing_type = std::int64_t; +}; + +} // namespace detail + +template +using raft_data_t = typename detail::raft_io_type_mapper::data_type; + +template +using raft_indexing_t = typename detail::raft_io_type_mapper::indexing_type; + +template +using raft_input_indexing_t = typename detail::raft_io_type_mapper::input_indexing_type; + +} // namespace raft_knowhere diff --git a/src/common/raft/proto/filtered_search_instantiation.cuh b/src/common/raft/proto/filtered_search_instantiation.cuh new file mode 100644 index 000000000..09acadb00 --- /dev/null +++ b/src/common/raft/proto/filtered_search_instantiation.cuh @@ -0,0 +1,54 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include + +#include "common/raft/proto/raft_index_kind.hpp" + +namespace raft_proto { +namespace detail { +template +using index_instantiation = std::conditional_t< + K == raft_proto::raft_index_kind::ivf_flat, raft::neighbors::ivf_flat::index, + std::conditional_t< + K == raft_proto::raft_index_kind::ivf_pq, raft::neighbors::ivf_pq::index, + std::conditional_t, + raft::neighbors::ivf_flat::index>>>; +} // namespace detail +} // namespace raft_proto + +#define RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \ + template void search_with_filtering>( \ + raft::resources const&, search_params const&, \ + raft_proto::detail::index_instantiation const&, \ + raft::device_matrix_view, raft::device_matrix_view, \ + raft::device_matrix_view, raft::neighbors::filtering::bitset_filter) + +#define RAFT_FILTERED_SEARCH_INSTANTIATION(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \ + namespace raft::neighbors::index_type { \ + RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT); \ + } + +#define RAFT_FILTERED_SEARCH_EXTERN(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \ + namespace raft::neighbors::index_type { \ + RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT); \ + } diff --git a/src/common/raft/proto/raft_index.cuh b/src/common/raft/proto/raft_index.cuh index ff25b7e04..89638db2e 100644 --- a/src/common/raft/proto/raft_index.cuh +++ b/src/common/raft/proto/raft_index.cuh @@ -20,17 +20,18 @@ #include #include #include -#include #include +#include +#include +#include #include #include #include #include #include #include -#include -#include -#include +#include + #include "common/raft/proto/raft_index_kind.hpp" namespace raft_proto { @@ -38,359 +39,211 @@ namespace raft_proto { auto static const RAFT_NAME = raft::RAFT_NAME; namespace detail { -template typename index_template> -struct template_matches_index_kind : std::false_type{}; +template typename index_template> +struct template_matches_index_kind : std::false_type {}; -template<> -struct template_matches_index_kind : std::true_type{}; +template <> +struct template_matches_index_kind : std::true_type {}; -template<> -struct template_matches_index_kind : std::true_type{}; +template <> +struct template_matches_index_kind : std::true_type {}; -template<> -struct template_matches_index_kind : std::true_type{}; +template <> +struct template_matches_index_kind : std::true_type {}; -template typename index_template> +template typename index_template> auto static constexpr template_matches_index_kind_v = template_matches_index_kind::value; -} +} // namespace detail template