Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Default Backend Should be changed into tilelang instead of tvm before v0.0.1 release #252

Closed
3 tasks done
LeiWang1999 opened this issue Nov 28, 2024 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@LeiWang1999
Copy link
Contributor

LeiWang1999 commented Nov 28, 2024

We propose changing the default backend from tvm to tilelang before the v0.0.1 release. The tilelang backend has demonstrated compatibility with all current operators (e.g., Matmul, Flash Attention) and offers significant performance advantages.

TODO Items

  • Implement all tilelang templates for current operators
  • Make all operators test pass.
  • Evaluate the performance variants.
@LeiWang1999
Copy link
Contributor Author

ref to release plan #150

@LeiWang1999 LeiWang1999 self-assigned this Dec 8, 2024
@LeiWang1999 LeiWang1999 added the enhancement New feature or request label Dec 9, 2024
@LeiWang1999
Copy link
Contributor Author

LeiWang1999 commented Dec 12, 2024

<style> </style>
config tir tl tl/tir
1-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None 0.303 0.304 0.996710526
1-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 0.087 0.087 1
1-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 0.089 0.09 0.988888889
1-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 0.17 0.154 1.103896104
1-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 0.051 0.051 1
32-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None 0.326 0.348 0.936781609
32-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 0.159 0.611 0.260229133
32-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 0.166 0.611 0.271685761
32-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 0.177 0.186 0.951612903
32-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 0.113 0.146 0.773972603
128-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None0 0.423 0.453 0.933774834
128-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 0.428 0.405 1.056790123
128-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 0.472 0.421 1.121140143
128-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 0.313 0.233 1.343347639
128-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 0.273 0.295 0.925423729
512-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None1 1.097 1.364 0.804252199
512-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 1.388 1.206 1.150912106
512-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 1.544 1.278 1.208137715
512-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 0.779 0.629 1.238473768
512-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 0.85 0.797 1.066499373
16384-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None36 37.259 36.038 1.033880903
16384-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 33.878 33.896 0.999468964
16384-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 37.392 37.301 1.002439613
16384-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 23.506 17.365 1.353642384
16384-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 24.608 21.205 1.160481019

@LeiWang1999
Copy link
Contributor Author

in most benchmark cases, tl has better performance or equal to tir backend, while in m=32 or m=128 with dequantize, tl has weaker performance than tir backend, which is mainly due to the lack of implementation for block reduction.

To reproduce the results:

#!/usr/bin/env bash

set -euo pipefail

test_shapes="$(python3 <<EOF
import json

benchmark_shapes = [
    (1, 16384, 16384),
    (32, 16384, 16384),
    (128, 16384, 16384),
    (512, 16384, 16384),
    (16384, 16384, 16384),
]

op_configs = [
    {"A_dtype": "float16", "W_dtype": "float16", "accum_dtype": "float16", "out_dtype": "float16"},
    {"A_dtype": "float16", "W_dtype": "int4",    "accum_dtype": "float16", "out_dtype": "float16", "with_scaling": False},
    {"A_dtype": "float16", "W_dtype": "int4",    "accum_dtype": "float16", "out_dtype": "float16", "with_scaling": True,  "group_size": -1},
    {"A_dtype": "int8",    "W_dtype": "int8",    "accum_dtype": "int32",   "out_dtype": "int8"},
    {"A_dtype": "int8",    "W_dtype": "int2",    "accum_dtype": "int32",   "out_dtype": "int8"},
]

op_config = "MatmulConfig"
op_class = "Matmul"

configs = []
for shape in benchmark_shapes:
    for config in op_configs:
        input_args = list(shape)
        input_args.append(config["A_dtype"])
        input_args.append(config["W_dtype"])
        input_args.append(config["out_dtype"])
        input_args.append(config["accum_dtype"])
        input_args.append("nt") # layout
        input_args.append(False) # with_bias
        input_args.append(-1 if "group_size" not in config else config["group_size"])
        input_args.append(False if "with_scaling" not in config else config["with_scaling"])
        input_args.append(False if "with_zeros" not in config else config["with_zeros"])
        input_args.append(None if "zeros_mode" not in config else config["zeros_mode"])
        
        configs.append([op_config, op_class, input_args])

print(json.dumps(configs))
EOF
)"

echo "Running benchmark with test shapes:"
python3 -c "import json; configs = json.loads('$test_shapes'); [print(c) for c in configs]"

mkdir -p benchmark_logs

# backends=("tir" "tl")
backends=("tl")

for backend in "${backends[@]}"; do
    log_file="benchmark_logs/${backend}_benchmark.log"
    echo "Running benchmark for backend '${backend}'"
    cmd="python ./benchmark/operators/benchmark_bitblas_matmul.py --backend ${backend} --test_shapes '${test_shapes}'"
    echo "Running command: $cmd"
    bash -c "$cmd 2>&1 | tee ${log_file}"
    echo "Logs for backend '${backend}' written to ${log_file}"
done

@LeiWang1999
Copy link
Contributor Author

TL with Split K Support Performance:

<style> </style>
config tir tl tl+splitk tl/tir tl/tir-splitk
1-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None 0.303 0.304 0.303 0.996710526 1
1-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 0.087 0.087 0.087 1 1
1-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 0.089 0.09 0.09 0.988888889 0.988888889
1-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 0.17 0.154 0.154 1.103896104 1.103896104
1-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 0.051 0.051 0.051 1 1
32-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None 0.326 0.348 0.331 0.936781609 0.98489426
32-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 0.159 0.611 0.117 0.260229133 1.358974359
32-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 0.166 0.611 0.124 0.271685761 1.338709677
32-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 0.177 0.186 0.171 0.951612903 1.035087719
32-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 0.113 0.146 0.108 0.773972603 1.046296296
128-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None 0.423 0.453 0.434 0.933774834 0.974654378
128-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 0.428 0.405 0.323 1.056790123 1.325077399
128-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 0.472 0.421 0.349 1.121140143 1.35243553
128-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 0.313 0.233 0.221 1.343347639 1.416289593
128-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 0.273 0.295 0.199 0.925423729 1.371859296
512-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None 1.097 1.364 1.175 0.804252199 0.933617021
512-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 1.388 1.206 1.154 1.150912106 1.202772964
512-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 1.544 1.278 1.21 1.208137715 1.276033058
512-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 0.779 0.629 0.794 1.238473768 0.981108312
512-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 0.85 0.797 0.942 1.066499373 0.902335456
16384-16384-16384-float16-float16-float16-float16-nt-False--1-False-False-None 37.259 36.038 35.585 1.033880903 1.047042293
16384-16384-16384-float16-int4-float16-float16-nt-False--1-False-False-None 33.878 33.896 33.514 0.999468964 1.010861133
16384-16384-16384-float16-int4-float16-float16-nt-False--1-True-False-None 37.392 37.301 37.847 1.002439613 0.987977911
16384-16384-16384-int8-int8-int8-int32-nt-False--1-False-False-None 23.506 17.365 16.707 1.353642384 1.406955168
16384-16384-16384-int8-int2-int8-int32-nt-False--1-False-False-None 24.608 21.205 19.441 1.160481019 1.265778509

Let's make a new pull request to change the default backend into tilelang

@LeiWang1999
Copy link
Contributor Author

Closed as be modified at pr #270

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant