386 lines
13 KiB
Python
386 lines
13 KiB
Python
"""
|
|
Compound Task Manager
|
|
=====================
|
|
|
|
Manages the complete lifecycle of compound tasks:
|
|
- Tracks task state through planning and execution phases
|
|
- Executes subtasks in parallel while respecting dependencies
|
|
- Updates progress in real-time (via polling)
|
|
- Logs all executions to the immutable ledger
|
|
|
|
Architecture:
|
|
1. submit_compound_task() - plans task (sync, fast) then spawns background execution
|
|
2. Background thread executes subtasks in dependency-respecting waves
|
|
3. get_task_progress() - clients poll for current state
|
|
4. list_recent_tasks() - dashboard view of recent executions
|
|
"""
|
|
|
|
import json
|
|
import threading
|
|
import subprocess
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import Optional, Dict, List, Any
|
|
from pathlib import Path
|
|
import time
|
|
|
|
from .planner import plan_task
|
|
|
|
|
|
# In-memory store for active/recent tasks
|
|
# Limited to _MAX_TASKS to prevent unbounded memory growth
|
|
_tasks = {}
|
|
_tasks_lock = threading.Lock()
|
|
_MAX_TASKS = 50
|
|
|
|
# Model-to-tier mapping
|
|
TIER_MODELS = {1: "haiku", 2: "sonnet", 3: "opus"}
|
|
|
|
# Approximate costs per execution (USD)
|
|
# These are rough estimates - actual costs depend on token usage
|
|
TIER_COSTS = {1: 0.008, 2: 0.04, 3: 0.15}
|
|
|
|
# Ledger file path (immutable execution log)
|
|
LEDGER_PATH = Path("/data/symbiont/ledger.jsonl")
|
|
|
|
|
|
def _log_ledger(entry: Dict[str, Any]) -> None:
|
|
"""
|
|
Append an entry to the immutable execution ledger.
|
|
|
|
Args:
|
|
entry: Dictionary with execution details (timestamp, model, tokens, cost, etc.)
|
|
|
|
The ledger is a jsonl file (one JSON object per line) used for:
|
|
- Cost tracking and billing
|
|
- Audit trails
|
|
- Performance analysis
|
|
"""
|
|
try:
|
|
with open(LEDGER_PATH, "a") as f:
|
|
f.write(json.dumps(entry) + "\n")
|
|
except Exception:
|
|
# Silently fail ledger writes to prevent task execution failures
|
|
pass
|
|
|
|
|
|
def _execute_subtask(subtask: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Execute a single subtask via Claude CLI.
|
|
|
|
Args:
|
|
subtask: Subtask dict with description and tier assignment
|
|
|
|
Returns:
|
|
Updated subtask dict with result, status, cost, and timing info
|
|
|
|
Process:
|
|
1. Determine model from tier assignment
|
|
2. Invoke Claude CLI with subprocess
|
|
3. Parse JSON output and extract result
|
|
4. Calculate cost and record in ledger
|
|
5. Update subtask status and completion time
|
|
"""
|
|
|
|
model = TIER_MODELS.get(subtask.get("tier_assigned") or subtask.get("tier_hint", 2), "sonnet")
|
|
_update_subtask(subtask, model=model, status="executing",
|
|
started_at=datetime.now(timezone.utc).isoformat())
|
|
|
|
try:
|
|
# Execute via Claude CLI with JSON output mode
|
|
result = subprocess.run(
|
|
['claude', '-p', '--model', model, '--output-format', 'json'],
|
|
input=subtask["description"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=120
|
|
)
|
|
|
|
output = json.loads(result.stdout)
|
|
response_text = output.get('result', output.get('content', str(output)))
|
|
|
|
# Extract token counts for cost calculation
|
|
tokens_in = output.get('input_tokens', 0)
|
|
tokens_out = output.get('output_tokens', 0)
|
|
cost = TIER_COSTS.get(subtask.get("tier_assigned", 2), 0.04)
|
|
|
|
truncated = response_text[:2000]
|
|
if len(response_text) > 2000:
|
|
truncated += "\n[TRUNCATED...]"
|
|
_update_subtask(subtask, status="completed", result=truncated,
|
|
cost=cost, completed_at=datetime.now(timezone.utc).isoformat())
|
|
|
|
# Log successful execution to ledger
|
|
_log_ledger({
|
|
"timestamp": subtask["completed_at"],
|
|
"model": model,
|
|
"success": True,
|
|
"input_tokens": tokens_in,
|
|
"output_tokens": tokens_out,
|
|
"estimated_cost_usd": cost,
|
|
"prompt_preview": subtask["description"][:100],
|
|
"compound_task_id": subtask.get("id", "unknown")
|
|
})
|
|
|
|
except subprocess.TimeoutExpired:
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
_update_subtask(subtask, status="failed", result="Execution timed out (120s)", completed_at=now)
|
|
_log_ledger({
|
|
"timestamp": now, "model": model, "success": False,
|
|
"error": "timeout", "compound_task_id": subtask.get("id", "unknown")
|
|
})
|
|
|
|
except Exception as e:
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
_update_subtask(subtask, status="failed", result=f"Error: {str(e)}", completed_at=now)
|
|
_log_ledger({
|
|
"timestamp": now, "model": model, "success": False,
|
|
"error": str(e), "compound_task_id": subtask.get("id", "unknown")
|
|
})
|
|
|
|
return subtask
|
|
|
|
|
|
def _validate_dependencies(subtasks: List[Dict[str, Any]]) -> None:
|
|
"""Validate and clamp dependency indices to valid range."""
|
|
valid_indices = set(s["index"] for s in subtasks)
|
|
for st in subtasks:
|
|
deps = st.get("depends_on", [])
|
|
# Remove self-references and out-of-range indices
|
|
st["depends_on"] = [d for d in deps if d in valid_indices and d != st["index"]]
|
|
|
|
|
|
def _update_subtask(subtask: Dict[str, Any], **updates) -> None:
|
|
"""Thread-safe subtask field update under the global lock."""
|
|
with _tasks_lock:
|
|
subtask.update(updates)
|
|
|
|
|
|
def _run_compound_task(task_id: str) -> None:
|
|
"""
|
|
Background thread function: execute subtasks respecting dependency order.
|
|
|
|
Args:
|
|
task_id: ID of the compound task to execute
|
|
|
|
Execution strategy:
|
|
1. Validate dependency graph
|
|
2. Execute subtasks in dependency-respecting waves
|
|
3. A subtask is ready when all its dependencies have completed
|
|
4. Ready subtasks are executed in parallel (up to 4 concurrent workers)
|
|
5. Repeat until all subtasks are complete or stuck
|
|
6. Calculate total cost and finalize task status
|
|
"""
|
|
|
|
with _tasks_lock:
|
|
task = _tasks.get(task_id)
|
|
if not task:
|
|
return
|
|
task["status"] = "executing"
|
|
|
|
subtasks = task["subtasks"]
|
|
completed_indices = set()
|
|
|
|
# Phase 0: Validate dependency graph
|
|
_validate_dependencies(subtasks)
|
|
|
|
# Phase 1: Routing - assign tier to each subtask
|
|
for st in subtasks:
|
|
_update_subtask(st, status="routing", tier_assigned=st.get("tier_hint", 2))
|
|
|
|
# Phase 2: Execution - run subtasks in waves respecting dependencies
|
|
max_stall_cycles = 60 # 30 seconds max stall (0.5s * 60)
|
|
stall_count = 0
|
|
|
|
with ThreadPoolExecutor(max_workers=4) as executor:
|
|
while len(completed_indices) < len(subtasks):
|
|
# Find subtasks that are ready to execute
|
|
ready = []
|
|
for st in subtasks:
|
|
if st["index"] in completed_indices:
|
|
continue
|
|
if st["status"] in ("executing", "queued", "failed"):
|
|
continue
|
|
deps = set(st.get("depends_on", []))
|
|
if deps.issubset(completed_indices):
|
|
ready.append(st)
|
|
|
|
if not ready:
|
|
remaining = [s for s in subtasks if s["index"] not in completed_indices]
|
|
still_running = any(s["status"] in ("executing", "queued") for s in remaining)
|
|
if not still_running:
|
|
# Truly stuck — all remaining are blocked by failed deps
|
|
for s in remaining:
|
|
if s["status"] not in ("completed", "failed"):
|
|
_update_subtask(s, status="failed", result="Blocked by failed dependency")
|
|
completed_indices.add(s["index"])
|
|
break
|
|
stall_count += 1
|
|
if stall_count > max_stall_cycles:
|
|
break
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
stall_count = 0 # Reset on progress
|
|
|
|
# Launch ready subtasks in parallel
|
|
futures = {}
|
|
for st in ready:
|
|
_update_subtask(st, status="queued")
|
|
futures[executor.submit(_execute_subtask, st)] = st
|
|
|
|
for future in as_completed(futures):
|
|
st = futures[future]
|
|
try:
|
|
future.result()
|
|
except Exception as e:
|
|
_update_subtask(st, status="failed", result=str(e))
|
|
|
|
if st["status"] in ("completed", "failed"):
|
|
completed_indices.add(st["index"])
|
|
|
|
# Phase 3: Finalization
|
|
total_cost = sum(s.get("cost", 0) or 0 for s in subtasks)
|
|
all_ok = all(s["status"] == "completed" for s in subtasks)
|
|
|
|
with _tasks_lock:
|
|
task["status"] = "completed" if all_ok else "partial"
|
|
task["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
task["total_cost"] = total_cost
|
|
|
|
|
|
def submit_compound_task(prompt: str, auth_token: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Plan and begin executing a compound task.
|
|
|
|
Args:
|
|
prompt: The user's request to decompose and execute
|
|
auth_token: (Optional) authentication token for the submission
|
|
|
|
Returns:
|
|
Immediate response with task ID for polling:
|
|
{
|
|
"id": "compound-{uuid}",
|
|
"status": "planned",
|
|
"subtask_count": N
|
|
}
|
|
|
|
Process:
|
|
1. Use Haiku to plan the task (fast, synchronous)
|
|
2. Store task in memory
|
|
3. Spawn background thread for async execution
|
|
4. Return immediately to client for polling
|
|
|
|
The client can then poll /task/{task_id}/progress to monitor execution.
|
|
"""
|
|
|
|
# Phase 1: Plan (synchronous - fast, uses Haiku)
|
|
task = plan_task(prompt)
|
|
task_id = task["id"]
|
|
|
|
with _tasks_lock:
|
|
# Evict oldest task if we're at capacity
|
|
if len(_tasks) >= _MAX_TASKS:
|
|
oldest_key = min(_tasks, key=lambda k: _tasks[k].get("created_at", ""))
|
|
del _tasks[oldest_key]
|
|
_tasks[task_id] = task
|
|
|
|
# Phase 2: Execute (async - in background thread)
|
|
# The task will progress from "planned" -> "executing" -> "completed"/"partial"
|
|
thread = threading.Thread(target=_run_compound_task, args=(task_id,), daemon=True)
|
|
thread.start()
|
|
|
|
return {
|
|
"id": task_id,
|
|
"status": task["status"],
|
|
"subtask_count": len(task["subtasks"])
|
|
}
|
|
|
|
|
|
def get_task_progress(task_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get current state of a compound task (for polling/dashboard).
|
|
|
|
Args:
|
|
task_id: The task ID from submit_compound_task()
|
|
|
|
Returns:
|
|
Complete task snapshot including all subtask progress, or None if not found.
|
|
|
|
Returned structure:
|
|
{
|
|
"id": task_id,
|
|
"prompt": original prompt,
|
|
"status": "planned"|"executing"|"completed"|"partial",
|
|
"reasoning": explanation from planner,
|
|
"subtasks": [
|
|
{
|
|
"id": subtask ID,
|
|
"index": 0,
|
|
"description": task description,
|
|
"tier_hint": 1|2|3,
|
|
"tier_assigned": 1|2|3,
|
|
"model": "haiku"|"sonnet"|"opus"|None,
|
|
"depends_on": [indices],
|
|
"status": "pending"|"routing"|"queued"|"executing"|"completed"|"failed",
|
|
"result": str or None,
|
|
"cost": float or None,
|
|
"started_at": ISO8601 or None,
|
|
"completed_at": ISO8601 or None
|
|
},
|
|
...
|
|
],
|
|
"created_at": ISO8601,
|
|
"planned_at": ISO8601,
|
|
"completed_at": ISO8601 or None,
|
|
"total_cost": float
|
|
}
|
|
"""
|
|
with _tasks_lock:
|
|
task = _tasks.get(task_id)
|
|
if not task:
|
|
return None
|
|
# Return a deep copy for thread safety
|
|
return json.loads(json.dumps(task, default=str))
|
|
|
|
|
|
def list_recent_tasks(limit: int = 20) -> List[Dict[str, Any]]:
|
|
"""
|
|
List recent compound tasks (for dashboard view).
|
|
|
|
Args:
|
|
limit: Maximum number of tasks to return
|
|
|
|
Returns:
|
|
List of task summaries (most recent first):
|
|
[
|
|
{
|
|
"id": task ID,
|
|
"prompt": truncated prompt,
|
|
"status": current status,
|
|
"subtask_count": total subtasks,
|
|
"completed_count": subtasks finished,
|
|
"total_cost": cumulative USD cost,
|
|
"created_at": ISO8601,
|
|
"completed_at": ISO8601 or None
|
|
},
|
|
...
|
|
]
|
|
"""
|
|
with _tasks_lock:
|
|
tasks = sorted(_tasks.values(), key=lambda t: t.get("created_at", ""), reverse=True)
|
|
return [
|
|
{
|
|
"id": t["id"],
|
|
"prompt": t["prompt"][:100],
|
|
"status": t["status"],
|
|
"subtask_count": len(t["subtasks"]),
|
|
"completed_count": sum(1 for s in t["subtasks"] if s["status"] == "completed"),
|
|
"total_cost": t.get("total_cost", 0),
|
|
"created_at": t.get("created_at"),
|
|
"completed_at": t.get("completed_at")
|
|
}
|
|
for t in tasks[:limit]
|
|
]
|