Skip to content

Commit

Permalink
benchdnn: deconv: introduce --bia-dt support
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Jan 2, 2025
1 parent 3ebde37 commit d766266
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 37 deletions.
2 changes: 1 addition & 1 deletion scripts/verbose_converter/src/benchdnn_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def aux(self):
class ConvolutionConverter(
AlgorithmMixin,
TagTripletMixin,
MultiDataTypeMixin,
MultiDataTypeWithBiasMixin,
Converter,
):
driver: str = "conv"
Expand Down
18 changes: 16 additions & 2 deletions tests/benchdnn/deconv/bench_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void check_correctness(
const settings_t &s, driver_task_executor_t &task_executor) {
for_(const auto &i_dir : s.dir)
for_(const auto &i_dt : s.dt)
for_(const auto &i_bia_dt_ : s.bia_dt)
for_(const auto &i_stag : s.stag)
for_(const auto &i_wtag : s.wtag)
for_(const auto &i_dtag : s.dtag)
Expand All @@ -50,8 +51,20 @@ void check_correctness(
for_(const auto &i_ctx_init : s.ctx_init)
for_(const auto &i_ctx_exe : s.ctx_exe)
for (const auto &i_mb : s.mb) {
const prb_t prb(s.desc, i_dir, i_dt, i_stag, i_wtag, i_dtag, i_alg,
i_mb, i_attr, i_ctx_init, i_ctx_exe, s.impl_filter);
auto i_bia_dt = i_bia_dt_;
if (i_dir & FLAG_BIA) {
if (i_bia_dt != dnnl_data_type_undef) {
BENCHDNN_PRINT(0, "%s\n",
"Warning: `--dir=FWD_B,BWD_WB` options are "
"incompatible with `--bia-dt` option. To specify a "
"bias data type, use `--dir=FWD_D,FWD_I,BWD_W` values "
"intead.");
}
// Hard-code bias data type for directions with bias included.
i_bia_dt = dnnl_f32;
}
const prb_t prb(s.desc, i_dir, i_dt, i_bia_dt, i_stag, i_wtag, i_dtag,
i_alg, i_mb, i_attr, i_ctx_init, i_ctx_exe, s.impl_filter);
if (s.pattern && !match_regex(prb.str(), s.pattern)) return;

task_executor.submit(
Expand Down Expand Up @@ -85,6 +98,7 @@ int bench(int argc, char **argv) {
|| parse_batch(bench, argv[0])
|| parse_dir(s.dir, def.dir, argv[0])
|| parse_multi_dt(s.dt, def.dt, argv[0], "dt")
|| parse_dt(s.bia_dt, def.bia_dt, argv[0], "bia-dt")
|| parse_tag(s.stag, def.stag, argv[0], "stag")
|| parse_tag(s.wtag, def.wtag, argv[0], "wtag")
|| parse_tag(s.dtag, def.dtag, argv[0], "dtag")
Expand Down
35 changes: 22 additions & 13 deletions tests/benchdnn/deconv/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,11 @@ dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
auto wei_d = dnn_mem_t::init_md(prb->ndims + prb->has_groups,
prb->wei_dims().data(), force_f32_dt ? dnnl_f32 : prb->get_dt(WEI),
prb->wtag);
auto bia_d = dnn_mem_t::init_md(1, prb->bia_dims().data(),
force_f32_dt ? dnnl_f32 : prb->get_dt(BIA), tag::any);
benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> bia_d {};
if (prb->bia_dt() != dnnl_data_type_undef) {
bia_d = dnn_mem_t::init_md(1, prb->bia_dims().data(),
force_f32_dt ? dnnl_f32 : prb->get_dt(BIA), tag::any);
}
auto dst_d = dnn_mem_t::init_md(prb->ndims, prb->dst_dims().data(),
force_f32_dt ? dnnl_f32 : prb->get_dt(DST), prb->dtag);

Expand All @@ -252,7 +255,6 @@ dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
case FWD_D:
case FWD_B:
case FWD_I:
if (prb->dir != FWD_B) bia_d.reset(nullptr);
TIME_C_PD(DNN_SAFE_STATUS(
dnnl_deconvolution_forward_primitive_desc_create(
&init_pd_args.pd, init_pd_args.engine,
Expand All @@ -275,7 +277,6 @@ dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
break;
case BWD_W:
case BWD_WB:
if (prb->dir == BWD_W) bia_d.reset(nullptr);
TIME_C_PD(DNN_SAFE_STATUS(
dnnl_deconvolution_backward_weights_primitive_desc_create(
&init_pd_args.pd, init_pd_args.engine, alg, src_d,
Expand Down Expand Up @@ -313,13 +314,21 @@ int init_prim_ref(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref,
update_cpu_ref_attrs(cpu_attr);
std::vector<std::vector<dnnl_data_type_t>> prim_ref_dt {
prb->dt, {dnnl_f32}};
if (is_cpu()) prim_ref_dt.erase(prim_ref_dt.begin());
// If there's no bias, undef data type should be used for prim_ref as well.
dnnl_data_type_t cpu_bia_dt
= prb->bia_dt() == dnnl_data_type_undef ? prb->bia_dt() : dnnl_f32;
std::vector<dnnl_data_type_t> prim_ref_bia_dt {prb->bia_dt(), cpu_bia_dt};
if (is_cpu()) {
prim_ref_dt.erase(prim_ref_dt.begin());
prim_ref_bia_dt.erase(prim_ref_bia_dt.begin());
}
dnnl_primitive_t prim_ref_ {};

for (const auto &prim_ref_dt_i : prim_ref_dt) {
prb_t prb_cpu {*prb, prb->dir, prim_ref_dt_i, tag::any, tag::any,
tag::any, DIRECT, prb->mb, cpu_attr, prb->ctx_init,
prb->ctx_exe, prb->impl_filter};
for_(const auto &prim_ref_dt_i : prim_ref_dt)
for (const auto &prim_ref_bia_dt_i : prim_ref_bia_dt) {
prb_t prb_cpu {*prb, prb->dir, prim_ref_dt_i, prim_ref_bia_dt_i,
tag::any, tag::any, tag::any, DIRECT, prb->mb, cpu_attr,
prb->ctx_init, prb->ctx_exe, prb->impl_filter};

init_pd_args_t<prb_t> init_pd_args(
/* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir,
Expand Down Expand Up @@ -354,9 +363,9 @@ int init_prim_ref(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref,
}

void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
skip_unimplemented_data_type(
{prb->get_dt(SRC), prb->get_dt(WEI), prb->get_dt(DST)}, prb->dir,
res);
skip_unimplemented_data_type({prb->get_dt(SRC), prb->get_dt(WEI),
prb->get_dt(BIA), prb->get_dt(DST)},
prb->dir, res);
skip_unimplemented_sum_po(
prb->attr, res, dnnl_deconvolution, prb->get_dt(SRC));
skip_unimplemented_prelu_po(prb->attr, res, dnnl_deconvolution);
Expand Down Expand Up @@ -527,7 +536,7 @@ std::vector<data_kind_t> get_kinds_to_check(const prb_t *prb) {
check_kinds = {SRC};
} else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) {
check_kinds = {WEI};
if (prb->dir & FLAG_BIA) check_kinds.push_back(BIA);
if (prb->bia_dt() != dnnl_data_type_undef) check_kinds.push_back(BIA);
} else {
assert(!"unexpected!");
SAFE_V(FAIL);
Expand Down
28 changes: 15 additions & 13 deletions tests/benchdnn/deconv/deconv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ struct settings_t : public base_settings_t {

std::vector<dir_t> dir {FWD_B};
std::vector<std::vector<dnnl_data_type_t>> dt {{dnnl_f32}};
std::vector<dnnl_data_type_t> bia_dt {dnnl_data_type_undef};
std::vector<std::string> stag {tag::any}, wtag {tag::any}, dtag {tag::any};
std::vector<alg_t> alg {DIRECT};

Expand All @@ -115,29 +116,31 @@ struct settings_t : public base_settings_t {
void reset() { *this = settings_t(perf_template); }

bool has_single_setup() const override {
return dir.size() == 1 && dt.size() == 1 && stag.size() == 1
&& wtag.size() == 1 && dtag.size() == 1 && alg.size() == 1
&& base_settings_t::has_single_setup();
return dir.size() == 1 && dt.size() == 1 && bia_dt.size() == 1
&& stag.size() == 1 && wtag.size() == 1 && dtag.size() == 1
&& alg.size() == 1 && base_settings_t::has_single_setup();
}
};

struct prb_t : public desc_t {
// A ctor with common interface across all drivers.
prb_t(const settings_t &s)
: prb_t(s.desc, s.dir[0], s.dt[0], s.stag[0], s.wtag[0], s.dtag[0],
s.alg[0], s.mb[0], s.attributes.front(), s.ctx_init[0],
s.ctx_exe[0], s.impl_filter) {
: prb_t(s.desc, s.dir[0], s.dt[0], s.bia_dt[0], s.stag[0], s.wtag[0],
s.dtag[0], s.alg[0], s.mb[0], s.attributes.front(),
s.ctx_init[0], s.ctx_exe[0], s.impl_filter) {
SAFE_V(s.has_single_setup() ? OK : FAIL);
}

prb_t(const desc_t &desc, dir_t dir,
const std::vector<dnnl_data_type_t> &dt, const std::string &stag,
const std::string &wtag, const std::string &dtag, alg_t alg,
int64_t mb, const attr_t &attr, const thr_ctx_t &ctx_init,
const thr_ctx_t &ctx_exe, const impl_filter_t &impl_filter)
const std::vector<dnnl_data_type_t> &dt, dnnl_data_type_t bia_dt,
const std::string &stag, const std::string &wtag,
const std::string &dtag, alg_t alg, int64_t mb, const attr_t &attr,
const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe,
const impl_filter_t &impl_filter)
: desc_t(desc)
, dir(dir)
, dt(dt)
, bia_dt_(bia_dt)
, stag(stag)
, wtag(wtag)
, dtag(dtag)
Expand All @@ -162,6 +165,7 @@ struct prb_t : public desc_t {

dir_t dir;
std::vector<dnnl_data_type_t> dt;
dnnl_data_type_t bia_dt_; // `_` to avoid conflicting name with bia_dt().
std::string stag, wtag, dtag;
alg_t alg;
int64_t user_mb;
Expand All @@ -180,9 +184,7 @@ struct prb_t : public desc_t {

dnnl_data_type_t src_dt() const { return dt[0]; }
dnnl_data_type_t wei_dt() const { return dt[1]; }
dnnl_data_type_t bia_dt() const {
return is_integral_dt(wei_dt()) ? dnnl_f32 : wei_dt();
} // TODO: customize
dnnl_data_type_t bia_dt() const { return bia_dt_; }
dnnl_data_type_t dst_dt() const { return dt[2]; }
dnnl_data_type_t get_dt(data_kind_t data_kind) const;

Expand Down
2 changes: 2 additions & 0 deletions tests/benchdnn/deconv/deconv_aux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ std::string prb_t::set_repro_line() {

if (canonical || dir != def.dir[0]) s << "--dir=" << dir << " ";
if (canonical || !has_default_dts) s << "--dt=" << dt << " ";
if ((canonical || bia_dt_ != def.bia_dt[0]) && !(dir & FLAG_BIA))
s << "--bia-dt=" << bia_dt_ << " ";
if (canonical || stag != def.stag[0]) s << "--stag=" << stag << " ";
if (canonical || wtag != def.wtag[0]) s << "--wtag=" << wtag << " ";
if (canonical || dtag != def.dtag[0]) s << "--dtag=" << dtag << " ";
Expand Down
10 changes: 5 additions & 5 deletions tests/benchdnn/deconv/ref_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ void compute_ref_direct_fwd(const prb_t *prb, const args_t &args) {
float conv_res = 0;
ker(conv_res, g, mb, oc, od, oh, ow);

if (prb->dir & FLAG_BIA) {
if (prb->bia_dt() != dnnl_data_type_undef) {
const size_t bia_off = bia_off_f(prb, g, oc);
conv_res += ((float *)bia_m)[bia_off];
}
Expand Down Expand Up @@ -287,7 +287,7 @@ void compute_ref_direct_bwd_d(const prb_t *prb, const args_t &args) {
else
ker(conv_res, g, mb, ic, id, ih, iw);

if (prb->dir & FLAG_BIA) {
if (prb->bia_dt() != dnnl_data_type_undef) {
const size_t bia_off = (size_t)g * ICG + ic;
conv_res += ((float *)bia_m)[bia_off];
}
Expand Down Expand Up @@ -481,7 +481,7 @@ void compute_ref_bwd_w(
// entry problem which is transposed - `p_tr`. Simpler to use the kernel
// directly.
// Take original memories, not `ref_conv_args`.
if (prb->dir & FLAG_BIA) {
if (prb->bia_dt() != dnnl_data_type_undef) {
const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS);
const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
/* help compiler optimize the code */
Expand Down Expand Up @@ -512,8 +512,8 @@ void compute_ref_bwd_w(
void compute_ref(
const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
// Update prb descriptor to re-use convolution reference.
prb_t prb_tr((desc_t)*prb, prb->dir, prb->dt, prb->stag, prb->wtag,
prb->dtag, prb->alg, prb->mb, prb->attr, prb->ctx_init,
prb_t prb_tr((desc_t)*prb, prb->dir, prb->dt, prb->bia_dt(), prb->stag,
prb->wtag, prb->dtag, prb->alg, prb->mb, prb->attr, prb->ctx_init,
prb->ctx_exe, prb->impl_filter);
std::swap(prb_tr.ic, prb_tr.oc);
std::swap(prb_tr.ih, prb_tr.oh);
Expand Down
6 changes: 3 additions & 3 deletions tests/benchdnn/deconv/ref_wino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ void compute_wino_ref_fwd(const prb_t *prb, const args_t &args) {
SAFE_V(prb->kh == 3 ? OK : FAIL);
SAFE_V(prb->kw == 3 ? OK : FAIL);

bool with_bias = prb->dir & FLAG_BIA;
bool with_bias = prb->bia_dt() != dnnl_data_type_undef;
const int64_t t_pad = prb->ph;
const int64_t l_pad = prb->pw;
const int64_t wp_max = prb->iw + l_pad;
Expand Down Expand Up @@ -431,7 +431,7 @@ void compute_wino_ref_bwd_d(const prb_t *prb, const args_t &args) {
const int64_t hp_max = prb->oh + t_pad;
const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles;

bool with_bias = prb->dir & FLAG_BIA;
bool with_bias = prb->bia_dt() != dnnl_data_type_undef;

benchdnn_parallel_nd(prb->mb, prb->oc, sp.h_tiles, sp.w_tiles,
[&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) {
Expand Down Expand Up @@ -621,7 +621,7 @@ void compute_wino_ref_bwd_w(const prb_t *prb, const args_t &args) {

free_scratchpad(&sp);

if (prb->dir & FLAG_BIA) compute_ref_bwd_bias(prb, args);
if (prb->bia_dt() != dnnl_data_type_undef) compute_ref_bwd_bias(prb, args);
}

} // namespace deconv

0 comments on commit d766266

Please sign in to comment.