From 97c3f6b79d1b0961534da2b0eef09c1938c8cde8 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Thu, 13 Feb 2025 07:03:31 -0800 Subject: [PATCH] WaveCache.module_op already returns the MLIR module str --------- Signed-off-by: tyb0807 --- tests/kernel/wave/attention/evoformer_test.py | 2 +- tests/kernel/wave/wave_gemm_test.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernel/wave/attention/evoformer_test.py b/tests/kernel/wave/attention/evoformer_test.py index 7b740fe7d..5dc2689e4 100644 --- a/tests/kernel/wave/attention/evoformer_test.py +++ b/tests/kernel/wave/attention/evoformer_test.py @@ -140,7 +140,7 @@ def testEvoformerAttentionForward( if dump_generated_mlir: filename = f"wave_evoformer_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) eps = 1e-2 if output.dtype == torch.float16 else 5e-2 assert ( diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index f58e06f84..1ec9e5d1f 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -205,7 +205,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -352,7 +352,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -501,7 +501,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -618,7 +618,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}_f8.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -733,7 +733,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: if test_dump_generated_mlir: filename = f"wave_gemm_{'x'.join(map(str, shape))}_f8.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: @@ -841,7 +841,7 @@ def repeat( if test_dump_generated_mlir: filename = f"wave_batched_gemm_{'x'.join(map(str, shape))}.mlir" with open(filename, "w") as f: - f.write(mb.module_op.get_asm()) + f.write(mb.module_op) if run_bench: if dump_perf is not None: