-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Continuation] Merge EmbeddedLLM/vllm-rocm into vLLM main #1836
Merged
Merged
Changes from 44 commits
Commits
Show all changes
63 commits
Select commit
Hold shift + click to select a range
43af310
port dtype_float16.cuh and cache_kernels.cu
pcmoritz cc81866
port dtype_bfloat16.cuh
pcmoritz 475b5e2
port attention_utils.cuh
pcmoritz ddc496c
port more kernels
pcmoritz 5eaa7a1
fix typo
pcmoritz f7273c6
add cuda_compat.h
pcmoritz 99c3be7
Merge branch 'main' into port-to-rocm
pcmoritz f8093dc
sync branches
pcmoritz 41df689
update
pcmoritz 93be9c5
update
pcmoritz d96fa3c
fixes
pcmoritz 421365b
cleanup
pcmoritz 06b800e
update
pcmoritz 2312beb
update
pcmoritz 2958b39
update
pcmoritz 3f89734
fmt
pcmoritz 5397a57
cleanup
pcmoritz 90e02d2
refactor
pcmoritz a420202
update
pcmoritz b072182
Merge branch 'main' into port-to-rocm
pcmoritz 2d1e435
detecting rocm and adding flag for compiling
iAmir97 e231b79
using asm volatile instead of hip api
iAmir97 31bb335
using asm volatile for type casting of f16
iAmir97 b027d06
Hipifying csrc file to accomodate rocm builds
kliuae 9a1781c
Checked CUDA ROCm Compatibility (#15)
tjtanaa 0f67117
merged with latest upstream
kliuae 7dbf2d4
format code
kliuae 52ffcf0
downgrade torch requirement in toml to torch 2.0.1 to accommodate ROC…
kliuae 27f0513
Merged changes from vllm main
kliuae 5cce649
Merged with changes in vllm main
kliuae 16d3ccc
Updated Dockerfile, rocm installation guide and setuppy
kliuae d764f9d
Updated amd installation guide and dockerfile
kliuae e798632
Added num_gpus for ray init in ROCm
kliuae 0e8129f
Synced torch version with vllm main in pyproject.toml
kliuae 2b3821b
Format code
kliuae 0c8795a
Merge branch 'main' into vllm-cuda-rocm-dev
kliuae 5793f30
Updated dockerfile.rocm and requirements-rocm.txt
kliuae b172cdd
Disable mistral for ROCm
kliuae 9cd5b18
Format code
kliuae b86f88a
Revert to cuda kernels
kliuae 9727ab4
Merge remote-tracking branch 'pcmoritz/port-to-rocm'
kliuae c4aa2af
Port latest kernels to ROCm
kliuae f8c304e
Update readme
kliuae e608c30
Cleaned up kernel code
kliuae 951e225
Added wrapper for setting devFuncAttributeMaxDynamicSharedMemorySize
kliuae 25f9a97
Added wrapper for setting devFuncAttributeMaxDynamicSharedMemorySize
kliuae e984ada
Updated ROCm warp size
kliuae cc1195f
Format code
kliuae f92980e
Check hip from wrapper
kliuae 66b4aa1
Format code
kliuae 4a0ecb8
Enable support for mistral models
kliuae acf51a8
Fixed hip device attribute
kliuae 4a52977
Format code
kliuae 23a987a
Restored awq file
kliuae 8787a4e
Format code
kliuae 5911131
Merge latest vllm main
kliuae 9fa8075
Updated rocm dockerfile
kliuae 81e052d
Update amd installation guide
kliuae fb8ac26
Update vLLM Documentations (#18)
tjtanaa 98f5487
Updated setup.py, vllm/utils.py and amd-installation doc
kliuae d90187a
Updated setup.py
kliuae c840531
Format code
kliuae 9dba1d8
Merge branch 'main' into vllm-cuda-rocm-mod
kliuae File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,3 +177,7 @@ _build/ | |
# vim swap files | ||
*.swo | ||
*.swp | ||
|
||
# hip files generated by PyTorch | ||
*.hip | ||
*_hip* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 | ||
|
||
# Install some basic utilities | ||
RUN apt-get update && apt-get install python3 python3-pip -y | ||
|
||
# Install some basic utilities | ||
RUN apt-get update && apt-get install -y \ | ||
curl \ | ||
ca-certificates \ | ||
sudo \ | ||
git \ | ||
bzip2 \ | ||
libx11-6 \ | ||
build-essential \ | ||
wget \ | ||
unzip \ | ||
nvidia-cuda-toolkit \ | ||
tmux \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
|
||
### Mount Point ### | ||
# When launching the container, mount the code directory to /app | ||
ARG APP_MOUNT=/app | ||
VOLUME [ ${APP_MOUNT} ] | ||
WORKDIR ${APP_MOUNT} | ||
|
||
RUN python3 -m pip install --upgrade pip | ||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas | ||
|
||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer | ||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: | ||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: | ||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: | ||
ENV PYTORCH_ROCM_ARCH=gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1101 | ||
|
||
# Install ROCm flash-attention | ||
RUN mkdir libs \ | ||
&& cd libs \ | ||
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ | ||
&& cd flash-attention \ | ||
&& git checkout 3d2b6f5 \ | ||
&& git submodule update --init \ | ||
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \ | ||
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ | ||
&& python3 setup.py install \ | ||
&& cd .. | ||
|
||
COPY ./ /app/vllm | ||
|
||
RUN python3 -m pip install --upgrade pip | ||
RUN pip install xformers==0.0.22.post7 --no-deps | ||
|
||
RUN cd /app \ | ||
&& cd vllm \ | ||
&& pip install -U -r requirements-rocm.txt \ | ||
&& bash patch_xformers-0.0.22.post7.rocm.sh \ | ||
&& python3 setup.py install \ | ||
&& cd .. | ||
|
||
RUN python3 -m pip install --upgrade pip | ||
RUN python3 -m pip install --no-cache-dir ray[all] | ||
|
||
CMD ["/bin/bash"] | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gfx900 can be removed from the list.
Is gfx1101 tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pytorch is already built in the pulled
rocm/pytorch
base image so probably don't need to overwrite this env variable. Thank you for pointing this out.