Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for common hf multimodel #276

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
745f74d
support for hf multimodel
n1ck-guo Oct 11, 2024
4f537d3
Merge branch 'main' into hengguo/multimodel
wenhuach21 Oct 11, 2024
8fd0278
update
n1ck-guo Oct 12, 2024
9ec7af9
Merge branch 'hengguo/multimodel' of https://github.com/intel/auto-ro…
n1ck-guo Oct 12, 2024
e1a7107
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2024
169e698
update
n1ck-guo Oct 12, 2024
6acdecd
Merge branch 'hengguo/multimodel' of https://github.com/intel/auto-ro…
n1ck-guo Oct 12, 2024
b41a74f
Merge branch 'main' into hengguo/multimodel
n1ck-guo Oct 12, 2024
ccaa250
support in auto_round cmd
n1ck-guo Oct 16, 2024
7231913
update
n1ck-guo Oct 17, 2024
d97a64c
merge main
n1ck-guo Oct 17, 2024
ff02509
fix
n1ck-guo Oct 17, 2024
bec46eb
fix
n1ck-guo Oct 18, 2024
234b3bb
fix
n1ck-guo Oct 18, 2024
dd74350
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
6de760f
change api
n1ck-guo Oct 21, 2024
00c92d1
Merge branch 'hengguo/multimodel' of https://github.com/intel/auto-ro…
n1ck-guo Oct 21, 2024
4b05fdd
Merge branch 'main' into hengguo/multimodel
n1ck-guo Oct 21, 2024
355a6d4
update
n1ck-guo Oct 21, 2024
c4f81c6
support cogvlm2
n1ck-guo Oct 22, 2024
d7fa8d7
remove example, replaced by __main__
n1ck-guo Oct 22, 2024
544241f
fix
n1ck-guo Oct 22, 2024
bbb6d4c
fix
n1ck-guo Oct 23, 2024
4107e5c
fix
n1ck-guo Oct 23, 2024
0e1f749
fix
n1ck-guo Oct 23, 2024
a75a3d5
calib for mllm
n1ck-guo Oct 23, 2024
815e85f
fix bug for cogvlm2
n1ck-guo Oct 24, 2024
a06066f
refact api
n1ck-guo Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,6 @@ def tune_mllm(args):
print(
"Warning, activation quantization is an experiment feature")

if "marlin" in args.format and args.sym == False:
assert False, "marlin backend only supports sym quantization, please set --sym"
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved

if args.act_bits <= 8 and args.deployment_device != "fake":
assert False, "only support fake mode for activation quantization currently"

Expand Down
4 changes: 2 additions & 2 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,11 +534,11 @@ def calib(self, nsamples, bs):
for key in data.keys():
data_new[key] = to_device(data[key], self.model.device)
if key == 'images':
data_new[key] = to_dtype(data[key], self.model.dtype)
data_new[key] = to_dtype(data_new[key], self.model.dtype)
input_ids = data_new["input_ids"]
if input_ids.shape[-1] < self.seqlen:
continue

self.model(**data_new)
try:
if isinstance(data_new, torch.Tensor):
self.model(data_new)
Expand Down
64 changes: 18 additions & 46 deletions auto_round/mllm/mllm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import json
from typing import Dict
from enum import Enum, unique


import torch
from torch.utils.data import Dataset, DataLoader
Expand All @@ -34,45 +34,22 @@ def register(dataset):
return dataset
return register

@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
OBSERVATION = "observation"


def fill_content(target, **kwargs):
for name, value in kwargs.items():
target = target.replace("{{" + name + "}}", value, 1)
return target


def mllm_encode(sources, template: "Template"):
element = ""
for i, source in enumerate(sources):
if i == 0:
element += fill_content(template.format_system, content=template.default_system)
# if i > 0 and i % 2 ==0:
# element += fill_content(template.format_separator)

