Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port most vLLM kernels to ROCm #1313

Closed
wants to merge 23 commits into from
Closed

Conversation

pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Oct 10, 2023

This ports most vLLM kernels to ROCm (with the exception of the quantization_ops which is not critical to run some of the models).

If you have a working ROCm installation, you can compile this with python setup.py develop or python setup.py install (on my local machine, pip install -e . is NOT working, not sure if that is generic or specific to my setup).

@pcmoritz pcmoritz mentioned this pull request Oct 10, 2023
@@ -64,66 +60,6 @@ def get_torch_arch_list() -> Set[str]:
return set(arch_list)


# First, check the TORCH_CUDA_ARCH_LIST environment variable.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can put all of this code into a function? Then we can just skip it if torch.version.hip is set :)

If you have a different preference, let me know!

@ardfork
Copy link

ardfork commented Oct 10, 2023

It does build, but I don't think I have enough VRAM to run any unquantized model. Also, note that it require ROCm 5.7, amd_hip_bf16.h didn't exist before.

Anyway, doesn't it have a hard dependency on xFormers? Just trying to call vllm.entrypoints.openai.api_server throw ModuleNotFoundError: No module named 'xformers'. All the xFormers part need to have a different path if it's not installed.

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Oct 10, 2023

The first step here will be to get one of the smaller models like OPT-125M working, which doesn't need AWQ. For dealing with xformers my plan is to work with https://github.com/pcmoritz/vllm-public/tree/flash-attn for the time being :)

You are right this needs ROCm 5.7, which the pytorch nightlies already support now.

@casper-hansen
Copy link
Contributor

@pcmoritz Curious if you would mind looking into porting the AWQ quantization kernels to ROCm too? Would be a benefit to everyone running quantized models.

@pcmoritz
Copy link
Collaborator Author

@casper-hansen It depends on how the rest of the port goes and how the performance numbers for the rest looks like. I hope we can get some help from AMD to port them since it is likely more involved :)

@pcmoritz
Copy link
Collaborator Author

@WoosukKwon Can you have a look at the PR? All the vLLM layers are working now on AMD hardware except the ones that depend on xformers (with the following patch that I didn't want to merge). Make sure to have a closer look at and test the refactoring in setup.py since I haven't tested it on nvidia hardware :)

(base) anyscale@phx2-pc2-017:/tmp/vllm-public$ git diff
diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py
index 0b3ad0a..44aa944 100644
--- a/tests/kernels/test_activation.py
+++ b/tests/kernels/test_activation.py
@@ -72,4 +72,4 @@ def test_gelu_fast(
     out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
     activation_ops.gelu_fast(out, x)
     ref_out = get_activation("gelu_fast")(x)
-    assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
+    assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3)
diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py
index 31d78dd..82ecc09 100644
--- a/tests/kernels/test_attention.py
+++ b/tests/kernels/test_attention.py
@@ -12,7 +12,7 @@ from vllm.utils import get_max_shared_memory_bytes
 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
 # This will change depending on the compute capability.
 # - 512 as a buffer
-MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
+MAX_SEQ_LEN = 8192
 NUM_BLOCKS = 128  # Arbitrary values for testing
 PARTITION_SIZE = 512
 
diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py
index d660417..0f66a9e 100644
--- a/tests/kernels/test_pos_encoding.py
+++ b/tests/kernels/test_pos_encoding.py
@@ -170,5 +170,5 @@ def test_rotary_embedding(
     ref_key = ref_key.view(num_tokens, num_heads * head_size)
 
     # Compare the results.
-    assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
-    assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
+    assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-3)
+    assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-3)
(base) anyscale@phx2-pc2-017:/tmp/vllm-public$ ls tests/kernels/
conftest.py  __pycache__  test_activation.py  test_attention.py  test_cache.py  test_layernorm.py  test_pos_encoding.py
(base) anyscale@phx2-pc2-017:/tmp/vllm-public$ python -m pytest -s tests/kernels/test_activation.py 
WARNING 10-16 23:45:51 ray_utils.py:35] Failed to import Ray with ModuleNotFoundError("No module named 'msgpack'"). For distributed inference, please install Ray with `pip install ray pandas pyarrow`.
=========================================================================================================== test session starts ============================================================================================================
platform linux -- Python 3.11.4, pytest-7.4.2, pluggy-1.0.0
rootdir: /tmp/vllm-public
collected 108 items                                                                                                                                                                                                                        

tests/kernels/test_activation.py ............................................................................................................

=========================================================================================================== 108 passed in 0.50s ============================================================================================================
(base) anyscale@phx2-pc2-017:/tmp/vllm-public$ python -m pytest -s tests/kernels/test_cache.py
WARNING 10-16 23:46:02 ray_utils.py:35] Failed to import Ray with ModuleNotFoundError("No module named 'msgpack'"). For distributed inference, please install Ray with `pip install ray pandas pyarrow`.
=========================================================================================================== test session starts ============================================================================================================
platform linux -- Python 3.11.4, pytest-7.4.2, pluggy-1.0.0
rootdir: /tmp/vllm-public
collected 270 items                                                                                                                                                                                                                        

