Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Jul 31, 2024
1 parent 3327a21 commit d2387f8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
17 changes: 4 additions & 13 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,8 @@ tile_load(tile_t& tile, payload_t& payload) {
detail::getNextPowerOf2<uint32_t(tile_t::block_elems * sizeof(dtype))>();

using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
static constexpr uint32_t max_load_vec_len = std::min(
uint32_t(tile_t::block_elems * sizeof(dtype)),
load_store_attr::max_load_vec_len);
static constexpr uint32_t max_load_vec_len =
std::min(power2_block_elems, load_store_attr::max_aligned_load_vec_len);

static constexpr uint32_t max_load_vec_elems =
max_load_vec_len / sizeof(dtype);
Expand Down Expand Up @@ -501,10 +500,7 @@ tile_load(tile_t& tile, payload_t& payload) {
load_elems,
payload_t::vector_size,
L1,
L2>(
payload.base_ptr,
payload.channel_offset + payload.base_offset + address_offset,
mask);
L2>(payload.base_ptr, channel_offset + address_offset, mask);

if constexpr (
payload_t::vector_size > 1 && payload_t::num_channel > 1) {
Expand Down Expand Up @@ -684,12 +680,7 @@ tile_load(
: offset_x * sizeof(dtype) +
(offset_y + sub_block_y) * payload.pitch_in_bytes;

reg_tmp = xetla_load_global<
load_dtype,
load_elems,
1,
L1,
L2>(
reg_tmp = xetla_load_global<load_dtype, load_elems, 1, L1, L2>(
payload.base_ptr,
channel_offset + address_offset,
pred_x && pred_y);
Expand Down
14 changes: 7 additions & 7 deletions include/subgroup/tile/impl/payload_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2226,18 +2226,18 @@ struct prefetch_payload_t<

private:
// Fetches the entire CL.
// static constexpr uint32_t cacheline_elems = 64 / sizeof(dtype);
// static constexpr uint32_t mem_block_nums =
// (tile_desc::tile_size_x + cacheline_elems - 1) / cacheline_elems;
// static constexpr uint32_t num_coop_sg = num_coop_sg_;
static constexpr uint32_t cacheline_elems = 64 / sizeof(dtype);
static constexpr uint32_t mem_block_nums =
(tile_desc::tile_size_x + cacheline_elems - 1) / cacheline_elems;
static constexpr uint32_t num_coop_sg = num_coop_sg_;

// For mem_tile_nums < num_coop_sg cases, mem_tile_size_x will be CL length
// which might lead to illegal read.
// there are num_coop_sg threads to prefetch mem_block_nums
// each thread will prefetch mem_tile_size_x elements
// static constexpr uint32_t mem_tile_size_x = mem_block_nums > num_coop_sg
// ? (mem_block_nums + num_coop_sg - 1) / num_coop_sg* cacheline_elems
// : 0;
static constexpr uint32_t mem_tile_size_x = mem_block_nums > num_coop_sg
? (mem_block_nums + num_coop_sg - 1) / num_coop_sg* cacheline_elems
: 0;
using this_payload_t =
prefetch_payload_t<mem_desc_t, tile_desc, num_coop_sg_, arch_tag>;

Expand Down

0 comments on commit d2387f8

Please sign in to comment.