Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Continuation] Merge EmbeddedLLM/vllm-rocm into vLLM main #1836

Merged
merged 63 commits into from
Dec 8, 2023

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Nov 29, 2023

Add ROCm- Support

  • Dynamic code path selection for CUDA or ROCm in PyTorch
  • Llama2 support
  • SqueezeLLM ROCm
  • add documentation amd-installation.rst. Describing how to setup vLLM ROCm version.
  • format.sh all the code
  • Prepare amd.Dockerfile

As there are too many changes has been made after #1749 ,
the previous PR #1749 is closed as and continued here.

PR Authors:
@kliuae
@iAmir97
@tjtanaa
@tanpinsiang

Contributer:
@pcmortiz

This pull request also incorporates the work from Port most vLLM kernels to ROCm #1313 by @pcmoritz, which was not merged. We appreciate @pcmoritz's contribution.

@WoosukKwon WoosukKwon added the rocm label Dec 1, 2023
@WoosukKwon
Copy link
Collaborator

Thanks for the amazing work. I just verified the Dockerfile.rocm and ran the benchmarking on llama2-7b model on MI210. The thruput is:

Throughput: 0.89 requests/s, 424.90 tokens/s

Hi @hongxiayang Could you provide what your benchmark setting is? I'm wondering because it seems quite lower than what we got from A100-80GB GPUs on the ShareGPT benchmark.

@hongxiayang
Copy link
Collaborator

hongxiayang commented Dec 7, 2023

Hi @hongxiayang Could you provide what your benchmark setting is? I'm wondering because it seems quite lower than what we got from A100-80GB GPUs on the ShareGPT benchmark.

The setting is as below: 1 gpu only on MI210 (dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0).
MI250 will be better, and I could run the same to get the numbers.
@WoosukKwon what is your setting on A100?

Namespace(backend='vllm', dataset='/app/dataset/ShareGPT_V3_unfiltered_cleaned_split.json', input_len=None, output_len=None, model='/app/model', tokenizer='/app/model', quantization=None, tensor_parallel_size=1, n=1, use_beam_search=False, num_prompts=1000, seed=0, hf_max_batch_size=None, trust_remote_code=False, max_model_len=None, dtype='auto')
INFO 12-05 22:40:38 llm_engine.py:73] Initializing an LLM engine with config: model='/app/model', tokenizer='/app/model', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0)
WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for:
    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.0.1+gita61a294)
    Python  3.10.13 (you have 3.10.13)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/cuda/__init__.py:546: UserWarning: Can't initialize NVML
  warnings.warn("Can't initialize NVML")
INFO 12-05 22:41:11 llm_engine.py:222] # GPU blocks: 5705, # CPU blocks: 512
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [18:45<00:00,  1.13s/it]
Throughput: 0.89 requests/s, 424.90 tokens/s

Update the vLLM installation procedures on AMD platform.
Update vLLM documentations.
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Dec 7, 2023

@WoosukKwon It is ready for another review. Thank you very much.

@WoosukKwon
Copy link
Collaborator

@hongxiayang This is my benchmark result on llama-7b and ShareGPT (benchmark_throughput.py), which is quite different from your results but seems more reasonable.

GPU A100 MI210x
TFLOPs (FP16) 312 181
Memory capacity 80 GB 64 GB
Memory bandwidth 1.9 TB/s 1.6 TB/s
Throughput 8.30 reqs/s 5.25 reqs/s

@hongxiayang
Copy link
Collaborator

hongxiayang commented Dec 7, 2023

@hongxiayang This is my benchmark result on llama-7b and ShareGPT (benchmark_throughput.py), which is quite different from your results but seems more reasonable.

GPU A100 MI210x
TFLOPs (FP16) 312 181
Memory capacity 80 GB 64 GB
Memory bandwidth 1.9 TB/s 1.6 TB/s
Throughput 8.30 reqs/s 5.25 reqs/s

Wow, Is this one gpu, or 8 GPUs? your number is quite different from mine, and I am wondering whether we had the same parameters when running the test, like INFO 12-05 22:40:38 llm_engine.py:73] Initializing an LLM engine with config: model='/app/model', tokenizer='/app/model', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, seed=0)

edit: My MI210 was wonky when I ran the test. Your number is valid.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Thanks again for the great work! I found the code super clean and well organized. I also like the detailed documentation and the provided docker image. I could run vLLM on MI210 very smoothly and the performance was great! Thanks a lot for the contribution.

Left some minor comments on the code style. Please take a look!

vllm/utils.py Outdated Show resolved Hide resolved
vllm/engine/ray_utils.py Show resolved Hide resolved
-v <path/to/model>:/app/model \
vllm-rocm \
bash

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep this as is. I tried it out and it worked pretty smoothly!

setup.py Outdated Show resolved Hide resolved
vllm/utils.py Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved

ROOT_DIR = os.path.dirname(__file__)

MAIN_CUDA_VERSION = "12.1"

# Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: Which part of the code makes this requirement? That is, why is gfx8 not supported? While I don't we have to support it, I'd like to know why we don't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way we compiled this list of rocm supported archs is based on what AMD is supporting for ROCm and HIP, furthermore each arch has its own set of assembly instructions we have to make sure the currently used assembly instructions is supported by those archs as well.

To the best of our knowledge, the following are the ARCH requirements needed by different libraries:

  1. Pytorch gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1101
  2. vLLM Custom Ops: gfx90a gfx908 gfx906 gfx1030 gfx1100
  3. Flash-Attention-ROCm: gfx90a gfx940 gfx941 gfx942

Should we use the intersection of all three ARCH requirements instead?

Copy link
Collaborator

@WoosukKwon WoosukKwon Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Thanks for the detailed explanation. Sorry, I have little background on this stuff. Maybe I should learn more about ROCm and AMD GPUs 😂

As far as I understand, the vLLM custom ops support every "recent" AMD GPUs, and currently the supported GPU list is limited by the ROCm Flash Attention. Is this correct?

Copy link
Contributor Author

@tjtanaa tjtanaa Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon We believe in near future, the supported GPU ARCH is going to be restricted by Flash Attention ROCm.

Copy link
Collaborator

@hongxiayang hongxiayang Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi: The supported gfx arch for ROCm is documented here (as "LLVM target" column): https://rocm.docs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus.

setup.py Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
docs/source/getting_started/amd-installation.rst Outdated Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator

@hongxiayang It's LLaMA2-7B on a single MI210x. Basically, it should be the same setup as yours.

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Dec 8, 2023

@tjtanaa Thanks again for the great work! I found the code super clean and well organized. I also like the detailed documentation and the provided docker image. I could run vLLM on MI210 very smoothly and the performance was great! Thanks a lot for the contribution.

Left some minor comments on the code style. Please take a look!

@WoosukKwon We have done updating the code style. and replied to your questions regarding to the supported ARCHs.

@WoosukKwon
Copy link
Collaborator

@tjtanaa @kliuae LGTM! Many thanks again for the wonderful work! This will be HUGE!!

@WoosukKwon WoosukKwon merged commit 6ccc0bf into vllm-project:main Dec 8, 2023
2 checks passed
@tjtanaa tjtanaa mentioned this pull request Dec 10, 2023
11 tasks
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: Philipp Moritz <[email protected]>
Co-authored-by: Amir Balwel <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Co-authored-by: kuanfu <[email protected]>
Co-authored-by: miloice <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants