Skip to content

Commit

Permalink
Turn c2 asm f32 kernels
Browse files Browse the repository at this point in the history
This is the first time that we exploit the broken dependency between gemm & igemm

PiperOrigin-RevId: 724255293
  • Loading branch information
alankelly authored and xnnpack-bot committed Feb 7, 2025
1 parent 114acd2 commit 57fac03
Show file tree
Hide file tree
Showing 89 changed files with 8,527 additions and 2,608 deletions.
176 changes: 176 additions & 0 deletions bench/f32-gemm-minmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,182 @@
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/1, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/2, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/3, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/4, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/5, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/6, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/7, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/8, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/9, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/10, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/11, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/1, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/2, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/3, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/4, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/5, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast)
#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY


Expand Down
18 changes: 17 additions & 1 deletion cmake/gen/amd64_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

SET(PROD_AMD64_ASM_MICROKERNEL_SRCS
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S)
src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S)

SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
Expand Down Expand Up @@ -40,31 +42,45 @@ SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-6x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-7x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-8x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-9x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-11x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S
Expand Down
29 changes: 18 additions & 11 deletions gemm_compiler/avx512bf16_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class Avx512Bf16(isa.Avx512F):

def __init__(self):
pass # Empty constructor
self.c = 2

def isa(self):
return 'avx512bf16'
Expand All @@ -33,27 +33,28 @@ def compute_asm(self):
def function_name(self, M, N, isa):
return f'xnn_bf16_f32_gemm_minmax_ukernel_{M}x{N}c2__asm_amd64_{isa}_broadcast'

def init_accumulators(self, M, N):
asm_string = super().init_accumulators(M, N)
asm_string += """
# Are there at least 4 bytes?
cmp rdx, 4
js inner_loop_tail\n"""

return asm_string

def outer_loop_prepare(self, M, N):
k_register = self.k_register()
kc_register = self.kc_register()
offset = M * 16 + self.c_ptr_stack_offset()
kmask = self.k_mask()
asm_string = f"""
# Copy k and flip bit.
mov {k_register}, rdx
and {k_register}, 0x2
and {kc_register}, 0xFFFFFFFFFFFFFFFD
and {kc_register}, {kmask}
mov [rsp + {offset}], {k_register}\n"""
return asm_string

def init_accumulators(self, M, N):
asm_string = super().init_accumulators(M, N)
asm_string += """
# Are there at least 4 bytes?
cmp rdx, 4
js inner_loop_tail\n"""

return asm_string

def inner_loop_tail(self, M, N):
k_register = self.k_register()
nc_register = self.nc_register()
Expand All @@ -75,3 +76,9 @@ def inner_loop_tail(self, M, N):
else:
asm_string += self.inner_loop_small_M_N(M=M, N=N, tail=True)
return asm_string

def element_size(self):
return 2

def k_mask(self):
return "0xFFFFFFFFFFFFFFFD"
Loading

0 comments on commit 57fac03

Please sign in to comment.