Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xe: fix various GEMM issues on XeLP #2373

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions src/gpu/intel/jit/emulation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
#ifndef GPU_INTEL_JIT_EMULATION_HPP
#define GPU_INTEL_JIT_EMULATION_HPP

#include "gpu/intel/jit/ngen/ngen_config.hpp"
#ifdef NGEN_ENABLE_SOURCE_LOCATION
#include <source_location>
#endif

#include <exception>

namespace dnnl {
Expand Down Expand Up @@ -66,9 +71,18 @@ struct EmulationState {
// Implementation wrapped as static methods in non-instantiated class.
// Clients should declare EmulationImplementation as a friend.
struct EmulationImplementation {
#ifdef NGEN_ENABLE_SOURCE_LOCATION
[[noreturn]] static void stub(
std::source_location where = std::source_location::current()) {
throw std::runtime_error(std::string("Unimplemented (at ")
+ std::string(where.file_name()) + ":"
+ std::to_string(where.line()) + ")");
}
#else
[[noreturn]] static void stub() {
throw std::runtime_error("Unimplemented");
}
#endif

template <typename DT, typename O>
static void applyDefaultType(O &op) {
Expand Down Expand Up @@ -247,32 +261,48 @@ struct EmulationImplementation {

bool dstQ = isQW(dst);
bool s0Q = isQW(src0);
bool s0D = isDW(src0);
bool isDF = (src0.getType() == DataType::df
&& dst.getType() == DataType::df);
bool unaligned = (mod.getExecSize() > 1 && src0.getHS() != 0
&& src0.getOffset() != dst.getOffset());

if ((dstQ && s0D) && strategy.emulate64) {
if (src0.getNeg()) stub();
bool s0Signed = isSigned(src0.getType());
RegData dstHi, dstLo;
splitToDW(dst, dstLo, dstHi);
g.mov(mod, dstLo, src0);
if (!s0Signed)
g.mov(mod, dstHi, 0);
else
g.asr(mod, dstHi, dstLo, uint16_t(31));
} else if (((dstQ || s0Q) && strategy.emulate64)
|| (isDF && unaligned && g.hardware >= ngen::HW::XeHP)) {
if (dstQ != s0Q) stub();

auto mod2x = mod;
mod2x.setExecSize(mod.getExecSize() * 2);

makeDWPair(dst, mod.getExecSize());
makeDWPair(src0, mod.getExecSize());
g.mov(mod2x, dst, src0);
bool emulateDF = isDF && unaligned && g.hardware >= ngen::HW::XeHP;

if ((strategy.emulate64 && dstQ) || emulateDF) {
switch (src0.getType()) {
case DataType::ub:
case DataType::uw:
case DataType::ud: {
RegData dstHi, dstLo;
splitToDW(dst, dstLo, dstHi);
g.mov(mod, dstLo, src0);
g.mov(mod, dstHi, 0);
break;
}
case DataType::d: {
if (src0.getNeg()) stub();
RegData dstHi, dstLo;
splitToDW(dst, dstLo, dstHi);
g.mov(mod, dstLo, src0);
g.asr(mod, dstHi, src0, uint16_t(31));
break;
}
case DataType::q:
case DataType::uq:
case DataType::df: {
if (dstQ != s0Q) stub();

auto mod2x = mod;
mod2x.setExecSize(mod.getExecSize() * 2);

makeDWPair(dst, mod.getExecSize());
makeDWPair(src0, mod.getExecSize());
g.mov(mod2x, dst, src0);
break;
}
default: stub(); break;
}
} else if (strategy.emulate64 && s0Q) {
stub();
} else if (dst.getType() == DataType::f
&& src0.getType() == DataType::bf
&& (src0.getHS() != 1 || mod.getExecSize() == 1)) {
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src,
const uint b_scale_dim = (NDIMS == 2) ? d1 : (NDIMS == 3) ? d2 : d3;
float b_scale = 1;
if (B_SCALES) load(&b_scale, b_scales + scale_stride * b_scale_dim);
acc *= a_scale * b_scale;
if (A_SCALES || B_SCALES) acc *= a_scale * b_scale;

if (bias) {
ACC_DATA_T b;
Expand Down
22 changes: 16 additions & 6 deletions src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,31 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx(
kernel_ctx, memory_desc_info_t::create(dst_md(0)), "DST", false);

int ndims = src_info.ndims;
bool is_int8 = src_md(1)->data_type == data_type::s8;
kernel_ctx.set_data_type(c_type);
//here SRC is output tensor of gemm call
def_data_type(kernel_ctx, is_int8 ? data_type::f32 : desc_.acc_type, "ACC");

kernel_ctx.define_int("NDIMS", ndims);
CHECK(def_attr_info(
kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md()));
const auto &attr_scales = attr()->scales_;
const bool with_src_scales
= !attr_scales.get(DNNL_ARG_SRC).has_default_values();
const bool with_wei_scales
= !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values();
const bool with_dst_scales
= !attr_scales.get(DNNL_ARG_DST).has_default_values();
auto is_int_type = [](data_type_t t) {
return utils::one_of(t, data_type::s8, data_type::u8, data_type::s32);
};
data_type_t acc_type = desc_.acc_type;
if (desc_.acc_type == data_type::s32) {
if (with_src_scales || with_wei_scales
|| !is_int_type(bias_info.data_type)
|| !is_int_type(dst_md(0)->data_type)) {
acc_type = data_type::f32;
}
}
def_data_type(kernel_ctx, acc_type, "ACC");

kernel_ctx.define_int("NDIMS", ndims);
CHECK(def_attr_info(
kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md()));
kernel_ctx.define_int("A_SCALES", with_src_scales);
kernel_ctx.define_int("B_SCALES", with_wei_scales);
kernel_ctx.define_int("C_SCALES", with_dst_scales);
Expand Down
2 changes: 2 additions & 0 deletions src/gpu/intel/ocl/ocl_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ DEF_load(float, char);
DEF_load(float, uchar);
DEF_load(int, char);
DEF_load(int, uchar);
DEF_load(int, int);
DEF_load(float, bf16);

// Writes
Expand All @@ -88,6 +89,7 @@ DEF_write(float, float);
DEF_write(char, int);
DEF_write(uchar, int);
DEF_write(bf16, int);
DEF_write(int, int);
DEF_write(float, int);
DEF_write(int, float);

Expand Down
Loading