Skip to content

Commit

Permalink
feat: add rocm support (#913)
Browse files Browse the repository at this point in the history
* Added build configurations for Intel and AMD hardware

* Improved rocm build

* Added options for OneAPI and ROCm

* Build llama using icx

* [autofix.ci] apply automated fixes

* Fixed rocm image

* Build ROCm

* Tried to adjust compile flags for SYCL

* Removed references to oneAPI

* Provide info about the used device for ROCm

* Added ROCm documentation

* Addressed review comments

* Refactored to expose generic accelerator information

* Pull request cleanup

* cleanup

* cleanup

* Delete .github/workflows/docker-cuda.yml

* Delete .github/workflows/docker-rocm.yml

* Delete crates/tabby-common/src/api/accelerator.rs

* update

* cleanup

* update

* update

* update

* update

---------

Co-authored-by: Cromefire_ <[email protected]>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 29, 2023
1 parent 2b131ad commit 9c905e4
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
1 change: 1 addition & 0 deletions crates/llama-cpp-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[features]
cuda = []
rocm = []

[build-dependencies]
cxx-build = "1.0"
Expand Down
37 changes: 36 additions & 1 deletion crates/llama-cpp-bindings/build.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::path::Path;
use std::{env, path::Path};

use cmake::Config;

Expand Down Expand Up @@ -32,6 +32,41 @@ fn main() {
println!("cargo:rustc-link-lib=cublas");
println!("cargo:rustc-link-lib=cublasLt");
}
if cfg!(feature = "rocm") {
let amd_gpu_targets: Vec<&str> = vec![
"gfx803",
"gfx900",
"gfx906:xnack-",
"gfx908:xnack-",
"gfx90a:xnack+",
"gfx90a:xnack-",
"gfx940",
"gfx941",
"gfx942",
"gfx1010",
"gfx1012",
"gfx1030",
"gfx1100",
"gfx1101",
"gfx1102",
];

let rocm_root = env::var("ROCM_ROOT").unwrap_or("/opt/rocm".to_string());
config.define("LLAMA_HIPBLAS", "ON");
config.define("CMAKE_C_COMPILER", format!("{}/llvm/bin/clang", rocm_root));
config.define(
"CMAKE_CXX_COMPILER",
format!("{}/llvm/bin/clang++", rocm_root),
);
config.define("AMDGPU_TARGETS", amd_gpu_targets.join(";"));
println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries");
println!("cargo:rustc-link-search=native={}/hip/lib", rocm_root);
println!("cargo:rustc-link-search=native={}/rocblas/lib", rocm_root);
println!("cargo:rustc-link-search=native={}/hipblas/lib", rocm_root);
println!("cargo:rustc-link-lib=amdhip64");
println!("cargo:rustc-link-lib=rocblas");
println!("cargo:rustc-link-lib=hipblas");
}

let dst = config.build();
println!("cargo:rustc-link-search=native={}/build", dst.display());
Expand Down
1 change: 1 addition & 0 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"
default = ["ee"]
ee = ["dep:tabby-webserver"]
cuda = ["llama-cpp-bindings/cuda"]
rocm = ["llama-cpp-bindings/rocm"]
experimental-http = ["dep:http-api-bindings"]

[dependencies]
Expand Down
15 changes: 14 additions & 1 deletion crates/tabby/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ pub enum Device {
#[strum(serialize = "cuda")]
Cuda,

#[cfg(feature = "rocm")]
#[strum(serialize = "rocm")]
Rocm,

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,
Expand All @@ -89,7 +93,16 @@ impl Device {
*self == Device::Cuda
}

#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
#[cfg(feature = "rocm")]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Rocm
}

#[cfg(not(any(
all(target_os = "macos", target_arch = "aarch64"),
feature = "cuda",
feature = "rocm",
)))]
pub fn ggml_use_gpu(&self) -> bool {
false
}
Expand Down

0 comments on commit 9c905e4

Please sign in to comment.