Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
fix sdp bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Jun 21, 2024
1 parent 8f0abc4 commit ea1267d
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 81 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ if (${LOG} STREQUAL "on")
endif ()

# For large registers mode, enable 256 registers for kernels
set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
# set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt")
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-codegen")
# Enable bank conflict reduction.
Expand Down
4 changes: 2 additions & 2 deletions include/experimental/group/gemm/compute_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ struct compute_policy_int4_dequantize<
quant_info_.weight_mem_layout == mem_layout::col_major;

static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16;
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32;
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 128 : 32;
static constexpr uint32_t block_size_x_a =
block_bytes_x_a / sizeof(dtype_mma_a);
static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32;
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32;
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 128 : 32;
static constexpr uint32_t block_size_y_b =
block_bytes_y_b / sizeof(dtype_mma_b);

Expand Down
128 changes: 64 additions & 64 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
namespace gpu::xetla::subgroup {

namespace detail {
template <typename tile_t, typename payload_t, bool is_lsc_gather_ = false>
template <typename tile_t, typename payload_t, bool is_lsc_gather_ = true>
struct check_load_type {
static constexpr bool is_lsc_gather = is_lsc_gather_;
static constexpr bool is_global_block_2d =
Expand Down Expand Up @@ -465,73 +465,73 @@ tile_load(tile_t& tile, payload_t& payload) {
uint32_t offset_x = j * tile_desc::block_size_x;
auto reg_sub = tile.reg.xetla_select<tile_desc::block_elems, 1>(
(i * tile_desc::num_block_x + j) * tile_desc::block_elems);
// #pragma unroll
// for (uint32_t sub_block_offset = 0; sub_block_offset <
// (payload_t::mem_transpose ? tile_desc::block_size_x
// : tile_desc::block_size_y);
// sub_block_offset += num_channel) {
uint32_t sub_block_offset = 0;
xetla_vector<load_dtype, load_elems> reg_tmp = 0;
uint32_t address_offset = payload_t::mem_transpose
? (offset_x + sub_block_offset) * payload.pitch_in_bytes +
offset_y * sizeof(dtype)
: offset_x * sizeof(dtype) +
(offset_y + sub_block_offset) * payload.pitch_in_bytes;
xetla_mask<num_channel> pred = 1;
if constexpr (num_channel > 1) {
// For SDP load, need pred
const uint32_t sub_block_offset_x = payload.base_x + offset_x +
(payload_t::mem_transpose ? sub_block_offset : 0);
const uint32_t sub_block_offset_y = payload.base_y + offset_y +
(payload_t::mem_transpose ? 0 : sub_block_offset);
const auto offset_ch_dim =
payload_t::trans ? sub_block_offset_x : sub_block_offset_y;
const auto size_ch_dim =
payload_t::trans ? payload.width_in_elems : payload.height_in_elems;

pred = offset_ch_dim + num_channel > size_ch_dim
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
size_ch_dim)
: 1;
}
reg_tmp = xetla_load_global<
load_dtype,
payload_t::simd_exec_size,
data_size::default_size,
L1,
L2,
payload_t::num_channel>(
payload.base_ptr,
payload.channel_offset + payload.base_offset + address_offset,
pred);

if constexpr (
payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) {
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
#pragma unroll
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
iii * payload_t::simd_exec_size) =
reg_tmp.xetla_select<
payload_t::simd_exec_size,
payload_t::num_channel>(iii);
else // TODO (dingyi): Delete after driver fix
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
iii * payload_t::simd_exec_size) = 0;
for (uint32_t sub_block_offset = 0; sub_block_offset <
(payload_t::mem_transpose ? tile_desc::block_size_x
: tile_desc::block_size_y);
sub_block_offset += num_channel) {
// uint32_t sub_block_offset = 0;
xetla_vector<load_dtype, load_elems> reg_tmp = 0;
uint32_t address_offset = payload_t::mem_transpose
? (offset_x + sub_block_offset) * payload.pitch_in_bytes +
offset_y * sizeof(dtype)
: offset_x * sizeof(dtype) +
(offset_y + sub_block_offset) * payload.pitch_in_bytes;
xetla_mask<num_channel> pred = 1;
if constexpr (num_channel > 1) {
// For SDP load, need pred
const uint32_t sub_block_offset_x = payload.base_x + offset_x +
(payload_t::mem_transpose ? sub_block_offset : 0);
const uint32_t sub_block_offset_y = payload.base_y + offset_y +
(payload_t::mem_transpose ? 0 : sub_block_offset);
const auto offset_ch_dim =
payload_t::trans ? sub_block_offset_x : sub_block_offset_y;
const auto size_ch_dim = payload_t::trans ? payload.width_in_elems
: payload.height_in_elems;

pred = offset_ch_dim + num_channel > size_ch_dim
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
size_ch_dim)
: 1;
}
reg_tmp = xetla_load_global<
load_dtype,
payload_t::simd_exec_size,
data_size::default_size,
L1,
L2,
num_channel>(
payload.base_ptr,
payload.channel_offset + payload.base_offset + address_offset,
pred);

if constexpr (
payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) {
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
#pragma unroll
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
iii * payload_t::simd_exec_size) =
reg_tmp.xetla_select<
payload_t::simd_exec_size,
payload_t::num_channel>(iii);
else // TODO (dingyi): Delete after driver fix
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
iii * payload_t::simd_exec_size) = 0;
}
reg_sub
.xetla_select<load_elems * pack_factor, 1>(
sub_block_offset * tile_desc::block_size_x)
.xetla_format<load_dtype>() = reg_tmp_trans;
} else {
reg_sub
.xetla_select<load_elems * pack_factor, 1>(
sub_block_offset * tile_desc::block_size_x)
.xetla_format<load_dtype>() = reg_tmp;
}
reg_sub
.xetla_select<load_elems * pack_factor, 1>(
sub_block_offset * tile_desc::block_size_x)
.xetla_format<load_dtype>() = reg_tmp_trans;
} else {
reg_sub
.xetla_select<load_elems * pack_factor, 1>(
sub_block_offset * tile_desc::block_size_x)
.xetla_format<load_dtype>() = reg_tmp;
}
}
// }
}

