Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] reduce host side overhead #328

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import dataclasses
import functools
import logging
from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union
from brt.utils import brt_dtype_to_torch_dtype

import torch

Expand Down Expand Up @@ -42,56 +44,99 @@ def __init__(self, module_path_or_session, none_indices):
self._none_indices = none_indices
self._req = self._session.new_request_context(
torch.cuda.current_stream()._as_parameter_.value)
self.input_arg_offsets = self._session.get_input_arg_offsets()
self.output_arg_offsets = self._session.get_output_arg_offsets()

self.output_shape_and_dtype = [(
self._session.get_static_shape(offset),
brt_dtype_to_torch_dtype(self._session.get_data_type(offset)),
) for offset in self._session.get_output_arg_offsets()]

self._outs_len = len(self.output_arg_offsets)
self.static_shape_and_dtype = [
(self._session.get_static_shape(offset),
brt_dtype_to_torch_dtype(self._session.get_data_type(offset)))
for offset in self.output_arg_offsets
]

self.real_outs_index_map = self._get_outputs_index_map(
self._outs_len, self._none_indices)
self.strited_inputs_index = None

def _get_outputs_index_map(self, out_lens: int, none_indices: List[int]):
res = []
none_lens = len(none_indices)
none_cnt = 0
for idx in range(out_lens + none_lens):
if none_cnt < none_lens and idx == none_indices[none_cnt]:
none_cnt += 1
continue
res.append(idx)

return res

@functools.lru_cache
def get_out_tensors(self, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comments for this? For example, describe the limitation of this.

"""
The number of outputs is too large, which causes Torch to still take a significant amount of time even
with a memory pool. We just use a simple cache to reduce this overhead.

NB. One should notice that: We made an assumption here, the subgraph will just be called once per
training iteration. As the cached output tensor is not reusable inside a iteration.

TODO: We could implement a memory pool or just reduce the amount of torch tensor allocation, e.g. Just
alloc a large tensor and split this to output tensors.
"""
outputs_ptr = [None] * self._outs_len
results = [None] * (self._outs_len + len(self._none_indices))

for idx, shape_dty in enumerate(self.static_shape_and_dtype):
_out = torch.empty(shape_dty[0], dtype=shape_dty[1], device=device)
results[self.real_outs_index_map[idx]] = _out
outputs_ptr[idx] = _out.data_ptr()

return results, outputs_ptr

def __call__(self, *inputs):
from brt.utils import brt_dtype_to_torch_dtype

log.debug(f"***** Run function compiled through byteir ******")

# FIXME. byteir requires all inputs on device side, move host side tensor to device.
# Preprocess the strided tensor as byteir does not support yet.
new_inputs = []

for i in range(0, len(inputs)):
_t = inputs[i]
if not _t.is_cuda:
log.warning(f"device error: type={type(_t)}, {_t.device}")
_t = _t.to("cuda")
new_inputs.append(_t.contiguous())

device = new_inputs[0].device

results = [
torch.empty(
self._session.get_static_shape(offset),
dtype=brt_dtype_to_torch_dtype(
self._session.get_data_type(offset)),
device=device,
) for offset in self._session.get_output_arg_offsets()
]

for offset, input in zip(self._session.get_input_arg_offsets(),
new_inputs):
self._req.bind_arg(offset, input.data_ptr())
for offset, output in zip(self._session.get_output_arg_offsets(),
results):
self._req.bind_arg(offset, output.data_ptr())
new_inputs_ptr = [None] * len(inputs)

if self.strited_inputs_index is None:
self.strited_inputs_index = []
for i in range(0, len(inputs)):
_t = inputs[i]
if not _t.is_contiguous():
_t = _t.contiguous()
self.strited_inputs_index.append(i)
new_inputs_ptr[i] = _t.data_ptr()
else:
for i in range(0, len(inputs)):
new_inputs_ptr[i] = inputs[i].data_ptr()
for i in self.strited_inputs_index:
new_inputs_ptr[i] = inputs[i].contiguous().data_ptr()

device = inputs[0].device

results, outputs_ptr = self.get_out_tensors(device)

inputOffsetAndArg = [None] * len(new_inputs_ptr)
outputOffsetAndArg = [None] * len(outputs_ptr)
for idx, (offset, input_ptr) in enumerate(zip(self.input_arg_offsets, new_inputs_ptr)):
inputOffsetAndArg[idx] = (offset, input_ptr)
for idx, (offset, output_ptr) in enumerate(zip(self.output_arg_offsets, outputs_ptr)):
outputOffsetAndArg[idx] = (offset, output_ptr)
self._req.bind_args(inputOffsetAndArg)
self._req.bind_args(outputOffsetAndArg)
self._req.finish_io_binding()
self._req.run()
self._req.sync()

# add None results to return values
rets = []
none_cnt = 0
result_cnt = 0
for i in range(len(results) + len(self._none_indices)):
if none_cnt < len(
self._none_indices) and i == self._none_indices[none_cnt]:
rets.append(None)
none_cnt += 1
else:
rets.append(results[result_cnt])
result_cnt += 1
rets = results

if len(rets) == 1:
return rets[0]
return rets
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def byteir_compiler(
partition_fn=byteir_partition_fn,
#partition_fn=min_cut_rematerialization_partition,
#partition_fn=default_partition,
keep_inference_input_mutations=False,
)

fake_mode = detect_fake_mode(
Expand Down
24 changes: 24 additions & 0 deletions runtime/python/src/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ PYBIND11_MODULE(MODULE_NAME, m) {
THROW_ON_FAIL(
req.Context().BindArg(offset, reinterpret_cast<void *>(ptr)));
})
.def("bind_args",
[](ReqeustContextWithSession &req, py::list offset_and_args) {
for (auto handle : offset_and_args) {
PyObject *obj = handle.ptr();
if (!PyTuple_Check(obj) || PyTuple_Size(obj) != 2) {
PyErr_SetString(PyExc_TypeError,
"expect pair of offset and arg");
return;
}

PyObject *offset = PyTuple_GetItem(obj, 0);
PyObject *arg = PyTuple_GetItem(obj, 1);
if (!PyLong_Check(offset)) {
PyErr_SetString(PyExc_TypeError, "offset should be integer");
return;
}
if (!PyLong_Check(arg)) {
PyErr_SetString(PyExc_TypeError, "arg should be integer");
return;
}
THROW_ON_FAIL(req.Context().BindArg(PyLong_AsSize_t(offset),
PyLong_AsVoidPtr(arg)));
}
})
.def("get_arg",
[](ReqeustContextWithSession &req, size_t offset) {
void *ptr = req.Context().GetArg(offset);
Expand Down
Loading