Skip to content

Commit

Permalink
Refactor radix_sort tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Feb 3, 2025
1 parent 9b461f4 commit e2d81be
Showing 1 changed file with 11 additions and 126 deletions.
137 changes: 11 additions & 126 deletions cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,8 @@ struct policy_hub
SEGMENTED_RADIX_BITS - 1>;
};

/// SM90
struct Policy900 : ChainedPolicy<900, Policy900, Policy800>
template <typename OnesweepSmallKeyPolicySizes>
struct OnesweepSmallKeyTunedPolicy
{
static constexpr bool ONESWEEP = true;
static constexpr int ONESWEEP_RADIX_BITS = 8;
Expand Down Expand Up @@ -770,9 +770,6 @@ struct policy_hub

using OnesweepLargeKeyPolicy = ::cuda::std::_If<sizeof(KeyT) == 4, OnesweepPolicyKey32, OnesweepPolicyKey64>;

using OnesweepSmallKeyPolicySizes =
sm90_small_key_tuning<sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>;

using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy<
OnesweepSmallKeyPolicySizes::threads,
OnesweepSmallKeyPolicySizes::items,
Expand Down Expand Up @@ -854,128 +851,16 @@ struct policy_hub
SEGMENTED_RADIX_BITS - 1>;
};

// todo(@gonidelis): refactor this as to not duplicate SM90.
struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900>
{
static constexpr bool ONESWEEP = true;
static constexpr int ONESWEEP_RADIX_BITS = 8;

using HistogramPolicy = AgentRadixSortHistogramPolicy<128, 16, 1, KeyT, ONESWEEP_RADIX_BITS>;
using ExclusiveSumPolicy = AgentRadixSortExclusiveSumPolicy<256, ONESWEEP_RADIX_BITS>;

private:
static constexpr int PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5;
static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5;
static constexpr int SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5;
static constexpr int OFFSET_64BIT = sizeof(OffsetT) == 8 ? 1 : 0;
static constexpr int FLOAT_KEYS = ::cuda::std::is_same<KeyT, float>::value ? 1 : 0;

using OnesweepPolicyKey32 = AgentRadixSortOnesweepPolicy<
384,
KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS
: (sizeof(ValueT) < 8 ? (OFFSET_64BIT ? 17 : 23) : (OFFSET_64BIT ? 29 : 30)),
DominantT,
1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT,
ONESWEEP_RADIX_BITS>;

using OnesweepPolicyKey64 = AgentRadixSortOnesweepPolicy<
384,
sizeof(ValueT) < 8 ? 30 : 24,
DominantT,
1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT,
ONESWEEP_RADIX_BITS>;

using OnesweepLargeKeyPolicy = ::cuda::std::_If<sizeof(KeyT) == 4, OnesweepPolicyKey32, OnesweepPolicyKey64>;

using OnesweepSmallKeyPolicySizes =
sm100_small_key_tuning<ValueT, sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>;

using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy<
OnesweepSmallKeyPolicySizes::threads,
OnesweepSmallKeyPolicySizes::items,
DominantT,
1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY,
BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT,
8>;

public:
using OnesweepPolicy = ::cuda::std::_If<sizeof(KeyT) < 4, OnesweepSmallKeyPolicy, OnesweepLargeKeyPolicy>;

// The Scan, Downsweep and Upsweep policies are never run on SM90, but we have to include them to prevent a
// compilation error: When we compile e.g. for SM70 **and** SM90, the host compiler will reach calls to those
// kernels, and instantiate them for MaxPolicy (which is Policy900) on the host, which will reach into the policies
// below to set the launch bounds. The device compiler pass will also compile all kernels for SM70 **and** SM90,
// even though only the Onesweep kernel is used on SM90.
using ScanPolicy =
AgentScanPolicy<512,
23,
OffsetT,
BLOCK_LOAD_WARP_TRANSPOSE,
LOAD_DEFAULT,
BLOCK_STORE_WARP_TRANSPOSE,
BLOCK_SCAN_RAKING_MEMOIZE>;

using DownsweepPolicy = AgentRadixSortDownsweepPolicy<
512,
23,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MATCH,
BLOCK_SCAN_WARP_SCANS,
PRIMARY_RADIX_BITS>;

using AltDownsweepPolicy = AgentRadixSortDownsweepPolicy<
(sizeof(KeyT) > 1) ? 256 : 128,
47,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
PRIMARY_RADIX_BITS - 1>;
struct Policy900
: ChainedPolicy<900, Policy900, Policy800>
, OnesweepSmallKeyTunedPolicy<sm90_small_key_tuning<sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>>
{};

using UpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 23, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS>;
using AltUpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 47, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1>;

using SingleTilePolicy = AgentRadixSortDownsweepPolicy<
256,
19,
DominantT,
BLOCK_LOAD_DIRECT,
LOAD_LDG,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
SINGLE_TILE_RADIX_BITS>;

using SegmentedPolicy = AgentRadixSortDownsweepPolicy<
192,
39,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
SEGMENTED_RADIX_BITS>;

using AltSegmentedPolicy = AgentRadixSortDownsweepPolicy<
384,
11,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
SEGMENTED_RADIX_BITS - 1>;
};
struct Policy1000
: ChainedPolicy<1000, Policy1000, Policy900>
, OnesweepSmallKeyTunedPolicy<
sm100_small_key_tuning<ValueT, sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>>
{};

using MaxPolicy = Policy1000;
};
Expand Down

0 comments on commit e2d81be

Please sign in to comment.