Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support mps device for demos #80

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@

import numpy as np

# Device and dtype configuration
def get_device_and_dtype():
if torch.cuda.is_available():
return 'cuda', torch.bfloat16 # CUDA设备使用bfloat16
elif torch.backends.mps.is_available():
return 'mps', torch.float32
return 'cpu', torch.float32

device, dtype = get_device_and_dtype()

# Load model and processor
model_path = "deepseek-ai/Janus-1.3B"
Expand All @@ -15,22 +24,25 @@
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
vl_gpt = vl_gpt.to(dtype).to(device)

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Multimodal Understanding function
@torch.inference_mode()
# Multimodal Understanding function
def multimodal_understanding(image, question, seed, top_p, temperature):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
# Clear CUDA cache if using CUDA
if device == 'cuda':
torch.cuda.empty_cache()

# set seed
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
if device == 'cuda':
torch.cuda.manual_seed(seed)
elif device == 'mps':
torch.mps.manual_seed(seed)

conversation = [
{
Expand All @@ -44,8 +56,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)

).to(device, dtype=dtype)

inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

Expand Down Expand Up @@ -74,16 +85,17 @@ def generate(input_ids,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
# Clear CUDA cache if using CUDA
if device == 'cuda':
torch.cuda.empty_cache()

tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device)

pkv = None
for i in range(image_token_num_per_image):
Expand Down Expand Up @@ -123,11 +135,15 @@ def generate_image(prompt,
seed=None,
guidance=5):
# Clear CUDA cache and avoid tracking gradients
torch.cuda.empty_cache()
if device == 'cuda':
torch.cuda.empty_cache()
# Set the seed for reproducible results
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if device == 'cuda':
torch.cuda.manual_seed(seed)
elif device == 'mps':
torch.mps.manual_seed(seed)
np.random.seed(seed)
width = 384
height = 384
Expand Down
47 changes: 28 additions & 19 deletions demo/app_janusflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,41 @@
from diffusers.models import AutoencoderKL
import numpy as np

cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 设置设备选择逻辑
if torch.cuda.is_available():
device = 'cuda'
dtype = torch.bfloat16
elif torch.backends.mps.is_available():
device = 'mps'
dtype = torch.float32 # MPS设备使用float32
else:
device = 'cpu'
dtype = torch.float32

# Load model and processor
model_path = "deepseek-ai/JanusFlow-1.3B"
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
vl_gpt = vl_gpt.to(dtype).to(device).eval()

# remember to use bfloat16 dtype, this vae doesn't work with fp16
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
vae = vae.to(torch.bfloat16).to(cuda_device).eval()
vae = vae.to(dtype).to(device).eval()

# Multimodal Understanding function
@torch.inference_mode()
# Multimodal Understanding function
def multimodal_understanding(image, question, seed, top_p, temperature):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
if device == 'cuda':
torch.cuda.empty_cache()

# set seed
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
if device == 'cuda':
torch.cuda.manual_seed(seed)

conversation = [
{
Expand All @@ -43,8 +53,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)

).to(device, dtype=dtype)

inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

Expand Down Expand Up @@ -73,7 +82,7 @@ def generate(
num_inference_steps: int = 30
):
# we generate 5 images at a time, *2 for CFG
tokens = torch.stack([input_ids] * 10).cuda()
tokens = torch.stack([input_ids] * 10).to(device)
tokens[5:, 1:] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
print(inputs_embeds.shape)
Expand All @@ -83,13 +92,13 @@ def generate(

# generate with rectified flow ode
# step 1: encode with vision_gen_enc
z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
z = torch.randn((5, 4, 48, 48), dtype=dtype).to(device)

dt = 1.0 / num_inference_steps
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
dt = torch.zeros_like(z).to(device).to(dtype) + dt

# step 2: run ode
attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(device)
attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
attention_mask = attention_mask.int()
for step in range(num_inference_steps):
Expand All @@ -108,8 +117,7 @@ def generate(
if step == 0:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=None)
attention_mask=attention_mask)
past_key_values = []
for kv_cache in past_key_values:
k, v = kv_cache[0], kv_cache[1]
Expand All @@ -118,8 +126,7 @@ def generate(
else:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state

# transform hidden_states back to v
Expand Down Expand Up @@ -153,12 +160,14 @@ def generate_image(prompt,
seed=None,
guidance=5,
num_inference_steps=30):
# Clear CUDA cache and avoid tracking gradients
torch.cuda.empty_cache()
# Clear CUDA cache if using CUDA device
if device == 'cuda':
torch.cuda.empty_cache()
# Set the seed for reproducible results
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if device == 'cuda':
torch.cuda.manual_seed(seed)
np.random.seed(seed)

with torch.no_grad():
Expand Down
54 changes: 36 additions & 18 deletions demo/app_januspro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@
import numpy as np
import os
import time
# import spaces # Import spaces for ZeroGPU compatibility

# Device and dtype configuration
if torch.cuda.is_available():
device = 'cuda'
dtype = torch.bfloat16
elif torch.backends.mps.is_available():
device = 'mps'
dtype = torch.float32 # MPS设备使用float32
else:
device = 'cpu'
dtype = torch.float32 # CPU设备使用float32

# Load model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
Expand All @@ -19,26 +28,26 @@
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True)
if torch.cuda.is_available():
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
vl_gpt = vl_gpt.to(torch.float16)
vl_gpt = vl_gpt.to(dtype).to(device)

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'

@torch.inference_mode()
# @spaces.GPU(duration=120)
# Multimodal Understanding function
def multimodal_understanding(image, question, seed, top_p, temperature):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
# Clear device cache
if device == 'cuda':
torch.cuda.empty_cache()
elif device == 'mps':
torch.mps.empty_cache()

# set seed
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
if device == 'cuda':
torch.cuda.manual_seed(seed)

conversation = [
{
Expand All @@ -52,8 +61,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)

).to(device, dtype=dtype)

inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

Expand Down Expand Up @@ -82,16 +90,19 @@ def generate(input_ids,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
# Clear device cache
if device == 'cuda':
torch.cuda.empty_cache()
elif device == 'mps':
torch.mps.empty_cache()

tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device)

pkv = None
for i in range(image_token_num_per_image):
Expand Down Expand Up @@ -133,17 +144,24 @@ def unpack(dec, width, height, parallel_size=5):

@torch.inference_mode()
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
@torch.inference_mode()
def generate_image(prompt,
seed=None,
guidance=5,
t2i_temperature=1.0):
# Clear CUDA cache and avoid tracking gradients
torch.cuda.empty_cache()
# Clear device cache
if device == 'cuda':
torch.cuda.empty_cache()
elif device == 'mps':
torch.mps.empty_cache()

# Set the seed for reproducible results
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
if device == 'cuda':
torch.cuda.manual_seed(seed)

width = 384
height = 384
parallel_size = 5
Expand Down
Loading