Skip to content

Commit

Permalink
Renamed files to xe_*
Browse files Browse the repository at this point in the history
* Removed l2 workspace alignment
  • Loading branch information
muhammad-tanvir-1211 committed Sep 16, 2024
1 parent e08e740 commit 93d87fc
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 27 deletions.
7 changes: 3 additions & 4 deletions examples/sycl/pvc/pvc_gemm_streamk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
#include "cutlass/util/reference/device/tensor_compare.h"
#include "common.h"

#include "cutlass/gemm/kernel/intel_pvc_persistent_tile_scheduler_params_streamk.hpp"
#include "cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp"
using namespace cute;

///////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -87,7 +87,6 @@ struct Options {
cmd.get_cmd_line_argument("n", n, 4096);
cmd.get_cmd_line_argument("k", k, 4096);
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("splits", splits, 16);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations, 100);
Expand Down Expand Up @@ -232,8 +231,8 @@ struct ExampleRunner {
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
hw_info,
{options.splits,
options.splitk ? cutlass::gemm::kernel::detail::PersistentTileSchedulerIntelPVCStreamKParams::DecompositionMode::SplitK :
cutlass::gemm::kernel::detail::PersistentTileSchedulerIntelPVCStreamKParams::DecompositionMode::StreamK}
options.splitk ? cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode::SplitK :
cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode::StreamK}
};

Gemm gemm_op;
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/gemm_universal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U

#if defined(SYCL_INTEL_TARGET)
#include "cutlass/gemm/kernel/intel_pvc_gemm.hpp"
#include "cutlass/gemm/kernel/intel_pvc_gemm_cooperative.hpp"
#include "cutlass/gemm/kernel/xe_gemm_cooperative.hpp"
#endif
////////////////////////////////////////////////////////////////////////////////
4 changes: 2 additions & 2 deletions include/cutlass/gemm/kernel/tile_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp"
#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp"
#if defined (SYCL_INTEL_TARGET)
#include "cutlass/gemm/kernel/intel_pvc_tile_scheduler_streamk.hpp"
#include "cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp"
#endif
////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -141,7 +141,7 @@ struct TileSchedulerSelector<
TileShape,
ClusterShape
> {
using Scheduler = PersistentTileSchedulerIntelPVCStreamK<TileShape>;
using Scheduler = PersistentTileSchedulerXeStreamK<TileShape>;
};
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ namespace kernel {
namespace detail {

////////////////////////////////////////////////////////////////////////////////
// Parameters for Intel PVC persistent stream-K scheduler
struct PersistentTileSchedulerIntelPVCStreamKParams {
// Parameters for Xe persistent stream-K scheduler
struct PersistentTileSchedulerXeStreamKParams {

// Strategies for computing reductions between work-groups computing portions of a given output tile
enum class ReductionMode {
Expand Down Expand Up @@ -88,7 +88,7 @@ struct PersistentTileSchedulerIntelPVCStreamKParams {
FastDivmodU64 divmod_blk_major_{};

// Divide up the number of stream-K tiles amongst G groups of stream-K units.
// Currently defaults to 1 since we don't create groups for PVC.
// Currently defaults to 1 since we don't create groups for Xe.
FastDivmodU64 divmod_sk_groups_{};

// Number of stream-K units in each group
Expand Down Expand Up @@ -464,7 +464,7 @@ struct PersistentTileSchedulerIntelPVCStreamKParams {
static size_t
get_barrier_workspace_size(uint64_t num_tiles, uint32_t barrier_bits) {
size_t workspace_bits = num_tiles * static_cast<size_t>(barrier_bits);
return round_up_to_l2_alignment(bits_to_bytes<size_t>(workspace_bits));
return bits_to_bytes<size_t>(workspace_bits);
}

// Calculates the size of the workspace needed for holding partial outputs from splits
Expand All @@ -473,7 +473,7 @@ struct PersistentTileSchedulerIntelPVCStreamKParams {
get_reduction_workspace_size(uint64_t num_tiles, GemmCoord tile_shape, uint32_t accumulator_bits, uint32_t num_accumulator_mtxs = 1) {
size_t output_tile_size = tile_shape.m() * tile_shape.n();
size_t workspace_bits = accumulator_bits * output_tile_size * num_tiles * num_accumulator_mtxs;
return round_up_to_l2_alignment(bits_to_bytes<size_t>(workspace_bits));
return bits_to_bytes<size_t>(workspace_bits);
}

static void
Expand Down Expand Up @@ -695,15 +695,6 @@ struct PersistentTileSchedulerIntelPVCStreamKParams {
sk_units_ = 0;
divmod_sk_units_per_group_ = FastDivmodU64(blocks_m * blocks_n * blocks_l);
}

private:
// Round up number of bytes to the nearest multiple of L2 cache line alignment
CUTLASS_HOST_DEVICE
static size_t
round_up_to_l2_alignment(size_t bytes) {
constexpr size_t L2CacheLineSizeBytes = 128u;
return (bytes + L2CacheLineSizeBytes - 1) / L2CacheLineSizeBytes * L2CacheLineSizeBytes;
}
};

////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/kernel/intel_pvc_persistent_tile_scheduler_params_streamk.hpp"
#include "cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp"

namespace cutlass::gemm::kernel::detail {

// Persistent Thread Block (TB) scheduler leveraging stream-K decomposition
template <
class TileShape
>
class PersistentTileSchedulerIntelPVCStreamK {
class PersistentTileSchedulerXeStreamK {
//
// Data members
//
Expand All @@ -59,7 +59,7 @@ class PersistentTileSchedulerIntelPVCStreamK {
// Use a dummy barrier manager to simply get the type used to store the barrier
using BarrierType = typename NamedBarrierManager<1>::T;

using Params = PersistentTileSchedulerIntelPVCStreamKParams;
using Params = PersistentTileSchedulerXeStreamKParams;
using ReductionMode = Params::ReductionMode;
using DecompositionMode = Params::DecompositionMode;

Expand Down Expand Up @@ -180,10 +180,10 @@ class PersistentTileSchedulerIntelPVCStreamK {
}

CUTLASS_HOST_DEVICE
PersistentTileSchedulerIntelPVCStreamK() { };
PersistentTileSchedulerXeStreamK() { };

CUTLASS_HOST_DEVICE
PersistentTileSchedulerIntelPVCStreamK(Params const& params_) : scheduler_params(params_) {
PersistentTileSchedulerXeStreamK(Params const& params_) : scheduler_params(params_) {
current_work_linear_idx_ = uint64_t(BlockIdxX());
}

Expand Down Expand Up @@ -324,7 +324,7 @@ template <int ThreadsPerBlock, class FrgTensorC>
int barrier_group_thread_idx = ThreadIdxX();

// Reductions use BlockStripedReduce with a width of BarrierManager::ThreadCount under the hood.
// Thus, the start of the reduction space is the same across all threads in a warp group.
// Thus, the start of the reduction space is the same across all threads in a work group.
int reduction_offset =
(cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * reduction_tile_idx * num_accumulator_mtxs) +
reduction_peer_offset;
Expand Down

0 comments on commit 93d87fc

Please sign in to comment.