Skip to content

Commit

Permalink
Remove all references to gemm in deconv
Browse files Browse the repository at this point in the history
This allows us to break the dependency between gemm & igemm meaning that we can select optimal kernels everywhere without unnecessary constraints.

PiperOrigin-RevId: 724280843
  • Loading branch information
alankelly authored and xnnpack-bot committed Feb 7, 2025
1 parent 114acd2 commit 3fa35de
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 149 deletions.
37 changes: 0 additions & 37 deletions src/operator-run.c
Original file line number Diff line number Diff line change
Expand Up @@ -885,43 +885,6 @@ void xnn_compute_grouped_subgemm2d(
&context->params);
}

void xnn_compute_subgemm2d(
const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t subkernel_index,
size_t slice_y,
size_t slice_x_start,
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size)
{
const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];

if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
return;
}

const size_t slice_width = subconvolution_params->slice_width;
if XNN_UNLIKELY(slice_x_start >= slice_width) {
return;
}
const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);

const size_t ax_stride = context->ax_stride;
const size_t cx_stride = context->cx_stride;
context->ukernel.function[XNN_UARCH_DEFAULT](
slice_x_size,
nc_block_size,
context->kc,
(const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
ax_stride,
(const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
(void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
cx_stride,
context->cn_stride,
&context->params);
}

void xnn_compute_grouped_subconv2d(
const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
Expand Down
150 changes: 48 additions & 102 deletions src/operators/deconvolution-nhwc.c
Original file line number Diff line number Diff line change
Expand Up @@ -1481,15 +1481,9 @@ static enum xnn_status reshape_subconv2d_path(
const size_t output_size = output_height * output_width;
uint32_t mr = deconvolution_op->ukernel.igemm.mr;
const uint32_t nr = deconvolution_op->ukernel.igemm.nr;
bool use_gemm = deconvolution_op->ukernel.subtype == xnn_microkernel_type_gemm;
#if XNN_ENABLE_GEMM_M_SPECIALIZATION
if (use_gemm) {
mr = xnn_get_heuristic_mr_gemm(
batch_size, mr, nr, deconvolution_op->ukernel.igemm.gemm_cases);
} else {
mr = xnn_get_heuristic_mr_igemm(
batch_size, mr, nr, deconvolution_op->ukernel.igemm.igemm_cases);
}
mr = xnn_get_heuristic_mr_igemm(
batch_size, mr, nr, deconvolution_op->ukernel.igemm.igemm_cases);
#endif

const size_t input_pixel_stride = deconvolution_op->input_pixel_stride << log2_input_element_size;
Expand Down Expand Up @@ -1540,27 +1534,25 @@ static enum xnn_status reshape_subconv2d_path(
}

if (any_size_change) {
if (!use_gemm) {
const size_t indirection_buffer_size = sizeof(void*) *
kernel_size * output_height * stride_width * round_up(divide_round_up(output_width, stride_width), mr);

const void** indirection_buffer =
(const void**) xnn_reallocate_memory(deconvolution_op->indirection_buffer, indirection_buffer_size);
if (indirection_buffer == NULL) {
xnn_log_error(
"failed to allocate %zu bytes for %s operator indirection buffer",
indirection_buffer_size, xnn_operator_type_to_string(deconvolution_op->type));
return xnn_status_out_of_memory;
}
deconvolution_op->indirection_buffer = indirection_buffer;
xnn_log_debug("allocated %zu bytes for indirection buffer in %s operator",
indirection_buffer_size, xnn_operator_type_to_string(deconvolution_op->type));

// Set a dummy input first, the actual input offset is calculated in setup when we have the input pointer.
// This offset must be aligned properly because inputs and input offsets need to be aligned.
deconvolution_op->input = (void*) ((uintptr_t) deconvolution_op->zero_buffer + XNN_ALLOCATION_ALIGNMENT);
xnn_indirection_init_subconv2d(deconvolution_op, mr, log2_input_element_size);
const size_t indirection_buffer_size = sizeof(void*) *
kernel_size * output_height * stride_width * round_up(divide_round_up(output_width, stride_width), mr);

const void** indirection_buffer =
(const void**) xnn_reallocate_memory(deconvolution_op->indirection_buffer, indirection_buffer_size);
if (indirection_buffer == NULL) {
xnn_log_error(
"failed to allocate %zu bytes for %s operator indirection buffer",
indirection_buffer_size, xnn_operator_type_to_string(deconvolution_op->type));
return xnn_status_out_of_memory;
}
deconvolution_op->indirection_buffer = indirection_buffer;
xnn_log_debug("allocated %zu bytes for indirection buffer in %s operator",
indirection_buffer_size, xnn_operator_type_to_string(deconvolution_op->type));

// Set a dummy input first, the actual input offset is calculated in setup when we have the input pointer.
// This offset must be aligned properly because inputs and input offsets need to be aligned.
deconvolution_op->input = (void*) ((uintptr_t) deconvolution_op->zero_buffer + XNN_ALLOCATION_ALIGNMENT);
xnn_indirection_init_subconv2d(deconvolution_op, mr, log2_input_element_size);

deconvolution_op->last_input = deconvolution_op->input;
deconvolution_op->last_input_height = input_height;
Expand All @@ -1576,44 +1568,23 @@ static enum xnn_status reshape_subconv2d_path(
const uint32_t sr = deconvolution_op->ukernel.igemm.sr;
const size_t w_stride = stride_height * stride_width * extra_weights_element_size +
(round_up_po2(group_input_channels, kr * sr) * kernel_size << log2_filter_element_size);
if (use_gemm) {
struct xnn_hmp_gemm_ukernel* gemm_cases = deconvolution_op->ukernel.igemm.gemm_cases;
deconvolution_op->context.subgemm = (struct subgemm_context) {
.subconvolution_params = deconvolution_op->subconvolution_buffer,
.kc = group_input_channels << log2_input_element_size,
.ax_stride = input_pixel_stride,
.ay_stride = input_width * input_pixel_stride,
.cx_stride = stride_width * output_pixel_stride,
.cy_stride = stride_height * output_width * output_pixel_stride,
.cn_stride = nr << log2_output_element_size,
.ga_stride = group_input_channels << log2_input_element_size,
.gw_stride = w_stride * round_up(group_output_channels, nr),
.gc_stride = group_output_channels << log2_output_element_size,
.ba_stride = input_height * input_width * input_pixel_stride,
.bc_stride = output_size * output_pixel_stride,
.log2_csize = log2_output_element_size,
.ukernel = gemm_cases[mr - 1],
};
memcpy(&deconvolution_op->context.subgemm.params, params, params_size);
} else {
struct xnn_hmp_igemm_ukernel* igemm_cases = deconvolution_op->ukernel.igemm.igemm_cases;
deconvolution_op->context.subconv = (struct subconv_context) {
.subconvolution_params = deconvolution_op->subconvolution_buffer,
.kc = group_input_channels << log2_input_element_size,
.zero = deconvolution_op->zero_buffer,
.cx_stride = stride_width * output_pixel_stride,
.cy_stride = stride_height * output_width * output_pixel_stride,
.cn_stride = nr << log2_output_element_size,
.ga_stride = group_input_channels << log2_input_element_size,
.gw_stride = w_stride * round_up(group_output_channels, nr),
.gc_stride = group_output_channels << log2_output_element_size,
.ba_stride = input_height * input_width * input_pixel_stride,
.bc_stride = output_size * output_pixel_stride,
.log2_csize = log2_output_element_size,
.ukernel = igemm_cases[mr - 1],
};
memcpy(&deconvolution_op->context.subconv.params, params, params_size);
}
struct xnn_hmp_igemm_ukernel* igemm_cases = deconvolution_op->ukernel.igemm.igemm_cases;
deconvolution_op->context.subconv = (struct subconv_context) {
.subconvolution_params = deconvolution_op->subconvolution_buffer,
.kc = group_input_channels << log2_input_element_size,
.zero = deconvolution_op->zero_buffer,
.cx_stride = stride_width * output_pixel_stride,
.cy_stride = stride_height * output_width * output_pixel_stride,
.cn_stride = nr << log2_output_element_size,
.ga_stride = group_input_channels << log2_input_element_size,
.gw_stride = w_stride * round_up(group_output_channels, nr),
.gc_stride = group_output_channels << log2_output_element_size,
.ba_stride = input_height * input_width * input_pixel_stride,
.bc_stride = output_size * output_pixel_stride,
.log2_csize = log2_output_element_size,
.ukernel = igemm_cases[mr - 1],
};
memcpy(&deconvolution_op->context.subconv.params, params, params_size);

size_t nc = group_output_channels;
if (num_threads > 1) {
Expand All @@ -1637,14 +1608,10 @@ static enum xnn_status reshape_subconv2d_path(
}
if (groups == 1) {
deconvolution_op->compute[igemm_compute_index].type = xnn_parallelization_type_5d_tile_2d;
if (use_gemm) {
deconvolution_op->compute[igemm_compute_index].task_5d_tile_2d = (pthreadpool_task_5d_tile_2d_t) xnn_compute_subgemm2d;
if (dynamic_quantization) {
deconvolution_op->compute[igemm_compute_index].task_5d_tile_2d = (pthreadpool_task_5d_tile_2d_t) xnn_compute_dqsubconv2d;
} else {
if (dynamic_quantization) {
deconvolution_op->compute[igemm_compute_index].task_5d_tile_2d = (pthreadpool_task_5d_tile_2d_t) xnn_compute_dqsubconv2d;
} else {
deconvolution_op->compute[igemm_compute_index].task_5d_tile_2d = (pthreadpool_task_5d_tile_2d_t) xnn_compute_subconv2d;
}
deconvolution_op->compute[igemm_compute_index].task_5d_tile_2d = (pthreadpool_task_5d_tile_2d_t) xnn_compute_subconv2d;
}
deconvolution_op->compute[igemm_compute_index].range[0] = batch_size;
deconvolution_op->compute[igemm_compute_index].range[1] = stride_height * stride_width;
Expand All @@ -1655,14 +1622,10 @@ static enum xnn_status reshape_subconv2d_path(
deconvolution_op->compute[igemm_compute_index].tile[1] = nc;
} else {
deconvolution_op->compute[igemm_compute_index].type = xnn_parallelization_type_6d_tile_2d;
if (use_gemm) {
deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_subgemm2d;
if (dynamic_quantization) {
deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_dqsubconv2d;
} else {
if (dynamic_quantization) {
deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_dqsubconv2d;
} else {
deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_subconv2d;
}
deconvolution_op->compute[igemm_compute_index].task_6d_tile_2d = (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_subconv2d;
}
deconvolution_op->compute[igemm_compute_index].range[0] = batch_size;
deconvolution_op->compute[igemm_compute_index].range[1] = groups;
Expand Down Expand Up @@ -1760,19 +1723,7 @@ static enum xnn_status reshape_deconvolution2d_nhwc(
params, params_size, num_threads);
case xnn_microkernel_type_subconv2d:
{
const size_t mr = deconvolution_op->ukernel.igemm.mr;
const bool no_padding = (deconvolution_op->padding_top | deconvolution_op->padding_right | deconvolution_op->padding_bottom | deconvolution_op->padding_left) == 0;
const bool no_adjustment = (adjustment_height | adjustment_width) == 0;
const bool use_gemm = no_padding && no_adjustment &&
deconvolution_op->kernel_height == deconvolution_op->stride_height &&
deconvolution_op->kernel_width == deconvolution_op->stride_width &&
deconvolution_op->ukernel.igemm.gemm_cases[mr - 1].function[XNN_UARCH_DEFAULT] != NULL &&
!dynamic_quantization;
if (use_gemm) {
deconvolution_op->ukernel.subtype = xnn_microkernel_type_gemm;
} else {
deconvolution_op->ukernel.subtype = xnn_microkernel_type_igemm;
}
deconvolution_op->ukernel.subtype = xnn_microkernel_type_igemm;
return reshape_subconv2d_path(
deconvolution_op,
batch_size,
Expand Down Expand Up @@ -2047,7 +1998,6 @@ static enum xnn_status setup_subconv2d_path(
{
assert(deconvolution_op->ukernel.type == xnn_microkernel_type_subconv2d);

bool use_gemm = deconvolution_op->ukernel.subtype == xnn_microkernel_type_gemm;
const size_t stride_height = deconvolution_op->stride_height;
const size_t stride_width = deconvolution_op->stride_width;

Expand All @@ -2063,14 +2013,10 @@ static enum xnn_status setup_subconv2d_path(
deconvolution_op->last_output = output;
}

if (use_gemm) {
deconvolution_op->context.subgemm.a = input;
} else {
deconvolution_op->context.subconv.a_offset = (size_t) ((uintptr_t) input - (uintptr_t) deconvolution_op->last_input);
deconvolution_op->context.subconv.zero_size = deconvolution_op->zero_size;
deconvolution_op->context.subconv.zero_buffers = deconvolution_op->zero_buffers;
deconvolution_op->context.subconv.quantization_params = deconvolution_op->quantization_params;
}
deconvolution_op->context.subconv.a_offset = (size_t) ((uintptr_t) input - (uintptr_t) deconvolution_op->last_input);
deconvolution_op->context.subconv.zero_size = deconvolution_op->zero_size;
deconvolution_op->context.subconv.zero_buffers = deconvolution_op->zero_buffers;
deconvolution_op->context.subconv.quantization_params = deconvolution_op->quantization_params;

deconvolution_op->state = xnn_run_state_ready;
return xnn_status_success;
Expand Down
10 changes: 0 additions & 10 deletions src/xnnpack/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -658,16 +658,6 @@ struct subgemm_context {
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size);

XNN_PRIVATE void xnn_compute_subgemm2d(
const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t batch_index,
size_t subkernel_index,
size_t slice_y,
size_t slice_x_start,
size_t nc_block_start,
size_t slice_x_max,
size_t nc_block_size);
#endif

struct subconv_context {
Expand Down

0 comments on commit 3fa35de

Please sign in to comment.