Skip to content

Commit

Permalink
Remove all references to gemm in Conv2d.
Browse files Browse the repository at this point in the history
This path is handled by a subgraph re-write

PiperOrigin-RevId: 724272709
  • Loading branch information
alankelly authored and xnnpack-bot committed Feb 7, 2025
1 parent 8695741 commit da19f44
Showing 1 changed file with 3 additions and 188 deletions.
191 changes: 3 additions & 188 deletions src/operators/convolution-nhwc.c
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ static enum xnn_status create_dwconv_path(
return status;
}

static enum xnn_status create_gemm_or_igemm(
static enum xnn_status create_igemm(
enum xnn_microkernel_type ukernel_type,
uint32_t kernel_size,
uint32_t groups,
Expand All @@ -298,7 +298,6 @@ static enum xnn_status create_gemm_or_igemm(
uint32_t log2_input_element_size,
uint32_t log2_filter_element_size,
uint32_t bias_element_size,
xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w,
xnn_pack_conv_kgo_w_fn pack_conv_kgo_w,
xnn_pack_conv_goki_w_fn pack_conv_goki_w,
const void* packing_params,
Expand Down Expand Up @@ -365,49 +364,6 @@ static enum xnn_status create_gemm_or_igemm(
gemm_ukernels = &gemm_config->relu;
}
switch (ukernel_type) {
case xnn_microkernel_type_gemm:
if(!weights_already_cached) {
if (gemm_config->pack_weights_and_biases) {
size_t k_stride = round_up_po2(group_input_channels, kr * sr);
gemm_config->pack_weights_and_biases(
flags, gemm_config, group_input_channels, group_output_channels,
groups,
/*unused_block_size*/0,
k_stride,
/*accumulator_init=*/bias,
/*weights=*/kernel,
/*int_extra_data0_fn=*/(xnn_init_scale_params_fn)init_scale_params,
/*extra_data0=*/scale_params,
/*extra_data0_size=*/init_scale_params != NULL ? sizeof(float) : 0,
/*init_extra_data1_fn=*/
(xnn_init_scale_params_fn)init_kernel_scale_params,
/*extra_data1=*/(const void *) kernel_scale_params,
/*extra_data1_size=*/init_kernel_scale_params != NULL ? sizeof(float)
: 0,
/*packed_weights_ptr=*/weights_ptr, packing_params);
// Kernel and bias have already been packed so prevent them from being
// packed again below.
weights_already_cached = true;
} else {
pack_gemm_goi_w(groups, group_output_channels, group_input_channels,
nr, kr, sr,
kernel, bias, /*scale=*/NULL, weights_ptr, gemm_config->nr * extra_weights_bytes,
packing_params);
}
}
convolution_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
.mr = mr,
.nr = nr,
.kr = kr,
.sr = sr,
};

assert(XNN_MAX_MR >= mr);
for (size_t i = 0; i < mr; i++) {
convolution_op->ukernel.gemm.gemm_cases[i] = gemm_ukernels->gemm[i];
}

break;
case xnn_microkernel_type_igemm:
if(!weights_already_cached) {
if (flags & XNN_FLAG_DEPTHWISE_CONVOLUTION) {
Expand Down Expand Up @@ -523,7 +479,6 @@ static enum xnn_status create_convolution2d_nhwc(
xnn_pack_vmulcaddc_w_fn pack_vmulcaddc_w,
xnn_pack_dwconv_hwg_w_fn pack_dwconv_hwg_w,
xnn_pack_dwconv_ghw_w_fn pack_dwconv_ghw_w,
xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w,
xnn_pack_conv_kgo_w_fn pack_conv_kgo_w,
xnn_pack_conv_goki_w_fn pack_conv_goki_w,
const void* packing_params,
Expand Down Expand Up @@ -664,8 +619,6 @@ static enum xnn_status create_convolution2d_nhwc(
ukernel_type = xnn_microkernel_type_vmulcaddc;
} else if (group_input_channels == 1 && group_output_channels == 1 && dwconv_ukernel != NULL) {
ukernel_type = xnn_microkernel_type_dwconv;
} else if (kernel_size == 1 && unit_subsampling && !any_padding && !dynamic_quantization) {
ukernel_type = xnn_microkernel_type_gemm;
} else {
ukernel_type = xnn_microkernel_type_igemm;
}
Expand Down Expand Up @@ -701,15 +654,14 @@ static enum xnn_status create_convolution2d_nhwc(
}
break;
}
case xnn_microkernel_type_gemm:
case xnn_microkernel_type_igemm:
{
status = create_gemm_or_igemm(
status = create_igemm(
ukernel_type, kernel_size,
groups, group_input_channels, group_output_channels,
kernel, bias, flags,
log2_input_element_size, log2_filter_element_size, bias_element_size,
pack_gemm_goi_w, pack_conv_kgo_w, pack_conv_goki_w, packing_params,
pack_conv_kgo_w, pack_conv_goki_w, packing_params,
packed_weights_padding_byte, extra_weights_bytes,
init_scale_params, scale_params, init_kernel_scale_params, kernel_scale_params,
gemm_params, gemm_params_size, gemm_config,
Expand Down Expand Up @@ -851,7 +803,6 @@ enum xnn_status create_convolution2d_nhwc_qx8_f16_qc8w(
(xnn_pack_vmulcaddc_w_fn) NULL,
(xnn_pack_dwconv_hwg_w_fn) NULL,
(xnn_pack_dwconv_ghw_w_fn) NULL,
(xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi,
(xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w,
(xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w,
/*packing_params=*/&packing_params,
Expand Down Expand Up @@ -1016,7 +967,6 @@ enum xnn_status create_convolution2d_nhwc_qx8_f32_qc8w(
(xnn_pack_vmulcaddc_w_fn) NULL,
(xnn_pack_dwconv_hwg_w_fn) NULL,
(xnn_pack_dwconv_ghw_w_fn) NULL,
(xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi,
(xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w,
(xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w,
/*packing_params=*/&packing_params,
Expand Down Expand Up @@ -1224,7 +1174,6 @@ enum xnn_status xnn_create_convolution2d_nhwc_qu8(
(xnn_pack_vmulcaddc_w_fn) NULL,
(xnn_pack_dwconv_hwg_w_fn) xnn_pack_qu8_dwconv_hwg_w,
(xnn_pack_dwconv_ghw_w_fn) xnn_pack_qu8_dwconv_ghw_w,
(xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi,
(xnn_pack_conv_kgo_w_fn) xnn_pack_qu8_conv_kgo_w,
(xnn_pack_conv_goki_w_fn) xnn_pack_qu8_conv_goki_w,
/*packing_params=*/&packing_params,
Expand Down Expand Up @@ -1368,7 +1317,6 @@ enum xnn_status xnn_create_convolution2d_nhwc_qs8(
(xnn_pack_vmulcaddc_w_fn) NULL,
(xnn_pack_dwconv_hwg_w_fn) xnn_pack_qs8_dwconv_hwg_w,
(xnn_pack_dwconv_ghw_w_fn) xnn_pack_qs8_dwconv_ghw_w,
(xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi,
(xnn_pack_conv_kgo_w_fn) gemm_config->pack_igemm_kgo,
(xnn_pack_conv_goki_w_fn) gemm_config->pack_igemm_goki,
/*packing_params=*/&packing_params,
Expand Down Expand Up @@ -1520,7 +1468,6 @@ enum xnn_status create_convolution2d_nhwc_qx8_qc8w(
(xnn_pack_vmulcaddc_w_fn) NULL,
(xnn_pack_dwconv_hwg_w_fn) xnn_pack_qs8_dwconv_hwg_w,
(xnn_pack_dwconv_ghw_w_fn) xnn_pack_qs8_dwconv_ghw_w,
(xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi,
(xnn_pack_conv_kgo_w_fn) gemm_config->pack_igemm_kgo,
(xnn_pack_conv_goki_w_fn) gemm_config->pack_igemm_goki,
/*packing_params=*/&packing_params,
Expand Down Expand Up @@ -1680,14 +1627,12 @@ enum xnn_status xnn_create_convolution2d_nhwc_f16(
xnn_pack_vmulcaddc_w_fn pack_vmulcaddc_w = (xnn_pack_vmulcaddc_w_fn) xnn_pack_f16_vmulcaddc_w;
xnn_pack_dwconv_hwg_w_fn pack_dwconv_hwg_w = (xnn_pack_dwconv_hwg_w_fn) xnn_pack_f16_dwconv_hwg_w;
xnn_pack_dwconv_ghw_w_fn pack_dwconv_ghw_w = (xnn_pack_dwconv_ghw_w_fn) xnn_pack_f16_dwconv_ghw_w;
xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w = (xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi;
xnn_pack_conv_kgo_w_fn pack_conv_kgo_w = (xnn_pack_conv_kgo_w_fn) xnn_pack_f16_conv_kgo_w;
xnn_pack_conv_goki_w_fn pack_conv_goki_w = (xnn_pack_conv_goki_w_fn) xnn_pack_f16_conv_goki_w;
if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) {
pack_vmulcaddc_w = (xnn_pack_vmulcaddc_w_fn) xnn_pack_f32_to_f16_vmulcaddc_w;
pack_dwconv_hwg_w = (xnn_pack_dwconv_hwg_w_fn) xnn_pack_f32_to_f16_dwconv_hwg_w;
pack_dwconv_ghw_w = (xnn_pack_dwconv_ghw_w_fn) xnn_pack_f32_to_f16_dwconv_ghw_w;
pack_gemm_goi_w = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_f32_to_f16_gemm_goi_w;
pack_conv_kgo_w = (xnn_pack_conv_kgo_w_fn) xnn_pack_f32_to_f16_conv_kgo_w;
pack_conv_goki_w = (xnn_pack_conv_goki_w_fn) xnn_pack_f32_to_f16_conv_goki_w;
}
Expand All @@ -1706,7 +1651,6 @@ enum xnn_status xnn_create_convolution2d_nhwc_f16(
pack_vmulcaddc_w,
pack_dwconv_hwg_w,
pack_dwconv_ghw_w,
pack_gemm_goi_w,
pack_conv_kgo_w,
pack_conv_goki_w,
/*packing_params=*/NULL,
Expand Down Expand Up @@ -1829,7 +1773,6 @@ enum xnn_status create_convolution2d_nhwc_f32(
(xnn_pack_vmulcaddc_w_fn) xnn_pack_f32_vmulcaddc_w,
(xnn_pack_dwconv_hwg_w_fn) xnn_pack_f32_dwconv_hwg_w,
(xnn_pack_dwconv_ghw_w_fn) xnn_pack_f32_dwconv_ghw_w,
(xnn_packw_gemm_goi_ukernel_fn) gemm_config->pack_gemm_goi,
(xnn_pack_conv_kgo_w_fn) xnn_pack_f32_conv_kgo_w,
(xnn_pack_conv_goki_w_fn) xnn_pack_f32_conv_goki_w,
/*packing_params=*/NULL,
Expand Down Expand Up @@ -1999,117 +1942,6 @@ static inline bool input_size_changed(xnn_operator_t convolution_op)
convolution_op->input_width != convolution_op->last_input_width;
}

static enum xnn_status reshape_gemm(
xnn_operator_t convolution_op,
uint32_t log2_input_element_size,
uint32_t log2_filter_element_size,
uint32_t extra_weights_elements_size,
uint32_t log2_output_element_size,
size_t* workspace_size,
size_t* workspace_alignment,
size_t num_threads)
{
// Convolution maps directly to GEMM and doesn't use indirection buffer.
const size_t batch_size = convolution_op->batch_size;

const size_t output_height = convolution_op->output_height;
const size_t output_width = convolution_op->output_width;
const size_t output_size = output_height * output_width;
const size_t batch_output_size = batch_size * output_size;

const size_t groups = convolution_op->groups;
const size_t group_input_channels = convolution_op->group_input_channels;
const size_t w_stride = extra_weights_elements_size +
(round_up_po2(group_input_channels, convolution_op->ukernel.gemm.kr * convolution_op->ukernel.gemm.sr) << log2_filter_element_size);
const size_t group_output_channels = convolution_op->group_output_channels;

uint32_t mr = convolution_op->ukernel.gemm.mr;
const uint32_t nr = convolution_op->ukernel.gemm.nr;
struct xnn_hmp_gemm_ukernel *gemm_cases = convolution_op->ukernel.gemm.gemm_cases;

#if XNN_ENABLE_GEMM_M_SPECIALIZATION
mr = xnn_get_heuristic_mr_gemm(batch_output_size, mr, nr, gemm_cases);
#else
if (batch_output_size == 1 && gemm_cases[0].function[XNN_UARCH_DEFAULT] != NULL) {
mr = 1;
}
#endif

struct xnn_hmp_gemm_ukernel gemm_ukernel = gemm_cases[mr - 1];

convolution_op->context.gemm.gemm.gemm = (struct gemm_context){
.k_scaled = group_input_channels << log2_input_element_size,
.a_stride = convolution_op->input_pixel_stride << log2_input_element_size,
.ga_stride = group_input_channels << log2_input_element_size,
.packed_w = packed_weights(convolution_op),
.w_stride = w_stride,
.gw_stride = w_stride * round_up(group_output_channels, nr),
.cm_stride = convolution_op->output_pixel_stride
<< log2_output_element_size,
.cn_stride = nr << log2_output_element_size,
.gc_stride = group_output_channels << log2_output_element_size,
.log2_csize = log2_output_element_size,
.num_batch_dims = 1,
.ukernel = gemm_ukernel,
.mr = mr,
};
convolution_op->context.gemm.gemm.gemm.batch_dims_a[0] = groups;
convolution_op->context.gemm.gemm.gemm.batch_dims_b[0] = groups;
convolution_op->context.gemm.gemm.gemm.batch_strides_c[0] = 1;
memcpy(&convolution_op->context.gemm.gemm.gemm.params, &convolution_op->params, sizeof(convolution_op->context.gemm.gemm.gemm.params));
convolution_op->context.gemm.gemm.gemm.fused_params = &convolution_op->context.gemm.gemm.gemm.params;

// Compute the optimal tile size for this GEMM.
const size_t nc = xnn_gemm_best_tile_size(
/*num_groups=*/groups, /*m=*/batch_output_size,
/*n=*/group_output_channels,
/*m_stride=*/convolution_op->context.gemm.gemm.gemm.a_stride,
/*n_stride=*/convolution_op->context.gemm.gemm.gemm.w_stride,
/*cm_stride=*/convolution_op->context.gemm.gemm.gemm.cm_stride,
/*cn_stride=*/1 << log2_output_element_size, mr, nr, num_threads);

if (groups == 1) {
#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_with_uarch;
convolution_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm;
} else {
convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d;
convolution_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
}
#else
convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d;
convolution_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
#endif
convolution_op->compute[0].range[1] = batch_output_size;
convolution_op->compute[0].range[0] = group_output_channels;
} else {
#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
convolution_op->compute[0].type = xnn_parallelization_type_3d_tile_2d_with_uarch;
convolution_op->compute[0].task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_grouped_gemm;
} else {
convolution_op->compute[0].type = xnn_parallelization_type_3d_tile_2d;
convolution_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
}
#else
convolution_op->compute[0].type = xnn_parallelization_type_3d_tile_2d;
convolution_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
#endif
convolution_op->compute[0].range[0] = groups;
convolution_op->compute[0].range[2] = batch_output_size;
convolution_op->compute[0].range[1] = group_output_channels;
}
convolution_op->compute[0].tile[1] = mr;
convolution_op->compute[0].tile[0] = nc;
convolution_op->state = xnn_run_state_needs_setup;

*workspace_size = 0;
*workspace_alignment = 1;

return xnn_status_success;
}

static enum xnn_status reshape_igemm(
xnn_operator_t convolution_op,
uint32_t log2_input_element_size,
Expand Down Expand Up @@ -2681,11 +2513,6 @@ static enum xnn_status reshape_convolution2d_nhwc(

const size_t num_threads = pthreadpool_get_threads_count(threadpool);
switch (convolution_op->ukernel.type) {
case xnn_microkernel_type_gemm:
return reshape_gemm(
convolution_op,
log2_input_element_size, log2_filter_element_size, extra_weights_elements_size, log2_output_element_size,
workspace_size, workspace_alignment, num_threads);
case xnn_microkernel_type_igemm:
return reshape_igemm(
convolution_op,
Expand Down Expand Up @@ -2977,16 +2804,6 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_f32(
threadpool);
}

static enum xnn_status setup_gemm(xnn_operator_t convolution_op)
{
convolution_op->context.gemm.gemm.gemm.a = convolution_op->input;
convolution_op->context.gemm.gemm.gemm.c = convolution_op->output;
convolution_op->context.gemm.gemm.gemm.quantization_params = convolution_op->quantization_params;
convolution_op->state = xnn_run_state_ready;

return xnn_status_success;
}

static enum xnn_status setup_igemm(
xnn_operator_t convolution_op,
void* workspace,
Expand Down Expand Up @@ -3086,8 +2903,6 @@ static enum xnn_status setup_convolution2d_nhwc(
convolution_op->quantization_params = quantization_params;

switch (convolution_op->ukernel.type) {
case xnn_microkernel_type_gemm:
return setup_gemm(convolution_op);
case xnn_microkernel_type_igemm:
return setup_igemm(convolution_op, workspace, log2_input_element_size);
case xnn_microkernel_type_dwconv:
Expand Down

0 comments on commit da19f44

Please sign in to comment.