Skip to content

Commit

Permalink
Add implementation and test
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Nov 11, 2024
1 parent 136eafd commit 622e3e3
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 8 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ if(BUILD_SHARED_LIBS)
src/distance/distance.cu
src/distance/pairwise_distance.cu
src/neighbors/brute_force.cu
src/neighbors/brute_force_serialize.cu
src/neighbors/cagra_build_float.cu
src/neighbors/cagra_build_half.cu
src/neighbors/cagra_build_int8.cu
Expand Down
115 changes: 108 additions & 7 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@ void search(raft::resources const& handle,
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*
*/
void serialize(raft::resources const& handle,
const std::string& filename,
Expand Down Expand Up @@ -440,6 +439,61 @@ void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);

/**
* Write the index to an output stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cuvs::neighbors::brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, os, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index brute force index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*/
void serialize(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::brute_force::index<half, float>& index,
bool include_dataset = true);

/**
* Write the index to an output stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cuvs::neighbors::brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, os, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index brute force index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*/
void serialize(raft::resources const& handle,
std::ostream& os,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);

/**
* Load index from file.
*
Expand Down Expand Up @@ -473,7 +527,7 @@ void deserialize(raft::resources const& handle,
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/brute_force_serialize.cuh>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
Expand All @@ -492,11 +546,58 @@ void deserialize(raft::resources const& handle,
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<float, float>* index);

/**@}*/

} // namespace raft::neighbors::brute_force

/**
* Load index from input stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = half; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, is, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] is input stream
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<half, float>* index);
/**
* Load index from input stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = float; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, is, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] is input stream
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<float, float>* index);
/**
* @}
*/
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
#include <raft/core/copy.hpp>

