From 7ff2f368269964040ee461414b20d9e3e219dc8c Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 24 Aug 2024 10:34:57 +0800 Subject: [PATCH] [e2e] add mlp inference e2e example (#433) as title --- .../examples/inference/brt_backend.py | 43 ++++++++++++++++ .../torch-frontend/examples/inference/mlp.py | 49 +++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 frontends/torch-frontend/examples/inference/brt_backend.py create mode 100644 frontends/torch-frontend/examples/inference/mlp.py diff --git a/frontends/torch-frontend/examples/inference/brt_backend.py b/frontends/torch-frontend/examples/inference/brt_backend.py new file mode 100644 index 000000000..8d3279560 --- /dev/null +++ b/frontends/torch-frontend/examples/inference/brt_backend.py @@ -0,0 +1,43 @@ +import brt +from brt.utils import brt_dtype_to_torch_dtype +import torch + +class BRTBackend: + def __init__(self, device, brt_file_path): + assert device == "cuda" or device == "cpu" + from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete + + _allocator_alloc = caching_allocator_alloc if device == "cuda" else None + _allocator_delete = caching_allocator_delete if device == "cuda" else None + _stream = ( + torch.cuda.current_stream()._as_parameter_.value + if device == "cuda" + else None + ) + self.session = brt.Session( + device=device.upper(), + alloc_func=_allocator_alloc, + free_func=_allocator_delete, + ) + self.session.load(brt_file_path) + self.req = self.session.new_request_context(_stream) + self.device = device + + def execute(self, inputs): + # TODO(lyq): how to support dynamic shape? + assert len(self.session.get_input_arg_offsets()) == len(inputs) + outputs = [] + for offset, arg in zip(self.session.get_input_arg_offsets(), inputs): + assert list(self.session.get_static_shape(offset)) == list(arg.shape) + assert brt_dtype_to_torch_dtype(self.session.get_data_type(offset)) == arg.dtype + self.req.bind_arg(offset, arg.data_ptr()) + for offset in self.session.get_output_arg_offsets(): + shape = self.session.get_static_shape(offset) + outputs.append(torch.empty(shape, dtype=brt_dtype_to_torch_dtype(self.session.get_data_type(offset)), device=self.device)) + self.req.bind_arg(offset, outputs[-1].data_ptr()) + + self.req.finish_io_binding() + self.req.run() + self.req.sync() + + return outputs diff --git a/frontends/torch-frontend/examples/inference/mlp.py b/frontends/torch-frontend/examples/inference/mlp.py new file mode 100644 index 000000000..3f3e8fb55 --- /dev/null +++ b/frontends/torch-frontend/examples/inference/mlp.py @@ -0,0 +1,49 @@ +import os + +import torch +from torch import nn +import torch_frontend +import byteir + +from brt_backend import BRTBackend + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 20) + self.linear2 = nn.Linear(20, 20) + self.linear3 = nn.Linear(20, 10) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.functional.relu(x) + x = self.linear2(x) + x = torch.nn.functional.relu(x) + x = self.linear3(x) + return x + +workspace = "./workspace" +os.makedirs(workspace, exist_ok=True) + +model = MLP().cuda().eval() +inputs = [torch.randn(2, 10).cuda()] +traced_model = torch.jit.trace(model, inputs) + +stablehlo_file = workspace + "/model.stablehlo.mlir" +byre_file = workspace + "/model.byre.mlir" +module = torch_frontend.compile(traced_model, inputs, "stablehlo") +with open(stablehlo_file, "w") as f: + f.write(module.operation.get_asm()) + +byteir.compile(stablehlo_file, byre_file, entry_func="forward", target="cuda") + +backend = BRTBackend("cuda", byre_file) + +torch_outputs = model(*inputs) +torch_jit_outputs = traced_model(*inputs) +byteir_outputs = backend.execute(inputs) +if len(byteir_outputs) == 1: + byteir_outputs = byteir_outputs[0] + +torch.testing.assert_close(torch_outputs, torch_jit_outputs) +torch.testing.assert_close(torch_outputs, byteir_outputs)