Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functor APIs #210

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

Functor APIs #210

wants to merge 4 commits into from

Conversation

jiashuy
Copy link
Collaborator

@jiashuy jiashuy commented Feb 28, 2025

User story

When developing the erase_if and export_batch_if APIs, HKV supported customized predicate functors as input to check whether to erase/export or not.
From the perspective of HKV, value is an array of type V, it only stores the embedding, therefore the interface of predicate functors ignores the value.

As HKV’s user wants to erase/export the table after evaluating the value, we can also see similar usage in std::erase_if(std::map).
So we decided to develop new APIs export_batch_if_v2 and erase_if_v2 to support this feature.

And we think it's more general for HKV to evaluate the whole item [key, score, value] than [key, score].
So we will keep supporting export_batch_if and erase_if for a short term, but will deprecate it in the future.

Design

User predicate functor

The user needs to provide a functor whose template parameters and input and output need to be aligned with the following code.

Please note that the device functor assumes that each thread deals with a KV-pair.
The GroupSize is used when users want to evaluate the value using multi-threads, and HKV supports using a cooperative group to deal with cooperatively. However, the GroupSize is configured by HKV, so users don’t need to instantiate the device function. Instantiating the struct is enough.

Provide two use cases here:

Use case 1: evaluate key, score and partial value.

namespace cg = cooperative_groups;
template <class K, class V, class S>
struct PredFunctor {
  K pattern;
  S threshold;
  template<int GroupSize>
  __forceinline__ __device__ bool operator()(const K& key,
                                             const V* value,
                                             const S& score,
                                             const cg::tiled_partition<GroupSize>& g) {
    /* evaluate key, score and value. */
    return (key & 0x1 == pattern) && (score < threshold) && (value[2] < 1.0f);
  }
};

When users want to evaluate the value using more than one thread, param g comes in handy.
Use case 2: evaluate the whole value, if there exists item is not 0, then return true.

namespace cg = cooperative_groups;

template <class K, class V, class S>
struct PredFunctor {
  int dim;
  template<int GroupSize>
  __forceinline__ __device__ bool operator()(const K& key,
                                             const V* value,
                                             const S& score,
                                             const cg::tiled_partition<GroupSize>& g) {
    bool pred = false;
    for (int i = 0; i < g.size(); i ++) {
      auto cur_value = g.shfl(value, i);
      bool cur_pred = false;
      unsigned int vote = 0;
      /* evaluate one value cooperatively in one loop. */
      for (int j = g.thread_rank(); j < dim; j += g.size()) {
        if (cur_value[j] != 0) cur_pred = true;
        vote = g.ballot(cur_pred);
        if (vote != 0) break;
      }
      if (g.thread_rank() == i && vote != 0) pred = true;
    }
    return pred;
  }
};

APIs

  /**
   * @brief Exports a certain number of the key-value-score tuples which match @tparam PredFunctor A functor with template <K, V, S, int> defined an operator with signature:  __device__ (bool*)(const K&, const V*, const S&, const V*).
   *
   * @param n The maximum number of exported pairs.
   * @param offset The position of the key to remove.
   * @d_counter The number of elements dumped which is on device.
   * @param keys The keys to dump from GPU-accessible memory with shape (n).
   * @param values The values to dump from GPU-accessible memory with shape (n, DIM).
   * @param scores The scores to search on GPU-accessible memory with shape (n).
   * @parblock
   * If @p scores is `nullptr`, the score for each key will not be returned.
   * @endparblock
   *
   * @param stream The CUDA stream that is used to execute the operation.
   *
   * @return void
   *
   * @throw CudaException If the key-value size is too large for GPU shared memory. Reducing the value for @p n is currently required if this exception occurs.
   */

template <typename PredFunctor>
void export_batch_if_v2(PredFunctor& pred,
                        size_type n, const size_type offset,
                        size_type* d_counter,
                        key_type* keys,                // (n)
                        value_type* values,            // (n, DIM)
                        score_type* scores = nullptr,  // (n)
                        cudaStream_t stream = 0);



/**
 * @brief Erase the key-value-score tuples which match @tparam PredFunctor A functor with template <K, V, S, int> defined an operator with signature:  __device__ (bool*)(const K&, const V*, const S&, const V*).
* @param stream The CUDA stream that is used to execute the operation.
*
* @return The number of elements removed.
*/

template <typename PredFunctor>
size_type erase_if_v2(PredFunctor& pred, cudaStream_t stream = 0);

Copy link

@jiashuy jiashuy requested a review from shijieliu February 28, 2025 02:19
@jiashuy
Copy link
Collaborator Author

jiashuy commented Feb 28, 2025

/blossom-ci


// Using for_each API to simulate export_batch_if_v2 API.
template <class K, class V, class S>
struct ForEachExecutionFuncV4 {
Copy link
Collaborator Author

@jiashuy jiashuy Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can also be used to count the matched key-values in the whole table.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant