Skip to content

Commit

Permalink
Turn TEST_[HALF|BF]_T into function-style macros and fix some tests (
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber authored Feb 3, 2025
1 parent 5088b64 commit 660b6a8
Show file tree
Hide file tree
Showing 16 changed files with 99 additions and 84 deletions.
16 changes: 8 additions & 8 deletions c2h/generators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -478,15 +478,15 @@ template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, float* out, std::size_t element_size);
template void init_key_segments(
const c2h::device_vector<std::uint32_t>& segment_offsets, custom_type_state_t* out, std::size_t element_size);
#ifdef _CCCL_HAS_NVFP16
#if TEST_HALF_T()
template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, half_t* out, std::size_t element_size);
#endif // _CCCL_HAS_NVFP16
#endif // TEST_HALF_T()

#ifdef _CCCL_HAS_NVBF16
#if TEST_BF_T()
template void
init_key_segments(const c2h::device_vector<std::uint32_t>& segment_offsets, bfloat16_t* out, std::size_t element_size);
#endif // _CCCL_HAS_NVBF16
#endif // TEST_BF_T()
} // namespace detail

template <typename T>
Expand Down Expand Up @@ -552,15 +552,15 @@ INSTANTIATE(double);
INSTANTIATE(bool);
INSTANTIATE(char);

#ifdef _CCCL_HAS_NVFP16
#if TEST_HALF_T()
INSTANTIATE(half_t);
INSTANTIATE(__half);
#endif // _CCCL_HAS_NVFP16
#endif // TEST_HALF_T()

#ifdef _CCCL_HAS_NVBF16
#if TEST_BF_T()
INSTANTIATE(bfloat16_t);
INSTANTIATE(__nv_bfloat16);
#endif // _CCCL_HAS_NVBF16
#endif // TEST_BF_T()

#undef INSTANTIATE_RND
#undef INSTANTIATE_MOD
Expand Down
26 changes: 17 additions & 9 deletions c2h/include/c2h/extended_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,29 @@
#include <cuda/__cccl_config>

#ifndef TEST_HALF_T
# define TEST_HALF_T _CCCL_HAS_NVFP16
#endif
# if defined(_CCCL_HAS_NVFP16)
# define TEST_HALF_T() 1
# else // defined(_CCCL_HAS_NVFP16)
# define TEST_HALF_T() 0
# endif // defined(_CCCL_HAS_NVFP16)
#endif // TEST_HALF_T

#ifndef TEST_BF_T
# define TEST_BF_T _CCCL_HAS_NVBF16
#endif

#ifdef TEST_HALF_T
# if defined(_CCCL_HAS_NVBF16)
# define TEST_BF_T() 1
# else // defined(_CCCL_HAS_NVBF16)
# define TEST_BF_T() 0
# endif // defined(_CCCL_HAS_NVBF16)
#endif // TEST_BF_T

#if TEST_HALF_T()
# include <cuda_fp16.h>

# include <c2h/half.cuh>
#endif
#endif // TEST_HALF_T()

#ifdef TEST_BF_T
#if TEST_BF_T()
# include <cuda_bf16.h>

# include <c2h/bfloat16.cuh>
#endif
#endif // TEST_BF_T()
8 changes: 4 additions & 4 deletions cub/test/catch2_segmented_sort_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -265,21 +265,21 @@ struct unwrap_value_t_impl
using type = T;
};

#if TEST_HALF_T
#if TEST_HALF_T()
template <>
struct unwrap_value_t_impl<half_t>
{
using type = __half;
};
#endif
#endif // TEST_HALF_T()

#if TEST_BF_T
#if TEST_BF_T()
template <>
struct unwrap_value_t_impl<bfloat16_t>
{
using type = __nv_bfloat16;
};
#endif
#endif // TEST_BF_T()

template <typename T>
using unwrap_value_t = typename unwrap_value_t_impl<T>::type;
Expand Down
8 changes: 4 additions & 4 deletions cub/test/catch2_test_device_histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ auto cast_if_half_pointer(T* p) -> T*
return p;
}

