Skip to content

Commit

Permalink
bugfix: fix pin memory device (#755)
Browse files Browse the repository at this point in the history
right now flashinfer does not specify the device when creating
pin-memory tensor. it will cause error when users changed the default
pytorch device.

we should explicitly set the device to be `"cpu"` for pin-memory tensor.

---------

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 27, 2025
1 parent aab8715 commit 5243043
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 6 additions & 2 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,8 @@ def __init__(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True
(8 * 1024 * 1024,), dtype=torch.uint8,
pin_memory=True, device="cpu",
)

if use_cuda_graph:
Expand Down Expand Up @@ -718,6 +719,7 @@ def reset_workspace_buffer(
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)

Expand Down Expand Up @@ -1277,7 +1279,8 @@ def __init__(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
(8 * 1024 * 1024,), dtype=torch.uint8, pin_memory=True
(8 * 1024 * 1024,), dtype=torch.uint8,
pin_memory=True, device="cpu",
)

if use_cuda_graph:
Expand Down Expand Up @@ -1330,6 +1333,7 @@ def reset_workspace_buffer(
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)

Expand Down
8 changes: 6 additions & 2 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ def __init__(
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)
self._use_cuda_graph = use_cuda_graph
Expand Down Expand Up @@ -1165,6 +1166,7 @@ def reset_workspace_buffer(
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)

Expand Down Expand Up @@ -1415,7 +1417,7 @@ def plan(
if page_size != 1:
vector_sparse_indptr_host = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
torch.tensor([0], dtype=torch.int32, device=kv_lens_arr_host.device),
torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32),
],
dim=0,
Expand Down Expand Up @@ -1858,7 +1860,8 @@ def __init__(
(8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
)
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True
self._int_workspace_buffer.shape, dtype=torch.uint8,
pin_memory=True, device="cpu",
)
self._use_cuda_graph = use_cuda_graph
if use_cuda_graph:
Expand Down Expand Up @@ -1911,6 +1914,7 @@ def reset_workspace_buffer(
self._pin_memory_int_workspace_buffer = torch.empty(
self._int_workspace_buffer.shape,
dtype=self._int_workspace_buffer.dtype,
device="cpu",
pin_memory=True,
)

Expand Down

0 comments on commit 5243043

Please sign in to comment.