Skip to content

Commit

Permalink
F32 C2 AVX512F GEMM
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723485230
  • Loading branch information
alankelly authored and xnnpack-bot committed Feb 7, 2025
1 parent 8695741 commit cb3a239
Show file tree
Hide file tree
Showing 84 changed files with 7,722 additions and 1,912 deletions.
176 changes: 176 additions & 0 deletions bench/f32-gemm-minmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,182 @@
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/1, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/2, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/3, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/4, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/5, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/6, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/7, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/8, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/9, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/10, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/11, /*nr=*/16, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/1, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/2, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/3, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/4, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast)

static void f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/5, /*nr=*/32, /*kr=*/2, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast)
#endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY


Expand Down
16 changes: 16 additions & 0 deletions cmake/gen/amd64_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,47 @@ SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-6x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-7x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-8x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-9x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-11x16c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S
Expand Down
29 changes: 18 additions & 11 deletions gemm_compiler/avx512bf16_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class Avx512Bf16(isa.Avx512F):

def __init__(self):
pass # Empty constructor
self.c = 2

def isa(self):
return 'avx512bf16'
Expand All @@ -33,27 +33,28 @@ def compute_asm(self):
def function_name(self, M, N, isa):
return f'xnn_bf16_f32_gemm_minmax_ukernel_{M}x{N}c2__asm_amd64_{isa}_broadcast'

def init_accumulators(self, M, N):
asm_string = super().init_accumulators(M, N)
asm_string += """
# Are there at least 4 bytes?
cmp rdx, 4
js inner_loop_tail\n"""

return asm_string

def outer_loop_prepare(self, M, N):
k_register = self.k_register()
kc_register = self.kc_register()
offset = M * 16 + self.c_ptr_stack_offset()
kmask = self.k_mask()
asm_string = f"""
# Copy k and flip bit.
mov {k_register}, rdx
and {k_register}, 0x2
and {kc_register}, 0xFFFFFFFFFFFFFFFD
and {kc_register}, {kmask}
mov [rsp + {offset}], {k_register}\n"""
return asm_string

def init_accumulators(self, M, N):
asm_string = super().init_accumulators(M, N)
asm_string += """
# Are there at least 4 bytes?
cmp rdx, 4
js inner_loop_tail\n"""

return asm_string

def inner_loop_tail(self, M, N):
k_register = self.k_register()
nc_register = self.nc_register()
Expand All @@ -75,3 +76,9 @@ def inner_loop_tail(self, M, N):
else:
asm_string += self.inner_loop_small_M_N(M=M, N=N, tail=True)
return asm_string

def element_size(self):
return 2

def k_mask(self):
return "0xFFFFFFFFFFFFFFFD"
Loading

0 comments on commit cb3a239

Please sign in to comment.