Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

benchdnn: introduce --bia-dt option for drivers supporting bias (fixed MFDNN-3031, fixed MFDNN-12936) #2335

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool.black]
line-length = 80
include = 'scripts\/.*\.pyi?$'
39 changes: 17 additions & 22 deletions scripts/verbose_converter/src/benchdnn_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,7 @@ def _get_dir(self):
if self.entry.prop_kind not in dirs:
return ""

dir = dirs[self.entry.prop_kind]
for md in self.entry.mds:
if md.arg != "bia" or md.data_type == "undef":
continue
if "FWD" in dir:
return "FWD_B"
if dir == "BWD_W":
return "BWD_WB"
break
return dir
return dirs[self.entry.prop_kind]

def _get_alg(self):
return self.entry.aux.get("alg")
Expand Down Expand Up @@ -388,6 +379,17 @@ def dts(self):
return "--dt=" + ":".join(dt for dt in dts if dt)


class MultiDataTypeWithBiasMixin(MultiDataTypeMixin):
@property
def dts(self):
dts = super().dts
for md in self.entry.mds:
if md.arg != "bia":
continue
return f"{dts} --bia-dt={md.data_type}".strip()
return dts


class NormalizationMixin:
entry: ir.Entry

Expand Down Expand Up @@ -435,7 +437,7 @@ def aux(self):
class ConvolutionConverter(
AlgorithmMixin,
TagTripletMixin,
MultiDataTypeMixin,
MultiDataTypeWithBiasMixin,
Converter,
):
driver: str = "conv"
Expand Down Expand Up @@ -482,7 +484,9 @@ def tags(self):
return "--tag=" + ":".join(tags)


class InnerProductConverter(TagTripletMixin, MultiDataTypeMixin, Converter):
class InnerProductConverter(
TagTripletMixin, MultiDataTypeWithBiasMixin, Converter
):
driver: str = "ip"


Expand Down Expand Up @@ -523,7 +527,7 @@ def aux(self):
return f"--alg={algs[alg]}"


class MatmulConverter(StridesMixin, MultiDataTypeMixin, Converter):
class MatmulConverter(StridesMixin, MultiDataTypeWithBiasMixin, Converter):
driver: str = "matmul"

@staticmethod
Expand All @@ -544,15 +548,6 @@ def bias_mask(self):
return f"--bia_mask={mask}"
return ""

@property
def dts(self):
dts = super().dts
for md in self.entry.mds:
if md.arg != "bia":
continue
return f"{dts} --bia_dt={md.data_type}".strip()
return dts

