Skip to content

Commit

Permalink
Remove redundant team-size and dataset-block-dim parameters from the …
Browse files Browse the repository at this point in the history
…data descriptor
  • Loading branch information
achirkin committed Mar 20, 2024
1 parent 5174811 commit 16ddb13
Show file tree
Hide file tree
Showing 156 changed files with 966 additions and 1,116 deletions.
10 changes: 3 additions & 7 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,8 @@ void launch_vpq_search_main_core(
DatasetT,
8 /*PQ bit*/,
2 /* Subspace dimension*/,
0,
DistanceT,
InternalIdxT,
0>;
InternalIdxT>;
dataset_desc_t dataset_desc(vpq_dset->data.data_handle(),
vpq_dset->encoded_row_length(),
vpq_dset->pq_dim(),
Expand All @@ -200,10 +198,8 @@ void launch_vpq_search_main_core(
DatasetT,
8 /*PQ bit*/,
4 /* Subspace dimension*/,
0,
DistanceT,
InternalIdxT,
0>;
InternalIdxT>;
dataset_desc_t dataset_desc(vpq_dset->data.data_handle(),
vpq_dset->encoded_row_length(),
vpq_dset->pq_dim(),
Expand Down Expand Up @@ -266,7 +262,7 @@ void search_main(raft::resources const& res,
strided_dset != nullptr) {
// Set TEAM_SIZE and DATASET_BLOCK_SIZE to zero tentatively since these parameters cannot be
// determined here. They are set just before kernel launch.
using dataset_desc_t = standard_dataset_descriptor_t<T, InternalIdxT, 0, 0, DistanceT>;
using dataset_desc_t = standard_dataset_descriptor_t<T, InternalIdxT, DistanceT>;
// Search using a plain (strided) row-major dataset
const dataset_desc_t dataset_desc(strided_dset->view().data_handle(),
strided_dset->n_rows(),
Expand Down
36 changes: 6 additions & 30 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
}
}

const auto norm2 = dataset_desc.compute_similarity(query_buffer, seed_index, valid_i);
const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, seed_index, valid_i);

if (valid_i && (norm2 < best_norm2_team_local)) {
best_norm2_team_local = norm2;
Expand Down Expand Up @@ -152,8 +153,8 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(
INDEX_T child_id = invalid_index;
if (valid_i) { child_id = result_child_indices_ptr[i]; }

const auto norm2 =
dataset_desc.compute_similarity(query_buffer, child_id, child_id != invalid_index);
const auto norm2 = dataset_desc.template compute_similarity<DATASET_BLOCK_DIM, TEAM_SIZE>(
query_buffer, child_id, child_id != invalid_index);

// Store the distance
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
Expand Down Expand Up @@ -181,18 +182,12 @@ struct dataset_descriptor_base_t {
dataset_descriptor_base_t(const INDEX_T size, const std::uint32_t dim) : size(size), dim(dim) {}
};

template <class DATA_T_,
class INDEX_T,
std::uint32_t DATASET_BLOCK_DIM_ = 0,
std::uint32_t TEAM_SIZE_ = 0,
class DISTANCE_T = float>
template <class DATA_T_, class INDEX_T, class DISTANCE_T = float>
struct standard_dataset_descriptor_t
: public dataset_descriptor_base_t<float, DISTANCE_T, INDEX_T> {
using LOAD_T = device::LOAD_128BIT_T;
using DATA_T = DATA_T_;
using QUERY_T = typename dataset_descriptor_base_t<float, DISTANCE_T, INDEX_T>::QUERY_T;
static const std::uint32_t DATASET_BLOCK_DIM = DATASET_BLOCK_DIM_;
static const std::uint32_t TEAM_SIZE = TEAM_SIZE_;

const DATA_T* const ptr;
const std::size_t ld;
Expand All @@ -210,6 +205,7 @@ struct standard_dataset_descriptor_t
static const std::uint32_t smem_buffer_size_in_byte = 0;
__device__ void set_smem_ptr(void* const){};

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
const bool valid) const
Expand Down Expand Up @@ -255,24 +251,4 @@ struct standard_dataset_descriptor_t
}
};

template <std::uint32_t DATASET_BLOCK_DIM_OUT,
std::uint32_t TEAM_SIZE_OUT,
std::uint32_t DATASET_BLOCK_DIM_IN,
std::uint32_t TEAM_SIZE_IN,
class DATA_T,
class INDEX_T,
class DISTANCE_T>
standard_dataset_descriptor_t<DATA_T, INDEX_T, DATASET_BLOCK_DIM_OUT, TEAM_SIZE_OUT, DISTANCE_T>
set_compute_template_params(
standard_dataset_descriptor_t<DATA_T, INDEX_T, DATASET_BLOCK_DIM_IN, TEAM_SIZE_IN, DISTANCE_T>&
desc_in)
{
return standard_dataset_descriptor_t<DATA_T,
INDEX_T,
DATASET_BLOCK_DIM_OUT,
TEAM_SIZE_OUT,
DISTANCE_T>(
desc_in.ptr, desc_in.size, desc_in.dim, desc_in.ld);
}

} // namespace raft::neighbors::cagra::detail
104 changes: 28 additions & 76 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@ namespace raft::neighbors::cagra::detail {
template <class DATA_T_,
class CODE_BOOK_T_,
unsigned PQ_BITS,
unsigned PQ_CODE_BOOK_DIM,
unsigned DATASET_BLOCK_DIM_,
unsigned PQ_LEN,
class DISTANCE_T,
class INDEX_T,
unsigned TEAM_SIZE_>
class INDEX_T>
struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T> {
using LOAD_T = device::LOAD_128BIT_T;
using DATA_T = DATA_T_;
using CODE_BOOK_T = CODE_BOOK_T_;
using QUERY_T = typename dataset_descriptor_base_t<half, DISTANCE_T, INDEX_T>::QUERY_T;
static const std::uint32_t DATASET_BLOCK_DIM = DATASET_BLOCK_DIM_;
static const std::uint32_t TEAM_SIZE = TEAM_SIZE_;

const std::uint8_t* encoded_dataset_ptr;
const std::uint32_t encoded_dataset_dim;
Expand All @@ -50,23 +46,22 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
// Set on device
CODE_BOOK_T* smem_pq_code_book_ptr;
static const std::uint32_t smem_buffer_size_in_byte =
(1 << PQ_BITS) * PQ_CODE_BOOK_DIM * utils::size_of<CODE_BOOK_T>();
(1 << PQ_BITS) * PQ_LEN * utils::size_of<CODE_BOOK_T>();

__device__ void set_smem_ptr(void* const smem_ptr)
{
smem_pq_code_book_ptr = reinterpret_cast<CODE_BOOK_T*>(smem_ptr);

// Copy PQ table
if constexpr (std::is_same<CODE_BOOK_T, half>::value) {
for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_CODE_BOOK_DIM;
i += blockDim.x * 2) {
for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) {
half2 buf2;
buf2.x = pq_code_book_ptr[i];
buf2.y = pq_code_book_ptr[i + 1];
(reinterpret_cast<half2*>(smem_pq_code_book_ptr + i))[0] = buf2;
}
} else {
for (unsigned i = threadIdx.x; i < (1 << PQ_BITS) * PQ_CODE_BOOK_DIM; i += blockDim.x) {
for (unsigned i = threadIdx.x; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x) {
// TODO: vectorize
smem_pq_code_book_ptr[i] = pq_code_book_ptr[i];
}
Expand All @@ -93,6 +88,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
{
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T node_id,
const bool valid) const
Expand All @@ -104,9 +100,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
encoded_dataset_ptr + (static_cast<std::uint64_t>(encoded_dataset_dim) * node_id)));
if (PQ_BITS == 8) {
for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) {
constexpr unsigned vlen = 4; // **** DO NOT CHANGE ****
constexpr unsigned nelem = raft::div_rounding_up_unsafe<unsigned>(
DATASET_BLOCK_DIM / PQ_CODE_BOOK_DIM, TEAM_SIZE * vlen);
constexpr unsigned vlen = 4; // **** DO NOT CHANGE ****
constexpr unsigned nelem =
raft::div_rounding_up_unsafe<unsigned>(DATASET_BLOCK_DIM / PQ_LEN, TEAM_SIZE * vlen);
// Loading PQ codes
uint32_t pq_codes[nelem];
#pragma unroll
Expand All @@ -119,18 +115,18 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
4 + k));
}
//
if constexpr ((std::is_same<CODE_BOOK_T, half>::value) && (PQ_CODE_BOOK_DIM % 2 == 0)) {
if constexpr ((std::is_same<CODE_BOOK_T, half>::value) && (PQ_LEN % 2 == 0)) {
// **** Use half2 for distance computation ****
half2 norm2{0, 0};
#pragma unroll
for (std::uint32_t e = 0; e < nelem; e++) {
const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= n_subspace) break;
// Loading VQ code-book
raft::TxN_t<half2, vlen / 2> vq_vals[PQ_CODE_BOOK_DIM];
raft::TxN_t<half2, vlen / 2> vq_vals[PQ_LEN];
#pragma unroll
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m += 1) {
const uint32_t d = (vlen * m) + (PQ_CODE_BOOK_DIM * k) + elem_offset;
for (std::uint32_t m = 0; m < PQ_LEN; m += 1) {
const uint32_t d = (vlen * m) + (PQ_LEN * k) + elem_offset;
if (d >= dim) break;
vq_vals[m].load(
reinterpret_cast<const half2*>(vq_code_book_ptr + d + (dim * vq_code)), 0);
Expand All @@ -139,11 +135,11 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
std::uint32_t pq_code = pq_codes[e];
#pragma unroll
for (std::uint32_t v = 0; v < vlen; v++) {
if (PQ_CODE_BOOK_DIM * (v + k) >= dim) break;
if (PQ_LEN * (v + k) >= dim) break;
#pragma unroll
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m += 2) {
const std::uint32_t d1 = m + (PQ_CODE_BOOK_DIM * v);
const std::uint32_t d = d1 + (PQ_CODE_BOOK_DIM * k);
for (std::uint32_t m = 0; m < PQ_LEN; m += 2) {
const std::uint32_t d1 = m + (PQ_LEN * v);
const std::uint32_t d = d1 + (PQ_LEN * k);
// Loading query vector in smem
half2 diff2 = (reinterpret_cast<const half2*>(
query_ptr))[device::swizzling<std::uint32_t, DATASET_BLOCK_DIM / 2>(d / 2)];
Expand All @@ -164,10 +160,10 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset;
if (k >= n_subspace) break;
// Loading VQ code-book
raft::TxN_t<CODE_BOOK_T, vlen> vq_vals[PQ_CODE_BOOK_DIM];
raft::TxN_t<CODE_BOOK_T, vlen> vq_vals[PQ_LEN];
#pragma unroll
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m++) {
const std::uint32_t d = (vlen * m) + (PQ_CODE_BOOK_DIM * k) + elem_offset;
for (std::uint32_t m = 0; m < PQ_LEN; m++) {
const std::uint32_t d = (vlen * m) + (PQ_LEN * k) + elem_offset;
if (d >= dim) break;
// Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device
// memory)
Expand All @@ -178,15 +174,15 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
std::uint32_t pq_code = pq_codes[e];
#pragma unroll
for (std::uint32_t v = 0; v < vlen; v++) {
if (PQ_CODE_BOOK_DIM * (v + k) >= dim) break;
raft::TxN_t<CODE_BOOK_T, PQ_CODE_BOOK_DIM> pq_vals;
pq_vals.load(reinterpret_cast<const half2*>(smem_pq_code_book_ptr +
PQ_CODE_BOOK_DIM * (pq_code & 0xff)),
0); // (from L1$ or smem)
if (PQ_LEN * (v + k) >= dim) break;
raft::TxN_t<CODE_BOOK_T, PQ_LEN> pq_vals;
pq_vals.load(
reinterpret_cast<const half2*>(smem_pq_code_book_ptr + PQ_LEN * (pq_code & 0xff)),
0); // (from L1$ or smem)
#pragma unroll
for (std::uint32_t m = 0; m < PQ_CODE_BOOK_DIM; m++) {
const std::uint32_t d1 = m + (PQ_CODE_BOOK_DIM * v);
const std::uint32_t d = d1 + (PQ_CODE_BOOK_DIM * k);
for (std::uint32_t m = 0; m < PQ_LEN; m++) {
const std::uint32_t d1 = m + (PQ_LEN * v);
const std::uint32_t d = d1 + (PQ_LEN * k);
// if (d >= dataset_dim) break;
DISTANCE_T diff = query_ptr[d]; // (from smem)
diff -= pq_scale * static_cast<float>(pq_vals.data[m]);
Expand All @@ -207,48 +203,4 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
}
};

template <std::uint32_t DATASET_BLOCK_DIM_OUT,
std::uint32_t TEAM_SIZE_OUT,
std::uint32_t DATASET_BLOCK_DIM_IN,
std::uint32_t TEAM_SIZE_IN,
class DATA_T,
class INDEX_T,
class DISTANCE_T,
class CODE_BOOK_T,
unsigned PQ_BITS,
unsigned PQ_CODE_BOOK_DIM>
cagra_q_dataset_descriptor_t<DATA_T,
CODE_BOOK_T,
PQ_BITS,
PQ_CODE_BOOK_DIM,
DATASET_BLOCK_DIM_OUT,
DISTANCE_T,
INDEX_T,
TEAM_SIZE_OUT>
set_compute_template_params(cagra_q_dataset_descriptor_t<DATA_T,
CODE_BOOK_T,
PQ_BITS,
PQ_CODE_BOOK_DIM,
DATASET_BLOCK_DIM_IN,
DISTANCE_T,
INDEX_T,
TEAM_SIZE_IN>& desc_in)
{
return cagra_q_dataset_descriptor_t<DATA_T,
CODE_BOOK_T,
PQ_BITS,
PQ_CODE_BOOK_DIM,
DATASET_BLOCK_DIM_OUT,
DISTANCE_T,
INDEX_T,
TEAM_SIZE_OUT>(desc_in.encoded_dataset_ptr,
desc_in.encoded_dataset_dim,
desc_in.n_subspace,
desc_in.vq_code_book_ptr,
desc_in.vq_scale,
desc_in.pq_code_book_ptr,
desc_in.pq_scale,
desc_in.size,
desc_in.dim);
}
} // namespace raft::neighbors::cagra::detail
Loading

0 comments on commit 16ddb13

Please sign in to comment.