tests/kernels/test_cache.py ..............................................................................................................................................................................................................................................................................

=========================================================================================================== 270 passed in 6.21s ============================================================================================================
(base) anyscale@phx2-pc2-017:/tmp/vllm-public$ python -m pytest -s tests/kernels/test_layernorm.py
WARNING 10-16 23:46:15 ray_utils.py:35] Failed to import Ray with ModuleNotFoundError("No module named 'msgpack'"). For distributed inference, please install Ray with `pip install ray pandas pyarrow`.
=========================================================================================================== test session starts ============================================================================================================
platform linux -- Python 3.11.4, pytest-7.4.2, pluggy-1.0.0
rootdir: /tmp/vllm-public
collected 45 items                                                                                                                                                                                                                         

tests/kernels/test_layernorm.py .............................................

============================================================================================================ 45 passed in 0.44s ============================================================================================================
(base) anyscale@phx2-pc2-017:/tmp/vllm-public$ python -m pytest -s tests/kernels/test_pos_encoding.py 
WARNING 10-16 23:46:21 ray_utils.py:35] Failed to import Ray with ModuleNotFoundError("No module named 'msgpack'"). For distributed inference, please install Ray with `pip install ray pandas pyarrow`.
=========================================================================================================== test session starts ============================================================================================================
platform linux -- Python 3.11.4, pytest-7.4.2, pluggy-1.0.0
rootdir: /tmp/vllm-public
collected 864 items                                                                                                                                                                                                                        

tests/kernels/test_pos_encoding.py ................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

=========================================================================================================== 864 passed in 3.77s ============================================================================================================
(base) anyscale@phx2-pc2-017:/tmp/vllm-public$ python -m pytest -s tests/kernels/test_attention.py 
WARNING 10-16 23:46:37 ray_utils.py:35] Failed to import Ray with ModuleNotFoundError("No module named 'msgpack'"). For distributed inference, please install Ray with `pip install ray pandas pyarrow`.
=========================================================================================================== test session starts ============================================================================================================
platform linux -- Python 3.11.4, pytest-7.4.2, pluggy-1.0.0
rootdir: /tmp/vllm-public
collecting ... Deterministic: False
Performance Mode: True
collected 324 items                                                                                                                                                                                                                        

tests/kernels/test_attention.py ................................................................................................................................................................................................................................................................................................FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF

(the failures are the ones from xformers modules not being available)

@WoosukKwon
Copy link
Collaborator

Hi @pcmoritz, huge thanks for the great work! I will definitely take a look. As for the FlashAttention and xformers, we will probably be able to collaborate with the AI teams at AMD.

@WoosukKwon
Copy link
Collaborator

@pcmoritz @casper-hansen Speaking of the AWQ kernels, I believe they are for temporary use. We plan to implement much faster kernels (probably using Triton) in the near future.

@pcmoritz
Copy link
Collaborator Author

@iAmir97 converted the script to AMD assembler instructions in https://github.com/pcmoritz/vllm-public/pull/1/files, let me try if I can cherrypick the commits into this PR and if the tests still pass, then maybe that's the better solution since the performance might be better :)

@pcmoritz
Copy link
Collaborator Author

The tests are passing, so I changed the code.

I'll be at skycamp tomorrow @WoosukKwon, if you want and have time, we can chat / hack some more on this :)

@fsx950223
Copy link

When I tried flash_attn branch, I got NameError: name 'BlockDiagonalCausalMask' is not defined.

@fsx950223
Copy link

Does asm work as expect?

@pcmoritz
Copy link
Collaborator Author

@fsx950223 It is passing the unit tests for the layers, so seems to be working for me. On the BlockDiagonalCausalMask error -- there are some more changes that need to be made to get flash attention working, see #1314. It is not working fully for me yet but I'm also planning to spend some more time on it -- let me know if you'd like to help, it should be possible to get it working :)

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Oct 20, 2023

@sabreshao Thanks, I'm planning to work more on flash attention integration, currently I'm working with the latest master in https://github.com/ROCmSoftwarePlatform/flash-attention, let me know if you think using a different code base will make it more likely to succeed :) Currently the focus is on correctness and not speed yet :)

@fsx950223
Copy link

fsx950223 commented Oct 21, 2023

@fsx950223 It is passing the unit tests for the layers, so seems to be working for me. On the BlockDiagonalCausalMask error -- there are some more changes that need to be made to get flash attention working, see #1314. It is not working fully for me yet but I'm also planning to spend some more time on it -- let me know if you'd like to help, it should be possible to get it working :)

