diff --git a/.github/workflows/compiler-and-runtime-build.yml b/.github/workflows/compiler-and-runtime-build.yml index c3542d752..d48d19eb4 100644 --- a/.github/workflows/compiler-and-runtime-build.yml +++ b/.github/workflows/compiler-and-runtime-build.yml @@ -15,6 +15,14 @@ on: - "scripts/frontends/**" workflow_dispatch: +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: clear_workspace: name: Clear workspace diff --git a/.github/workflows/e2e_test.yaml b/.github/workflows/e2e_test.yaml new file mode 100644 index 000000000..6433bbcad --- /dev/null +++ b/.github/workflows/e2e_test.yaml @@ -0,0 +1,31 @@ +name: e2e Numerical test CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + workflow_dispatch: + +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + numerical_e2e_test: + name: e2e CI + runs-on: self-hosted + steps: + - name: clear workspace + run: rm -rf $GITHUB_WORKSPACE && mkdir $GITHUB_WORKSPACE + - name: Checkout byteir repo + uses: actions/checkout@v3 + - name: Build and test e2e + run: ./scripts/e2e/build_and_test_e2e.sh ${{ secrets.LLVM_INSTALL_DIR }} ${{ secrets.TORCH_FRONTEND_LLVM_INSTALL_DIR }} + shell: bash diff --git a/.github/workflows/torch-frontend-ci.yml b/.github/workflows/torch-frontend-ci.yml index 9b20f7efc..c8a44bb3f 100644 --- a/.github/workflows/torch-frontend-ci.yml +++ b/.github/workflows/torch-frontend-ci.yml @@ -15,6 +15,14 @@ on: - "scripts/frontends/torch-frontend/**" workflow_dispatch: +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: torch_frontend_build_and_test: name: torch-frontend CI diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index 1708ac05a..70c1bc1f8 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -24,7 +24,7 @@ def _detect_cuda_with_nvidia_smi(): sm_names = { "sm_70": ["V100"], "sm_75": ["T4", "Quadro T2000"], - "sm_80": ["PG509", "A100", "A10", "RTX 30", "A30", "RTX 40"], + "sm_80": ["PG509", "A100", "A10", "RTX 30", "A30", "RTX 40", "A16"], "sm_90": ["H100"], } for sm, names in sm_names.items(): @@ -222,6 +222,7 @@ def compile_cuda_with_ait( if verbose: _print_verbose(device_module, "// IR Dump After NVVM Codegen:") # write to output device ptx + assert _detect_cuda_with_nvidia_smi() != None byteir.translate_to_ptx(device_module.operation, output_file_dir + "/" + output_file_name, _detect_cuda_with_nvidia_smi()) with context: diff --git a/scripts/e2e/build_and_test_e2e.sh b/scripts/e2e/build_and_test_e2e.sh new file mode 100755 index 000000000..29eb37653 --- /dev/null +++ b/scripts/e2e/build_and_test_e2e.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +set -e +set -x + +CUR_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +# path to byteir root +ROOT_PROJ_DIR="$CUR_DIR/../.." + +LLVM_INSTALL_DIR="$1" + +TORCH_FRONTEND_LLVM_INSTALL_DIR="$2" + +pushd $ROOT_PROJ_DIR +# build compiler +bash scripts/compiler/build_and_lit_test.sh $LLVM_INSTALL_DIR +# build runtime +bash scripts/runtime/build_and_test.sh --cuda --python --no-test $LLVM_INSTALL_DIR +# build torch_frontend +bash scripts/frontends/torch-frontend/build_and_test.sh $TORCH_FRONTEND_LLVM_INSTALL_DIR + +pip3 install $ROOT_PROJ_DIR/external/AITemplate/python/dist/*.whl --force-reinstall +pip3 install $ROOT_PROJ_DIR/compiler/build/python/dist/*.whl --force-reinstall +pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl --force-reinstall +pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl --force-reinstall +pip3 install -r $ROOT_PROJ_DIR/frontends/torch-frontend/torch-requirements.txt + +python3 tests/numerical_test/main.py +rm -rf ./local_test +popd diff --git a/scripts/runtime/build_and_test.sh b/scripts/runtime/build_and_test.sh index 2fc490ce6..d0efa3ee5 100755 --- a/scripts/runtime/build_and_test.sh +++ b/scripts/runtime/build_and_test.sh @@ -71,6 +71,13 @@ cmake -GNinja \ cmake --build "$BUILD_DIR" --target all --target install +if [[ $BRT_ENABLE_PYTHON_BINDINGS == "ON" ]]; then + pushd $PROJ_DIR/python + # note: python packing depend on `--target install` + python3 setup.py bdist_wheel + popd +fi + if [[ $BRT_USE_CUDA == "ON" ]] && [[ $BRT_ENABLE_ASAN == "ON" ]]; then export ASAN_OPTIONS=protect_shadow_gap=0 fi diff --git a/tests/numerical_test/execute.py b/tests/numerical_test/execute.py index 7da9e24ba..46e67351a 100644 --- a/tests/numerical_test/execute.py +++ b/tests/numerical_test/execute.py @@ -86,6 +86,7 @@ def get_entry_func_name(interp): def compile_and_run_mlir(mhlo_file, target): + np.random.seed(0) try: interp = Interpreter.load_from_file(mhlo_file) np_inputs = generate_np_inputs(interp) diff --git a/tests/numerical_test/main.py b/tests/numerical_test/main.py index 1be674656..ae2c416e4 100644 --- a/tests/numerical_test/main.py +++ b/tests/numerical_test/main.py @@ -40,10 +40,10 @@ SM80_PLUS_TESTS = [ "dot_f32.mlir", - "MatmulF32Module_basic", - "bmm_rrr_add_f32.mlir", - "bmm_rrr_f32.mlir", "bmm_rrr_permute_f32.mlir", + "MatmulF32Module_basic", + "BatchMatmulAddF32Module_basic", + "BatchMatmulF32Module", ] @@ -59,7 +59,7 @@ def _detect_cuda_with_nvidia_smi(): sm_names = { 70: ["V100"], 75: ["T4", "Quadro T2000"], - 80: ["PG509", "A100", "A10", "RTX 30", "A30", "RTX 40"], + 80: ["PG509", "A100", "A10", "RTX 30", "A30", "RTX 40", "A16"], 90: ["H100"], } for sm, names in sm_names.items(): @@ -120,8 +120,8 @@ def main(): elif args.config == 'torch': results = run_torch_test(arch) failed = report_results(results) - # TODO(zzk): use test infra for dynamo tests - run_torch_dynamo_tests(arch) + # TODO(zzk): disable flash attn test for now + # run_torch_dynamo_tests(arch) sys.exit(1 if failed else 0) diff --git a/tests/numerical_test/mlir_tests/ops/bmm_rrr_add_f32.mlir b/tests/numerical_test/mlir_tests/ops/bmm_rrr_add_f32.mlir deleted file mode 100644 index c25558d16..000000000 --- a/tests/numerical_test/mlir_tests/ops/bmm_rrr_add_f32.mlir +++ /dev/null @@ -1,6 +0,0 @@ -func.func @bmm_rrr_add(%arg0 : tensor<32x256x256xf32>, %arg1 : tensor<32x256x128xf32>, %arg2 : tensor<1x32x256x128xf32>) -> tensor<1x32x256x128xf32> { - %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot} : (tensor<32x256x256xf32>, tensor<32x256x128xf32>) -> tensor<32x256x128xf32> - %1 = mhlo.reshape %0 : (tensor<32x256x128xf32>) -> tensor<1x32x256x128xf32> - %2 = mhlo.add %arg2, %1 : tensor<1x32x256x128xf32> - return %2 : tensor<1x32x256x128xf32> -} diff --git a/tests/numerical_test/mlir_tests/ops/bmm_rrr_f32.mlir b/tests/numerical_test/mlir_tests/ops/bmm_rrr_f32.mlir deleted file mode 100644 index 553c60d5b..000000000 --- a/tests/numerical_test/mlir_tests/ops/bmm_rrr_f32.mlir +++ /dev/null @@ -1,4 +0,0 @@ -func.func @bmm_rrr(%arg0 : tensor<32x256x256xf32>, %arg1 : tensor<32x256x128xf32>) -> tensor<32x256x128xf32> { - %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot} : (tensor<32x256x256xf32>, tensor<32x256x128xf32>) -> tensor<32x256x128xf32> - return %0 : tensor<32x256x128xf32> -} diff --git a/tests/numerical_test/mlir_tests/ops/bmm_rrr_permute_f32.mlir b/tests/numerical_test/mlir_tests/ops/bmm_rrr_permute_f32.mlir index baaa883ad..87cc5929d 100644 --- a/tests/numerical_test/mlir_tests/ops/bmm_rrr_permute_f32.mlir +++ b/tests/numerical_test/mlir_tests/ops/bmm_rrr_permute_f32.mlir @@ -1,6 +1,6 @@ -func.func @bmm_rrr(%arg0 : tensor<12x256x256xf32>, %arg1 : tensor<12x256x64xf32>) -> tensor<1x256x12x64xf32> { - %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot} : (tensor<12x256x256xf32>, tensor<12x256x64xf32>) -> tensor<12x256x64xf32> - %1 = mhlo.reshape %0 : (tensor<12x256x64xf32>) -> tensor<1x12x256x64xf32> - %2 = "mhlo.transpose"(%1) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x12x256x64xf32>) -> tensor<1x256x12x64xf32> - return %2 : tensor<1x256x12x64xf32> +func.func @bmm_rrr_permute_f32(%arg0: tensor<4x2x2xf32>, %arg1: tensor<4x2x2xf32>) -> tensor<2x2x2x2xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x2x2xf32>, tensor<4x2x2xf32>) -> tensor<4x2x2xf32> + %1 = mhlo.reshape %0 : (tensor<4x2x2xf32>) -> tensor<2x2x2x2xf32> + %2 = "mhlo.transpose"(%1) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %2 : tensor<2x2x2x2xf32> } diff --git a/tests/numerical_test/mlir_tests/ops/dot_f32.mlir b/tests/numerical_test/mlir_tests/ops/dot_f32.mlir index 74c0646ae..e41c25455 100644 --- a/tests/numerical_test/mlir_tests/ops/dot_f32.mlir +++ b/tests/numerical_test/mlir_tests/ops/dot_f32.mlir @@ -1,4 +1,4 @@ -func.func @gemm_rrr_f32(%arg0 : tensor<256x256xf32>, %arg1 : tensor<256x256xf32>) -> tensor<256x256xf32> { - %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<256x256xf32>, tensor<256x256xf32>) -> tensor<256x256xf32> - return %0 : tensor<256x256xf32> +func.func @gemm_rrr_f32(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>) -> tensor<4x4xf32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> } diff --git a/tests/numerical_test/torch_e2e_testing/test_suite/basic.py b/tests/numerical_test/torch_e2e_testing/test_suite/basic.py index 69ac697b5..483e434f7 100644 --- a/tests/numerical_test/torch_e2e_testing/test_suite/basic.py +++ b/tests/numerical_test/torch_e2e_testing/test_suite/basic.py @@ -58,3 +58,31 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: MatmulF32Module()) def MatmulF32Module_basic(module, tu: TestUtils): module.forward(tu.rand(5, 6), tu.rand(6, 10)) + + +class BatchMatmulF32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, a, b): + return torch.bmm(a, b) + + +@register_test_case(module_factory=lambda: BatchMatmulF32Module()) +def BatchMatmulF32Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 6), tu.rand(2, 6, 10)) + + +class BatchMatmulAddF32Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + return c + torch.bmm(a, b) + + +@register_test_case(module_factory=lambda: BatchMatmulAddF32Module()) +def BatchMatmulAddF32Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 6), tu.rand(2, 6, 10), tu.rand(2, 5, 10))