From 9c905e4849f5447d4ffbc77c42e8a379d10b3a32 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Wed, 29 Nov 2023 11:27:03 +0800 Subject: [PATCH] feat: add rocm support (#913) * Added build configurations for Intel and AMD hardware * Improved rocm build * Added options for OneAPI and ROCm * Build llama using icx * [autofix.ci] apply automated fixes * Fixed rocm image * Build ROCm * Tried to adjust compile flags for SYCL * Removed references to oneAPI * Provide info about the used device for ROCm * Added ROCm documentation * Addressed review comments * Refactored to expose generic accelerator information * Pull request cleanup * cleanup * cleanup * Delete .github/workflows/docker-cuda.yml * Delete .github/workflows/docker-rocm.yml * Delete crates/tabby-common/src/api/accelerator.rs * update * cleanup * update * update * update * update --------- Co-authored-by: Cromefire_ Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- crates/llama-cpp-bindings/Cargo.toml | 1 + crates/llama-cpp-bindings/build.rs | 37 +++++++++++++++++++++++++++- crates/tabby/Cargo.toml | 1 + crates/tabby/src/main.rs | 15 ++++++++++- 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/crates/llama-cpp-bindings/Cargo.toml b/crates/llama-cpp-bindings/Cargo.toml index 4e8af380432..a054d36f73a 100644 --- a/crates/llama-cpp-bindings/Cargo.toml +++ b/crates/llama-cpp-bindings/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [features] cuda = [] +rocm = [] [build-dependencies] cxx-build = "1.0" diff --git a/crates/llama-cpp-bindings/build.rs b/crates/llama-cpp-bindings/build.rs index c6219478071..f9b825a99e6 100644 --- a/crates/llama-cpp-bindings/build.rs +++ b/crates/llama-cpp-bindings/build.rs @@ -1,4 +1,4 @@ -use std::path::Path; +use std::{env, path::Path}; use cmake::Config; @@ -32,6 +32,41 @@ fn main() { println!("cargo:rustc-link-lib=cublas"); println!("cargo:rustc-link-lib=cublasLt"); } + if cfg!(feature = "rocm") { + let amd_gpu_targets: Vec<&str> = vec![ + "gfx803", + "gfx900", + "gfx906:xnack-", + "gfx908:xnack-", + "gfx90a:xnack+", + "gfx90a:xnack-", + "gfx940", + "gfx941", + "gfx942", + "gfx1010", + "gfx1012", + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + ]; + + let rocm_root = env::var("ROCM_ROOT").unwrap_or("/opt/rocm".to_string()); + config.define("LLAMA_HIPBLAS", "ON"); + config.define("CMAKE_C_COMPILER", format!("{}/llvm/bin/clang", rocm_root)); + config.define( + "CMAKE_CXX_COMPILER", + format!("{}/llvm/bin/clang++", rocm_root), + ); + config.define("AMDGPU_TARGETS", amd_gpu_targets.join(";")); + println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries"); + println!("cargo:rustc-link-search=native={}/hip/lib", rocm_root); + println!("cargo:rustc-link-search=native={}/rocblas/lib", rocm_root); + println!("cargo:rustc-link-search=native={}/hipblas/lib", rocm_root); + println!("cargo:rustc-link-lib=amdhip64"); + println!("cargo:rustc-link-lib=rocblas"); + println!("cargo:rustc-link-lib=hipblas"); + } let dst = config.build(); println!("cargo:rustc-link-search=native={}/build", dst.display()); diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 577a118179f..98549330071 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" default = ["ee"] ee = ["dep:tabby-webserver"] cuda = ["llama-cpp-bindings/cuda"] +rocm = ["llama-cpp-bindings/rocm"] experimental-http = ["dep:http-api-bindings"] [dependencies] diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index fe7c4a4ad7b..e13f13849ad 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -69,6 +69,10 @@ pub enum Device { #[strum(serialize = "cuda")] Cuda, + #[cfg(feature = "rocm")] + #[strum(serialize = "rocm")] + Rocm, + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[strum(serialize = "metal")] Metal, @@ -89,7 +93,16 @@ impl Device { *self == Device::Cuda } - #[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))] + #[cfg(feature = "rocm")] + pub fn ggml_use_gpu(&self) -> bool { + *self == Device::Rocm + } + + #[cfg(not(any( + all(target_os = "macos", target_arch = "aarch64"), + feature = "cuda", + feature = "rocm", + )))] pub fn ggml_use_gpu(&self) -> bool { false }