From d11f257ee72b569e1d6ca24cfa104a8e57bc85ef Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Thu, 23 Jan 2025 16:10:19 +0800 Subject: [PATCH] Add GPU example for MiniCPM-o-2_6 (#12735) * Add init example for omni mode * Small fix * Small fix * Add chat example * Remove lagecy link * Further update link * Add readme * Small fix * Update main readme link * Update based on comments * Small fix * Small fix * Small fix --- README.md | 1 + README.zh-CN.md | 1 + .../Multimodal/MiniCPM-o-2_6/README.md | 164 ++++++++++++++++++ .../Multimodal/MiniCPM-o-2_6/chat.py | 119 +++++++++++++ .../Multimodal/MiniCPM-o-2_6/omni.py | 133 ++++++++++++++ python/llm/example/GPU/README.md | 2 +- 6 files changed, 419 insertions(+), 1 deletion(-) create mode 100644 python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/README.md create mode 100644 python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/chat.py create mode 100644 python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/omni.py diff --git a/README.md b/README.md index f76aa225a8d..8a06ec54bae 100644 --- a/README.md +++ b/README.md @@ -337,6 +337,7 @@ Over 70 models have been optimized/verified on `ipex-llm`, including *LLaMA/LLaM | MiniCPM-V-2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2) | | MiniCPM-Llama3-V-2_5 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-Llama3-V-2_5) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) | | MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) | +| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) | | StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) | | Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) | | Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) | diff --git a/README.zh-CN.md b/README.zh-CN.md index 53ec503a22f..850adb94557 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -337,6 +337,7 @@ See the demo of running [*Text-Generation-WebUI*](https://ipex-llm.readthedocs.i | MiniCPM-V-2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2) | | MiniCPM-Llama3-V-2_5 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-Llama3-V-2_5) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) | | MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) | +| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) | | StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) | | Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) | | Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) | diff --git a/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/README.md b/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/README.md new file mode 100644 index 00000000000..c9d321081c8 --- /dev/null +++ b/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/README.md @@ -0,0 +1,164 @@ +# MiniCPM-o-2_6 +In this directory, you will find examples on how you could apply IPEX-LLM INT4 optimizations on MiniCPM-o-2_6 model on [Intel GPUs](../../../README.md). For illustration purposes, we utilize [openbmb/MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) as reference MiniCPM-o-2_6 model. + +In the following examples, we will guide you to apply IPEX-LLM optimizations on MiniCPM-o-2_6 model for text/audio/image/video inputs. + +## 0. Requirements & Installation + +To run these examples with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information. + +### 0.1 Install IPEX-LLM + +- For **Intel Core™ Ultra Processors (Series 2) with processor number 2xxV (code name Lunar Lake)** on Windows: + ```cmd + conda create -n llm python=3.11 libuv + conda activate llm + + :: or --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/lnl/cn/ + pip install --pre --upgrade ipex-llm[xpu_lnl] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/lnl/us/ + pip install torchaudio==2.3.1.post0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/lnl/us/ + ``` +- For **Intel Arc B-Series GPU (code name Battlemage)** on Linux: + ```cmd + conda create -n llm python=3.11 + conda activate llm + + # or --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ + pip install --pre --upgrade ipex-llm[xpu-arc] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ + pip install torchaudio==2.3.1.post0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ + ``` + +> [!NOTE] +> We will update for installation on more Intel GPU platforms. + +### 0.2 Install Required Pacakges for MiniCPM-o-2_6 + +```bash +conda activate llm + +# refer to: https://huggingface.co/openbmb/MiniCPM-o-2_6#usage +pip install transformers==4.44.2 trl +pip install librosa==0.9.0 +pip install soundfile==0.12.1 +pip install moviepy +``` + +### 0.3 Runtime Configuration + +- For **Intel Core™ Ultra Processors (Series 2) with processor number 2xxV (code name Lunar Lake)** on Windows: + ```cmd + set SYCL_CACHE_PERSISTENT=1 + ``` +- For **Intel Arc B-Series GPU (code name Battlemage)** on Linux: + ```cmd + unset OCL_ICD_VENDOR + export SYCL_CACHE_PERSISTENT=1 + ``` + +> [!NOTE] +> We will update for runtime configuration on more Intel GPU platforms. + +### 1. Example: Chat in Omni Mode +In [omni.py](./omni.py), we show a use case for a MiniCPM-V-2_6 model to chat in omni mode with IPEX-LLM INT4 optimizations on Intel GPUs. In this example, the model will take a video as input, and conduct inference based on the images and audio of this video. + +For example, the video input shows a clip of an athlete swimming, with background audio asking "What the athlete is doing?". Then the model in omni mode should inference based on the images of the video and the question in audio. + +#### 1.1 Running example + +```bash +python omni.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --video-path VIDEO_PATH +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for MiniCPM-o-2_6 model (e.g. `openbmb/MiniCPM-o-2_6`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'openbmb/MiniCPM-o-2_6'`. +- `--video-path VIDEO_PATH`: argument defining the video input. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. + +> [!TIP] +> In Omni mode, please make sure that the video input contains sound. + +> [!TIP] +> You could just ignore the warning regarding `Some weights of the model checkpoint at xxx were not used when initializing MiniCPMO`. + +### 2. Example: Chat with text/audio/image input +In [chat.py](./chat.py), we show a use case for a MiniCPM-V-2_6 model to chat based on text/audio/image, or a combination of two of them, with IPEX-LLM INT4 optimizations on Intel GPUs. + +#### 2.1 Running example + +- Chat with text input + ```bash + python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT + ``` + +- Chat with audio input + ```bash + python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --audio-path AUDIO_PATH + ``` + +- Chat with image input + ```bash + python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --image-path IMAGE_PATH + ``` + +- Chat with text + audio inputs + ```bash + python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --audio-path AUDIO_PATH + ``` + +- Chat with text + image inputs + ```bash + python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --image-path IMAGE_PATH + ``` + +- Chat with audio + image inputs + ```bash + python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --audio-path AUDIO_PATH --image-path IMAGE_PATH + ``` + + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for MiniCPM-o-2_6 model (e.g. `openbmb/MiniCPM-o-2_6`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'openbmb/MiniCPM-o-2_6'`. +- `--prompt PROMPT`: argument defining the text input. +- `--audio-path AUDIO_PATH`: argument defining the audio input. +- `--image-path IMAGE_PATH`: argument defining the image input. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. + +> [!TIP] +> You could just ignore the warning regarding `Some weights of the model checkpoint at xxx were not used when initializing MiniCPMO`. + +#### 2.2 Sample Outputs + +##### [openbmb/MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) + +The sample input image is (which is fetched from [COCO dataset](https://cocodataset.org/#explore?id=264959)): + +
+http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg + +And the sample audio is a person saying "What is in this image". + +- Chat with text + image inputs + ```log + Inference time: xxxx s + -------------------- Input Image Path -------------------- + 5602445367_3504763978_z.jpg + -------------------- Input Audio Path -------------------- + None + -------------------- Input Prompt -------------------- + What is in this image? + -------------------- Chat Output -------------------- + The image features a young child holding and displaying her white teddy bear. She is wearing a pink dress, which complements the color of the stuffed toy she + ``` + +- Chat with audio + image inputs: + ```log + Inference time: xxxx s + -------------------- Input Image Path -------------------- + 5602445367_3504763978_z.jpg + -------------------- Input Audio Path -------------------- + test_audio.wav + -------------------- Input Prompt -------------------- + None + -------------------- Chat Output -------------------- + In this image, there is a young girl holding and displaying her stuffed teddy bear. She appears to be the main subject of the photo, with her toy + ``` diff --git a/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/chat.py b/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/chat.py new file mode 100644 index 00000000000..5e3b8701a10 --- /dev/null +++ b/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/chat.py @@ -0,0 +1,119 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time +import torch +import librosa +import argparse +from PIL import Image +from transformers import AutoTokenizer +from ipex_llm.transformers import AutoModel + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Chat with MiniCPM-o-2_6 with text/audio/image') + parser.add_argument('--repo-id-or-model-path', type=str, default="openbmb/MiniCPM-o-2_6", + help='The Hugging Face or ModelScope repo id for the MiniCPM-o-2_6 model to be downloaded' + ', or the path to the checkpoint folder') + parser.add_argument('--image-path', type=str, + help='The path to the image for inference.') + parser.add_argument('--audio-path', type=str, + help='The path to the audio for inference.') + parser.add_argument('--prompt', type=str, + help='Prompt for inference.') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + + model_path = args.repo_id_or_model_path + image_path = args.image_path + audio_path = args.audio_path + + modules_to_not_convert = [] + init_vision = False + init_audio = False + if image_path is not None and os.path.exists(image_path): + init_vision = True + modules_to_not_convert += ["vpm", "resampler"] + if audio_path is not None and os.path.exists(audio_path): + init_audio = True + modules_to_not_convert += ["apm"] + + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + model = AutoModel.from_pretrained(model_path, + load_in_low_bit="sym_int4", + optimize_model=True, + trust_remote_code=True, + attn_implementation='sdpa', + use_cache=True, + init_vision=init_vision, + init_audio=init_audio, + init_tts=False, + modules_to_not_convert=modules_to_not_convert) + + model = model.half().to('xpu') + + tokenizer = AutoTokenizer.from_pretrained(model_path, + trust_remote_code=True) + + + # The following code for generation is adapted from + # https://huggingface.co/openbmb/MiniCPM-o-2_6#addressing-various-audio-understanding-tasks and + # https://huggingface.co/openbmb/MiniCPM-o-2_6#chat-with-single-image + content = [] + if init_vision: + image_input = Image.open(image_path).convert('RGB') + content.append(image_input) + if args.prompt is not None: + content.append(args.prompt) + if init_audio: + audio_input, _ = librosa.load(audio_path, sr=16000, mono=True) + content.append(audio_input) + messages = [{'role': 'user', 'content': content}] + + + with torch.inference_mode(): + # ipex_llm model needs a warmup, then inference time can be accurate + model.chat( + msgs=messages, + tokenizer=tokenizer, + sampling=True, + max_new_tokens=args.n_predict, + ) + + st = time.time() + response = model.chat( + msgs=messages, + tokenizer=tokenizer, + sampling=True, + max_new_tokens=args.n_predict, + ) + torch.xpu.synchronize() + end = time.time() + + print(f'Inference time: {end-st} s') + print('-'*20, 'Input Image Path', '-'*20) + print(image_path) + print('-'*20, 'Input Audio Path', '-'*20) + print(audio_path) + print('-'*20, 'Input Prompt', '-'*20) + print(args.prompt) + print('-'*20, 'Chat Output', '-'*20) + print(response) + diff --git a/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/omni.py b/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/omni.py new file mode 100644 index 00000000000..7002f2f72aa --- /dev/null +++ b/python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/omni.py @@ -0,0 +1,133 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time +import math +import torch +import librosa +import argparse +import numpy as np +from PIL import Image +from moviepy import VideoFileClip +from transformers import AutoTokenizer +from ipex_llm.transformers import AutoModel + + +# The video chunk function is adpated from https://huggingface.co/openbmb/MiniCPM-o-2_6#chat-inference +def get_video_chunk_content(video_path, temp_audio_name, flatten=True): + video = VideoFileClip(video_path) + print('video_duration:', video.duration) + + with open(temp_audio_name, 'wb') as temp_audio_file: + temp_audio_file_path = temp_audio_file.name + video.audio.write_audiofile(temp_audio_file_path, codec="pcm_s16le", fps=16000) + audio_np, sr = librosa.load(temp_audio_file_path, sr=16000, mono=True) + num_units = math.ceil(video.duration) + + # 1 frame + 1s audio chunk + contents= [] + for i in range(num_units): + frame = video.get_frame(i+1) + image = Image.fromarray((frame).astype(np.uint8)) + audio = audio_np[sr*i:sr*(i+1)] + if flatten: + contents.extend(["", image, audio]) + else: + contents.append(["", image, audio]) + + return contents + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Chat with MiniCPM-o-2_6 in Omni mode') + parser.add_argument('--repo-id-or-model-path', type=str, default="openbmb/MiniCPM-o-2_6", + help='The Hugging Face or ModelScope repo id for the MiniCPM-o-2_6 model to be downloaded' + ', or the path to the checkpoint folder') + parser.add_argument('--video-path', type=str, required=True, + help='The path to the video, which the model uses to conduct inference ' + 'based on its images and audio.') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + + args = parser.parse_args() + + model_path = args.repo_id_or_model_path + video_path = args.video_path + + # Load model in 4 bit, + # which convert the relevant layers in the model into INT4 format + model = AutoModel.from_pretrained(model_path, + load_in_low_bit="sym_int4", + optimize_model=True, + trust_remote_code=True, + attn_implementation='sdpa', + use_cache=True, + init_vision=True, + init_audio=True, + init_tts=False, + modules_to_not_convert=["apm", "vpm", "resampler"]) + + model = model.half().to('xpu') + + tokenizer = AutoTokenizer.from_pretrained(model_path, + trust_remote_code=True) + + + # The following code for generation is adapted from https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct#quickstart + temp_audio_name = "temp_audio.wav" + contents = get_video_chunk_content(video_path, temp_audio_name) + messages = [{"role":"user", "content": contents}] + + if os.path.exists(temp_audio_name): + os.remove(temp_audio_name) + + with torch.inference_mode(): + # ipex_llm model needs a warmup, then inference time can be accurate + model.chat( + msgs=messages, + tokenizer=tokenizer, + sampling=True, + temperature=0.5, + max_new_tokens=args.n_predict, + omni_input=True, # need to set omni_input=True when omni inference + use_tts_template=False, + generate_audio=False, + max_slice_nums=1, + use_image_id=False, + ) + + st = time.time() + response = model.chat( + msgs=messages, + tokenizer=tokenizer, + sampling=True, + temperature=0.5, + max_new_tokens=args.n_predict, + omni_input=True, # need to set omni_input=True when omni inference + use_tts_template=False, + generate_audio=False, + max_slice_nums=1, + use_image_id=False, + ) + torch.xpu.synchronize() + end = time.time() + + print(f'Inference time: {end-st} s') + print('-'*20, 'Input Video Path', '-'*20) + print(video_path) + print('-'*20, 'Chat Output', '-'*20) + print(response) diff --git a/python/llm/example/GPU/README.md b/python/llm/example/GPU/README.md index 41a86c7678b..46297fa18c3 100644 --- a/python/llm/example/GPU/README.md +++ b/python/llm/example/GPU/README.md @@ -36,4 +36,4 @@ This folder contains examples of running IPEX-LLM on Intel GPU: - Windows 10/11, with or without WSL ## Requirements -To apply Intel GPU acceleration, there’re several steps for tools installation and environment preparation. See the [GPU installation guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) for mode details. \ No newline at end of file +To apply Intel GPU acceleration, there’re several steps for tools installation and environment preparation. See the GPU installation guide on [Linux](../../../../docs/mddocs/Quickstart/install_linux_gpu.md) or [Windows](../../../../docs/mddocs/Quickstart/install_windows_gpu.md) for mode details. \ No newline at end of file