#if TEST_HALF_T
#if TEST_HALF_T()
auto cast_if_half_pointer(half_t* p) -> __half*
{
return reinterpret_cast<__half*>(p);
Expand All @@ -79,7 +79,7 @@ auto cast_if_half_pointer(const half_t* p) -> const __half*
{
return reinterpret_cast<const __half*>(p);
}
#endif
#endif // TEST_HALF_T()

template <typename T>
using caller_vector = c2h::
Expand Down Expand Up @@ -412,9 +412,9 @@ using types =
std::uint32_t,
std::int64_t,
std::uint64_t,
#if TEST_HALF_T
#if TEST_HALF_T()
half_t,
#endif
#endif // TEST_HALF_T()
float,
double>;

Expand Down
9 changes: 5 additions & 4 deletions cub/test/catch2_test_device_radix_sort_keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "catch2_radix_sort_helper.cuh"
#include "catch2_test_launch_helper.h"
#include <c2h/catch2_test_helper.h>
#include <c2h/extended_types.h>

// %PARAM% TEST_LAUNCH lid 0:1:2

Expand All @@ -70,12 +71,12 @@ using bit_window_key_types = c2h::type_list<cuda::std::uint8_t, cuda::std::int8_
using key_types = c2h::type_list<
cuda::std::uint16_t
, cuda::std::int16_t
#ifdef TEST_HALF_T
#if TEST_HALF_T()
, half_t
#endif
#ifdef TEST_BF_T
#endif // TEST_HALF_T()
#if TEST_BF_T()
, bfloat16_t
#endif
#endif // TEST_BF_T()
>;
// clang-format on
using bit_window_key_types = c2h::type_list<cuda::std::uint16_t, cuda::std::int16_t>;
Expand Down
13 changes: 6 additions & 7 deletions cub/test/catch2_test_device_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,13 @@ using full_type_list = c2h::type_list<type_pair<uchar3>, type_pair<ulonglong4>>;
// clang-format off
using full_type_list = c2h::type_list<
type_pair<custom_t>
#if TEST_HALF_T
, type_pair<half_t> // testing half
#endif
#if TEST_BF_T
, type_pair<bfloat16_t> // testing bf16

#if TEST_HALF_T()
, type_pair<half_t>
#endif // TEST_HALF_T()
#if TEST_BF_T()
, type_pair<bfloat16_t>
#endif // TEST_BF_T()
>;
#endif
// clang-format on
#elif TEST_TYPES == 4
// DPX SIMD instructions
Expand Down
25 changes: 15 additions & 10 deletions cub/test/catch2_test_device_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
#include <c2h/test_util_vec.h>
#include <nv/target>

#if TEST_HALF_T
#if TEST_HALF_T()
// Half support is provided by SM53+. We currently test against a few older architectures.
// The specializations below can be removed once we drop these architectures.

Expand Down Expand Up @@ -107,7 +107,12 @@ __host__ __device__ __forceinline__ //

return a;
}
#endif // TEST_HALF_T

CUB_NAMESPACE_END

#endif // TEST_HALF_T()

CUB_NAMESPACE_BEGIN

