Skip to content

Commit

Permalink
Minimize rebuilds of RAFT code due to knowhere changes (#264)
Browse files Browse the repository at this point in the history
* Minimize rebuilds of RAFT code due to knowhere changes

Signed-off-by: William Hicks <[email protected]>

* Revert accidental reversion of config defaults

Signed-off-by: William Hicks <[email protected]>

---------

Signed-off-by: William Hicks <[email protected]>
  • Loading branch information
wphicks authored Dec 8, 2023
1 parent 48986a8 commit edcc6c7
Show file tree
Hide file tree
Showing 14 changed files with 428 additions and 385 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

default_language_version:
python: python3.8
python: python3
exclude: '^thirdparty'
fail_fast: True
repos:
Expand All @@ -32,4 +32,5 @@ repos:
rev: v1.3.5
hooks:
- id: clang-format
types_or: [c, c++, cuda]
args: [-style=file]
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ include(cmake/utils/utils.cmake)

knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF)
if (WITH_RAFT)
set(CMAKE_CUDA_ARCHITECTURES RAPIDS)
if("${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "")
set(CMAKE_CUDA_ARCHITECTURES RAPIDS)
endif()
include(cmake/libs/librapids.cmake)
project(knowhere CXX C CUDA)
include(cmake/libs/libraft.cmake)
Expand Down Expand Up @@ -139,7 +141,7 @@ list(APPEND KNOWHERE_LINKER_LIBS Folly::folly)
add_library(knowhere SHARED ${KNOWHERE_SRCS})
add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS})
if(WITH_RAFT)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft CUDA::cublas CUDA::cusparse CUDA::cusolver)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft raft::compiled_static CUDA::cublas CUDA::cusparse CUDA::cusolver)
endif()
target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS})
target_include_directories(
Expand Down
8 changes: 5 additions & 3 deletions cmake/libs/libraft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# the License.

add_definitions(-DKNOWHERE_WITH_RAFT)
add_definitions(-DRAFT_EXPLICIT_INSTANTIATE_ONLY)
set(RAFT_VERSION "${RAPIDS_VERSION}")
set(RAFT_FORK "rapidsai")
set(RAFT_PINNED_TAG "branch-23.12")
set(RAFT_FORK "wphicks")
set(RAFT_PINNED_TAG "knowhere-2.4")


rapids_find_package(CUDAToolkit REQUIRED
Expand All @@ -38,7 +39,7 @@ function(find_and_configure_raft)
GLOBAL_TARGETS
raft::raft
COMPONENTS
${RAFT_COMPONENTS}
compiled_static
CPM_ARGS
GIT_REPOSITORY
https://github.com/${PKG_FORK}/raft.git
Expand All @@ -47,6 +48,7 @@ function(find_and_configure_raft)
SOURCE_SUBDIR
cpp
OPTIONS
"RAFT_COMPILE_LIBRARY ON"
"BUILD_TESTS OFF"
"BUILD_BENCH OFF"
"RAFT_USE_FAISS_STATIC OFF") # Turn this on to build FAISS into your binary
Expand Down
11 changes: 10 additions & 1 deletion src/common/raft/integration/cagra_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/raft/proto/raft_index_kind.hpp"
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "common/raft/integration/raft_knowhere_index.cuh"
#include "common/raft/proto/filtered_search_instantiation.cuh"
#include "common/raft/proto/raft_index_kind.hpp"

RAFT_FILTERED_SEARCH_EXTERN(cagra, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::cagra>,
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::cagra>,
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::cagra>,
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::cagra>,
raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type)

namespace raft_knowhere {
template struct raft_knowhere_index<raft_proto::raft_index_kind::cagra>;
} // namespace raft_knowhere
27 changes: 27 additions & 0 deletions src/common/raft/integration/cagra_instantiations.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "common/raft/integration/type_mappers.hpp"
#include "common/raft/proto/filtered_search_instantiation.cuh"
#include "common/raft/proto/raft_index_kind.hpp"

RAFT_FILTERED_SEARCH_INSTANTIATION(cagra, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::cagra>,
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::cagra>,
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::cagra>,
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::cagra>,
raft_knowhere::knowhere_bitset_data_type,
raft_knowhere::knowhere_bitset_indexing_type)
11 changes: 10 additions & 1 deletion src/common/raft/integration/ivf_flat_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/raft/proto/raft_index_kind.hpp"
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "common/raft/integration/raft_knowhere_index.cuh"
#include "common/raft/proto/filtered_search_instantiation.cuh"
#include "common/raft/proto/raft_index_kind.hpp"

RAFT_FILTERED_SEARCH_EXTERN(ivf_flat, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_flat>,
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::ivf_flat>,
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::ivf_flat>,
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_flat>,
raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type)

namespace raft_knowhere {
template struct raft_knowhere_index<raft_proto::raft_index_kind::ivf_flat>;
} // namespace raft_knowhere
27 changes: 27 additions & 0 deletions src/common/raft/integration/ivf_flat_instantiations.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "common/raft/integration/type_mappers.hpp"
#include "common/raft/proto/filtered_search_instantiation.cuh"
#include "common/raft/proto/raft_index_kind.hpp"

RAFT_FILTERED_SEARCH_INSTANTIATION(ivf_flat, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_flat>,
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::ivf_flat>,
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::ivf_flat>,
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_flat>,
raft_knowhere::knowhere_bitset_data_type,
raft_knowhere::knowhere_bitset_indexing_type)
11 changes: 10 additions & 1 deletion src/common/raft/integration/ivf_pq_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/raft/proto/raft_index_kind.hpp"
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "common/raft/integration/raft_knowhere_index.cuh"
#include "common/raft/proto/filtered_search_instantiation.cuh"
#include "common/raft/proto/raft_index_kind.hpp"

RAFT_FILTERED_SEARCH_EXTERN(ivf_pq, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_pq>,
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::ivf_pq>,
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::ivf_pq>,
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_pq>,
raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type)

namespace raft_knowhere {
template struct raft_knowhere_index<raft_proto::raft_index_kind::ivf_pq>;
} // namespace raft_knowhere
27 changes: 27 additions & 0 deletions src/common/raft/integration/ivf_pq_instantiations.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "common/raft/integration/type_mappers.hpp"
#include "common/raft/proto/filtered_search_instantiation.cuh"
#include "common/raft/proto/raft_index_kind.hpp"

RAFT_FILTERED_SEARCH_INSTANTIATION(ivf_pq, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_pq>,
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::ivf_pq>,
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::ivf_pq>,
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_pq>,
raft_knowhere::knowhere_bitset_data_type,
raft_knowhere::knowhere_bitset_indexing_type)
5 changes: 1 addition & 4 deletions src/common/raft/integration/raft_knowhere_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,11 @@ struct raft_knowhere_index<IndexKind>::impl {
using data_type = raft_data_t<index_kind>;
using indexing_type = raft_indexing_t<index_kind>;
using input_indexing_type = raft_input_indexing_t<index_kind>;
using raft_index_type = raft_index_t<index_kind>;

impl() {
}

private:
using raft_index_type = raft_index_t<index_kind>;

public:
auto
is_trained() const {
return index_.has_value();
Expand Down
43 changes: 1 addition & 42 deletions src/common/raft/integration/raft_knowhere_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,52 +18,11 @@
#include <cstdint>

#include "common/raft/integration/raft_knowhere_config.hpp"
#include "common/raft/integration/type_mappers.hpp"
#include "common/raft/proto/raft_index_kind.hpp"

namespace raft_knowhere {

using knowhere_data_type = float;
using knowhere_indexing_type = std::int64_t;
using knowhere_bitset_data_type = std::uint8_t;
using knowhere_bitset_indexing_type = std::uint32_t;

namespace detail {

template <bool B, raft_proto::raft_index_kind IndexKind>
struct raft_io_type_mapper : std::false_type {};

template <>
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::ivf_flat> : std::true_type {
using data_type = float;
using indexing_type = std::int64_t;
using input_indexing_type = std::int64_t;
};

template <>
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::ivf_pq> : std::true_type {
using data_type = float;
using indexing_type = std::int64_t;
using input_indexing_type = std::int64_t;
};

template <>
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::cagra> : std::true_type {
using data_type = float;
using indexing_type = std::uint32_t;
using input_indexing_type = std::int64_t;
};

} // namespace detail

template <raft_proto::raft_index_kind IndexKind>
using raft_data_t = typename detail::raft_io_type_mapper<true, IndexKind>::data_type;

template <raft_proto::raft_index_kind IndexKind>
using raft_indexing_t = typename detail::raft_io_type_mapper<true, IndexKind>::indexing_type;

template <raft_proto::raft_index_kind IndexKind>
using raft_input_indexing_t = typename detail::raft_io_type_mapper<true, IndexKind>::input_indexing_type;

template <raft_proto::raft_index_kind IndexKind>
struct raft_knowhere_index {
auto static constexpr index_kind = IndexKind;
Expand Down
67 changes: 67 additions & 0 deletions src/common/raft/integration/type_mappers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <type_traits>

#include "common/raft/proto/raft_index_kind.hpp"

namespace raft_knowhere {

using knowhere_data_type = float;
using knowhere_indexing_type = std::int64_t;
using knowhere_bitset_data_type = std::uint8_t;
using knowhere_bitset_indexing_type = std::uint32_t;

namespace detail {

template <bool B, raft_proto::raft_index_kind IndexKind>
struct raft_io_type_mapper : std::false_type {};

template <>
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::ivf_flat> : std::true_type {
using data_type = float;
using indexing_type = std::int64_t;
using input_indexing_type = std::int64_t;
};

template <>
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::ivf_pq> : std::true_type {
using data_type = float;
using indexing_type = std::int64_t;
using input_indexing_type = std::uint32_t;
};

template <>
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::cagra> : std::true_type {
using data_type = float;
using indexing_type = std::uint32_t;
using input_indexing_type = std::int64_t;
};

} // namespace detail

template <raft_proto::raft_index_kind IndexKind>
using raft_data_t = typename detail::raft_io_type_mapper<true, IndexKind>::data_type;

template <raft_proto::raft_index_kind IndexKind>
using raft_indexing_t = typename detail::raft_io_type_mapper<true, IndexKind>::indexing_type;

template <raft_proto::raft_index_kind IndexKind>
using raft_input_indexing_t = typename detail::raft_io_type_mapper<true, IndexKind>::input_indexing_type;

} // namespace raft_knowhere
54 changes: 54 additions & 0 deletions src/common/raft/proto/filtered_search_instantiation.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/sample_filter.cuh>

#include "common/raft/proto/raft_index_kind.hpp"

namespace raft_proto {
namespace detail {
template <raft_proto::raft_index_kind K, typename T, typename IdxT>
using index_instantiation = std::conditional_t<
K == raft_proto::raft_index_kind::ivf_flat, raft::neighbors::ivf_flat::index<T, IdxT>,
std::conditional_t<
K == raft_proto::raft_index_kind::ivf_pq, raft::neighbors::ivf_pq::index<IdxT>,
std::conditional_t<K == raft_proto::raft_index_kind::cagra, raft::neighbors::cagra::index<T, IdxT>,
raft::neighbors::ivf_flat::index<T, IdxT>>>>;
} // namespace detail
} // namespace raft_proto

#define RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \
template void search_with_filtering<T, IdxT, raft::neighbors::filtering::bitset_filter<BitsetDataT, BitsetIdxT>>( \
raft::resources const&, search_params const&, \
raft_proto::detail::index_instantiation<raft_proto::raft_index_kind::index_type, T, IdxT> const&, \
raft::device_matrix_view<const T, InpIdxT>, raft::device_matrix_view<IdxT, InpIdxT>, \
raft::device_matrix_view<DistT, InpIdxT>, raft::neighbors::filtering::bitset_filter<BitsetDataT, BitsetIdxT>)

#define RAFT_FILTERED_SEARCH_INSTANTIATION(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \
namespace raft::neighbors::index_type { \
RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT); \
}

#define RAFT_FILTERED_SEARCH_EXTERN(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \
namespace raft::neighbors::index_type { \
RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT); \
}
Loading

0 comments on commit edcc6c7

Please sign in to comment.