# [DEF:TaskManagerModule:Module] # @SEMANTICS: task, manager, lifecycle, execution, state # @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously. # @LAYER: Core # @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks. import asyncio import uuid from datetime import datetime from enum import Enum from typing import Dict, Any, List, Optional from concurrent.futures import ThreadPoolExecutor from pydantic import BaseModel, Field # Assuming PluginBase and PluginConfig are defined in plugin_base.py # from .plugin_base import PluginBase, PluginConfig # Not needed here, TaskManager interacts with the PluginLoader # [DEF:TaskStatus:Enum] # @SEMANTICS: task, status, state, enum # @PURPOSE: Defines the possible states a task can be in during its lifecycle. class TaskStatus(str, Enum): PENDING = "PENDING" RUNNING = "RUNNING" SUCCESS = "SUCCESS" FAILED = "FAILED" AWAITING_MAPPING = "AWAITING_MAPPING" # [/DEF] # [DEF:LogEntry:Class] # @SEMANTICS: log, entry, record, pydantic # @PURPOSE: A Pydantic model representing a single, structured log entry associated with a task. class LogEntry(BaseModel): timestamp: datetime = Field(default_factory=datetime.utcnow) level: str message: str context: Optional[Dict[str, Any]] = None # [/DEF] # [DEF:Task:Class] # @SEMANTICS: task, job, execution, state, pydantic # @PURPOSE: A Pydantic model representing a single execution instance of a plugin, including its status, parameters, and logs. class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) plugin_id: str status: TaskStatus = TaskStatus.PENDING started_at: Optional[datetime] = None finished_at: Optional[datetime] = None user_id: Optional[str] = None logs: List[LogEntry] = Field(default_factory=list) params: Dict[str, Any] = Field(default_factory=dict) # [/DEF] # [DEF:TaskManager:Class] # @SEMANTICS: task, manager, lifecycle, execution, state # @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. class TaskManager: """ Manages the lifecycle of tasks, including their creation, execution, and state tracking. """ def __init__(self, plugin_loader): self.plugin_loader = plugin_loader self.tasks: Dict[str, Task] = {} self.subscribers: Dict[str, List[asyncio.Queue]] = {} self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution self.loop = asyncio.get_event_loop() self.task_futures: Dict[str, asyncio.Future] = {} # [/DEF] async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task: """ Creates and queues a new task for execution. """ if not self.plugin_loader.has_plugin(plugin_id): raise ValueError(f"Plugin with ID '{plugin_id}' not found.") plugin = self.plugin_loader.get_plugin(plugin_id) # Validate params against plugin schema (this will be done at a higher level, e.g., API route) # For now, a basic check if not isinstance(params, dict): raise ValueError("Task parameters must be a dictionary.") task = Task(plugin_id=plugin_id, params=params, user_id=user_id) self.tasks[task.id] = task self.loop.create_task(self._run_task(task.id)) # Schedule task for execution return task async def _run_task(self, task_id: str): """ Internal method to execute a task. """ task = self.tasks[task_id] plugin = self.plugin_loader.get_plugin(task.plugin_id) task.status = TaskStatus.RUNNING task.started_at = datetime.utcnow() self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'") try: # Execute plugin in a separate thread to avoid blocking the event loop # if the plugin's execute method is synchronous and potentially CPU-bound. # If the plugin's execute method is already async, this can be simplified. # Pass task_id to plugin so it can signal pause params = {**task.params, "_task_id": task_id} await self.loop.run_in_executor( self.executor, lambda: asyncio.run(plugin.execute(params)) if asyncio.iscoroutinefunction(plugin.execute) else plugin.execute(params) ) task.status = TaskStatus.SUCCESS self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'") except Exception as e: task.status = TaskStatus.FAILED self._add_log(task_id, "ERROR", f"Task failed: {e}", {"error_type": type(e).__name__}) finally: task.finished_at = datetime.utcnow() # In a real system, you might notify clients via WebSocket here async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]): """ Resumes a task that is awaiting mapping. """ task = self.tasks.get(task_id) if not task or task.status != TaskStatus.AWAITING_MAPPING: raise ValueError("Task is not awaiting mapping.") # Update task params with resolution task.params.update(resolution_params) task.status = TaskStatus.RUNNING self._add_log(task_id, "INFO", "Task resumed after mapping resolution.") # Signal the future to continue if task_id in self.task_futures: self.task_futures[task_id].set_result(True) async def wait_for_resolution(self, task_id: str): """ Pauses execution and waits for a resolution signal. """ task = self.tasks.get(task_id) if not task: return task.status = TaskStatus.AWAITING_MAPPING self.task_futures[task_id] = self.loop.create_future() try: await self.task_futures[task_id] finally: del self.task_futures[task_id] def get_task(self, task_id: str) -> Optional[Task]: """ Retrieves a task by its ID. """ return self.tasks.get(task_id) def get_all_tasks(self) -> List[Task]: """ Retrieves all registered tasks. """ return list(self.tasks.values()) def get_task_logs(self, task_id: str) -> List[LogEntry]: """ Retrieves logs for a specific task. """ task = self.tasks.get(task_id) return task.logs if task else [] def _add_log(self, task_id: str, level: str, message: str, context: Optional[Dict[str, Any]] = None): """ Adds a log entry to a task and notifies subscribers. """ task = self.tasks.get(task_id) if not task: return log_entry = LogEntry(level=level, message=message, context=context) task.logs.append(log_entry) # Notify subscribers if task_id in self.subscribers: for queue in self.subscribers[task_id]: self.loop.call_soon_threadsafe(queue.put_nowait, log_entry) async def subscribe_logs(self, task_id: str) -> asyncio.Queue: """ Subscribes to real-time logs for a task. """ queue = asyncio.Queue() if task_id not in self.subscribers: self.subscribers[task_id] = [] self.subscribers[task_id].append(queue) return queue def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue): """ Unsubscribes from real-time logs for a task. """ if task_id in self.subscribers: self.subscribers[task_id].remove(queue) if not self.subscribers[task_id]: del self.subscribers[task_id]