Skip to content

Commit

Permalink
Fix issues in PVC example
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed May 20, 2024
1 parent ce10e06 commit 1b95768
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
19 changes: 11 additions & 8 deletions examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,8 @@
template <typename T>
static void fill_matrix(std::vector<T> &M)
{
std::random_device dev;
std::mt19937 rng(dev());
std::uniform_real_distribution<float> dist(1.0, 2.0);
std::generate(std::begin(M), std::end(M), [&]
{ return static_cast<T>(dist(rng)); });
{ return static_cast<T>( 2*(rand() / double(RAND_MAX)) - 1 ); });
}

template <typename T>
Expand Down Expand Up @@ -208,7 +205,12 @@ struct ExampleRunner {

// Check if output from CUTLASS kernel and reference kernel are relatively equal or not
// need to set a larger error margin for comparison to succeed
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), 0.5f, 0.5f);
auto epsilon = static_cast<ElementOutput>(0.1f);
auto nonzero_floor = static_cast<ElementOutput>(0.1f);

bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(
block_ref_D.get(), block_D.get(), block_D.size(),
epsilon, nonzero_floor);

return passed;
}
Expand All @@ -219,7 +221,7 @@ struct ExampleRunner {
auto [M, N, K, L] = problem_shape_MNKL;

stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(K, N, L));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));

Expand Down Expand Up @@ -279,7 +281,7 @@ struct ExampleRunner {

// Verify that the result is correct
bool passed = verify(problem_size, options.alpha, options.beta);
std::cout << "PVC GEMM Example : " << (passed ? "Passed" : "Failed") << std::endl;
std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl;

if (passed && options.iterations > 0) {
GPU_Clock timer;
Expand All @@ -291,7 +293,8 @@ struct ExampleRunner {

float cute_time = timer.seconds() / options.iterations;
double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12;
printf("PVC GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000);
std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000);
}

return;
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/relatively_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ CUTLASS_HOST_DEVICE
bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) {

#if defined (CUTLASS_ENABLE_SYCL)
using cutlass::abs;
using sycl::fabs;
#elif defined(__CUDACC_RTC__)
using cuda::std::abs;
#else
Expand Down

0 comments on commit 1b95768

Please sign in to comment.