TaskManager refactor

This commit is contained in:
2025-12-29 10:13:37 +03:00
parent 6962a78112
commit 4c9d554432
25 changed files with 2778 additions and 283 deletions

View File

@@ -3,11 +3,11 @@
# @PURPOSE: Defines the FastAPI router for task-related endpoints, allowing clients to create, list, and get the status of tasks.
# @LAYER: UI (API)
# @RELATION: Depends on the TaskManager. It is included by the main app.
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from ...core.task_manager import TaskManager, Task
from ...core.task_manager import TaskManager, Task, TaskStatus, LogEntry
from ...dependencies import get_task_manager
router = APIRouter()
@@ -19,6 +19,9 @@ class CreateTaskRequest(BaseModel):
class ResolveTaskRequest(BaseModel):
resolution_params: Dict[str, Any]
class ResumeTaskRequest(BaseModel):
passwords: Dict[str, str]
@router.post("/", response_model=Task, status_code=status.HTTP_201_CREATED)
async def create_task(
request: CreateTaskRequest,
@@ -72,4 +75,19 @@ async def resolve_task(
return task_manager.get_task(task_id)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.post("/{task_id}/resume", response_model=Task)
async def resume_task(
task_id: str,
request: ResumeTaskRequest,
task_manager: TaskManager = Depends(get_task_manager)
):
"""
Resume a task that is awaiting input (e.g., passwords).
"""
try:
task_manager.resume_task_with_password(task_id, request.passwords)
return task_manager.get_task(task_id)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# [/DEF]

View File

@@ -1,226 +0,0 @@
# [DEF:TaskManagerModule:Module]
# @SEMANTICS: task, manager, lifecycle, execution, state
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously.
# @LAYER: Core
# @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks.
import asyncio
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, Any, List, Optional
from concurrent.futures import ThreadPoolExecutor
from pydantic import BaseModel, Field
# Assuming PluginBase and PluginConfig are defined in plugin_base.py
# from .plugin_base import PluginBase, PluginConfig # Not needed here, TaskManager interacts with the PluginLoader
# [DEF:TaskStatus:Enum]
# @SEMANTICS: task, status, state, enum
# @PURPOSE: Defines the possible states a task can be in during its lifecycle.
class TaskStatus(str, Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
AWAITING_MAPPING = "AWAITING_MAPPING"
# [/DEF]
# [DEF:LogEntry:Class]
# @SEMANTICS: log, entry, record, pydantic
# @PURPOSE: A Pydantic model representing a single, structured log entry associated with a task.
class LogEntry(BaseModel):
timestamp: datetime = Field(default_factory=datetime.utcnow)
level: str
message: str
context: Optional[Dict[str, Any]] = None
# [/DEF]
# [DEF:Task:Class]
# @SEMANTICS: task, job, execution, state, pydantic
# @PURPOSE: A Pydantic model representing a single execution instance of a plugin, including its status, parameters, and logs.
class Task(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
plugin_id: str
status: TaskStatus = TaskStatus.PENDING
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
user_id: Optional[str] = None
logs: List[LogEntry] = Field(default_factory=list)
params: Dict[str, Any] = Field(default_factory=dict)
# [/DEF]
# [DEF:TaskManager:Class]
# @SEMANTICS: task, manager, lifecycle, execution, state
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking.
class TaskManager:
"""
Manages the lifecycle of tasks, including their creation, execution, and state tracking.
"""
def __init__(self, plugin_loader):
self.plugin_loader = plugin_loader
self.tasks: Dict[str, Task] = {}
self.subscribers: Dict[str, List[asyncio.Queue]] = {}
self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.get_event_loop()
self.task_futures: Dict[str, asyncio.Future] = {}
# [/DEF]
async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task:
"""
Creates and queues a new task for execution.
"""
from ..core.logger import logger
logger.info(f"TaskManager: Creating task for plugin '{plugin_id}' with params: {params}")
if not self.plugin_loader.has_plugin(plugin_id):
logger.error(f"TaskManager: Plugin with ID '{plugin_id}' not found.")
raise ValueError(f"Plugin with ID '{plugin_id}' not found.")
plugin = self.plugin_loader.get_plugin(plugin_id)
logger.info(f"TaskManager: Found plugin '{plugin.name}' for task creation")
# Validate params against plugin schema (this will be done at a higher level, e.g., API route)
# For now, a basic check
if not isinstance(params, dict):
logger.error("TaskManager: Task parameters must be a dictionary.")
raise ValueError("Task parameters must be a dictionary.")
task = Task(plugin_id=plugin_id, params=params, user_id=user_id)
self.tasks[task.id] = task
logger.info(f"TaskManager: Task {task.id} created and scheduled for execution")
self.loop.create_task(self._run_task(task.id)) # Schedule task for execution
return task
async def _run_task(self, task_id: str):
"""
Internal method to execute a task.
"""
from ..core.logger import logger
task = self.tasks[task_id]
plugin = self.plugin_loader.get_plugin(task.plugin_id)
logger.info(f"TaskManager: Starting execution of task {task_id} for plugin '{plugin.name}'")
task.status = TaskStatus.RUNNING
task.started_at = datetime.utcnow()
self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'")
try:
# Execute plugin in a separate thread to avoid blocking the event loop
# if the plugin's execute method is synchronous and potentially CPU-bound.
# If the plugin's execute method is already async, this can be simplified.
# Pass task_id to plugin so it can signal pause
params = {**task.params, "_task_id": task_id}
logger.info(f"TaskManager: Executing plugin '{plugin.name}' with params: {params}")
if asyncio.iscoroutinefunction(plugin.execute):
logger.info(f"TaskManager: Executing async plugin '{plugin.name}'")
await plugin.execute(params)
else:
logger.info(f"TaskManager: Executing sync plugin '{plugin.name}' in executor")
await self.loop.run_in_executor(
self.executor,
plugin.execute,
params
)
logger.info(f"TaskManager: Task {task_id} completed successfully for plugin '{plugin.name}'")
task.status = TaskStatus.SUCCESS
self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'")
except Exception as e:
logger.error(f"TaskManager: Task {task_id} failed for plugin '{plugin.name}': {e}")
task.status = TaskStatus.FAILED
self._add_log(task_id, "ERROR", f"Task failed: {e}", {"error_type": type(e).__name__})
finally:
task.finished_at = datetime.utcnow()
logger.info(f"TaskManager: Task {task_id} execution finished with status: {task.status}")
# In a real system, you might notify clients via WebSocket here
async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]):
"""
Resumes a task that is awaiting mapping.
"""
task = self.tasks.get(task_id)
if not task or task.status != TaskStatus.AWAITING_MAPPING:
raise ValueError("Task is not awaiting mapping.")
# Update task params with resolution
task.params.update(resolution_params)
task.status = TaskStatus.RUNNING
self._add_log(task_id, "INFO", "Task resumed after mapping resolution.")
# Signal the future to continue
if task_id in self.task_futures:
self.task_futures[task_id].set_result(True)
async def wait_for_resolution(self, task_id: str):
"""
Pauses execution and waits for a resolution signal.
"""
task = self.tasks.get(task_id)
if not task: return
task.status = TaskStatus.AWAITING_MAPPING
self.task_futures[task_id] = self.loop.create_future()
try:
await self.task_futures[task_id]
finally:
del self.task_futures[task_id]
def get_task(self, task_id: str) -> Optional[Task]:
"""
Retrieves a task by its ID.
"""
return self.tasks.get(task_id)
def get_all_tasks(self) -> List[Task]:
"""
Retrieves all registered tasks.
"""
return list(self.tasks.values())
def get_task_logs(self, task_id: str) -> List[LogEntry]:
"""
Retrieves logs for a specific task.
"""
task = self.tasks.get(task_id)
return task.logs if task else []
def _add_log(self, task_id: str, level: str, message: str, context: Optional[Dict[str, Any]] = None):
"""
Adds a log entry to a task and notifies subscribers.
"""
task = self.tasks.get(task_id)
if not task:
return
log_entry = LogEntry(level=level, message=message, context=context)
task.logs.append(log_entry)
# Notify subscribers
if task_id in self.subscribers:
for queue in self.subscribers[task_id]:
self.loop.call_soon_threadsafe(queue.put_nowait, log_entry)
async def subscribe_logs(self, task_id: str) -> asyncio.Queue:
"""
Subscribes to real-time logs for a task.
"""
queue = asyncio.Queue()
if task_id not in self.subscribers:
self.subscribers[task_id] = []
self.subscribers[task_id].append(queue)
return queue
def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue):
"""
Unsubscribes from real-time logs for a task.
"""
if task_id in self.subscribers:
self.subscribers[task_id].remove(queue)
if not self.subscribers[task_id]:
del self.subscribers[task_id]

View File

@@ -0,0 +1,12 @@
# [DEF:TaskManagerPackage:Module]
# @SEMANTICS: task, manager, package, exports
# @PURPOSE: Exports the public API of the task manager package.
# @LAYER: Core
# @RELATION: Aggregates models and manager.
from .models import Task, TaskStatus, LogEntry
from .manager import TaskManager
__all__ = ["TaskManager", "Task", "TaskStatus", "LogEntry"]
# [/DEF:TaskManagerPackage:Module]

View File

@@ -0,0 +1,335 @@
# [DEF:TaskManagerModule:Module]
# @SEMANTICS: task, manager, lifecycle, execution, state
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously.
# @LAYER: Core
# @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks.
# @INVARIANT: Task IDs are unique.
# @CONSTRAINT: Must use belief_scope for logging.
# [SECTION: IMPORTS]
import asyncio
from datetime import datetime
from typing import Dict, Any, List, Optional
from concurrent.futures import ThreadPoolExecutor
from .models import Task, TaskStatus, LogEntry
from .persistence import TaskPersistenceService
from ..logger import logger, belief_scope
# [/SECTION]
# [DEF:TaskManager:Class]
# @SEMANTICS: task, manager, lifecycle, execution, state
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking.
class TaskManager:
"""
Manages the lifecycle of tasks, including their creation, execution, and state tracking.
"""
# [DEF:TaskManager.__init__:Function]
# @PURPOSE: Initialize the TaskManager with dependencies.
# @PRE: plugin_loader is initialized.
# @POST: TaskManager is ready to accept tasks.
# @PARAM: plugin_loader - The plugin loader instance.
def __init__(self, plugin_loader):
with belief_scope("TaskManager.__init__"):
self.plugin_loader = plugin_loader
self.tasks: Dict[str, Task] = {}
self.subscribers: Dict[str, List[asyncio.Queue]] = {}
self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution
self.persistence_service = TaskPersistenceService()
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.get_event_loop()
self.task_futures: Dict[str, asyncio.Future] = {}
# [/DEF:TaskManager.__init__:Function]
# [DEF:TaskManager.create_task:Function]
# @PURPOSE: Creates and queues a new task for execution.
# @PRE: Plugin with plugin_id exists. Params are valid.
# @POST: Task is created, added to registry, and scheduled for execution.
# @PARAM: plugin_id (str) - The ID of the plugin to run.
# @PARAM: params (Dict[str, Any]) - Parameters for the plugin.
# @PARAM: user_id (Optional[str]) - ID of the user requesting the task.
# @RETURN: Task - The created task instance.
# @THROWS: ValueError if plugin not found or params invalid.
async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task:
with belief_scope("TaskManager.create_task", f"plugin_id={plugin_id}"):
if not self.plugin_loader.has_plugin(plugin_id):
logger.error(f"Plugin with ID '{plugin_id}' not found.")
raise ValueError(f"Plugin with ID '{plugin_id}' not found.")
plugin = self.plugin_loader.get_plugin(plugin_id)
if not isinstance(params, dict):
logger.error("Task parameters must be a dictionary.")
raise ValueError("Task parameters must be a dictionary.")
task = Task(plugin_id=plugin_id, params=params, user_id=user_id)
self.tasks[task.id] = task
logger.info(f"Task {task.id} created and scheduled for execution")
self.loop.create_task(self._run_task(task.id)) # Schedule task for execution
return task
# [/DEF:TaskManager.create_task:Function]
# [DEF:TaskManager._run_task:Function]
# @PURPOSE: Internal method to execute a task.
# @PRE: Task exists in registry.
# @POST: Task is executed, status updated to SUCCESS or FAILED.
# @PARAM: task_id (str) - The ID of the task to run.
async def _run_task(self, task_id: str):
with belief_scope("TaskManager._run_task", f"task_id={task_id}"):
task = self.tasks[task_id]
plugin = self.plugin_loader.get_plugin(task.plugin_id)
logger.info(f"Starting execution of task {task_id} for plugin '{plugin.name}'")
task.status = TaskStatus.RUNNING
task.started_at = datetime.utcnow()
self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'")
try:
# Execute plugin
params = {**task.params, "_task_id": task_id}
if asyncio.iscoroutinefunction(plugin.execute):
await plugin.execute(params)
else:
await self.loop.run_in_executor(
self.executor,
plugin.execute,
params
)
logger.info(f"Task {task_id} completed successfully")
task.status = TaskStatus.SUCCESS
self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'")
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
task.status = TaskStatus.FAILED
self._add_log(task_id, "ERROR", f"Task failed: {e}", {"error_type": type(e).__name__})
finally:
task.finished_at = datetime.utcnow()
logger.info(f"Task {task_id} execution finished with status: {task.status}")
# [/DEF:TaskManager._run_task:Function]
# [DEF:TaskManager.resolve_task:Function]
# @PURPOSE: Resumes a task that is awaiting mapping.
# @PRE: Task exists and is in AWAITING_MAPPING state.
# @POST: Task status updated to RUNNING, params updated, execution resumed.
# @PARAM: task_id (str) - The ID of the task.
# @PARAM: resolution_params (Dict[str, Any]) - Params to resolve the wait.
# @THROWS: ValueError if task not found or not awaiting mapping.
async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]):
with belief_scope("TaskManager.resolve_task", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task or task.status != TaskStatus.AWAITING_MAPPING:
raise ValueError("Task is not awaiting mapping.")
# Update task params with resolution
task.params.update(resolution_params)
task.status = TaskStatus.RUNNING
self._add_log(task_id, "INFO", "Task resumed after mapping resolution.")
# Signal the future to continue
if task_id in self.task_futures:
self.task_futures[task_id].set_result(True)
# [/DEF:TaskManager.resolve_task:Function]
# [DEF:TaskManager.wait_for_resolution:Function]
# @PURPOSE: Pauses execution and waits for a resolution signal.
# @PRE: Task exists.
# @POST: Execution pauses until future is set.
# @PARAM: task_id (str) - The ID of the task.
async def wait_for_resolution(self, task_id: str):
with belief_scope("TaskManager.wait_for_resolution", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task: return
task.status = TaskStatus.AWAITING_MAPPING
self.task_futures[task_id] = self.loop.create_future()
try:
await self.task_futures[task_id]
finally:
if task_id in self.task_futures:
del self.task_futures[task_id]
# [/DEF:TaskManager.wait_for_resolution:Function]
# [DEF:TaskManager.wait_for_input:Function]
# @PURPOSE: Pauses execution and waits for user input.
# @PRE: Task exists.
# @POST: Execution pauses until future is set via resume_task_with_password.
# @PARAM: task_id (str) - The ID of the task.
async def wait_for_input(self, task_id: str):
with belief_scope("TaskManager.wait_for_input", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task: return
# Status is already set to AWAITING_INPUT by await_input()
self.task_futures[task_id] = self.loop.create_future()
try:
await self.task_futures[task_id]
finally:
if task_id in self.task_futures:
del self.task_futures[task_id]
# [/DEF:TaskManager.wait_for_input:Function]
# [DEF:TaskManager.get_task:Function]
# @PURPOSE: Retrieves a task by its ID.
# @PARAM: task_id (str) - ID of the task.
# @RETURN: Optional[Task] - The task or None.
def get_task(self, task_id: str) -> Optional[Task]:
return self.tasks.get(task_id)
# [/DEF:TaskManager.get_task:Function]
# [DEF:TaskManager.get_all_tasks:Function]
# @PURPOSE: Retrieves all registered tasks.
# @RETURN: List[Task] - All tasks.
def get_all_tasks(self) -> List[Task]:
return list(self.tasks.values())
# [/DEF:TaskManager.get_all_tasks:Function]
# [DEF:TaskManager.get_tasks:Function]
# @PURPOSE: Retrieves tasks with pagination and optional status filter.
# @PRE: limit and offset are non-negative integers.
# @POST: Returns a list of tasks sorted by start_time descending.
# @PARAM: limit (int) - Maximum number of tasks to return.
# @PARAM: offset (int) - Number of tasks to skip.
# @PARAM: status (Optional[TaskStatus]) - Filter by task status.
# @RETURN: List[Task] - List of tasks matching criteria.
def get_tasks(self, limit: int = 10, offset: int = 0, status: Optional[TaskStatus] = None) -> List[Task]:
tasks = list(self.tasks.values())
if status:
tasks = [t for t in tasks if t.status == status]
# Sort by start_time descending (most recent first)
tasks.sort(key=lambda t: t.started_at or datetime.min, reverse=True)
return tasks[offset:offset + limit]
# [/DEF:TaskManager.get_tasks:Function]
# [DEF:TaskManager.get_task_logs:Function]
# @PURPOSE: Retrieves logs for a specific task.
# @PARAM: task_id (str) - ID of the task.
# @RETURN: List[LogEntry] - List of log entries.
def get_task_logs(self, task_id: str) -> List[LogEntry]:
task = self.tasks.get(task_id)
return task.logs if task else []
# [/DEF:TaskManager.get_task_logs:Function]
# [DEF:TaskManager._add_log:Function]
# @PURPOSE: Adds a log entry to a task and notifies subscribers.
# @PRE: Task exists.
# @POST: Log added to task and pushed to queues.
# @PARAM: task_id (str) - ID of the task.
# @PARAM: level (str) - Log level.
# @PARAM: message (str) - Log message.
# @PARAM: context (Optional[Dict]) - Log context.
def _add_log(self, task_id: str, level: str, message: str, context: Optional[Dict[str, Any]] = None):
task = self.tasks.get(task_id)
if not task:
return
log_entry = LogEntry(level=level, message=message, context=context)
task.logs.append(log_entry)
# Notify subscribers
if task_id in self.subscribers:
for queue in self.subscribers[task_id]:
self.loop.call_soon_threadsafe(queue.put_nowait, log_entry)
# [/DEF:TaskManager._add_log:Function]
# [DEF:TaskManager.subscribe_logs:Function]
# @PURPOSE: Subscribes to real-time logs for a task.
# @PARAM: task_id (str) - ID of the task.
# @RETURN: asyncio.Queue - Queue for log entries.
async def subscribe_logs(self, task_id: str) -> asyncio.Queue:
queue = asyncio.Queue()
if task_id not in self.subscribers:
self.subscribers[task_id] = []
self.subscribers[task_id].append(queue)
return queue
# [/DEF:TaskManager.subscribe_logs:Function]
# [DEF:TaskManager.unsubscribe_logs:Function]
# @PURPOSE: Unsubscribes from real-time logs for a task.
# @PARAM: task_id (str) - ID of the task.
# @PARAM: queue (asyncio.Queue) - Queue to remove.
def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue):
if task_id in self.subscribers:
if queue in self.subscribers[task_id]:
self.subscribers[task_id].remove(queue)
if not self.subscribers[task_id]:
del self.subscribers[task_id]
# [/DEF:TaskManager.unsubscribe_logs:Function]
# [DEF:TaskManager.persist_awaiting_input_tasks:Function]
# @PURPOSE: Persist tasks in AWAITING_INPUT state using persistence service.
def persist_awaiting_input_tasks(self) -> None:
self.persistence_service.persist_tasks(list(self.tasks.values()))
# [/DEF:TaskManager.persist_awaiting_input_tasks:Function]
# [DEF:TaskManager.load_persisted_tasks:Function]
# @PURPOSE: Load persisted tasks using persistence service.
def load_persisted_tasks(self) -> None:
loaded_tasks = self.persistence_service.load_tasks()
for task in loaded_tasks:
if task.id not in self.tasks:
self.tasks[task.id] = task
# [/DEF:TaskManager.load_persisted_tasks:Function]
# [DEF:TaskManager.await_input:Function]
# @PURPOSE: Transition a task to AWAITING_INPUT state with input request.
# @PRE: Task exists and is in RUNNING state.
# @POST: Task status changed to AWAITING_INPUT, input_request set, persisted.
# @PARAM: task_id (str) - ID of the task.
# @PARAM: input_request (Dict) - Details about required input.
# @THROWS: ValueError if task not found or not RUNNING.
def await_input(self, task_id: str, input_request: Dict[str, Any]) -> None:
with belief_scope("TaskManager.await_input", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
if task.status != TaskStatus.RUNNING:
raise ValueError(f"Task {task_id} is not RUNNING (current: {task.status})")
task.status = TaskStatus.AWAITING_INPUT
task.input_required = True
task.input_request = input_request
self._add_log(task_id, "INFO", "Task paused for user input", {"input_request": input_request})
self.persist_awaiting_input_tasks()
# [/DEF:TaskManager.await_input:Function]
# [DEF:TaskManager.resume_task_with_password:Function]
# @PURPOSE: Resume a task that is awaiting input with provided passwords.
# @PRE: Task exists and is in AWAITING_INPUT state.
# @POST: Task status changed to RUNNING, passwords injected, task resumed.
# @PARAM: task_id (str) - ID of the task.
# @PARAM: passwords (Dict[str, str]) - Mapping of database name to password.
# @THROWS: ValueError if task not found, not awaiting input, or passwords invalid.
def resume_task_with_password(self, task_id: str, passwords: Dict[str, str]) -> None:
with belief_scope("TaskManager.resume_task_with_password", f"task_id={task_id}"):
task = self.tasks.get(task_id)
if not task:
raise ValueError(f"Task {task_id} not found")
if task.status != TaskStatus.AWAITING_INPUT:
raise ValueError(f"Task {task_id} is not AWAITING_INPUT (current: {task.status})")
if not isinstance(passwords, dict) or not passwords:
raise ValueError("Passwords must be a non-empty dictionary")
task.params["passwords"] = passwords
task.input_required = False
task.input_request = None
task.status = TaskStatus.RUNNING
self._add_log(task_id, "INFO", "Task resumed with passwords", {"databases": list(passwords.keys())})
if task_id in self.task_futures:
self.task_futures[task_id].set_result(True)
self.persist_awaiting_input_tasks()
# [/DEF:TaskManager.resume_task_with_password:Function]
# [/DEF:TaskManager:Class]
# [/DEF:TaskManagerModule:Module]

View File

@@ -0,0 +1,67 @@
# [DEF:TaskManagerModels:Module]
# @SEMANTICS: task, models, pydantic, enum, state
# @PURPOSE: Defines the data models and enumerations used by the Task Manager.
# @LAYER: Core
# @RELATION: Used by TaskManager and API routes.
# @INVARIANT: Task IDs are immutable once created.
# @CONSTRAINT: Must use Pydantic for data validation.
# [SECTION: IMPORTS]
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
# [/SECTION]
# [DEF:TaskStatus:Enum]
# @SEMANTICS: task, status, state, enum
# @PURPOSE: Defines the possible states a task can be in during its lifecycle.
class TaskStatus(str, Enum):
PENDING = "PENDING"
RUNNING = "RUNNING"
SUCCESS = "SUCCESS"
FAILED = "FAILED"
AWAITING_MAPPING = "AWAITING_MAPPING"
AWAITING_INPUT = "AWAITING_INPUT"
# [/DEF:TaskStatus:Enum]
# [DEF:LogEntry:Class]
# @SEMANTICS: log, entry, record, pydantic
# @PURPOSE: A Pydantic model representing a single, structured log entry associated with a task.
class LogEntry(BaseModel):
timestamp: datetime = Field(default_factory=datetime.utcnow)
level: str
message: str
context: Optional[Dict[str, Any]] = None
# [/DEF:LogEntry:Class]
# [DEF:Task:Class]
# @SEMANTICS: task, job, execution, state, pydantic
# @PURPOSE: A Pydantic model representing a single execution instance of a plugin, including its status, parameters, and logs.
class Task(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
plugin_id: str
status: TaskStatus = TaskStatus.PENDING
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
user_id: Optional[str] = None
logs: List[LogEntry] = Field(default_factory=list)
params: Dict[str, Any] = Field(default_factory=dict)
input_required: bool = False
input_request: Optional[Dict[str, Any]] = None
# [DEF:Task.__init__:Function]
# @PURPOSE: Initializes the Task model and validates input_request for AWAITING_INPUT status.
# @PRE: If status is AWAITING_INPUT, input_request must be provided.
# @POST: Task instance is created or ValueError is raised.
# @PARAM: **data - Keyword arguments for model initialization.
def __init__(self, **data):
super().__init__(**data)
if self.status == TaskStatus.AWAITING_INPUT and not self.input_request:
raise ValueError("input_request is required when status is AWAITING_INPUT")
# [/DEF:Task.__init__:Function]
# [/DEF:Task:Class]
# [/DEF:TaskManagerModels:Module]

View File

@@ -0,0 +1,127 @@
# [DEF:TaskPersistenceModule:Module]
# @SEMANTICS: persistence, sqlite, task, storage
# @PURPOSE: Handles the persistence of tasks, specifically those awaiting user input, to a SQLite database.
# @LAYER: Core
# @RELATION: Used by TaskManager to save and load tasks.
# @INVARIANT: Database schema must match the Task model structure.
# @CONSTRAINT: Uses synchronous SQLite operations (blocking), should be used carefully.
# [SECTION: IMPORTS]
import sqlite3
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from .models import Task, TaskStatus
from ..logger import logger, belief_scope
# [/SECTION]
# [DEF:TaskPersistenceService:Class]
# @SEMANTICS: persistence, service, database
# @PURPOSE: Provides methods to save and load tasks from a local SQLite database.
class TaskPersistenceService:
def __init__(self, db_path: Optional[Path] = None):
if db_path is None:
self.db_path = Path(__file__).parent.parent.parent.parent / "migrations.db"
else:
self.db_path = db_path
self._ensure_db_exists()
# [DEF:TaskPersistenceService._ensure_db_exists:Function]
# @PURPOSE: Ensures the database directory and table exist.
# @PRE: None.
# @POST: Database file and table are created if they didn't exist.
def _ensure_db_exists(self) -> None:
with belief_scope("TaskPersistenceService._ensure_db_exists"):
self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS persistent_tasks (
id TEXT PRIMARY KEY,
status TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
input_request JSON,
context JSON
)
""")
conn.commit()
conn.close()
# [/DEF:TaskPersistenceService._ensure_db_exists:Function]
# [DEF:TaskPersistenceService.persist_tasks:Function]
# @PURPOSE: Persists a list of tasks to the database.
# @PRE: Tasks list contains valid Task objects.
# @POST: Tasks matching the criteria (AWAITING_INPUT) are saved/updated in the DB.
# @PARAM: tasks (List[Task]) - The list of tasks to check and persist.
def persist_tasks(self, tasks: List[Task]) -> None:
with belief_scope("TaskPersistenceService.persist_tasks"):
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
count = 0
for task in tasks:
if task.status == TaskStatus.AWAITING_INPUT:
cursor.execute("""
INSERT OR REPLACE INTO persistent_tasks
(id, status, created_at, updated_at, input_request, context)
VALUES (?, ?, ?, ?, ?, ?)
""", (
task.id,
task.status.value,
task.started_at.isoformat() if task.started_at else datetime.utcnow().isoformat(),
datetime.utcnow().isoformat(),
json.dumps(task.input_request) if task.input_request else None,
json.dumps(task.params)
))
count += 1
conn.commit()
conn.close()
logger.info(f"Persisted {count} tasks awaiting input.")
# [/DEF:TaskPersistenceService.persist_tasks:Function]
# [DEF:TaskPersistenceService.load_tasks:Function]
# @PURPOSE: Loads persisted tasks from the database.
# @PRE: Database exists.
# @POST: Returns a list of Task objects reconstructed from the DB.
# @RETURN: List[Task] - The loaded tasks.
def load_tasks(self) -> List[Task]:
with belief_scope("TaskPersistenceService.load_tasks"):
if not self.db_path.exists():
return []
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute("SELECT id, status, created_at, input_request, context FROM persistent_tasks")
rows = cursor.fetchall()
loaded_tasks = []
for row in rows:
task_id, status, created_at, input_request_json, context_json = row
try:
task = Task(
id=task_id,
plugin_id="migration", # Default, assumes migration context for now
status=TaskStatus(status),
started_at=datetime.fromisoformat(created_at),
input_required=True,
input_request=json.loads(input_request_json) if input_request_json else None,
params=json.loads(context_json) if context_json else {}
)
loaded_tasks.append(task)
except Exception as e:
logger.error(f"Failed to load task {task_id}: {e}")
conn.close()
return loaded_tasks
# [/DEF:TaskPersistenceService.load_tasks:Function]
# [/DEF:TaskPersistenceService:Class]
# [/DEF:TaskPersistenceModule:Module]

View File

@@ -218,6 +218,41 @@ class MigrationPlugin(PluginBase):
logger.info(f"[MigrationPlugin][Success] Dashboard {title} imported.")
except Exception as exc:
# Check for password error
error_msg = str(exc)
if "Must provide a password for the database" in error_msg:
# Extract database name (assuming format: "Must provide a password for the database 'PostgreSQL'")
import re
match = re.search(r"database '([^']+)'", error_msg)
db_name = match.group(1) if match else "unknown"
# Get task manager
from ..dependencies import get_task_manager
tm = get_task_manager()
task_id = params.get("_task_id")
if task_id:
input_request = {
"type": "database_password",
"databases": [db_name],
"error_message": error_msg
}
tm.await_input(task_id, input_request)
# Wait for user input
await tm.wait_for_input(task_id)
# Resume with passwords
task = tm.get_task(task_id)
passwords = task.params.get("passwords", {})
# Retry import with password
if passwords:
logger.info(f"[MigrationPlugin][Action] Retrying import for {title} with provided passwords.")
to_c.import_dashboard(file_name=tmp_new_zip, dash_id=dash_id, dash_slug=dash_slug, passwords=passwords)
logger.info(f"[MigrationPlugin][Success] Dashboard {title} imported after password injection.")
continue
logger.error(f"[MigrationPlugin][Failure] Failed to migrate dashboard {title}: {exc}", exc_info=True)
logger.info("[MigrationPlugin][Exit] Migration finished.")