Skip to content

Commit

Permalink
Update CUB to latest master from https://github.com/NVlabs/cub
Browse files Browse the repository at this point in the history
  • Loading branch information
Spudz76 committed Oct 4, 2021
1 parent 2809c47 commit c753391
Show file tree
Hide file tree
Showing 25 changed files with 1,575 additions and 411 deletions.
102 changes: 47 additions & 55 deletions src/3rdparty/cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "../block/block_store.cuh"
#include "../block/block_radix_rank.cuh"
#include "../block/block_exchange.cuh"
#include "../block/radix_rank_sort_operations.cuh"
#include "../config.cuh"
#include "../util_type.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
Expand All @@ -56,16 +57,6 @@ namespace cub {
* Tuning policy types
******************************************************************************/

/**
* Radix ranking algorithm
*/
enum RadixRankAlgorithm
{
RADIX_RANK_BASIC,
RADIX_RANK_MEMOIZE,
RADIX_RANK_MATCH
};

/**
* Parameterizable tuning policy type for AgentRadixSortDownsweep
*/
Expand Down Expand Up @@ -137,6 +128,9 @@ struct AgentRadixSortDownsweep

RADIX_DIGITS = 1 << RADIX_BITS,
KEYS_ONLY = Equals<ValueT, NullType>::VALUE,
LOAD_WARP_STRIPED = RANK_ALGORITHM == RADIX_RANK_MATCH ||
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY ||
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR,
};

// Input iterator wrapper type (for applying cache modifier)s
Expand All @@ -148,10 +142,22 @@ struct AgentRadixSortDownsweep
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, false, SCAN_ALGORITHM>,
typename If<(RANK_ALGORITHM == RADIX_RANK_MEMOIZE),
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, true, SCAN_ALGORITHM>,
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>
typename If<(RANK_ALGORITHM == RADIX_RANK_MATCH),
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>,
typename If<(RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY),
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING,
SCAN_ALGORITHM, WARP_MATCH_ANY>,
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING,
SCAN_ALGORITHM, WARP_MATCH_ATOMIC_OR>
>::Type
>::Type
>::Type
>::Type BlockRadixRankT;

// Digit extractor type
typedef BFEDigitExtractor<KeyT> DigitExtractorT;


enum
{
/// Number of bin-starting offsets tracked per thread
Expand Down Expand Up @@ -184,11 +190,11 @@ struct AgentRadixSortDownsweep
typename BlockLoadValuesT::TempStorage load_values;
typename BlockRadixRankT::TempStorage radix_rank;

struct
struct KeysAndOffsets
{
UnsignedBits exchange_keys[TILE_ITEMS];
OffsetT relative_bin_offsets[RADIX_DIGITS];
};
} keys_and_offsets;

Uninitialized<ValueExchangeT> exchange_values;

Expand Down Expand Up @@ -216,11 +222,8 @@ struct AgentRadixSortDownsweep
// The global scatter base offset for each digit (valid in the first RADIX_DIGITS threads)
OffsetT bin_offset[BINS_TRACKED_PER_THREAD];

// The least-significant bit position of the current digit to extract
int current_bit;

// Number of bits in current digit
int num_bits;
// Digit extractor
DigitExtractorT digit_extractor;

// Whether to short-cirucit
int short_circuit;
Expand All @@ -243,17 +246,17 @@ struct AgentRadixSortDownsweep
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
temp_storage.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM];
temp_storage.keys_and_offsets.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM];
}

CTA_SYNC();

#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
UnsignedBits key = temp_storage.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)];
UnsignedBits digit = BFE(key, current_bit, num_bits);
relative_bin_offsets[ITEM] = temp_storage.relative_bin_offsets[digit];
UnsignedBits key = temp_storage.keys_and_offsets.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)];
UnsignedBits digit = digit_extractor.Digit(key);
relative_bin_offsets[ITEM] = temp_storage.keys_and_offsets.relative_bin_offsets[digit];

// Un-twiddle
key = Traits<KeyT>::TwiddleOut(key);
Expand Down Expand Up @@ -303,16 +306,15 @@ struct AgentRadixSortDownsweep
}

/**
* Load a tile of keys (specialized for full tile, any ranking algorithm)
* Load a tile of keys (specialized for full tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<true> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
BlockLoadKeysT(temp_storage.load_keys).Load(
d_keys_in + block_offset, keys);
Expand All @@ -322,16 +324,15 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of keys (specialized for partial tile, any ranking algorithm)
* Load a tile of keys (specialized for partial tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<false> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -345,30 +346,29 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of keys (specialized for full tile, match ranking algorithm)
* Load a tile of keys (specialized for full tile, warp-striped load)
*/
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<true> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys);
}


/**
* Load a tile of keys (specialized for partial tile, match ranking algorithm)
* Load a tile of keys (specialized for partial tile, warp-striped load)
*/
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<false> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -377,17 +377,15 @@ struct AgentRadixSortDownsweep
LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys, valid_items, oob_item);
}


/**
* Load a tile of values (specialized for full tile, any ranking algorithm)
* Load a tile of values (specialized for full tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<true> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
BlockLoadValuesT(temp_storage.load_values).Load(
d_values_in + block_offset, values);
Expand All @@ -397,15 +395,14 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of values (specialized for partial tile, any ranking algorithm)
* Load a tile of values (specialized for partial tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<false> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -419,28 +416,27 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of items (specialized for full tile, match ranking algorithm)
* Load a tile of items (specialized for full tile, warp-striped load)
*/
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<true> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values);
}


/**
* Load a tile of items (specialized for partial tile, match ranking algorithm)
* Load a tile of items (specialized for partial tile, warp-striped load)
*/
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<false> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -449,7 +445,6 @@ struct AgentRadixSortDownsweep
LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values, valid_items);
}


/**
* Truck along associated values
*/
Expand All @@ -470,7 +465,7 @@ struct AgentRadixSortDownsweep
block_offset,
valid_items,
Int2Type<FULL_TILE>(),
Int2Type<RANK_ALGORITHM>());
Int2Type<LOAD_WARP_STRIPED>());

ScatterValues<FULL_TILE>(
values,
Expand Down Expand Up @@ -515,7 +510,7 @@ struct AgentRadixSortDownsweep
valid_items,
default_key,
Int2Type<FULL_TILE>(),
Int2Type<RANK_ALGORITHM>());
Int2Type<LOAD_WARP_STRIPED>());

// Twiddle key bits if necessary
#pragma unroll
Expand All @@ -529,8 +524,7 @@ struct AgentRadixSortDownsweep
BlockRadixRankT(temp_storage.radix_rank).RankKeys(
keys,
ranks,
current_bit,
num_bits,
digit_extractor,
exclusive_digit_prefix);

CTA_SYNC();
Expand Down Expand Up @@ -586,7 +580,7 @@ struct AgentRadixSortDownsweep
if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
{
bin_offset[track] -= exclusive_digit_prefix[track];
temp_storage.relative_bin_offsets[bin_idx] = bin_offset[track];
temp_storage.keys_and_offsets.relative_bin_offsets[bin_idx] = bin_offset[track];
bin_offset[track] += inclusive_digit_prefix[track];
}
}
Expand Down Expand Up @@ -677,8 +671,7 @@ struct AgentRadixSortDownsweep
d_values_in(d_values_in),
d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)),
d_values_out(d_values_out),
current_bit(current_bit),
num_bits(num_bits),
digit_extractor(current_bit, num_bits),
short_circuit(1)
{
#pragma unroll
Expand Down Expand Up @@ -717,8 +710,7 @@ struct AgentRadixSortDownsweep
d_values_in(d_values_in),
d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)),
d_values_out(d_values_out),
current_bit(current_bit),
num_bits(num_bits),
digit_extractor(current_bit, num_bits),
short_circuit(1)
{
#pragma unroll
Expand Down
19 changes: 9 additions & 10 deletions src/3rdparty/cub/agent/agent_radix_sort_upsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "../thread/thread_load.cuh"
#include "../warp/warp_reduce.cuh"
#include "../block/block_load.cuh"
#include "../block/radix_rank_sort_operations.cuh"
#include "../config.cuh"
#include "../util_type.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
Expand Down Expand Up @@ -121,7 +122,7 @@ struct AgentRadixSortUpsweep
PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter),
LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE,

LOG_COUNTER_LANES = CUB_MAX(0, RADIX_BITS - LOG_PACKING_RATIO),
LOG_COUNTER_LANES = CUB_MAX(0, int(RADIX_BITS) - int(LOG_PACKING_RATIO)),
COUNTER_LANES = 1 << LOG_COUNTER_LANES,

// To prevent counter overflow, we must periodically unpack and aggregate the
Expand All @@ -139,6 +140,9 @@ struct AgentRadixSortUpsweep
// Input iterator wrapper type (for applying cache modifier)s
typedef CacheModifiedInputIterator<LOAD_MODIFIER, UnsignedBits, OffsetT> KeysItr;

// Digit extractor type
typedef BFEDigitExtractor<KeyT> DigitExtractorT;

/**
* Shared memory storage layout
*/
Expand Down Expand Up @@ -167,12 +171,8 @@ struct AgentRadixSortUpsweep
// Input and output device pointers
KeysItr d_keys_in;

// The least-significant bit position of the current digit to extract
int current_bit;

// Number of bits in current digit
int num_bits;

// Digit extractor
DigitExtractorT digit_extractor;


//---------------------------------------------------------------------
Expand Down Expand Up @@ -217,7 +217,7 @@ struct AgentRadixSortUpsweep
UnsignedBits converted_key = Traits<KeyT>::TwiddleIn(key);

// Extract current digit bits
UnsignedBits digit = BFE(converted_key, current_bit, num_bits);
UnsignedBits digit = digit_extractor.Digit(converted_key);

// Get sub-counter offset
UnsignedBits sub_counter = digit & (PACKING_RATIO - 1);
Expand Down Expand Up @@ -342,8 +342,7 @@ struct AgentRadixSortUpsweep
:
temp_storage(temp_storage.Alias()),
d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
current_bit(current_bit),
num_bits(num_bits)
digit_extractor(current_bit, num_bits)
{}


Expand Down
Loading

0 comments on commit c753391

Please sign in to comment.