-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[e2e] add mlp inference e2e example (#433)
as title
- Loading branch information
Showing
2 changed files
with
92 additions
and
0 deletions.
There are no files selected for viewing
43 changes: 43 additions & 0 deletions
43
frontends/torch-frontend/examples/inference/brt_backend.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import brt | ||
from brt.utils import brt_dtype_to_torch_dtype | ||
import torch | ||
|
||
class BRTBackend: | ||
def __init__(self, device, brt_file_path): | ||
assert device == "cuda" or device == "cpu" | ||
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete | ||
|
||
_allocator_alloc = caching_allocator_alloc if device == "cuda" else None | ||
_allocator_delete = caching_allocator_delete if device == "cuda" else None | ||
_stream = ( | ||
torch.cuda.current_stream()._as_parameter_.value | ||
if device == "cuda" | ||
else None | ||
) | ||
self.session = brt.Session( | ||
device=device.upper(), | ||
alloc_func=_allocator_alloc, | ||
free_func=_allocator_delete, | ||
) | ||
self.session.load(brt_file_path) | ||
self.req = self.session.new_request_context(_stream) | ||
self.device = device | ||
|
||
def execute(self, inputs): | ||
# TODO(lyq): how to support dynamic shape? | ||
assert len(self.session.get_input_arg_offsets()) == len(inputs) | ||
outputs = [] | ||
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs): | ||
assert list(self.session.get_static_shape(offset)) == list(arg.shape) | ||
assert brt_dtype_to_torch_dtype(self.session.get_data_type(offset)) == arg.dtype | ||
self.req.bind_arg(offset, arg.data_ptr()) | ||
for offset in self.session.get_output_arg_offsets(): | ||
shape = self.session.get_static_shape(offset) | ||
outputs.append(torch.empty(shape, dtype=brt_dtype_to_torch_dtype(self.session.get_data_type(offset)), device=self.device)) | ||
self.req.bind_arg(offset, outputs[-1].data_ptr()) | ||
|
||
self.req.finish_io_binding() | ||
self.req.run() | ||
self.req.sync() | ||
|
||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
import torch_frontend | ||
import byteir | ||
|
||
from brt_backend import BRTBackend | ||
|
||
class MLP(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear1 = nn.Linear(10, 20) | ||
self.linear2 = nn.Linear(20, 20) | ||
self.linear3 = nn.Linear(20, 10) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = torch.nn.functional.relu(x) | ||
x = self.linear2(x) | ||
x = torch.nn.functional.relu(x) | ||
x = self.linear3(x) | ||
return x | ||
|
||
workspace = "./workspace" | ||
os.makedirs(workspace, exist_ok=True) | ||
|
||
model = MLP().cuda().eval() | ||
inputs = [torch.randn(2, 10).cuda()] | ||
traced_model = torch.jit.trace(model, inputs) | ||
|
||
stablehlo_file = workspace + "/model.stablehlo.mlir" | ||
byre_file = workspace + "/model.byre.mlir" | ||
module = torch_frontend.compile(traced_model, inputs, "stablehlo") | ||
with open(stablehlo_file, "w") as f: | ||
f.write(module.operation.get_asm()) | ||
|
||
byteir.compile(stablehlo_file, byre_file, entry_func="forward", target="cuda") | ||
|
||
backend = BRTBackend("cuda", byre_file) | ||
|
||
torch_outputs = model(*inputs) | ||
torch_jit_outputs = traced_model(*inputs) | ||
byteir_outputs = backend.execute(inputs) | ||
if len(byteir_outputs) == 1: | ||
byteir_outputs = byteir_outputs[0] | ||
|
||
torch.testing.assert_close(torch_outputs, torch_jit_outputs) | ||
torch.testing.assert_close(torch_outputs, byteir_outputs) |