From 73dbf143bccdcf29579024c17fdb7077d0a0fac1 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 28 Jun 2024 07:57:44 -0400 Subject: [PATCH] Refactor RequestGenerator to use threading and update test suite Refactored RequestGenerator class: - Replaced asyncio.Queue with Queue from the queue module for thread safety. - Utilized threading for background queue population to ensure non-blocking request generation. - Removed the start method as threading automatically starts the background task in async mode. - Ensured that the _populate_queue method runs in a background thread to keep the queue populated. - Implemented clean shutdown with the stop method joining the background thread. Updated unit tests: - Added test_request_generator_sync_constructor and test_request_generator_async_constructor to verify constructor behavior. - Added tests for __repr__ and __iter__ methods. - Added tests to ensure create_item raises NotImplementedError if not overridden. - Added tests to verify __iter__ calls create_item the expected number of times. Separated test files: - Created tests/unit/request/test_base.py for unit tests. - Created tests/integration/request/test_base.py for integration tests. Unit tests: - Verified the construction of the class with different input parameters. - Mocked AutoTokenizer for testing tokenizer initialization with both a class implementation and a string alias. - Ensured that the __iter__ method works correctly in both sync and async modes. - Verified that create_item is called the expected number of times. Integration tests: - Tested tokenizer construction with both a Hugging Face tokenizer and a string alias, ensuring the correct tokenizer is created. --- src/guidellm/request/base.py | 66 ++++++------- .../integration/{core => request}/__init__.py | 0 tests/integration/request/test_base.py | 23 +++++ tests/unit/request/test_base.py | 99 ++++++++++++++++--- 4 files changed, 138 insertions(+), 50 deletions(-) rename tests/integration/{core => request}/__init__.py (100%) create mode 100644 tests/integration/request/test_base.py diff --git a/src/guidellm/request/base.py b/src/guidellm/request/base.py index e5b2003..1a110b7 100644 --- a/src/guidellm/request/base.py +++ b/src/guidellm/request/base.py @@ -1,5 +1,7 @@ -import asyncio +import threading +import time from abc import ABC, abstractmethod +from queue import Empty, Full, Queue from typing import Iterator, Optional, Union from loguru import logger @@ -31,8 +33,8 @@ def __init__( ): self._async_queue_size = async_queue_size self._mode = mode - self._queue = asyncio.Queue(maxsize=async_queue_size) - self._stop_event = asyncio.Event() + self._queue = Queue(maxsize=async_queue_size) + self._stop_event = threading.Event() if tokenizer is not None: self._tokenizer = ( @@ -40,11 +42,20 @@ def __init__( if isinstance(tokenizer, str) else tokenizer ) - logger.info(f"Tokenizer initialized: {self._tokenizer}") + logger.info("Tokenizer initialized: {}", self._tokenizer) else: self._tokenizer = None logger.debug("No tokenizer provided") + if self._mode == "async": + self._thread = threading.Thread(target=self._populate_queue) + self._thread.daemon = True + self._thread.start() + logger.info( + "RequestGenerator started in async mode with queue size: {}", + self._async_queue_size, + ) + def __repr__(self) -> str: """ Return a string representation of the RequestGenerator. @@ -72,7 +83,7 @@ def __iter__(self) -> Iterator[TextGenerationRequest]: item = self._queue.get_nowait() self._queue.task_done() yield item - except asyncio.QueueEmpty: + except Empty: continue else: while not self._stop_event.is_set(): @@ -118,46 +129,31 @@ def create_item(self) -> TextGenerationRequest: """ raise NotImplementedError() - def start(self): - """ - Start the background task that populates the queue. - """ - if self.mode == "async": - try: - loop = asyncio.get_running_loop() - logger.info("Using existing event loop") - except RuntimeError: - raise RuntimeError("No running event loop found for async mode") - - loop.call_soon_threadsafe( - lambda: asyncio.create_task(self._populate_queue()) - ) - logger.info( - f"RequestGenerator started in async mode with queue size: " - f"{self._async_queue_size}" - ) - else: - logger.info("RequestGenerator started in sync mode") - def stop(self): """ Stop the background task that populates the queue. """ logger.info("Stopping RequestGenerator...") self._stop_event.set() + if self._mode == "async": + self._thread.join() logger.info("RequestGenerator stopped") - async def _populate_queue(self): + def _populate_queue(self): """ Populate the request queue in the background. """ while not self._stop_event.is_set(): - if self._queue.qsize() < self._async_queue_size: - item = self.create_item() - await self._queue.put(item) - logger.debug( - f"Item added to queue. Current queue size: {self._queue.qsize()}" - ) - else: - await asyncio.sleep(0.1) + try: + if self._queue.qsize() < self._async_queue_size: + item = self.create_item() + self._queue.put(item, timeout=0.1) + logger.debug( + "Item added to queue. Current queue size: {}", + self._queue.qsize(), + ) + else: + time.sleep(0.1) + except Full: + continue logger.info("RequestGenerator stopped populating queue") diff --git a/tests/integration/core/__init__.py b/tests/integration/request/__init__.py similarity index 100% rename from tests/integration/core/__init__.py rename to tests/integration/request/__init__.py diff --git a/tests/integration/request/test_base.py b/tests/integration/request/test_base.py new file mode 100644 index 0000000..a631909 --- /dev/null +++ b/tests/integration/request/test_base.py @@ -0,0 +1,23 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from guidellm.core.request import TextGenerationRequest +from guidellm.request.base import RequestGenerator + + +class TestRequestGenerator(RequestGenerator): + def create_item(self) -> TextGenerationRequest: + return TextGenerationRequest(prompt="Test prompt") + + +@pytest.mark.smoke +def test_request_generator_with_hf_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + generator = TestRequestGenerator(tokenizer=tokenizer) + assert generator.tokenizer == tokenizer + + +@pytest.mark.smoke +def test_request_generator_with_string_tokenizer(): + generator = TestRequestGenerator(tokenizer="bert-base-uncased") + assert isinstance(generator.tokenizer, PreTrainedTokenizerBase) + assert generator.tokenizer.name_or_path == "bert-base-uncased" diff --git a/tests/unit/request/test_base.py b/tests/unit/request/test_base.py index 31623ba..c0afdcd 100644 --- a/tests/unit/request/test_base.py +++ b/tests/unit/request/test_base.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock, patch + import pytest from guidellm.core.request import TextGenerationRequest @@ -10,15 +12,28 @@ def create_item(self) -> TextGenerationRequest: @pytest.mark.smoke -def test_request_generator_sync(): +def test_request_generator_sync_constructor(): generator = TestRequestGenerator(mode="sync") assert generator.mode == "sync" + assert generator.async_queue_size == 50 # Default value assert generator.tokenizer is None + +@pytest.mark.smoke +def test_request_generator_async_constructor(): + generator = TestRequestGenerator(mode="async", async_queue_size=10) + assert generator.mode == "async" + assert generator.async_queue_size == 10 + assert generator.tokenizer is None + generator.stop() + + +@pytest.mark.smoke +def test_request_generator_sync_iter(): + generator = TestRequestGenerator(mode="sync") items = [] for item in generator: items.append(item) - if len(items) == 5: break @@ -27,28 +42,30 @@ def test_request_generator_sync(): @pytest.mark.smoke -@pytest.mark.asyncio -def test_request_generator_async(): - generator = TestRequestGenerator(mode="async", async_queue_size=10) - assert generator.mode == "async" - assert generator.async_queue_size == 10 - assert generator.tokenizer is None - - generator.start() - +def test_request_generator_async_iter(): + generator = TestRequestGenerator(mode="async") items = [] for item in generator: items.append(item) - if len(items) == 5: break generator.stop() - assert generator._stop_event.is_set() - assert len(items) == 5 assert items[0].prompt == "Test prompt" - assert items[-1].prompt == "Test prompt" + + +@pytest.mark.regression +def test_request_generator_with_mock_tokenizer(): + mock_tokenizer = Mock() + generator = TestRequestGenerator(tokenizer=mock_tokenizer) + assert generator.tokenizer == mock_tokenizer + + with patch("guidellm.request.base.AutoTokenizer") as MockAutoTokenizer: + MockAutoTokenizer.from_pretrained.return_value = mock_tokenizer + generator = TestRequestGenerator(tokenizer="mock-tokenizer") + assert generator.tokenizer == mock_tokenizer + MockAutoTokenizer.from_pretrained.assert_called_with("mock-tokenizer") @pytest.mark.regression @@ -57,3 +74,55 @@ def test_request_generator_repr(): assert repr(generator) == ( "RequestGenerator(mode=sync, async_queue_size=100, tokenizer=None)" ) + + +@pytest.mark.regression +def test_request_generator_create_item_not_implemented(): + with pytest.raises(TypeError): + class IncompleteRequestGenerator(RequestGenerator): + pass + + IncompleteRequestGenerator() + + class IncompleteCreateItemGenerator(RequestGenerator): + def create_item(self): + super().create_item() + + generator = IncompleteCreateItemGenerator() + with pytest.raises(NotImplementedError): + generator.create_item() + + +@pytest.mark.regression +def test_request_generator_iter_calls_create_item(): + generator = TestRequestGenerator(mode="sync") + generator.create_item = Mock( + return_value=TextGenerationRequest(prompt="Mock prompt") + ) + + items = [] + for item in generator: + items.append(item) + if len(items) == 5: + break + + assert len(items) == 5 + generator.create_item.assert_called() + + +@pytest.mark.regression +def test_request_generator_async_iter_calls_create_item(): + generator = TestRequestGenerator(mode="sync") + generator.create_item = Mock( + return_value=TextGenerationRequest(prompt="Mock prompt") + ) + + items = [] + for item in generator: + items.append(item) + if len(items) == 5: + break + + generator.stop() + assert len(items) == 5 + generator.create_item.assert_called()