Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MPS compatibility and dtype fixes #81

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 87 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,17 +287,33 @@ generate(
```

### Gradio Demo
We have deployed online demo in [Huggingface](https://huggingface.co/spaces/deepseek-ai/Janus-Pro-7B).
We have deployed an online demo in [Huggingface](https://huggingface.co/spaces/deepseek-ai/Janus-Pro-7B).

For the local Gradio demo, you can run one of the following commands:

For the local gradio demo, you can run with the following command:

**For standard CUDA-based inference:**

```
pip install -e .[gradio]

python demo/app_januspro.py
```

**For Apple Silicon (MPS) users (experimental):**

```
pip install -e .[gradio]

python demo/app_januspro_mps.py
```

This version includes optimizations for Apple Silicon (MPS), using `torch.float16` instead of `torch.bfloat16`.

*Note:* This is an experimental script contributed by the community and has not been officially tested by the DeepSeek team. Please share feedback if you encounter issues!



Have Fun!

</details>
Expand Down Expand Up @@ -710,11 +726,76 @@ Have Fun!

</details>

## 4. License

## 4. Community Contributions

This repository welcomes community contributions that improve the model’s usability across different platforms.

---

### **🔹 Apple Silicon (MPS) Compatibility & Performance Fixes**
**Issue:** The original `app_januspro.py` script ran inference **on the CPU instead of utilizing the MPS (Metal Performance Shaders) backend**, leading to slow performance. Additionally, dtype mismatches between **bfloat16 (input) and float16 (bias)** caused runtime errors.

**Solution:**
- A new script, **`app_januspro_mps.py`**, has been added, optimized for **Apple MPS**.
- The script **prioritizes MPS acceleration when available**, significantly improving performance.
- It **remains compatible with CUDA and CPU**, though **further community testing is encouraged**.

**Key Improvements:**
- **Automatic device selection** (`cuda`, `mps`, or `cpu`).
- **Ensures dtype consistency**:
- **MPS:** `float16` (to prevent dtype mismatches).
- **CUDA:** `bfloat16` (or `float16`, if preferred).
- **CPU:** `float32` (fallback for compatibility).
- **Fixes dtype mismatches** that previously caused crashes on MPS.

**Usage:**
For Apple Silicon (MPS) users, try the new script:
```sh
pip install -e .[gradio]

python demo/app_januspro_mps.py

💡 This script is an experimental addition. If it performs well across all platforms, the community can consider merging improvements into the main app_januspro.py.

---

### **🔹 Fix for RuntimeError: Mismatched DType (`bfloat16` vs `half`)**
**Issue:** Running Janus-Pro-7B on **Apple's MPS backend** previously resulted in:

```
RuntimeError: Input type (c10::BFloat16) and bias type (c10::Half) should be the same
```
This was caused by **dtype mismatches**:
- The **Upsample block** converted inputs to `float32`, applied interpolation, then cast them to `bfloat16`.
- Meanwhile, **convolution layers used `float16`**, leading to an error.
- **Apple’s MPS has partial support for `bfloat16`**, contributing to instability.

**Solution:**
- We **standardized all tensor dtypes to `torch.float16` on MPS** to prevent mismatches.
- This change ensures **stable execution across MPS and CUDA**.

### **🔹 Fixes Applied to `vq_model.py`**
**Issue:** Additional dtype mismatches were identified in **`vq_model.py`**, specifically in the **Upsample module**, where tensor operations introduced unnecessary conversions.

**Solution:**
- The **Upsample module now preserves the original input tensor dtype**, ensuring dtype consistency throughout the pipeline.
- This prevents unexpected dtype mismatches across **MPS, CUDA, and CPU environments**.

### **🔹 Status & Call for Community Testing**
The **new `app_januspro_mps.py` script and dtype fixes in `vq_model.py` have been tested successfully on Apple Silicon**. While initial results indicate improved performance, **further validation from the DeepSeek community is encouraged**.

🚀 If you have expertise in **PyTorch, Apple MPS, or GPU optimization**, your feedback and improvements are welcome!

If you encounter issues, please **open an Issue or submit a Pull Request**.


## 5. License

This code repository is licensed under [the MIT License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-CODE). The use of Janus models is subject to [DeepSeek Model License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-MODEL).

## 5. Citation

## 6. Citation

```bibtex
@misc{chen2025januspro,
Expand All @@ -738,6 +819,7 @@ This code repository is licensed under [the MIT License](https://github.com/deep
}
```

## 6. Contact

## 7. Contact

If you have any questions, please raise an issue or contact us at [[email protected]](mailto:[email protected]).
Loading