204 lines
7.9 KiB
Python
Executable File
204 lines
7.9 KiB
Python
Executable File
# [DEF:TaskManagerModule:Module]
|
|
# @SEMANTICS: task, manager, lifecycle, execution, state
|
|
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously.
|
|
# @LAYER: Core
|
|
# @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks.
|
|
import asyncio
|
|
import uuid
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Dict, Any, List, Optional
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Assuming PluginBase and PluginConfig are defined in plugin_base.py
|
|
# from .plugin_base import PluginBase, PluginConfig # Not needed here, TaskManager interacts with the PluginLoader
|
|
|
|
# [DEF:TaskStatus:Enum]
|
|
# @SEMANTICS: task, status, state, enum
|
|
# @PURPOSE: Defines the possible states a task can be in during its lifecycle.
|
|
class TaskStatus(str, Enum):
|
|
PENDING = "PENDING"
|
|
RUNNING = "RUNNING"
|
|
SUCCESS = "SUCCESS"
|
|
FAILED = "FAILED"
|
|
AWAITING_MAPPING = "AWAITING_MAPPING"
|
|
|
|
# [/DEF]
|
|
|
|
# [DEF:LogEntry:Class]
|
|
# @SEMANTICS: log, entry, record, pydantic
|
|
# @PURPOSE: A Pydantic model representing a single, structured log entry associated with a task.
|
|
class LogEntry(BaseModel):
|
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
|
level: str
|
|
message: str
|
|
context: Optional[Dict[str, Any]] = None
|
|
# [/DEF]
|
|
|
|
# [DEF:Task:Class]
|
|
# @SEMANTICS: task, job, execution, state, pydantic
|
|
# @PURPOSE: A Pydantic model representing a single execution instance of a plugin, including its status, parameters, and logs.
|
|
class Task(BaseModel):
|
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
plugin_id: str
|
|
status: TaskStatus = TaskStatus.PENDING
|
|
started_at: Optional[datetime] = None
|
|
finished_at: Optional[datetime] = None
|
|
user_id: Optional[str] = None
|
|
logs: List[LogEntry] = Field(default_factory=list)
|
|
params: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
# [/DEF]
|
|
|
|
# [DEF:TaskManager:Class]
|
|
# @SEMANTICS: task, manager, lifecycle, execution, state
|
|
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking.
|
|
class TaskManager:
|
|
"""
|
|
Manages the lifecycle of tasks, including their creation, execution, and state tracking.
|
|
"""
|
|
def __init__(self, plugin_loader):
|
|
self.plugin_loader = plugin_loader
|
|
self.tasks: Dict[str, Task] = {}
|
|
self.subscribers: Dict[str, List[asyncio.Queue]] = {}
|
|
self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution
|
|
self.loop = asyncio.get_event_loop()
|
|
self.task_futures: Dict[str, asyncio.Future] = {}
|
|
# [/DEF]
|
|
|
|
async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task:
|
|
"""
|
|
Creates and queues a new task for execution.
|
|
"""
|
|
if not self.plugin_loader.has_plugin(plugin_id):
|
|
raise ValueError(f"Plugin with ID '{plugin_id}' not found.")
|
|
|
|
plugin = self.plugin_loader.get_plugin(plugin_id)
|
|
# Validate params against plugin schema (this will be done at a higher level, e.g., API route)
|
|
# For now, a basic check
|
|
if not isinstance(params, dict):
|
|
raise ValueError("Task parameters must be a dictionary.")
|
|
|
|
task = Task(plugin_id=plugin_id, params=params, user_id=user_id)
|
|
self.tasks[task.id] = task
|
|
self.loop.create_task(self._run_task(task.id)) # Schedule task for execution
|
|
return task
|
|
|
|
async def _run_task(self, task_id: str):
|
|
"""
|
|
Internal method to execute a task.
|
|
"""
|
|
task = self.tasks[task_id]
|
|
plugin = self.plugin_loader.get_plugin(task.plugin_id)
|
|
|
|
task.status = TaskStatus.RUNNING
|
|
task.started_at = datetime.utcnow()
|
|
self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'")
|
|
|
|
try:
|
|
# Execute plugin in a separate thread to avoid blocking the event loop
|
|
# if the plugin's execute method is synchronous and potentially CPU-bound.
|
|
# If the plugin's execute method is already async, this can be simplified.
|
|
# Pass task_id to plugin so it can signal pause
|
|
params = {**task.params, "_task_id": task_id}
|
|
await self.loop.run_in_executor(
|
|
self.executor,
|
|
lambda: asyncio.run(plugin.execute(params)) if asyncio.iscoroutinefunction(plugin.execute) else plugin.execute(params)
|
|
)
|
|
task.status = TaskStatus.SUCCESS
|
|
self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'")
|
|
except Exception as e:
|
|
task.status = TaskStatus.FAILED
|
|
self._add_log(task_id, "ERROR", f"Task failed: {e}", {"error_type": type(e).__name__})
|
|
finally:
|
|
task.finished_at = datetime.utcnow()
|
|
# In a real system, you might notify clients via WebSocket here
|
|
|
|
async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]):
|
|
"""
|
|
Resumes a task that is awaiting mapping.
|
|
"""
|
|
task = self.tasks.get(task_id)
|
|
if not task or task.status != TaskStatus.AWAITING_MAPPING:
|
|
raise ValueError("Task is not awaiting mapping.")
|
|
|
|
# Update task params with resolution
|
|
task.params.update(resolution_params)
|
|
task.status = TaskStatus.RUNNING
|
|
self._add_log(task_id, "INFO", "Task resumed after mapping resolution.")
|
|
|
|
# Signal the future to continue
|
|
if task_id in self.task_futures:
|
|
self.task_futures[task_id].set_result(True)
|
|
|
|
async def wait_for_resolution(self, task_id: str):
|
|
"""
|
|
Pauses execution and waits for a resolution signal.
|
|
"""
|
|
task = self.tasks.get(task_id)
|
|
if not task: return
|
|
|
|
task.status = TaskStatus.AWAITING_MAPPING
|
|
self.task_futures[task_id] = self.loop.create_future()
|
|
|
|
try:
|
|
await self.task_futures[task_id]
|
|
finally:
|
|
del self.task_futures[task_id]
|
|
|
|
def get_task(self, task_id: str) -> Optional[Task]:
|
|
"""
|
|
Retrieves a task by its ID.
|
|
"""
|
|
return self.tasks.get(task_id)
|
|
|
|
def get_all_tasks(self) -> List[Task]:
|
|
"""
|
|
Retrieves all registered tasks.
|
|
"""
|
|
return list(self.tasks.values())
|
|
|
|
def get_task_logs(self, task_id: str) -> List[LogEntry]:
|
|
"""
|
|
Retrieves logs for a specific task.
|
|
"""
|
|
task = self.tasks.get(task_id)
|
|
return task.logs if task else []
|
|
|
|
def _add_log(self, task_id: str, level: str, message: str, context: Optional[Dict[str, Any]] = None):
|
|
"""
|
|
Adds a log entry to a task and notifies subscribers.
|
|
"""
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
return
|
|
|
|
log_entry = LogEntry(level=level, message=message, context=context)
|
|
task.logs.append(log_entry)
|
|
|
|
# Notify subscribers
|
|
if task_id in self.subscribers:
|
|
for queue in self.subscribers[task_id]:
|
|
self.loop.call_soon_threadsafe(queue.put_nowait, log_entry)
|
|
|
|
async def subscribe_logs(self, task_id: str) -> asyncio.Queue:
|
|
"""
|
|
Subscribes to real-time logs for a task.
|
|
"""
|
|
queue = asyncio.Queue()
|
|
if task_id not in self.subscribers:
|
|
self.subscribers[task_id] = []
|
|
self.subscribers[task_id].append(queue)
|
|
return queue
|
|
|
|
def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue):
|
|
"""
|
|
Unsubscribes from real-time logs for a task.
|
|
"""
|
|
if task_id in self.subscribers:
|
|
self.subscribers[task_id].remove(queue)
|
|
if not self.subscribers[task_id]:
|
|
del self.subscribers[task_id]
|