symbiont_ex/lib/symbiont/router.ex

107 lines
3.5 KiB
Elixir

defmodule Symbiont.Router do
@moduledoc """
Task classifier — routes incoming tasks to the cheapest capable model tier.
Uses Haiku to classify task complexity, then returns a routing decision:
- Tier 1 (Haiku): simple extraction, formatting, classification
- Tier 2 (Sonnet): content writing, code gen, moderate reasoning
- Tier 3 (Opus): complex reasoning, strategy, full-context QA
The classifier prompt is intentionally concise to minimize routing cost.
"""
@type tier :: :haiku | :sonnet | :opus
@type routing :: %{tier: tier, confidence: String.t(), reason: String.t()}
@tier_costs %{
haiku: 0.008,
sonnet: 0.04,
opus: 0.15
}
@classifier_prompt "You are a task complexity classifier. Analyze the task and respond with ONLY valid JSON: " <>
~s({"tier": 1|2|3, "confidence": "low"|"medium"|"high", "reason": "brief explanation"}\n\n) <>
"Tier 1 (Haiku): Simple extraction, formatting, classification, short Q&A\n" <>
"Tier 2 (Sonnet): Content writing, code generation, analysis, moderate reasoning\n" <>
"Tier 3 (Opus): Complex multi-step reasoning, strategy, architecture, edge cases\n\n" <>
"Task: "
@doc "Classify a task and return the recommended tier."
@spec classify(String.t()) :: {:ok, routing()} | {:error, term()}
def classify(task) when is_binary(task) do
prompt = @classifier_prompt <> task
case Symbiont.Dispatcher.invoke(:haiku, prompt) do
{:ok, %{result: result_text} = result} ->
case parse_classification(result_text) do
{:ok, classification} ->
{:ok, Map.merge(classification, %{routing_cost: result.estimated_cost_usd})}
{:error, _reason} ->
# If parsing fails, default to Sonnet (safe middle ground)
{:ok, %{tier: :sonnet, confidence: "low", reason: "classification parse failed"}}
end
{:error, reason} ->
{:error, {:classifier_failed, reason}}
end
end
@doc "Route and execute a task end-to-end."
@spec route_and_execute(String.t(), keyword()) :: {:ok, map()} | {:error, term()}
def route_and_execute(task, opts \\ []) do
force_tier = Keyword.get(opts, :force_tier)
tier =
if force_tier do
normalize_tier(force_tier)
else
case classify(task) do
{:ok, %{tier: tier}} -> tier
{:error, _} -> :sonnet
end
end
Symbiont.Dispatcher.invoke(tier, task)
end
@doc "Return the approximate cost per call for a tier."
@spec tier_cost(tier()) :: float()
def tier_cost(tier), do: Map.get(@tier_costs, tier, 0.04)
# -- Private --
defp parse_classification(text) do
# Extract JSON from the response (may have surrounding text)
case Regex.run(~r/\{[^}]+\}/, text) do
[json_str] ->
case Jason.decode(json_str) do
{:ok, %{"tier" => tier_num} = data} ->
{:ok,
%{
tier: tier_from_number(tier_num),
confidence: data["confidence"] || "medium",
reason: data["reason"] || "classified"
}}
_ ->
{:error, :invalid_json}
end
nil ->
{:error, :no_json_found}
end
end
defp tier_from_number(1), do: :haiku
defp tier_from_number(2), do: :sonnet
defp tier_from_number(3), do: :opus
defp tier_from_number(_), do: :sonnet
defp normalize_tier("haiku"), do: :haiku
defp normalize_tier("sonnet"), do: :sonnet
defp normalize_tier("opus"), do: :opus
defp normalize_tier(tier) when tier in [:haiku, :sonnet, :opus], do: tier
defp normalize_tier(_), do: :sonnet
end