if constexpr (payload_t::trans) {
Expand Down
24 changes: 12 additions & 12 deletions include/subgroup/tile/impl/payload_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,24 +1152,24 @@ struct mem_payload_t<
uint32_t,
dtype>::type>::type;
static constexpr uint32_t pack_factor = sizeof(mem_dtype) / sizeof(dtype);
// for pvc, we can use simd16 or simd32
static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype);
static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype);
static constexpr uint32_t simd_channel =
((tile_bytes % max_store_bytes) == 0 &&
(block_bytes % max_store_bytes) == 0)
? 32
: 16;
static constexpr uint32_t num_channel = mem_transpose
? (simd_channel >= block_size_x) ? block_size_x : simd_channel
: (simd_channel >= block_size_y) ? block_size_y
: simd_channel;

static constexpr uint32_t simd_exec_size =
(mem_transpose ? block_size_y : block_size_x) >= pack_factor
? (mem_transpose ? block_size_y : block_size_x) / pack_factor
: 1;

// for pvc, we can use simd16 or simd32
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
static constexpr uint32_t max_bytes =
load_store_attr::max_load_vec_len;

static constexpr uint32_t simd_channel =
max_bytes / (simd_exec_size * sizeof(mem_dtype));

static constexpr uint32_t num_channel = mem_transpose
? std::min(block_size_x, simd_channel)
: std::min(block_size_y, simd_channel);

xetla_vector<uint32_t, num_channel> channel_offset;
xetla_vector<uint32_t, num_channel> step_x;
xetla_vector<uint32_t, num_channel> step_y;
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/gemv/int4/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class test_col_major_1 {
static constexpr size_t wg_n = 1;
static constexpr size_t sg_m = 1;
static constexpr size_t sg_n = 1;
static constexpr size_t sg_k = 1024 / 1;
static constexpr size_t sg_k = 512 / 1;
static constexpr size_t dequant_s = 128;
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
Expand All @@ -47,7 +47,7 @@ class test_col_major_1 {
static constexpr mem_layout layout_a = mem_layout::row_major;
static constexpr mem_layout layout_b = mem_layout::col_major;
static constexpr mma_engine mma_eng = mma_engine::fpu;
static constexpr gpu_arch arch = gpu_arch::XeHpc;
static constexpr gpu_arch arch = gpu_arch::XeLpg;
using data_type_a = fp16;
using data_type_b = int4x8;
using data_type_c = fp16;
Expand Down

0 comments on commit ea1267d

Please sign in to comment.