Skip to content

Commit

Permalink
Add serialization API to brute-force
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Nov 11, 2024
1 parent 29c0f5b commit 136eafd
Showing 1 changed file with 126 additions and 0 deletions.
126 changes: 126 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ struct index : cuvs::neighbors::index {
index& operator=(index&&) = default;
~index() = default;

/**
* @brief Construct an empty index.
*
* Constructs an empty index. This index will either need to be trained with `build`
* or loaded from a saved copy with `deserialize`
*/
index(raft::resources const& handle);

/** Construct a brute force index from dataset
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
Expand Down Expand Up @@ -371,6 +379,124 @@ void search(raft::resources const& handle,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
const cuvs::neighbors::filtering::base_filter& sample_filter =
cuvs::neighbors::filtering::none_sample_filter{});

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*
*/
void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<half, float>& index,
bool include_dataset = true);
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = brute_force::build(...);`
* cuvs::neighbors::brute_force::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index brute force index
* @param[in] include_dataset whether to include the dataset in the serialized
* output
*
*/
void serialize(raft::resources const& handle,
const std::string& filename,
const cuvs::neighbors::brute_force::index<float, float>& index,
bool include_dataset = true);
/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/neighbors/brute_force.hpp>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = half; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, filename, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<half, float>* index);
/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <raft/neighbors/brute_force_serialize.cuh>
*
* raft::resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = float; // data element type
* brute_force::index<T, float> index(handle);
* cuvs::neighbors::brute_force::deserialize(handle, filename, index);
* @endcode
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[out] index brute force index
*
*/
void deserialize(raft::resources const& handle,
const std::string& filename,
cuvs::neighbors::brute_force::index<float, float>* index);

/**@}*/

} // namespace raft::neighbors::brute_force

/**
* @}
*/
Expand Down

0 comments on commit 136eafd

Please sign in to comment.