Skip to content

Commit

Permalink
[e2e] add mlp inference e2e example (#433)
Browse files Browse the repository at this point in the history
as title
  • Loading branch information
qingyunqu authored Aug 24, 2024
1 parent e02e4d6 commit 7ff2f36
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
43 changes: 43 additions & 0 deletions frontends/torch-frontend/examples/inference/brt_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import brt
from brt.utils import brt_dtype_to_torch_dtype
import torch

class BRTBackend:
def __init__(self, device, brt_file_path):
assert device == "cuda" or device == "cpu"
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete

_allocator_alloc = caching_allocator_alloc if device == "cuda" else None
_allocator_delete = caching_allocator_delete if device == "cuda" else None
_stream = (
torch.cuda.current_stream()._as_parameter_.value
if device == "cuda"
else None
)
self.session = brt.Session(
device=device.upper(),
alloc_func=_allocator_alloc,
free_func=_allocator_delete,
)
self.session.load(brt_file_path)
self.req = self.session.new_request_context(_stream)
self.device = device

def execute(self, inputs):
# TODO(lyq): how to support dynamic shape?
assert len(self.session.get_input_arg_offsets()) == len(inputs)
outputs = []
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(arg.shape)
assert brt_dtype_to_torch_dtype(self.session.get_data_type(offset)) == arg.dtype
self.req.bind_arg(offset, arg.data_ptr())
for offset in self.session.get_output_arg_offsets():
shape = self.session.get_static_shape(offset)
outputs.append(torch.empty(shape, dtype=brt_dtype_to_torch_dtype(self.session.get_data_type(offset)), device=self.device))
self.req.bind_arg(offset, outputs[-1].data_ptr())

self.req.finish_io_binding()
self.req.run()
self.req.sync()

return outputs
49 changes: 49 additions & 0 deletions frontends/torch-frontend/examples/inference/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os

import torch
from torch import nn
import torch_frontend
import byteir

from brt_backend import BRTBackend

class MLP(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.linear2 = nn.Linear(20, 20)
self.linear3 = nn.Linear(20, 10)

def forward(self, x):
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
x = torch.nn.functional.relu(x)
x = self.linear3(x)
return x

workspace = "./workspace"
os.makedirs(workspace, exist_ok=True)

model = MLP().cuda().eval()
inputs = [torch.randn(2, 10).cuda()]
traced_model = torch.jit.trace(model, inputs)

stablehlo_file = workspace + "/model.stablehlo.mlir"
byre_file = workspace + "/model.byre.mlir"
module = torch_frontend.compile(traced_model, inputs, "stablehlo")
with open(stablehlo_file, "w") as f:
f.write(module.operation.get_asm())

byteir.compile(stablehlo_file, byre_file, entry_func="forward", target="cuda")

backend = BRTBackend("cuda", byre_file)

torch_outputs = model(*inputs)
torch_jit_outputs = traced_model(*inputs)
byteir_outputs = backend.execute(inputs)
if len(byteir_outputs) == 1:
byteir_outputs = byteir_outputs[0]

torch.testing.assert_close(torch_outputs, torch_jit_outputs)
torch.testing.assert_close(torch_outputs, byteir_outputs)

0 comments on commit 7ff2f36

Please sign in to comment.