From f9b332056f829f26da6ded2fac49a1bc2b9dccb2 Mon Sep 17 00:00:00 2001 From: Aditya Vaish Date: Thu, 29 Jan 2026 13:11:22 +0530 Subject: [PATCH] refactor: update AsyncQueueManager to use a single RabbitMQ priority queue and enhance connection handling; add default queue name and priority mapping --- .../app/core/orchestration/queue_manager.py | 149 ++++++++++-------- 1 file changed, 85 insertions(+), 64 deletions(-) diff --git a/backend/app/core/orchestration/queue_manager.py b/backend/app/core/orchestration/queue_manager.py index 346bc9a0..1eafe0db 100644 --- a/backend/app/core/orchestration/queue_manager.py +++ b/backend/app/core/orchestration/queue_manager.py @@ -9,43 +9,57 @@ logger = logging.getLogger(__name__) +# Single queue name for all priorities (broker handles ordering via x-max-priority) +DEFAULT_QUEUE_NAME = "task_queue" +MAX_PRIORITY = 10 # RabbitMQ priority 0-255; higher = more urgent + + class QueuePriority(str, Enum): HIGH = "high" MEDIUM = "medium" LOW = "low" + +# Map enum to numeric priority for RabbitMQ (higher number = higher priority) +PRIORITY_MAP = { + QueuePriority.HIGH: 10, + QueuePriority.MEDIUM: 5, + QueuePriority.LOW: 1, +} + class AsyncQueueManager: - """Queue manager for agent orchestration""" + """Queue manager for agent orchestration using a single RabbitMQ priority queue.""" - def __init__(self): - self.queues = { - QueuePriority.HIGH: 'high_task_queue', - QueuePriority.MEDIUM: 'medium_task_queue', - QueuePriority.LOW: 'low_task_queue' - } + def __init__(self, queue_name: str = DEFAULT_QUEUE_NAME): + self.queue_name = queue_name self.handlers: Dict[str, Callable] = {} self.running = False - self.worker_tasks = [] + self.worker_tasks: list[asyncio.Task] = [] self.connection: Optional[aio_pika.RobustConnection] = None self.channel: Optional[aio_pika.abc.AbstractChannel] = None - - - async def connect(self): + async def connect(self) -> None: try: - rabbitmq_url = getattr(settings, 'rabbitmq_url', 'amqp://guest:guest@localhost/') + rabbitmq_url = getattr( + settings, "rabbitmq_url", "amqp://guest:guest@localhost/" + ) self.connection = await aio_pika.connect_robust(rabbitmq_url) self.channel = await self.connection.channel() - # Declare queues - for queue_name in self.queues.values(): - await self.channel.declare_queue(queue_name, durable=True) - logger.info("Successfully connected to RabbitMQ") + # Prefetch: broker sends at most this many unacked messages per consumer + await self.channel.set_qos(prefetch_count=1) + # Single priority queue: broker orders by message priority, no polling + await self.channel.declare_queue( + self.queue_name, + durable=True, + arguments={"x-max-priority": MAX_PRIORITY}, + ) + logger.info("Successfully connected to RabbitMQ (single priority queue)") except Exception as e: logger.error(f"Failed to connect to RabbitMQ: {e}") raise - async def start(self, num_workers: int = 3): - """Start the queue processing workers""" + async def start(self, num_workers: int = 3) -> None: + """Start the queue processing workers (push-based consumers, no polling).""" await self.connect() self.running = True @@ -53,78 +67,83 @@ async def start(self, num_workers: int = 3): task = asyncio.create_task(self._worker(f"worker-{i}")) self.worker_tasks.append(task) - logger.info(f"Started {num_workers} async queue workers") + logger.info(f"Started {num_workers} async queue workers on {self.queue_name}") - async def stop(self): - """Stop the queue processing""" + async def stop(self) -> None: + """Stop the queue processing and close connections.""" self.running = False - # Cancel all worker tasks for task in self.worker_tasks: task.cancel() await asyncio.gather(*self.worker_tasks, return_exceptions=True) + self.worker_tasks.clear() if self.channel: await self.channel.close() if self.connection: await self.connection.close() logger.info("Stopped all queue workers and closed connection") - async def enqueue(self, - message: Dict[str, Any], - priority: QueuePriority = QueuePriority.MEDIUM, - delay: float = 0): - """Add a message to the queue""" - + async def enqueue( + self, + message: Dict[str, Any], + priority: QueuePriority = QueuePriority.MEDIUM, + delay: float = 0, + ) -> None: + """Add a message to the single priority queue.""" if delay > 0: await asyncio.sleep(delay) queue_item = { "id": message.get("id", f"msg_{datetime.now().timestamp()}"), - "priority": priority, - "data": message + "priority": priority.value, + "data": message, } - json_message = json.dumps(queue_item).encode() + json_body = json.dumps(queue_item).encode() + numeric_priority = PRIORITY_MAP[priority] + await self.channel.default_exchange.publish( - aio_pika.Message(body=json_message), - routing_key=self.queues[priority] + aio_pika.Message(body=json_body, priority=numeric_priority), + routing_key=self.queue_name, ) logger.info(f"Enqueued message {queue_item['id']} with priority {priority}") - def register_handler(self, message_type: str, handler: Callable): - """Register a handler for a specific message type""" + def register_handler(self, message_type: str, handler: Callable) -> None: + """Register a handler for a specific message type.""" self.handlers[message_type] = handler logger.info(f"Registered handler for message type: {message_type}") - async def _worker(self, worker_name: str): - """Worker coroutine to process queue items""" + async def _worker(self, worker_name: str) -> None: + """Worker: long-lived consumer on the single queue (push-based, no polling).""" logger.info(f"Started queue worker: {worker_name}") - # Each worker listens to all queues by priority - queues = [ - await self.channel.declare_queue(self.queues[priority], durable=True) - for priority in [QueuePriority.HIGH, QueuePriority.MEDIUM, QueuePriority.LOW] - ] - while self.running: - for queue in queues: - try: - message = await queue.get(no_ack=False, fail=False) - if message: - try: - item = json.loads(message.body.decode()) - await self._process_item(item, worker_name) - await message.ack() - except Exception as e: - logger.error(f"Error processing message: {e}") - await message.nack(requeue=False) - except asyncio.CancelledError: - logger.info(f"Worker {worker_name} cancelled") - return - except Exception as e: - logger.error(f"Worker {worker_name} error: {e}") - await asyncio.sleep(0.1) - - async def _process_item(self, item: Dict[str, Any], worker_name: str): - """Process a queue item""" + + queue = await self.channel.declare_queue( + self.queue_name, + durable=True, + arguments={"x-max-priority": MAX_PRIORITY}, + ) + + try: + async with queue.iterator() as queue_iter: + async for message in queue_iter: + if not self.running: + break + try: + item = json.loads(message.body.decode()) + await self._process_item(item, worker_name) + await message.ack() + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Error processing message: {e}") + await message.nack(requeue=False) + except asyncio.CancelledError: + logger.info(f"Worker {worker_name} cancelled") + except Exception as e: + logger.error(f"Worker {worker_name} error: {e}") + + async def _process_item(self, item: Dict[str, Any], worker_name: str) -> None: + """Process a queue item by message type.""" try: message_data = item["data"] message_type = message_data.get("type", "unknown") @@ -132,7 +151,9 @@ async def _process_item(self, item: Dict[str, Any], worker_name: str): handler = self.handlers.get(message_type) if handler: - logger.debug(f"Worker {worker_name} processing {item['id']} (type: {message_type})") + logger.debug( + f"Worker {worker_name} processing {item['id']} (type: {message_type})" + ) if asyncio.iscoroutinefunction(handler): await handler(message_data) else: