Skip to content

Commit

Permalink
Add image repeat to the benchmark_serving.py to test hit/miss of MM c…
Browse files Browse the repository at this point in the history
…ache

Signed-off-by: Alexander Matveev <[email protected]>
  • Loading branch information
alexm-redhat committed Dec 13, 2024
1 parent 0fb851b commit d35febb
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
53 changes: 51 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,36 @@ def sample_sonnet_requests(
return sampled_requests


class ImageRepeater:

def __init__(self, image_repeat_prob):
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
self.no_yes = [0, 1]
self.probs = [1.0 - image_repeat_prob, image_repeat_prob]

self.prev_image = None
self.idx = 0

def process(self, image):
self.idx += 1

res = random.choices(self.no_yes, self.probs)[0]
if res == 0 or self.prev_image is None:
# No repeat => Use current/new image
self.prev_image = image
else:
# Repeat previous image
pass

return self.prev_image


def sample_mmmu_pro_vision_requests(
dataset,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
image_repeater: Optional[ImageRepeater] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
sampled_requests: List[Tuple[str, int, int, Dict[str,
Collection[str]]]] = []
Expand Down Expand Up @@ -233,6 +258,8 @@ def sample_mmmu_pro_vision_requests(
Image), ("Input image format must be `PIL.Image.Image`, "
f"given {type(data['image'])}.")
image: Image = data["image"]
if image_repeater is not None:
image = image_repeater.process(image)
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
Expand All @@ -257,8 +284,14 @@ def sample_hf_requests(
tokenizer: PreTrainedTokenizerBase,
random_seed: int,
fixed_output_len: Optional[int] = None,
image_repeat_prob: Optional[float] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:

# Init image repeater (if specified)
image_repeater = None
if image_repeat_prob is not None:
image_repeater = ImageRepeater(image_repeat_prob)

# Special case for MMMU-Pro vision dataset
if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision':
assert dataset_split == "test"
Expand All @@ -270,8 +303,11 @@ def sample_hf_requests(
"MMMU/MMMU_Pro vision dataset must have 'image' column.")
filter_func = lambda x: isinstance(x["image"], Image)
dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
return sample_mmmu_pro_vision_requests(dataset, num_requests,
tokenizer, fixed_output_len)
return sample_mmmu_pro_vision_requests(dataset,
num_requests,
tokenizer,
fixed_output_len,
image_repeater=image_repeater)

dataset = load_dataset(dataset_path,
name=dataset_subset,
Expand Down Expand Up @@ -305,6 +341,8 @@ def sample_hf_requests(

if "image" in data and isinstance(data["image"], Image):
image: Image = data["image"]
if image_repeater is not None:
image = image_repeater.process(image)
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
Expand All @@ -323,6 +361,9 @@ def sample_hf_requests(
else:
image_url = f"file://{data['image']}"

if image_repeater is not None:
image_url = image_repeater.process(image_url)

mm_content = {
"type": "image_url",
"image_url": {
Expand Down Expand Up @@ -854,6 +895,7 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer,
random_seed=args.seed,
fixed_output_len=args.hf_output_len,
image_repeat_prob=args.image_repeat_prob,
)

elif args.dataset_name == "random":
Expand Down Expand Up @@ -1222,5 +1264,12 @@ def main(args: argparse.Namespace):
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')

parser.add_argument(
'--image-repeat-prob',
type=float,
default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache'
' (if enabled)')

args = parser.parse_args()
main(args)
6 changes: 3 additions & 3 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def __init__(
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)

# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
self.mm_debug_cache_hit_ratio_steps = 32
self.mm_cache_hits = 0
self.mm_cache_total = 0

def cache_hit_ratio(self, steps) -> float:
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)
print("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)

def process_inputs(
self,
Expand Down

0 comments on commit d35febb

Please sign in to comment.