if source['role'] == Role.USER.value:
element += fill_content(template.format_user, content=source["content"])
elif source['role'] == Role.ASSISTANT.value:
element += fill_content(template.format_assistant, content=source["content"])
elif source['role'] == Role.OBSERVATION.value:
element += fill_content(template.format_observation, content=source["content"])
elif source['role'] == Role.FUNCTION.value:
element += fill_content(template.format_function, content=source["content"])
return element


@register_dataset("llava")
class LlavaDataset(Dataset):
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
"""Dataset for supervised fine-tuning."""

def __init__(self, model_type_or_template, model, tokenzier, dataset_path, extra_data_dir, max_length) -> None:
def __init__(
self,
model_type_or_template,
model,
tokenzier,
dataset_path,
extra_data_dir,
max_length,
padding=True,
truncation=True,
) -> None:
super().__init__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please enhance the code by adding some comments and better add a README to include instructions on how to prepare llava dataset , support new models, etc. in this folder. Additionally, better test the code with an unseen model to evaluate its generalization capabilities.

if isinstance(model_type_or_template, str):
assert model_type_or_template in TEMPLATES, f"{model_type_or_template} is not supported"
Expand All @@ -86,6 +63,8 @@ def __init__(self, model_type_or_template, model, tokenzier, dataset_path, extra
raise TypeError
self.tokenizer = tokenzier
self.questions = json.load(open(dataset_path, "r"))
self.padding = padding
self.truncation = truncation
self.extra_data_dir = extra_data_dir
self.max_length = max_length
self.role_mapping = {"human": "user", "gpt": "assistant"}
Expand All @@ -101,27 +80,21 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:

text = self.questions[i]["conversations"]
text = self.covert_conversations(text)
text = mllm_encode(text, self.template)

token_length = len(self.tokenizer(text).input_ids)
if token_length < self.max_length:
if self.tokenizer.pad_token:
text += self.tokenizer.pad_token * (self.max_length - token_length)
else:
text = self.tokenizer.decode(self.tokenizer(text).input_ids[:self.max_length])
text = self.template._encode(text)

image_fold = _extract_data_dir(self.extra_data_dir)
if isinstance(image_fold, dict):
image_fold = image_fold['image']
image = Image.open(os.path.join(image_fold, os.path.basename(self.questions[i]["image"])))
image = self.template.plugin.image_processor(os.path.join(image_fold, os.path.basename(self.questions[i]["image"])))

ret = self.template.plugin.get_input(
self.model,
self.tokenizer,
text=text,
images=image,
padding=True,
truncation=True,
padding=self.padding,
truncation=self.truncation,
return_tensors="pt",
max_length = self.max_length
)
Expand All @@ -141,7 +114,6 @@ def covert_conversations(self, data):
return new_data



def get_mllm_dataloader(
template_or_path,
model,
Expand Down
44 changes: 40 additions & 4 deletions auto_round/mllm/plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from transformers.data.data_collator import default_data_collator

from PIL import Image

PLUGINS = {}

def regist_plugin(name):
Expand All @@ -22,7 +24,16 @@ def get_input(
padding=True,
truncation=True,
return_tensors="pt",
max_length=None,
**kwargs):

token_length = len(tokenizer(text).input_ids)
if token_length < max_length:
if tokenizer.pad_token:
text += tokenizer.pad_token * (max_length - token_length)
else:
text = tokenizer.decode(tokenizer(text).input_ids[:max_length])

ret = tokenizer.processor(
text=text,
images=images,
Expand All @@ -39,6 +50,10 @@ def get_input(
def data_collator(batch):
return default_data_collator(batch)

@staticmethod
def image_processor(image_path):
return Image.open(image_path)

@regist_plugin("qwen2_vl")
class Qwen2VLPlugin(BasicPlugin):
def get_input(model, tokenizer, text, images, padding=True, truncation=True, return_tensors="pt", **kwargs):
Expand All @@ -59,7 +74,12 @@ def get_input(model, tokenizer, text, images, padding=True, truncation=True, ret

@regist_plugin("cogvlm2")
class CogVLM2Plugin(BasicPlugin):
def get_input(model, tokenizer, text, images, max_length=2048, **kwargs):
def get_input(
model, tokenizer, text, images, max_length=2048,
padding=True, truncation=True, **kwargs):
breakpoint()
padding_len = 2303
max_length += padding_len
input_data = model.build_conversation_input_ids(
tokenizer,
query=text,
Expand All @@ -70,13 +90,19 @@ def get_input(model, tokenizer, text, images, max_length=2048, **kwargs):
def pad_to_len(unpadded_tensor, pad_to_length, pad_value=0):
current_length = len(unpadded_tensor)
if current_length >= pad_to_length:
return unpadded_tensor[:pad_to_length]
return torch.cat(
if truncation:
return unpadded_tensor[:pad_to_length]
else:
return unpadded_tensor
if padding:
return torch.cat(
(unpadded_tensor,
torch.full([pad_to_length - current_length],
fill_value=pad_value,
dtype=unpadded_tensor.dtype,
device=unpadded_tensor.device)), dim=0)
else:
return unpadded_tensor
input_data['input_ids'] = pad_to_len(
input_data['input_ids'],
max_length,
Expand All @@ -98,6 +124,12 @@ def pad_to_len(unpadded_tensor, pad_to_length, pad_value=0):
max_length,
pad_value=-100
)
input_data = {
'input_ids': input_data['input_ids'].unsqueeze(0),
'token_type_ids': input_data['token_type_ids'].unsqueeze(0),
'attention_mask': input_data['attention_mask'].unsqueeze(0),
'images': [[input_data['images'][0]]] if input_data['images'] is not None else None,
}
return input_data

@staticmethod
Expand All @@ -110,4 +142,8 @@ def data_collator(batch):
batched_data[key] = torch.stack([item[key] for item in batch])
# else:
# raise ValueError("Unsupported datatype in custom collate_fn")
return batched_data
return batched_data

@staticmethod
def image_processor(image_path):
return Image.open(image_path).convert('RGB')
34 changes: 34 additions & 0 deletions auto_round/mllm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,27 @@
import json
from dataclasses import dataclass
from typing import Dict, Optional, List, Union, Sequence
from enum import Enum, unique


from .plugin import BasicPlugin, PLUGINS

TEMPLATES: Dict[str, "Template"] = {}

def fill_content(target, **kwargs):
for name, value in kwargs.items():
target = target.replace("{{" + name + "}}", value, 1)
return target


@unique
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
OBSERVATION = "observation"

@dataclass
class Template:
model_type: str
Expand All @@ -20,6 +36,24 @@ class Template:
replace_tokens: List[tuple]
plugin: "BasicPlugin"

def _encode(self, sources):
element = ""
for i, source in enumerate(sources):
if i == 0:
element += fill_content(self.format_system, content=self.default_system)
# if i > 0 and i % 2 ==0:
# element += fill_content(self.format_separator)

if source['role'] == Role.USER.value:
element += fill_content(self.format_user, content=source["content"])
elif source['role'] == Role.ASSISTANT.value:
element += fill_content(self.format_assistant, content=source["content"])
elif source['role'] == Role.OBSERVATION.value:
element += fill_content(self.format_observation, content=source["content"])
elif source['role'] == Role.FUNCTION.value:
element += fill_content(self.format_function, content=source["content"])
return element


def _register_template(
model_type: str,
Expand Down
1 change: 1 addition & 0 deletions auto_round/mllm/templates/cogvlm2.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"model_type": "cogvlm2",
"format_user": "Question: {{content}} ",
"format_assistant": "Answer: {{content}}\n",
"replace_tokens": ["<image>\n", ""],
"plugin": "cogvlm2"
wenhuach21 marked this conversation as resolved.
Show resolved Hide resolved
}