-
Notifications
You must be signed in to change notification settings - Fork 452
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support valley
- Loading branch information
Showing
12 changed files
with
235 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from . import (baai, baichuan, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, microsoft, | ||
minicpm, mistral, mllm, mplug, openbuddy, qwen, skywork, telechat, yi) | ||
minicpm, mistral, mllm, mplug, openbuddy, qwen, skywork, telechat, valley, yi) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import os | ||
import sys | ||
from functools import partial, wraps | ||
from typing import Any, Dict | ||
|
||
from transformers import AutoConfig | ||
|
||
from swift.llm import TemplateType | ||
from ..constant import MLLMModelType | ||
from ..model_arch import ModelArch | ||
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, | ||
get_model_tokenizer_with_flash_attn, register_model) | ||
from ..utils import ModelInfo, git_clone_github, safe_snapshot_download | ||
|
||
|
||
def get_model_tokenizer_valley(model_dir: str, | ||
model_info: ModelInfo, | ||
model_kwargs: Dict[str, Any], | ||
load_model: bool = True, | ||
**kwargs): | ||
llm_model_type = kwargs.pop('llm_model_type') | ||
local_repo_path = kwargs.get('local_repo_path') | ||
if not local_repo_path: | ||
repo_path = 'https://github.com/bytedance/Valley.git' | ||
local_repo_path = git_clone_github(repo_path) | ||
sys.path.append(os.path.join(local_repo_path)) | ||
|
||
if llm_model_type == 'valley': | ||
from valley_eagle.model.language_model.valley_qwen2 import ValleyQwen2ForCausalLM, ValleyConfig | ||
model_config = ValleyConfig.from_pretrained(model_dir) | ||
model_config.mm_vision_tower = safe_snapshot_download('AI-ModelScope/siglip-so400m-patch14-384') | ||
model_config.eagle_vision_tower = safe_snapshot_download('Qwen/Qwen2-VL-7B-Instruct') | ||
automodel_class = ValleyQwen2ForCausalLM | ||
|
||
kwargs['model_config'] = model_config | ||
kwargs['automodel_class'] = automodel_class | ||
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) | ||
model.generation_config.repetition_penalty = 1.0 # Otherwise, Error. Same for original code. | ||
if model is not None: | ||
from transformers import AutoProcessor, SiglipImageProcessor | ||
tokenizer.image_processor = SiglipImageProcessor.from_pretrained(model.config.mm_vision_tower) | ||
tokenizer.qwen2vl_processor = AutoProcessor.from_pretrained( | ||
model.config.eagle_vision_tower, max_pixels=1280 * 28 * 28) | ||
tokenizer.image_processor.crop_size = tokenizer.image_processor.size['height'] | ||
return model, tokenizer | ||
|
||
|
||
register_model( | ||
ModelMeta( | ||
MLLMModelType.valley, | ||
[ | ||
ModelGroup([ | ||
Model('bytedance-research/Valley-Eagle-7B'), | ||
], ), | ||
], | ||
TemplateType.valley, | ||
partial(get_model_tokenizer_valley, llm_model_type='valley'), | ||
architectures=['ValleyQwen2ForCausalLM'], | ||
model_arch=ModelArch.valley, | ||
requires=['transformers>=4.42', 'av'], | ||
tags=['vision'], | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from . import (deepseek, emu3, gemma, glm, got_ocr, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, | ||
minicpm, molmo, mplug, openbuddy, pixtral, qwen, yi) | ||
minicpm, molmo, mplug, openbuddy, pixtral, qwen, valley, yi) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import io | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Literal, Optional | ||
|
||
import torch | ||
from PIL import Image | ||
|
||
from ..base import Template | ||
from ..constant import MLLMTemplateType | ||
from ..register import register_template | ||
from ..template_inputs import StdTemplateInputs | ||
from ..utils import Context | ||
from .utils import ChatmlTemplateMeta | ||
|
||
|
||
@dataclass | ||
class ValleyTemplateMeta(ChatmlTemplateMeta): | ||
auto_add_bos: bool = False | ||
default_system: Optional[str] = ('You are Valley, a large language and vision assistant trained by ByteDance.' | ||
'You are able to understand the visual content or video that the user provides,' | ||
' and assist the user with a variety of tasks using natural language.' | ||
'Follow the instructions carefully and explain your answers in detail.') | ||
|
||
|
||
class ValleyTemplate(Template): | ||
skip_prompt = True | ||
use_model = True | ||
|
||
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, | ||
inputs: StdTemplateInputs) -> List[Context]: | ||
# assert media_type == 'image' | ||
if media_type == 'video': | ||
from ..vision_utils import load_video_valley | ||
return self.replace_video2image(load_video_valley, inputs, lambda i: [[151665, -200, 151666]]) | ||
return [[151665, -200, 151666]] | ||
|
||
def preprocess_images(self, image_binary_list): | ||
from valley_eagle.util.mm_utils import process_anyres_image | ||
|
||
def byte2image(byte_data): | ||
return Image.open(io.BytesIO(byte_data)) | ||
|
||
images = [] | ||
for binary in image_binary_list: | ||
if isinstance(binary, Image.Image): | ||
images.append(binary.convert('RGB')) | ||
elif isinstance(binary, bytes): | ||
images.append(byte2image(binary)) | ||
else: | ||
raise ValueError('unsupported type') | ||
video_pad = [] | ||
for img in images: | ||
if self.model.config.anyres: | ||
image = process_anyres_image(img, self.tokenizer.image_processor, self.model.config.grid_pinpoints) | ||
else: | ||
image = self.tokenizer.image_processor(img, return_tensors='pt')['pixel_values'][0] | ||
video_pad.append(image) | ||
|
||
if not self.model.config.anyres: | ||
video = torch.stack(video_pad, dim=0) | ||
else: | ||
video = [torch.stack(img, dim=0) for img in video_pad] | ||
return video | ||
|
||
def process_images(self, inputs, images_binary): | ||
import re | ||
from qwen_vl_utils import fetch_image | ||
|
||
if inputs.messages[-1]['role'] == 'user': | ||
text = inputs.messages[-1]['content'] | ||
elif len(inputs.messages) > 1 and inputs.messages[-2]['role'] == 'user': | ||
text = inputs.messages[-2]['content'] | ||
else: | ||
text = '' | ||
video_images_tensor = self.preprocess_images(images_binary) | ||
img_length = len(video_images_tensor) | ||
video_images_tensor = [video_images_tensor] | ||
if img_length: | ||
images = [[item.to(self.model.device).to(self.model.dtype) for item in img] for img in video_images_tensor] | ||
|
||
messages_qwen = [] | ||
image_list = [] | ||
if isinstance(images_binary[0], Image.Image): | ||
images_pil = [img.convert('RGB') for img in images_binary] | ||
elif isinstance(images_binary[0], bytes): | ||
images_pil = [Image.open(io.BytesIO(img)).convert('RGB') for img in images_binary] | ||
image_sizes = torch.tensor([[x.size for x in images_pil]]) | ||
for image_file in images_pil: | ||
image = fetch_image({'image': image_file}) | ||
image_list.append(image) | ||
messages_qwen.append({'role': 'user', 'content': [{'type': 'text', 'text': text}]}) | ||
messages_qwen.append({'role': 'assistant', 'content': [{'type': 'text', 'text': ''}]}) | ||
text = self.tokenizer.qwen2vl_processor.apply_chat_template( | ||
messages_qwen[:-1], tokenize=False, add_generation_prompt=True) | ||
text_segs = re.split('<image>', text) | ||
text = '<|vision_start|><|image_pad|><|vision_end|>'.join(text_segs[:len(image_list) + 1]) + ''.join( | ||
text_segs[len(image_list) + 1:]) | ||
data_dict_qwen2vl = self.tokenizer.qwen2vl_processor( | ||
text=[text], images=image_list, padding=True, return_tensors='pt') | ||
results = {} | ||
|
||
results['images'] = images | ||
results['image_sizes'] = image_sizes | ||
results['pixel_values'] = data_dict_qwen2vl['pixel_values'].to(self.model.device) | ||
results['image_grid_thw'] = data_dict_qwen2vl['image_grid_thw'].to(self.model.device) | ||
return results | ||
|
||
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | ||
encoded = super()._encode(inputs) | ||
images = inputs.images or [] | ||
input_ids = encoded['input_ids'] | ||
labels = encoded['labels'] | ||
if images: | ||
results = self.process_images(inputs, images) | ||
encoded['images'] = results['images'] | ||
encoded['image_sizes'] = results['image_sizes'] | ||
encoded['pixel_values'] = results['pixel_values'] | ||
encoded['image_grid_thw'] = results['image_grid_thw'] | ||
encoded['input_ids'] = input_ids | ||
encoded['labels'] = labels | ||
return encoded | ||
|
||
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: | ||
res = super()._data_collator(batch, padding_to=padding_to) | ||
if 'images' in batch[0]: | ||
res['images'] = sum([b['images'] for b in batch if 'images' in b], start=[]) | ||
res['image_sizes'] = torch.concat([b['image_sizes'] for b in batch if 'image_sizes' in b], dim=0) | ||
for media_type in ['image', 'video']: | ||
grid_thw = [b[f'{media_type}_grid_thw'] for b in batch if b.get(f'{media_type}_grid_thw') is not None] | ||
if grid_thw: | ||
res[f'{media_type}_grid_thw'] = torch.concat(grid_thw) | ||
return res | ||
|
||
|
||
register_template(ValleyTemplateMeta( | ||
MLLMTemplateType.valley, | ||
template_cls=ValleyTemplate, | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters