Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
sync_ipex(update prefetch)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Jul 31, 2024
1 parent d2387f8 commit 15a2b0c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 14 deletions.
84 changes: 80 additions & 4 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,14 @@ class gemm_t<

int scale_prefetch_addr_i = args.inner_loop_start;
int tile_k_idx = args.inner_loop_start;
uint32_t prefetch_stages =
stages < args.inner_loop_count ? stages : args.inner_loop_count;
uint32_t prefetch_compute_stages =
stages < args.inner_loop_count ? args.inner_loop_count - stages : 0;
uint32_t compute_stages =
stages < args.inner_loop_count ? stages : args.inner_loop_count;
SW_BARRIER();
#pragma unroll
for (uint32_t i = 0; i < stages; i++) {
for (uint32_t i = 0; i < prefetch_stages; i++) {
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
matA_prefetch_payload);
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
Expand Down Expand Up @@ -542,8 +547,7 @@ class gemm_t<
}
}
}

for (uint32_t i = 0; i < args.inner_loop_count; i++) {
for (uint32_t i = 0; i < prefetch_compute_stages; i++) {
if constexpr (enable_periodic_sync) {
if ((i % sync_freq) == 0) {
if constexpr (wg_size_x > 1) {
Expand Down Expand Up @@ -651,6 +655,78 @@ class gemm_t<
}
}
SW_BARRIER();
for (uint32_t i = 0; i < compute_stages; i++) {
if constexpr (enable_periodic_sync) {
if ((i % sync_freq) == 0) {
if constexpr (wg_size_x > 1) {
nbarrier_a.arrive();
}
if constexpr (arch_tag >= gpu_arch::XeHpc) {
if constexpr (wg_size_y > 1) {
nbarrier_b.arrive();
}
}
}
}
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
matA, matA_payload);
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
matB, matB_payload);
// subgroup::tile_load<cache_hint::uncached, cache_hint::uncached>(
// matB, matB_payload);
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
scale, scale_payload);
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
zero_pt, zero_pt_payload);
}
tile_k_idx++;
SW_BARRIER();
matA_payload.template update_tdesc<update_dir_a>(matA_t::tile_size_x);
matB_payload.template update_tdesc<update_dir_b>(matB_t::tile_size_y);
if (tile_k_idx % scale_addr_update_freq == 0) {
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
}
if constexpr (compute_policy::quant_mode != quant_mode::I4_SYM) {
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
zero_pt_t::tile_size_y);
}
}
SW_BARRIER();
matA_acc_t matA_acc;
matB_acc_t matB_acc;
if constexpr (is_vnni_tiled_a) {
subgroup::vnni_reverse(matA);
}
subgroup::elemwise_cvt(matA_acc, matA);

dequantize(matB_acc, matB, scale, zero_pt, dequantize_args);
SW_BARRIER();
if constexpr (is_gemv) {
tile_mma::mma(
matAcc, matAcc, matC, matB_acc, matA_acc, i == compute_stages - 1);
} else {
if constexpr (is_col_major_b) {
tile_transpose(matB_acc);
}
tile_mma::mma(matC, matC, matB_acc, matA_acc);
}
SW_BARRIER();
if constexpr (enable_periodic_sync) {
if ((i % sync_freq) == 0) {
if constexpr (wg_size_x > 1) {
nbarrier_a.wait();
}
if constexpr (arch_tag >= gpu_arch::XeHpc) {
if constexpr (wg_size_y > 1) {
nbarrier_b.wait();
}
}
}
}
}
SW_BARRIER();
}

private:
Expand Down
18 changes: 8 additions & 10 deletions include/subgroup/tile/impl/payload_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1655,12 +1655,11 @@ struct prefetch_payload_t<
reg_layout_>,
num_coop_sg_,
arch_tag_,
std::enable_if_t<
(!arch_has_2d_load_store<arch_tag_>) &&
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
mem_layout_ == mem_layout::row_major) ||
((block_size_x_ != 1 || tile_size_x_ != 1) &&
mem_layout_ == mem_layout::col_major))>> {
std::enable_if_t<(!arch_has_2d_load_store<arch_tag_>)&&(
((block_size_y_ != 1 || tile_size_y_ != 1) &&
mem_layout_ == mem_layout::row_major) ||
((block_size_x_ != 1 || tile_size_x_ != 1) &&
mem_layout_ == mem_layout::col_major))>> {
using dtype = native_type_t<dtype_>;
using mem_desc_t =
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
Expand Down Expand Up @@ -1902,10 +1901,9 @@ struct prefetch_payload_t<
reg_layout_>,
num_coop_sg_,
arch_tag_,
std::enable_if_t<
(arch_has_2d_load_store<arch_tag_>) &&
(((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
std::enable_if_t<(arch_has_2d_load_store<arch_tag_>)&&(
((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
using dtype = dtype_;
using mem_desc_t =
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
Expand Down

0 comments on commit 15a2b0c

Please sign in to comment.