Skip to content

Commit

Permalink
Add calibration_dataset_concat_size option/mode (#1257)
Browse files Browse the repository at this point in the history
* add calibration_data_concat and calibration_data_concat_context

* rename to calibration_dataset_concat_size

* rename

* fix new line input_ids

* fix new_line_input_ids_len

* use " "

* set calibration_dataset_concat_size default 2048

* set calibration_dataset_concat_size default 2048

* cleanup
  • Loading branch information
LRL-ModelCloud authored Feb 12, 2025
1 parent 16e86ae commit 301081f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
2 changes: 2 additions & 0 deletions gptqmodel/models/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,5 @@ def get_best_device(backend: BACKEND = BACKEND.AUTO) -> torch.device:
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

EXPERT_INDEX_PLACEHOLDER = "{expert_index}"

CALIBRATION_DATASET_CONCAT_CHAR = " "
74 changes: 72 additions & 2 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from ..utils.progress import ProgressBar
from ..utils.torch import torch_empty_cache
from ._const import CPU, DEFAULT_MAX_SHARD_SIZE, DEVICE, SUPPORTS_MODULE_TYPES
from ._const import CPU, DEFAULT_MAX_SHARD_SIZE, DEVICE, SUPPORTS_MODULE_TYPES, CALIBRATION_DATASET_CONCAT_CHAR
from .loader import ModelLoader
from .writer import (
QUANT_LOG_DAMP,
Expand Down Expand Up @@ -171,6 +171,8 @@ def __init__(
def prepare_dataset(
self,
calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[List[int]]],
# Setting a fixed calibration_dataset_concat_size may improve the performance of the quantized model.
calibration_dataset_concat_size: Optional[int] = None,
batch_size: int = 1,
):
if isinstance(calibration_dataset[0], (str, list)) or (isinstance(calibration_dataset[0], list) and all(isinstance(x, int) for x in calibration_dataset[0])):
Expand Down Expand Up @@ -217,6 +219,72 @@ def _convert_tensor_to_list(tensor):
}
)

if calibration_dataset_concat_size:
concatenated_data = []
input_ids_buff = []
attention_mask_buff = []
current_length = 0

new_line = self.tokenizer(CALIBRATION_DATASET_CONCAT_CHAR, return_tensors="pt")
new_line_input_ids = _convert_tensor_to_list(new_line["input_ids"])[0]
new_line_attention_mask = _convert_tensor_to_list(new_line["attention_mask"])[0]
new_line_input_ids_len = len(new_line_input_ids)

for example in new_calibration_dataset:
input_ids = example["input_ids"][0]
attention_mask = example["attention_mask"][0]

if current_length + len(input_ids) + new_line_input_ids_len >= calibration_dataset_concat_size:
if len(input_ids_buff) > 0:
remaining_space = calibration_dataset_concat_size - current_length
# if there is remaining space, add the remaining input to the current block
if remaining_space > 0:
input_ids_buff.extend(new_line_input_ids)
input_ids_buff.extend(input_ids[:remaining_space - new_line_input_ids_len])
attention_mask_buff.extend(new_line_attention_mask)
attention_mask_buff.extend(attention_mask[:remaining_space - new_line_input_ids_len])

concatenated_data.append({
"input_ids": [input_ids_buff],
"attention_mask": [attention_mask_buff]
})
else:
# if there is no remaining space, add the current block to the concatenated data
concatenated_data.append({
"input_ids": [input_ids_buff],
"attention_mask": [attention_mask_buff]
})

input_ids_buff = input_ids[:calibration_dataset_concat_size]
attention_mask_buff = attention_mask[:calibration_dataset_concat_size]
current_length = len(input_ids_buff)
else:
input_ids_buff = input_ids[:calibration_dataset_concat_size]
attention_mask_buff = attention_mask[:calibration_dataset_concat_size]
current_length = len(input_ids_buff)
else:
if len(input_ids_buff) > 0:
input_ids_buff.extend(new_line_input_ids)
attention_mask_buff.extend(new_line_attention_mask)
current_length += new_line_input_ids_len

input_ids_buff.extend(input_ids)
attention_mask_buff.extend(attention_mask)
current_length += len(input_ids)


if input_ids_buff:
padding_length = calibration_dataset_concat_size - len(input_ids_buff)
if padding_length > 0:
input_ids_buff.extend([self.tokenizer.pad_token_id] * padding_length)
attention_mask_buff.extend([0] * padding_length)
concatenated_data.append({
"input_ids": [input_ids_buff],
"attention_mask": [attention_mask_buff]
})

new_calibration_dataset = concatenated_data

new_calibration_dataset_batched = [
collate_data(new_calibration_dataset[start: start + batch_size], self.tokenizer.pad_token_id)
for start in range(0, len(new_calibration_dataset), batch_size)
Expand All @@ -229,6 +297,8 @@ def _convert_tensor_to_list(tensor):
def quantize(
self,
calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]],
# Setting a fixed calibration_dataset_concat_size may improve the performance of the quantized model.
calibration_dataset_concat_size: Optional[int] = None,
batch_size: int = 1,
calibration_enable_gpu_cache: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
Expand Down Expand Up @@ -307,7 +377,7 @@ def quantize(
if BITBLAS_AVAILABLE is False:
raise ValueError(BITBLAS_INSTALL_HINT)

calibration_dataset = self.prepare_dataset(calibration_dataset, batch_size,)
calibration_dataset = self.prepare_dataset(calibration_dataset, calibration_dataset_concat_size, batch_size)

# Calculate the average length of the average input_ids
total_input_ids_length = 0
Expand Down

0 comments on commit 301081f

Please sign in to comment.