TaskManager refactor

This commit is contained in:
2025-12-29 10:13:37 +03:00
parent 6962a78112
commit 4c9d554432
25 changed files with 2778 additions and 283 deletions

View File

@@ -0,0 +1,335 @@
# [DEF:TaskManagerModule:Module]
# @SEMANTICS: task, manager, lifecycle, execution, state
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously.
# @LAYER: Core
# @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks.
# @INVARIANT: Task IDs are unique.
# @CONSTRAINT: Must use belief_scope for logging.
# [SECTION: IMPORTS]
import asyncio
from datetime import datetime
from typing import Dict, Any, List, Optional
from concurrent.futures import ThreadPoolExecutor
from .models import Task, TaskStatus, LogEntry
from .persistence import TaskPersistenceService
from ..logger import logger, belief_scope
# [/SECTION]
# [DEF:TaskManager:Class]
# @SEMANTICS: task, manager, lifecycle, execution, state
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking.
class TaskManager:
"""
Manages the lifecycle of tasks, including their creation, execution, and state tracking.
"""
# [DEF:TaskManager.__init__:Function]
# @PURPOSE: Initialize the TaskManager with dependencies.
# @PRE: plugin_loader is initialized.
# @POST: TaskManager is ready to accept tasks.
# @PARAM: plugin_loader - The plugin loader instance.
def __init__(self, plugin_loader):
with belief_scope("TaskManager.__init__"):
self.plugin_loader = plugin_loader
self.tasks: Dict[str, Task] = {}
self.subscribers: Dict[str, List[asyncio.Queue]] = {}
self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution
self.persistence_service = TaskPersistenceService()
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.get_event_loop()
self.task_futures: Dict[str, asyncio.Future] = {}
# [/DEF:TaskManager.__init__:Function]
# [DEF:TaskManager.create_task:Function]
# @PURPOSE: Creates and queues a new task for execution.
# @PRE: Plugin with plugin_id exists. Params are valid.
# @POST: Task is created, added to registry, and scheduled for execution.
# @PARAM: plugin_id (str) - The ID of the plugin to run.
# @PARAM: params (Dict[str, Any]) - Parameters for the plugin.
# @PARAM: user_id (Optional[str]) - ID of the user requesting the task.
# @RETURN: Task - The created task instance.
# @THROWS: ValueError if plugin not found or params invalid.
async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task:
with belief_scope("TaskManager.create_task", f"plugin_id={plugin_id}"):
if not self.plugin_loader.has_plugin(plugin_id):
logger.error(f"Plugin with ID '{plugin_id}' not found.")
raise ValueError(f"Plugin with ID '{plugin_id}' not found.")
plugin = self.plugin_loader.get_plugin(plugin_id)
if not isinstance(params, dict):
logger.error("Task parameters must be a dictionary.")
raise ValueError("Task parameters must be a dictionary.")
task = Task(plugin_id=plugin_id, params=params, user_id=user_id)
self.tasks[task.id] = task
logger.info(f"Task {task.id} created and scheduled for execution")
self.loop.create_task(self._run_task(task.id)) # Schedule task for execution
return task
# [/DEF:TaskManager.create_task:Function]
# [DEF:TaskManager._run_task:Function]
# @PURPOSE: Internal method to execute a task.
# @PRE: Task exists in registry.
# @POST: Task is executed, status updated to SUCCESS or FAILED.
# @PARAM: task_id (str) - The ID of the task to run.
async def _run_task(self, task_id: str):
with belief_scope("TaskManager._run_task", f"task_id={task_id}"):
task = self.tasks[task_id]
plugin = self.plugin_loader.get_plugin(task.plugin_id)
logger.info(f"Starting execution of task {task_id} for plugin '{plugin.name}'")
task.status = TaskStatus.RUNNING
task.started_at = datetime.utcnow()
self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'")
try:
# Execute plugin
params = {**task.params, "_task_id": task_id}
if asyncio.iscoroutinefunction(plugin.execute):
await plugin.execute(params)
else:
await self.loop.run_in_executor(
self.executor,
plugin.execute,
params
)
logger.info(f"Task {task_id} completed successfully")
task.status = TaskStatus.SUCCESS
self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'")
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
task.status = TaskStatus.FAILED
self._add_log(task_id, "ERROR", f"Task failed: {e}", {"error_type": type(e).__name__})
finally:
task.finished_at = datetime.utcnow()
logger.info(f"Task {task_id} execution finished with status: {task.status}")
# [/DEF:TaskManager._run_task:Function]
# [DEF:TaskManager.resolve_task:Function]
# @PURPOSE: Resumes a task that is awaiting mapping.
# @PRE: Task exists and is in AWAITING_MAPPING state.
# @POST: Task status updated to RUNNING, params updated, execution resumed.
# @PARAM: task_id (str) - The ID of the task.
# @PARAM: resolution_params (Dict[str, Any]) - Params to resolve the wait.
# @THROWS: ValueError if task not found or not awaiting mapping.
async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]):
with belief_scope("TaskManager.resolve_task", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task or task.status != TaskStatus.AWAITING_MAPPING:
raise ValueError("Task is not awaiting mapping.")
# Update task params with resolution
task.params.update(resolution_params)
task.status = TaskStatus.RUNNING
self._add_log(task_id, "INFO", "Task resumed after mapping resolution.")
# Signal the future to continue
if task_id in self.task_futures:
self.task_futures[task_id].set_result(True)
# [/DEF:TaskManager.resolve_task:Function]
# [DEF:TaskManager.wait_for_resolution:Function]
# @PURPOSE: Pauses execution and waits for a resolution signal.
# @PRE: Task exists.
# @POST: Execution pauses until future is set.
# @PARAM: task_id (str) - The ID of the task.
async def wait_for_resolution(self, task_id: str):
with belief_scope("TaskManager.wait_for_resolution", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task: return
task.status = TaskStatus.AWAITING_MAPPING
self.task_futures[task_id] = self.loop.create_future()
try:
await self.task_futures[task_id]
finally:
if task_id in self.task_futures:
del self.task_futures[task_id]
# [/DEF:TaskManager.wait_for_resolution:Function]
# [DEF:TaskManager.wait_for_input:Function]
# @PURPOSE: Pauses execution and waits for user input.
# @PRE: Task exists.
# @POST: Execution pauses until future is set via resume_task_with_password.
# @PARAM: task_id (str) - The ID of the task.
async def wait_for_input(self, task_id: str):
with belief_scope("TaskManager.wait_for_input", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task: return
# Status is already set to AWAITING_INPUT by await_input()
self.task_futures[task_id] = self.loop.create_future()
try:
await self.task_futures[task_id]
finally:
if task_id in self.task_futures:
del self.task_futures[task_id]
# [/DEF:TaskManager.wait_for_input:Function]
# [DEF:TaskManager.get_task:Function]
# @PURPOSE: Retrieves a task by its ID.
# @PARAM: task_id (str) - ID of the task.
# @RETURN: Optional[Task] - The task or None.
def get_task(self, task_id: str) -> Optional[Task]:
return self.tasks.get(task_id)
# [/DEF:TaskManager.get_task:Function]
# [DEF:TaskManager.get_all_tasks:Function]
# @PURPOSE: Retrieves all registered tasks.
# @RETURN: List[Task] - All tasks.
def get_all_tasks(self) -> List[Task]:
return list(self.tasks.values())
# [/DEF:TaskManager.get_all_tasks:Function]
# [DEF:TaskManager.get_tasks:Function]
# @PURPOSE: Retrieves tasks with pagination and optional status filter.
# @PRE: limit and offset are non-negative integers.
# @POST: Returns a list of tasks sorted by start_time descending.
# @PARAM: limit (int) - Maximum number of tasks to return.
# @PARAM: offset (int) - Number of tasks to skip.
# @PARAM: status (Optional[TaskStatus]) - Filter by task status.
# @RETURN: List[Task] - List of tasks matching criteria.
def get_tasks(self, limit: int = 10, offset: int = 0, status: Optional[TaskStatus] = None) -> List[Task]:
tasks = list(self.tasks.values())
if status:
tasks = [t for t in tasks if t.status == status]
# Sort by start_time descending (most recent first)
tasks.sort(key=lambda t: t.started_at or datetime.min, reverse=True)
return tasks[offset:offset + limit]
# [/DEF:TaskManager.get_tasks:Function]
# [DEF:TaskManager.get_task_logs:Function]
# @PURPOSE: Retrieves logs for a specific task.
# @PARAM: task_id (str) - ID of the task.
# @RETURN: List[LogEntry] - List of log entries.
def get_task_logs(self, task_id: str) -> List[LogEntry]:
task = self.tasks.get(task_id)
return task.logs if task else []
# [/DEF:TaskManager.get_task_logs:Function]
# [DEF:TaskManager._add_log:Function]
# @PURPOSE: Adds a log entry to a task and notifies subscribers.
# @PRE: Task exists.
# @POST: Log added to task and pushed to queues.
# @PARAM: task_id (str) - ID of the task.
# @PARAM: level (str) - Log level.
# @PARAM: message (str) - Log message.
# @PARAM: context (Optional[Dict]) - Log context.
def _add_log(self, task_id: str, level: str, message: str, context: Optional[Dict[str, Any]] = None):
task = self.tasks.get(task_id)
if not task:
return
log_entry = LogEntry(level=level, message=message, context=context)
task.logs.append(log_entry)
# Notify subscribers
if task_id in self.subscribers:
for queue in self.subscribers[task_id]:
self.loop.call_soon_threadsafe(queue.put_nowait, log_entry)
# [/DEF:TaskManager._add_log:Function]
# [DEF:TaskManager.subscribe_logs:Function]
# @PURPOSE: Subscribes to real-time logs for a task.
# @PARAM: task_id (str) - ID of the task.
# @RETURN: asyncio.Queue - Queue for log entries.
async def subscribe_logs(self, task_id: str) -> asyncio.Queue:
queue = asyncio.Queue()
if task_id not in self.subscribers:
self.subscribers[task_id] = []
self.subscribers[task_id].append(queue)
return queue
# [/DEF:TaskManager.subscribe_logs:Function]
# [DEF:TaskManager.unsubscribe_logs:Function]
# @PURPOSE: Unsubscribes from real-time logs for a task.
# @PARAM: task_id (str) - ID of the task.
# @PARAM: queue (asyncio.Queue) - Queue to remove.
def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue):
if task_id in self.subscribers:
if queue in self.subscribers[task_id]:
self.subscribers[task_id].remove(queue)
if not self.subscribers[task_id]:
del self.subscribers[task_id]
# [/DEF:TaskManager.unsubscribe_logs:Function]
# [DEF:TaskManager.persist_awaiting_input_tasks:Function]
# @PURPOSE: Persist tasks in AWAITING_INPUT state using persistence service.
def persist_awaiting_input_tasks(self) -> None:
self.persistence_service.persist_tasks(list(self.tasks.values()))
# [/DEF:TaskManager.persist_awaiting_input_tasks:Function]
# [DEF:TaskManager.load_persisted_tasks:Function]
# @PURPOSE: Load persisted tasks using persistence service.
def load_persisted_tasks(self) -> None:
loaded_tasks = self.persistence_service.load_tasks()
for task in loaded_tasks:
if task.id not in self.tasks:
self.tasks[task.id] = task
# [/DEF:TaskManager.load_persisted_tasks:Function]
# [DEF:TaskManager.await_input:Function]
# @PURPOSE: Transition a task to AWAITING_INPUT state with input request.
# @PRE: Task exists and is in RUNNING state.
# @POST: Task status changed to AWAITING_INPUT, input_request set, persisted.
# @PARAM: task_id (str) - ID of the task.
# @PARAM: input_request (Dict) - Details about required input.
# @THROWS: ValueError if task not found or not RUNNING.
def await_input(self, task_id: str, input_request: Dict[str, Any]) -> None:
with belief_scope("TaskManager.await_input", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
if task.status != TaskStatus.RUNNING:
raise ValueError(f"Task {task_id} is not RUNNING (current: {task.status})")
task.status = TaskStatus.AWAITING_INPUT
task.input_required = True
task.input_request = input_request
self._add_log(task_id, "INFO", "Task paused for user input", {"input_request": input_request})
self.persist_awaiting_input_tasks()
# [/DEF:TaskManager.await_input:Function]
# [DEF:TaskManager.resume_task_with_password:Function]
# @PURPOSE: Resume a task that is awaiting input with provided passwords.
# @PRE: Task exists and is in AWAITING_INPUT state.
# @POST: Task status changed to RUNNING, passwords injected, task resumed.
# @PARAM: task_id (str) - ID of the task.
# @PARAM: passwords (Dict[str, str]) - Mapping of database name to password.
# @THROWS: ValueError if task not found, not awaiting input, or passwords invalid.
def resume_task_with_password(self, task_id: str, passwords: Dict[str, str]) -> None:
with belief_scope("TaskManager.resume_task_with_password", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
if task.status != TaskStatus.AWAITING_INPUT:
raise ValueError(f"Task {task_id} is not AWAITING_INPUT (current: {task.status})")
if not isinstance(passwords, dict) or not passwords:
raise ValueError("Passwords must be a non-empty dictionary")
task.params["passwords"] = passwords
task.input_required = False
task.input_request = None
task.status = TaskStatus.RUNNING
self._add_log(task_id, "INFO", "Task resumed with passwords", {"databases": list(passwords.keys())})
if task_id in self.task_futures:
self.task_futures[task_id].set_result(True)
self.persist_awaiting_input_tasks()
# [/DEF:TaskManager.resume_task_with_password:Function]
# [/DEF:TaskManager:Class]
# [/DEF:TaskManagerModule:Module]