/**
* @brief Introduces the required NumericTraits for `c2h::custom_type_t`.
Expand Down Expand Up @@ -173,21 +178,21 @@ struct ExtendedFloatSum
return result;
}

#if TEST_HALF_T
#if TEST_HALF_T()
__host__ __device__ __half operator()(__half a, __half b) const
{
uint16_t result = this->operator()(half_t{a}, half_t(b)).raw();
return reinterpret_cast<__half&>(result);
}
#endif
#endif // TEST_HALF_T()

#if TEST_BF_T
#if TEST_BF_T()
__device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) const
{
uint16_t result = this->operator()(bfloat16_t{a}, bfloat16_t(b)).raw();
return reinterpret_cast<__nv_bfloat16&>(result);
}
#endif
#endif // TEST_BF_T()
};

template <class It>
Expand All @@ -196,7 +201,7 @@ inline It unwrap_it(It it)
return it;
}

#if TEST_HALF_T
#if TEST_HALF_T()
inline __half* unwrap_it(half_t* it)
{
return reinterpret_cast<__half*>(it);
Expand All @@ -209,9 +214,9 @@ inline thrust::constant_iterator<__half, OffsetT> unwrap_it(thrust::constant_ite
__half val = wrapped_val.operator __half();
return thrust::constant_iterator<__half, OffsetT>(val);
}
#endif
#endif // TEST_HALF_T()

#if TEST_BF_T
#if TEST_BF_T()
inline __nv_bfloat16* unwrap_it(bfloat16_t* it)
{
return reinterpret_cast<__nv_bfloat16*>(it);
Expand All @@ -224,7 +229,7 @@ thrust::constant_iterator<__nv_bfloat16, OffsetT> inline unwrap_it(thrust::const
__nv_bfloat16 val = wrapped_val.operator __nv_bfloat16();
return thrust::constant_iterator<__nv_bfloat16, OffsetT>(val);
}
#endif
#endif // TEST_BF_T()

template <typename T>
using unwrap_value_t = typename std::remove_reference<decltype(*unwrap_it(std::declval<T*>()))>::type;
Expand Down
8 changes: 4 additions & 4 deletions cub/test/catch2_test_device_reduce_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ using full_type_list = c2h::type_list<type_triple<uchar3, uchar3, custom_t>, typ
// clang-format off
using full_type_list = c2h::type_list<
type_triple<custom_t>
#if TEST_HALF_T
#if TEST_HALF_T()
, type_triple<half_t> // testing half
#endif
#if TEST_BF_T
#endif // TEST_HALF_T()
#if TEST_BF_T()
, type_triple<bfloat16_t> // testing bf16
#endif
#endif // TEST_BF_T()
>;
// clang-format on
#endif
Expand Down
8 changes: 4 additions & 4 deletions cub/test/catch2_test_device_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ using full_type_list = c2h::type_list<type_pair<uchar3>, type_pair<ulonglong4>>;
// clang-format off
using full_type_list = c2h::type_list<
type_pair<custom_t>
#if TEST_HALF_T
#if TEST_HALF_T()
, type_pair<half_t> // testing half
#endif
#if TEST_BF_T
#endif // TEST_HALF_T()
#if TEST_BF_T()
, type_pair<bfloat16_t> // testing bf16
#endif
#endif // TEST_BF_T()
>;
// clang-format on
#endif
Expand Down
8 changes: 4 additions & 4 deletions cub/test/catch2_test_device_scan_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ using full_type_list =
// clang-format off
using full_type_list = c2h::type_list<
type_quad<custom_t, custom_t, custom_t>
#if TEST_HALF_T
#if TEST_HALF_T()
, type_quad<half_t> // testing half
#endif
#if TEST_BF_T
#endif // TEST_HALF_T()
#if TEST_BF_T()
, type_quad<bfloat16_t> // testing bf16
#endif
#endif // TEST_BF_T()
>;
// clang-format on
#endif
Expand Down
9 changes: 5 additions & 4 deletions cub/test/catch2_test_device_segmented_radix_sort_keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "catch2_radix_sort_helper.cuh"
#include "catch2_test_launch_helper.h"
#include <c2h/catch2_test_helper.h>
#include <c2h/extended_types.h>

// TODO replace with DeviceSegmentedRadixSort::SortKeys interface once https://github.com/NVIDIA/cccl/issues/50 is
// addressed Temporary wrapper that allows specializing the DeviceSegmentedRadixSort algorithm for different offset
Expand Down Expand Up @@ -120,12 +121,12 @@ using bit_window_key_types = c2h::type_list<cuda::std::uint8_t, cuda::std::int8_
using key_types = c2h::type_list<
cuda::std::uint16_t
, cuda::std::int16_t
#ifdef TEST_HALF_T
#if TEST_HALF_T()
, half_t
#endif
#ifdef TEST_BF_T
#endif // TEST_HALF_T()
#if TEST_BF_T()
, bfloat16_t
#endif
#endif // TEST_BF_T()
>;
// clang-format on
using bit_window_key_types = c2h::type_list<cuda::std::uint16_t, cuda::std::int16_t>;
Expand Down
8 changes: 4 additions & 4 deletions cub/test/catch2_test_device_segmented_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ using full_type_list = c2h::type_list<type_pair<uchar3>, type_pair<ulonglong4>>;
// clang-format off
using full_type_list = c2h::type_list<
type_pair<custom_t>
#if TEST_HALF_T
#if TEST_HALF_T()
, type_pair<half_t> // testing half
#endif
#if TEST_BF_T
#endif // TEST_HALF_T()
#if TEST_BF_T()
, type_pair<bfloat16_t> // testing bf16
#endif
#endif // TEST_BF_T()
>;
// clang-format on
#endif
Expand Down
9 changes: 5 additions & 4 deletions cub/test/catch2_test_device_segmented_sort_keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "insert_nested_NVTX_range_guard.h"
// above header needs to be included first
#include <cub/device/device_segmented_sort.cuh>
#include <cub/util_type.cuh>

#include "catch2_radix_sort_helper.cuh"
#include "catch2_segmented_sort_helper.cuh"
Expand All @@ -43,14 +44,14 @@ using key_types =
c2h::type_list<bool,
std::uint8_t,
std::uint64_t
#if TEST_HALF_T
#if TEST_HALF_T()
,
half_t
#endif
#if TEST_BF_T
#endif // TEST_HALF_T()
#if TEST_BF_T()
,
bfloat16_t
#endif
#endif // TEST_BF_T()
>;

C2H_TEST("DeviceSegmentedSortKeys: No segments", "[keys][segmented][sort][device]")
Expand Down
8 changes: 4 additions & 4 deletions cub/test/catch2_test_device_segmented_sort_pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ using pair_types =
c2h::type_list<c2h::type_list<bool, std::uint8_t>,
c2h::type_list<std::int8_t, std::uint64_t>,
c2h::type_list<double, float>
#if TEST_HALF_T
#if TEST_HALF_T()
,
c2h::type_list<half_t, std::int8_t>
#endif
#if TEST_BF_T
#endif // TEST_HALF_T()
#if TEST_BF_T()
,
c2h::type_list<bfloat16_t, float>
#endif
#endif // TEST_BF_T()
>;

C2H_TEST("DeviceSegmentedSortPairs: No segments", "[pairs][segmented][sort][device]")
Expand Down
Loading

0 comments on commit 660b6a8

Please sign in to comment.