namespace cuvs::neighbors::brute_force {

template <typename T, typename DistT>
index<T, DistT>::index(raft::resources const& res)
// this constructor is just for a temporary index, for use in the deserialization
// api. all the parameters here will get replaced with loaded values - that aren't
// necessarily known ahead of time before deserialization.
// TODO: do we even need a handle here - could just construct one?
: cuvs::neighbors::index(),
metric_(cuvs::distance::DistanceType::L2Expanded),
dataset_(raft::make_device_matrix<T, int64_t>(res, 0, 0)),
norms_(std::nullopt),
metric_arg_(0)
{
}

template <typename T, typename DistT>
index<T, DistT>::index(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset,
Expand Down
167 changes: 167 additions & 0 deletions cpp/src/neighbors/brute_force_serialize.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuvs/neighbors/brute_force.hpp>
#include <raft/core/copy.cuh>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>

namespace cuvs::neighbors::brute_force {

int constexpr serialization_version = 0;

template <typename T, typename DistT>
void serialize(raft::resources const& handle,
std::ostream& os,
const index<T, DistT>& index,
bool include_dataset = true)
{
RAFT_LOG_DEBUG(
"Saving brute force index, size %zu, dim %u", static_cast<size_t>(index.size()), index.dim());

auto dtype_string = raft::detail::numpy_serializer::get_numpy_dtype<T>().to_string();
dtype_string.resize(4);
os << dtype_string;

raft::serialize_scalar(handle, os, serialization_version);
raft::serialize_scalar(handle, os, index.size());
raft::serialize_scalar(handle, os, index.dim());
raft::serialize_scalar(handle, os, index.metric());
raft::serialize_scalar(handle, os, index.metric_arg());
raft::serialize_scalar(handle, os, include_dataset);
if (include_dataset) { raft::serialize_mdspan(handle, os, index.dataset()); }
auto has_norms = index.has_norms();
raft::serialize_scalar(handle, os, has_norms);
if (has_norms) { raft::serialize_mdspan(handle, os, index.norms()); }
raft::resource::sync_stream(handle);
}

void serialize(raft::resources const& handle,
const std::string& filename,
const index<half, float>& index,
bool include_dataset)
{
auto os = std::ofstream{filename, std::ios::out | std::ios::binary};
RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str());
serialize<half, float>(handle, os, index, include_dataset);
}

void serialize(raft::resources const& handle,
const std::string& filename,
const index<float, float>& index,
bool include_dataset)
{
auto os = std::ofstream{filename, std::ios::out | std::ios::binary};
RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str());
serialize<float, float>(handle, os, index, include_dataset);
}

void serialize(raft::resources const& handle,
std::ostream& os,
const index<half, float>& index,
bool include_dataset)
{
serialize<half, float>(handle, os, index, include_dataset);
}

void serialize(raft::resources const& handle,
std::ostream& os,
const index<float, float>& index,
bool include_dataset)
{
serialize<float, float>(handle, os, index, include_dataset);
}

template <typename T, typename DistT>
auto deserialize(raft::resources const& handle, std::istream& is)
{
auto dtype_string = std::array<char, 4>{};
is.read(dtype_string.data(), 4);

auto ver = raft::deserialize_scalar<int>(handle, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
}
std::int64_t rows = raft::deserialize_scalar<size_t>(handle, is);
std::int64_t dim = raft::deserialize_scalar<size_t>(handle, is);
auto metric = raft::deserialize_scalar<cuvs::distance::DistanceType>(handle, is);
auto metric_arg = raft::deserialize_scalar<DistT>(handle, is);

auto dataset_storage = raft::make_host_matrix<T>(std::int64_t{}, std::int64_t{});
auto include_dataset = raft::deserialize_scalar<bool>(handle, is);
if (include_dataset) {
dataset_storage = raft::make_host_matrix<T>(rows, dim);
raft::deserialize_mdspan(handle, is, dataset_storage.view());
}

auto has_norms = raft::deserialize_scalar<bool>(handle, is);
auto norms_storage = has_norms ? std::optional{raft::make_host_vector<DistT, std::int64_t>(rows)}
: std::optional<raft::host_vector<DistT, std::int64_t>>{};
// TODO(wphicks): Use mdbuffer here when available
auto norms_storage_dev =
has_norms ? std::optional{raft::make_device_vector<DistT, std::int64_t>(handle, rows)}
: std::optional<raft::device_vector<DistT, std::int64_t>>{};
if (has_norms) {
raft::deserialize_mdspan(handle, is, norms_storage->view());
raft::copy(handle, norms_storage_dev->view(), norms_storage->view());
}

auto result = index<T, DistT>(handle,
raft::make_const_mdspan(dataset_storage.view()),
std::move(norms_storage_dev),
metric,
metric_arg);
raft::resource::sync_stream(handle);

return result;
}

void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<half, float>* index)
{
auto is = std::ifstream{filename, std::ios::in | std::ios::binary};
RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str());

*index = deserialize<half, float>(handle, is);
}

void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<float, float>* index)
{
auto is = std::ifstream{filename, std::ios::in | std::ios::binary};
RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str());

*index = deserialize<float, float>(handle, is);
}

void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<half, float>* index)
{
*index = deserialize<half, float>(handle, is);
}

void deserialize(raft::resources const& handle,
std::istream& is,
cuvs::neighbors::brute_force::index<float, float>* index)
{
*index = deserialize<float, float>(handle, is);
}

} // namespace cuvs::neighbors::brute_force
18 changes: 17 additions & 1 deletion cpp/test/neighbors/ann_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,28 @@ class AnnBruteForceTest : public ::testing::TestWithParam<AnnBruteForceInputs<Id
0.001f,
stream_,
true));

brute_force::serialize(handle_, std::string{"brute_force_index"}, idx, true);
auto index_loaded = brute_force::index<DataT, T>(handle_);
brute_force::deserialize(handle_, std::string{"brute_force_index"}, &index_loaded);

brute_force::search(handle_,
idx,
index_loaded,
search_queries_view,
indices_out_view,
dists_out_view,
cuvs::neighbors::filtering::none_sample_filter{});
raft::resource::sync_stream(handle_);

ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(indices_naive_dev.data(),
indices_bruteforce_dev.data(),
distances_naive_dev.data(),
distances_bruteforce_dev.data(),
ps.num_queries,
ps.k,
0.001f,
stream_,
true));
}
}

Expand Down

0 comments on commit 622e3e3

Please sign in to comment.