Skip to content

Commit

Permalink
[runtime] add BRTBackend, add byteir.compile_from_string (#434)
Browse files Browse the repository at this point in the history
as title
  • Loading branch information
qingyunqu authored Aug 26, 2024
1 parent 7ff2f36 commit 6b93af6
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 76 deletions.
2 changes: 1 addition & 1 deletion compiler/python/byteir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# ==============================================================================

from ._mlir_libs._byteir import *
from .compile import compile, DebugType
from .compile import compile, compile_from_string, DebugType
53 changes: 40 additions & 13 deletions compiler/python/byteir/compile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from pathlib import Path
import os
from shutil import copymode
from typing import Union

from . import ir
from .passmanager import PassManager
Expand Down Expand Up @@ -366,9 +366,8 @@ def _compile_cpu(
if (module.operation.get_asm() != deserialized_module.operation.get_asm()):
raise ValueError("module asm has be changed after byre serialization")


def compile(
input_file_path: str,
def compile_from_string(
input_string_or_bytes: Union[str, bytes],
output_file_path: str,
entry_func: str = "main",
target: str = "cuda",
Expand All @@ -391,18 +390,12 @@ def compile(
gpu_arch_num = int(gpu_arch[3:])
if enable_tf32:
assert gpu_arch_num >= 80, "1xtf32 only support on gpu >= sm_80"
print(f"Compiling {os.path.basename(input_file_path)} to {gpu_arch}")
print(f"[ByteIR] Compiling to {gpu_arch}")
elif _device == "cpu":
print(f"Compiling {os.path.basename(input_file_path)} to {cpu_arch}")
print(f"[ByteIR] Compiling to {cpu_arch}")

### load from .mlir or .mlirbc
from byteir._mlir_libs._stablehlo import deserialize_portable_artifact
context = ir.Context()
if input_file_path.endswith(".mlirbc"):
module_bytes = deserialize_portable_artifact(open(input_file_path, "rb").read())
module = ir.Module.parse(module_bytes, context)
else:
module = ir.Module.parse(open(input_file_path, "r").read(), context)
module = ir.Module.parse(input_string_or_bytes, context)
_print_verbose(module, "// IR Dump Input MLIR:") if verbose else ...

### legalize stablehlo to mhlo
Expand Down Expand Up @@ -444,3 +437,37 @@ def compile(
_compile_fn(compile_options)
else:
raise NotImplementedError("not implemented target: {}".format(target))

def compile(
input_file_path: str,
output_file_path: str,
entry_func: str = "main",
target: str = "cuda",
gpu_arch: str = "local",
cpu_arch: str = "x86_64",
byre_serial_version: str = "1.0.0",
verbose: bool = False,
enable_tf32: bool = False,
parallelism: int = 1,
disable_byteir_ait_cache: bool = False,
**kwargs,
) -> None:
### load from .mlir or .mlirbc
from byteir._mlir_libs._stablehlo import deserialize_portable_artifact
if input_file_path.endswith(".mlirbc"):
module_bytes = deserialize_portable_artifact(open(input_file_path, "rb").read())
else:
module_bytes = open(input_file_path, "r").read()

compile_from_string(module_bytes,
output_file_path=output_file_path,
entry_func=entry_func,
target=target,
gpu_arch=gpu_arch,
cpu_arch=cpu_arch,
byre_serial_version=byre_serial_version,
verbose=verbose,
enable_tf32=enable_tf32,
parallelism=parallelism,
disable_byteir_ait_cache=disable_byteir_ait_cache,
kwargs=kwargs)
144 changes: 144 additions & 0 deletions runtime/python/brt/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import brt
from brt.utils import brt_dtype_to_torch_dtype
import torch

import time

# BRTBackend for static shape and single device
class BRTBackend:
def __init__(self, byre_file_path, device):
assert device == "cuda" or device == "cpu"
if device == "cuda":
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete
_allocator_alloc = caching_allocator_alloc
_allocator_delete = caching_allocator_delete
_stream = torch.cuda.current_stream()._as_parameter_.value
else:
_allocator_alloc = None
_allocator_delete = None
_stream = None
self.session = brt.Session(
device=device.upper(),
alloc_func=_allocator_alloc,
free_func=_allocator_delete,
)
self.session.load(byre_file_path)
self.req = self.session.new_request_context(_stream)
self.device = device

# for static shape model, just cache shape and dtype info
self.input_arg_offsets = self.session.get_input_arg_offsets()
self.input_shapes = []
self.input_dtypes = []
for offset in self.input_arg_offsets:
self.input_shapes.append(self.session.get_static_shape(offset))
self.input_dtypes.append(brt_dtype_to_torch_dtype(self.session.get_data_type(offset)))
self.output_arg_offsets = self.session.get_output_arg_offsets()
self.output_shapes = []
self.output_dtypes = []
for offset in self.output_arg_offsets:
self.output_shapes.append(self.session.get_static_shape(offset))
self.output_dtypes.append(brt_dtype_to_torch_dtype(self.session.get_data_type(offset)))

def _check_shape_dtype(self, tensors, shapes, dtypes):
assert len(tensors) == len(shapes)
assert len(tensors) == len(dtypes)
for tensor, shape, dtype in zip(tensors, shapes, dtypes):
assert list(shape) == list(tensor.shape)
assert dtype == tensor.dtype

def _bind_inputs(self, inputs):
inputOffsetAndData = []
for offset, input in zip(self.input_arg_offsets, inputs):
inputOffsetAndData.append((offset, input.data_ptr()))
self.req.bind_args(inputOffsetAndData)

def _bind_outputs(self, outputs):
outputOffsetAndData = []
for offset, output in zip(self.output_arg_offsets, outputs):
outputOffsetAndData.append((offset, output.data_ptr()))
self.req.bind_args(outputOffsetAndData)

def run(self, inputs, check=True):
if check:
self._check_shape_dtype(inputs, self.input_shapes, self.input_dtypes)

# alloc outputs
outputs = []
for shape, dtype in zip(self.output_shapes, self.output_dtypes):
outputs.append(torch.empty(shape, dtype=dtype, device=self.device))

self._bind_inputs(inputs)
self._bind_outputs(outputs)

# run
self.req.finish_io_binding()
self.req.run()
self.req.sync()

return outputs

def profile(self, inputs, check=True, warmup_trials=10, run_trials=50):
if check:
self._check_shape_dtype(inputs, self.input_shapes, self.input_dtypes)

# alloc outputs
outputs = []
for shape, dtype in zip(self.output_shapes, self.output_dtypes):
outputs.append(torch.empty(shape, dtype=dtype, device=self.device))

self._bind_inputs(inputs)
self._bind_outputs(outputs)
self.req.finish_io_binding()

# warmup
for _ in range(warmup_trials):
self.req.run()
self.req.sync()

start = time.time()
for _ in range(run_trials):
self.req.run()
self.req.sync()
end = time.time()
avg = ((end - start) * 1000) / run_trials

return outputs, avg

def run_with_outputs(self, inputs, outputs, check=True):
if check:
self._check_shape_dtype(inputs, self.input_shapes, self.input_dtypes)
self._check_shape_dtype(outputs, self.output_shapes, self.output_dtypes)

self._bind_inputs(inputs)
self._bind_outputs(outputs)

self.req.finish_io_binding()
self.req.run()
self.req.sync()

def profile_with_outputs(self, inputs, outputs, check=True, warmup_trials=10, run_trials=50):
if check:
self._check_shape_dtype(inputs, self.input_shapes, self.input_dtypes)
self._check_shape_dtype(outputs, self.output_shapes, self.output_dtypes)

self._bind_inputs(inputs)
self._bind_outputs(outputs)
self.req.finish_io_binding()

# warmup
for _ in range(warmup_trials):
self.req.run()
self.req.sync()

start = time.time()
for _ in range(run_trials):
self.req.run()
self.req.sync()
end = time.time()
avg = ((end - start) * 1000) / run_trials

return avg


# TODO: add BRTDynamicShapeBackend and BRTNCCLBackend
69 changes: 7 additions & 62 deletions tests/numerical_test/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from reporting import TestResult

import brt
from brt.backend import BRTBackend
import byteir
from byteir import ir
from byteir._backend_registry import get_target_device
Expand Down Expand Up @@ -100,62 +101,6 @@ def generate_torch_outputs(self, device="cpu") -> List[torch.Tensor]:
return outputs


class BRTBackend:
def __init__(self, device, brt_file_path):
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete

_allocator_alloc = caching_allocator_alloc if device == "cuda" else None
_allocator_delete = caching_allocator_delete if device == "cuda" else None
_stream = (
torch.cuda.current_stream()._as_parameter_.value
if device == "cuda"
else None
)
self.session = brt.Session(
device=device.upper(),
alloc_func=_allocator_alloc,
free_func=_allocator_delete,
)
self.session.load(brt_file_path)
self.req = self.session.new_request_context(_stream)

def execute(self, inputs, outputs):
# TODO(lyq): how to support dynamic shape?
assert len(self.session.get_input_arg_offsets()) == len(inputs)
assert len(self.session.get_output_arg_offsets()) == len(outputs)
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(arg.shape)
self.req.bind_arg(offset, arg.data_ptr())
for offset, ret in zip(self.session.get_output_arg_offsets(), outputs):
assert list(self.session.get_static_shape(offset)) == list(ret.shape)
self.req.bind_arg(offset, ret.data_ptr())
self.req.finish_io_binding()
self.req.run()
self.req.sync()

def profile(self, inputs, outputs, warmup_trials=10, run_trials=50):
assert len(self.session.get_input_arg_offsets()) == len(inputs)
assert len(self.session.get_output_arg_offsets()) == len(outputs)
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(arg.shape)
self.req.bind_arg(offset, arg.data_ptr())
for offset, ret in zip(self.session.get_output_arg_offsets(), outputs):
assert list(self.session.get_static_shape(offset)) == list(ret.shape)
self.req.bind_arg(offset, ret.data_ptr())
self.req.finish_io_binding()

for _ in range(warmup_trials):
self.req.run()
self.req.sync()

start = time.time()
for _ in range(run_trials):
self.req.run()
self.req.sync()
end = time.time()
return ((end - start) * 1000) / run_trials


def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical", unique_name=None, **kwargs):
if unique_name is None:
unique_name = os.path.basename(mhlo_file).split(".")[0] + "." + target
Expand Down Expand Up @@ -195,7 +140,7 @@ def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical",
# brt runtime
try:
cur_device = get_target_device(target)
brt_backend = BRTBackend(cur_device, output_mlir_file_name)
brt_backend = BRTBackend(output_mlir_file_name, cur_device)

torch_inputs = []
for np_input in np_inputs:
Expand All @@ -204,9 +149,9 @@ def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical",
torch_outputs = data_generator.generate_torch_outputs(cur_device)

if mode == "numerical":
brt_backend.execute(torch_inputs, torch_outputs)
brt_backend.run_with_outputs(torch_inputs, torch_outputs)
else:
avg_time = brt_backend.profile(torch_inputs, torch_outputs)
avg_time = brt_backend.profile_with_outputs(torch_inputs, torch_outputs)
return TestResult(
unique_name=unique_name,
compilation_error=None,
Expand Down Expand Up @@ -299,11 +244,11 @@ def compile_and_run_torch(test, target, workdir, verbose, mode="numerical"):

# runtime
try:
brt_backend = BRTBackend(cur_device, output_mlir_file_name)
brt_backend = BRTBackend(output_mlir_file_name, cur_device)
if mode == "numerical":
brt_backend.execute(torch_inputs, torch_outputs)
brt_backend.run_with_outputs(torch_inputs, torch_outputs)
else:
avg_time = brt_backend.profile(torch_inputs, torch_outputs)
avg_time = brt_backend.profile_with_outputs(torch_inputs, torch_outputs)
return TestResult(
unique_name=unique_name,
compilation_error=None,
Expand Down

0 comments on commit 6b93af6

Please sign in to comment.