@property
def aux(self):
rt_dim_masks = self.entry.aux.get("runtime_dims_masks", "")
Expand Down
31 changes: 25 additions & 6 deletions src/cpu/ref_deconvolution.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
* Copyright 2022 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -173,6 +173,16 @@ struct ref_deconvolution_fwd_t : public primitive_t {
alg_kind::deconvolution_direct,
alg_kind::deconvolution_winograd),
VERBOSE_BAD_ALGORITHM);
// This implementation will check data types requirements through
// an underlying convolution implementation, however, convolution
// might be called without bias, thus, need to check bias data type
// if it was requested.
if (with_bias()) {
const auto bia_type = invariant_wei_md(1)->data_type;
VDISPATCH_DECONVOLUTION(utils::one_of(bia_type, f32, bf16, f16,
f8_e5m2, f8_e4m3),
VERBOSE_UNSUPPORTED_DT);
}
VDISPATCH_DECONVOLUTION(attr()->has_default_values(skip_mask),
VERBOSE_UNSUPPORTED_ATTR);
VDISPATCH_DECONVOLUTION(
Expand Down Expand Up @@ -464,18 +474,27 @@ struct ref_deconvolution_bwd_weights_t : public primitive_t {
status_t init(engine_t *engine) {
using namespace format_tag;
using namespace data_type;
auto src_type = desc()->src_desc.data_type;
auto dwei_type = desc()->diff_weights_desc.data_type;
auto ddst_type = desc()->diff_dst_desc.data_type;
auto src_type = invariant_src_md()->data_type;
auto wei_type = invariant_wei_md(0)->data_type;
auto dst_type = invariant_dst_md()->data_type;
VDISPATCH_DECONVOLUTION(
desc()->prop_kind == prop_kind::backward_weights,
VERBOSE_BAD_PROPKIND);
VDISPATCH_DECONVOLUTION(utils::one_of(src_type, f32, bf16, f16),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_DECONVOLUTION(ddst_type == src_type,
VDISPATCH_DECONVOLUTION(dst_type == src_type,
VERBOSE_INCONSISTENT_DT, "diff_dst", "src");
VDISPATCH_DECONVOLUTION(utils::one_of(dwei_type, src_type, f32),
VDISPATCH_DECONVOLUTION(utils::one_of(wei_type, src_type, f32),
VERBOSE_UNSUPPORTED_DT);
// This implementation will check data types requirements through
// an underlying convolution implementation, however, convolution
// might be called without bias, thus, need to check bias data type
// if it was requested.
if (with_bias()) {
const auto bia_type = invariant_wei_md(1)->data_type;
VDISPATCH_DECONVOLUTION(utils::one_of(bia_type, f32, bf16, f16),
VERBOSE_UNSUPPORTED_DT);
}
VDISPATCH_DECONVOLUTION(utils::one_of(desc()->alg_kind,
alg_kind::deconvolution_direct,
alg_kind::deconvolution_winograd),
Expand Down
6 changes: 3 additions & 3 deletions src/cpu/x64/gemm_bf16_inner_product.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2024 Intel Corporation
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,6 +65,8 @@ struct gemm_bf16_inner_product_fwd_t : public primitive_t {
IMPLICATION(with_bias(),
one_of(weights_md(1)->data_type, f32, bf16)),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_INNER_PRODUCT(set_default_params() == status::success,
VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_INNER_PRODUCT(
attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops,
Expand All @@ -77,8 +79,6 @@ struct gemm_bf16_inner_product_fwd_t : public primitive_t {
VDISPATCH_INNER_PRODUCT(inner_product_utils::post_ops_ok(
attr()->post_ops_, &dst_md_),
VERBOSE_UNSUPPORTED_POSTOP);
VDISPATCH_INNER_PRODUCT(set_default_params() == status::success,
VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_INNER_PRODUCT(dense_gemm_consitency_check(
src_md(), weights_md(), dst_md()),
VERBOSE_INCOMPATIBLE_GEMM_FMT);
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/jit_avx512_core_amx_conv_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2024 Intel Corporation
* Copyright 2020-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -3225,7 +3225,7 @@ void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_xf16(
vpslld(zmm_in, zmm_in, 16);
break;
case data_type::f16: vcvtph2ps(zmm_in_k, addr); break;
case data_type::f32: vaddps(zmm_in_k, addr); return;
case data_type::f32: vmovups(zmm_in_k, addr); break;
default: assert(!"Unsupported data type in xf16 conv");
}
vaddps(zmm_out, zmm_in);
Expand Down
6 changes: 3 additions & 3 deletions src/cpu/x64/jit_brdgmm_dw_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
IMPLICATION(is_int8,
one_of(bia_type, data_type::undef, f32, s32, s8, u8)),
VERBOSE_UNSUPPORTED_BIAS_CFG);
VDISPATCH_CONV(
IMPLICATION(!is_int8,
one_of(bia_type, data_type::undef, src_type, dst_type)),
VDISPATCH_CONV(IMPLICATION(!is_int8,
one_of(bia_type, data_type::undef, data_type::f32,
src_type, dst_type)),
VERBOSE_UNSUPPORTED_BIAS_CFG);
VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct),
VERBOSE_BAD_ALGORITHM);
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ status_t brgemm_convolution_bwd_strided_t<isa>::pd_t::init(engine_t *engine) {
| skip_mask_t::zero_points_runtime;

const bool is_f32_supported
= everyone_is(f32, diff_src_type, wei_type, diff_dst_type);
= everyone_is(f32, diff_src_type, wei_type, diff_dst_type)
&& IMPLICATION(with_bias(), bias_md_.data_type == f32);

const bool is_xf16_supported = one_of(wei_type, bf16, f16)
&& wei_type == diff_dst_type && one_of(diff_src_type, wei_type, f32)
Expand Down
24 changes: 21 additions & 3 deletions tests/benchdnn/conv/bench_conv.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2017-2024 Intel Corporation
* Copyright 2017-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +43,7 @@ void check_correctness(
const settings_t &s, driver_task_executor_t &task_executor) {
for_(const auto &i_dir : s.dir)
for_(const auto &i_dt : s.dt)
for_(const auto &i_bia_dt_ : s.bia_dt)
for_(const auto &i_stag : s.stag)
for_(const auto &i_wtag : s.wtag)
for_(const auto &i_dtag : s.dtag)
Expand All @@ -52,8 +53,24 @@ void check_correctness(
for_(const auto &i_ctx_init : s.ctx_init)
for_(const auto &i_ctx_exe : s.ctx_exe)
for (const auto &i_mb : s.mb) {
const prb_t prb(s.desc, i_dir, i_dt, i_stag, i_wtag, i_dtag, i_strides,
i_alg, i_mb, i_attr, i_ctx_init, i_ctx_exe, s.impl_filter);
auto i_bia_dt = i_bia_dt_;
if (i_dir & FLAG_BIA) {
if (i_bia_dt != dnnl_data_type_undef) {
BENCHDNN_PRINT(0, "%s\n",
"Warning: `--dir=FWD_B,BWD_WB` options are "
"incompatible with `--bia-dt` option. To specify a "
"bias data type, use `--dir=FWD_D,FWD_I,BWD_W` values "
"intead.");
}
// The f32/f64 data type should be used as the default for bias with
// directions that include a bias.
const bool is_f64 = (i_dt.size() == 1 && i_dt[0] == dnnl_f64)
|| (i_dt.size() > 1 && i_dt[1] == dnnl_f64);
i_bia_dt = is_f64 ? dnnl_f64 : dnnl_f32;
}
const prb_t prb(s.desc, i_dir, i_dt, i_bia_dt, i_stag, i_wtag, i_dtag,
i_strides, i_alg, i_mb, i_attr, i_ctx_init, i_ctx_exe,
s.impl_filter);
if (s.pattern && !match_regex(prb.str(), s.pattern)) return;

bool has_dw_po = i_attr.post_ops.convolution_index() >= 0;
Expand Down Expand Up @@ -128,6 +145,7 @@ int bench(int argc, char **argv) {
|| parse_batch(bench, argv[0])
|| parse_dir(s.dir, def.dir, argv[0])
|| parse_multi_dt(s.dt, def.dt, argv[0], "dt")
|| parse_dt(s.bia_dt, def.bia_dt, argv[0], "bia-dt")
|| parse_tag(s.stag, def.stag, argv[0], "stag")
|| parse_tag(s.wtag, def.wtag, argv[0], "wtag")
|| parse_tag(s.dtag, def.dtag, argv[0], "dtag")
Expand Down
41 changes: 25 additions & 16 deletions tests/benchdnn/conv/conv.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2017-2024 Intel Corporation
* Copyright 2017-2025 Intel Corporation
* Copyright 2021 FUJITSU LIMITED
* Copyright 2021 Arm Ltd. and affiliates
*
Expand Down Expand Up @@ -222,8 +222,11 @@ dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
auto wei_d = dnn_mem_t::init_md(prb->ndims + prb->has_groups,
prb->wei_dims().data(), force_f32_dt ? dnnl_f32 : prb->get_dt(WEI),
prb->wtag, prb->strides[STRIDES_WEI]);
auto bia_d = dnn_mem_t::init_md(1, prb->bia_dims().data(),
force_f32_dt ? dnnl_f32 : prb->get_dt(BIA), tag::any);
benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> bia_d {};
if (prb->bia_dt() != dnnl_data_type_undef) {
bia_d = dnn_mem_t::init_md(1, prb->bia_dims().data(),
force_f32_dt ? dnnl_f32 : prb->get_dt(BIA), tag::any);
}
auto dst_d = dnn_mem_t::init_md(prb->ndims, prb->dst_dims().data(),
force_f32_dt ? dnnl_f32 : prb->get_dt(DST), prb->dtag,
prb->strides[STRIDES_DST]);
Expand Down Expand Up @@ -253,16 +256,14 @@ dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
| DNNL_ARG_WEIGHTS,
wei_mask);
}
const auto dw_bia_dt = prb->dir == FWD_B ? dnnl_f32 : dnnl_data_type_undef;
attr_args.prepare_dw_post_op(prb->attr, prb->get_dt(WEI), dw_bia_dt);
attr_args.prepare_dw_post_op(prb->attr, prb->get_dt(WEI), prb->bia_dt());
auto dnnl_attr = make_benchdnn_dnnl_wrapper(
create_dnnl_attr(prb->attr, attr_args));

switch (prb->dir) {
case FWD_D:
case FWD_B:
case FWD_I:
if (prb->dir != FWD_B) bia_d.reset(nullptr);
TIME_C_PD(DNN_SAFE_STATUS(
dnnl_convolution_forward_primitive_desc_create(
&init_pd_args.pd, init_pd_args.engine,
Expand All @@ -285,7 +286,6 @@ dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
break;
case BWD_W:
case BWD_WB:
if (prb->dir == BWD_W) bia_d.reset(nullptr);
TIME_C_PD(DNN_SAFE_STATUS(
dnnl_convolution_backward_weights_primitive_desc_create(
&init_pd_args.pd, init_pd_args.engine, alg, src_d,
Expand Down Expand Up @@ -325,13 +325,22 @@ int init_prim_ref(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref,
update_cpu_ref_attrs(cpu_attr);
std::vector<std::vector<dnnl_data_type_t>> prim_ref_dt {
prb->dt, {dnnl_f32}};
if (is_cpu()) prim_ref_dt.erase(prim_ref_dt.begin());
// If there's no bias, undef data type should be used for prim_ref as well.
dnnl_data_type_t cpu_bia_dt
= prb->bia_dt() == dnnl_data_type_undef ? prb->bia_dt() : dnnl_f32;
std::vector<dnnl_data_type_t> prim_ref_bia_dt {prb->bia_dt(), cpu_bia_dt};
if (is_cpu()) {
prim_ref_dt.erase(prim_ref_dt.begin());
prim_ref_bia_dt.erase(prim_ref_bia_dt.begin());
}
dnnl_primitive_t prim_ref_ {};

for (const auto &prim_ref_dt_i : prim_ref_dt) {
prb_t prb_cpu {*prb, prb->dir, prim_ref_dt_i, tag::any, tag::any,
tag::any, {vdims_t(STRIDES_SIZE)}, DIRECT, prb->mb, cpu_attr,
prb->ctx_init, prb->ctx_exe, prb->impl_filter};
for_(const auto &prim_ref_dt_i : prim_ref_dt)
for (const auto &prim_ref_bia_dt_i : prim_ref_bia_dt) {
prb_t prb_cpu {*prb, prb->dir, prim_ref_dt_i, prim_ref_bia_dt_i,
tag::any, tag::any, tag::any, {vdims_t(STRIDES_SIZE)}, DIRECT,
prb->mb, cpu_attr, prb->ctx_init, prb->ctx_exe,
prb->impl_filter};

init_pd_args_t<prb_t> init_pd_args(
/* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir,
Expand Down Expand Up @@ -366,9 +375,9 @@ int init_prim_ref(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref,
}

void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
skip_unimplemented_data_type(
{prb->get_dt(SRC), prb->get_dt(WEI), prb->get_dt(DST)}, prb->dir,
res);
skip_unimplemented_data_type({prb->get_dt(SRC), prb->get_dt(WEI),
prb->get_dt(BIA), prb->get_dt(DST)},
prb->dir, res);
skip_unimplemented_sum_po(prb->attr, res, dnnl_convolution,
prb->get_dt(SRC), prb->get_dt(DST));
skip_unimplemented_prelu_po(prb->attr, res, dnnl_convolution);
Expand Down Expand Up @@ -550,7 +559,7 @@ std::vector<data_kind_t> get_kinds_to_check(const prb_t *prb) {
check_kinds = {SRC};
} else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) {
check_kinds = {WEI};
if (prb->dir & FLAG_BIA) check_kinds.push_back(BIA);
if (prb->bia_dt() != dnnl_data_type_undef) check_kinds.push_back(BIA);
} else {
assert(!"unexpected!");
SAFE_V(FAIL);
Expand Down
Loading
Loading