398 lines
19 KiB
Python
398 lines
19 KiB
Python
# [DEF:TaskManagerModule:Module]
|
|
# @SEMANTICS: task, manager, lifecycle, execution, state
|
|
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously.
|
|
# @LAYER: Core
|
|
# @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks.
|
|
# @INVARIANT: Task IDs are unique.
|
|
# @CONSTRAINT: Must use belief_scope for logging.
|
|
|
|
# [SECTION: IMPORTS]
|
|
import asyncio
|
|
from datetime import datetime
|
|
from typing import Dict, Any, List, Optional
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from .models import Task, TaskStatus, LogEntry
|
|
from .persistence import TaskPersistenceService
|
|
from ..logger import logger, belief_scope
|
|
# [/SECTION]
|
|
|
|
# [DEF:TaskManager:Class]
|
|
# @SEMANTICS: task, manager, lifecycle, execution, state
|
|
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking.
|
|
class TaskManager:
|
|
"""
|
|
Manages the lifecycle of tasks, including their creation, execution, and state tracking.
|
|
"""
|
|
|
|
# [DEF:__init__:Function]
|
|
# @PURPOSE: Initialize the TaskManager with dependencies.
|
|
# @PRE: plugin_loader is initialized.
|
|
# @POST: TaskManager is ready to accept tasks.
|
|
# @PARAM: plugin_loader - The plugin loader instance.
|
|
def __init__(self, plugin_loader):
|
|
with belief_scope("TaskManager.__init__"):
|
|
self.plugin_loader = plugin_loader
|
|
self.tasks: Dict[str, Task] = {}
|
|
self.subscribers: Dict[str, List[asyncio.Queue]] = {}
|
|
self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution
|
|
self.persistence_service = TaskPersistenceService()
|
|
|
|
try:
|
|
self.loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
self.loop = asyncio.get_event_loop()
|
|
self.task_futures: Dict[str, asyncio.Future] = {}
|
|
|
|
# Load persisted tasks on startup
|
|
self.load_persisted_tasks()
|
|
# [/DEF:__init__:Function]
|
|
|
|
# [DEF:create_task:Function]
|
|
# @PURPOSE: Creates and queues a new task for execution.
|
|
# @PRE: Plugin with plugin_id exists. Params are valid.
|
|
# @POST: Task is created, added to registry, and scheduled for execution.
|
|
# @PARAM: plugin_id (str) - The ID of the plugin to run.
|
|
# @PARAM: params (Dict[str, Any]) - Parameters for the plugin.
|
|
# @PARAM: user_id (Optional[str]) - ID of the user requesting the task.
|
|
# @RETURN: Task - The created task instance.
|
|
# @THROWS: ValueError if plugin not found or params invalid.
|
|
async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task:
|
|
with belief_scope("TaskManager.create_task", f"plugin_id={plugin_id}"):
|
|
if not self.plugin_loader.has_plugin(plugin_id):
|
|
logger.error(f"Plugin with ID '{plugin_id}' not found.")
|
|
raise ValueError(f"Plugin with ID '{plugin_id}' not found.")
|
|
|
|
plugin = self.plugin_loader.get_plugin(plugin_id)
|
|
|
|
if not isinstance(params, dict):
|
|
logger.error("Task parameters must be a dictionary.")
|
|
raise ValueError("Task parameters must be a dictionary.")
|
|
|
|
task = Task(plugin_id=plugin_id, params=params, user_id=user_id)
|
|
self.tasks[task.id] = task
|
|
self.persistence_service.persist_task(task)
|
|
logger.info(f"Task {task.id} created and scheduled for execution")
|
|
self.loop.create_task(self._run_task(task.id)) # Schedule task for execution
|
|
return task
|
|
# [/DEF:create_task:Function]
|
|
|
|
# [DEF:_run_task:Function]
|
|
# @PURPOSE: Internal method to execute a task.
|
|
# @PRE: Task exists in registry.
|
|
# @POST: Task is executed, status updated to SUCCESS or FAILED.
|
|
# @PARAM: task_id (str) - The ID of the task to run.
|
|
async def _run_task(self, task_id: str):
|
|
with belief_scope("TaskManager._run_task", f"task_id={task_id}"):
|
|
task = self.tasks[task_id]
|
|
plugin = self.plugin_loader.get_plugin(task.plugin_id)
|
|
|
|
logger.info(f"Starting execution of task {task_id} for plugin '{plugin.name}'")
|
|
task.status = TaskStatus.RUNNING
|
|
task.started_at = datetime.utcnow()
|
|
self.persistence_service.persist_task(task)
|
|
self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'")
|
|
|
|
try:
|
|
# Execute plugin
|
|
params = {**task.params, "_task_id": task_id}
|
|
|
|
if asyncio.iscoroutinefunction(plugin.execute):
|
|
task.result = await plugin.execute(params)
|
|
else:
|
|
task.result = await self.loop.run_in_executor(
|
|
self.executor,
|
|
plugin.execute,
|
|
params
|
|
)
|
|
|
|
logger.info(f"Task {task_id} completed successfully")
|
|
task.status = TaskStatus.SUCCESS
|
|
self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'")
|
|
except Exception as e:
|
|
logger.error(f"Task {task_id} failed: {e}")
|
|
task.status = TaskStatus.FAILED
|
|
self._add_log(task_id, "ERROR", f"Task failed: {e}", {"error_type": type(e).__name__})
|
|
finally:
|
|
task.finished_at = datetime.utcnow()
|
|
self.persistence_service.persist_task(task)
|
|
logger.info(f"Task {task_id} execution finished with status: {task.status}")
|
|
# [/DEF:_run_task:Function]
|
|
|
|
# [DEF:resolve_task:Function]
|
|
# @PURPOSE: Resumes a task that is awaiting mapping.
|
|
# @PRE: Task exists and is in AWAITING_MAPPING state.
|
|
# @POST: Task status updated to RUNNING, params updated, execution resumed.
|
|
# @PARAM: task_id (str) - The ID of the task.
|
|
# @PARAM: resolution_params (Dict[str, Any]) - Params to resolve the wait.
|
|
# @THROWS: ValueError if task not found or not awaiting mapping.
|
|
async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]):
|
|
with belief_scope("TaskManager.resolve_task", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task or task.status != TaskStatus.AWAITING_MAPPING:
|
|
raise ValueError("Task is not awaiting mapping.")
|
|
|
|
# Update task params with resolution
|
|
task.params.update(resolution_params)
|
|
task.status = TaskStatus.RUNNING
|
|
self.persistence_service.persist_task(task)
|
|
self._add_log(task_id, "INFO", "Task resumed after mapping resolution.")
|
|
|
|
# Signal the future to continue
|
|
if task_id in self.task_futures:
|
|
self.task_futures[task_id].set_result(True)
|
|
# [/DEF:resolve_task:Function]
|
|
|
|
# [DEF:wait_for_resolution:Function]
|
|
# @PURPOSE: Pauses execution and waits for a resolution signal.
|
|
# @PRE: Task exists.
|
|
# @POST: Execution pauses until future is set.
|
|
# @PARAM: task_id (str) - The ID of the task.
|
|
async def wait_for_resolution(self, task_id: str):
|
|
with belief_scope("TaskManager.wait_for_resolution", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task: return
|
|
|
|
task.status = TaskStatus.AWAITING_MAPPING
|
|
self.persistence_service.persist_task(task)
|
|
self.task_futures[task_id] = self.loop.create_future()
|
|
|
|
try:
|
|
await self.task_futures[task_id]
|
|
finally:
|
|
if task_id in self.task_futures:
|
|
del self.task_futures[task_id]
|
|
# [/DEF:wait_for_resolution:Function]
|
|
|
|
# [DEF:wait_for_input:Function]
|
|
# @PURPOSE: Pauses execution and waits for user input.
|
|
# @PRE: Task exists.
|
|
# @POST: Execution pauses until future is set via resume_task_with_password.
|
|
# @PARAM: task_id (str) - The ID of the task.
|
|
async def wait_for_input(self, task_id: str):
|
|
with belief_scope("TaskManager.wait_for_input", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task: return
|
|
|
|
# Status is already set to AWAITING_INPUT by await_input()
|
|
self.task_futures[task_id] = self.loop.create_future()
|
|
|
|
try:
|
|
await self.task_futures[task_id]
|
|
finally:
|
|
if task_id in self.task_futures:
|
|
del self.task_futures[task_id]
|
|
# [/DEF:wait_for_input:Function]
|
|
|
|
# [DEF:get_task:Function]
|
|
# @PURPOSE: Retrieves a task by its ID.
|
|
# @PRE: task_id is a string.
|
|
# @POST: Returns Task object or None.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @RETURN: Optional[Task] - The task or None.
|
|
def get_task(self, task_id: str) -> Optional[Task]:
|
|
with belief_scope("TaskManager.get_task", f"task_id={task_id}"):
|
|
return self.tasks.get(task_id)
|
|
# [/DEF:get_task:Function]
|
|
|
|
# [DEF:get_all_tasks:Function]
|
|
# @PURPOSE: Retrieves all registered tasks.
|
|
# @PRE: None.
|
|
# @POST: Returns list of all Task objects.
|
|
# @RETURN: List[Task] - All tasks.
|
|
def get_all_tasks(self) -> List[Task]:
|
|
with belief_scope("TaskManager.get_all_tasks"):
|
|
return list(self.tasks.values())
|
|
# [/DEF:get_all_tasks:Function]
|
|
|
|
# [DEF:get_tasks:Function]
|
|
# @PURPOSE: Retrieves tasks with pagination and optional status filter.
|
|
# @PRE: limit and offset are non-negative integers.
|
|
# @POST: Returns a list of tasks sorted by start_time descending.
|
|
# @PARAM: limit (int) - Maximum number of tasks to return.
|
|
# @PARAM: offset (int) - Number of tasks to skip.
|
|
# @PARAM: status (Optional[TaskStatus]) - Filter by task status.
|
|
# @RETURN: List[Task] - List of tasks matching criteria.
|
|
def get_tasks(self, limit: int = 10, offset: int = 0, status: Optional[TaskStatus] = None) -> List[Task]:
|
|
with belief_scope("TaskManager.get_tasks"):
|
|
tasks = list(self.tasks.values())
|
|
if status:
|
|
tasks = [t for t in tasks if t.status == status]
|
|
# Sort by start_time descending (most recent first)
|
|
tasks.sort(key=lambda t: t.started_at or datetime.min, reverse=True)
|
|
return tasks[offset:offset + limit]
|
|
# [/DEF:get_tasks:Function]
|
|
|
|
# [DEF:get_task_logs:Function]
|
|
# @PURPOSE: Retrieves logs for a specific task.
|
|
# @PRE: task_id is a string.
|
|
# @POST: Returns list of LogEntry objects.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @RETURN: List[LogEntry] - List of log entries.
|
|
def get_task_logs(self, task_id: str) -> List[LogEntry]:
|
|
with belief_scope("TaskManager.get_task_logs", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
return task.logs if task else []
|
|
# [/DEF:get_task_logs:Function]
|
|
|
|
# [DEF:_add_log:Function]
|
|
# @PURPOSE: Adds a log entry to a task and notifies subscribers.
|
|
# @PRE: Task exists.
|
|
# @POST: Log added to task and pushed to queues.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: level (str) - Log level.
|
|
# @PARAM: message (str) - Log message.
|
|
# @PARAM: context (Optional[Dict]) - Log context.
|
|
def _add_log(self, task_id: str, level: str, message: str, context: Optional[Dict[str, Any]] = None):
|
|
with belief_scope("TaskManager._add_log", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
return
|
|
|
|
log_entry = LogEntry(level=level, message=message, context=context)
|
|
task.logs.append(log_entry)
|
|
self.persistence_service.persist_task(task)
|
|
|
|
# Notify subscribers
|
|
if task_id in self.subscribers:
|
|
for queue in self.subscribers[task_id]:
|
|
self.loop.call_soon_threadsafe(queue.put_nowait, log_entry)
|
|
# [/DEF:_add_log:Function]
|
|
|
|
# [DEF:subscribe_logs:Function]
|
|
# @PURPOSE: Subscribes to real-time logs for a task.
|
|
# @PRE: task_id is a string.
|
|
# @POST: Returns an asyncio.Queue for log entries.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @RETURN: asyncio.Queue - Queue for log entries.
|
|
async def subscribe_logs(self, task_id: str) -> asyncio.Queue:
|
|
with belief_scope("TaskManager.subscribe_logs", f"task_id={task_id}"):
|
|
queue = asyncio.Queue()
|
|
if task_id not in self.subscribers:
|
|
self.subscribers[task_id] = []
|
|
self.subscribers[task_id].append(queue)
|
|
return queue
|
|
# [/DEF:subscribe_logs:Function]
|
|
|
|
# [DEF:unsubscribe_logs:Function]
|
|
# @PURPOSE: Unsubscribes from real-time logs for a task.
|
|
# @PRE: task_id is a string, queue is asyncio.Queue.
|
|
# @POST: Queue removed from subscribers.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: queue (asyncio.Queue) - Queue to remove.
|
|
def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue):
|
|
with belief_scope("TaskManager.unsubscribe_logs", f"task_id={task_id}"):
|
|
if task_id in self.subscribers:
|
|
if queue in self.subscribers[task_id]:
|
|
self.subscribers[task_id].remove(queue)
|
|
if not self.subscribers[task_id]:
|
|
del self.subscribers[task_id]
|
|
# [/DEF:unsubscribe_logs:Function]
|
|
|
|
# [DEF:load_persisted_tasks:Function]
|
|
# @PURPOSE: Load persisted tasks using persistence service.
|
|
# @PRE: None.
|
|
# @POST: Persisted tasks loaded into self.tasks.
|
|
def load_persisted_tasks(self) -> None:
|
|
with belief_scope("TaskManager.load_persisted_tasks"):
|
|
loaded_tasks = self.persistence_service.load_tasks(limit=100)
|
|
for task in loaded_tasks:
|
|
if task.id not in self.tasks:
|
|
self.tasks[task.id] = task
|
|
# [/DEF:load_persisted_tasks:Function]
|
|
|
|
# [DEF:await_input:Function]
|
|
# @PURPOSE: Transition a task to AWAITING_INPUT state with input request.
|
|
# @PRE: Task exists and is in RUNNING state.
|
|
# @POST: Task status changed to AWAITING_INPUT, input_request set, persisted.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: input_request (Dict) - Details about required input.
|
|
# @THROWS: ValueError if task not found or not RUNNING.
|
|
def await_input(self, task_id: str, input_request: Dict[str, Any]) -> None:
|
|
with belief_scope("TaskManager.await_input", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
if task.status != TaskStatus.RUNNING:
|
|
raise ValueError(f"Task {task_id} is not RUNNING (current: {task.status})")
|
|
|
|
task.status = TaskStatus.AWAITING_INPUT
|
|
task.input_required = True
|
|
task.input_request = input_request
|
|
self.persistence_service.persist_task(task)
|
|
self._add_log(task_id, "INFO", "Task paused for user input", {"input_request": input_request})
|
|
# [/DEF:await_input:Function]
|
|
|
|
# [DEF:resume_task_with_password:Function]
|
|
# @PURPOSE: Resume a task that is awaiting input with provided passwords.
|
|
# @PRE: Task exists and is in AWAITING_INPUT state.
|
|
# @POST: Task status changed to RUNNING, passwords injected, task resumed.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: passwords (Dict[str, str]) - Mapping of database name to password.
|
|
# @THROWS: ValueError if task not found, not awaiting input, or passwords invalid.
|
|
def resume_task_with_password(self, task_id: str, passwords: Dict[str, str]) -> None:
|
|
with belief_scope("TaskManager.resume_task_with_password", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
if task.status != TaskStatus.AWAITING_INPUT:
|
|
raise ValueError(f"Task {task_id} is not AWAITING_INPUT (current: {task.status})")
|
|
|
|
if not isinstance(passwords, dict) or not passwords:
|
|
raise ValueError("Passwords must be a non-empty dictionary")
|
|
|
|
task.params["passwords"] = passwords
|
|
task.input_required = False
|
|
task.input_request = None
|
|
task.status = TaskStatus.RUNNING
|
|
self.persistence_service.persist_task(task)
|
|
self._add_log(task_id, "INFO", "Task resumed with passwords", {"databases": list(passwords.keys())})
|
|
|
|
if task_id in self.task_futures:
|
|
self.task_futures[task_id].set_result(True)
|
|
# [/DEF:resume_task_with_password:Function]
|
|
|
|
# [DEF:clear_tasks:Function]
|
|
# @PURPOSE: Clears tasks based on status filter.
|
|
# @PRE: status is Optional[TaskStatus].
|
|
# @POST: Tasks matching filter (or all non-active) cleared from registry and database.
|
|
# @PARAM: status (Optional[TaskStatus]) - Filter by task status.
|
|
# @RETURN: int - Number of tasks cleared.
|
|
def clear_tasks(self, status: Optional[TaskStatus] = None) -> int:
|
|
with belief_scope("TaskManager.clear_tasks"):
|
|
tasks_to_remove = []
|
|
for task_id, task in list(self.tasks.items()):
|
|
# If status is provided, match it.
|
|
# If status is None, match everything EXCEPT RUNNING (unless they are awaiting input/mapping which are technically running but paused?)
|
|
# Actually, AWAITING_INPUT and AWAITING_MAPPING are distinct statuses in TaskStatus enum.
|
|
# RUNNING is active execution.
|
|
|
|
should_remove = False
|
|
if status:
|
|
if task.status == status:
|
|
should_remove = True
|
|
else:
|
|
# Clear all non-active tasks (keep RUNNING, AWAITING_INPUT, AWAITING_MAPPING)
|
|
if task.status not in [TaskStatus.RUNNING, TaskStatus.AWAITING_INPUT, TaskStatus.AWAITING_MAPPING]:
|
|
should_remove = True
|
|
|
|
if should_remove:
|
|
tasks_to_remove.append(task_id)
|
|
|
|
for tid in tasks_to_remove:
|
|
# Cancel future if exists (e.g. for AWAITING_INPUT/MAPPING)
|
|
if tid in self.task_futures:
|
|
self.task_futures[tid].cancel()
|
|
del self.task_futures[tid]
|
|
|
|
del self.tasks[tid]
|
|
|
|
# Remove from persistence
|
|
self.persistence_service.delete_tasks(tasks_to_remove)
|
|
|
|
logger.info(f"Cleared {len(tasks_to_remove)} tasks.")
|
|
return len(tasks_to_remove)
|
|
# [/DEF:clear_tasks:Function]
|
|
|
|
# [/DEF:TaskManager:Class]
|
|
# [/DEF:TaskManagerModule:Module] |