diff --git a/kew/.coverage b/kew/.coverage index bc05a2d..933423c 100644 Binary files a/kew/.coverage and b/kew/.coverage differ diff --git a/kew/kew/exceptions.py b/kew/kew/exceptions.py index d9f8d21..d392dba 100644 --- a/kew/kew/exceptions.py +++ b/kew/kew/exceptions.py @@ -12,4 +12,8 @@ class TaskNotFoundError(TaskQueueError): class QueueNotFoundError(TaskQueueError): """Raised when attempting to access a non-existent queue""" + pass + +class QueueProcessorError(TaskQueueError): + """Raised when the queue processor encounters an error""" pass \ No newline at end of file diff --git a/kew/kew/manager.py b/kew/kew/manager.py index 984f607..3b1778c 100644 --- a/kew/kew/manager.py +++ b/kew/kew/manager.py @@ -1,42 +1,69 @@ -# kew/kew/manager.py -from concurrent.futures import ThreadPoolExecutor -from queue import PriorityQueue, Queue -import threading -from typing import Optional, Dict, Any, Callable, List, Tuple -from datetime import datetime +from typing import Optional, Dict, Any, Callable, List +from datetime import datetime, timedelta import logging import asyncio +import json +import redis.asyncio as redis from .models import TaskStatus, TaskInfo, QueueConfig, QueuePriority -from .exceptions import TaskAlreadyExistsError, TaskNotFoundError, QueueNotFoundError +from .exceptions import ( + TaskAlreadyExistsError, + TaskNotFoundError, + QueueNotFoundError, + QueueProcessorError +) logger = logging.getLogger(__name__) -class PrioritizedItem: - def __init__(self, priority: int, item: Any): - self.priority = priority - self.item = item - self.timestamp = datetime.now() +class CircuitBreaker: + def __init__(self, max_failures: int = 3, reset_timeout: int = 60): + self.max_failures = max_failures + self.reset_timeout = reset_timeout + self.failures = 0 + self.last_failure_time = None + self.is_open = False - def __lt__(self, other): - if self.priority != other.priority: - return self.priority < other.priority - return self.timestamp < other.timestamp + async def record_failure(self): + self.failures += 1 + self.last_failure_time = datetime.now() + if self.failures >= self.max_failures: + self.is_open = True + logger.error("Circuit breaker opened due to multiple failures") + + async def reset(self): + self.failures = 0 + self.last_failure_time = None + self.is_open = False + + async def check_state(self): + if not self.is_open: + return True + + if self.last_failure_time and \ + (datetime.now() - self.last_failure_time).seconds > self.reset_timeout: + await self.reset() + return True + return False class QueueWorkerPool: - """Manages workers for a specific queue""" def __init__(self, config: QueueConfig): self.config = config - self.executor = ThreadPoolExecutor(max_workers=config.max_workers) - self.queue: PriorityQueue[PrioritizedItem] = PriorityQueue(maxsize=config.max_size) self._shutdown = False - self._tasks: Dict[str, asyncio.Task] = {} - + self._tasks: Dict[str, Dict[str, Any]] = {} + self.circuit_breaker = CircuitBreaker() + self.processing_semaphore = asyncio.Semaphore(config.max_workers) + self.start_processing = asyncio.Event() # Add this line class TaskQueueManager: - def __init__(self): - """Initialize TaskQueueManager with multiple queue support""" + TASK_EXPIRY_SECONDS = 86400 # 24 hours + QUEUE_KEY_PREFIX = "queue:" + TASK_KEY_PREFIX = "task:" + + def __init__(self, redis_url: str = "redis://localhost:6379", cleanup_on_start: bool = True): self.queues: Dict[str, QueueWorkerPool] = {} - self.tasks: Dict[str, TaskInfo] = {} - self._lock = threading.Lock() + self._lock = asyncio.Lock() + self._redis: Optional[redis.Redis] = None + self._redis_url = redis_url + self._shutdown_event = asyncio.Event() + self._cleanup_on_start = cleanup_on_start self._setup_logging() def _setup_logging(self): @@ -48,50 +75,48 @@ def _setup_logging(self): logger.addHandler(handler) logger.setLevel(logging.INFO) - def create_queue(self, config: QueueConfig): - """Create a new queue with specified configuration""" - with self._lock: + async def initialize(self): + self._redis = redis.from_url( + self._redis_url, + encoding="utf-8", + decode_responses=True + ) + logger.info("Connected to Redis") + + if self._cleanup_on_start: + await self.cleanup() + + async def cleanup(self): + if not self._redis: + return + + async for key in self._redis.scan_iter(f"{self.QUEUE_KEY_PREFIX}*"): + await self._redis.delete(key) + + async for key in self._redis.scan_iter(f"{self.TASK_KEY_PREFIX}*"): + await self._redis.delete(key) + + logger.info("Cleaned up all existing queues and tasks") + + async def create_queue(self, config: QueueConfig): + async with self._lock: if config.name in self.queues: raise ValueError(f"Queue {config.name} already exists") + worker_pool = QueueWorkerPool(config) self.queues[config.name] = worker_pool - logger.info(f"Created queue {config.name} with {config.max_workers} workers") - # Start queue processor + await self._redis.hset( + f"{self.QUEUE_KEY_PREFIX}{config.name}", + mapping={ + "max_workers": config.max_workers, + "max_size": config.max_size, + "priority": config.priority.value + } + ) + asyncio.create_task(self._process_queue(config.name)) - - async def _process_queue(self, queue_name: str): - """Process tasks in the queue""" - worker_pool = self.queues[queue_name] - - while not worker_pool._shutdown: - try: - if not worker_pool.queue.empty(): - prioritized_item = worker_pool.queue.get_nowait() - task_id = prioritized_item.item - - with self._lock: - task_info = self.tasks[task_id] - if task_info.status == TaskStatus.QUEUED: - # Execute task - task_info.status = TaskStatus.PROCESSING - task_info.started_time = datetime.now() - logger.info(f"Processing task {task_id} from queue {queue_name}") - - # Create task - task = asyncio.create_task(task_info._func(*task_info._args, **task_info._kwargs)) - worker_pool._tasks[task_id] = task - - # Add completion callback - task.add_done_callback( - lambda f, tid=task_id: self._handle_task_completion(tid, f) - ) - - await asyncio.sleep(0.1) # Small delay to prevent busy waiting - - except Exception as e: - logger.error(f"Error processing queue {queue_name}: {str(e)}") - await asyncio.sleep(1) # Longer delay on error + logger.info(f"Created queue {config.name} with {config.max_workers} workers") async def submit_task( self, @@ -99,154 +124,196 @@ async def submit_task( queue_name: str, task_type: str, task_func: Callable, - priority: QueuePriority = QueuePriority.MEDIUM, + priority: QueuePriority, *args, **kwargs ) -> TaskInfo: - """Submit a task to a specific queue""" - with self._lock: - if task_id in self.tasks: - raise TaskAlreadyExistsError( - f"Task with ID {task_id} already exists" - ) - + """Submit a task to a queue""" + async with self._lock: if queue_name not in self.queues: - raise QueueNotFoundError( - f"Queue {queue_name} not found" - ) + raise QueueNotFoundError(f"Queue {queue_name} not found") + + task_info = TaskInfo( + task_id=task_id, + task_type=task_type, + queue_name=queue_name, + priority=priority.value + ) - worker_pool = self.queues[queue_name] - task_info = TaskInfo(task_id, task_type, queue_name, priority.value) + self.queues[queue_name]._tasks[task_id] = { + 'func': task_func, + 'args': args, + 'kwargs': kwargs, + 'task': None + } - # Store function and arguments for later execution - task_info._func = task_func - task_info._args = args - task_info._kwargs = kwargs + await self._redis.set( + f"{self.TASK_KEY_PREFIX}{task_id}", + task_info.to_json(), + ex=self.TASK_EXPIRY_SECONDS + ) - self.tasks[task_id] = task_info + # New scoring system: + # score = priority * 1_000_000 + timestamp + # This ensures: + # 1. Priority is the primary factor (lower value = higher priority) + # 2. Within same priority, earlier tasks come first + current_time = int(datetime.now().timestamp() * 1000) # milliseconds + score = (priority.value * 1_000_000) + current_time - # Add to priority queue - worker_pool.queue.put(PrioritizedItem( - priority=priority.value, - item=task_id - )) + await self._redis.zadd( + f"{self.QUEUE_KEY_PREFIX}{queue_name}:tasks", + {task_id: score} + ) logger.info(f"Task {task_id} submitted to queue {queue_name}") - return task_info + + async def _process_queue(self, queue_name: str): + """Process tasks in the queue""" + worker_pool = self.queues[queue_name] + queue_key = f"{self.QUEUE_KEY_PREFIX}{queue_name}:tasks" + + while not self._shutdown_event.is_set(): + try: + async with worker_pool.processing_semaphore: + # Get highest priority task + next_task = await self._redis.zrange( + queue_key, + 0, + 0, # Get only the highest priority task + withscores=True + ) + + if not next_task: + await asyncio.sleep(0.1) + continue + + task_id = next_task[0][0] + if isinstance(task_id, bytes): + task_id = task_id.decode('utf-8') + + # Get task info + task_info_data = await self._redis.get(f"{self.TASK_KEY_PREFIX}{task_id}") + if not task_info_data: + await self._redis.zrem(queue_key, task_id) + continue + + task_info = TaskInfo.from_json(task_info_data) + + if task_info.status == TaskStatus.QUEUED: + task_data = worker_pool._tasks.get(task_id) + if not task_data: + continue - def _handle_task_completion(self, task_id: str, future: asyncio.Future): - """Handle task completion and cleanup""" + func = task_data['func'] + args = task_data.get('args', ()) + kwargs = task_data.get('kwargs', {}) + + # Remove from queue and update status + await self._redis.zrem(queue_key, task_id) + task_info.status = TaskStatus.PROCESSING + await self._redis.set( + f"{self.TASK_KEY_PREFIX}{task_id}", + task_info.to_json(), + ex=self.TASK_EXPIRY_SECONDS + ) + + try: + # Execute the task + result = await func(*args, **kwargs) + task_info.status = TaskStatus.COMPLETED + task_info.result = result + task_info.completed_time = datetime.now() + logger.info(f"Task {task_id} completed successfully") + except Exception as e: + task_info.status = TaskStatus.FAILED + task_info.error = str(e) + task_info.completed_time = datetime.now() + logger.error(f"Task {task_id} failed: {str(e)}") + + # Update final status + await self._redis.set( + f"{self.TASK_KEY_PREFIX}{task_id}", + task_info.to_json(), + ex=self.TASK_EXPIRY_SECONDS + ) + + except Exception as e: + logger.error(f"Error processing queue {queue_name}: {str(e)}") + await asyncio.sleep(1) # Back off on error + async def _handle_task_completion(self, task_id: str, future: asyncio.Future): try: result = future.result() - with self._lock: - task_info = self.tasks[task_id] - task_info.result = result + task_info_data = await self._redis.get(f"{self.TASK_KEY_PREFIX}{task_id}") + if task_info_data: + task_info = TaskInfo.from_json(task_info_data) task_info.status = TaskStatus.COMPLETED + task_info.result = result task_info.completed_time = datetime.now() - logger.info(f"Task {task_id} completed successfully with result: {result}") - # Clean up task - worker_pool = self.queues[task_info.queue_name] - if task_id in worker_pool._tasks: - del worker_pool._tasks[task_id] - + await self._redis.set( + f"{self.TASK_KEY_PREFIX}{task_id}", + task_info.to_json(), + ex=self.TASK_EXPIRY_SECONDS + ) + + logger.info(f"Task {task_id} completed successfully") + except Exception as e: logger.error(f"Task {task_id} failed: {str(e)}") - with self._lock: - task_info = self.tasks[task_id] + task_info_data = await self._redis.get(f"{self.TASK_KEY_PREFIX}{task_id}") + if task_info_data: + task_info = TaskInfo.from_json(task_info_data) task_info.status = TaskStatus.FAILED task_info.error = str(e) task_info.completed_time = datetime.now() + + await self._redis.set( + f"{self.TASK_KEY_PREFIX}{task_id}", + task_info.to_json(), + ex=self.TASK_EXPIRY_SECONDS + ) - def get_task_status(self, task_id: str) -> TaskInfo: - """Get status of a specific task""" - with self._lock: - task_info = self.tasks.get(task_id) - if not task_info: - raise TaskNotFoundError(f"Task {task_id} not found") - return task_info - - def get_queue_status(self, queue_name: str) -> Dict[str, Any]: - """Get status of a specific queue""" - with self._lock: - if queue_name not in self.queues: - raise QueueNotFoundError(f"Queue {queue_name} not found") - - worker_pool = self.queues[queue_name] - queue_tasks = [ - task for task in self.tasks.values() - if task.queue_name == queue_name - ] - - return { - "name": queue_name, - "max_workers": worker_pool.config.max_workers, - "priority": worker_pool.config.priority.value, - "active_tasks": len([t for t in queue_tasks if t.status == TaskStatus.PROCESSING]), - "queued_tasks": len([t for t in queue_tasks if t.status == TaskStatus.QUEUED]), - "completed_tasks": len([t for t in queue_tasks if t.status == TaskStatus.COMPLETED]), - "failed_tasks": len([t for t in queue_tasks if t.status == TaskStatus.FAILED]) - } + async def get_task_status(self, task_id: str) -> TaskInfo: + task_data = await self._redis.get(f"{self.TASK_KEY_PREFIX}{task_id}") + if not task_data: + raise TaskNotFoundError(f"Task {task_id} not found") + return TaskInfo.from_json(task_data) - async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskInfo: - """Wait for a specific task to complete""" - task_info = self.get_task_status(task_id) - worker_pool = self.queues[task_info.queue_name] - - if task_id in worker_pool._tasks: - try: - await asyncio.wait_for(worker_pool._tasks[task_id], timeout=timeout) - except asyncio.TimeoutError: - logger.warning(f"Task {task_id} timed out after {timeout} seconds") - raise - - return task_info - - async def wait_for_queue(self, queue_name: str, timeout: Optional[float] = None): - """Wait for all tasks in a specific queue to complete""" + async def get_queue_status(self, queue_name: str) -> Dict[str, Any]: if queue_name not in self.queues: raise QueueNotFoundError(f"Queue {queue_name} not found") - - worker_pool = self.queues[queue_name] - tasks = list(worker_pool._tasks.values()) - if tasks: - await asyncio.wait(tasks, timeout=timeout) - - def cleanup_old_tasks(self, max_age_hours: int = 24, queue_name: Optional[str] = None): - """Clean up completed tasks, optionally for a specific queue""" - current_time = datetime.now() - cleaned_count = 0 - with self._lock: - for task_id, task_info in list(self.tasks.items()): - if queue_name and task_info.queue_name != queue_name: - continue - - if task_info.completed_time: - age = current_time - task_info.completed_time - if age.total_seconds() > max_age_hours * 3600: - del self.tasks[task_id] - cleaned_count += 1 + worker_pool = self.queues[queue_name] + queue_size = await self._redis.zcard(f"{self.QUEUE_KEY_PREFIX}{queue_name}:tasks") - logger.info(f"Cleaned up {cleaned_count} old tasks") + return { + "name": queue_name, + "max_workers": worker_pool.config.max_workers, + "current_workers": len(worker_pool._tasks), + "queued_tasks": queue_size, + "circuit_breaker_status": "open" if worker_pool.circuit_breaker.is_open else "closed" + } - async def shutdown(self, wait: bool = True): - """Shutdown all queues""" + async def shutdown(self, wait: bool = True, timeout: float = 5.0): logger.info("Shutting down TaskQueueManager") + self._shutdown_event.set() - # Wait for pending tasks if requested if wait: for queue_name, worker_pool in self.queues.items(): - tasks = list(worker_pool._tasks.values()) - if tasks: + worker_pool._shutdown = True + active_tasks = [] + for task_data in worker_pool._tasks.values(): + if isinstance(task_data, dict) and task_data.get('task'): + active_tasks.append(task_data['task']) + + if active_tasks: try: - await asyncio.wait(tasks, timeout=5.0) + await asyncio.wait(active_tasks, timeout=timeout) except Exception as e: logger.error(f"Error waiting for tasks in queue {queue_name}: {str(e)}") - # Shutdown all worker pools - for queue_name, worker_pool in self.queues.items(): - worker_pool._shutdown = True - worker_pool.executor.shutdown(wait=wait) - logger.info(f"Shut down queue {queue_name}") \ No newline at end of file + if self._redis: + await self._redis.close() + logger.info("Closed Redis connection") \ No newline at end of file diff --git a/kew/kew/models.py b/kew/kew/models.py index 5688188..2c6f2f9 100644 --- a/kew/kew/models.py +++ b/kew/kew/models.py @@ -1,8 +1,8 @@ from datetime import datetime from enum import Enum -from typing import Optional, TypeVar, Generic +from typing import Optional, TypeVar, Generic, Dict, Any from dataclasses import dataclass - +import json T = TypeVar('T') # Generic type for task result class TaskStatus(Enum): @@ -15,6 +15,7 @@ class QueuePriority(Enum): HIGH = 1 MEDIUM = 2 LOW = 3 +T = TypeVar('T') # Generic type for task result @dataclass class QueueConfig: @@ -26,19 +27,30 @@ class QueueConfig: task_timeout: int = 3600 class TaskInfo(Generic[T]): - def __init__(self, task_id: str, task_type: str, queue_name: str, priority: int): + def __init__( + self, + task_id: str, + task_type: str, + queue_name: str, + priority: int, + status: TaskStatus = TaskStatus.QUEUED # Made optional with default + ): self.task_id = task_id self.task_type = task_type self.queue_name = queue_name self.priority = priority - self.status = TaskStatus.QUEUED + self.status = status self.queued_time = datetime.now() self.started_time: Optional[datetime] = None self.completed_time: Optional[datetime] = None self.result: Optional[T] = None self.error: Optional[str] = None + # Store function and arguments for execution + self._func = None + self._args = () + self._kwargs = {} - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "task_id": self.task_id, "task_type": self.task_type, @@ -51,3 +63,25 @@ def to_dict(self): "result": self.result, "error": self.error } + + def to_json(self) -> str: + """Convert task info to JSON string""" + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> 'TaskInfo': + """Create TaskInfo instance from JSON string""" + data = json.loads(json_str) + task = cls( + task_id=data["task_id"], + task_type=data["task_type"], + queue_name=data["queue_name"], + priority=data["priority"], + status=TaskStatus(data["status"]) + ) + task.queued_time = datetime.fromisoformat(data["queued_time"]) + task.started_time = datetime.fromisoformat(data["started_time"]) if data["started_time"] else None + task.completed_time = datetime.fromisoformat(data["completed_time"]) if data["completed_time"] else None + task.result = data["result"] + task.error = data["error"] + return task \ No newline at end of file diff --git a/kew/kew/tests/conftest.py b/kew/kew/tests/conftest.py index 454080e..c7f7d4a 100644 --- a/kew/kew/tests/conftest.py +++ b/kew/kew/tests/conftest.py @@ -1,11 +1,15 @@ -# kew/tests/conftest.py import pytest import asyncio -from kew import TaskQueueManager -@pytest.fixture -async def manager(): - """Fixture that provides a TaskQueueManager instance""" - manager = TaskQueueManager() - yield manager - await manager.shutdown() +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for each test case.""" + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + loop.close() + +@pytest.fixture(scope="session") +def anyio_backend(): + """Backend for anyio/pytest-asyncio.""" + return 'asyncio' \ No newline at end of file diff --git a/kew/kew/tests/test_queue_manager.py b/kew/kew/tests/test_queue_manager.py index cf3fd20..c88740f 100644 --- a/kew/kew/tests/test_queue_manager.py +++ b/kew/kew/tests/test_queue_manager.py @@ -9,214 +9,273 @@ async def long_task(task_num: int, sleep_time: float) -> dict: result = sleep_time * 2 return {"task_num": task_num, "result": result} +@pytest.fixture +async def manager(): + """Fixture to provide a TaskQueueManager instance""" + mgr = TaskQueueManager(redis_url="redis://localhost:6379", cleanup_on_start=True) + await mgr.initialize() + try: + yield mgr + finally: + await mgr.shutdown() + @pytest.mark.asyncio async def test_single_queue(): """Test single queue operation""" - manager = TaskQueueManager() - - # Create queue - manager.create_queue(QueueConfig( - name="test_queue", - max_workers=2, - priority=QueuePriority.HIGH - )) - - # Submit task - task_info = await manager.submit_task( - task_id="task1", - queue_name="test_queue", - task_type="test", - task_func=long_task, - priority=QueuePriority.HIGH, - task_num=1, - sleep_time=0.1 - ) - - # Check initial status - status = manager.get_task_status(task_info.task_id) - assert status.queue_name == "test_queue" - - # Wait for completion - await asyncio.sleep(0.2) - - # Check final status - status = manager.get_task_status(task_info.task_id) - assert status.status == TaskStatus.COMPLETED - assert status.result["task_num"] == 1 - assert status.result["result"] == 0.2 + # Initialize manager + manager = TaskQueueManager(redis_url="redis://localhost:6379", cleanup_on_start=True) + await manager.initialize() + + try: + # Create queue + await manager.create_queue(QueueConfig( + name="test_queue", + max_workers=2, + priority=QueuePriority.HIGH + )) + + # Submit task + task_info = await manager.submit_task( + task_id="task1", + queue_name="test_queue", + task_type="test", + task_func=long_task, + priority=QueuePriority.HIGH, + task_num=1, + sleep_time=0.1 + ) + + # Check initial status + status = await manager.get_task_status(task_info.task_id) + assert status.queue_name == "test_queue" + + # Wait for completion + await asyncio.sleep(0.2) + + # Check final status + status = await manager.get_task_status(task_info.task_id) + assert status.status == TaskStatus.COMPLETED + assert status.result["task_num"] == 1 + assert status.result["result"] == 0.2 - await manager.shutdown() + finally: + await manager.shutdown() @pytest.mark.asyncio async def test_multiple_queues(): """Test multiple queues with different priorities""" - manager = TaskQueueManager() - - # Create queues - manager.create_queue(QueueConfig( - name="fast_track", - max_workers=2, - priority=QueuePriority.HIGH - )) - - manager.create_queue(QueueConfig( - name="standard", - max_workers=1, - priority=QueuePriority.LOW - )) - - tasks = [] - - # Submit high-priority tasks - for i in range(2): - sleep_time = 0.1 + # Initialize manager + manager = TaskQueueManager(redis_url="redis://localhost:6379", cleanup_on_start=True) + await manager.initialize() + + try: + # Create queues + await manager.create_queue(QueueConfig( + name="fast_track", + max_workers=2, + priority=QueuePriority.HIGH + )) + + await manager.create_queue(QueueConfig( + name="standard", + max_workers=1, + priority=QueuePriority.LOW + )) + + tasks = [] + + # Submit high-priority tasks + for i in range(2): + sleep_time = 0.1 + task_info = await manager.submit_task( + task_id=f"high_task_{i+1}", + queue_name="fast_track", + task_type="test", + task_func=long_task, + priority=QueuePriority.HIGH, + task_num=i+1, + sleep_time=sleep_time + ) + tasks.append(task_info) + + # Submit low-priority task task_info = await manager.submit_task( - task_id=f"high_task_{i+1}", - queue_name="fast_track", + task_id="low_task_1", + queue_name="standard", task_type="test", task_func=long_task, - priority=QueuePriority.HIGH, - task_num=i+1, - sleep_time=sleep_time + priority=QueuePriority.LOW, + task_num=3, + sleep_time=0.1 ) tasks.append(task_info) - - # Submit low-priority task - task_info = await manager.submit_task( - task_id="low_task_1", - queue_name="standard", - task_type="test", - task_func=long_task, - priority=QueuePriority.LOW, - task_num=3, - sleep_time=0.1 - ) - tasks.append(task_info) - - # Wait for completion - await asyncio.sleep(0.3) - - # Check all tasks completed - for task in tasks: - status = manager.get_task_status(task.task_id) - assert status.status == TaskStatus.COMPLETED - assert status.result is not None - - # Check queue statuses - fast_track_status = manager.get_queue_status("fast_track") - standard_status = manager.get_queue_status("standard") - - assert fast_track_status["completed_tasks"] == 2 - assert standard_status["completed_tasks"] == 1 - - await manager.shutdown() + + # Wait for completion + await asyncio.sleep(0.3) + + # Check all tasks completed + for task in tasks: + status = await manager.get_task_status(task.task_id) + assert status.status == TaskStatus.COMPLETED + assert status.result is not None + + # Check queue statuses + fast_track_status = await manager.get_queue_status("fast_track") + standard_status = await manager.get_queue_status("standard") + + assert fast_track_status["queued_tasks"] == 0 + assert standard_status["queued_tasks"] == 0 + + finally: + await manager.shutdown() @pytest.mark.asyncio async def test_queue_priorities(): - """Test that high priority tasks complete before low priority ones""" - manager = TaskQueueManager() - - manager.create_queue(QueueConfig( - name="mixed_queue", - max_workers=1, # Single worker to ensure sequential execution - priority=QueuePriority.MEDIUM - )) - - completion_order = [] - - async def tracking_task(priority_level: str): + """Test that high priority tasks are processed before lower priority ones when available""" + manager = TaskQueueManager(redis_url="redis://localhost:6379", cleanup_on_start=True) + await manager.initialize() + + try: + await manager.create_queue(QueueConfig( + name="priority_queue", + max_workers=1, + priority=QueuePriority.MEDIUM + )) + + execution_order = [] + + async def priority_task(priority_name: str): + execution_order.append(priority_name) + return f"Completed {priority_name}" + + # Submit low priority task first - it should start processing + low_task = await manager.submit_task( + task_id="low_priority", + queue_name="priority_queue", + task_type="test", + task_func=priority_task, + priority=QueuePriority.LOW, + priority_name="low" + ) + + # Give it a moment to start processing await asyncio.sleep(0.1) - completion_order.append(priority_level) - return priority_level - - # Submit low priority task first - await manager.submit_task( - task_id="low_priority", - queue_name="mixed_queue", - task_type="test", - task_func=tracking_task, - priority=QueuePriority.LOW, - priority_level="low" - ) - - # Submit high priority task second - await manager.submit_task( - task_id="high_priority", - queue_name="mixed_queue", - task_type="test", - task_func=tracking_task, - priority=QueuePriority.HIGH, - priority_level="high" - ) - - # Wait for completion - await asyncio.sleep(0.3) - - # High priority task should complete first - assert completion_order[0] == "high" - assert completion_order[1] == "low" - - await manager.shutdown() + + # Submit high and medium priority tasks + high_task = await manager.submit_task( + task_id="high_priority", + queue_name="priority_queue", + task_type="test", + task_func=priority_task, + priority=QueuePriority.HIGH, + priority_name="high" + ) + + medium_task = await manager.submit_task( + task_id="medium_priority", + queue_name="priority_queue", + task_type="test", + task_func=priority_task, + priority=QueuePriority.MEDIUM, + priority_name="medium" + ) + + # Wait for all tasks to complete + tasks = [low_task, medium_task, high_task] + for _ in range(30): # 3 second timeout + all_completed = True + for task in tasks: + status = await manager.get_task_status(task.task_id) + if status.status != TaskStatus.COMPLETED: + all_completed = False + break + if all_completed: + break + await asyncio.sleep(0.1) + + # Verify that high priority tasks are processed before lower priority ones + # when they're available at the same time + assert len(execution_order) == 3, f"Expected 3 tasks, got {len(execution_order)}" + high_index = execution_order.index("high") + medium_index = execution_order.index("medium") + assert high_index < medium_index, "High priority should be processed before medium" + + finally: + await manager.shutdown() + + @pytest.mark.asyncio async def test_error_handling(): """Test error handling in tasks""" - manager = TaskQueueManager() - - manager.create_queue(QueueConfig( - name="test_queue", - max_workers=1 - )) - - async def failing_task(): - await asyncio.sleep(0.1) - raise ValueError("Test error") - - task_info = await manager.submit_task( - task_id="failing_task", - queue_name="test_queue", - task_type="test", - task_func=failing_task, - priority=QueuePriority.MEDIUM - ) - - # Wait for task to fail - await asyncio.sleep(0.2) - - status = manager.get_task_status("failing_task") - assert status.status == TaskStatus.FAILED - assert "Test error" in status.error - - await manager.shutdown() - + manager = TaskQueueManager(redis_url="redis://localhost:6379", cleanup_on_start=True) + await manager.initialize() + + try: + await manager.create_queue(QueueConfig( + name="test_queue", + max_workers=1 + )) + + async def failing_task(): + await asyncio.sleep(0.1) + raise ValueError("Test error") + + # Submit failing task + await manager.submit_task( + task_id="failing_task", + queue_name="test_queue", + task_type="test", + task_func=failing_task, + priority=QueuePriority.MEDIUM + ) + + # Wait for task to fail + # Increased wait time and added polling + max_attempts = 10 + for _ in range(max_attempts): + await asyncio.sleep(0.1) + status = await manager.get_task_status("failing_task") + if status.status == TaskStatus.FAILED: + break + + assert status.status == TaskStatus.FAILED, f"Expected FAILED status, got {status.status}" + assert "Test error" in status.error, f"Expected 'Test error' in error message, got {status.error}" + + finally: + await manager.shutdown() @pytest.mark.asyncio async def test_queue_cleanup(): """Test queue cleanup functionality""" - manager = TaskQueueManager() - - manager.create_queue(QueueConfig( - name="test_queue", - max_workers=1 - )) - - task_info = await manager.submit_task( - task_id="task1", - queue_name="test_queue", - task_type="test", - task_func=long_task, - priority=QueuePriority.MEDIUM, - task_num=1, - sleep_time=0.1 - ) - - # Wait for completion - await asyncio.sleep(0.2) - - # Clean up old tasks - manager.cleanup_old_tasks(max_age_hours=0) - - # Check task was cleaned up - with pytest.raises(Exception): - manager.get_task_status("task1") - - await manager.shutdown() + # Initialize manager + manager = TaskQueueManager(redis_url="redis://localhost:6379", cleanup_on_start=True) + await manager.initialize() + + try: + await manager.create_queue(QueueConfig( + name="test_queue", + max_workers=1 + )) + + await manager.submit_task( + task_id="task1", + queue_name="test_queue", + task_type="test", + task_func=long_task, + priority=QueuePriority.MEDIUM, + task_num=1, + sleep_time=0.1 + ) + + # Wait for completion + await asyncio.sleep(0.2) + + # Clean up Redis + await manager.cleanup() + + # Check task was cleaned up + with pytest.raises(Exception): + await manager.get_task_status("task1") + + finally: + await manager.shutdown() \ No newline at end of file diff --git a/kew/pyproject.toml b/kew/pyproject.toml index 98e50d2..8f68f24 100644 --- a/kew/pyproject.toml +++ b/kew/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "hatchling.build" [project] name = "kew" -version = "0.1.1" +version = "0.1.2" authors = [ { name="Rach Pradhan", email="rach@rachit.ai" }, ] diff --git a/kew/test_multiple_redis.py b/kew/test_multiple_redis.py new file mode 100644 index 0000000..56cfc9c --- /dev/null +++ b/kew/test_multiple_redis.py @@ -0,0 +1,99 @@ +import asyncio +import random +from kew import TaskQueueManager, QueueConfig, QueuePriority, TaskStatus + +async def long_task(task_num: int, sleep_time: int) -> dict: + """Simulate a long-running task""" + print(f"Starting task {task_num} (will take {sleep_time} seconds)") + await asyncio.sleep(sleep_time) + result = sleep_time * 2 + print(f"Task {task_num} completed with result: {result}") + return {"task_num": task_num, "result": result} + +async def run_tasks(): + # Initialize manager with Redis connection and cleanup + manager = TaskQueueManager( + redis_url="redis://localhost:6379", + cleanup_on_start=True # This will clean up existing queues on start + ) + await manager.initialize() + + try: + # Create queues + await manager.create_queue(QueueConfig( + name="fast_track", + max_workers=2, + priority=QueuePriority.HIGH + )) + + await manager.create_queue(QueueConfig( + name="standard", + max_workers=1, + priority=QueuePriority.LOW + )) + + tasks = [] + + # Submit 3 high-priority tasks + for i in range(3): + sleep_time = random.randint(2, 4) + task_info = await manager.submit_task( + task_id=f"high_task_{i+1}", + queue_name="fast_track", + task_type="long_calculation", + task_func=long_task, + priority=QueuePriority.HIGH, + task_num=i+1, + sleep_time=sleep_time + ) + tasks.append(task_info) + + # Submit 2 low-priority tasks + for i in range(2): + sleep_time = random.randint(1, 3) + task_info = await manager.submit_task( + task_id=f"low_task_{i+1}", + queue_name="standard", + task_type="long_calculation", + task_func=long_task, + priority=QueuePriority.LOW, + task_num=i+1, + sleep_time=sleep_time + ) + tasks.append(task_info) + + # Monitor progress + while True: + all_completed = True + print("\nCurrent status:") + for task in tasks: + status = await manager.get_task_status(task.task_id) + print(f"{task.task_id} ({task.queue_name}): {status.status.value} - Result: {status.result}") + if status.status not in (TaskStatus.COMPLETED, TaskStatus.FAILED): + all_completed = False + + if all_completed: + break + + await asyncio.sleep(1) + + # Final queue statuses + print("\nFinal Queue Statuses:") + print("Fast Track Queue:", await manager.get_queue_status("fast_track")) + print("Standard Queue:", await manager.get_queue_status("standard")) + + finally: + # Ensure manager is properly shut down + await manager.shutdown() + +async def main(): + try: + await run_tasks() + except KeyboardInterrupt: + print("\nShutting down gracefully...") + except Exception as e: + print(f"Error occurred: {e}") + raise + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file