Could you share the development environment with me (maybe a docker file)? I met some issues on my local environment. And which type of GPU you are using, MI250 or MI300?

@jingqiao
Copy link

@pcmoritz @WoosukKwon @sabreshao

Below is a workable solution to enable Flash Attention for vLLM per my testing. Pls just treat it as a reference, and I am sure you can make further optimizations.

To be specific, it is to follow the instruction at https://github.com/ROCmSoftwarePlatform/flash-attention except that, instead of using ROCm 5.4, we can use ROCm5.7 (already supported by pytorch nightly), the same version @pcmoritz used above.

To be specific, one can replace the following code in vLLM

        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=input_metadata.attn_bias,
            p=0.0,
            scale=self.scale,
        )

with the following code (similar to the code here)

        from einops import repeat
        from flash_attn.bert_padding import unpad_input, pad_input
        from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func

        # Use Flash Attention (with ROCm support)
        batch_size = len(input_metadata.prompt_lens)
        seq_len = input_metadata.max_prompt_len

        lengths = torch.tensor(input_metadata.prompt_lens, device=query.device).view(-1,1)
        attention_mask_bool = repeat(
            torch.arange(seq_len, device=query.device),
            's -> b s', b=batch_size
        ) < lengths

        qkv = torch.stack([query, key, value], dim=1).view(
            batch_size, seq_len, 3, self.num_heads, self.head_size
        )
        qkv_unpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(qkv, attention_mask_bool)

        out_unpad = flash_attn_unpadded_qkvpacked_func(
            qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p=0.0, causal=True
        )
        out = pad_input(out_unpad, indices, batch_size, seq_len).view(
                  -1, self.num_heads, self.head_size
        )

@WoosukKwon
Copy link
Collaborator

Hi @pcmoritz, I got the following error when building vLLM:

In file included from csrc/attention/attention_kernels.hip:23:
    In file included from csrc/attention/attention_dtypes_hip.h:7:
    In file included from csrc/attention/dtype_bfloat16_hip.cuh:29:
    /opt/rocm-5.6.0/include/hip/hip_bf16.h:30:10: fatal error: 'hip/amd_detail/amd_hip_bf16.h' file not found
    #include <hip/amd_detail/amd_hip_bf16.h>
             ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    1 error generated when compiling for gfx1030.
    error: command '/opt/rocm-5.6.0/bin/hipcc' failed with exit code 1

Do you happen to know how to resolve this?

I'm using PyTorch 2.1.0 with ROCm 5.6.0. My local ROCm version is also 5.6.0.

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Oct 22, 2023 via email

@WoosukKwon
Copy link
Collaborator

@pcmoritz I see. Then, which version of PyTorch did you use?

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Oct 22, 2023 via email

@fxmarty
Copy link

fxmarty commented Oct 24, 2023

@pcmoritz Thank you for the port. Are you aware of a way to avoid using amd_hip_bf16.h & to compile against RoCm 5.6?

@ehartford
Copy link

need to support 5.6 because thats what pytorch supports now

@fxmarty
Copy link

fxmarty commented Nov 2, 2023

@ehartford Yes it would be nice. So far I have been using nightly and it works quite nicely!

@ehartford
Copy link

Ok I figured out how to install both versions of rocm so that I can use 5.7 with vllm and 5.6 with pytorch

@ehartford
Copy link

what's preventing this from being merged?

@ehartford
Copy link

ehartford commented Nov 3, 2023

(vllm) eric@quixi1:~/vllm-public$ python setup.py install
Traceback (most recent call last):
  File "/home/eric/vllm-public/setup.py", line 28, in <module>
    raise RuntimeError(
RuntimeError: Cannot find CUDA_HOME. CUDA must be available to build the package.

@pcmoritz do you know what I did wrong?

Of course there's no CUDA_HOME because it's ROCm.

@iAmir97
Copy link
Contributor

iAmir97 commented Nov 3, 2023

@WoosukKwon Do you have any idea when the changing of AWQ kernels will happen? planning to port them to rocm.

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Nov 4, 2023

@ehartford Are you on the right branch? In this branch, this should be handled (https://github.com/vllm-project/vllm/pull/1313/files#diff-60f61ab7a8d1910d86d9fda2261620314edcae5894d5aaa236b821c7256badd7R27) -- your error seems to indicate that the raise RuntimeError line is line 28, which is not the case in this branch :)

@Poisonsting
Copy link

Hi, I have a 7900 XTX and a burning desire to run AWQ models. I see there hasn't been much activity on this issue since October, are you still waiting for more mature AWQ kernels? Is there anything I can do to help?

@WoosukKwon
Copy link
Collaborator

Closed as we merged #1836 which is a superset of this PR. @pcmoritz Thanks for the amazing work!

@WoosukKwon WoosukKwon closed this Dec 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.