symbiont/symbiont/task_manager.py

386 lines
13 KiB
Python

"""
Compound Task Manager
=====================
Manages the complete lifecycle of compound tasks:
- Tracks task state through planning and execution phases
- Executes subtasks in parallel while respecting dependencies
- Updates progress in real-time (via polling)
- Logs all executions to the immutable ledger
Architecture:
1. submit_compound_task() - plans task (sync, fast) then spawns background execution
2. Background thread executes subtasks in dependency-respecting waves
3. get_task_progress() - clients poll for current state
4. list_recent_tasks() - dashboard view of recent executions
"""
import json
import threading
import subprocess
import uuid
from datetime import datetime, timezone
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional, Dict, List, Any
from pathlib import Path
import time
from .planner import plan_task
# In-memory store for active/recent tasks
# Limited to _MAX_TASKS to prevent unbounded memory growth
_tasks = {}
_tasks_lock = threading.Lock()
_MAX_TASKS = 50
# Model-to-tier mapping
TIER_MODELS = {1: "haiku", 2: "sonnet", 3: "opus"}
# Approximate costs per execution (USD)
# These are rough estimates - actual costs depend on token usage
TIER_COSTS = {1: 0.008, 2: 0.04, 3: 0.15}
# Ledger file path (immutable execution log)
LEDGER_PATH = Path("/data/symbiont/ledger.jsonl")
def _log_ledger(entry: Dict[str, Any]) -> None:
"""
Append an entry to the immutable execution ledger.
Args:
entry: Dictionary with execution details (timestamp, model, tokens, cost, etc.)
The ledger is a jsonl file (one JSON object per line) used for:
- Cost tracking and billing
- Audit trails
- Performance analysis
"""
try:
with open(LEDGER_PATH, "a") as f:
f.write(json.dumps(entry) + "\n")
except Exception:
# Silently fail ledger writes to prevent task execution failures
pass
def _execute_subtask(subtask: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute a single subtask via Claude CLI.
Args:
subtask: Subtask dict with description and tier assignment
Returns:
Updated subtask dict with result, status, cost, and timing info
Process:
1. Determine model from tier assignment
2. Invoke Claude CLI with subprocess
3. Parse JSON output and extract result
4. Calculate cost and record in ledger
5. Update subtask status and completion time
"""
model = TIER_MODELS.get(subtask.get("tier_assigned") or subtask.get("tier_hint", 2), "sonnet")
_update_subtask(subtask, model=model, status="executing",
started_at=datetime.now(timezone.utc).isoformat())
try:
# Execute via Claude CLI with JSON output mode
result = subprocess.run(
['claude', '-p', '--model', model, '--output-format', 'json'],
input=subtask["description"],
capture_output=True,
text=True,
timeout=120
)
output = json.loads(result.stdout)
response_text = output.get('result', output.get('content', str(output)))
# Extract token counts for cost calculation
tokens_in = output.get('input_tokens', 0)
tokens_out = output.get('output_tokens', 0)
cost = TIER_COSTS.get(subtask.get("tier_assigned", 2), 0.04)
truncated = response_text[:2000]
if len(response_text) > 2000:
truncated += "\n[TRUNCATED...]"
_update_subtask(subtask, status="completed", result=truncated,
cost=cost, completed_at=datetime.now(timezone.utc).isoformat())
# Log successful execution to ledger
_log_ledger({
"timestamp": subtask["completed_at"],
"model": model,
"success": True,
"input_tokens": tokens_in,
"output_tokens": tokens_out,
"estimated_cost_usd": cost,
"prompt_preview": subtask["description"][:100],
"compound_task_id": subtask.get("id", "unknown")
})
except subprocess.TimeoutExpired:
now = datetime.now(timezone.utc).isoformat()
_update_subtask(subtask, status="failed", result="Execution timed out (120s)", completed_at=now)
_log_ledger({
"timestamp": now, "model": model, "success": False,
"error": "timeout", "compound_task_id": subtask.get("id", "unknown")
})
except Exception as e:
now = datetime.now(timezone.utc).isoformat()
_update_subtask(subtask, status="failed", result=f"Error: {str(e)}", completed_at=now)
_log_ledger({
"timestamp": now, "model": model, "success": False,
"error": str(e), "compound_task_id": subtask.get("id", "unknown")
})
return subtask
def _validate_dependencies(subtasks: List[Dict[str, Any]]) -> None:
"""Validate and clamp dependency indices to valid range."""
valid_indices = set(s["index"] for s in subtasks)
for st in subtasks:
deps = st.get("depends_on", [])
# Remove self-references and out-of-range indices
st["depends_on"] = [d for d in deps if d in valid_indices and d != st["index"]]
def _update_subtask(subtask: Dict[str, Any], **updates) -> None:
"""Thread-safe subtask field update under the global lock."""
with _tasks_lock:
subtask.update(updates)
def _run_compound_task(task_id: str) -> None:
"""
Background thread function: execute subtasks respecting dependency order.
Args:
task_id: ID of the compound task to execute
Execution strategy:
1. Validate dependency graph
2. Execute subtasks in dependency-respecting waves
3. A subtask is ready when all its dependencies have completed
4. Ready subtasks are executed in parallel (up to 4 concurrent workers)
5. Repeat until all subtasks are complete or stuck
6. Calculate total cost and finalize task status
"""
with _tasks_lock:
task = _tasks.get(task_id)
if not task:
return
task["status"] = "executing"
subtasks = task["subtasks"]
completed_indices = set()
# Phase 0: Validate dependency graph
_validate_dependencies(subtasks)
# Phase 1: Routing - assign tier to each subtask
for st in subtasks:
_update_subtask(st, status="routing", tier_assigned=st.get("tier_hint", 2))
# Phase 2: Execution - run subtasks in waves respecting dependencies
max_stall_cycles = 60 # 30 seconds max stall (0.5s * 60)
stall_count = 0
with ThreadPoolExecutor(max_workers=4) as executor:
while len(completed_indices) < len(subtasks):
# Find subtasks that are ready to execute
ready = []
for st in subtasks:
if st["index"] in completed_indices:
continue
if st["status"] in ("executing", "queued", "failed"):
continue
deps = set(st.get("depends_on", []))
if deps.issubset(completed_indices):
ready.append(st)
if not ready:
remaining = [s for s in subtasks if s["index"] not in completed_indices]
still_running = any(s["status"] in ("executing", "queued") for s in remaining)
if not still_running:
# Truly stuck — all remaining are blocked by failed deps
for s in remaining:
if s["status"] not in ("completed", "failed"):
_update_subtask(s, status="failed", result="Blocked by failed dependency")
completed_indices.add(s["index"])
break
stall_count += 1
if stall_count > max_stall_cycles:
break
time.sleep(0.5)
continue
stall_count = 0 # Reset on progress
# Launch ready subtasks in parallel
futures = {}
for st in ready:
_update_subtask(st, status="queued")
futures[executor.submit(_execute_subtask, st)] = st
for future in as_completed(futures):
st = futures[future]
try:
future.result()
except Exception as e:
_update_subtask(st, status="failed", result=str(e))
if st["status"] in ("completed", "failed"):
completed_indices.add(st["index"])
# Phase 3: Finalization
total_cost = sum(s.get("cost", 0) or 0 for s in subtasks)
all_ok = all(s["status"] == "completed" for s in subtasks)
with _tasks_lock:
task["status"] = "completed" if all_ok else "partial"
task["completed_at"] = datetime.now(timezone.utc).isoformat()
task["total_cost"] = total_cost
def submit_compound_task(prompt: str, auth_token: Optional[str] = None) -> Dict[str, Any]:
"""
Plan and begin executing a compound task.
Args:
prompt: The user's request to decompose and execute
auth_token: (Optional) authentication token for the submission
Returns:
Immediate response with task ID for polling:
{
"id": "compound-{uuid}",
"status": "planned",
"subtask_count": N
}
Process:
1. Use Haiku to plan the task (fast, synchronous)
2. Store task in memory
3. Spawn background thread for async execution
4. Return immediately to client for polling
The client can then poll /task/{task_id}/progress to monitor execution.
"""
# Phase 1: Plan (synchronous - fast, uses Haiku)
task = plan_task(prompt)
task_id = task["id"]
with _tasks_lock:
# Evict oldest task if we're at capacity
if len(_tasks) >= _MAX_TASKS:
oldest_key = min(_tasks, key=lambda k: _tasks[k].get("created_at", ""))
del _tasks[oldest_key]
_tasks[task_id] = task
# Phase 2: Execute (async - in background thread)
# The task will progress from "planned" -> "executing" -> "completed"/"partial"
thread = threading.Thread(target=_run_compound_task, args=(task_id,), daemon=True)
thread.start()
return {
"id": task_id,
"status": task["status"],
"subtask_count": len(task["subtasks"])
}
def get_task_progress(task_id: str) -> Optional[Dict[str, Any]]:
"""
Get current state of a compound task (for polling/dashboard).
Args:
task_id: The task ID from submit_compound_task()
Returns:
Complete task snapshot including all subtask progress, or None if not found.
Returned structure:
{
"id": task_id,
"prompt": original prompt,
"status": "planned"|"executing"|"completed"|"partial",
"reasoning": explanation from planner,
"subtasks": [
{
"id": subtask ID,
"index": 0,
"description": task description,
"tier_hint": 1|2|3,
"tier_assigned": 1|2|3,
"model": "haiku"|"sonnet"|"opus"|None,
"depends_on": [indices],
"status": "pending"|"routing"|"queued"|"executing"|"completed"|"failed",
"result": str or None,
"cost": float or None,
"started_at": ISO8601 or None,
"completed_at": ISO8601 or None
},
...
],
"created_at": ISO8601,
"planned_at": ISO8601,
"completed_at": ISO8601 or None,
"total_cost": float
}
"""
with _tasks_lock:
task = _tasks.get(task_id)
if not task:
return None
# Return a deep copy for thread safety
return json.loads(json.dumps(task, default=str))
def list_recent_tasks(limit: int = 20) -> List[Dict[str, Any]]:
"""
List recent compound tasks (for dashboard view).
Args:
limit: Maximum number of tasks to return
Returns:
List of task summaries (most recent first):
[
{
"id": task ID,
"prompt": truncated prompt,
"status": current status,
"subtask_count": total subtasks,
"completed_count": subtasks finished,
"total_cost": cumulative USD cost,
"created_at": ISO8601,
"completed_at": ISO8601 or None
},
...
]
"""
with _tasks_lock:
tasks = sorted(_tasks.values(), key=lambda t: t.get("created_at", ""), reverse=True)
return [
{
"id": t["id"],
"prompt": t["prompt"][:100],
"status": t["status"],
"subtask_count": len(t["subtasks"]),
"completed_count": sum(1 for s in t["subtasks"] if s["status"] == "completed"),
"total_cost": t.get("total_cost", 0),
"created_at": t.get("created_at"),
"completed_at": t.get("completed_at")
}
for t in tasks[:limit]
]