TaskManager refactor
This commit is contained in:
335
backend/src/core/task_manager/manager.py
Normal file
335
backend/src/core/task_manager/manager.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# [DEF:TaskManagerModule:Module]
|
||||
# @SEMANTICS: task, manager, lifecycle, execution, state
|
||||
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously.
|
||||
# @LAYER: Core
|
||||
# @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks.
|
||||
# @INVARIANT: Task IDs are unique.
|
||||
# @CONSTRAINT: Must use belief_scope for logging.
|
||||
|
||||
# [SECTION: IMPORTS]
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .models import Task, TaskStatus, LogEntry
|
||||
from .persistence import TaskPersistenceService
|
||||
from ..logger import logger, belief_scope
|
||||
# [/SECTION]
|
||||
|
||||
# [DEF:TaskManager:Class]
|
||||
# @SEMANTICS: task, manager, lifecycle, execution, state
|
||||
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking.
|
||||
class TaskManager:
|
||||
"""
|
||||
Manages the lifecycle of tasks, including their creation, execution, and state tracking.
|
||||
"""
|
||||
|
||||
# [DEF:TaskManager.__init__:Function]
|
||||
# @PURPOSE: Initialize the TaskManager with dependencies.
|
||||
# @PRE: plugin_loader is initialized.
|
||||
# @POST: TaskManager is ready to accept tasks.
|
||||
# @PARAM: plugin_loader - The plugin loader instance.
|
||||
def __init__(self, plugin_loader):
|
||||
with belief_scope("TaskManager.__init__"):
|
||||
self.plugin_loader = plugin_loader
|
||||
self.tasks: Dict[str, Task] = {}
|
||||
self.subscribers: Dict[str, List[asyncio.Queue]] = {}
|
||||
self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution
|
||||
self.persistence_service = TaskPersistenceService()
|
||||
|
||||
try:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.task_futures: Dict[str, asyncio.Future] = {}
|
||||
# [/DEF:TaskManager.__init__:Function]
|
||||
|
||||
# [DEF:TaskManager.create_task:Function]
|
||||
# @PURPOSE: Creates and queues a new task for execution.
|
||||
# @PRE: Plugin with plugin_id exists. Params are valid.
|
||||
# @POST: Task is created, added to registry, and scheduled for execution.
|
||||
# @PARAM: plugin_id (str) - The ID of the plugin to run.
|
||||
# @PARAM: params (Dict[str, Any]) - Parameters for the plugin.
|
||||
# @PARAM: user_id (Optional[str]) - ID of the user requesting the task.
|
||||
# @RETURN: Task - The created task instance.
|
||||
# @THROWS: ValueError if plugin not found or params invalid.
|
||||
async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task:
|
||||
with belief_scope("TaskManager.create_task", f"plugin_id={plugin_id}"):
|
||||
if not self.plugin_loader.has_plugin(plugin_id):
|
||||
logger.error(f"Plugin with ID '{plugin_id}' not found.")
|
||||
raise ValueError(f"Plugin with ID '{plugin_id}' not found.")
|
||||
|
||||
plugin = self.plugin_loader.get_plugin(plugin_id)
|
||||
|
||||
if not isinstance(params, dict):
|
||||
logger.error("Task parameters must be a dictionary.")
|
||||
raise ValueError("Task parameters must be a dictionary.")
|
||||
|
||||
task = Task(plugin_id=plugin_id, params=params, user_id=user_id)
|
||||
self.tasks[task.id] = task
|
||||
logger.info(f"Task {task.id} created and scheduled for execution")
|
||||
self.loop.create_task(self._run_task(task.id)) # Schedule task for execution
|
||||
return task
|
||||
# [/DEF:TaskManager.create_task:Function]
|
||||
|
||||
# [DEF:TaskManager._run_task:Function]
|
||||
# @PURPOSE: Internal method to execute a task.
|
||||
# @PRE: Task exists in registry.
|
||||
# @POST: Task is executed, status updated to SUCCESS or FAILED.
|
||||
# @PARAM: task_id (str) - The ID of the task to run.
|
||||
async def _run_task(self, task_id: str):
|
||||
with belief_scope("TaskManager._run_task", f"task_id={task_id}"):
|
||||
task = self.tasks[task_id]
|
||||
plugin = self.plugin_loader.get_plugin(task.plugin_id)
|
||||
|
||||
logger.info(f"Starting execution of task {task_id} for plugin '{plugin.name}'")
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = datetime.utcnow()
|
||||
self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'")
|
||||
|
||||
try:
|
||||
# Execute plugin
|
||||
params = {**task.params, "_task_id": task_id}
|
||||
|
||||
if asyncio.iscoroutinefunction(plugin.execute):
|
||||
await plugin.execute(params)
|
||||
else:
|
||||
await self.loop.run_in_executor(
|
||||
self.executor,
|
||||
plugin.execute,
|
||||
params
|
||||
)
|
||||
|
||||
logger.info(f"Task {task_id} completed successfully")
|
||||
task.status = TaskStatus.SUCCESS
|
||||
self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Task {task_id} failed: {e}")
|
||||
task.status = TaskStatus.FAILED
|
||||
self._add_log(task_id, "ERROR", f"Task failed: {e}", {"error_type": type(e).__name__})
|
||||
finally:
|
||||
task.finished_at = datetime.utcnow()
|
||||
logger.info(f"Task {task_id} execution finished with status: {task.status}")
|
||||
# [/DEF:TaskManager._run_task:Function]
|
||||
|
||||
# [DEF:TaskManager.resolve_task:Function]
|
||||
# @PURPOSE: Resumes a task that is awaiting mapping.
|
||||
# @PRE: Task exists and is in AWAITING_MAPPING state.
|
||||
# @POST: Task status updated to RUNNING, params updated, execution resumed.
|
||||
# @PARAM: task_id (str) - The ID of the task.
|
||||
# @PARAM: resolution_params (Dict[str, Any]) - Params to resolve the wait.
|
||||
# @THROWS: ValueError if task not found or not awaiting mapping.
|
||||
async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]):
|
||||
with belief_scope("TaskManager.resolve_task", f"task_id={task_id}"):
|
||||
task = self.tasks.get(task_id)
|
||||
if not task or task.status != TaskStatus.AWAITING_MAPPING:
|
||||
raise ValueError("Task is not awaiting mapping.")
|
||||
|
||||
# Update task params with resolution
|
||||
task.params.update(resolution_params)
|
||||
task.status = TaskStatus.RUNNING
|
||||
self._add_log(task_id, "INFO", "Task resumed after mapping resolution.")
|
||||
|
||||
# Signal the future to continue
|
||||
if task_id in self.task_futures:
|
||||
self.task_futures[task_id].set_result(True)
|
||||
# [/DEF:TaskManager.resolve_task:Function]
|
||||
|
||||
# [DEF:TaskManager.wait_for_resolution:Function]
|
||||
# @PURPOSE: Pauses execution and waits for a resolution signal.
|
||||
# @PRE: Task exists.
|
||||
# @POST: Execution pauses until future is set.
|
||||
# @PARAM: task_id (str) - The ID of the task.
|
||||
async def wait_for_resolution(self, task_id: str):
|
||||
with belief_scope("TaskManager.wait_for_resolution", f"task_id={task_id}"):
|
||||
task = self.tasks.get(task_id)
|
||||
if not task: return
|
||||
|
||||
task.status = TaskStatus.AWAITING_MAPPING
|
||||
self.task_futures[task_id] = self.loop.create_future()
|
||||
|
||||
try:
|
||||
await self.task_futures[task_id]
|
||||
finally:
|
||||
if task_id in self.task_futures:
|
||||
del self.task_futures[task_id]
|
||||
# [/DEF:TaskManager.wait_for_resolution:Function]
|
||||
|
||||
# [DEF:TaskManager.wait_for_input:Function]
|
||||
# @PURPOSE: Pauses execution and waits for user input.
|
||||
# @PRE: Task exists.
|
||||
# @POST: Execution pauses until future is set via resume_task_with_password.
|
||||
# @PARAM: task_id (str) - The ID of the task.
|
||||
async def wait_for_input(self, task_id: str):
|
||||
with belief_scope("TaskManager.wait_for_input", f"task_id={task_id}"):
|
||||
task = self.tasks.get(task_id)
|
||||
if not task: return
|
||||
|
||||
# Status is already set to AWAITING_INPUT by await_input()
|
||||
self.task_futures[task_id] = self.loop.create_future()
|
||||
|
||||
try:
|
||||
await self.task_futures[task_id]
|
||||
finally:
|
||||
if task_id in self.task_futures:
|
||||
del self.task_futures[task_id]
|
||||
# [/DEF:TaskManager.wait_for_input:Function]
|
||||
|
||||
# [DEF:TaskManager.get_task:Function]
|
||||
# @PURPOSE: Retrieves a task by its ID.
|
||||
# @PARAM: task_id (str) - ID of the task.
|
||||
# @RETURN: Optional[Task] - The task or None.
|
||||
def get_task(self, task_id: str) -> Optional[Task]:
|
||||
return self.tasks.get(task_id)
|
||||
# [/DEF:TaskManager.get_task:Function]
|
||||
|
||||
# [DEF:TaskManager.get_all_tasks:Function]
|
||||
# @PURPOSE: Retrieves all registered tasks.
|
||||
# @RETURN: List[Task] - All tasks.
|
||||
def get_all_tasks(self) -> List[Task]:
|
||||
return list(self.tasks.values())
|
||||
# [/DEF:TaskManager.get_all_tasks:Function]
|
||||
|
||||
# [DEF:TaskManager.get_tasks:Function]
|
||||
# @PURPOSE: Retrieves tasks with pagination and optional status filter.
|
||||
# @PRE: limit and offset are non-negative integers.
|
||||
# @POST: Returns a list of tasks sorted by start_time descending.
|
||||
# @PARAM: limit (int) - Maximum number of tasks to return.
|
||||
# @PARAM: offset (int) - Number of tasks to skip.
|
||||
# @PARAM: status (Optional[TaskStatus]) - Filter by task status.
|
||||
# @RETURN: List[Task] - List of tasks matching criteria.
|
||||
def get_tasks(self, limit: int = 10, offset: int = 0, status: Optional[TaskStatus] = None) -> List[Task]:
|
||||
tasks = list(self.tasks.values())
|
||||
if status:
|
||||
tasks = [t for t in tasks if t.status == status]
|
||||
# Sort by start_time descending (most recent first)
|
||||
tasks.sort(key=lambda t: t.started_at or datetime.min, reverse=True)
|
||||
return tasks[offset:offset + limit]
|
||||
# [/DEF:TaskManager.get_tasks:Function]
|
||||
|
||||
# [DEF:TaskManager.get_task_logs:Function]
|
||||
# @PURPOSE: Retrieves logs for a specific task.
|
||||
# @PARAM: task_id (str) - ID of the task.
|
||||
# @RETURN: List[LogEntry] - List of log entries.
|
||||
def get_task_logs(self, task_id: str) -> List[LogEntry]:
|
||||
task = self.tasks.get(task_id)
|
||||
return task.logs if task else []
|
||||
# [/DEF:TaskManager.get_task_logs:Function]
|
||||
|
||||
# [DEF:TaskManager._add_log:Function]
|
||||
# @PURPOSE: Adds a log entry to a task and notifies subscribers.
|
||||
# @PRE: Task exists.
|
||||
# @POST: Log added to task and pushed to queues.
|
||||
# @PARAM: task_id (str) - ID of the task.
|
||||
# @PARAM: level (str) - Log level.
|
||||
# @PARAM: message (str) - Log message.
|
||||
# @PARAM: context (Optional[Dict]) - Log context.
|
||||
def _add_log(self, task_id: str, level: str, message: str, context: Optional[Dict[str, Any]] = None):
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
|
||||
log_entry = LogEntry(level=level, message=message, context=context)
|
||||
task.logs.append(log_entry)
|
||||
|
||||
# Notify subscribers
|
||||
if task_id in self.subscribers:
|
||||
for queue in self.subscribers[task_id]:
|
||||
self.loop.call_soon_threadsafe(queue.put_nowait, log_entry)
|
||||
# [/DEF:TaskManager._add_log:Function]
|
||||
|
||||
# [DEF:TaskManager.subscribe_logs:Function]
|
||||
# @PURPOSE: Subscribes to real-time logs for a task.
|
||||
# @PARAM: task_id (str) - ID of the task.
|
||||
# @RETURN: asyncio.Queue - Queue for log entries.
|
||||
async def subscribe_logs(self, task_id: str) -> asyncio.Queue:
|
||||
queue = asyncio.Queue()
|
||||
if task_id not in self.subscribers:
|
||||
self.subscribers[task_id] = []
|
||||
self.subscribers[task_id].append(queue)
|
||||
return queue
|
||||
# [/DEF:TaskManager.subscribe_logs:Function]
|
||||
|
||||
# [DEF:TaskManager.unsubscribe_logs:Function]
|
||||
# @PURPOSE: Unsubscribes from real-time logs for a task.
|
||||
# @PARAM: task_id (str) - ID of the task.
|
||||
# @PARAM: queue (asyncio.Queue) - Queue to remove.
|
||||
def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue):
|
||||
if task_id in self.subscribers:
|
||||
if queue in self.subscribers[task_id]:
|
||||
self.subscribers[task_id].remove(queue)
|
||||
if not self.subscribers[task_id]:
|
||||
del self.subscribers[task_id]
|
||||
# [/DEF:TaskManager.unsubscribe_logs:Function]
|
||||
|
||||
# [DEF:TaskManager.persist_awaiting_input_tasks:Function]
|
||||
# @PURPOSE: Persist tasks in AWAITING_INPUT state using persistence service.
|
||||
def persist_awaiting_input_tasks(self) -> None:
|
||||
self.persistence_service.persist_tasks(list(self.tasks.values()))
|
||||
# [/DEF:TaskManager.persist_awaiting_input_tasks:Function]
|
||||
|
||||
# [DEF:TaskManager.load_persisted_tasks:Function]
|
||||
# @PURPOSE: Load persisted tasks using persistence service.
|
||||
def load_persisted_tasks(self) -> None:
|
||||
loaded_tasks = self.persistence_service.load_tasks()
|
||||
for task in loaded_tasks:
|
||||
if task.id not in self.tasks:
|
||||
self.tasks[task.id] = task
|
||||
# [/DEF:TaskManager.load_persisted_tasks:Function]
|
||||
|
||||
# [DEF:TaskManager.await_input:Function]
|
||||
# @PURPOSE: Transition a task to AWAITING_INPUT state with input request.
|
||||
# @PRE: Task exists and is in RUNNING state.
|
||||
# @POST: Task status changed to AWAITING_INPUT, input_request set, persisted.
|
||||
# @PARAM: task_id (str) - ID of the task.
|
||||
# @PARAM: input_request (Dict) - Details about required input.
|
||||
# @THROWS: ValueError if task not found or not RUNNING.
|
||||
def await_input(self, task_id: str, input_request: Dict[str, Any]) -> None:
|
||||
with belief_scope("TaskManager.await_input", f"task_id={task_id}"):
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
if task.status != TaskStatus.RUNNING:
|
||||
raise ValueError(f"Task {task_id} is not RUNNING (current: {task.status})")
|
||||
|
||||
task.status = TaskStatus.AWAITING_INPUT
|
||||
task.input_required = True
|
||||
task.input_request = input_request
|
||||
self._add_log(task_id, "INFO", "Task paused for user input", {"input_request": input_request})
|
||||
|
||||
self.persist_awaiting_input_tasks()
|
||||
# [/DEF:TaskManager.await_input:Function]
|
||||
|
||||
# [DEF:TaskManager.resume_task_with_password:Function]
|
||||
# @PURPOSE: Resume a task that is awaiting input with provided passwords.
|
||||
# @PRE: Task exists and is in AWAITING_INPUT state.
|
||||
# @POST: Task status changed to RUNNING, passwords injected, task resumed.
|
||||
# @PARAM: task_id (str) - ID of the task.
|
||||
# @PARAM: passwords (Dict[str, str]) - Mapping of database name to password.
|
||||
# @THROWS: ValueError if task not found, not awaiting input, or passwords invalid.
|
||||
def resume_task_with_password(self, task_id: str, passwords: Dict[str, str]) -> None:
|
||||
with belief_scope("TaskManager.resume_task_with_password", f"task_id={task_id}"):
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
if task.status != TaskStatus.AWAITING_INPUT:
|
||||
raise ValueError(f"Task {task_id} is not AWAITING_INPUT (current: {task.status})")
|
||||
|
||||
if not isinstance(passwords, dict) or not passwords:
|
||||
raise ValueError("Passwords must be a non-empty dictionary")
|
||||
|
||||
task.params["passwords"] = passwords
|
||||
task.input_required = False
|
||||
task.input_request = None
|
||||
task.status = TaskStatus.RUNNING
|
||||
self._add_log(task_id, "INFO", "Task resumed with passwords", {"databases": list(passwords.keys())})
|
||||
|
||||
if task_id in self.task_futures:
|
||||
self.task_futures[task_id].set_result(True)
|
||||
|
||||
self.persist_awaiting_input_tasks()
|
||||
# [/DEF:TaskManager.resume_task_with_password:Function]
|
||||
|
||||
# [/DEF:TaskManager:Class]
|
||||
# [/DEF:TaskManagerModule:Module]
|
||||
Reference in New Issue
Block a user