diff --git a/examples/localcowork/.env.example b/examples/localcowork/.env.example index 16e64bd..b18bd1e 100644 --- a/examples/localcowork/.env.example +++ b/examples/localcowork/.env.example @@ -11,6 +11,7 @@ # Text model API endpoint (OpenAI-compatible). Set by start-model.sh. # Default when using LFM2 via llama-server: http://localhost:8080/v1 # Default when using Ollama: http://localhost:11434/v1 +# Default when using LM Studio: http://localhost:1234/v1 # LOCALCOWORK_MODEL_ENDPOINT=http://localhost:8080/v1 # Vision model endpoint (for AI-powered OCR — optional, falls back to Tesseract) diff --git a/examples/localcowork/.git-hooks/pre-commit b/examples/localcowork/.git-hooks/pre-commit index d97f222..02cbe62 100755 --- a/examples/localcowork/.git-hooks/pre-commit +++ b/examples/localcowork/.git-hooks/pre-commit @@ -28,6 +28,20 @@ fi WARNINGS=() +# ── Check 0: Shell scripts pass shellcheck ─────────────────────────────────── + +for file in $STAGED_FILES; do + case "$file" in + *.sh) + if command -v shellcheck >/dev/null 2>&1; then + if ! shellcheck -s bash "$file" >/dev/null 2>&1; then + WARNINGS+=("shellcheck failed for $file (run: shellcheck -s bash $file)") + fi + fi + ;; + esac +done + # ── Check 1: Source files changed but PROGRESS.yaml not staged ────────────── HAS_SOURCE_CHANGES=false diff --git a/examples/localcowork/.github/workflows/shellcheck.yml b/examples/localcowork/.github/workflows/shellcheck.yml new file mode 100644 index 0000000..1738149 --- /dev/null +++ b/examples/localcowork/.github/workflows/shellcheck.yml @@ -0,0 +1,20 @@ +name: Shellcheck + +on: + push: + paths: + - "**.sh" + pull_request: + paths: + - "**.sh" + +jobs: + shellcheck: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Run shellcheck + uses: ludeeus/action-shellcheck@master + env: + SHELLCHECK_OPTS: "-s bash -S error" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml new file mode 100644 index 0000000..34a3fa8 --- /dev/null +++ b/src-tauri/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "localcowork" +version = "0.1.0" +description = "LocalCowork — on-device AI agent desktop app" +authors = ["LocalCowork Contributors"] +license = "MIT" +edition = "2021" +rust-version = "1.77" + +[build-dependencies] +tauri-build = { version = "2", features = [] } + +[dependencies] +tauri = { version = "2", features = ["devtools"] } +tauri-plugin-shell = "2" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +serde_yaml = "0.9" +tokio = { version = "1", features = ["full"] } +anyhow = "1" +thiserror = "2" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +uuid = { version = "1", features = ["v4"] } +chrono = { version = "0.4", features = ["serde"] } +reqwest = { version = "0.12", features = ["json", "stream"] } +rusqlite = { version = "0.32", features = ["bundled"] } +futures = "0.3" +dirs = "6" +sysinfo = "0.33" +sha2 = "0.10" +tauri-plugin-dialog = "2" + +[dev-dependencies] +tempfile = "3" + +[features] +default = ["custom-protocol"] +custom-protocol = ["tauri/custom-protocol"] + +[profile.release] +strip = true +lto = true +codegen-units = 1 +panic = "abort" diff --git a/src-tauri/build.rs b/src-tauri/build.rs new file mode 100644 index 0000000..261851f --- /dev/null +++ b/src-tauri/build.rs @@ -0,0 +1,3 @@ +fn main() { + tauri_build::build(); +} diff --git a/src-tauri/capabilities/default.json b/src-tauri/capabilities/default.json new file mode 100644 index 0000000..9102361 --- /dev/null +++ b/src-tauri/capabilities/default.json @@ -0,0 +1,15 @@ +{ + "identifier": "default", + "description": "Default capabilities for LocalCowork", + "windows": ["main"], + "permissions": [ + "core:default", + "shell:allow-open", + "shell:allow-execute", + "shell:allow-spawn", + "shell:allow-stdin-write", + "dialog:default", + "dialog:allow-open", + "dialog:allow-save" + ] +} diff --git a/src-tauri/entitlements.plist b/src-tauri/entitlements.plist new file mode 100644 index 0000000..8ce248b --- /dev/null +++ b/src-tauri/entitlements.plist @@ -0,0 +1,19 @@ + + + + + + com.apple.security.cs.allow-unsigned-executable-memory + + + com.apple.security.cs.disable-library-validation + + + com.apple.security.network.client + + + com.apple.security.files.user-selected.read-write + + + diff --git a/src-tauri/icons/128x128.png b/src-tauri/icons/128x128.png new file mode 100644 index 0000000..6a66724 Binary files /dev/null and b/src-tauri/icons/128x128.png differ diff --git a/src-tauri/icons/128x128@2x.png b/src-tauri/icons/128x128@2x.png new file mode 100644 index 0000000..323b36e Binary files /dev/null and b/src-tauri/icons/128x128@2x.png differ diff --git a/src-tauri/icons/32x32.png b/src-tauri/icons/32x32.png new file mode 100644 index 0000000..5568400 Binary files /dev/null and b/src-tauri/icons/32x32.png differ diff --git a/src-tauri/icons/icon.icns b/src-tauri/icons/icon.icns new file mode 100644 index 0000000..caa531a Binary files /dev/null and b/src-tauri/icons/icon.icns differ diff --git a/src-tauri/icons/icon.ico b/src-tauri/icons/icon.ico new file mode 100644 index 0000000..4107408 Binary files /dev/null and b/src-tauri/icons/icon.ico differ diff --git a/src-tauri/mcp-servers.json b/src-tauri/mcp-servers.json new file mode 100644 index 0000000..6df9c17 --- /dev/null +++ b/src-tauri/mcp-servers.json @@ -0,0 +1,3 @@ +{ + "servers": {} +} diff --git a/src-tauri/src/agent_core/conversation.rs b/src-tauri/src/agent_core/conversation.rs new file mode 100644 index 0000000..e21f2a0 --- /dev/null +++ b/src-tauri/src/agent_core/conversation.rs @@ -0,0 +1,841 @@ +//! ConversationManager — persistent conversation history with context window management. +//! +//! Responsibilities: +//! - Store and retrieve conversation messages (SQLite) +//! - Track token usage per message +//! - Enforce context window budget (32k default) via eviction +//! - Maintain session summaries for evicted turns +//! - Build `Vec` for the inference client + +use crate::inference::types::{ + ChatMessage, FunctionCallResponse, Role, ToolCallResponse, +}; + +use super::database::AgentDatabase; +use super::errors::AgentError; +use super::tokens; +use super::types::{ + ConversationMessage, ContextBudget, NewMessage, NewUndoEntry, SessionSummary, UndoEntry, +}; + +// ─── Constants ────────────────────────────────────────────────────────────── + +/// Default total context window size (tokens). +const DEFAULT_CONTEXT_WINDOW: u32 = 32_768; + +/// Tokens reserved for the system prompt (rules + few-shot examples). +const SYSTEM_PROMPT_BUDGET: u32 = 900; + +/// Default tokens reserved for tool definitions. +/// +/// Used when the actual tool definition tokens haven't been measured yet. +/// This is a conservative fallback — the real value should be computed from +/// the serialized tool definitions and set via `set_tool_definitions_budget()`. +const DEFAULT_TOOL_DEFINITIONS_BUDGET: u32 = 2_000; + +/// Tokens reserved for the model's output response. +/// +/// Every production agent reserves space for the model to generate its +/// response. Without this, the context window could be 100% filled with +/// input, leaving no room for output. +/// +/// Note: The PRD's "Active file/document content" budget (~9,500 tokens) +/// was a static reservation for a ProactiveContextor feature that hasn't +/// been built yet. When that feature is implemented, it will dynamically +/// claim tokens from the conversation budget — not from a phantom static +/// reservation that wastes 29% of the context window. +const OUTPUT_RESERVATION: u32 = 2_000; + +/// Safety buffer — never fill these tokens. +const SAFETY_BUFFER: u32 = 768; + +/// When remaining tokens drop below this, trigger eviction. +/// +/// Set to 5,000 so eviction fires well before the agent loop's +/// `MIN_ROUND_TOKEN_BUDGET` (1,500) gate kills the loop. With the +/// old value of 1,000, eviction never triggered because the budget +/// gate always fired first, making eviction effectively dead. +const EVICTION_THRESHOLD: u32 = 5_000; + +/// Number of most recent turns to keep in full detail during eviction. +const FULL_DETAIL_TURNS: usize = 10; + +/// Maximum tokens allowed for the session summary. +/// +/// Without a cap, each eviction cycle appends to the summary, which can +/// grow to 2,000+ tokens after 3 cycles — eating into the space eviction +/// was supposed to free. The cap keeps the most recent portion. +const MAX_SUMMARY_TOKENS: u32 = 500; + +// ─── ConversationManager ──────────────────────────────────────────────────── + +/// Manages conversation history, token budgets, and context window eviction. +pub struct ConversationManager { + /// SQLite database handle. + db: AgentDatabase, + /// Total context window size (configurable per model). + context_window: u32, + /// Actual tokens consumed by tool definitions (measured, not estimated). + /// + /// Set by `set_tool_definitions_budget()` after tool definitions are built. + /// Falls back to `DEFAULT_TOOL_DEFINITIONS_BUDGET` if not set. + tool_definitions_budget: u32, + /// Actual tokens consumed by the system prompt (measured, not estimated). + /// + /// Set by `set_system_prompt_budget()` after the dynamic system prompt is built. + /// Falls back to `SYSTEM_PROMPT_BUDGET` if not set. + system_prompt_budget: u32, +} + +impl ConversationManager { + /// Create a new ConversationManager backed by the given database. + pub fn new(db: AgentDatabase) -> Self { + Self { + db, + context_window: DEFAULT_CONTEXT_WINDOW, + tool_definitions_budget: DEFAULT_TOOL_DEFINITIONS_BUDGET, + system_prompt_budget: SYSTEM_PROMPT_BUDGET, + } + } + + /// Override the context window size (e.g., from model config). + pub fn set_context_window(&mut self, size: u32) { + self.context_window = size; + } + + /// Set the actual tool definitions token budget based on measured serialization. + /// + /// This should be called after tool definitions are built (in `send_message`) + /// so the budget calculation uses the real cost instead of the default estimate. + pub fn set_tool_definitions_budget(&mut self, tokens: u32) { + self.tool_definitions_budget = tokens; + } + + /// Set the actual system prompt token budget based on the dynamic prompt. + /// + /// Called in `start_session` after building the prompt from the MCP registry. + /// Ensures the context budget display reflects the real prompt size. + pub fn set_system_prompt_budget(&mut self, tokens: u32) { + self.system_prompt_budget = tokens; + } + + /// Access the underlying database (for ToolRouter/audit operations). + pub fn db(&self) -> &AgentDatabase { + &self.db + } + + // ─── Session Management ───────────────────────────────────────────── + + /// Start a new conversation session. + /// + /// Creates the session record and inserts the system prompt as the first + /// message. Returns the session ID. + pub fn new_session( + &self, + session_id: &str, + system_prompt: &str, + ) -> Result<(), AgentError> { + self.db.create_session(session_id)?; + + let token_count = tokens::estimate_system_prompt_tokens(system_prompt); + let msg = NewMessage { + role: Role::System, + content: Some(system_prompt.to_string()), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + self.db.insert_message(session_id, &msg, token_count)?; + Ok(()) + } + + // ─── Message Operations ───────────────────────────────────────────── + + /// Add a user message to the conversation. + pub fn add_user_message( + &self, + session_id: &str, + content: &str, + ) -> Result { + let token_count = tokens::estimate_tokens(content) + 4; // overhead + let msg = NewMessage { + role: Role::User, + content: Some(content.to_string()), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + self.db.insert_message(session_id, &msg, token_count) + } + + /// Add an assistant text message to the conversation. + pub fn add_assistant_message( + &self, + session_id: &str, + content: &str, + ) -> Result { + let token_count = tokens::estimate_tokens(content) + 4; + let msg = NewMessage { + role: Role::Assistant, + content: Some(content.to_string()), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + self.db.insert_message(session_id, &msg, token_count) + } + + /// Add an assistant message that contains tool calls. + pub fn add_tool_call_message( + &self, + session_id: &str, + tool_calls: &[crate::inference::types::ToolCall], + ) -> Result { + // Estimate tokens for tool calls + let mut token_count: u32 = 4; // overhead + for tc in tool_calls { + token_count += 10; // per-call overhead + token_count += tokens::estimate_tokens(&tc.name); + token_count += tokens::estimate_tokens( + &serde_json::to_string(&tc.arguments).unwrap_or_default(), + ); + } + + let msg = NewMessage { + role: Role::Assistant, + content: None, + tool_calls: Some(tool_calls.to_vec()), + tool_call_id: None, + tool_result: None, + }; + self.db.insert_message(session_id, &msg, token_count) + } + + /// Add a tool result message to the conversation. + pub fn add_tool_result_message( + &self, + session_id: &str, + tool_call_id: &str, + result: &serde_json::Value, + ) -> Result { + // Use the plain string if the value is a String, otherwise JSON-encode it. + // This avoids double-serialization (wrapping "text" as "\"text\"") which + // confuses local LLMs into thinking the tool result is empty/malformed. + let result_str = match result.as_str() { + Some(s) => s.to_string(), + None => serde_json::to_string(result).unwrap_or_default(), + }; + let token_count = tokens::estimate_tokens(&result_str) + 4; + + let msg = NewMessage { + role: Role::Tool, + content: Some(result_str), + tool_calls: None, + tool_call_id: Some(tool_call_id.to_string()), + tool_result: Some(result.clone()), + }; + self.db.insert_message(session_id, &msg, token_count) + } + + /// Get the full conversation history for a session. + pub fn get_history( + &self, + session_id: &str, + ) -> Result, AgentError> { + self.db.get_messages(session_id) + } + + /// Get the N most recent messages. + pub fn get_recent( + &self, + session_id: &str, + n: usize, + ) -> Result, AgentError> { + self.db.get_recent_messages(session_id, n) + } + + // ─── Context Window Management ────────────────────────────────────── + + /// Get the current context budget snapshot. + pub fn get_budget(&self, session_id: &str) -> Result { + let conversation_tokens = self.db.total_message_tokens(session_id)?; + let total = self.context_window; + let overhead = self.system_prompt_budget + self.tool_definitions_budget + + OUTPUT_RESERVATION + SAFETY_BUFFER; + let remaining = total.saturating_sub(overhead).saturating_sub(conversation_tokens); + + Ok(ContextBudget { + total, + system_prompt: self.system_prompt_budget, + tool_definitions: self.tool_definitions_budget, + conversation_history: conversation_tokens, + output_reservation: OUTPUT_RESERVATION, + remaining, + }) + } + + /// Check if eviction is needed and perform it. + /// + /// Evicts the oldest non-system messages until remaining tokens are + /// above the threshold. Evicted messages are summarized into the + /// session summary. + pub fn evict_if_needed(&self, session_id: &str) -> Result { + let budget = self.get_budget(session_id)?; + + if budget.remaining >= EVICTION_THRESHOLD { + return Ok(0); // No eviction needed + } + + let message_count = self.db.message_count(session_id)?; + if message_count <= FULL_DETAIL_TURNS + 1 { + // +1 for system prompt + return Ok(0); // Not enough messages to evict + } + + // Evict messages beyond the full-detail window + let evict_count = message_count - FULL_DETAIL_TURNS - 1; + let evicted = self.db.delete_oldest_messages(session_id, evict_count)?; + + // Build a summary from evicted messages + let mut summary_parts = Vec::new(); + let mut files: Vec = Vec::new(); + + for msg in &evicted { + let line = tokens::summarize_turn(&msg.role, msg.content.as_deref()); + summary_parts.push(line); + + // Track file paths mentioned in tool calls + if let Some(ref tc) = msg.tool_calls { + for call in tc { + if let Some(path) = call.arguments.get("path").and_then(|v| v.as_str()) { + if !files.contains(&path.to_string()) { + files.push(path.to_string()); + } + } + } + } + } + + let summary_text = summary_parts.join("\n"); + let evicted_tokens: u32 = evicted.iter().map(|m| m.token_count).sum(); + + // Update session summary (append to existing, then cap) + let existing = self.db.get_session_summary(session_id)?; + let full_summary = match existing { + Some(s) => format!("{}\n{}", s.summary_text, summary_text), + None => summary_text, + }; + + // Cap summary to prevent it from consuming the space eviction freed + let summary_tokens = tokens::estimate_tokens(&full_summary); + let capped_summary = if summary_tokens > MAX_SUMMARY_TOKENS { + let target_chars = (MAX_SUMMARY_TOKENS as f64 * 3.2) as usize; + let start = full_summary.len().saturating_sub(target_chars); + format!("[earlier context omitted]\n{}", &full_summary[start..]) + } else { + full_summary + }; + + self.db.update_session_summary( + session_id, + &capped_summary, + &files, + &[], // decisions are tracked separately + )?; + + Ok(evicted_tokens) + } + + /// Build the `Vec` to send to the inference client. + /// + /// Includes: session summary (if any) + system prompt + recent messages. + pub fn build_chat_messages( + &self, + session_id: &str, + ) -> Result, AgentError> { + let messages = self.db.get_messages(session_id)?; + let summary = self.db.get_session_summary(session_id)?; + + let mut chat_messages = Vec::new(); + + for msg in &messages { + match msg.role { + Role::System => { + // Prepend session summary to system prompt + let mut content = msg.content.clone().unwrap_or_default(); + if let Some(ref s) = summary { + content = format!( + "{content}\n\n## Previous conversation summary:\n{}", + s.summary_text + ); + } + chat_messages.push(ChatMessage { + role: Role::System, + content: Some(content), + tool_call_id: None, + tool_calls: None, + }); + } + Role::User => { + chat_messages.push(ChatMessage { + role: Role::User, + content: msg.content.clone(), + tool_call_id: None, + tool_calls: None, + }); + } + Role::Assistant => { + let tool_calls = msg.tool_calls.as_ref().map(|calls| { + calls + .iter() + .map(|tc| ToolCallResponse { + id: tc.id.clone(), + r#type: "function".to_string(), + function: FunctionCallResponse { + name: tc.name.clone(), + arguments: serde_json::to_string(&tc.arguments) + .unwrap_or_default(), + }, + }) + .collect() + }); + chat_messages.push(ChatMessage { + role: Role::Assistant, + content: msg.content.clone(), + tool_call_id: None, + tool_calls, + }); + } + Role::Tool => { + chat_messages.push(ChatMessage { + role: Role::Tool, + content: msg.content.clone(), + tool_call_id: msg.tool_call_id.clone(), + tool_calls: None, + }); + } + } + } + + Ok(chat_messages) + } + + /// Build a windowed `Vec` optimized for multi-step workflows. + /// + /// Implements a 3-tier message strategy to minimize token waste: + /// - **Tier 1 (recent)**: Last `recent_window` messages sent verbatim + /// - **Tier 2 (middle)**: Tool results compressed to one-line summaries; + /// user/assistant messages kept verbatim + /// - **Tier 3 (evicted)**: Already handled by session summary + /// + /// This prevents stale tool results from consuming context. A 6,000-char + /// OCR result from round 2 is compressed to ~50 chars in rounds 4+. + pub fn build_windowed_chat_messages( + &self, + session_id: &str, + recent_window: usize, + ) -> Result, AgentError> { + let messages = self.db.get_messages(session_id)?; + let summary = self.db.get_session_summary(session_id)?; + + let total = messages.len(); + // Window start index: everything before this is Tier 2 (compressed) + // +1 to account for system prompt at index 0 + let window_start = if total > recent_window + 1 { + total - recent_window + } else { + 1 // include everything after system prompt + }; + + let mut chat_messages = Vec::new(); + + for (i, msg) in messages.iter().enumerate() { + match msg.role { + Role::System => { + // Prepend session summary to system prompt (same as build_chat_messages) + let mut content = msg.content.clone().unwrap_or_default(); + if let Some(ref s) = summary { + content = format!( + "{content}\n\n## Previous conversation summary:\n{}", + s.summary_text + ); + } + chat_messages.push(ChatMessage { + role: Role::System, + content: Some(content), + tool_call_id: None, + tool_calls: None, + }); + } + Role::Tool if i < window_start => { + // Tier 2: compress old tool results to one-line summary + let compressed = tokens::summarize_tool_result( + msg.tool_call_id.as_deref().unwrap_or("tool"), + &msg.tool_result.clone().unwrap_or(serde_json::Value::Null), + ); + chat_messages.push(ChatMessage { + role: Role::Tool, + content: Some(compressed), + tool_call_id: msg.tool_call_id.clone(), + tool_calls: None, + }); + } + Role::User => { + chat_messages.push(ChatMessage { + role: Role::User, + content: msg.content.clone(), + tool_call_id: None, + tool_calls: None, + }); + } + Role::Assistant => { + let tool_calls = msg.tool_calls.as_ref().map(|calls| { + calls + .iter() + .map(|tc| ToolCallResponse { + id: tc.id.clone(), + r#type: "function".to_string(), + function: FunctionCallResponse { + name: tc.name.clone(), + arguments: serde_json::to_string(&tc.arguments) + .unwrap_or_default(), + }, + }) + .collect() + }); + chat_messages.push(ChatMessage { + role: Role::Assistant, + content: msg.content.clone(), + tool_call_id: None, + tool_calls, + }); + } + Role::Tool => { + // Tier 1: recent tool results — send verbatim + chat_messages.push(ChatMessage { + role: Role::Tool, + content: msg.content.clone(), + tool_call_id: msg.tool_call_id.clone(), + tool_calls: None, + }); + } + } + } + + Ok(chat_messages) + } + + // ─── Undo Stack (delegates to DB) ─────────────────────────────────── + + /// Push a new entry onto the undo stack. + pub fn push_undo( + &self, + session_id: &str, + entry: &NewUndoEntry, + ) -> Result { + self.db.push_undo_entry(session_id, entry) + } + + /// Get the current undo stack for a session. + pub fn get_undo_stack( + &self, + session_id: &str, + ) -> Result, AgentError> { + self.db.get_undo_stack(session_id) + } + + /// Mark an undo entry as undone. + pub fn mark_undone(&self, undo_id: i64) -> Result<(), AgentError> { + self.db.mark_undone(undo_id) + } + + /// Get the session summary. + pub fn get_session_summary( + &self, + session_id: &str, + ) -> Result, AgentError> { + self.db.get_session_summary(session_id) + } +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent_core::database::AgentDatabase; + + fn test_manager() -> ConversationManager { + let db = AgentDatabase::open(":memory:").unwrap(); + ConversationManager::new(db) + } + + #[test] + fn test_new_session() { + let mgr = test_manager(); + mgr.new_session("s1", "You are a helpful assistant.").unwrap(); + + let history = mgr.get_history("s1").unwrap(); + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, Role::System); + } + + #[test] + fn test_add_messages() { + let mgr = test_manager(); + mgr.new_session("s1", "system").unwrap(); + + mgr.add_user_message("s1", "Hello").unwrap(); + mgr.add_assistant_message("s1", "Hi there!").unwrap(); + + let history = mgr.get_history("s1").unwrap(); + assert_eq!(history.len(), 3); // system + user + assistant + assert_eq!(history[1].role, Role::User); + assert_eq!(history[2].role, Role::Assistant); + } + + #[test] + fn test_add_tool_call_and_result() { + let mgr = test_manager(); + mgr.new_session("s1", "system").unwrap(); + mgr.add_user_message("s1", "list files").unwrap(); + + let tool_calls = vec![crate::inference::types::ToolCall { + id: "call_1".to_string(), + name: "filesystem.list_dir".to_string(), + arguments: serde_json::json!({"path": "/tmp"}), + }]; + mgr.add_tool_call_message("s1", &tool_calls).unwrap(); + + let result = serde_json::json!({"files": ["a.txt", "b.txt"]}); + mgr.add_tool_result_message("s1", "call_1", &result).unwrap(); + + let history = mgr.get_history("s1").unwrap(); + assert_eq!(history.len(), 4); // system + user + assistant(tool_calls) + tool + assert_eq!(history[3].role, Role::Tool); + } + + #[test] + fn test_get_budget() { + let mgr = test_manager(); + mgr.new_session("s1", "system prompt").unwrap(); + + let budget = mgr.get_budget("s1").unwrap(); + assert_eq!(budget.total, DEFAULT_CONTEXT_WINDOW); + assert!(budget.remaining > 0); + assert!(budget.conversation_history > 0); // system prompt has tokens + } + + #[test] + fn test_eviction_not_needed() { + let mgr = test_manager(); + mgr.new_session("s1", "system").unwrap(); + mgr.add_user_message("s1", "hello").unwrap(); + + let evicted = mgr.evict_if_needed("s1").unwrap(); + assert_eq!(evicted, 0); + } + + #[test] + fn test_eviction_with_many_messages() { + let mgr = test_manager(); + // Use a tiny context window to force eviction + // We can't set context_window on test_manager easily, so let's + // just test that the logic works with enough messages. + mgr.new_session("s1", "system").unwrap(); + + // Add many large messages to exceed budget + for i in 0..50 { + let large_content = format!("message {i}: {}", "x".repeat(500)); + mgr.add_user_message("s1", &large_content).unwrap(); + } + + let count_before = mgr.get_history("s1").unwrap().len(); + assert!(count_before > FULL_DETAIL_TURNS); + + let evicted = mgr.evict_if_needed("s1").unwrap(); + + // With default 32k window, eviction may or may not trigger depending + // on actual token counts. Let's just verify no error. + // The important thing is the logic path works. + let _ = evicted; + } + + #[test] + fn test_build_chat_messages() { + let mgr = test_manager(); + mgr.new_session("s1", "You are helpful.").unwrap(); + mgr.add_user_message("s1", "Hello").unwrap(); + mgr.add_assistant_message("s1", "Hi!").unwrap(); + + let chat = mgr.build_chat_messages("s1").unwrap(); + assert_eq!(chat.len(), 3); + assert_eq!(chat[0].role, Role::System); + assert_eq!(chat[1].role, Role::User); + assert_eq!(chat[2].role, Role::Assistant); + } + + #[test] + fn test_build_chat_messages_with_summary() { + let mgr = test_manager(); + mgr.new_session("s1", "You are helpful.").unwrap(); + + // Manually set a session summary + mgr.db().update_session_summary( + "s1", + "User previously asked about files in /tmp.", + &["/tmp/file.txt".to_string()], + &[], + ).unwrap(); + + let chat = mgr.build_chat_messages("s1").unwrap(); + let system_content = chat[0].content.as_ref().unwrap(); + assert!(system_content.contains("Previous conversation summary")); + assert!(system_content.contains("files in /tmp")); + } + + #[test] + fn test_undo_stack() { + let mgr = test_manager(); + mgr.new_session("s1", "system").unwrap(); + + let entry = NewUndoEntry { + tool_name: "filesystem.move_file".to_string(), + action_type: "move".to_string(), + original_state: serde_json::json!({"path": "/old/file.txt"}), + new_state: serde_json::json!({"path": "/new/file.txt"}), + }; + let id = mgr.push_undo("s1", &entry).unwrap(); + + let stack = mgr.get_undo_stack("s1").unwrap(); + assert_eq!(stack.len(), 1); + assert_eq!(stack[0].tool_name, "filesystem.move_file"); + + mgr.mark_undone(id).unwrap(); + let stack = mgr.get_undo_stack("s1").unwrap(); + assert_eq!(stack.len(), 0); + } + + #[test] + fn test_budget_after_optimization() { + // After removing the phantom ACTIVE_CONTEXT_BUDGET (9,500) and + // replacing with OUTPUT_RESERVATION (2,000), a fresh session + // should have significantly more remaining budget. + let mgr = test_manager(); + mgr.new_session("s1", "short prompt").unwrap(); + + let budget = mgr.get_budget("s1").unwrap(); + // With 32K total, overhead = 500 (system) + 2000 (tools) + 2000 (output) + 768 (safety) + // = 5,268. Remaining = 32,768 - 5,268 - conversation_tokens ≈ 27,000+ + assert!( + budget.remaining > 20_000, + "remaining should be >20K after optimization, got {}", + budget.remaining + ); + assert_eq!(budget.output_reservation, 2_000); + } + + #[test] + fn test_build_windowed_compresses_old_tool_results() { + let mgr = test_manager(); + mgr.new_session("s1", "system").unwrap(); + + // Simulate a 3-round workflow so old tool results are clearly + // outside the recent window: + // 0: system + // 1: user ("process files") + // 2: assistant (tool_call 1) + // 3: tool (result 1 — large, should be compressed) + // 4: assistant (tool_call 2) + // 5: tool (result 2 — large, should be compressed) + // 6: assistant (tool_call 3) + // 7: tool (result 3 — recent, keep verbatim) + // 8: user ("continue") + // + // With window=4: window_start = 9 - 4 = 5 + // Messages at index < 5 are Tier 2 → tool results compressed + // Messages at index >= 5 are Tier 1 → verbatim + mgr.add_user_message("s1", "process files").unwrap(); + + for i in 1..=3 { + let tc = vec![crate::inference::types::ToolCall { + id: format!("call_{i}"), + name: "ocr.extract_text_from_image".to_string(), + arguments: serde_json::json!({"path": format!("/tmp/img{i}.png")}), + }]; + mgr.add_tool_call_message("s1", &tc).unwrap(); + + let large_result = serde_json::json!({"text": "x".repeat(200)}); + mgr.add_tool_result_message("s1", &format!("call_{i}"), &large_result) + .unwrap(); + } + + mgr.add_user_message("s1", "continue").unwrap(); + + // Full build: all messages verbatim + let full = mgr.build_chat_messages("s1").unwrap(); + // Windowed build with window=4: old tool results compressed + let windowed = mgr.build_windowed_chat_messages("s1", 4).unwrap(); + + // Both should have same number of messages + assert_eq!(full.len(), windowed.len()); + assert_eq!(full.len(), 9); // system + user + 3*(tc+result) + user + + // The first tool result (index 3) should be compressed in windowed + let full_tool = full[3].content.as_ref().unwrap(); + let windowed_tool = windowed[3].content.as_ref().unwrap(); + + assert!( + windowed_tool.len() < full_tool.len(), + "windowed tool result ({} chars) should be shorter than full ({} chars)", + windowed_tool.len(), + full_tool.len() + ); + // Compressed result should contain the summarization marker + assert!( + windowed_tool.starts_with('['), + "compressed result should be a summary bracket: {}", + windowed_tool + ); + + // The last tool result (index 7) should be verbatim (in recent window) + let full_recent = full[7].content.as_ref().unwrap(); + let windowed_recent = windowed[7].content.as_ref().unwrap(); + assert_eq!( + full_recent, windowed_recent, + "recent tool result should be verbatim" + ); + } + + #[test] + fn test_summary_capped_on_eviction() { + let mut mgr = test_manager(); + // Use a small context window to force eviction + mgr.set_context_window(4_000); + mgr.new_session("s1", "system").unwrap(); + + // Add enough large messages to trigger eviction + for i in 0..30 { + let content = format!("large message {i}: {}", "x".repeat(300)); + mgr.add_user_message("s1", &content).unwrap(); + } + + let evicted = mgr.evict_if_needed("s1").unwrap(); + assert!(evicted > 0, "eviction should have triggered"); + + // Check that summary exists and is capped + let summary = mgr.get_session_summary("s1").unwrap(); + assert!(summary.is_some(), "summary should exist after eviction"); + + let summary_text = summary.unwrap().summary_text; + let summary_tokens = tokens::estimate_tokens(&summary_text); + // The capped summary should be at most MAX_SUMMARY_TOKENS + some overhead + // from the "[earlier context omitted]" prefix + assert!( + summary_tokens <= MAX_SUMMARY_TOKENS + 20, + "summary should be capped at ~{} tokens, got {}", + MAX_SUMMARY_TOKENS, + summary_tokens + ); + } +} diff --git a/src-tauri/src/agent_core/database.rs b/src-tauri/src/agent_core/database.rs new file mode 100644 index 0000000..0cdf69d --- /dev/null +++ b/src-tauri/src/agent_core/database.rs @@ -0,0 +1,787 @@ +//! SQLite database for conversation history, sessions, and undo stack. +//! +//! Uses `rusqlite` in synchronous mode (Tauri commands run on a thread pool). +//! WAL mode is enabled for concurrent reads during streaming. + +use rusqlite::{params, Connection, OptionalExtension}; + +use super::errors::AgentError; +use super::types::{ + AuditEntry, AuditStatus, ConversationMessage, NewMessage, NewUndoEntry, Session, + SessionSummary, UndoEntry, +}; +use crate::inference::types::{Role, ToolCall}; + +// ─── Database ─────────────────────────────────────────────────────────────── + +/// SQLite database handle for the agent core. +pub struct AgentDatabase { + conn: Connection, +} + +impl AgentDatabase { + /// Open (or create) the agent database at the given path. + /// + /// Pass `":memory:"` for an in-memory database (tests). + pub fn open(path: &str) -> Result { + let conn = Connection::open(path)?; + + // Enable WAL mode for concurrent reads + conn.execute_batch("PRAGMA journal_mode=WAL;")?; + conn.execute_batch("PRAGMA foreign_keys=ON;")?; + + let db = Self { conn }; + db.create_tables()?; + Ok(db) + } + + /// Create all required tables if they don't exist. + fn create_tables(&self) -> Result<(), AgentError> { + self.conn.execute_batch( + " + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + last_activity TEXT NOT NULL DEFAULT (datetime('now')), + summary TEXT, + files_touched TEXT DEFAULT '[]', + decisions_made TEXT DEFAULT '[]' + ); + + CREATE TABLE IF NOT EXISTS conversation_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + timestamp TEXT NOT NULL DEFAULT (datetime('now')), + role TEXT NOT NULL, + content TEXT, + tool_calls TEXT, + tool_call_id TEXT, + tool_result TEXT, + token_count INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ); + + CREATE INDEX IF NOT EXISTS idx_messages_session + ON conversation_messages(session_id, id); + + CREATE TABLE IF NOT EXISTS undo_stack ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + timestamp TEXT NOT NULL DEFAULT (datetime('now')), + tool_name TEXT NOT NULL, + action_type TEXT NOT NULL, + original_state TEXT NOT NULL, + new_state TEXT NOT NULL, + undone INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ); + + CREATE INDEX IF NOT EXISTS idx_undo_session + ON undo_stack(session_id, undone); + + CREATE TABLE IF NOT EXISTS audit_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + timestamp TEXT NOT NULL DEFAULT (datetime('now')), + tool_name TEXT NOT NULL, + arguments TEXT, + result TEXT, + result_status TEXT NOT NULL, + user_confirmed INTEGER NOT NULL DEFAULT 0, + execution_time_ms INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ); + + CREATE INDEX IF NOT EXISTS idx_audit_session + ON audit_log(session_id); + ", + )?; + Ok(()) + } + + // ─── Sessions ─────────────────────────────────────────────────────── + + /// Create a new session with the given ID. + pub fn create_session(&self, session_id: &str) -> Result<(), AgentError> { + self.conn.execute( + "INSERT INTO sessions (id) VALUES (?1)", + params![session_id], + )?; + Ok(()) + } + + /// Get a session by ID. + pub fn get_session(&self, session_id: &str) -> Result, AgentError> { + let result = self + .conn + .query_row( + "SELECT id, created_at, last_activity, summary, files_touched, decisions_made + FROM sessions WHERE id = ?1", + params![session_id], + |row| { + Ok(Session { + id: row.get(0)?, + created_at: row.get(1)?, + last_activity: row.get(2)?, + summary: row.get(3)?, + files_touched: parse_json_array(row.get::<_, String>(4)?), + decisions_made: parse_json_array(row.get::<_, String>(5)?), + }) + }, + ) + .optional()?; + Ok(result) + } + + /// Update the session's last activity timestamp. + pub fn touch_session(&self, session_id: &str) -> Result<(), AgentError> { + self.conn.execute( + "UPDATE sessions SET last_activity = datetime('now') WHERE id = ?1", + params![session_id], + )?; + Ok(()) + } + + /// Update the session summary. + pub fn update_session_summary( + &self, + session_id: &str, + summary: &str, + files: &[String], + decisions: &[String], + ) -> Result<(), AgentError> { + let files_json = serde_json::to_string(files)?; + let decisions_json = serde_json::to_string(decisions)?; + self.conn.execute( + "UPDATE sessions SET summary = ?2, files_touched = ?3, decisions_made = ?4 + WHERE id = ?1", + params![session_id, summary, files_json, decisions_json], + )?; + Ok(()) + } + + /// Get the session summary for context window inclusion. + pub fn get_session_summary( + &self, + session_id: &str, + ) -> Result, AgentError> { + let session = self.get_session(session_id)?; + match session { + Some(s) if s.summary.is_some() => Ok(Some(SessionSummary { + session_id: s.id, + summary_text: s.summary.unwrap_or_default(), + files_touched: s.files_touched, + decisions_made: s.decisions_made, + })), + _ => Ok(None), + } + } + + /// List all sessions, ordered by most recent activity first. + pub fn list_sessions(&self) -> Result, AgentError> { + let mut stmt = self.conn.prepare( + "SELECT id, created_at, last_activity, summary, files_touched, decisions_made + FROM sessions + ORDER BY last_activity DESC", + )?; + + let rows = stmt.query_map([], |row| { + Ok(Session { + id: row.get(0)?, + created_at: row.get(1)?, + last_activity: row.get(2)?, + summary: row.get(3)?, + files_touched: parse_json_array(row.get::<_, String>(4)?), + decisions_made: parse_json_array(row.get::<_, String>(5)?), + }) + })?; + + let mut sessions = Vec::new(); + for row in rows { + sessions.push(row?); + } + Ok(sessions) + } + + /// Delete a session and all its associated messages, undo entries, and audit log. + pub fn delete_session(&self, session_id: &str) -> Result<(), AgentError> { + self.conn.execute( + "DELETE FROM audit_log WHERE session_id = ?1", + params![session_id], + )?; + self.conn.execute( + "DELETE FROM undo_stack WHERE session_id = ?1", + params![session_id], + )?; + self.conn.execute( + "DELETE FROM conversation_messages WHERE session_id = ?1", + params![session_id], + )?; + self.conn.execute( + "DELETE FROM sessions WHERE id = ?1", + params![session_id], + )?; + Ok(()) + } + + // ─── Messages ─────────────────────────────────────────────────────── + + /// Insert a new message into the conversation history. + pub fn insert_message( + &self, + session_id: &str, + msg: &NewMessage, + token_count: u32, + ) -> Result { + let role_str = role_to_str(&msg.role); + let tool_calls_json = msg + .tool_calls + .as_ref() + .map(|tc| serde_json::to_string(tc).unwrap_or_default()); + let tool_result_json = msg + .tool_result + .as_ref() + .map(|r| serde_json::to_string(r).unwrap_or_default()); + + self.conn.execute( + "INSERT INTO conversation_messages + (session_id, role, content, tool_calls, tool_call_id, tool_result, token_count) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", + params![ + session_id, + role_str, + msg.content, + tool_calls_json, + msg.tool_call_id, + tool_result_json, + token_count, + ], + )?; + + self.touch_session(session_id)?; + Ok(self.conn.last_insert_rowid()) + } + + /// Get all messages for a session, ordered by ID (chronological). + pub fn get_messages( + &self, + session_id: &str, + ) -> Result, AgentError> { + let mut stmt = self.conn.prepare( + "SELECT id, session_id, timestamp, role, content, + tool_calls, tool_call_id, tool_result, token_count + FROM conversation_messages + WHERE session_id = ?1 + ORDER BY id ASC", + )?; + + let rows = stmt.query_map(params![session_id], |row| { + Ok(row_to_message(row)) + })?; + + let mut messages = Vec::new(); + for row in rows { + messages.push(row?); + } + Ok(messages) + } + + /// Get the N most recent messages for a session. + pub fn get_recent_messages( + &self, + session_id: &str, + limit: usize, + ) -> Result, AgentError> { + let mut stmt = self.conn.prepare( + "SELECT id, session_id, timestamp, role, content, + tool_calls, tool_call_id, tool_result, token_count + FROM conversation_messages + WHERE session_id = ?1 + ORDER BY id DESC + LIMIT ?2", + )?; + + let rows = stmt.query_map(params![session_id, limit as i64], |row| { + Ok(row_to_message(row)) + })?; + + let mut messages = Vec::new(); + for row in rows { + messages.push(row?); + } + // Reverse so oldest is first + messages.reverse(); + Ok(messages) + } + + /// Get the total token count for all messages in a session. + pub fn total_message_tokens(&self, session_id: &str) -> Result { + let total: i64 = self.conn.query_row( + "SELECT COALESCE(SUM(token_count), 0) + FROM conversation_messages WHERE session_id = ?1", + params![session_id], + |row| row.get(0), + )?; + Ok(total as u32) + } + + /// Count messages in a session. + pub fn message_count(&self, session_id: &str) -> Result { + let count: i64 = self.conn.query_row( + "SELECT COUNT(*) FROM conversation_messages WHERE session_id = ?1", + params![session_id], + |row| row.get(0), + )?; + Ok(count as usize) + } + + /// Delete the oldest N messages from a session (for eviction). + pub fn delete_oldest_messages( + &self, + session_id: &str, + count: usize, + ) -> Result, AgentError> { + // First, fetch the messages we're about to delete + let mut stmt = self.conn.prepare( + "SELECT id, session_id, timestamp, role, content, + tool_calls, tool_call_id, tool_result, token_count + FROM conversation_messages + WHERE session_id = ?1 AND role != 'system' + ORDER BY id ASC + LIMIT ?2", + )?; + + let rows = stmt.query_map(params![session_id, count as i64], |row| { + Ok(row_to_message(row)) + })?; + + let mut evicted = Vec::new(); + for row in rows { + evicted.push(row?); + } + + // Delete them + if !evicted.is_empty() { + let ids: Vec = evicted.iter().map(|m| m.id).collect(); + let placeholders: Vec = ids.iter().map(|_| "?".to_string()).collect(); + let sql = format!( + "DELETE FROM conversation_messages WHERE id IN ({})", + placeholders.join(",") + ); + let params: Vec> = + ids.iter().map(|id| Box::new(*id) as Box).collect(); + self.conn.execute( + &sql, + rusqlite::params_from_iter(params.iter().map(|p| p.as_ref())), + )?; + } + + Ok(evicted) + } + + // ─── Undo Stack ───────────────────────────────────────────────────── + + /// Push a new undo entry. + pub fn push_undo_entry( + &self, + session_id: &str, + entry: &NewUndoEntry, + ) -> Result { + let original = serde_json::to_string(&entry.original_state)?; + let new_state = serde_json::to_string(&entry.new_state)?; + + self.conn.execute( + "INSERT INTO undo_stack + (session_id, tool_name, action_type, original_state, new_state) + VALUES (?1, ?2, ?3, ?4, ?5)", + params![ + session_id, + entry.tool_name, + entry.action_type, + original, + new_state, + ], + )?; + Ok(self.conn.last_insert_rowid()) + } + + /// Get all non-undone entries in the undo stack for a session. + pub fn get_undo_stack(&self, session_id: &str) -> Result, AgentError> { + let mut stmt = self.conn.prepare( + "SELECT id, session_id, timestamp, tool_name, action_type, + original_state, new_state, undone + FROM undo_stack + WHERE session_id = ?1 AND undone = 0 + ORDER BY id DESC", + )?; + + let rows = stmt.query_map(params![session_id], |row| { + Ok(UndoEntry { + id: row.get(0)?, + session_id: row.get(1)?, + timestamp: row.get(2)?, + tool_name: row.get(3)?, + action_type: row.get(4)?, + original_state: parse_json_value(row.get::<_, String>(5)?), + new_state: parse_json_value(row.get::<_, String>(6)?), + undone: row.get::<_, i32>(7)? != 0, + }) + })?; + + let mut entries = Vec::new(); + for row in rows { + entries.push(row?); + } + Ok(entries) + } + + /// Mark an undo entry as undone. + pub fn mark_undone(&self, undo_id: i64) -> Result<(), AgentError> { + let updated = self.conn.execute( + "UPDATE undo_stack SET undone = 1 WHERE id = ?1", + params![undo_id], + )?; + if updated == 0 { + return Err(AgentError::UndoFailed { + undo_id, + reason: "undo entry not found".to_string(), + }); + } + Ok(()) + } + + // ─── Audit Log ────────────────────────────────────────────────────── + + /// Insert an audit log entry. + #[allow(clippy::too_many_arguments)] + pub fn insert_audit_entry( + &self, + session_id: &str, + tool_name: &str, + arguments: &serde_json::Value, + result: Option<&serde_json::Value>, + status: AuditStatus, + user_confirmed: bool, + execution_time_ms: u64, + ) -> Result { + let args_json = serde_json::to_string(arguments)?; + let result_json = result.map(|r| serde_json::to_string(r).unwrap_or_default()); + + self.conn.execute( + "INSERT INTO audit_log + (session_id, tool_name, arguments, result, result_status, + user_confirmed, execution_time_ms) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", + params![ + session_id, + tool_name, + args_json, + result_json, + status.as_str(), + user_confirmed as i32, + execution_time_ms as i64, + ], + )?; + Ok(self.conn.last_insert_rowid()) + } + + /// Get audit log entries for a session. + pub fn get_audit_entries( + &self, + session_id: &str, + ) -> Result, AgentError> { + let mut stmt = self.conn.prepare( + "SELECT id, session_id, timestamp, tool_name, arguments, + result, result_status, user_confirmed, execution_time_ms + FROM audit_log + WHERE session_id = ?1 + ORDER BY id ASC", + )?; + + let rows = stmt.query_map(params![session_id], |row| { + Ok(AuditEntry { + id: row.get(0)?, + session_id: row.get(1)?, + timestamp: row.get(2)?, + tool_name: row.get(3)?, + arguments: parse_json_value(row.get::<_, String>(4)?), + result: row + .get::<_, Option>(5)? + .map(parse_json_value), + result_status: AuditStatus::parse( + &row.get::<_, String>(6)?, + ), + user_confirmed: row.get::<_, i32>(7)? != 0, + execution_time_ms: row.get::<_, i64>(8)? as u64, + }) + })?; + + let mut entries = Vec::new(); + for row in rows { + entries.push(row?); + } + Ok(entries) + } +} + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +/// Convert a rusqlite row to a ConversationMessage. +fn row_to_message(row: &rusqlite::Row<'_>) -> ConversationMessage { + ConversationMessage { + id: row.get(0).unwrap_or(0), + session_id: row.get(1).unwrap_or_default(), + timestamp: row.get(2).unwrap_or_default(), + role: str_to_role(&row.get::<_, String>(3).unwrap_or_default()), + content: row.get(4).unwrap_or(None), + tool_calls: row + .get::<_, Option>(5) + .unwrap_or(None) + .and_then(|s| serde_json::from_str::>(&s).ok()), + tool_call_id: row.get(6).unwrap_or(None), + tool_result: row + .get::<_, Option>(7) + .unwrap_or(None) + .map(parse_json_value), + token_count: row.get::<_, i32>(8).unwrap_or(0) as u32, + } +} + +/// Parse a JSON string into a Vec, defaulting to empty. +fn parse_json_array(json: String) -> Vec { + serde_json::from_str(&json).unwrap_or_default() +} + +/// Parse a JSON string into a serde_json::Value, defaulting to null. +fn parse_json_value(json: String) -> serde_json::Value { + serde_json::from_str(&json).unwrap_or(serde_json::Value::Null) +} + +/// Convert a Role to its string representation. +fn role_to_str(role: &Role) -> &'static str { + match role { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + } +} + +/// Parse a string into a Role. +fn str_to_role(s: &str) -> Role { + match s { + "system" => Role::System, + "user" => Role::User, + "assistant" => Role::Assistant, + "tool" => Role::Tool, + _ => Role::User, + } +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::inference::types::Role; + + fn test_db() -> AgentDatabase { + AgentDatabase::open(":memory:").unwrap() + } + + #[test] + fn test_create_and_get_session() { + let db = test_db(); + db.create_session("test-session-1").unwrap(); + + let session = db.get_session("test-session-1").unwrap(); + assert!(session.is_some()); + let s = session.unwrap(); + assert_eq!(s.id, "test-session-1"); + assert!(s.summary.is_none()); + } + + #[test] + fn test_session_not_found() { + let db = test_db(); + let session = db.get_session("nonexistent").unwrap(); + assert!(session.is_none()); + } + + #[test] + fn test_insert_and_get_messages() { + let db = test_db(); + db.create_session("s1").unwrap(); + + let msg1 = NewMessage { + role: Role::User, + content: Some("hello".to_string()), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + let msg2 = NewMessage { + role: Role::Assistant, + content: Some("hi there".to_string()), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + + db.insert_message("s1", &msg1, 10).unwrap(); + db.insert_message("s1", &msg2, 15).unwrap(); + + let messages = db.get_messages("s1").unwrap(); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, Role::User); + assert_eq!(messages[1].role, Role::Assistant); + assert_eq!(messages[0].token_count, 10); + } + + #[test] + fn test_get_recent_messages() { + let db = test_db(); + db.create_session("s1").unwrap(); + + for i in 0..10 { + let msg = NewMessage { + role: Role::User, + content: Some(format!("message {i}")), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + db.insert_message("s1", &msg, 5).unwrap(); + } + + let recent = db.get_recent_messages("s1", 3).unwrap(); + assert_eq!(recent.len(), 3); + assert!(recent[0].content.as_ref().unwrap().contains("message 7")); + assert!(recent[2].content.as_ref().unwrap().contains("message 9")); + } + + #[test] + fn test_total_message_tokens() { + let db = test_db(); + db.create_session("s1").unwrap(); + + let msg = NewMessage { + role: Role::User, + content: Some("test".to_string()), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + db.insert_message("s1", &msg, 10).unwrap(); + db.insert_message("s1", &msg, 20).unwrap(); + + assert_eq!(db.total_message_tokens("s1").unwrap(), 30); + } + + #[test] + fn test_delete_oldest_messages() { + let db = test_db(); + db.create_session("s1").unwrap(); + + // Insert system + 5 user messages + let sys = NewMessage { + role: Role::System, + content: Some("system prompt".to_string()), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + db.insert_message("s1", &sys, 50).unwrap(); + + for i in 0..5 { + let msg = NewMessage { + role: Role::User, + content: Some(format!("msg {i}")), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + db.insert_message("s1", &msg, 10).unwrap(); + } + + // Delete 2 oldest non-system messages + let evicted = db.delete_oldest_messages("s1", 2).unwrap(); + assert_eq!(evicted.len(), 2); + assert!(evicted[0].content.as_ref().unwrap().contains("msg 0")); + assert!(evicted[1].content.as_ref().unwrap().contains("msg 1")); + + // 4 remain (1 system + 3 user) + assert_eq!(db.message_count("s1").unwrap(), 4); + } + + #[test] + fn test_undo_stack() { + let db = test_db(); + db.create_session("s1").unwrap(); + + let entry = NewUndoEntry { + tool_name: "filesystem.move_file".to_string(), + action_type: "move".to_string(), + original_state: serde_json::json!({"path": "/old"}), + new_state: serde_json::json!({"path": "/new"}), + }; + let id = db.push_undo_entry("s1", &entry).unwrap(); + + let stack = db.get_undo_stack("s1").unwrap(); + assert_eq!(stack.len(), 1); + assert_eq!(stack[0].tool_name, "filesystem.move_file"); + assert!(!stack[0].undone); + + // Mark as undone + db.mark_undone(id).unwrap(); + let stack = db.get_undo_stack("s1").unwrap(); + assert_eq!(stack.len(), 0); // Filtered out + } + + #[test] + fn test_audit_log() { + let db = test_db(); + db.create_session("s1").unwrap(); + + let args = serde_json::json!({"path": "/tmp"}); + let result = serde_json::json!({"files": ["a.txt"]}); + + db.insert_audit_entry( + "s1", + "filesystem.list_dir", + &args, + Some(&result), + AuditStatus::Success, + false, + 42, + ) + .unwrap(); + + let entries = db.get_audit_entries("s1").unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].tool_name, "filesystem.list_dir"); + assert_eq!(entries[0].result_status, AuditStatus::Success); + assert!(!entries[0].user_confirmed); + assert_eq!(entries[0].execution_time_ms, 42); + } + + #[test] + fn test_update_session_summary() { + let db = test_db(); + db.create_session("s1").unwrap(); + + db.update_session_summary( + "s1", + "User asked to organize files in /tmp", + &["file.txt".to_string()], + &["moved files to /archive".to_string()], + ) + .unwrap(); + + let summary = db.get_session_summary("s1").unwrap(); + assert!(summary.is_some()); + let s = summary.unwrap(); + assert!(s.summary_text.contains("organize files")); + assert_eq!(s.files_touched, vec!["file.txt"]); + } +} diff --git a/src-tauri/src/agent_core/errors.rs b/src-tauri/src/agent_core/errors.rs new file mode 100644 index 0000000..b57b9e6 --- /dev/null +++ b/src-tauri/src/agent_core/errors.rs @@ -0,0 +1,67 @@ +//! Agent Core error types. + +use thiserror::Error; + +/// Errors that can occur during agent core operations. +#[derive(Debug, Error)] +pub enum AgentError { + /// Database operation failed. + #[error("database error: {reason}")] + DatabaseError { reason: String }, + + /// Session not found. + #[error("session not found: '{session_id}'")] + SessionNotFound { session_id: String }, + + /// Context window budget exceeded. + #[error("context window budget exceeded: {used} / {limit} tokens")] + ContextOverflow { used: u32, limit: u32 }, + + /// Token counting failed. + #[error("token estimation error: {reason}")] + TokenEstimationError { reason: String }, + + /// Tool execution error (wraps McpError). + #[error("tool execution failed: {reason}")] + ToolExecutionError { reason: String }, + + /// Tool call rejected by user. + #[error("tool call '{tool_name}' rejected by user")] + ToolCallRejected { tool_name: String }, + + /// Undo operation failed. + #[error("undo failed for entry {undo_id}: {reason}")] + UndoFailed { undo_id: i64, reason: String }, + + /// No undo entries available. + #[error("no undo entries in session '{session_id}'")] + NoUndoEntries { session_id: String }, + + /// Confirmation channel error. + #[error("confirmation channel error: {reason}")] + ConfirmationError { reason: String }, + + /// Audit log error. + #[error("audit log error: {reason}")] + AuditError { reason: String }, + + /// Serialization error. + #[error("serialization error: {reason}")] + SerializationError { reason: String }, +} + +impl From for AgentError { + fn from(e: rusqlite::Error) -> Self { + AgentError::DatabaseError { + reason: e.to_string(), + } + } +} + +impl From for AgentError { + fn from(e: serde_json::Error) -> Self { + AgentError::SerializationError { + reason: e.to_string(), + } + } +} diff --git a/src-tauri/src/agent_core/mod.rs b/src-tauri/src/agent_core/mod.rs new file mode 100644 index 0000000..8fd7be0 --- /dev/null +++ b/src-tauri/src/agent_core/mod.rs @@ -0,0 +1,38 @@ +//! Agent Core — orchestration layer for LocalCowork. +//! +//! Submodules: +//! - `conversation`: Conversation history and context management +//! - `tool_router`: Dispatches model tool calls to MCP servers +//! - `tokens`: Token estimation for context window budgets +//! - `database`: SQLite persistence for sessions, messages, undo stack, audit +//! - `permissions`: Tiered permission grants (once / session / always) +//! - `response_analysis`: Detect incomplete tasks, deflection (FM-3), completion +//! - `orchestrator`: Dual-model pipeline — planner (24B) + router (1.2B) (ADR-009) +//! - `plan_parser`: Bracket + JSON plan output parsers for the orchestrator +//! - `tool_prefilter`: RAG pre-filter for tool selection (ADR-010 / ADR-009) +//! - `types`: Shared types across the agent core +//! - `errors`: Agent-level error types + +pub mod conversation; +pub mod database; +pub mod errors; +pub mod orchestrator; +pub mod plan_parser; +pub mod plan_templates; +pub mod permissions; +pub mod response_analysis; +pub mod tokens; +pub mod tool_prefilter; +pub mod tool_router; +pub mod types; + +// Re-exports for convenience +pub use conversation::ConversationManager; +pub use database::AgentDatabase; +pub use errors::AgentError; +pub use permissions::PermissionStore; +pub use tool_router::ToolRouter; +pub use types::{ + AuditEntry, AuditStatus, ConfirmationRequest, ConfirmationResponse, ContextBudget, + ConversationMessage, NewMessage, NewUndoEntry, Session, SessionSummary, UndoEntry, +}; diff --git a/src-tauri/src/agent_core/orchestrator.rs b/src-tauri/src/agent_core/orchestrator.rs new file mode 100644 index 0000000..2222db2 --- /dev/null +++ b/src-tauri/src/agent_core/orchestrator.rs @@ -0,0 +1,1997 @@ +//! Dual-model orchestrator: Planner (LFM2-24B-A2B) + Router (LFM2.5-1.2B-Router-FT). +//! +//! Architecture (ADR-009): +//! 1. **Plan** — planner model (MoE, ~2B active) decomposes user request into steps +//! 2. **Execute** — router model selects and calls one tool per step (RAG pre-filtered K=15) +//! 3. **Synthesize** — planner model generates user-facing summary from step results +//! +//! Each step is a clean single-turn interaction with the router model — no +//! conversation history, no context accumulation. This preserves the 78% +//! single-step accuracy that degrades to 8% in multi-turn context. + +use serde::{Deserialize, Serialize}; +use std::sync::Mutex; + +use tokio::sync::Mutex as TokioMutex; + +use crate::agent_core::plan_parser::{parse_bracket_plan, parse_json_plan}; +use crate::agent_core::tokens::truncate_utf8; +use crate::agent_core::tool_prefilter::ToolEmbeddingIndex; +use crate::inference::client::InferenceClient; +use crate::inference::config::{ModelsConfig, OrchestratorConfig}; +use crate::inference::types::{ + ChatMessage, Role, SamplingOverrides, ToolCall, +}; +use crate::mcp_client::client::McpClient; + +// ─── Types ────────────────────────────────────────────────────────────────── + +/// A single step in the execution plan (from the planner model). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlanStep { + pub step_number: u32, + /// Self-contained instruction for the router model. + pub description: String, + /// Hint: which MCP server is likely needed (e.g., "filesystem"). + #[serde(default)] + pub expected_server: Option, + /// Hint: key parameter values from the user's request. + #[serde(default)] + pub hint_params: Option, +} + +/// Structured plan output from the planner model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StepPlan { + /// Whether the request needs tool calls at all. + pub needs_tools: bool, + /// Direct response when no tools needed. + #[serde(default)] + pub direct_response: Option, + /// Ordered sequence of tool steps. + #[serde(default)] + pub steps: Vec, +} + +/// Result of executing a single plan step. +#[derive(Debug, Clone, Serialize)] +pub struct StepExecutionResult { + pub step_number: u32, + pub description: String, + pub tool_called: Option, + pub tool_arguments: Option, + pub tool_result: Option, + pub success: bool, + pub error: Option, +} + +/// Full orchestration result. +#[derive(Debug, Clone)] +pub struct OrchestrationResult { + pub step_results: Vec, + pub synthesis: String, + pub all_steps_succeeded: bool, + /// True if the orchestrator aborted and the caller should fall back. + pub fell_back: bool, +} + +// ─── Orchestrator Entry Point ─────────────────────────────────────────────── + +/// Execute the dual-model orchestration pipeline. +/// +/// Returns `Ok(result)` on success. If `result.fell_back` is true, the caller +/// should fall through to the single-model agent loop. +#[allow(clippy::too_many_arguments)] +pub async fn orchestrate_dual_model( + session_id: &str, + user_message: &str, + conversation_history: &[ChatMessage], + models_config: &ModelsConfig, + orch_config: &OrchestratorConfig, + app_handle: &tauri::AppHandle, + conv_state: &Mutex, + mcp_state: &TokioMutex, +) -> Result { + // Create separate clients for planner and router + let mut planner = InferenceClient::from_config_with_model( + models_config.clone(), + &orch_config.planner_model, + ) + .map_err(|e| format!("planner client error: {e}"))?; + + let mut router = InferenceClient::from_config_with_model( + models_config.clone(), + &orch_config.router_model, + ) + .map_err(|e| format!("router client error: {e}"))?; + + tracing::info!( + planner = planner.current_model_name(), + router = router.current_model_name(), + "orchestrator: starting dual-model pipeline" + ); + + // ── Phase 0: Template match (M1) ──────────────────────────────────── + // Check if the user's message matches a known use case pattern. + // If matched, skip the planner entirely — use the pre-built plan. + let mut plan_is_template = false; + + // ── Phase 1: Plan ─────────────────────────────────────────────────── + let mut plan = if let Some(template_plan) = + crate::agent_core::plan_templates::try_template_match(user_message) + { + tracing::info!( + steps = template_plan.steps.len(), + "orchestrator: using template-matched plan (skipping planner)" + ); + plan_is_template = true; + template_plan + } else { + match plan_steps(&mut planner, user_message, conversation_history).await { + Ok(plan) => plan, + Err(e) => { + tracing::warn!(error = %e, "orchestrator: plan failed — falling back"); + return Ok(OrchestrationResult { + step_results: Vec::new(), + synthesis: String::new(), + all_steps_succeeded: false, + fell_back: true, + }); + } + } + }; + + // Post-plan decomposition check (Fix F11): only for model-generated plans. + // Template plans are pre-decomposed and don't need this check. + if !plan_is_template && plan_needs_decomposition(&plan, user_message) { + tracing::info!( + original_steps = plan.steps.len(), + "orchestrator: plan under-decomposed — re-planning with stronger prompt" + ); + let retry_message = format!( + "{}\n\n\ + CRITICAL: This request requires MULTIPLE steps across DIFFERENT servers. \ + You MUST break it into separate steps. Each step calls ONE tool from ONE server. \ + Do NOT combine scanning, reading, and task creation into a single step. \ + Look for these signals in the request: \"and\", \"then\", \"create a task\", \ + \"scan for X and Y\" — each signals a separate step.", + user_message + ); + if let Ok(retry_plan) = + plan_steps(&mut planner, &retry_message, conversation_history).await + { + if retry_plan.needs_tools && retry_plan.steps.len() > plan.steps.len() { + tracing::info!( + new_steps = retry_plan.steps.len(), + "orchestrator: re-plan produced more steps — using new plan" + ); + plan = retry_plan; + } + } + } + + // If no tools needed, stream the direct response. + // Note: stream-complete is NOT emitted here — the caller (chat.rs) + // handles persistence and the properly-formatted ChatMessage emission. + if !plan.needs_tools { + let response = plan.direct_response.unwrap_or_default(); + let _ = tauri::Emitter::emit(app_handle, "stream-token", &response); + return Ok(OrchestrationResult { + step_results: Vec::new(), + synthesis: response, + all_steps_succeeded: true, + fell_back: false, + }); + } + + let _ = tauri::Emitter::emit(app_handle, "plan-created", &plan.steps); + + tracing::info!(step_count = plan.steps.len(), "orchestrator: plan created"); + + // ── Build tool embedding index ────────────────────────────────────── + let tool_pairs: Vec<(String, String)> = { + let mcp = mcp_state.lock().await; + mcp.registry.tool_name_description_pairs() + }; + + let tool_index = match ToolEmbeddingIndex::build( + router.current_base_url(), + &tool_pairs, + ) + .await + { + Ok(index) => index, + Err(e) => { + tracing::warn!(error = %e, "orchestrator: tool index build failed — falling back"); + return Ok(OrchestrationResult { + step_results: Vec::new(), + synthesis: String::new(), + all_steps_succeeded: false, + fell_back: true, + }); + } + }; + + tracing::info!(tool_count = tool_index.len(), "orchestrator: tool index built"); + + // ── Plan validation gate (Improvement I4) ───────────────────────── + { + let mcp = mcp_state.lock().await; + for step in &plan.steps { + if let Some(ref server) = step.expected_server { + let prefix = format!("{}.", server); + let has_tools = mcp + .registry + .tool_name_description_pairs() + .iter() + .any(|(name, _)| name.starts_with(&prefix)); + if !has_tools { + tracing::warn!( + step = step.step_number, + server = %server, + "orchestrator: plan references unknown server" + ); + } + } + } + } + + // ── Phase 2: Execute each step ────────────────────────────────────── + let mut step_results: Vec = Vec::new(); + let mut any_critical_failure = false; + let total_steps = plan.steps.len(); + + for step in &plan.steps { + // Richer step progress events (Improvement I3) + let _ = tauri::Emitter::emit( + app_handle, + "step-executing", + &serde_json::json!({ + "step_number": step.step_number, + "total_steps": total_steps, + "description": step.description, + "server": step.expected_server, + }), + ); + + let result = execute_step( + step, + &step_results, + &mut router, + &tool_index, + orch_config, + mcp_state, + ) + .await; + + let _ = tauri::Emitter::emit( + app_handle, + "step-completed", + &serde_json::json!({ + "step_number": step.step_number, + "total_steps": total_steps, + "success": result.success, + "tool_called": result.tool_called, + "result_preview": result.tool_result.as_deref() + .map(|r| truncate_utf8(r, 200)), + }), + ); + + if !result.success { + tracing::warn!( + step = step.step_number, + error = result.error.as_deref().unwrap_or("unknown"), + "orchestrator: step failed" + ); + // Check if subsequent steps reference this step's result + let step_ref = format!("step {}", step.step_number); + let is_critical = plan.steps.iter().any(|s| { + s.step_number > step.step_number + && s.description.to_lowercase().contains(&step_ref) + }); + + if is_critical { + any_critical_failure = true; + step_results.push(result); + break; + } + } + + step_results.push(result); + } + + // If a critical step failed, fall back to single-model mode + if any_critical_failure { + tracing::warn!("orchestrator: critical step failed — falling back"); + return Ok(OrchestrationResult { + step_results, + synthesis: String::new(), + all_steps_succeeded: false, + fell_back: true, + }); + } + + // ── Phase 3: Synthesize ───────────────────────────────────────────── + let synthesis = synthesize_response( + &mut planner, + user_message, + &step_results, + app_handle, + ) + .await + .unwrap_or_else(|e| { + tracing::warn!(error = %e, "orchestrator: synthesis failed"); + // Build a basic summary from step results + step_results + .iter() + .filter(|r| r.success) + .map(|r| { + format!( + "- {}: {}", + r.description, + r.tool_result.as_deref().unwrap_or("done") + ) + }) + .collect::>() + .join("\n") + }); + + // Persist the synthesized response + { + let mgr = conv_state.lock().map_err(|e| format!("Lock error: {e}"))?; + let _ = mgr.add_assistant_message(session_id, &synthesis); + } + + // Note: stream-complete is NOT emitted here — the caller (chat.rs) + // handles the properly-formatted ChatMessage emission to avoid + // duplicating the message format in two places. + + let all_succeeded = step_results.iter().all(|r| r.success); + + Ok(OrchestrationResult { + step_results, + synthesis, + all_steps_succeeded: all_succeeded, + fell_back: false, + }) +} + +// ─── Phase 1: Plan ────────────────────────────────────────────────────────── + +/// System prompt for the planner model — uses bracket-format calls that LFM2-24B-A2B +/// can reliably produce (JSON output had 94% parse failure rate in benchmarks). +const PLANNER_SYSTEM_PROMPT: &str = r#"You are a task planner for LocalCowork. Given a user request, decompose it into a sequence of tool-calling steps. You do NOT call tools yourself. Output your plan using bracket function calls. + +Available capability areas (servers): +- filesystem: list, read, write, move, copy, delete, search files +- document: extract text from PDF/DOCX, convert formats, diff, create PDF/DOCX +- ocr: extract text from images/screenshots, extract structured data +- data: CSV/SQLite operations, deduplication, anomaly detection +- knowledge: semantic search across indexed documents, RAG Q&A +- security: PII/secrets scanning, file encryption, duplicate finding +- task: create/update/list tasks, daily briefing +- calendar: list events, create events, find free slots +- email: draft/send emails, search, summarize threads +- meeting: transcribe audio, extract action items, generate minutes +- audit: tool usage logs, session summaries +- clipboard: read/write system clipboard +- system: system info, open apps, take screenshots + +Rules: +1. Use bracket function calls to build the plan. No prose before or after. +2. If the request does NOT require tools, call: [plan.respond(message="your direct answer")] +3. Each step description must be COMPLETE and self-contained. +4. Include file paths, search terms, and specifics from the user message in each step. +5. For steps needing a prior result, write: "Using the result from step N, ..." +6. Maximum 10 steps. +7. End with [plan.done()] + +DECOMPOSITION RULES (critical): +- Each step calls EXACTLY ONE tool from ONE server. Never combine multiple operations. +- If the user says "scan for SSNs and API keys", that is TWO steps: one for PII scanning, one for secrets scanning. +- If the user says "do X and then create a task", that is at least TWO steps: the action + task creation. +- If scanning multiple files, create one step to list/discover them, then steps to scan each file type. +- Keywords that signal separate steps: "and", "then", "also", "follow up", "create a task". +- NEVER collapse a multi-server workflow into one step. When in doubt, create MORE steps. + +Examples: + +Simple single-step: +[plan.add_step(step=1, server="filesystem", description="List all files in /Users/chintan/Downloads")] +[plan.done()] + +Two-server chain (filesystem + task): +[plan.add_step(step=1, server="filesystem", description="Read the file /Users/chintan/Projects/localCoWork/tests/fixtures/uc4/downloads/quarterly_report.txt")] +[plan.add_step(step=2, server="task", description="Using the content from step 1, create a task titled 'Review Q4 numbers' with due date Friday, including key findings from the quarterly report in the description")] +[plan.done()] + +Multi-server workflow (filesystem + security + task): +[plan.add_step(step=1, server="filesystem", description="List all files in /Users/chintan/Projects/localCoWork/tests/fixtures/uc3/sample_files/")] +[plan.add_step(step=2, server="security", description="Using the result from step 1, scan each file found for PII (SSNs, phone numbers, addresses)")] +[plan.add_step(step=3, server="security", description="Using the result from step 1, scan each file found for secrets (API keys, passwords, tokens)")] +[plan.add_step(step=4, server="task", description="Using the results from steps 2 and 3, create a follow-up task to remediate any sensitive files found, including the file paths and findings in the description")] +[plan.done()] + +Document analysis with knowledge search and email (filesystem + document + knowledge + email): +[plan.add_step(step=1, server="filesystem", description="List all PDF and DOCX files in /Users/chintan/Documents/Contracts/")] +[plan.add_step(step=2, server="document", description="Using the result from step 1, extract text from the contract file found")] +[plan.add_step(step=3, server="knowledge", description="Using the extracted text from step 2, search the knowledge base for similar clauses or related documents")] +[plan.add_step(step=4, server="email", description="Using the findings from steps 2 and 3, draft an email summarizing the key contract points and any related precedents found")] +[plan.done()] + +File scan with OCR, PII detection, and remediation (filesystem + ocr + security + task + email): +[plan.add_step(step=1, server="filesystem", description="List all files in /Users/chintan/Downloads/ including PDFs and images")] +[plan.add_step(step=2, server="ocr", description="Using the result from step 1, extract text from any image files found (PNG, JPG, screenshots)")] +[plan.add_step(step=3, server="security", description="Using the results from steps 1 and 2, scan all extracted content for PII (SSNs, credit card numbers, phone numbers)")] +[plan.add_step(step=4, server="task", description="Using the results from step 3, create a remediation task listing each file with PII findings and recommended actions")] +[plan.add_step(step=5, server="email", description="Using the results from steps 3 and 4, draft a notification email summarizing the PII scan findings and the remediation task created")] +[plan.done()] + +Meeting processing with tasks, calendar, and follow-up (meeting + task + calendar + knowledge + email): +[plan.add_step(step=1, server="meeting", description="Transcribe the audio file /Users/chintan/Recordings/standup-2026-02-19.m4a")] +[plan.add_step(step=2, server="meeting", description="Using the transcript from step 1, extract action items and commitments from the meeting")] +[plan.add_step(step=3, server="task", description="Using the action items from step 2, create a task for each commitment with the assigned person and due date")] +[plan.add_step(step=4, server="calendar", description="Using the tasks from step 3, find free time slots this week to schedule focused work blocks for the high-priority tasks")] +[plan.add_step(step=5, server="knowledge", description="Using the transcript from step 1, index the meeting notes in the knowledge base for future search")] +[plan.add_step(step=6, server="email", description="Using the action items from step 2 and tasks from step 3, draft a meeting summary email to attendees with the action items and deadlines")] +[plan.done()] + +For non-tool requests: +[plan.respond(message="The answer to your question is...")]"#; + +/// Call the planner model to decompose the request into steps. +async fn plan_steps( + planner: &mut InferenceClient, + user_message: &str, + conversation_history: &[ChatMessage], +) -> Result { + let mut messages = vec![ChatMessage { + role: Role::System, + content: Some(PLANNER_SYSTEM_PROMPT.to_string()), + tool_call_id: None, + tool_calls: None, + }]; + + // Include recent conversation history for context (last 6 turns max) + let history_window = conversation_history + .iter() + .filter(|m| m.role != Role::System) + .rev() + .take(6) + .collect::>(); + for msg in history_window.into_iter().rev() { + messages.push(msg.clone()); + } + + messages.push(ChatMessage { + role: Role::User, + content: Some(user_message.to_string()), + tool_call_id: None, + tool_calls: None, + }); + + let sampling = SamplingOverrides { + temperature: Some(0.1), + top_p: Some(0.2), + }; + + let result = planner + .chat_completion(messages, None, Some(sampling)) + .await + .map_err(|e| format!("planner inference error: {e}"))?; + + let text = result.token.unwrap_or_default(); + let trimmed = text.trim(); + + // Try bracket-format parsing first (primary for LFM2-24B-A2B) + if let Some(plan) = parse_bracket_plan(trimmed) { + tracing::info!("orchestrator: parsed bracket-format plan"); + return Ok(plan); + } + + // Fall back to JSON parsing (for models that support it) + match parse_json_plan(trimmed) { + Ok(plan) => { + tracing::info!("orchestrator: parsed JSON-format plan (fallback)"); + Ok(plan) + } + Err(json_err) => Err(format!( + "failed to parse plan (bracket and JSON both failed)\n\ + {json_err}\n\ + Raw output: {trimmed}" + )), + } +} + +/// Check if a plan likely under-decomposed a compound request. +/// +/// If the user message contains signals of a multi-step workflow (e.g., "scan AND +/// create a task", "read file THEN create task") but the planner only produced one +/// step, we should re-plan with a stronger decomposition prompt. +fn plan_needs_decomposition(plan: &StepPlan, user_message: &str) -> bool { + // Already multi-step or no-tool — no re-plan needed + if plan.steps.len() > 1 || !plan.needs_tools { + return false; + } + + let lower = user_message.to_lowercase(); + + // Explicit compound keywords (user says "do X and then do Y") + let compound_signals = [ + " and then ", + " and create ", + " then create ", + " then tell ", + " then make ", + " also ", + " follow up ", + " and scan ", + " and a task", + ", create a task", + ", then ", + ]; + for signal in &compound_signals { + if lower.contains(signal) { + return true; + } + } + + // Multi-operation pairs (user mentions two distinct action types) + let multi_op_pairs = [ + ("scan", "create"), + ("read", "create"), + ("list", "create"), + ("scan", "task"), + ("search", "task"), + ("extract", "task"), + ("read", "task"), + // Security + task combinations + ("ssn", "task"), + ("pii", "task"), + ("secret", "task"), + ("api key", "task"), + // Scan for multiple things + ("ssn", "api key"), + ("pii", "secret"), + ]; + for (a, b) in &multi_op_pairs { + if lower.contains(a) && lower.contains(b) { + return true; + } + } + + false +} + +// ─── Phase 2: Execute ─────────────────────────────────────────────────────── + +/// Build a system prompt for the router that matches the fine-tuning training format. +/// +/// The fine-tuned router was trained with tools as a numbered text list in the system +/// prompt (`generate_training_data_v2.py` lines 281-290). Sending tools via the OpenAI +/// `tools` JSON parameter causes llama-server to reformat them via its chat template, +/// which the 1.2B model has never seen — causing 0% tool call rate. +fn build_router_system_prompt( + filtered_names: &[String], + mcp: &McpClient, +) -> String { + let mut tool_lines = Vec::new(); + for (i, name) in filtered_names.iter().enumerate() { + let desc = mcp + .registry + .get_tool(name) + .map(|d| d.description.clone()) + .unwrap_or_default(); + tool_lines.push(format!("{}. {} — {}", i + 1, name, desc)); + } + + format!( + "You are LocalCowork, a desktop AI assistant that runs entirely on-device. \ + You have access to the following tools. ALWAYS call exactly one tool using \ + bracket syntax: [server.tool(param=\"value\")]. NEVER ask questions. \ + NEVER say you cannot help. ALWAYS select the most appropriate tool.\n\n\ + Available tools:\n{}", + tool_lines.join("\n") + ) +} + +/// Server-aware adaptive tool selection (Improvement I1). +/// +/// Guarantees all tools from the planner's hinted server are included in the +/// candidate set. Remaining slots are filled by RAG similarity. +async fn adaptive_filter( + tool_index: &ToolEmbeddingIndex, + step: &PlanStep, + router_base_url: &str, + description: &str, + base_k: usize, + mcp: &McpClient, +) -> Vec { + // Start with all tools from the hinted server (if available) + let mut selected: Vec = Vec::new(); + if let Some(ref server) = step.expected_server { + let prefix = format!("{}.", server); + let server_tools: Vec = mcp + .registry + .tool_name_description_pairs() + .iter() + .filter(|(name, _)| name.starts_with(&prefix)) + .map(|(name, _)| name.clone()) + .collect(); + selected.extend(server_tools); + } + + // Fill remaining slots with RAG-filtered tools (excluding already-selected) + let remaining_k = base_k.saturating_sub(selected.len()); + if remaining_k > 0 { + // Fetch more than needed so we can deduplicate + let fetch_k = base_k * 2; + if let Ok((rag_names, _)) = + tool_index.filter(router_base_url, description, fetch_k).await + { + for name in rag_names { + if selected.len() >= base_k { + break; + } + if !selected.contains(&name) { + selected.push(name); + } + } + } + } + + selected +} + +/// Last-resort extraction: if the router produced text mentioning exactly one +/// tool name but without proper bracket syntax, construct the call (Improvement I2). +fn extract_fallback_tool_call( + response_text: &str, + filtered_names: &[String], +) -> Option { + let mentioned: Vec<&str> = filtered_names + .iter() + .filter(|name| response_text.contains(name.as_str())) + .map(|s| s.as_str()) + .collect(); + + if mentioned.len() == 1 { + let name = mentioned[0]; + // Try to extract simple path= or query= arguments from the response text + let args = extract_inline_args(response_text); + tracing::info!( + tool = name, + args_keys = ?args.as_object().map(|o| o.keys().collect::>()), + "fallback: extracted tool call from router text" + ); + Some(ToolCall { + id: format!("call_{}", uuid::Uuid::new_v4()), + name: name.to_string(), + arguments: args, + }) + } else { + None + } +} + +/// Extract inline `key="value"` arguments from free-form text. +fn extract_inline_args(text: &str) -> serde_json::Value { + let mut map = serde_json::Map::new(); + + // Scan for key="value" patterns (common in router output) + let mut i = 0; + let bytes = text.as_bytes(); + while i < bytes.len() { + // Look for '="' which signals a key="value" pair + if i + 1 < bytes.len() && bytes[i] == b'=' && bytes[i + 1] == b'"' { + // Walk backwards to find the key (alphanumeric + underscore) + let eq_pos = i; + let mut key_start = eq_pos; + while key_start > 0 + && (bytes[key_start - 1].is_ascii_alphanumeric() || bytes[key_start - 1] == b'_') + { + key_start -= 1; + } + let key = &text[key_start..eq_pos]; + + // Walk forward to find the closing quote + let val_start = i + 2; + if let Some(val_end_offset) = text[val_start..].find('"') { + let val = &text[val_start..val_start + val_end_offset]; + if !key.is_empty() && !val.is_empty() { + map.insert( + key.to_string(), + serde_json::Value::String(val.to_string()), + ); + } + i = val_start + val_end_offset + 1; + } else { + i += 2; + } + } else { + i += 1; + } + } + + serde_json::Value::Object(map) +} + +/// Construct tool arguments from the step description context instead of relying +/// on the router's arguments (which are often hallucinated from training data). +/// +/// The 1.2B router is excellent at tool selection (100% in Phase 2c) but poor at +/// argument construction — it regurgitates memorized example paths like +/// `~/Documents/example.txt` instead of extracting the actual path from the user's +/// message. This function extracts arguments from the step description, which +/// contains the user's actual intent and specific paths/values. +fn construct_args_from_context( + tool_name: &str, + step_description: &str, + mcp: &McpClient, +) -> serde_json::Value { + let mut args = serde_json::Map::new(); + + // Get the tool's parameter schema from MCP registry + let schema = mcp.registry.get_tool(tool_name).map(|t| &t.params_schema); + + if let Some(schema) = schema { + if let Some(props) = schema.get("properties").and_then(|p| p.as_object()) { + for (key, _prop_schema) in props { + if let Some(value) = extract_param_value(key, step_description) { + args.insert(key.clone(), value); + } + } + } + } + + serde_json::Value::Object(args) +} + +/// Extract a parameter value from the step description based on the parameter name. +fn extract_param_value( + param_name: &str, + description: &str, +) -> Option { + match param_name { + // Path-like parameters: extract file/directory paths from description + "path" | "file_path" | "dir_path" | "directory" | "source" | "destination" => { + extract_path_from_text(description).map(serde_json::Value::String) + } + // Title-like parameters: extract from quoted text or keywords + "title" | "name" => extract_title_from_text(description).map(serde_json::Value::String), + // Due date parameters + "due" | "due_date" => extract_date_from_text(description).map(serde_json::Value::String), + // Description/details parameters — use the step description itself + "description" | "details" | "body" | "content" | "text" => { + // Don't auto-fill content fields from step description — let the router handle it + // or let the tool use defaults. Step description is meta-info, not content. + None + } + _ => None, + } +} + +/// Extract a file/directory path from natural language text. +/// +/// Priority order: +/// 1. Explicit paths (starting with `/` or `~/`) +/// 2. Backtick-quoted paths +/// 3. Well-known directory references ("Downloads folder") +pub(crate) fn extract_path_from_text(text: &str) -> Option { + // Priority 1: Backtick-quoted paths (most explicit) + let mut search_from = 0; + while let Some(start) = text[search_from..].find('`') { + let abs_start = search_from + start + 1; + if let Some(end) = text[abs_start..].find('`') { + let content = &text[abs_start..abs_start + end]; + if content.contains('/') { + return Some(content.to_string()); + } + search_from = abs_start + end + 1; + } else { + break; + } + } + + // Priority 2: Absolute/home-relative paths in the text + for word in text.split_whitespace() { + let clean = word.trim_matches(|c: char| { + c == '`' || c == '\'' || c == '"' || c == ',' || c == ')' + }); + if (clean.starts_with('/') || clean.starts_with("~/")) && clean.len() > 2 { + return Some(clean.to_string()); + } + } + + // Priority 3: Well-known directory references + let lower = text.to_lowercase(); + if lower.contains("downloads folder") + || lower.contains("downloads directory") + || lower.contains("my downloads") + || (lower.contains("downloads") && lower.contains("folder")) + { + return Some("~/Downloads".to_string()); + } + if lower.contains("documents folder") || lower.contains("documents directory") { + return Some("~/Documents".to_string()); + } + if lower.contains("desktop folder") + || lower.contains("desktop directory") + || lower.contains("my desktop") + { + return Some("~/Desktop".to_string()); + } + if lower.contains("home folder") || lower.contains("home directory") { + return Some("~".to_string()); + } + + None +} + +/// Extract a title from text — looks for quoted strings or "titled X" patterns. +fn extract_title_from_text(text: &str) -> Option { + // Look for quoted strings (single or double) + for quote in ['"', '\''] { + let mut search_from = 0; + while let Some(start) = text[search_from..].find(quote) { + let abs_start = search_from + start + 1; + if let Some(end) = text[abs_start..].find(quote) { + let content = &text[abs_start..abs_start + end]; + // Skip very short or very long content + if content.len() >= 3 && content.len() <= 200 { + return Some(content.to_string()); + } + search_from = abs_start + end + 1; + } else { + break; + } + } + } + + // Look for "titled X" or "called X" pattern + for prefix in ["titled ", "called ", "named "] { + if let Some(idx) = text.to_lowercase().find(prefix) { + let after = &text[idx + prefix.len()..]; + let end = after + .find([',', '.', '\n']) + .unwrap_or(after.len()); + let title = after[..end].trim().trim_matches('\'').trim_matches('"'); + if !title.is_empty() && title.len() <= 200 { + return Some(title.to_string()); + } + } + } + + None +} + +/// Extract a date reference from text. +fn extract_date_from_text(text: &str) -> Option { + let lower = text.to_lowercase(); + + // Day names + for day in [ + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday", + ] { + if lower.contains(day) { + // Capitalize first letter for the output + let capitalized = format!("{}{}", &day[..1].to_uppercase(), &day[1..]); + return Some(capitalized); + } + } + + // Relative dates + if lower.contains("tomorrow") { + return Some("tomorrow".to_string()); + } + if lower.contains("today") { + return Some("today".to_string()); + } + if lower.contains("next week") { + return Some("next week".to_string()); + } + if lower.contains("end of week") || lower.contains("end of the week") { + return Some("Friday".to_string()); + } + + // ISO date patterns (YYYY-MM-DD) + for word in text.split_whitespace() { + let clean = word.trim_matches(|c: char| !c.is_ascii_alphanumeric() && c != '-'); + if clean.len() == 10 + && clean.as_bytes().get(4) == Some(&b'-') + && clean.as_bytes().get(7) == Some(&b'-') + { + return Some(clean.to_string()); + } + } + + None +} + +/// Check if a value looks like a placeholder from training data rather than a +/// real user-supplied argument. The 1.2B router often regurgitates example +/// values from its fine-tuning data instead of extracting from the user message. +fn is_placeholder_value(value: &serde_json::Value) -> bool { + if let Some(s) = value.as_str() { + let lower = s.to_lowercase(); + // Generic example paths from training data + lower == "~/documents/example.txt" + || lower == "~/documents/" + || lower == "~/documents" + || lower.contains("/example.txt") + || lower.contains("/example/") + // Literal placeholder strings + || lower == "value" + || lower == "content" + || lower == "text" + || lower == "description" + || lower == "placeholder" + // Schema-literal patterns (router copies from schema description) + || lower.contains("iso 8601") + || lower == "due date" + || lower == "search query" + || lower == "file path" + || lower == "directory path" + // Empty-ish + || s.trim().is_empty() + } else { + false + } +} + +/// Merge two argument objects with smart override logic. +/// +/// - **Path-like keys** (`path`, `file_path`, etc.) always prefer `primary` because +/// the router almost always hallucinates training-data paths. +/// - **Non-path keys** (`title`, `name`, etc.) only override the router's value +/// if the router produced a placeholder (detected by `is_placeholder_value`). +/// - Keys absent in `secondary` are filled from `primary`. +/// - Keys absent in `primary` are kept from `secondary`. +fn merge_args(primary: &serde_json::Value, secondary: &serde_json::Value) -> serde_json::Value { + let mut merged = serde_json::Map::new(); + + // Start with secondary (router's args — may be wrong but fills gaps) + if let Some(obj) = secondary.as_object() { + for (k, v) in obj { + merged.insert(k.clone(), v.clone()); + } + } + + // Selectively override with primary (context-extracted args) + if let Some(obj) = primary.as_object() { + for (k, v) in obj { + // Only consider non-empty extracted values + let is_meaningful = match v { + serde_json::Value::String(s) => !s.is_empty(), + serde_json::Value::Null => false, + _ => true, + }; + if !is_meaningful { + continue; + } + + let is_path_key = matches!( + k.as_str(), + "path" + | "file_path" + | "dir_path" + | "directory" + | "source" + | "destination" + | "folder" + ); + + if is_path_key { + // Path keys: ALWAYS override — router paths are almost always wrong + merged.insert(k.clone(), v.clone()); + } else if !merged.contains_key(k) { + // Key not in router args: fill from context extraction + merged.insert(k.clone(), v.clone()); + } else { + // Key exists in router args: only override if router value is a placeholder + let router_is_placeholder = merged + .get(k) + .map(is_placeholder_value) + .unwrap_or(true); + if router_is_placeholder { + merged.insert(k.clone(), v.clone()); + } + // Otherwise keep the router's value (it's probably correct) + } + } + } + + serde_json::Value::Object(merged) +} + +/// Execute a single plan step using the router model. +/// +/// Tools are presented in the system prompt as a numbered text list matching the +/// fine-tuning training format. The router's bracket-format output is parsed by +/// `parse_non_streaming_response` → `parse_bracket_tool_calls`. +async fn execute_step( + step: &PlanStep, + prior_results: &[StepExecutionResult], + router: &mut InferenceClient, + tool_index: &ToolEmbeddingIndex, + config: &OrchestratorConfig, + mcp_state: &TokioMutex, +) -> StepExecutionResult { + let description = + interpolate_prior_results(step.step_number, &step.description, prior_results); + + // Adaptive tool selection: server hint → RAG fill (Improvement I1) + let filtered_names = { + let mcp = mcp_state.lock().await; + adaptive_filter( + tool_index, + step, + router.current_base_url(), + &description, + config.router_top_k as usize, + &mcp, + ) + .await + }; + + if filtered_names.is_empty() { + return StepExecutionResult { + step_number: step.step_number, + description, + tool_called: None, + tool_arguments: None, + tool_result: None, + success: false, + error: Some("no tools available after filtering".to_string()), + }; + } + + // Build the system prompt matching training format (Fix F1 — the critical fix) + let router_system = { + let mcp = mcp_state.lock().await; + build_router_system_prompt(&filtered_names, &mcp) + }; + + tracing::info!( + step = step.step_number, + filtered_tool_count = filtered_names.len(), + filtered_tools = ?filtered_names.iter().take(5).collect::>(), + "router: step tools selected" + ); + + let sampling = SamplingOverrides { + temperature: Some(0.1), + top_p: Some(0.1), + }; + + let mut last_response_text = String::new(); + + // Try up to step_retries times + for attempt in 0..config.step_retries { + let prompt = if attempt == 0 { + description.clone() + } else { + // Enhanced retry with bracket format example (Fix F2) + format!( + "{}\n\nIMPORTANT: You MUST respond with exactly one tool call in bracket \ + format. Example: [filesystem.list_dir(path=\"/Users/chintan/Downloads\")]\n\ + Choose from: {}", + description, + filtered_names + .iter() + .take(5) + .cloned() + .collect::>() + .join(", ") + ) + }; + + let messages = vec![ + ChatMessage { + role: Role::System, + content: Some(router_system.clone()), + tool_call_id: None, + tool_calls: None, + }, + ChatMessage { + role: Role::User, + content: Some(prompt), + tool_call_id: None, + tool_calls: None, + }, + ]; + + // Pass tools: None — tools are in the system prompt, not the API parameter. + // parse_non_streaming_response() handles bracket-format extraction from text. + let result = match router.chat_completion(messages, None, Some(sampling)).await { + Ok(chunk) => chunk, + Err(e) => { + tracing::warn!( + step = step.step_number, + attempt = attempt, + error = %e, + "router inference error" + ); + continue; + } + }; + + // Diagnostic logging (Fix F5) + let response_text = result.token.as_deref().unwrap_or(""); + let has_native_tool_calls = result.tool_calls.is_some(); + tracing::info!( + step = step.step_number, + attempt = attempt, + response_text_len = response_text.len(), + response_text_preview = %truncate_utf8(response_text, 200), + has_native_tool_calls, + "router raw response" + ); + last_response_text = response_text.to_string(); + + // Check for tool calls in the response (bracket-parsed or native) + if let Some(ref tool_calls) = result.tool_calls { + if let Some(tc) = tool_calls.first() { + // Override router's args with context-extracted args (Fix F6). + // The 1.2B router is great at tool selection but hallucinates + // arguments from training data. Extract real args from the + // step description which contains the user's actual paths/values. + let overridden_args = { + let mcp_ref = mcp_state.lock().await; + let context_args = + construct_args_from_context(&tc.name, &description, &mcp_ref); + merge_args(&context_args, &tc.arguments) + }; + + tracing::info!( + step = step.step_number, + tool = %tc.name, + router_args = %tc.arguments, + final_args = %overridden_args, + "router selected tool" + ); + + // Execute the tool via MCP with overridden args + let mut mcp = mcp_state.lock().await; + let tool_result = + match mcp.call_tool(&tc.name, overridden_args.clone()).await { + Ok(res) if res.success => { + let text = res + .result + .and_then(|v| { + v.get("text") + .and_then(|t| t.as_str()) + .map(|s| s.to_string()) + .or_else(|| serde_json::to_string(&v).ok()) + }) + .unwrap_or_else(|| "ok".to_string()); + tracing::info!( + step = step.step_number, + tool = %tc.name, + result_len = text.len(), + result_preview = %truncate_utf8(&text, 200), + "step tool execution succeeded" + ); + text + } + Ok(res) => { + let err = + res.error.unwrap_or_else(|| "tool failed".to_string()); + tracing::warn!( + step = step.step_number, + tool = %tc.name, + error = %err, + "step tool execution failed" + ); + err + } + Err(e) => { + tracing::warn!( + step = step.step_number, + tool = %tc.name, + error = %e, + "step tool MCP error" + ); + format!("MCP error: {e}") + } + }; + + return StepExecutionResult { + step_number: step.step_number, + description: description.clone(), + tool_called: Some(tc.name.clone()), + tool_arguments: Some(overridden_args), + tool_result: Some(tool_result), + success: true, + error: None, + }; + } + } + + tracing::info!( + step = step.step_number, + attempt = attempt, + "router returned no tool call — retrying" + ); + } + + // All retries exhausted — try fallback extraction (Improvement I2) + if !last_response_text.is_empty() { + if let Some(tc) = extract_fallback_tool_call(&last_response_text, &filtered_names) { + // Apply same argument override as normal path (Fix F6) + let overridden_args = { + let mcp_ref = mcp_state.lock().await; + let context_args = + construct_args_from_context(&tc.name, &description, &mcp_ref); + merge_args(&context_args, &tc.arguments) + }; + + tracing::info!( + step = step.step_number, + tool = %tc.name, + final_args = %overridden_args, + "router fallback: extracted tool from text" + ); + let mut mcp = mcp_state.lock().await; + let tool_result = match mcp.call_tool(&tc.name, overridden_args.clone()).await { + Ok(res) if res.success => { + let text = res + .result + .and_then(|v| { + v.get("text") + .and_then(|t| t.as_str()) + .map(|s| s.to_string()) + .or_else(|| serde_json::to_string(&v).ok()) + }) + .unwrap_or_else(|| "ok".to_string()); + tracing::info!( + step = step.step_number, + tool = %tc.name, + result_len = text.len(), + result_preview = %truncate_utf8(&text, 200), + "fallback tool execution succeeded" + ); + text + } + Ok(res) => res.error.unwrap_or_else(|| "tool failed".to_string()), + Err(e) => format!("MCP error: {e}"), + }; + + return StepExecutionResult { + step_number: step.step_number, + description: description.clone(), + tool_called: Some(tc.name.clone()), + tool_arguments: Some(overridden_args), + tool_result: Some(tool_result), + success: true, + error: None, + }; + } + } + + StepExecutionResult { + step_number: step.step_number, + description, + tool_called: None, + tool_arguments: None, + tool_result: None, + success: false, + error: Some(format!( + "router failed to produce a tool call after {} attempts", + config.step_retries + )), + } +} + +/// Condense a step execution result into a 1-2 line summary. +/// +/// Extracts the key information (tool name, outcome, key data) rather than +/// dumping the full result text. Keeps the router's context clean and focused. +fn condense_step_result(result: &StepExecutionResult) -> String { + let tool = result.tool_called.as_deref().unwrap_or("unknown"); + + match &result.tool_result { + Some(text) if result.success => { + let summary = if text.len() <= 200 { + text.clone() + } else { + format!("{}... ({} chars total)", truncate_utf8(text, 150), text.len()) + }; + format!("Step {} ({}) succeeded: {}", result.step_number, tool, summary) + } + _ if !result.success => { + let err = result.error.as_deref().unwrap_or("unknown error"); + format!("Step {} ({}) failed: {}", result.step_number, tool, err) + } + _ => { + format!("Step {} ({}) succeeded", result.step_number, tool) + } + } +} + +/// Enhance the step description with prior step results (M3). +/// +/// Three forwarding mechanisms: +/// 1. **Immediate predecessor**: Step N always gets step N-1's condensed result, +/// regardless of whether the description references it explicitly. +/// 2. **Explicit references**: If the description mentions "step M", that step's +/// condensed result is also included (preserving existing behavior). +/// 3. **Deduplication**: Each step's result appears at most once in the context block. +fn interpolate_prior_results( + step_number: u32, + description: &str, + prior_results: &[StepExecutionResult], +) -> String { + if prior_results.is_empty() { + return description.to_string(); + } + + let mut context_lines: Vec = Vec::new(); + let mut included_steps: Vec = Vec::new(); + + // 1. Always include the immediately preceding step's result + if let Some(prev) = prior_results.iter().rfind(|r| r.step_number == step_number - 1) { + if prev.success { + context_lines.push(condense_step_result(prev)); + included_steps.push(prev.step_number); + } + } + + // 2. Include any explicitly referenced steps (e.g., "step 2" in step 5's description) + let lower_desc = description.to_lowercase(); + for prior in prior_results { + if !prior.success || included_steps.contains(&prior.step_number) { + continue; + } + let step_ref = format!("step {}", prior.step_number); + if lower_desc.contains(&step_ref) { + context_lines.push(condense_step_result(prior)); + included_steps.push(prior.step_number); + } + } + + // 3. Build enhanced description with a clean [Prior step context] block + if context_lines.is_empty() { + description.to_string() + } else { + format!( + "{}\n\n[Prior step context]:\n{}", + description, + context_lines.join("\n") + ) + } +} + +// ─── Phase 3: Synthesize ──────────────────────────────────────────────────── + +/// Generate a user-facing summary from accumulated step results. +async fn synthesize_response( + planner: &mut InferenceClient, + user_message: &str, + step_results: &[StepExecutionResult], + app_handle: &tauri::AppHandle, +) -> Result { + let results_summary: String = step_results + .iter() + .map(|r| { + if r.success { + format!( + "Step {}: {} → {} → {}", + r.step_number, + r.description, + r.tool_called.as_deref().unwrap_or("none"), + r.tool_result + .as_deref() + .map(|s| truncate_utf8(s, 500)) + .unwrap_or("ok") + ) + } else { + format!( + "Step {}: {} → FAILED: {}", + r.step_number, + r.description, + r.error.as_deref().unwrap_or("unknown") + ) + } + }) + .collect::>() + .join("\n"); + + let synthesis_prompt = format!( + "The user asked: \"{user_message}\"\n\n\ + The following tool actions were executed:\n{results_summary}\n\n\ + Provide a clear, helpful summary of what was done and the results. \ + Be concise. Summarize ONLY the results that succeeded. For failed steps, \ + honestly report that the action could not be completed. NEVER fabricate results." + ); + + let messages = vec![ + ChatMessage { + role: Role::System, + content: Some( + "You are LocalCowork, an on-device AI assistant. Summarize the tool \ + results for the user. Be clear and concise. Only report what actually \ + happened." + .to_string(), + ), + tool_call_id: None, + tool_calls: None, + }, + ChatMessage { + role: Role::User, + content: Some(synthesis_prompt), + tool_call_id: None, + tool_calls: None, + }, + ]; + + let sampling = SamplingOverrides { + temperature: Some(0.7), + top_p: Some(0.9), + }; + + // Stream the synthesis to the frontend + use futures::StreamExt; + + let stream = planner + .chat_completion_stream(messages, None, Some(sampling)) + .await + .map_err(|e| format!("synthesis streaming error: {e}"))?; + + futures::pin_mut!(stream); + + let mut full_text = String::new(); + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if let Some(ref token) = chunk.token { + full_text.push_str(token); + let _ = tauri::Emitter::emit(app_handle, "stream-token", token); + } + } + Err(e) => { + tracing::warn!(error = %e, "synthesis stream error"); + break; + } + } + } + + if full_text.is_empty() { + return Err("synthesis produced empty response".to_string()); + } + + Ok(full_text) +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn interpolate_no_prior_results() { + let desc = "List files in /tmp"; + let result = interpolate_prior_results(1, desc, &[]); + assert_eq!(result, desc); + } + + #[test] + fn interpolate_with_explicit_step_reference() { + let desc = "Using the result from step 1, extract text from image"; + let prior = vec![StepExecutionResult { + step_number: 1, + description: "List files".into(), + tool_called: Some("filesystem.list_dir".into()), + tool_arguments: None, + tool_result: Some("[\"file1.png\", \"file2.png\"]".into()), + success: true, + error: None, + }]; + let result = interpolate_prior_results(2, desc, &prior); + assert!(result.contains("[Prior step context]")); + assert!(result.contains("file1.png")); + } + + #[test] + fn interpolate_skips_failed_steps() { + let desc = "Using the result from step 1, continue"; + let prior = vec![StepExecutionResult { + step_number: 1, + description: "Failed step".into(), + tool_called: None, + tool_arguments: None, + tool_result: None, + success: false, + error: Some("timeout".into()), + }]; + let result = interpolate_prior_results(2, desc, &prior); + assert!(!result.contains("[Prior step context]")); + } + + #[test] + fn interpolate_always_forwards_predecessor() { + // Step 2 should get step 1's result even without explicit "step 1" reference + let desc = "Extract text from the document"; + let prior = vec![StepExecutionResult { + step_number: 1, + description: "List files".into(), + tool_called: Some("filesystem.list_dir".into()), + tool_arguments: None, + tool_result: Some("[\"report.pdf\", \"notes.txt\"]".into()), + success: true, + error: None, + }]; + let result = interpolate_prior_results(2, desc, &prior); + assert!(result.contains("[Prior step context]"), "should include context block"); + assert!(result.contains("report.pdf"), "should include predecessor result"); + } + + #[test] + fn interpolate_deduplicates_predecessor_and_explicit() { + // Step 2 references "step 1" explicitly, and step 1 is also the predecessor. + // The result should appear only ONCE. + let desc = "Using the result from step 1, extract text"; + let prior = vec![StepExecutionResult { + step_number: 1, + description: "List files".into(), + tool_called: Some("filesystem.list_dir".into()), + tool_arguments: None, + tool_result: Some("[\"a.txt\"]".into()), + success: true, + error: None, + }]; + let result = interpolate_prior_results(2, desc, &prior); + let context_count = result.matches("filesystem.list_dir").count(); + assert_eq!(context_count, 1, "predecessor + explicit should not duplicate"); + } + + #[test] + fn interpolate_includes_explicit_and_predecessor() { + // Step 3 references "step 1" explicitly; step 2 is the predecessor. + // Both should appear. + let desc = "Using the results from step 1, create a task"; + let prior = vec![ + StepExecutionResult { + step_number: 1, + description: "Scan files".into(), + tool_called: Some("security.scan_for_pii".into()), + tool_arguments: None, + tool_result: Some("Found 3 files with SSNs".into()), + success: true, + error: None, + }, + StepExecutionResult { + step_number: 2, + description: "Scan secrets".into(), + tool_called: Some("security.scan_for_secrets".into()), + tool_arguments: None, + tool_result: Some("Found 1 API key".into()), + success: true, + error: None, + }, + ]; + let result = interpolate_prior_results(3, desc, &prior); + assert!(result.contains("Found 1 API key"), "should include step 2 (predecessor)"); + assert!(result.contains("Found 3 files"), "should include step 1 (referenced)"); + } + + #[test] + fn interpolate_skips_failed_predecessor() { + let desc = "Continue processing"; + let prior = vec![StepExecutionResult { + step_number: 1, + description: "Failed step".into(), + tool_called: None, + tool_arguments: None, + tool_result: None, + success: false, + error: Some("timeout".into()), + }]; + let result = interpolate_prior_results(2, desc, &prior); + assert!(!result.contains("[Prior step context]")); + } + + #[test] + fn condense_short_result_includes_full_text() { + let step = StepExecutionResult { + step_number: 1, + description: "List files".into(), + tool_called: Some("filesystem.list_dir".into()), + tool_arguments: None, + tool_result: Some("[\"a.txt\", \"b.pdf\"]".into()), + success: true, + error: None, + }; + let condensed = condense_step_result(&step); + assert!(condensed.contains("filesystem.list_dir")); + assert!(condensed.contains("succeeded")); + assert!(condensed.contains("a.txt")); + } + + #[test] + fn condense_long_result_truncates() { + let long_text = "x".repeat(500); + let step = StepExecutionResult { + step_number: 2, + description: "Extract text".into(), + tool_called: Some("document.extract_text".into()), + tool_arguments: None, + tool_result: Some(long_text), + success: true, + error: None, + }; + let condensed = condense_step_result(&step); + assert!(condensed.len() < 300, "condensed should be much shorter than 500"); + assert!(condensed.contains("500 chars total")); + } + + #[test] + fn condense_failed_result() { + let step = StepExecutionResult { + step_number: 3, + description: "Scan PII".into(), + tool_called: Some("security.scan_for_pii".into()), + tool_arguments: None, + tool_result: None, + success: false, + error: Some("file not found".into()), + }; + let condensed = condense_step_result(&step); + assert!(condensed.contains("failed")); + assert!(condensed.contains("file not found")); + } + + #[test] + fn extract_inline_args_basic() { + let text = r#"I'll call filesystem.list_dir with path="/Users/chintan/Downloads""#; + let args = extract_inline_args(text); + assert_eq!( + args.get("path").and_then(|v| v.as_str()), + Some("/Users/chintan/Downloads") + ); + } + + #[test] + fn extract_inline_args_multiple() { + let text = r#"query="find documents" path="/tmp/data""#; + let args = extract_inline_args(text); + assert_eq!( + args.get("query").and_then(|v| v.as_str()), + Some("find documents") + ); + assert_eq!( + args.get("path").and_then(|v| v.as_str()), + Some("/tmp/data") + ); + } + + #[test] + fn extract_inline_args_empty_when_no_patterns() { + let text = "Just some plain text without any arguments"; + let args = extract_inline_args(text); + assert!(args.as_object().unwrap().is_empty()); + } + + #[test] + fn extract_fallback_single_mention() { + let names = vec![ + "filesystem.list_dir".to_string(), + "filesystem.read_file".to_string(), + "document.extract_text".to_string(), + ]; + let text = "I would use filesystem.list_dir to list the Downloads folder"; + let result = extract_fallback_tool_call(text, &names); + assert!(result.is_some()); + assert_eq!(result.unwrap().name, "filesystem.list_dir"); + } + + #[test] + fn extract_fallback_multiple_mentions_returns_none() { + let names = vec![ + "filesystem.list_dir".to_string(), + "filesystem.read_file".to_string(), + ]; + let text = "I could use filesystem.list_dir or filesystem.read_file"; + let result = extract_fallback_tool_call(text, &names); + assert!(result.is_none()); + } + + #[test] + fn extract_fallback_no_mention_returns_none() { + let names = vec!["filesystem.list_dir".to_string()]; + let text = "I don't know which tool to use"; + let result = extract_fallback_tool_call(text, &names); + assert!(result.is_none()); + } + + // ─── Fix F6: Argument extraction tests ────────────────────────────── + + #[test] + fn extract_path_absolute() { + let text = "List all files in /Users/chintan/Downloads"; + assert_eq!( + extract_path_from_text(text), + Some("/Users/chintan/Downloads".to_string()) + ); + } + + #[test] + fn extract_path_tilde() { + let text = "Read ~/Projects/localCoWork/README.md"; + assert_eq!( + extract_path_from_text(text), + Some("~/Projects/localCoWork/README.md".to_string()) + ); + } + + #[test] + fn extract_path_backtick() { + let text = "Scan files in `tests/fixtures/uc3/sample_files/` for PII"; + assert_eq!( + extract_path_from_text(text), + Some("tests/fixtures/uc3/sample_files/".to_string()) + ); + } + + #[test] + fn extract_path_downloads_folder() { + let text = "What files are in my Downloads folder?"; + assert_eq!( + extract_path_from_text(text), + Some("~/Downloads".to_string()) + ); + } + + #[test] + fn extract_path_desktop() { + let text = "Show me what's on my Desktop"; + assert_eq!( + extract_path_from_text(text), + Some("~/Desktop".to_string()) + ); + } + + #[test] + fn extract_path_none_when_no_path() { + let text = "Create a new task to review the report"; + assert_eq!(extract_path_from_text(text), None); + } + + #[test] + fn extract_title_quoted() { + let text = "Create a task titled 'Review Q4 numbers' with due date Friday"; + assert_eq!( + extract_title_from_text(text), + Some("Review Q4 numbers".to_string()) + ); + } + + #[test] + fn extract_title_double_quoted() { + let text = r#"Create a task titled "Fix the login bug" by tomorrow"#; + assert_eq!( + extract_title_from_text(text), + Some("Fix the login bug".to_string()) + ); + } + + #[test] + fn extract_date_friday() { + let text = "Review Q4 numbers by Friday"; + assert_eq!( + extract_date_from_text(text), + Some("Friday".to_string()) + ); + } + + #[test] + fn extract_date_tomorrow() { + let text = "Follow up on this tomorrow"; + assert_eq!( + extract_date_from_text(text), + Some("tomorrow".to_string()) + ); + } + + #[test] + fn extract_date_iso() { + let text = "Schedule the meeting for 2026-03-15"; + assert_eq!( + extract_date_from_text(text), + Some("2026-03-15".to_string()) + ); + } + + #[test] + fn merge_args_path_always_overrides() { + // Path keys always prefer primary (context-extracted) over secondary (router) + let primary = serde_json::json!({"path": "~/Downloads"}); + let secondary = serde_json::json!({"path": "~/Documents/example.txt", "recursive": true}); + let merged = merge_args(&primary, &secondary); + assert_eq!(merged.get("path").unwrap(), "~/Downloads"); + assert_eq!(merged.get("recursive").unwrap(), true); + } + + #[test] + fn merge_args_empty_primary_keeps_secondary() { + let primary = serde_json::json!({}); + let secondary = serde_json::json!({"path": "/tmp/file.txt"}); + let merged = merge_args(&primary, &secondary); + assert_eq!(merged.get("path").unwrap(), "/tmp/file.txt"); + } + + #[test] + fn merge_args_null_primary_keeps_secondary() { + let primary = serde_json::json!({"path": null}); + let secondary = serde_json::json!({"path": "/tmp/file.txt"}); + let merged = merge_args(&primary, &secondary); + assert_eq!(merged.get("path").unwrap(), "/tmp/file.txt"); + } + + #[test] + fn merge_args_preserves_good_router_title() { + // Router produced a real title (not a placeholder) — keep it + let primary = serde_json::json!({"title": "Some extracted title"}); + let secondary = serde_json::json!({"title": "Review Q4 numbers"}); + let merged = merge_args(&primary, &secondary); + // Router's title "Review Q4 numbers" is not a placeholder, so it should be kept + assert_eq!(merged.get("title").unwrap(), "Review Q4 numbers"); + } + + #[test] + fn merge_args_overrides_placeholder_title() { + // Router produced a placeholder title — override with context-extracted + let primary = serde_json::json!({"title": "Follow up on sensitive files"}); + let secondary = serde_json::json!({"title": "content"}); + let merged = merge_args(&primary, &secondary); + assert_eq!(merged.get("title").unwrap(), "Follow up on sensitive files"); + } + + #[test] + fn merge_args_overrides_placeholder_due_date() { + // Router produced a schema-literal due date — override + let primary = serde_json::json!({"due_date": "Friday"}); + let secondary = serde_json::json!({"due_date": "Due date (ISO 8601)"}); + let merged = merge_args(&primary, &secondary); + assert_eq!(merged.get("due_date").unwrap(), "Friday"); + } + + #[test] + fn merge_args_fills_missing_router_keys() { + // Router didn't produce a due_date — fill from primary + let primary = serde_json::json!({"due_date": "Friday"}); + let secondary = serde_json::json!({"title": "Review Q4 numbers"}); + let merged = merge_args(&primary, &secondary); + assert_eq!(merged.get("title").unwrap(), "Review Q4 numbers"); + assert_eq!(merged.get("due_date").unwrap(), "Friday"); + } + + // ─── Fix F10: Placeholder detection tests ─────────────────────────── + + #[test] + fn placeholder_detects_example_path() { + assert!(is_placeholder_value(&serde_json::json!( + "~/Documents/example.txt" + ))); + } + + #[test] + fn placeholder_detects_schema_literal() { + assert!(is_placeholder_value(&serde_json::json!( + "Due date (ISO 8601)" + ))); + } + + #[test] + fn placeholder_detects_generic_words() { + assert!(is_placeholder_value(&serde_json::json!("content"))); + assert!(is_placeholder_value(&serde_json::json!("value"))); + assert!(is_placeholder_value(&serde_json::json!("description"))); + } + + #[test] + fn placeholder_rejects_real_values() { + assert!(!is_placeholder_value(&serde_json::json!( + "Review Q4 numbers" + ))); + assert!(!is_placeholder_value(&serde_json::json!("Friday"))); + assert!(!is_placeholder_value(&serde_json::json!( + "/Users/chintan/Downloads" + ))); + } + + #[test] + fn placeholder_detects_empty_string() { + assert!(is_placeholder_value(&serde_json::json!(""))); + assert!(is_placeholder_value(&serde_json::json!(" "))); + } + + // ─── Fix F11: Plan decomposition check tests ──────────────────────── + + fn single_step_plan() -> StepPlan { + StepPlan { + needs_tools: true, + direct_response: None, + steps: vec![PlanStep { + step_number: 1, + description: "Do something".into(), + expected_server: Some("security".into()), + hint_params: None, + }], + } + } + + fn multi_step_plan() -> StepPlan { + StepPlan { + needs_tools: true, + direct_response: None, + steps: vec![ + PlanStep { + step_number: 1, + description: "Scan files".into(), + expected_server: Some("security".into()), + hint_params: None, + }, + PlanStep { + step_number: 2, + description: "Create task".into(), + expected_server: Some("task".into()), + hint_params: None, + }, + ], + } + } + + #[test] + fn decomp_triggers_on_scan_and_task() { + let plan = single_step_plan(); + let msg = "Scan files for SSNs and API keys, then create a task to follow up"; + assert!(plan_needs_decomposition(&plan, msg)); + } + + #[test] + fn decomp_triggers_on_read_and_create() { + let plan = single_step_plan(); + let msg = "Read the quarterly report and create a task to review the numbers"; + assert!(plan_needs_decomposition(&plan, msg)); + } + + #[test] + fn decomp_triggers_on_ssn_and_api_key() { + let plan = single_step_plan(); + let msg = "Scan for SSN and API key issues in my documents"; + assert!(plan_needs_decomposition(&plan, msg)); + } + + #[test] + fn decomp_skips_already_multi_step() { + let plan = multi_step_plan(); + let msg = "Scan files and create a task"; + assert!(!plan_needs_decomposition(&plan, msg)); + } + + #[test] + fn decomp_skips_simple_request() { + let plan = single_step_plan(); + let msg = "What files are in my Downloads folder?"; + assert!(!plan_needs_decomposition(&plan, msg)); + } + + #[test] + fn decomp_skips_no_tools_plan() { + let plan = StepPlan { + needs_tools: false, + direct_response: Some("Hello!".into()), + steps: vec![], + }; + let msg = "Hello, how are you?"; + assert!(!plan_needs_decomposition(&plan, msg)); + } +} diff --git a/src-tauri/src/agent_core/permissions.rs b/src-tauri/src/agent_core/permissions.rs new file mode 100644 index 0000000..880bf6a --- /dev/null +++ b/src-tauri/src/agent_core/permissions.rs @@ -0,0 +1,350 @@ +//! Permission Store — tiered permission grants for tool execution. +//! +//! Supports three tiers: +//! - **Allow Once** (default Confirmed) — no grant stored, ask every time. +//! - **Allow for Session** — grant lives until `clear_session()` is called. +//! - **Always Allow** — persisted to platform data dir / `permissions.json`. +//! +//! The ToolRouter checks `PermissionStore::check()` before entering the +//! confirmation flow. If the tool has an active grant, confirmation is skipped. + +use std::collections::HashMap; +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +// ─── Types ────────────────────────────────────────────────────────────────── + +/// Scope of a permission grant. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PermissionScope { + /// Valid until the session ends. + Session, + /// Persisted across restarts. + Always, +} + +/// A single permission grant for a tool. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PermissionGrant { + /// The fully-qualified tool name (e.g. "filesystem.write_file"). + pub tool_name: String, + /// Whether this is a session or persistent grant. + pub scope: PermissionScope, + /// ISO 8601 timestamp when the grant was created. + pub granted_at: String, +} + +/// Result of checking a tool's permission status. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PermissionStatus { + /// Tool has an active grant — skip confirmation. + Allowed, + /// No grant — proceed with normal confirmation flow. + NeedsConfirmation, +} + +// ─── Persistent Format ────────────────────────────────────────────────────── + +/// On-disk format for `permissions.json`. +#[derive(Debug, Default, Serialize, Deserialize)] +struct PersistedGrants { + /// Version for forward compatibility. + version: u32, + /// Tool name → grant mapping. + grants: HashMap, +} + +// ─── PermissionStore ──────────────────────────────────────────────────────── + +/// Manages session and persistent permission grants. +pub struct PermissionStore { + /// Grants that expire when the session ends. + session_grants: HashMap, + /// Grants persisted to disk (loaded on startup, saved on mutation). + persistent_grants: HashMap, + /// Path to the permissions JSON file. + persist_path: PathBuf, +} + +impl Default for PermissionStore { + fn default() -> Self { + Self::new() + } +} + +impl PermissionStore { + /// Create a new PermissionStore and load any persisted grants. + pub fn new() -> Self { + let persist_path = Self::default_persist_path(); + let mut store = Self { + session_grants: HashMap::new(), + persistent_grants: HashMap::new(), + persist_path, + }; + store.load_from_disk(); + store + } + + /// Create a PermissionStore for testing (in-memory only, no disk I/O). + #[cfg(test)] + pub fn new_in_memory() -> Self { + Self { + session_grants: HashMap::new(), + persistent_grants: HashMap::new(), + persist_path: PathBuf::from("/dev/null"), + } + } + + /// Check if a tool has an active permission grant. + pub fn check(&self, tool_name: &str) -> PermissionStatus { + if self.persistent_grants.contains_key(tool_name) + || self.session_grants.contains_key(tool_name) + { + PermissionStatus::Allowed + } else { + PermissionStatus::NeedsConfirmation + } + } + + /// Grant a permission for a tool. + pub fn grant(&mut self, tool_name: &str, scope: PermissionScope) { + let grant = PermissionGrant { + tool_name: tool_name.to_string(), + scope, + granted_at: chrono::Utc::now().to_rfc3339(), + }; + + match scope { + PermissionScope::Session => { + self.session_grants.insert(tool_name.to_string(), grant); + } + PermissionScope::Always => { + self.persistent_grants + .insert(tool_name.to_string(), grant); + self.save_to_disk(); + } + } + + tracing::info!( + tool = tool_name, + scope = ?scope, + "permission granted" + ); + } + + /// Revoke a persistent permission grant. + pub fn revoke(&mut self, tool_name: &str) -> bool { + let removed_persistent = self.persistent_grants.remove(tool_name).is_some(); + let removed_session = self.session_grants.remove(tool_name).is_some(); + + if removed_persistent { + self.save_to_disk(); + } + + let removed = removed_persistent || removed_session; + if removed { + tracing::info!(tool = tool_name, "permission revoked"); + } + removed + } + + /// List all persistent grants (for the Settings UI). + pub fn list_persistent(&self) -> Vec<&PermissionGrant> { + let mut grants: Vec<&PermissionGrant> = self.persistent_grants.values().collect(); + grants.sort_by(|a, b| a.tool_name.cmp(&b.tool_name)); + grants + } + + /// Clear all session grants (called when a session ends). + pub fn clear_session(&mut self) { + let count = self.session_grants.len(); + self.session_grants.clear(); + if count > 0 { + tracing::info!(cleared = count, "session permissions cleared"); + } + } + + // ─── Persistence ──────────────────────────────────────────────────── + + /// Default path: platform-standard data directory / `permissions.json`. + fn default_persist_path() -> PathBuf { + crate::data_dir().join("permissions.json") + } + + /// Load persistent grants from disk. + fn load_from_disk(&mut self) { + if !self.persist_path.exists() { + return; + } + + match std::fs::read_to_string(&self.persist_path) { + Ok(content) => match serde_json::from_str::(&content) { + Ok(persisted) => { + tracing::info!( + count = persisted.grants.len(), + path = %self.persist_path.display(), + "loaded persistent permissions" + ); + self.persistent_grants = persisted.grants; + } + Err(e) => { + tracing::warn!( + error = %e, + path = %self.persist_path.display(), + "failed to parse permissions file, starting fresh" + ); + } + }, + Err(e) => { + tracing::warn!( + error = %e, + path = %self.persist_path.display(), + "failed to read permissions file" + ); + } + } + } + + /// Save persistent grants to disk (atomic write). + fn save_to_disk(&self) { + let persisted = PersistedGrants { + version: 1, + grants: self.persistent_grants.clone(), + }; + + let content = match serde_json::to_string_pretty(&persisted) { + Ok(c) => c, + Err(e) => { + tracing::error!(error = %e, "failed to serialize permissions"); + return; + } + }; + + // Ensure parent directory exists + if let Some(parent) = self.persist_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + + // Write to temp file, then rename for atomicity + let tmp_path = self.persist_path.with_extension("json.tmp"); + if let Err(e) = std::fs::write(&tmp_path, &content) { + tracing::error!(error = %e, "failed to write permissions temp file"); + return; + } + if let Err(e) = std::fs::rename(&tmp_path, &self.persist_path) { + tracing::error!(error = %e, "failed to rename permissions file"); + return; + } + + tracing::debug!( + count = self.persistent_grants.len(), + "saved persistent permissions" + ); + } +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_check_returns_needs_confirmation_by_default() { + let store = PermissionStore::new_in_memory(); + assert_eq!( + store.check("filesystem.write_file"), + PermissionStatus::NeedsConfirmation + ); + } + + #[test] + fn test_session_grant_allows_tool() { + let mut store = PermissionStore::new_in_memory(); + store.grant("filesystem.write_file", PermissionScope::Session); + assert_eq!( + store.check("filesystem.write_file"), + PermissionStatus::Allowed + ); + } + + #[test] + fn test_always_grant_allows_tool() { + let mut store = PermissionStore::new_in_memory(); + store.grant("filesystem.write_file", PermissionScope::Always); + assert_eq!( + store.check("filesystem.write_file"), + PermissionStatus::Allowed + ); + } + + #[test] + fn test_clear_session_removes_session_grants_only() { + let mut store = PermissionStore::new_in_memory(); + store.grant("tool_a", PermissionScope::Session); + store.grant("tool_b", PermissionScope::Always); + + store.clear_session(); + + assert_eq!(store.check("tool_a"), PermissionStatus::NeedsConfirmation); + assert_eq!(store.check("tool_b"), PermissionStatus::Allowed); + } + + #[test] + fn test_revoke_removes_grant() { + let mut store = PermissionStore::new_in_memory(); + store.grant("tool_a", PermissionScope::Always); + assert_eq!(store.check("tool_a"), PermissionStatus::Allowed); + + let removed = store.revoke("tool_a"); + assert!(removed); + assert_eq!(store.check("tool_a"), PermissionStatus::NeedsConfirmation); + } + + #[test] + fn test_revoke_nonexistent_returns_false() { + let mut store = PermissionStore::new_in_memory(); + assert!(!store.revoke("nonexistent")); + } + + #[test] + fn test_list_persistent_returns_sorted() { + let mut store = PermissionStore::new_in_memory(); + store.grant("zzz.tool", PermissionScope::Always); + store.grant("aaa.tool", PermissionScope::Always); + store.grant("mmm.tool", PermissionScope::Session); // should not appear + + let grants = store.list_persistent(); + assert_eq!(grants.len(), 2); + assert_eq!(grants[0].tool_name, "aaa.tool"); + assert_eq!(grants[1].tool_name, "zzz.tool"); + } + + #[test] + fn test_permission_scope_serialization() { + let session = PermissionScope::Session; + let json = serde_json::to_string(&session).unwrap(); + assert_eq!(json, "\"session\""); + + let always = PermissionScope::Always; + let json = serde_json::to_string(&always).unwrap(); + assert_eq!(json, "\"always\""); + } + + #[test] + fn test_grant_overwrites_existing() { + let mut store = PermissionStore::new_in_memory(); + store.grant("tool_a", PermissionScope::Session); + store.grant("tool_a", PermissionScope::Always); + + // Should be in persistent, not session + assert_eq!(store.list_persistent().len(), 1); + assert_eq!(store.check("tool_a"), PermissionStatus::Allowed); + + // Clear session — should still be allowed via persistent + store.clear_session(); + assert_eq!(store.check("tool_a"), PermissionStatus::Allowed); + } +} diff --git a/src-tauri/src/agent_core/plan_parser.rs b/src-tauri/src/agent_core/plan_parser.rs new file mode 100644 index 0000000..564366f --- /dev/null +++ b/src-tauri/src/agent_core/plan_parser.rs @@ -0,0 +1,342 @@ +//! Plan output parsers for the orchestrator planner model. +//! +//! Supports two formats: +//! - **Bracket format** (primary): `[plan.add_step(step=1, server="fs", description="...")]` +//! This matches LFM2-24B-A2B's native bracket tool-call syntax. +//! - **JSON format** (fallback): `{"needs_tools":true,"steps":[...]}` +//! For models that can produce structured JSON. +//! +//! The bracket parser is tried first because LFM2-24B-A2B had a 94% JSON parse +//! failure rate in orchestrator benchmarks — the model naturally produces bracket +//! syntax, not raw JSON. + +use crate::agent_core::orchestrator::{PlanStep, StepPlan}; + +// ─── Bracket-Format Parser ────────────────────────────────────────────────── + +/// Parse bracket-format plan output from LFM2-24B-A2B. +/// +/// Extracts calls in the form: +/// `[plan.add_step(step=1, server="filesystem", description="...")]` +/// `[plan.respond(message="direct answer")]` +/// `[plan.done()]` +pub fn parse_bracket_plan(text: &str) -> Option { + let mut steps: Vec = Vec::new(); + let mut direct_response: Option = None; + let mut found_any_call = false; + + for line in text.lines() { + let line = line.trim(); + + // Match [plan.add_step(...)] + if let Some(inner) = extract_bracket_call(line, "plan.add_step") { + found_any_call = true; + if let Some(step) = parse_add_step_args(inner) { + steps.push(step); + } + } + + // Match [plan.respond(message="...")] + if let Some(inner) = extract_bracket_call(line, "plan.respond") { + found_any_call = true; + if let Some(msg) = extract_named_string_arg(inner, "message") { + direct_response = Some(msg); + } + } + + // Match [plan.done()] — signals end of plan + if extract_bracket_call(line, "plan.done").is_some() { + found_any_call = true; + } + } + + if !found_any_call { + return None; + } + + if let Some(ref response) = direct_response { + Some(StepPlan { + needs_tools: false, + direct_response: Some(response.clone()), + steps: Vec::new(), + }) + } else if !steps.is_empty() { + Some(StepPlan { + needs_tools: true, + direct_response: None, + steps, + }) + } else { + None + } +} + +// ─── JSON-Format Parser ───────────────────────────────────────────────────── + +/// Parse JSON-format plan output (fallback for models that produce JSON). +pub fn parse_json_plan(text: &str) -> Result { + let json_str = extract_json(text); + serde_json::from_str::(json_str) + .map_err(|e| format!("failed to parse plan JSON: {e}")) +} + +/// Extract JSON from text that may be wrapped in markdown code fences. +fn extract_json(text: &str) -> &str { + if let Some(start) = text.find('{') { + if let Some(end) = text.rfind('}') { + return &text[start..=end]; + } + } + text +} + +// ─── Bracket Argument Extractors ───────────────────────────────────────────── + +/// Extract the inner arguments from a bracket call like `[fn_name(args)]`. +/// +/// Returns the argument string between parentheses, or `None` if the line +/// doesn't match the expected function name. +fn extract_bracket_call<'a>(line: &'a str, fn_name: &str) -> Option<&'a str> { + let pattern = format!("[{fn_name}("); + let start = line.find(&pattern)?; + let args_start = start + pattern.len(); + + // Find matching closing )] + let rest = &line[args_start..]; + let close = rest.rfind(")]")?; + + Some(&rest[..close]) +} + +/// Parse arguments from a `plan.add_step(step=1, server="fs", description="...")` call. +fn parse_add_step_args(args: &str) -> Option { + let step_num = extract_named_int_arg(args, "step").unwrap_or(1); + let server = extract_named_string_arg(args, "server"); + let description = extract_named_string_arg(args, "description")?; + + Some(PlanStep { + step_number: step_num, + description, + expected_server: server, + hint_params: None, + }) +} + +/// Extract a named string argument like `key="value"` from a comma-separated arg list. +/// +/// Handles escaped quotes within the value. +pub fn extract_named_string_arg(args: &str, key: &str) -> Option { + let pattern = format!("{key}=\""); + let start = args.find(&pattern)?; + let value_start = start + pattern.len(); + let rest = &args[value_start..]; + + // Find the closing quote, handling escaped quotes + let mut end = 0; + let bytes = rest.as_bytes(); + while end < bytes.len() { + if bytes[end] == b'"' && (end == 0 || bytes[end - 1] != b'\\') { + break; + } + end += 1; + } + + if end >= bytes.len() { + return None; + } + + let value = &rest[..end]; + Some(value.replace("\\\"", "\"")) +} + +/// Extract a named integer argument like `step=1` from a comma-separated arg list. +pub fn extract_named_int_arg(args: &str, key: &str) -> Option { + let pattern = format!("{key}="); + let start = args.find(&pattern)?; + let value_start = start + pattern.len(); + let rest = &args[value_start..]; + + // Read digits until non-digit + let digits: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect(); + digits.parse().ok() +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + // ─── Bracket-format plan parsing ───────────────────────────────────── + + #[test] + fn parse_bracket_single_step() { + let text = r#"[plan.add_step(step=1, server="filesystem", description="List files in /tmp")] +[plan.done()]"#; + let plan = parse_bracket_plan(text).unwrap(); + assert!(plan.needs_tools); + assert_eq!(plan.steps.len(), 1); + assert_eq!(plan.steps[0].step_number, 1); + assert_eq!(plan.steps[0].expected_server.as_deref(), Some("filesystem")); + assert_eq!(plan.steps[0].description, "List files in /tmp"); + } + + #[test] + fn parse_bracket_multi_step() { + let text = r#"[plan.add_step(step=1, server="filesystem", description="List all PDF files in /Users/chintan/Downloads")] +[plan.add_step(step=2, server="document", description="Using the result from step 1, extract text from the first PDF")] +[plan.add_step(step=3, server="knowledge", description="Index the extracted text for semantic search")] +[plan.done()]"#; + let plan = parse_bracket_plan(text).unwrap(); + assert!(plan.needs_tools); + assert_eq!(plan.steps.len(), 3); + assert_eq!(plan.steps[0].step_number, 1); + assert_eq!(plan.steps[1].step_number, 2); + assert_eq!(plan.steps[2].step_number, 3); + assert_eq!(plan.steps[2].expected_server.as_deref(), Some("knowledge")); + assert!(plan.steps[1].description.contains("result from step 1")); + } + + #[test] + fn parse_bracket_direct_response() { + let text = r#"[plan.respond(message="The capital of France is Paris.")]"#; + let plan = parse_bracket_plan(text).unwrap(); + assert!(!plan.needs_tools); + assert_eq!( + plan.direct_response.as_deref(), + Some("The capital of France is Paris.") + ); + assert!(plan.steps.is_empty()); + } + + #[test] + fn parse_bracket_with_escaped_quotes() { + let text = + r#"[plan.add_step(step=1, server="filesystem", description="Search for files named \"report.pdf\" in Downloads")] +[plan.done()]"#; + let plan = parse_bracket_plan(text).unwrap(); + assert_eq!(plan.steps.len(), 1); + assert_eq!( + plan.steps[0].description, + r#"Search for files named "report.pdf" in Downloads"# + ); + } + + #[test] + fn parse_bracket_no_server_hint() { + let text = + r#"[plan.add_step(step=1, description="Find all receipts from last month")] +[plan.done()]"#; + let plan = parse_bracket_plan(text).unwrap(); + assert_eq!(plan.steps.len(), 1); + assert!(plan.steps[0].expected_server.is_none()); + } + + #[test] + fn parse_bracket_with_surrounding_text() { + let text = r#"Here is the plan: +[plan.add_step(step=1, server="filesystem", description="List files in /tmp")] +[plan.done()]"#; + let plan = parse_bracket_plan(text).unwrap(); + assert_eq!(plan.steps.len(), 1); + } + + #[test] + fn parse_bracket_returns_none_for_plain_text() { + let text = "I can help you with that! Let me list the files."; + assert!(parse_bracket_plan(text).is_none()); + } + + #[test] + fn parse_bracket_returns_none_for_empty() { + assert!(parse_bracket_plan("").is_none()); + } + + // ─── Extract bracket call internals ────────────────────────────────── + + #[test] + fn extract_bracket_call_basic() { + let line = r#"[plan.add_step(step=1, server="fs", description="test")]"#; + let inner = extract_bracket_call(line, "plan.add_step").unwrap(); + assert_eq!(inner, r#"step=1, server="fs", description="test""#); + } + + #[test] + fn extract_bracket_call_no_args() { + let line = "[plan.done()]"; + let inner = extract_bracket_call(line, "plan.done").unwrap(); + assert_eq!(inner, ""); + } + + #[test] + fn extract_bracket_call_wrong_name() { + let line = "[plan.add_step(step=1)]"; + assert!(extract_bracket_call(line, "plan.respond").is_none()); + } + + // ─── Named argument extraction ─────────────────────────────────────── + + #[test] + fn extract_string_arg_basic() { + let args = r#"step=1, server="filesystem", description="List files""#; + assert_eq!( + extract_named_string_arg(args, "server"), + Some("filesystem".to_string()) + ); + assert_eq!( + extract_named_string_arg(args, "description"), + Some("List files".to_string()) + ); + } + + #[test] + fn extract_string_arg_missing() { + let args = r#"step=1, description="test""#; + assert!(extract_named_string_arg(args, "server").is_none()); + } + + #[test] + fn extract_int_arg_basic() { + let args = r#"step=3, server="filesystem""#; + assert_eq!(extract_named_int_arg(args, "step"), Some(3)); + } + + #[test] + fn extract_int_arg_missing() { + let args = r#"server="filesystem""#; + assert!(extract_named_int_arg(args, "step").is_none()); + } + + // ─── JSON fallback parsing ─────────────────────────────────────────── + + #[test] + fn parse_json_step_plan() { + let json = r#"{"needs_tools":true,"steps":[{"step_number":1,"description":"List files in /tmp","expected_server":"filesystem"}]}"#; + let plan = parse_json_plan(json).unwrap(); + assert!(plan.needs_tools); + assert_eq!(plan.steps.len(), 1); + assert_eq!(plan.steps[0].expected_server.as_deref(), Some("filesystem")); + } + + #[test] + fn parse_json_no_tools_plan() { + let json = r#"{"needs_tools":false,"direct_response":"The answer is 42."}"#; + let plan = parse_json_plan(json).unwrap(); + assert!(!plan.needs_tools); + assert_eq!(plan.direct_response.as_deref(), Some("The answer is 42.")); + } + + #[test] + fn extract_json_from_markdown() { + let wrapped = "```json\n{\"needs_tools\":true,\"steps\":[]}\n```"; + let result = extract_json(wrapped); + assert_eq!(result, "{\"needs_tools\":true,\"steps\":[]}"); + } + + #[test] + fn extract_json_bare() { + let bare = "{\"needs_tools\":false}"; + let result = extract_json(bare); + assert_eq!(result, bare); + } +} diff --git a/src-tauri/src/agent_core/plan_templates.rs b/src-tauri/src/agent_core/plan_templates.rs new file mode 100644 index 0000000..32c9ffb --- /dev/null +++ b/src-tauri/src/agent_core/plan_templates.rs @@ -0,0 +1,306 @@ +//! Template-based plan decomposition for known use case patterns. +//! +//! Before calling the planner model, the orchestrator checks if the user's +//! message matches a known use case pattern (from PRD UC-1 through UC-10). +//! If matched, a pre-built `StepPlan` is returned directly, bypassing the +//! planner model call entirely. This guarantees correct multi-step decomposition +//! for common workflows and saves ~2-3s of planner latency. + +use crate::agent_core::orchestrator::{PlanStep, StepPlan}; + +/// Attempt to match the user's message against known use case templates. +/// +/// Returns `Some(StepPlan)` if a high-confidence match is found, `None` otherwise. +/// The caller should fall through to the planner model when `None` is returned. +/// +/// Match confidence requires at least 3 signal keyword groups for multi-step +/// templates to avoid false positives on simple requests. +pub fn try_template_match(user_message: &str) -> Option { + let lower = user_message.to_lowercase(); + + // UC-4: Download triage (5 steps) — checked before UC-1 because UC-4's + // "download" keyword is more specific; without this ordering, generic + // file-management messages can false-positive into UC-1. + let uc4_score = keyword_score(&lower, &[ + &["download"], + &["organize", "classify", "sort", "clean up", "triage"], + &["move", "file", "rename"], + &["pii", "sensitive", "scan", "security"], + &["task", "follow up", "remediat"], + ]); + if uc4_score >= 3 { + tracing::info!(score = uc4_score, "template match: UC-4 download triage"); + return Some(build_uc4_download_triage_template(user_message)); + } + + // UC-1: Receipt reconciliation (4 steps) + let uc1_score = keyword_score(&lower, &[ + &["receipt", "invoice", "expense"], + &["folder", "directory", "files in"], + &["organize", "reconcil", "spreadsheet", "csv", "categoriz"], + &["scan", "extract", "ocr"], + ]); + if uc1_score >= 3 { + tracing::info!(score = uc1_score, "template match: UC-1 receipt reconciliation"); + return Some(build_uc1_receipt_template(user_message)); + } + + // UC-7: Contract copilot (3 steps) + let uc7_score = keyword_score(&lower, &[ + &["contract", "nda", "agreement", "legal"], + &["compare", "diff", "review", "analyz"], + &["email", "draft", "send", "counsel"], + ]); + if uc7_score >= 3 { + tracing::info!(score = uc7_score, "template match: UC-7 contract copilot"); + return Some(build_uc7_contract_copilot_template(user_message)); + } + + None +} + +/// Score a message against keyword groups. +/// +/// Each group is a set of synonymous terms. A group is "matched" if ANY term +/// in it appears in the message. Returns the count of matched groups. +fn keyword_score(lower_message: &str, groups: &[&[&str]]) -> usize { + groups + .iter() + .filter(|group| group.iter().any(|kw| lower_message.contains(kw))) + .count() +} + +/// Extract a file/directory path hint from the user message. +/// +/// Delegates to the orchestrator's existing path extraction logic. +fn extract_path_hint(user_message: &str) -> Option { + crate::agent_core::orchestrator::extract_path_from_text(user_message) +} + +// ─── Template Builders ────────────────────────────────────────────────────── + +/// UC-1: Receipt Reconciliation — list → extract → write CSV → create task. +fn build_uc1_receipt_template(user_message: &str) -> StepPlan { + let path = extract_path_hint(user_message).unwrap_or_else(|| "~/Downloads".to_string()); + StepPlan { + needs_tools: true, + direct_response: None, + steps: vec![ + PlanStep { + step_number: 1, + description: format!( + "List all files in {path} to find receipts, invoices, and expense documents" + ), + expected_server: Some("filesystem".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 2, + description: "Using the result from step 1, extract text from each receipt \ + or invoice file (OCR for images, text extraction for PDFs)" + .to_string(), + expected_server: Some("document".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 3, + description: "Using the extracted text from step 2, write the structured \ + receipt data (vendor, date, amount, category) to a CSV spreadsheet" + .to_string(), + expected_server: Some("data".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 4, + description: "Using the results from step 3, create a follow-up task to \ + review the reconciled receipts and flag any anomalies" + .to_string(), + expected_server: Some("task".to_string()), + hint_params: None, + }, + ], + } +} + +/// UC-4: Download Triage — list → extract → scan PII → move → create task. +fn build_uc4_download_triage_template(user_message: &str) -> StepPlan { + let path = extract_path_hint(user_message).unwrap_or_else(|| "~/Downloads".to_string()); + StepPlan { + needs_tools: true, + direct_response: None, + steps: vec![ + PlanStep { + step_number: 1, + description: format!( + "List all files in {path} to identify what needs to be triaged" + ), + expected_server: Some("filesystem".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 2, + description: "Using the result from step 1, extract text from document files \ + (PDFs, DOCX) to understand their content for classification" + .to_string(), + expected_server: Some("document".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 3, + description: "Using the result from step 1, scan all files for PII \ + (SSNs, credit card numbers) and secrets (API keys, passwords)" + .to_string(), + expected_server: Some("security".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 4, + description: format!( + "Using the results from steps 2 and 3, move files from {path} \ + to appropriate categorized folders" + ), + expected_server: Some("filesystem".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 5, + description: "Using the results from steps 3 and 4, create a remediation \ + task for any files with PII or security findings" + .to_string(), + expected_server: Some("task".to_string()), + hint_params: None, + }, + ], + } +} + +/// UC-7: Contract Copilot — extract text → search knowledge → draft email. +fn build_uc7_contract_copilot_template(user_message: &str) -> StepPlan { + let _path = extract_path_hint(user_message); + StepPlan { + needs_tools: true, + direct_response: None, + steps: vec![ + PlanStep { + step_number: 1, + description: "Extract text from the contract or NDA document provided" + .to_string(), + expected_server: Some("document".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 2, + description: "Using the extracted text from step 1, search the knowledge \ + base for similar clauses or related contract precedents" + .to_string(), + expected_server: Some("knowledge".to_string()), + hint_params: None, + }, + PlanStep { + step_number: 3, + description: "Using the analysis from steps 1 and 2, draft an email \ + summarizing the key findings, risk flags, and recommendations" + .to_string(), + expected_server: Some("email".to_string()), + hint_params: None, + }, + ], + } +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn uc1_receipt_match() { + let msg = "Scan and organize the receipts in my ~/Documents/Expenses folder \ + and extract the data into a CSV spreadsheet"; + let result = try_template_match(msg); + assert!(result.is_some(), "should match UC-1"); + let plan = result.unwrap(); + assert_eq!(plan.steps.len(), 4); + assert_eq!(plan.steps[0].expected_server.as_deref(), Some("filesystem")); + assert_eq!(plan.steps[1].expected_server.as_deref(), Some("document")); + assert_eq!(plan.steps[2].expected_server.as_deref(), Some("data")); + assert_eq!(plan.steps[3].expected_server.as_deref(), Some("task")); + } + + #[test] + fn uc4_download_triage_match() { + let msg = "Triage my downloads folder — organize files, scan for sensitive PII, \ + move them to the right place, and create a task for anything flagged"; + let result = try_template_match(msg); + assert!(result.is_some(), "should match UC-4"); + let plan = result.unwrap(); + assert_eq!(plan.steps.len(), 5); + assert_eq!(plan.steps[0].expected_server.as_deref(), Some("filesystem")); + assert_eq!(plan.steps[2].expected_server.as_deref(), Some("security")); + assert_eq!(plan.steps[4].expected_server.as_deref(), Some("task")); + } + + #[test] + fn uc7_contract_copilot_match() { + let msg = "Review the NDA contract, compare it against our legal knowledge base, \ + and draft an email to counsel with your analysis"; + let result = try_template_match(msg); + assert!(result.is_some(), "should match UC-7"); + let plan = result.unwrap(); + assert_eq!(plan.steps.len(), 3); + assert_eq!(plan.steps[0].expected_server.as_deref(), Some("document")); + assert_eq!(plan.steps[1].expected_server.as_deref(), Some("knowledge")); + assert_eq!(plan.steps[2].expected_server.as_deref(), Some("email")); + } + + #[test] + fn no_match_simple_request() { + let msg = "List my Downloads folder"; + assert!(try_template_match(msg).is_none(), "simple request should not match"); + } + + #[test] + fn no_match_partial_keywords() { + let msg = "Organize my files"; + assert!(try_template_match(msg).is_none(), "only 1-2 keyword groups should not match"); + } + + #[test] + fn uc1_extracts_path() { + let msg = "Scan the receipts in /Users/chintan/Expenses and extract to CSV"; + let result = try_template_match(msg); + assert!(result.is_some()); + let plan = result.unwrap(); + assert!( + plan.steps[0].description.contains("/Users/chintan/Expenses"), + "should extract path from message" + ); + } + + #[test] + fn keyword_score_all_groups_match() { + let score = keyword_score( + "scan receipts in folder and organize into csv", + &[ + &["receipt", "invoice"], + &["folder", "directory"], + &["organize", "csv"], + ], + ); + assert_eq!(score, 3); + } + + #[test] + fn keyword_score_partial_match() { + let score = keyword_score( + "organize my files", + &[ + &["receipt", "invoice"], + &["folder", "directory"], + &["organize", "csv"], + ], + ); + assert_eq!(score, 1); + } +} diff --git a/src-tauri/src/agent_core/response_analysis.rs b/src-tauri/src/agent_core/response_analysis.rs new file mode 100644 index 0000000..8af6a2b --- /dev/null +++ b/src-tauri/src/agent_core/response_analysis.rs @@ -0,0 +1,437 @@ +//! Response analysis for the agent loop. +//! +//! Detects incomplete tasks, conversational deflection (FM-3), and completion +//! summaries in model text responses. Used by the agent loop in `commands/chat.rs` +//! to decide whether to inject continuation prompts or exit. + +/// Detect if a model's text response indicates an incomplete task. +/// +/// A local model often "fatigues" mid-task and produces a partial summary +/// after processing 3-4 items out of 7+. This function looks for signals that +/// the model stopped before finishing: mentions of remaining files, "next" steps, +/// partial progress reports, etc. +/// +/// Returns `true` if the response suggests the task is unfinished and the model +/// should be prompted to continue. +pub fn is_incomplete_response(text: &str) -> bool { + let lower = text.to_lowercase(); + + // Patterns that indicate the model is reporting incomplete progress + let incomplete_signals = [ + "remaining", + "left to process", + "still need to", + "continue with", + "next file", + "next screenshot", + "more files", + "not yet processed", + "will process", + "haven't processed", + "need to rename", + "files left", + "let me continue", + "i'll continue", + "proceeding to", + "moving on to", + ]; + + if is_completion_summary(&lower) { + return false; + } + + for signal in &incomplete_signals { + if lower.contains(signal) { + return true; + } + } + + false +} + +/// Detect if a model's text response is a conversational deflection (FM-3). +/// +/// Called AFTER [`is_incomplete_response`] returns `false`. Identifies cases +/// where the model received tool results but responds with a question or +/// narration instead of calling the next tool. +/// +/// Three detection layers: +/// - **Result presentation guard**: text presenting tool results without +/// a question is NOT deflection (e.g., "Here are the files in your folder"). +/// - **Layer A**: Explicit deflection phrases — model asks user instead of acting. +/// - **Layer B**: Short-question heuristic for novel formulations. +/// +/// Gated on `round > 0 AND tool_call_count > 0` so that round-0 text +/// responses and direct answers are never flagged. +/// +/// # Arguments +/// * `text` — The model's text response. +/// * `round` — Current agent loop round (0-based). +/// * `tool_call_count` — Total tool calls executed so far in this loop. +pub fn is_deflection_response(text: &str, round: usize, tool_call_count: usize) -> bool { + // Gate: only check after at least one tool has been executed + if round == 0 || tool_call_count == 0 { + return false; + } + + // Trust gate: after 3+ tool calls the model has done meaningful work. + // Its text response is a summary or answer, not a deflection. + if tool_call_count >= 3 { + return false; + } + + let lower = text.to_lowercase(); + + // If the response is a genuine completion summary, don't flag it + if is_completion_summary(&lower) { + return false; + } + + // Result presentation guard: if the model is presenting tool results + // (file listings, scan findings, data summaries) WITHOUT asking a + // question or deferring to the user, this is expected behavior. + if is_presenting_results(&lower) && !text.contains('?') && !has_deferral(&lower) { + return false; + } + + // Layer A: Explicit deflection patterns — model asks user instead of acting + let deflection_patterns = [ + "what would you like", + "how would you like", + "how should i", + "what should i", + "would you like me to", + "shall i", + "do you want me to", + "let me know", + "please let me know", + "i can help you", + "i can assist", + "what do you think", + "which one", + "which files", + "here are some options", + ]; + + for pattern in &deflection_patterns { + if lower.contains(pattern) { + return true; + } + } + + // Layer A extension: Pure narration with no substance — model describes + // what it sees without presenting actionable results. Patterns that + // present results ("i found the following", "here are the files") are + // handled by the result-presentation guard above, not here. + let narration_patterns = [ + "i see the files", + "i see the following", + "i notice", + ]; + + for pattern in &narration_patterns { + if lower.contains(pattern) { + return true; + } + } + + // Layer B: Short-question heuristic — catches novel deflection formulations. + // If the response is short (<300 chars), contains a question mark, and we've + // already executed tools, the model is likely asking instead of acting. + if text.len() < 300 && text.contains('?') { + return true; + } + + false +} + +/// Check if text contains deferral phrases that hand control back to the user. +/// +/// These phrases indicate the model is waiting for instructions rather than +/// proceeding autonomously. Used alongside the result-presentation guard +/// to catch "I found X files. Let me know which ones you want." patterns. +fn has_deferral(lower: &str) -> bool { + let deferral_phrases = [ + "let me know", + "please let me know", + "would you like", + "shall i", + "do you want", + "which one", + "which files", + "which of", + ]; + + for phrase in &deferral_phrases { + if lower.contains(phrase) { + return true; + } + } + + false +} + +/// Check if a model's text is presenting tool results to the user. +/// +/// This distinguishes "Here are the files in your Downloads folder: ..." +/// (presenting results — correct behavior) from "I see files on your Desktop. +/// What would you like me to do?" (deflection — incorrect). +/// +/// Called by [`is_deflection_response`] to guard against false positives when +/// the model is doing exactly what the user asked: calling a tool and +/// reporting the results. +fn is_presenting_results(lower: &str) -> bool { + let result_patterns = [ + "here is", // "Here is what I found" + "here are", // "Here are the files" + "here's what", // "Here's what the scan found" + "i found", // "I found 5 files" + "the scan found", // "The scan found 2 secrets" + "the scan shows", // "The scan shows no issues" + "the results", // "The results show" + "contains", // "The folder contains" + "total of", // "A total of 12 files" + "no files", // "No files found" + "no results", // "No results" + "the folder is", // "The folder is empty" + "the directory", // "The directory contains" + "found in", // "Found in the folder" + "listed below", // "Files listed below" + ]; + + for p in &result_patterns { + if lower.contains(p) { + return true; + } + } + + false +} + +/// Check if text matches known task-completion signals. +/// +/// Used by both [`is_incomplete_response`] and [`is_deflection_response`] +/// to avoid false positives on legitimate summaries. +fn is_completion_summary(lower: &str) -> bool { + let complete_signals = [ + "all files have been", + "all screenshots have been", + "completed all", + "finished processing", + "all done", + "successfully renamed all", + "processed all", + "no more files", + "task complete", + "that's all", + "here is a summary", + "here's what i did", + "here's what was done", + "i've completed", + "i have completed", + "summary of changes", + "all files processed", + ]; + + for signal in &complete_signals { + if lower.contains(signal) { + return true; + } + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── is_deflection_response tests ────────────────────────────────── + + #[test] + fn deflection_question_after_tool() { + // Model asks the user what to do — clear deflection + let text = "I see the files on your Desktop. How would you like me to process them?"; + assert!(is_deflection_response(text, 1, 1)); + } + + #[test] + fn deflection_narration_with_question() { + // Narration + question = deflection (the "?" makes it past result guard) + let text = "I found the following files in your Downloads folder. \ + Let me know which ones you'd like to rename."; + assert!(is_deflection_response(text, 1, 1)); + } + + #[test] + fn deflection_shall_i_pattern() { + let text = "There are 7 screenshots. Shall I extract text from each one?"; + assert!(is_deflection_response(text, 1, 1)); + } + + #[test] + fn deflection_short_question_heuristic() { + let text = "I see 7 files. Which ones should I process?"; + assert!(is_deflection_response(text, 1, 1)); + } + + #[test] + fn no_deflection_on_round_zero() { + let text = "What would you like me to do?"; + assert!(!is_deflection_response(text, 0, 0)); + } + + #[test] + fn no_deflection_on_zero_tools() { + let text = "How would you like me to process them?"; + assert!(!is_deflection_response(text, 1, 0)); + } + + #[test] + fn no_deflection_on_completion_summary() { + let text = "All files have been renamed successfully. Here's what I did: \ + renamed 3 screenshots based on their OCR content."; + assert!(!is_deflection_response(text, 5, 7)); + } + + #[test] + fn no_deflection_on_legitimate_answer() { + let text = "The meeting notes contain the following action items: \ + 1. Prepare the quarterly report by Friday. \ + 2. Schedule a follow-up with the design team. \ + 3. Review the updated budget spreadsheet and send comments."; + assert!(!is_deflection_response(text, 1, 1)); + } + + // ── Result presentation — NOT deflection ───────────────────────── + + #[test] + fn no_deflection_when_presenting_file_listing() { + // This is exactly the Test 1 scenario: model lists files after list_dir + let text = "Here are the files in your Downloads folder: \ + DEMO CARD styles.png, benchmark_results.csv, \ + Liquid AI Notes.pdf, and 12 others."; + assert!(!is_deflection_response(text, 1, 1)); + } + + #[test] + fn no_deflection_when_reporting_scan_results() { + // Model reports scan findings without asking a question + let text = "The scan found 2 files containing secrets: \ + .env (AWS key), config.yaml (API token)."; + assert!(!is_deflection_response(text, 1, 1)); + } + + #[test] + fn no_deflection_when_folder_empty() { + let text = "No files were found matching your criteria in the folder."; + assert!(!is_deflection_response(text, 1, 1)); + } + + #[test] + fn no_deflection_here_is_what_i_found() { + let text = "Here is what I found in your Documents folder: 3 PDF files, \ + 2 spreadsheets, and 1 text file containing notes."; + assert!(!is_deflection_response(text, 1, 1)); + } + + #[test] + fn deflection_result_plus_question() { + // Presenting results BUT also asking a question — still deflection + let text = "I found 5 files. Which ones should I process?"; + assert!(is_deflection_response(text, 1, 1)); + } + + // ── Trust gate tests ───────────────────────────────────────────── + + #[test] + fn no_deflection_after_three_tool_calls() { + // After 3+ tool calls, model has done real work — trust it + let text = "I found the following text in your screenshots: ..."; + assert!(!is_deflection_response(text, 4, 3)); + assert!(!is_deflection_response(text, 8, 5)); + assert!(!is_deflection_response(text, 14, 13)); + } + + #[test] + fn deflection_still_fires_with_few_tool_calls() { + // With 1-2 tool calls, deflection detection still active for questions + let text = "I found the following files. Which ones should I process?"; + assert!(is_deflection_response(text, 1, 1)); + assert!(is_deflection_response(text, 2, 2)); + } + + // ── is_incomplete_response tests ────────────────────────────────── + + #[test] + fn incomplete_remaining_files() { + assert!(is_incomplete_response( + "I've processed 3 files. There are 4 remaining." + )); + } + + #[test] + fn incomplete_next_file() { + assert!(is_incomplete_response("Moving on to the next file now.")); + } + + #[test] + fn complete_all_done() { + assert!(!is_incomplete_response( + "All done! I renamed all 7 screenshots." + )); + } + + #[test] + fn complete_finished() { + assert!(!is_incomplete_response("Finished processing all files.")); + } + + #[test] + fn neutral_text_not_incomplete() { + assert!(!is_incomplete_response( + "The file contains a quarterly revenue report." + )); + } + + // ── is_completion_summary tests ─────────────────────────────────── + + #[test] + fn summary_detected() { + assert!(is_completion_summary( + "all files have been renamed successfully" + )); + } + + #[test] + fn summary_not_detected() { + assert!(!is_completion_summary("i see the files on your desktop")); + } + + // ── is_presenting_results tests ────────────────────────────────── + + #[test] + fn presenting_file_list() { + assert!(is_presenting_results( + "here are the files in your downloads folder" + )); + } + + #[test] + fn presenting_scan_findings() { + assert!(is_presenting_results("the scan found 2 secrets")); + } + + #[test] + fn presenting_empty_results() { + assert!(is_presenting_results("no files matching your criteria")); + } + + #[test] + fn not_presenting_pure_question() { + assert!(!is_presenting_results( + "what would you like me to do with these" + )); + } + +} diff --git a/src-tauri/src/agent_core/tokens.rs b/src-tauri/src/agent_core/tokens.rs new file mode 100644 index 0000000..732667b --- /dev/null +++ b/src-tauri/src/agent_core/tokens.rs @@ -0,0 +1,291 @@ +//! Token estimation for context window management. +//! +//! Uses character-based heuristics calibrated for LLM tokenizers: +//! - English prose: ~3.2 chars/token (conservative — overestimate is safer) +//! - JSON/structured content: ~2.8 chars/token (denser due to punctuation, short keys) +//! +//! A more accurate tokenizer (tiktoken-rs) can replace this when the model +//! is finalized. + +use crate::inference::types::{ChatMessage, Role}; + +// ─── Constants ────────────────────────────────────────────────────────────── + +/// Average characters per token for English prose. +/// +/// Calibrated conservatively — most LLM tokenizers produce ~3.5-4.0 chars/token +/// for English text. We use 3.2 to err on the side of overestimation, which is +/// safer than underestimating and overflowing the context window. +const CHARS_PER_TOKEN: f64 = 3.2; + +/// Average characters per token for JSON/structured content. +/// +/// JSON tokenizes more densely than prose due to punctuation, short keys, +/// braces, and colons. Tool call arguments, tool results, and schema +/// definitions all fall into this category. +const JSON_CHARS_PER_TOKEN: f64 = 2.8; + +/// Per-message overhead (role label, formatting tokens). +const MESSAGE_OVERHEAD_TOKENS: u32 = 4; + +/// Overhead for tool call JSON structure (per call). +const TOOL_CALL_OVERHEAD_TOKENS: u32 = 10; + +// ─── UTF-8 Safe Truncation ────────────────────────────────────────────────── + +/// Truncate a string to at most `max_bytes` bytes on a valid UTF-8 char boundary. +/// +/// Returns a `&str` that is always valid UTF-8 and at most `max_bytes` long. +/// If the byte at `max_bytes` is inside a multi-byte character, the slice is +/// shortened to the preceding character boundary. +pub(crate) fn truncate_utf8(s: &str, max_bytes: usize) -> &str { + if s.len() <= max_bytes { + return s; + } + // Walk backward to find a valid char boundary + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + &s[..end] +} + +// ─── Public API ───────────────────────────────────────────────────────────── + +/// Estimate the token count for a string of natural language text. +pub fn estimate_tokens(text: &str) -> u32 { + let chars = text.len() as f64; + (chars / CHARS_PER_TOKEN).ceil() as u32 +} + +/// Estimate the token count for JSON/structured content. +/// +/// JSON tokenizes more densely than prose, so this uses a tighter ratio. +/// Use this for tool call arguments, tool results, and schema definitions. +pub fn estimate_json_tokens(json_text: &str) -> u32 { + let chars = json_text.len() as f64; + (chars / JSON_CHARS_PER_TOKEN).ceil() as u32 +} + +/// Estimate the token count for a `ChatMessage`. +/// +/// Accounts for content, tool calls, and per-message overhead. +/// Uses the JSON-specific estimator for tool call arguments (which are +/// always JSON) and the prose estimator for natural language content. +pub fn estimate_message_tokens(message: &ChatMessage) -> u32 { + let mut total = MESSAGE_OVERHEAD_TOKENS; + + // Content tokens — use prose estimator for user/assistant text, + // JSON estimator for tool results (role == Tool) + if let Some(ref content) = message.content { + total += match message.role { + Role::Tool => estimate_json_tokens(content), + _ => estimate_tokens(content), + }; + } + + // Tool call tokens — arguments are always JSON + if let Some(ref calls) = message.tool_calls { + for call in calls { + total += TOOL_CALL_OVERHEAD_TOKENS; + total += estimate_tokens(&call.function.name); + total += estimate_json_tokens(&call.function.arguments); + } + } + + // Tool call ID (for tool-role messages) + if let Some(ref id) = message.tool_call_id { + total += estimate_tokens(id); + } + + total +} + +/// Estimate the token count for a system prompt string. +pub fn estimate_system_prompt_tokens(prompt: &str) -> u32 { + MESSAGE_OVERHEAD_TOKENS + estimate_tokens(prompt) +} + +/// Estimate token count for tool definitions in OpenAI format. +/// +/// Tool definitions are JSON, so we use the JSON-specific estimator. +pub fn estimate_tool_definitions_tokens(tools: &[serde_json::Value]) -> u32 { + let json = serde_json::to_string(tools).unwrap_or_default(); + estimate_json_tokens(&json) +} + +/// Estimate token count for a raw string that will be included as content. +pub fn estimate_content_tokens(content: &str) -> u32 { + estimate_tokens(content) +} + +/// Summarize a tool result into a one-line string. +/// +/// Used when evicting old conversation turns to reduce token usage. +pub fn summarize_tool_result(tool_name: &str, result: &serde_json::Value) -> String { + let result_str = serde_json::to_string(result).unwrap_or_default(); + let token_count = estimate_tokens(&result_str); + + if token_count <= 50 { + // Short enough to keep as-is + format!("[{tool_name} returned: {result_str}]") + } else { + // Summarize to one line + let preview = truncate_utf8(&result_str, 100); + format!( + "[{tool_name} returned ~{token_count} tokens: {preview}...]" + ) + } +} + +/// Build a one-line summary of a conversation turn for eviction. +/// +/// Captures the user's request and the assistant's response type. +pub fn summarize_turn(role: &Role, content: Option<&str>) -> String { + match role { + Role::User => { + let text = content.unwrap_or("[empty]"); + let preview = truncate_utf8(text, 80); + format!("User: {preview}") + } + Role::Assistant => { + let text = content.unwrap_or("[tool calls]"); + let preview = truncate_utf8(text, 80); + format!("Assistant: {preview}") + } + Role::Tool => { + let text = content.unwrap_or("[result]"); + let preview = truncate_utf8(text, 60); + format!("Tool result: {preview}") + } + Role::System => "System prompt".to_string(), + } +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::inference::types::{FunctionCallResponse, ToolCallResponse}; + + #[test] + fn test_estimate_tokens_empty() { + assert_eq!(estimate_tokens(""), 0); + } + + #[test] + fn test_estimate_tokens_short() { + // "hello" = 5 chars → ceil(5/3.2) = 2 + assert_eq!(estimate_tokens("hello"), 2); + } + + #[test] + fn test_estimate_tokens_longer() { + // 100 chars → ceil(100/3.2) = 32 + let text = "a".repeat(100); + assert_eq!(estimate_tokens(&text), 32); + } + + #[test] + fn test_estimate_json_tokens() { + // 16 chars → ceil(16/2.8) = 6 + let json = r#"{"path": "/tmp"}"#; + assert_eq!(estimate_json_tokens(json), 6); + } + + #[test] + fn test_estimate_message_tokens_content_only() { + let msg = ChatMessage { + role: Role::User, + content: Some("Hello, world!".to_string()), // 13 chars → ceil(13/3.2) = 5 + tool_call_id: None, + tool_calls: None, + }; + let tokens = estimate_message_tokens(&msg); + // 4 overhead + 5 content = 9 + assert_eq!(tokens, 9); + } + + #[test] + fn test_estimate_message_tokens_with_tool_calls() { + let msg = ChatMessage { + role: Role::Assistant, + content: None, + tool_call_id: None, + tool_calls: Some(vec![ToolCallResponse { + id: "call_1".to_string(), + r#type: "function".to_string(), + function: FunctionCallResponse { + name: "filesystem.list_dir".to_string(), + arguments: r#"{"path": "/tmp"}"#.to_string(), + }, + }]), + }; + let tokens = estimate_message_tokens(&msg); + // 4 overhead + 10 tool_call_overhead + name_tokens + args_tokens > 4 + assert!(tokens > 4); + } + + #[test] + fn test_summarize_tool_result_short() { + let result = serde_json::json!({"files": ["a.txt"]}); + let summary = summarize_tool_result("filesystem.list_dir", &result); + assert!(summary.starts_with("[filesystem.list_dir returned:")); + } + + #[test] + fn test_summarize_tool_result_long() { + let long_data: Vec = (0..200).map(|i| format!("file_{i}.txt")).collect(); + let result = serde_json::json!({"files": long_data}); + let summary = summarize_tool_result("filesystem.list_dir", &result); + assert!(summary.contains("tokens:")); + assert!(summary.ends_with("...]")); + } + + #[test] + fn test_summarize_turn_user() { + let summary = summarize_turn(&Role::User, Some("List all files in /tmp")); + assert!(summary.starts_with("User: List all")); + } + + #[test] + fn test_summarize_turn_assistant_none() { + let summary = summarize_turn(&Role::Assistant, None); + assert_eq!(summary, "Assistant: [tool calls]"); + } + + #[test] + fn test_truncate_utf8_ascii() { + assert_eq!(truncate_utf8("hello world", 5), "hello"); + } + + #[test] + fn test_truncate_utf8_within_multibyte() { + // '═' is U+2550, encoded as 3 bytes: 0xE2, 0x95, 0x90 + let text = "═══"; // 9 bytes total + // Cutting at byte 4 lands inside the second '═' (bytes 3..6) + assert_eq!(truncate_utf8(text, 4), "═"); + // Cutting at byte 6 is exactly at a boundary + assert_eq!(truncate_utf8(text, 6), "══"); + } + + #[test] + fn test_truncate_utf8_no_truncation_needed() { + assert_eq!(truncate_utf8("short", 100), "short"); + } + + #[test] + fn test_summarize_tool_result_unicode_no_panic() { + // Reproduces the crash: audit report with box-drawing chars + let report = format!( + "{}{}", + "═".repeat(50), + "AUDIT REPORT — Session test-session" + ); + let result = serde_json::json!({"report": report}); + // This must NOT panic + let summary = summarize_tool_result("audit.generate_audit_report", &result); + assert!(summary.contains("audit.generate_audit_report")); + } +} diff --git a/src-tauri/src/agent_core/tool_prefilter.rs b/src-tauri/src/agent_core/tool_prefilter.rs new file mode 100644 index 0000000..54afdc5 --- /dev/null +++ b/src-tauri/src/agent_core/tool_prefilter.rs @@ -0,0 +1,362 @@ +//! RAG pre-filter for tool selection (ADR-010 Phase 2 / ADR-009). +//! +//! Embeds tool descriptions at startup via the router model's `/embeddings` +//! endpoint, then for each user query, embeds the query and selects the +//! top-K tools by cosine similarity. +//! +//! Ported from TypeScript `tests/model-behavior/benchmark-lfm.ts` (lines 270-414). + +use reqwest::Client as HttpClient; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +// ─── Error Type ───────────────────────────────────────────────────────────── + +/// Errors from the tool pre-filter. +#[derive(Debug, thiserror::Error)] +pub enum ToolPreFilterError { + #[error("embedding request failed (HTTP {status}): {body}")] + HttpError { status: u16, body: String }, + + #[error("embedding request failed: {reason}")] + RequestFailed { reason: String }, + + #[error("empty embedding response for {count} inputs")] + EmptyResponse { count: usize }, + + #[error("dimension mismatch: expected {expected}, got {actual}")] + DimensionMismatch { expected: usize, actual: usize }, +} + +// ─── Embedding Response Types ─────────────────────────────────────────────── + +/// Raw embedding item from the `/embeddings` endpoint. +#[derive(Debug, Deserialize)] +struct RawEmbeddingItem { + index: usize, + embedding: serde_json::Value, // number[] or number[][] (per-token) +} + +/// Parsed embedding response (array of items). +#[derive(Debug, Deserialize)] +struct EmbeddingResponse { + data: Vec, +} + +// ─── Tool Embedding Index ─────────────────────────────────────────────────── + +/// Pre-computed tool embedding index for RAG pre-filtering. +/// +/// Built once at orchestrator startup, reused for every query. +#[derive(Debug, Clone)] +pub struct ToolEmbeddingIndex { + /// Tool names in registration order. + tool_names: Vec, + /// L2-normalized embeddings, one per tool. Shape: `[n_tools][n_dim]`. + embeddings: Vec>, +} + +/// A scored tool result from the pre-filter. +#[derive(Debug, Clone, Serialize)] +pub struct ScoredTool { + pub name: String, + pub score: f32, +} + +impl ToolEmbeddingIndex { + /// Build the index by embedding all tool descriptions. + /// + /// Each tool is embedded as `"name: description"` text. The embeddings are + /// mean-pooled (if per-token) and L2-normalized for cosine similarity. + pub async fn build( + endpoint: &str, + tools: &[(String, String)], // (name, description) pairs + ) -> Result { + if tools.is_empty() { + return Ok(Self { + tool_names: Vec::new(), + embeddings: Vec::new(), + }); + } + + let texts: Vec = tools + .iter() + .map(|(name, desc)| format!("{name}: {desc}")) + .collect(); + + let raw = embed_texts(endpoint, &texts).await?; + let embeddings: Vec> = raw.into_iter().map(l2_normalize).collect(); + let tool_names: Vec = tools.iter().map(|(n, _)| n.clone()).collect(); + + Ok(Self { + tool_names, + embeddings, + }) + } + + /// Select the top-K tool names by cosine similarity to the query. + /// + /// Returns `(selected_names, scored_tools)` where scored_tools is sorted + /// descending by score for debugging. + pub async fn filter( + &self, + endpoint: &str, + query: &str, + top_k: usize, + ) -> Result<(Vec, Vec), ToolPreFilterError> { + if self.tool_names.is_empty() { + return Ok((Vec::new(), Vec::new())); + } + + let raw_query = embed_texts(endpoint, &[query.to_string()]).await?; + let query_emb = l2_normalize( + raw_query + .into_iter() + .next() + .ok_or(ToolPreFilterError::EmptyResponse { count: 1 })?, + ); + + // Score all tools + let mut scored: Vec = self + .tool_names + .iter() + .enumerate() + .map(|(i, name)| ScoredTool { + name: name.clone(), + score: cosine_similarity(&query_emb, &self.embeddings[i]), + }) + .collect(); + + // Sort descending by score + scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)); + + // Take top-K + let k = top_k.min(scored.len()); + let selected: Vec = scored[..k].iter().map(|s| s.name.clone()).collect(); + + Ok((selected, scored)) + } + + /// Number of tools in the index. + pub fn len(&self) -> usize { + self.tool_names.len() + } + + /// Whether the index is empty. + pub fn is_empty(&self) -> bool { + self.tool_names.is_empty() + } +} + +// ─── Embedding Helpers ────────────────────────────────────────────────────── + +/// Embed a batch of texts via the `/embeddings` endpoint. +/// +/// The LFM2 `/embeddings` endpoint returns per-token embeddings (2D arrays). +/// We mean-pool them into a single vector per text. +async fn embed_texts( + endpoint: &str, + texts: &[String], +) -> Result>, ToolPreFilterError> { + let http = HttpClient::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|e| ToolPreFilterError::RequestFailed { + reason: format!("failed to build HTTP client: {e}"), + })?; + + let url = format!("{endpoint}/embeddings"); + let body = serde_json::json!({ "input": texts }); + + let response = http + .post(&url) + .json(&body) + .send() + .await + .map_err(|e| ToolPreFilterError::RequestFailed { + reason: format!("embedding request to {url}: {e}"), + })?; + + if !response.status().is_success() { + let status = response.status().as_u16(); + let body_text = response + .text() + .await + .unwrap_or_else(|_| "unknown".to_string()); + return Err(ToolPreFilterError::HttpError { + status, + body: body_text, + }); + } + + let result: EmbeddingResponse = + response + .json() + .await + .map_err(|e| ToolPreFilterError::RequestFailed { + reason: format!("failed to parse embedding response: {e}"), + })?; + + if result.data.is_empty() { + return Err(ToolPreFilterError::EmptyResponse { + count: texts.len(), + }); + } + + // Sort by index to ensure order matches input + let mut items = result.data; + items.sort_by_key(|item| item.index); + + items + .into_iter() + .map(|item| mean_pool_embedding(&item.embedding)) + .collect() +} + +/// Mean-pool per-token embeddings into a single vector. +/// +/// If already pooled (1D array of numbers), returns as-is. +/// If 2D (per-token), averages across the token dimension. +fn mean_pool_embedding(embedding: &serde_json::Value) -> Result, ToolPreFilterError> { + match embedding { + serde_json::Value::Array(arr) if arr.is_empty() => Ok(Vec::new()), + + // 1D: already pooled — [f32, f32, ...] + serde_json::Value::Array(arr) if arr[0].is_f64() => { + Ok(arr.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect()) + } + + // 2D: per-token — [[f32, ...], [f32, ...], ...] + serde_json::Value::Array(arr) if arr[0].is_array() => { + let tokens: Vec> = arr + .iter() + .filter_map(|row| { + row.as_array().map(|r| { + r.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect() + }) + }) + .collect(); + + if tokens.is_empty() { + return Ok(Vec::new()); + } + + let n_tokens = tokens.len(); + let n_dim = tokens[0].len(); + let mut result = vec![0.0_f32; n_dim]; + + for token in &tokens { + for (d, val) in token.iter().enumerate() { + if d < n_dim { + result[d] += val; + } + } + } + + for val in &mut result { + *val /= n_tokens as f32; + } + + Ok(result) + } + + _ => Err(ToolPreFilterError::RequestFailed { + reason: "unexpected embedding format (expected number[] or number[][])".to_string(), + }), + } +} + +/// L2-normalize a vector. Returns the normalized copy. +fn l2_normalize(vec: Vec) -> Vec { + let norm: f32 = vec.iter().map(|v| v * v).sum::().sqrt(); + if norm > 0.0 { + vec.into_iter().map(|v| v / norm).collect() + } else { + vec + } +} + +/// Cosine similarity between two L2-normalized vectors (= dot product). +fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cosine_identical_vectors() { + let v = l2_normalize(vec![1.0, 2.0, 3.0]); + let score = cosine_similarity(&v, &v); + assert!((score - 1.0).abs() < 1e-5, "identical vectors should have similarity ~1.0"); + } + + #[test] + fn cosine_orthogonal_vectors() { + let a = l2_normalize(vec![1.0, 0.0, 0.0]); + let b = l2_normalize(vec![0.0, 1.0, 0.0]); + let score = cosine_similarity(&a, &b); + assert!(score.abs() < 1e-5, "orthogonal vectors should have similarity ~0.0"); + } + + #[test] + fn cosine_opposite_vectors() { + let a = l2_normalize(vec![1.0, 0.0]); + let b = l2_normalize(vec![-1.0, 0.0]); + let score = cosine_similarity(&a, &b); + assert!((score + 1.0).abs() < 1e-5, "opposite vectors should have similarity ~-1.0"); + } + + #[test] + fn l2_normalize_zero_vector() { + let v = l2_normalize(vec![0.0, 0.0, 0.0]); + assert_eq!(v, vec![0.0, 0.0, 0.0], "zero vector unchanged"); + } + + #[test] + fn l2_normalize_unit_vector() { + let v = l2_normalize(vec![3.0, 4.0]); + let norm: f32 = v.iter().map(|x| x * x).sum::().sqrt(); + assert!((norm - 1.0).abs() < 1e-5, "normalized vector should have unit norm"); + assert!((v[0] - 0.6).abs() < 1e-5); + assert!((v[1] - 0.8).abs() < 1e-5); + } + + #[test] + fn mean_pool_1d_passthrough() { + let embedding = serde_json::json!([1.0, 2.0, 3.0]); + let result = mean_pool_embedding(&embedding).unwrap(); + assert_eq!(result, vec![1.0, 2.0, 3.0]); + } + + #[test] + fn mean_pool_2d_averaging() { + let embedding = serde_json::json!([[1.0, 2.0], [3.0, 4.0]]); + let result = mean_pool_embedding(&embedding).unwrap(); + assert_eq!(result, vec![2.0, 3.0]); // mean of [1,3]=2, mean of [2,4]=3 + } + + #[test] + fn empty_index_filter_returns_empty() { + let index = ToolEmbeddingIndex { + tool_names: Vec::new(), + embeddings: Vec::new(), + }; + // Can't call async filter in sync test, but verify construction + assert!(index.is_empty()); + assert_eq!(index.len(), 0); + } + + #[test] + fn index_len_and_empty() { + let index = ToolEmbeddingIndex { + tool_names: vec!["a".into(), "b".into()], + embeddings: vec![vec![1.0, 0.0], vec![0.0, 1.0]], + }; + assert_eq!(index.len(), 2); + assert!(!index.is_empty()); + } +} diff --git a/src-tauri/src/agent_core/tool_router.rs b/src-tauri/src/agent_core/tool_router.rs new file mode 100644 index 0000000..1e0cc8a --- /dev/null +++ b/src-tauri/src/agent_core/tool_router.rs @@ -0,0 +1,729 @@ +//! ToolRouter — dispatches model tool calls to MCP servers. +//! +//! The ToolRouter is the bridge between the LLM's tool call decisions and the +//! MCP server ecosystem. It handles: +//! - Validation (tool exists, arguments match schema) +//! - Confirmation flow (read-only → auto, mutable → confirm, destructive → warn) +//! - Execution via McpClient +//! - Retry with exponential backoff for transient errors +//! - Audit logging of every tool execution +//! - Undo stack entries for mutable/destructive actions + +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; + +use crate::agent_core::tokens::truncate_utf8; +use crate::inference::types::ToolCall; +use crate::mcp_client::errors::McpError; +use crate::mcp_client::types::ToolCallResult; +use crate::mcp_client::McpClient; + +use super::conversation::ConversationManager; +use super::permissions::{PermissionScope, PermissionStatus, PermissionStore}; +use super::types::{ + AuditStatus, ConfirmationRequest, ConfirmationResponse, NewUndoEntry, +}; + +// ─── Constants ────────────────────────────────────────────────────────────── + +/// Maximum retry attempts for transient tool execution errors. +const MAX_RETRIES: u32 = 2; + +/// Base delay between retries (doubles each attempt). +const RETRY_BASE_DELAY: Duration = Duration::from_millis(500); + +// ─── ToolRouter ───────────────────────────────────────────────────────────── + +/// Dispatches tool calls from the model to MCP servers and manages +/// the human-in-the-loop confirmation flow with tiered permissions. +pub struct ToolRouter { + /// Sender for confirmation requests (to the frontend). + confirm_tx: mpsc::Sender, + /// Receiver for confirmation responses (from the frontend). + confirm_rx: mpsc::Receiver, + /// Tiered permission grants (session + persistent). + pub permissions: PermissionStore, +} + +impl ToolRouter { + /// Create a new ToolRouter with confirmation channels. + /// + /// The caller must wire the other end of these channels to the frontend + /// (via Tauri IPC events). + pub fn new( + confirm_tx: mpsc::Sender, + confirm_rx: mpsc::Receiver, + ) -> Self { + Self { + confirm_tx, + confirm_rx, + permissions: PermissionStore::new(), + } + } + + /// Create a ToolRouter for testing (no confirmation flow — auto-confirms). + #[cfg(test)] + pub fn new_auto_confirm() -> (Self, mpsc::Sender, mpsc::Receiver) { + let (req_tx, req_rx) = mpsc::channel(16); + let (resp_tx, resp_rx) = mpsc::channel(16); + ( + Self { + confirm_tx: req_tx, + confirm_rx: resp_rx, + permissions: PermissionStore::new_in_memory(), + }, + resp_tx, + req_rx, + ) + } + + // ─── Dispatch ─────────────────────────────────────────────────────── + + /// Dispatch a batch of tool calls from the model. + /// + /// Processes tool calls sequentially (model expects ordered results). + /// Returns a result for each tool call. + pub async fn dispatch_tool_calls( + &mut self, + tool_calls: &[ToolCall], + session_id: &str, + mcp_client: &mut McpClient, + conversation: &ConversationManager, + ) -> Vec { + let mut results = Vec::new(); + + for tc in tool_calls { + let result = self + .dispatch_single(tc, session_id, mcp_client, conversation) + .await; + results.push(result); + } + + results + } + + /// Dispatch a single tool call with full lifecycle: + /// validate → confirm → execute → audit → undo. + pub async fn dispatch_single( + &mut self, + tool_call: &ToolCall, + session_id: &str, + mcp_client: &mut McpClient, + conversation: &ConversationManager, + ) -> ToolCallResult { + let start = Instant::now(); + + // 1. Validate + if let Err(e) = mcp_client.registry.validate_tool_call( + &tool_call.name, + &tool_call.arguments, + ) { + return self.log_and_return_error( + &tool_call.name, + &tool_call.arguments, + session_id, + conversation, + AuditStatus::Error, + false, + start, + &format!("validation failed: {e}"), + ); + } + + // 2. Check confirmation requirements + let needs_confirmation = mcp_client.registry.requires_confirmation(&tool_call.name); + let supports_undo = mcp_client.registry.supports_undo(&tool_call.name); + + // 3. Permission check — skip confirmation if tool has an active grant + if needs_confirmation + && self.permissions.check(&tool_call.name) == PermissionStatus::Allowed + { + tracing::debug!( + tool = %tool_call.name, + "skipping confirmation — permission granted" + ); + // Fall through to execution + } else if needs_confirmation { + // 4. Confirmation flow + let preview = generate_preview(&tool_call.name, &tool_call.arguments); + let is_destructive = is_destructive_action(&tool_call.name); + + let request = ConfirmationRequest { + request_id: uuid::Uuid::new_v4().to_string(), + tool_name: tool_call.name.clone(), + arguments: tool_call.arguments.clone(), + preview, + confirmation_required: true, + undo_supported: supports_undo, + is_destructive, + }; + + // Send confirmation request to frontend + if self.confirm_tx.send(request).await.is_err() { + return self.log_and_return_error( + &tool_call.name, + &tool_call.arguments, + session_id, + conversation, + AuditStatus::Error, + false, + start, + "failed to send confirmation request to frontend", + ); + } + + // Wait for user response + match self.confirm_rx.recv().await { + Some(ConfirmationResponse::Confirmed) => { + // Allow Once — proceed with execution, no grant stored + } + Some(ConfirmationResponse::ConfirmedForSession) => { + // Allow for Session — grant + proceed + self.permissions + .grant(&tool_call.name, PermissionScope::Session); + } + Some(ConfirmationResponse::ConfirmedAlways) => { + // Always Allow — persistent grant + proceed + self.permissions + .grant(&tool_call.name, PermissionScope::Always); + } + Some(ConfirmationResponse::EditedAndConfirmed { new_arguments }) => { + // Execute with modified arguments + return self + .execute_tool( + &tool_call.name, + &new_arguments, + &tool_call.id, + session_id, + mcp_client, + conversation, + supports_undo, + start, + ) + .await; + } + Some(ConfirmationResponse::Rejected) => { + return self.log_and_return_error( + &tool_call.name, + &tool_call.arguments, + session_id, + conversation, + AuditStatus::RejectedByUser, + false, + start, + "user rejected the tool call", + ); + } + None => { + return self.log_and_return_error( + &tool_call.name, + &tool_call.arguments, + session_id, + conversation, + AuditStatus::Error, + false, + start, + "confirmation channel closed", + ); + } + } + } + + // 4. Execute + self.execute_tool( + &tool_call.name, + &tool_call.arguments, + &tool_call.id, + session_id, + mcp_client, + conversation, + supports_undo, + start, + ) + .await + } + + // ─── Execution ────────────────────────────────────────────────────── + + /// Execute a tool call with retry logic. + #[allow(clippy::too_many_arguments)] + async fn execute_tool( + &self, + tool_name: &str, + arguments: &serde_json::Value, + _tool_call_id: &str, + session_id: &str, + mcp_client: &mut McpClient, + conversation: &ConversationManager, + supports_undo: bool, + start: Instant, + ) -> ToolCallResult { + let mut last_error: Option = None; + + for attempt in 0..=MAX_RETRIES { + if attempt > 0 { + let delay = RETRY_BASE_DELAY * 2u32.pow(attempt - 1); + tokio::time::sleep(delay).await; + } + + match mcp_client.call_tool(tool_name, arguments.clone()).await { + Ok(result) => { + let elapsed = start.elapsed().as_millis() as u64; + + // Audit log + let _ = conversation.db().insert_audit_entry( + session_id, + tool_name, + arguments, + result.result.as_ref(), + if result.success { + AuditStatus::Success + } else { + AuditStatus::Error + }, + true, // user confirmed (or auto-confirmed) + elapsed, + ); + + // Undo stack + if supports_undo && result.success { + let undo = NewUndoEntry { + tool_name: tool_name.to_string(), + action_type: infer_action_type(tool_name), + original_state: capture_original_state(tool_name, arguments), + new_state: capture_new_state(tool_name, &result), + }; + let _ = conversation.push_undo(session_id, &undo); + } + + return ToolCallResult { + tool_name: tool_name.to_string(), + success: result.success, + result: result.result, + error: result.error, + execution_time_ms: elapsed, + }; + } + Err(e) => { + if is_retriable_mcp_error(&e) && attempt < MAX_RETRIES { + last_error = Some(e.to_string()); + continue; + } + + return self.log_and_return_error( + tool_name, + arguments, + session_id, + conversation, + AuditStatus::Error, + true, + start, + &e.to_string(), + ); + } + } + } + + // All retries exhausted + self.log_and_return_error( + tool_name, + arguments, + session_id, + conversation, + AuditStatus::Error, + true, + start, + &last_error.unwrap_or_else(|| "all retries exhausted".to_string()), + ) + } + + // ─── Helpers ──────────────────────────────────────────────────────── + + /// Log an error result to the audit log and return a ToolCallResult. + #[allow(clippy::too_many_arguments)] + fn log_and_return_error( + &self, + tool_name: &str, + arguments: &serde_json::Value, + session_id: &str, + conversation: &ConversationManager, + status: AuditStatus, + user_confirmed: bool, + start: Instant, + error_msg: &str, + ) -> ToolCallResult { + let elapsed = start.elapsed().as_millis() as u64; + + let _ = conversation.db().insert_audit_entry( + session_id, + tool_name, + arguments, + None, + status, + user_confirmed, + elapsed, + ); + + ToolCallResult { + tool_name: tool_name.to_string(), + success: false, + result: None, + error: Some(error_msg.to_string()), + execution_time_ms: elapsed, + } + } +} + +// ─── Free Functions ───────────────────────────────────────────────────────── + +/// Check if an MCP error is retriable (transient). +fn is_retriable_mcp_error(err: &McpError) -> bool { + matches!( + err, + McpError::Timeout { .. } | McpError::ServerCrashed { .. } | McpError::TransportError { .. } + ) +} + +/// Determine if a tool action is destructive (delete, overwrite). +pub fn is_destructive_action(tool_name: &str) -> bool { + let name = tool_name.split('.').next_back().unwrap_or(tool_name); + matches!( + name, + "delete_file" | "delete_collection" | "delete_task" | "send_draft" + ) +} + +/// Infer the action type from a tool name (for undo entries). +fn infer_action_type(tool_name: &str) -> String { + let name = tool_name.split('.').next_back().unwrap_or(tool_name); + if name.starts_with("move") { + "move".to_string() + } else if name.starts_with("delete") { + "delete".to_string() + } else if name.starts_with("create") || name.starts_with("write") { + "create".to_string() + } else { + "write".to_string() + } +} + +/// Generate a human-readable preview for a tool call. +pub fn generate_preview(tool_name: &str, arguments: &serde_json::Value) -> String { + let name = tool_name.split('.').next_back().unwrap_or(tool_name); + + match name { + "write_file" => { + let path = arguments + .get("path") + .and_then(|v| v.as_str()) + .unwrap_or(""); + format!("Write to file: {path}") + } + "move_file" => { + let src = arguments + .get("source") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let dst = arguments + .get("destination") + .and_then(|v| v.as_str()) + .unwrap_or(""); + format!("Move: {src} → {dst}") + } + "delete_file" => { + let path = arguments + .get("path") + .and_then(|v| v.as_str()) + .unwrap_or(""); + format!("Delete file: {path}") + } + "copy_file" => { + let src = arguments + .get("source") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let dst = arguments + .get("destination") + .and_then(|v| v.as_str()) + .unwrap_or(""); + format!("Copy: {src} → {dst}") + } + "create_pdf" | "create_docx" => { + let path = arguments + .get("output_path") + .and_then(|v| v.as_str()) + .unwrap_or(""); + format!("Create document: {path}") + } + "create_task" => { + let title = arguments + .get("title") + .and_then(|v| v.as_str()) + .unwrap_or(""); + format!("Create task: {title}") + } + "send_draft" => { + let to = arguments + .get("to") + .and_then(|v| v.as_str()) + .unwrap_or(""); + format!("Send email to: {to}") + } + _ => { + // Generic preview + let args_preview = serde_json::to_string(arguments) + .unwrap_or_default(); + let truncated = if args_preview.len() > 100 { + format!("{}...", truncate_utf8(&args_preview, 100)) + } else { + args_preview + }; + format!("Execute {tool_name}: {truncated}") + } + } +} + +/// Capture the original state before a mutable action (for undo). +fn capture_original_state(tool_name: &str, arguments: &serde_json::Value) -> serde_json::Value { + let name = tool_name.split('.').next_back().unwrap_or(tool_name); + match name { + "move_file" => serde_json::json!({ + "path": arguments.get("source"), + }), + "delete_file" => serde_json::json!({ + "path": arguments.get("path"), + }), + "write_file" => serde_json::json!({ + "path": arguments.get("path"), + "existed_before": true, + }), + _ => serde_json::json!({ + "tool": tool_name, + "arguments": arguments, + }), + } +} + +/// Capture the new state after a mutable action (for undo). +fn capture_new_state(tool_name: &str, result: &ToolCallResult) -> serde_json::Value { + serde_json::json!({ + "tool": tool_name, + "success": result.success, + "result": result.result, + }) +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_destructive_action() { + assert!(is_destructive_action("filesystem.delete_file")); + assert!(is_destructive_action("email.send_draft")); + assert!(!is_destructive_action("filesystem.list_dir")); + assert!(!is_destructive_action("filesystem.write_file")); + } + + #[test] + fn test_infer_action_type() { + assert_eq!(infer_action_type("filesystem.move_file"), "move"); + assert_eq!(infer_action_type("filesystem.delete_file"), "delete"); + assert_eq!(infer_action_type("filesystem.create_pdf"), "create"); + assert_eq!(infer_action_type("filesystem.write_file"), "create"); + assert_eq!(infer_action_type("filesystem.copy_file"), "write"); + } + + #[test] + fn test_generate_preview_write() { + let args = serde_json::json!({"path": "/tmp/file.txt", "content": "hello"}); + let preview = generate_preview("filesystem.write_file", &args); + assert!(preview.contains("/tmp/file.txt")); + } + + #[test] + fn test_generate_preview_move() { + let args = serde_json::json!({"source": "/old/a.txt", "destination": "/new/a.txt"}); + let preview = generate_preview("filesystem.move_file", &args); + assert!(preview.contains("/old/a.txt")); + assert!(preview.contains("/new/a.txt")); + } + + #[test] + fn test_generate_preview_delete() { + let args = serde_json::json!({"path": "/tmp/old.txt"}); + let preview = generate_preview("filesystem.delete_file", &args); + assert!(preview.contains("Delete file")); + assert!(preview.contains("/tmp/old.txt")); + } + + #[test] + fn test_generate_preview_generic() { + let args = serde_json::json!({"query": "SELECT * FROM users"}); + let preview = generate_preview("data.query_sqlite", &args); + assert!(preview.contains("data.query_sqlite")); + } + + #[test] + fn test_capture_original_state_move() { + let args = serde_json::json!({"source": "/a.txt", "destination": "/b.txt"}); + let state = capture_original_state("filesystem.move_file", &args); + assert!(state.get("path").is_some()); + } + + #[test] + fn test_is_retriable_mcp_error() { + assert!(is_retriable_mcp_error(&McpError::Timeout { + tool: "t".into(), + timeout_ms: 1000, + })); + assert!(is_retriable_mcp_error(&McpError::ServerCrashed { + name: "s".into(), + reason: "gone".into(), + })); + assert!(!is_retriable_mcp_error(&McpError::UnknownTool { + name: "x".into(), + })); + } + + #[tokio::test] + async fn test_dispatch_rejected_tool_call() { + use crate::agent_core::database::AgentDatabase; + use crate::agent_core::conversation::ConversationManager; + use crate::mcp_client::types::{McpServersConfig, McpToolDefinition}; + use std::collections::HashMap; + + // Set up infrastructure + let db = AgentDatabase::open(":memory:").unwrap(); + let conv = ConversationManager::new(db); + conv.new_session("s1", "system").unwrap(); + + let mcp_config = McpServersConfig { + servers: HashMap::new(), + }; + let mut mcp = McpClient::new(mcp_config, None); + + // Register a tool that requires confirmation + mcp.registry.register_server_tools( + "filesystem", + vec![McpToolDefinition { + name: "write_file".to_string(), + description: "Write a file".to_string(), + params_schema: serde_json::json!({ + "type": "object", + "properties": {"path": {"type": "string"}, "content": {"type": "string"}}, + "required": ["path", "content"] + }), + returns_schema: serde_json::json!({}), + confirmation_required: true, + undo_supported: true, + }], + ); + + // Create router with auto-reject + let (mut router, resp_tx, mut req_rx) = ToolRouter::new_auto_confirm(); + + let tc = ToolCall { + id: "call_1".to_string(), + name: "filesystem.write_file".to_string(), + arguments: serde_json::json!({"path": "/tmp/test.txt", "content": "hello"}), + }; + + // Spawn a task that rejects the confirmation + tokio::spawn(async move { + let _req = req_rx.recv().await.unwrap(); + resp_tx.send(ConfirmationResponse::Rejected).await.unwrap(); + }); + + let result = router + .dispatch_single(&tc, "s1", &mut mcp, &conv) + .await; + + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("rejected")); + + // Check audit log + let entries = conv.db().get_audit_entries("s1").unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].result_status, AuditStatus::RejectedByUser); + } + + #[tokio::test] + async fn test_dispatch_validation_failure() { + use crate::agent_core::database::AgentDatabase; + use crate::agent_core::conversation::ConversationManager; + use crate::mcp_client::types::McpServersConfig; + use std::collections::HashMap; + + let db = AgentDatabase::open(":memory:").unwrap(); + let conv = ConversationManager::new(db); + conv.new_session("s1", "system").unwrap(); + + let mcp_config = McpServersConfig { + servers: HashMap::new(), + }; + let mut mcp = McpClient::new(mcp_config, None); + + let (mut router, _resp_tx, _req_rx) = ToolRouter::new_auto_confirm(); + + // Tool doesn't exist — should fail validation + let tc = ToolCall { + id: "call_1".to_string(), + name: "nonexistent.tool".to_string(), + arguments: serde_json::json!({}), + }; + + let result = router.dispatch_single(&tc, "s1", &mut mcp, &conv).await; + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("validation failed")); + } + + #[test] + fn test_permission_check_skips_confirmation() { + // Verify that a granted permission changes the check result + let (mut router, _resp_tx, _req_rx) = ToolRouter::new_auto_confirm(); + + // Initially needs confirmation + assert_eq!( + router.permissions.check("filesystem.write_file"), + PermissionStatus::NeedsConfirmation + ); + + // Grant session permission + router + .permissions + .grant("filesystem.write_file", PermissionScope::Session); + + // Now allowed + assert_eq!( + router.permissions.check("filesystem.write_file"), + PermissionStatus::Allowed + ); + + // Clear session — back to needing confirmation + router.permissions.clear_session(); + assert_eq!( + router.permissions.check("filesystem.write_file"), + PermissionStatus::NeedsConfirmation + ); + } + + #[test] + fn test_permission_always_grant_survives_session_clear() { + let (mut router, _resp_tx, _req_rx) = ToolRouter::new_auto_confirm(); + + router + .permissions + .grant("filesystem.write_file", PermissionScope::Always); + router.permissions.clear_session(); + + // Always grant should survive session clear + assert_eq!( + router.permissions.check("filesystem.write_file"), + PermissionStatus::Allowed + ); + } +} diff --git a/src-tauri/src/agent_core/types.rs b/src-tauri/src/agent_core/types.rs new file mode 100644 index 0000000..9e190bc --- /dev/null +++ b/src-tauri/src/agent_core/types.rs @@ -0,0 +1,342 @@ +//! Shared types for the agent core. +//! +//! Conversation messages, session metadata, undo entries, and confirmation +//! types used across the ConversationManager and ToolRouter. + +use serde::{Deserialize, Serialize}; + +use crate::inference::types::{Role, ToolCall}; + +// ─── Conversation Messages ────────────────────────────────────────────────── + +/// A single message stored in conversation history. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConversationMessage { + /// Auto-incremented row ID (set by DB on insert). + pub id: i64, + /// Which session this message belongs to. + pub session_id: String, + /// ISO 8601 timestamp. + pub timestamp: String, + /// Message role: system, user, assistant, or tool. + pub role: Role, + /// Text content (user messages, assistant text, system prompt). + pub content: Option, + /// Tool calls made by the assistant in this message. + pub tool_calls: Option>, + /// For `tool` role: the ID of the tool call this result belongs to. + pub tool_call_id: Option, + /// For `tool` role: the JSON result from executing the tool. + pub tool_result: Option, + /// Estimated token count for this message. + pub token_count: u32, +} + +/// Builder for creating conversation messages without specifying DB fields. +#[derive(Debug, Clone)] +pub struct NewMessage { + /// Message role. + pub role: Role, + /// Text content. + pub content: Option, + /// Tool calls (assistant messages). + pub tool_calls: Option>, + /// Tool call ID (tool result messages). + pub tool_call_id: Option, + /// Tool result (tool result messages). + pub tool_result: Option, +} + +// ─── Sessions ─────────────────────────────────────────────────────────────── + +/// Metadata for a conversation session. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Session { + /// Unique session identifier. + pub id: String, + /// ISO 8601 creation timestamp. + pub created_at: String, + /// ISO 8601 last activity timestamp. + pub last_activity: String, + /// Rolling session summary (populated after eviction). + pub summary: Option, + /// File paths touched during this session. + pub files_touched: Vec, + /// High-level decisions made during this session. + pub decisions_made: Vec, +} + +/// Session summary for context window inclusion. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionSummary { + /// The session this summary belongs to. + pub session_id: String, + /// Human-readable summary of past interactions. + pub summary_text: String, + /// Files that have been mentioned or modified. + pub files_touched: Vec, + /// Decisions the user or assistant has made. + pub decisions_made: Vec, +} + +// ─── Undo Stack ───────────────────────────────────────────────────────────── + +/// An entry in the undo stack for a mutable/destructive tool action. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UndoEntry { + /// Auto-incremented row ID. + pub id: i64, + /// Session this entry belongs to. + pub session_id: String, + /// ISO 8601 timestamp. + pub timestamp: String, + /// The tool that was executed. + pub tool_name: String, + /// Category of the action: "move", "delete", "create", "write". + pub action_type: String, + /// Serialized original state before the action. + pub original_state: serde_json::Value, + /// Serialized new state after the action. + pub new_state: serde_json::Value, + /// Whether this entry has been undone. + pub undone: bool, +} + +/// Input for creating a new undo entry (no DB fields). +#[derive(Debug, Clone)] +pub struct NewUndoEntry { + /// The tool that was executed. + pub tool_name: String, + /// Category of the action. + pub action_type: String, + /// Original state before the action. + pub original_state: serde_json::Value, + /// New state after the action. + pub new_state: serde_json::Value, +} + +// ─── Confirmation ─────────────────────────────────────────────────────────── + +/// Request sent to the frontend for user confirmation. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ConfirmationRequest { + /// Unique request ID for matching responses. + pub request_id: String, + /// The tool being called. + pub tool_name: String, + /// The arguments to the tool. + pub arguments: serde_json::Value, + /// Human-readable preview of what will happen. + pub preview: String, + /// Whether this tool requires confirmation. + pub confirmation_required: bool, + /// Whether the action supports undo. + pub undo_supported: bool, + /// Whether this is a destructive action (delete, overwrite). + pub is_destructive: bool, +} + +/// Response from the frontend after user decision. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ConfirmationResponse { + /// User confirmed the action (Allow Once). + Confirmed, + /// User confirmed for the remainder of this session (Allow for Session). + ConfirmedForSession, + /// User confirmed permanently — never ask again (Always Allow). + ConfirmedAlways, + /// User rejected the action. + Rejected, + /// User edited the arguments before confirming. + #[serde(rename = "edited")] + EditedAndConfirmed { + /// Modified arguments. + new_arguments: serde_json::Value, + }, +} + +// ─── Audit Log ────────────────────────────────────────────────────────────── + +/// A single entry in the tool execution audit log. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditEntry { + /// Auto-incremented row ID. + pub id: i64, + /// Session this entry belongs to. + pub session_id: String, + /// ISO 8601 timestamp. + pub timestamp: String, + /// The tool that was executed. + pub tool_name: String, + /// Arguments passed to the tool. + pub arguments: serde_json::Value, + /// Result returned by the tool (if successful). + pub result: Option, + /// Execution status. + pub result_status: AuditStatus, + /// Whether the user confirmed this action. + pub user_confirmed: bool, + /// How long the execution took (ms). + pub execution_time_ms: u64, +} + +/// Status of a tool execution in the audit log. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AuditStatus { + /// Tool executed successfully. + Success, + /// Tool execution returned an error. + Error, + /// User rejected the tool call. + RejectedByUser, + /// Tool call was skipped (e.g., auto-confirm for read-only). + Skipped, +} + +impl AuditStatus { + /// Convert to database string representation. + pub fn as_str(&self) -> &'static str { + match self { + AuditStatus::Success => "success", + AuditStatus::Error => "error", + AuditStatus::RejectedByUser => "rejected_by_user", + AuditStatus::Skipped => "skipped", + } + } + + /// Parse from database string representation. + pub fn parse(s: &str) -> Self { + match s { + "success" => AuditStatus::Success, + "error" => AuditStatus::Error, + "rejected_by_user" => AuditStatus::RejectedByUser, + "skipped" => AuditStatus::Skipped, + _ => AuditStatus::Error, + } + } +} + +// ─── Context Window Budget ────────────────────────────────────────────────── + +/// Snapshot of the current context window token usage. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ContextBudget { + /// Total context window size in tokens. + pub total: u32, + /// Tokens used by the system prompt. + pub system_prompt: u32, + /// Tokens used by tool definitions. + pub tool_definitions: u32, + /// Tokens used by conversation history. + pub conversation_history: u32, + /// Tokens reserved for the model's output response. + pub output_reservation: u32, + /// Remaining tokens (safety buffer excluded). + pub remaining: u32, +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_audit_status_roundtrip() { + for status in [ + AuditStatus::Success, + AuditStatus::Error, + AuditStatus::RejectedByUser, + AuditStatus::Skipped, + ] { + assert_eq!(AuditStatus::parse(status.as_str()), status); + } + } + + #[test] + fn test_audit_status_unknown_defaults_to_error() { + assert_eq!(AuditStatus::parse("unknown"), AuditStatus::Error); + } + + #[test] + fn test_confirmation_response_serialization() { + // Tagged enum with camelCase: {"type": "confirmed"} + let confirmed = ConfirmationResponse::Confirmed; + let json = serde_json::to_string(&confirmed).unwrap(); + assert_eq!(json, r#"{"type":"confirmed"}"#); + + let session = ConfirmationResponse::ConfirmedForSession; + let json = serde_json::to_string(&session).unwrap(); + assert_eq!(json, r#"{"type":"confirmedForSession"}"#); + + let always = ConfirmationResponse::ConfirmedAlways; + let json = serde_json::to_string(&always).unwrap(); + assert_eq!(json, r#"{"type":"confirmedAlways"}"#); + + let rejected = ConfirmationResponse::Rejected; + let json = serde_json::to_string(&rejected).unwrap(); + assert_eq!(json, r#"{"type":"rejected"}"#); + + let edited = ConfirmationResponse::EditedAndConfirmed { + new_arguments: serde_json::json!({"path": "/tmp/new"}), + }; + let json = serde_json::to_string(&edited).unwrap(); + assert!(json.contains(r#""type":"edited""#)); + assert!(json.contains("/tmp/new")); + } + + #[test] + fn test_confirmation_response_deserialization() { + // Frontend sends {"type": "confirmed"} etc. + let confirmed: ConfirmationResponse = + serde_json::from_str(r#"{"type":"confirmed"}"#).unwrap(); + assert!(matches!(confirmed, ConfirmationResponse::Confirmed)); + + let session: ConfirmationResponse = + serde_json::from_str(r#"{"type":"confirmedForSession"}"#).unwrap(); + assert!(matches!(session, ConfirmationResponse::ConfirmedForSession)); + + let edited: ConfirmationResponse = + serde_json::from_str(r#"{"type":"edited","new_arguments":{"path":"/tmp"}}"#).unwrap(); + assert!(matches!(edited, ConfirmationResponse::EditedAndConfirmed { .. })); + } + + #[test] + fn test_confirmation_request_serialization() { + let req = ConfirmationRequest { + request_id: "r1".to_string(), + tool_name: "filesystem.write_file".to_string(), + arguments: serde_json::json!({"path": "/tmp/test.txt"}), + preview: "Write to file: /tmp/test.txt".to_string(), + confirmation_required: true, + undo_supported: true, + is_destructive: false, + }; + let json = serde_json::to_string(&req).unwrap(); + // camelCase: requestId, toolName, undoSupported, isDestructive + assert!(json.contains("requestId")); + assert!(json.contains("toolName")); + assert!(json.contains("undoSupported")); + assert!(json.contains("isDestructive")); + assert!(json.contains("confirmationRequired")); + // Should NOT contain snake_case + assert!(!json.contains("request_id")); + assert!(!json.contains("tool_name")); + } + + #[test] + fn test_new_message_builder() { + let msg = NewMessage { + role: Role::User, + content: Some("hello".to_string()), + tool_calls: None, + tool_call_id: None, + tool_result: None, + }; + assert_eq!(msg.role, Role::User); + assert_eq!(msg.content.as_deref(), Some("hello")); + } +} diff --git a/src-tauri/src/commands/chat.rs b/src-tauri/src/commands/chat.rs new file mode 100644 index 0000000..0e9ee4c --- /dev/null +++ b/src-tauri/src/commands/chat.rs @@ -0,0 +1,3672 @@ +//! Tauri IPC commands for the chat interface. +//! +//! These commands are called from the React frontend via `invoke()`. +//! They bridge the frontend to the agent core (ConversationManager, +//! ToolRouter, and InferenceClient). + +use std::sync::Mutex; + +use futures::StreamExt; +use serde::Serialize; +use uuid::Uuid; + +use crate::agent_core::permissions::{PermissionScope, PermissionStatus, PermissionStore}; +// NOTE: response_analysis functions (is_incomplete_response, is_deflection_response) +// remain in the codebase and are tested, but are no longer called from the agent loop. +// They are available for the Orchestrator (ADR-009) or re-enablement via config. +// Tests below still exercise them for regression coverage. +use crate::agent_core::tokens::truncate_utf8; +use crate::agent_core::tool_router::{generate_preview, is_destructive_action}; +use crate::agent_core::{AuditStatus, ConfirmationRequest, ConfirmationResponse}; +use crate::agent_core::ConversationManager; +use crate::inference::config::{find_config_path, load_models_config}; +use crate::inference::types::{SamplingOverrides, ToolDefinition}; +use crate::inference::InferenceClient; +use crate::mcp_client::{CategoryRegistry, McpClient, ToolResolution}; +use crate::{PendingConfirmation, TokioMutex}; + +// ─── Two-Pass Tool Selection ──────────────────────────────────────────────── + +/// Tracks the two-pass tool selection state within the agent loop. +/// +/// On `Categories` phase, the model sees ~15 category meta-tools (~1,500 tokens). +/// On `Expanded`, the model sees real tools from selected categories. +/// On `Flat` (legacy), all tools are sent every turn (~8,670 tokens). +#[derive(Debug, Clone)] +enum ToolSelectionPhase { + /// First turn: model sees category meta-tools. + Categories { + /// The category registry used for expansion. + cat_registry: CategoryRegistry, + }, + /// Subsequent turns: model sees real tools from selected categories. + Expanded { + /// Category names that were selected (retained for diagnostics). + _selected_categories: Vec, + }, + /// Legacy flat mode: all tools every turn. + Flat, +} + +/// Minimum number of registered tools to activate two-pass mode. +/// Below this threshold, flat mode is used regardless of config. +/// Set to 30 because category meta-tools confuse LFM2-24B-A2B at ≤21 tools +/// (model responds with text instead of calling tools). Two-pass is only +/// worthwhile at 67+ tools where it saves ~7k tokens/turn. +const TWO_PASS_MIN_TOOLS: usize = 30; + +// ─── Response Types ───────────────────────────────────────────────────────── + +/// Session start response. +#[derive(Debug, Serialize)] +pub struct SessionInfo { + pub session_id: String, + /// Whether this is a newly created session or a resumed one. + pub resumed: bool, +} + +// ─── System prompt ────────────────────────────────────────────────────────── + +/// Identity and intro — static portion of the system prompt. +/// +/// Kept short: research shows small LLMs perform better with concise identity +/// statements. The capabilities section (dynamic) is inserted after this. +const SYSTEM_PROMPT_INTRO: &str = "\ +You are LocalCowork, a private on-device AI assistant. You call tools to help the user."; + +/// Behavioral rules and few-shot examples — dynamic portion of the system prompt. +/// +/// Optimized for small LLMs (24B MoE) based on research: +/// - XML section tags for clear structure (models parse sections, not paragraphs) +/// - Pre-computed relative dates (model doesn't need to reason about "today") +/// - ≤6 rules (small models lose track beyond 5-7 instructions) +/// - Calendar/date example in position 1 (primacy effect) +/// - Concrete dates in examples (no indirection) +fn system_prompt_rules(today: &str, tomorrow: &str, week_start: &str, week_end: &str) -> String { + let home = dirs::home_dir() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_else(|| { + if cfg!(target_os = "windows") { + r"C:\Users\user".to_string() + } else if cfg!(target_os = "macos") { + "/Users/user".to_string() + } else { + "/home/user".to_string() + } + }); + + format!("\ +\n\ +1. Use fully-qualified tool names: filesystem.list_dir, NOT list_dir.\n\ +2. Use absolute paths: {home}/Documents/file.txt, NOT ~/Documents/file.txt. \ +If a WORKING FOLDER is set, use ONLY the exact paths listed there.\n\ +3. READ tools: call immediately. WRITE tools: call directly (system shows confirmation).\n\ +4. After a scan returns results, present findings and STOP. Do NOT auto-chain to \ +mutable tools (encrypt, delete, move) without the user asking.\n\ +5. Never call the same tool with the same arguments twice. Use the result you already have.\n\ +6. Be concise. Respond after 1-3 tool calls unless the user asked for exhaustive processing.\n\ +\n\n\ +\n\ +Example 1 — calendar query (use pre-computed dates, never ask the user):\n\ + User: \"What's on my calendar today?\"\n\ + You call: calendar.list_events({{\"start_date\": \"{today}\", \"end_date\": \"{today}\"}})\n\ + User: \"Any meetings tomorrow?\"\n\ + You call: calendar.list_events({{\"start_date\": \"{tomorrow}\", \"end_date\": \"{tomorrow}\"}})\n\ + User: \"What do I have this week?\"\n\ + You call: calendar.list_events({{\"start_date\": \"{week_start}\", \"end_date\": \"{week_end}\"}})\n\ + WRONG: Asking the user what today's date is. You already know it.\n\n\ +Example 2 — file listing:\n\ + User: \"List my Documents folder.\"\n\ + You call: filesystem.list_dir({{\"path\": \"{home}/Documents\"}})\n\n\ +Example 3 — security scan:\n\ + User: \"Scan for secrets.\"\n\ + You call: security.scan_for_secrets({{\"path\": \"{home}/Projects\"}})\n\ + Then: present findings and STOP.\n\ +") +} + +/// Build the system prompt with dynamic tool capabilities from the MCP registry. +/// +/// Structure (optimized for small LLMs): +/// ```text +/// Identity (1 line) +/// block with pre-computed dates +/// from MCP registry +/// consolidated behavioral rules +/// few-shot examples (calendar first) +/// ``` +/// +/// Key optimizations for 24B MoE models: +/// - XML section tags (research: small models parse structured prompts better) +/// - Pre-computed relative dates (today, tomorrow, week range) — no reasoning needed +/// - Date block at position 2 (high primacy) and repeated in rules reminder +/// - ≤6 rules instead of 12 (small models lose track beyond 5-7) +fn build_system_prompt( + registry: &crate::mcp_client::registry::ToolRegistry, + two_pass_active: bool, +) -> String { + use chrono::{Datelike, Duration}; + + let capabilities = registry.capability_summary(); + + // Pre-compute all relative dates so the model never needs to reason about them. + // Research: small LLMs fail at date arithmetic; pre-computing eliminates the problem. + let now = chrono::Local::now(); + let today = now.format("%Y-%m-%d").to_string(); + let day_of_week = now.format("%A").to_string(); // e.g. "Monday" + let tomorrow = (now + Duration::days(1)).format("%Y-%m-%d").to_string(); + + // Compute Monday (start) and Sunday (end) of the current week + let weekday_num = now.weekday().num_days_from_monday(); // Mon=0, Sun=6 + let week_start = (now - Duration::days(weekday_num as i64)) + .format("%Y-%m-%d") + .to_string(); + let week_end = (now + Duration::days((6 - weekday_num) as i64)) + .format("%Y-%m-%d") + .to_string(); + + let time_str = now.format("%H:%M").to_string(); + + // Date block — prominent, structured, with pre-computed values. + // Placed immediately after identity for maximum primacy. + let date_block = format!( + "\n\ + today = {today} ({day_of_week})\n\ + tomorrow = {tomorrow}\n\ + this_week = {week_start} to {week_end}\n\ + current_time = {time_str}\n\ + Use these exact values when the user says \"today\", \"tomorrow\", \"this week\".\n\ + NEVER ask the user for a date.\n\ + " + ); + + let rules = system_prompt_rules(&today, &tomorrow, &week_start, &week_end); + + if two_pass_active { + let two_pass_instruction = "\n\nIMPORTANT: You will first see category-level tools \ + (like file_browse, image_ocr, data_analysis, etc.). Call 1-3 categories that match \ + the user's request. You will then receive the specific tools within those categories. \ + Always select the categories FIRST before trying to use specific tools. \ + After selecting categories and receiving the expanded tools, call the minimum \ + tools needed to answer the user's question, then provide your response."; + format!( + "{SYSTEM_PROMPT_INTRO}\n\n\ + {date_block}\n\n\ + \n{capabilities}\n\ + {two_pass_instruction}\n\n\ + {rules}" + ) + } else { + format!( + "{SYSTEM_PROMPT_INTRO}\n\n\ + {date_block}\n\n\ + \n{capabilities}\n\n\n\ + {rules}" + ) + } +} + +/// Maximum number of tool-call round-trips per user message. +/// +/// Each round allows one model response + one set of tool executions. +/// Complex tasks (e.g., OCR on 10 files) may use many rounds. +/// The model gets one call per tool per round (it can batch multiple +/// tool calls in a single response, but typically does one at a time). +const MAX_TOOL_ROUNDS: usize = 10; + +/// Maximum consecutive empty responses before forcing a summary. +/// +/// If the model returns 0 text AND 0 tool calls this many times in a row, +/// it's stuck (likely due to context confusion or timeout). We inject a +/// summary prompt to force text output. +const MAX_EMPTY_RETRIES: usize = 2; + +/// Maximum consecutive rounds with ALL tool calls failing before injecting +/// a corrective hint. +/// +/// When the model repeatedly calls the same non-existent tool (e.g., +/// `filesystem.rename_file` instead of `filesystem.move_file`), this +/// prevents burning all 20 rounds on the same error. After this many +/// consecutive all-error rounds, we inject a hint telling the model +/// which tools actually exist. +const MAX_CONSECUTIVE_ERROR_ROUNDS: usize = 2; + +/// Maximum times a single tool can fail before it's removed from the tool +/// definitions and the model is told to stop retrying. +/// +/// This catches the case where the model alternates between a succeeding tool +/// and a failing one — the per-round counter (`consecutive_error_rounds`) resets +/// on every success, so this per-tool counter is the only thing that can break +/// that loop. +const MAX_SAME_TOOL_FAILURES: usize = 3; + +/// Maximum consecutive duplicate tool calls (same tool name with identical +/// arguments) before the agent loop breaks. +/// +/// When the model gets stuck calling the same tool repeatedly with identical +/// params (e.g., `list_directory("~/Downloads")` 3× in a row), the loop +/// should detect this and exit. +/// +/// Note: `consecutive_duplicate_count()` returns 1 for the first occurrence, +/// so a threshold of 2 means "one genuine duplicate" (the tool was called +/// twice with identical args). Before reaching this hard break, the soft +/// interception in the tool execution loop will skip the redundant call and +/// inject a "you already have these results" nudge, giving the model a +/// chance to produce text. +const MAX_DUPLICATE_TOOL_CALLS: usize = 2; + +/// Minimum remaining token budget to start a new agent loop round. +/// +/// If the context window has fewer than this many tokens remaining, the +/// agent loop exits early rather than risk context overflow and degraded +/// model quality. Set to accommodate a model response (~500 tokens) plus +/// a tool result (~1000 tokens). +const MIN_ROUND_TOKEN_BUDGET: u32 = 1500; + +/// Configuration for tool result compression. +const COMPRESSION_THRESHOLD_CHARS: usize = 3_000; +const MAX_TOOL_RESULT_CHARS: usize = 6_000; + +/// Truncate a tool result if it exceeds `MAX_TOOL_RESULT_CHARS`. +/// +/// Preserves the beginning of the result (which usually contains the most +/// useful information) and appends a truncation notice. +fn truncate_tool_result(result: &str, tool_name: &str) -> String { + if result.len() <= MAX_TOOL_RESULT_CHARS { + return result.to_string(); + } + + // Try smart compression for known tool types + if let Some(summary) = compress_tool_result(result, tool_name) { + if summary.len() <= MAX_TOOL_RESULT_CHARS { + tracing::info!( + tool = %tool_name, + original_len = result.len(), + compressed_len = summary.len(), + "tool result compressed via smart summary" + ); + return summary; + } + } + + // Fall back to simple truncation + let truncated = &result[..MAX_TOOL_RESULT_CHARS]; + tracing::warn!( + tool = %tool_name, + original_len = result.len(), + truncated_to = MAX_TOOL_RESULT_CHARS, + "tool result truncated — exceeded MAX_TOOL_RESULT_CHARS" + ); + format!( + "{truncated}\n\n[... truncated: showing first {MAX_TOOL_RESULT_CHARS} of {} chars]", + result.len() + ) +} + +/// Compress tool results using smart extraction for known data patterns. +/// +/// For directory listings, extracts just filenames and counts. +/// For search results, extracts matches and count. +/// For JSON/structured data, extracts key summaries. +/// +/// Returns None if compression isn't beneficial for this tool type. +fn compress_tool_result(result: &str, tool_name: &str) -> Option { + // Only compress for read/search type operations + let compressible_tools = [ + "list_dir", + "search_files", + "scan_for_secrets", + "scan_for_pii", + "query_knowledge", + "list_events", + "list_tasks", + "search_emails", + ]; + + let is_compressible = compressible_tools + .iter() + .any(|t| tool_name.contains(t)); + + if !is_compressible || result.len() < COMPRESSION_THRESHOLD_CHARS { + return None; + } + + // For directory listings: extract just file/dir names and summary + if tool_name.contains("list_dir") { + return compress_directory_listing(result); + } + + // For search results: extract match counts and key matches + if tool_name.contains("search") || tool_name.contains("scan") { + return compress_search_results(result); + } + + // For JSON-like results: extract key fields + if result.starts_with('{') || result.starts_with('[') { + return compress_json_result(result); + } + + None +} + +/// Compress a directory listing to just names and counts. +fn compress_directory_listing(result: &str) -> Option { + let lines: Vec<&str> = result.lines().collect(); + if lines.is_empty() { + return None; + } + + let mut files = Vec::new(); + let mut dirs = Vec::new(); + + for line in &lines { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + if trimmed.starts_with("📁") { + if let Some(name) = trimmed.strip_prefix("📁").map(|s| s.trim()) { + dirs.push(name.to_string()); + } + } else if trimmed.starts_with("📄") { + if let Some(name) = trimmed.strip_prefix("📄").map(|s| s.trim()) { + files.push(name); + } + } else { + // Try to extract name from parenthetical format: "name (123 KB)" + if let Some(paren_idx) = trimmed.find('(') { + let name = trimmed[..paren_idx].trim(); + if !name.is_empty() { + files.push(name); + } + } + } + } + + let total = files.len() + dirs.len(); + let mut summary = format!("Total: {} items ({} files, {} directories)\n\n", total, files.len(), dirs.len()); + + if !dirs.is_empty() { + summary.push_str("Directories:\n"); + for d in dirs.iter().take(20) { + summary.push_str(&format!(" 📁 {}\n", d)); + } + if dirs.len() > 20 { + summary.push_str(&format!(" ... and {} more\n", dirs.len() - 20)); + } + summary.push('\n'); + } + + if !files.is_empty() { + summary.push_str("Files:\n"); + for f in files.iter().take(30) { + summary.push_str(&format!(" 📄 {}\n", f)); + } + if files.len() > 30 { + summary.push_str(&format!(" ... and {} more\n", files.len() - 30)); + } + } + + Some(summary) +} + +/// Compress search/scan results to match counts and key findings. +fn compress_search_results(result: &str) -> Option { + let lower = result.to_lowercase(); + + // Try to extract match count + let count_patterns = [ + ("found ", " matches"), + ("matches: ", ""), + ("results: ", ""), + ("total: ", " items"), + ]; + + for (prefix, suffix) in &count_patterns { + if let Some(idx) = lower.find(prefix) { + let after_prefix = &result[idx + prefix.len()..]; + let end_idx = after_prefix + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(after_prefix.len()); + let count = &after_prefix[..end_idx]; + if !count.is_empty() && count.len() <= 6 { + // Found a count, now get first few matches + let matches: Vec<&str> = result + .lines() + .filter(|l| { + let l_lower = l.to_lowercase(); + !l_lower.contains("found") + && !l_lower.contains("total") + && !l_lower.contains("scan") + && !l_lower.contains("error") + && l.trim().len() > 3 + }) + .take(10) + .collect(); + + let mut summary = format!("{}{}{}\n\n", prefix, count, suffix); + if !matches.is_empty() { + summary.push_str("Key findings:\n"); + for m in matches { + let trimmed = m.trim(); + if trimmed.len() > 100 { + summary.push_str(&format!(" {}\n", &trimmed[..100])); + } else { + summary.push_str(&format!(" {}\n", trimmed)); + } + } + } + return Some(summary); + } + } + } + + // Fallback: just take first 15 lines + let lines: Vec<&str> = result.lines().take(15).collect(); + if lines.is_empty() { + return None; + } + + let summary = format!( + "[... {} lines total ...]\n\n{}", + result.lines().count(), + lines.join("\n") + ); + Some(summary) +} + +/// Compress JSON results by extracting key fields. +fn compress_json_result(result: &str) -> Option { + let parsed: serde_json::Value = serde_json::from_str(result).ok()?; + + // For arrays, extract count and first few items + if let Some(arr) = parsed.as_array() { + if arr.is_empty() { + return Some("[] (empty)".to_string()); + } + + let count = arr.len(); + let mut summary = format!("[{} items]\n\n", count); + + for (i, item) in arr.iter().take(5).enumerate() { + summary.push_str(&format!("{}. ", i + 1)); + if let Some(obj) = item.as_object() { + // Extract common "name" or "text" fields + if let Some(name) = obj.get("name").and_then(|v| v.as_str()) { + summary.push_str(&format!("name: {}", name)); + } else if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { + let preview = if text.len() > 50 { + format!("{}...", &text[..50]) + } else { + text.to_string() + }; + summary.push_str(&preview); + } else if let Some(path) = obj.get("path").and_then(|v| v.as_str()) { + summary.push_str(&format!("path: {}", path)); + } else { + // Just stringify the object + let s = serde_json::to_string(item).ok()?; + let preview = if s.len() > 80 { + format!("{}...", &s[..80]) + } else { + s + }; + summary.push_str(&preview); + } + } else if let Some(s) = item.as_str() { + let preview = if s.len() > 80 { + format!("{}...", &s[..80]) + } else { + s.to_string() + }; + summary.push_str(&preview); + } + summary.push('\n'); + } + + if count > 5 { + summary.push_str(&format!("... and {} more\n", count - 5)); + } + + return Some(summary); + } + + // For objects, extract key fields + if let Some(obj) = parsed.as_object() { + let key_fields = ["text", "content", "name", "path", "message", "result", "total", "count"]; + let mut summary = String::new(); + + for key in &key_fields { + if let Some(val) = obj.get(*key) { + if let Some(s) = val.as_str() { + summary.push_str(&format!("{}: {}\n", key, s)); + } else if let Some(n) = val.as_u64() { + summary.push_str(&format!("{}: {}\n", key, n)); + } + } + } + + // If we extracted nothing useful, return original + if summary.is_empty() { + // Just return first 500 chars + let preview = if result.len() > 500 { + format!("{}...", &result[..500]) + } else { + result.to_string() + }; + return Some(preview); + } + + return Some(summary); + } + + None +} + +// ─── Tool Definitions ────────────────────────────────────────────────────── + +/// Built-in tool definitions (filesystem operations handled in-process). +fn builtin_tool_definitions() -> Vec { + vec![ + ToolDefinition { + r#type: "function".to_string(), + function: crate::inference::types::FunctionDefinition { + name: "list_directory".to_string(), + description: "List files and directories at the given path. \ + Returns name, type (file/dir), size, and modification date \ + for each entry. Use ~/path for home-relative paths." + .to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory path to list, e.g. ~/Desktop" + } + }, + "required": ["path"] + }), + }, + }, + ToolDefinition { + r#type: "function".to_string(), + function: crate::inference::types::FunctionDefinition { + name: "read_file".to_string(), + description: "Read the text contents of a file at the given path. \ + Returns the file content as a string. Only works for text files." + .to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path to read, e.g. ~/Desktop/notes.txt" + } + }, + "required": ["path"] + }), + }, + }, + ] +} + +/// Build merged tool definitions: built-in + MCP tools from the registry. +/// +/// Built-in tools (`list_directory`, `read_file`) are suppressed when the MCP +/// registry already contains their equivalents (`filesystem.list_dir`, +/// `filesystem.read_file`). This avoids confusing the model with near-duplicate +/// tools, which causes it to pick the wrong one or get stuck in loops. +fn build_all_tool_definitions(mcp_client: &McpClient) -> Vec { + // Map of built-in tool name → MCP equivalent that supersedes it + let builtin_mcp_equivalents: &[(&str, &str)] = &[ + ("list_directory", "filesystem.list_dir"), + ("read_file", "filesystem.read_file"), + ]; + + // Only include built-ins whose MCP equivalent is NOT in the registry + let mut tools: Vec = builtin_tool_definitions() + .into_iter() + .filter(|tool| { + let name = &tool.function.name; + !builtin_mcp_equivalents.iter().any(|(builtin, mcp)| { + name == builtin && mcp_client.registry.get_tool(mcp).is_some() + }) + }) + .collect(); + + // Append MCP tool definitions from the registry + let mcp_tools = mcp_client.registry.to_openai_tools(); + for mcp_tool_json in mcp_tools { + if let Ok(tool_def) = serde_json::from_value::(mcp_tool_json) { + tools.push(tool_def); + } + } + + tools +} + +/// Build tool definitions from category meta-tools (two-pass mode). +/// +/// Each category becomes a synthetic OpenAI function with a single `"intent"` +/// parameter. The model calls these to signal which capability areas it needs. +/// Built-in tools (`list_directory`, `read_file`) are always included. +fn build_category_tool_definitions(cat_registry: &CategoryRegistry) -> Vec { + let mut tools = builtin_tool_definitions(); + + let cat_tools = cat_registry.to_openai_tools(); + for cat_json in cat_tools { + if let Ok(tool_def) = serde_json::from_value::(cat_json) { + tools.push(tool_def); + } + } + + tools +} + +// ─── Tool Execution ────────────────────────────────────────────────────── + +/// Execute a built-in tool call and return the result as a string. +fn execute_builtin_tool(name: &str, arguments: &serde_json::Value) -> String { + match name { + "list_directory" => { + let path = arguments + .get("path") + .and_then(|v| v.as_str()) + .unwrap_or("."); + match super::filesystem::list_directory(path.to_string()) { + Ok(entries) => { + if entries.is_empty() { + "Directory is empty.".to_string() + } else { + let mut lines = Vec::new(); + for e in &entries { + let type_icon = if e.entry_type == "dir" { + "📁" + } else { + "📄" + }; + let size_str = if e.entry_type == "dir" { + String::new() + } else { + format_file_size(e.size) + }; + lines.push(format!( + "{} {} {}", + type_icon, e.name, size_str + )); + } + lines.join("\n") + } + } + Err(e) => format!("Error: {e}"), + } + } + "read_file" => { + let path = arguments + .get("path") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let resolved = if path.starts_with('~') { + if let Some(home) = dirs::home_dir() { + home.join(path.strip_prefix("~/").unwrap_or(path)) + } else { + std::path::PathBuf::from(path) + } + } else { + std::path::PathBuf::from(path) + }; + match std::fs::read_to_string(&resolved) { + Ok(content) => { + if content.len() > 8000 { + format!( + "{}\n\n[... truncated, showing first ~8000 chars of {} total]", + truncate_utf8(&content, 8000), + content.len() + ) + } else { + content + } + } + Err(e) => format!("Error reading file: {e}"), + } + } + _ => format!("Unknown built-in tool: {name}"), + } +} + +// ─── Tool Execution Outcome ───────────────────────────────────────────────── + +/// Typed result from executing a single tool call in the agent loop. +/// +/// Preserves the success/failure distinction through types instead of string +/// matching. The agent loop uses this to: +/// - Feed the right text back to the model (via `model_text()`) +/// - Track error patterns for loop detection (via `is_error()`) +/// - Build correction hints from `ToolResolution` suggestions +#[derive(Debug)] +#[allow(dead_code)] // tool_name fields used for Debug output and future ToolRouter integration +enum ToolExecutionOutcome { + /// Tool executed successfully. + Success { tool_name: String, text: String }, + + /// Tool exists but returned an application-level error + /// (e.g., "file not found", "permission denied"). + ToolError { tool_name: String, text: String }, + + /// Tool name not found in the registry. `resolution` carries the + /// registry's analysis (suggestions, nearest matches, etc.). + UnknownTool { + tool_name: String, + resolution: ToolResolution, + text: String, + }, + + /// Infrastructure error: timeout, server crash, transport failure. + InfraError { tool_name: String, text: String }, +} + +impl ToolExecutionOutcome { + /// The text to feed back to the model as the tool result message. + fn model_text(&self) -> &str { + match self { + Self::Success { text, .. } + | Self::ToolError { text, .. } + | Self::UnknownTool { text, .. } + | Self::InfraError { text, .. } => text, + } + } + + /// Whether this outcome represents an error (any variant except Success). + fn is_error(&self) -> bool { + !matches!(self, Self::Success { .. }) + } +} + +/// Minimum similarity score (0.0–1.0) for auto-correcting tool names. +/// +/// Below this threshold, the registry returns `NotFound` instead of +/// `Corrected`. Set conservatively to avoid correcting to the wrong tool. +const TOOL_RESOLUTION_THRESHOLD: f64 = 0.5; + +/// Execute a tool call: built-in tools run in-process, MCP tools route +/// through the McpClient. +/// +/// Tool names are resolved via `ToolRegistry::resolve()` which handles: +/// - Exact matches (tool exists as-is) +/// - Unprefixed names (model dropped the `server.` prefix) +/// - Fuzzy correction (model hallucinated a similar name) +/// +/// Results are capped at `MAX_TOOL_RESULT_CHARS` to prevent a single large +/// result from consuming the entire context window budget. +async fn execute_tool( + name: &str, + arguments: &serde_json::Value, + mcp_client: &mut McpClient, +) -> ToolExecutionOutcome { + // Built-in tools (handled in-process for speed) + if name == "list_directory" || name == "read_file" { + let text = truncate_tool_result(&execute_builtin_tool(name, arguments), name); + return ToolExecutionOutcome::Success { + tool_name: name.to_string(), + text, + }; + } + + // Resolve tool name via the registry (exact → unprefixed → fuzzy) + let resolution = mcp_client.registry.resolve(name, TOOL_RESOLUTION_THRESHOLD); + + let resolved_name = match &resolution { + ToolResolution::Exact(n) => n.clone(), + ToolResolution::Unprefixed { resolved, original } => { + tracing::info!( + original = %original, + resolved = %resolved, + "resolved unprefixed tool name" + ); + resolved.clone() + } + ToolResolution::Corrected { + resolved, + original, + score, + } => { + tracing::info!( + original = %original, + resolved = %resolved, + score = score, + "auto-corrected tool name via edit distance" + ); + resolved.clone() + } + ToolResolution::NotFound { + original, + suggestions, + } => { + let text = if suggestions.is_empty() { + format!( + "Unknown tool: '{original}'. Use fully-qualified names \ + (e.g., filesystem.list_dir, security.scan_for_secrets)." + ) + } else { + format!( + "Unknown tool: '{original}'. Did you mean: {}?", + suggestions.join(", ") + ) + }; + return ToolExecutionOutcome::UnknownTool { + tool_name: original.clone(), + resolution, + text, + }; + } + }; + + // Track whether we auto-corrected the name so we can annotate errors. + let correction_context: Option = match &resolution { + ToolResolution::Corrected { + original, resolved, .. + } => Some(format!( + "NOTE: '{original}' does not exist. Auto-corrected to '{resolved}'. " + )), + _ => None, + }; + + // Expand `~` prefixes in string arguments before MCP dispatch. + // Built-in tools handle tilde themselves; MCP servers expect absolute paths. + let expanded_arguments = expand_tilde_in_arguments(arguments); + + // Execute via MCP + match mcp_client + .call_tool(&resolved_name, expanded_arguments) + .await + { + Ok(result) => { + let raw_text = if result.success { + extract_mcp_result_text(&result.result) + } else { + result + .error + .unwrap_or_else(|| "Tool execution failed".to_string()) + }; + let text = truncate_tool_result(&raw_text, &resolved_name); + if result.success { + ToolExecutionOutcome::Success { + tool_name: resolved_name, + text, + } + } else { + // Prepend correction context so the model understands the + // mis-dispatch: e.g. "rename_file does not exist, corrected + // to move_file. ". + let annotated = if let Some(ctx) = &correction_context { + format!("{ctx}{text}") + } else { + text + }; + ToolExecutionOutcome::ToolError { + tool_name: resolved_name, + text: annotated, + } + } + } + Err(e) => { + let base = format!("MCP error for '{resolved_name}': {e}"); + let text = if let Some(ctx) = &correction_context { + format!("{ctx}{base}") + } else { + base + }; + ToolExecutionOutcome::InfraError { + tool_name: resolved_name, + text, + } + } + } +} + +// `is_incomplete_response` and `is_deflection_response` are now in +// `agent_core::response_analysis` — no longer called from the agent loop, +// but still tested for regression coverage and available for the Orchestrator. + +/// Detect when a model's final text claims task completion but tool history +/// disagrees — i.e., the model confabulated a summary. +/// +/// This catches the pattern where the model says "I've successfully renamed +/// all 9 files" but `move_file` never appeared in `tool_call_history`. +/// +/// Returns `true` when the response looks like a confabulated completion. +/// +/// NOTE: Currently only used by tests. The agent loop no longer calls this +/// (continuation heuristics were removed in favour of trusting the model). +/// Retained for the Orchestrator (ADR-009) and regression test coverage. +#[cfg(test)] +fn has_unverified_completion(text: &str, tool_call_history: &[String]) -> bool { + let lower = text.to_lowercase(); + + // Only trigger on text that claims the task is done. + let claims_done = [ + "successfully", + "completed", + "all files", + "renamed", + "processed all", + "all done", + "task complete", + "finished processing", + ]; + let claims_completion = claims_done.iter().any(|s| lower.contains(s)); + if !claims_completion { + return false; + } + + // Mutable operations the model might claim to have done. + // If the model claims completion but never called any of these, it confabulated. + // This list covers all mutable tools across all 13 MCP servers. + let mutable_tools = [ + // Filesystem + "move_file", + "write_file", + "copy_file", + "create_dir", + "move_to_trash", + "rename_file", + // Task management + "create_task", + "update_task", + "delete_task", + "complete_task", + // Calendar + "create_event", + "update_event", + "delete_event", + // Email + "send_email", + "draft_email", + // Security + "encrypt_file", + "decrypt_file", + "propose_cleanup", + // Knowledge + "index_document", + "delete_index", + // Document + "convert_document", + "merge_documents", + ]; + + let called_any_mutable = tool_call_history + .iter() + .any(|t| mutable_tools.iter().any(|m| t.contains(m))); + + // If model claims done AND actually called mutable tools → not confabulated. + if called_any_mutable { + return false; + } + + // If the model never called any mutable tool but claims completion, it + // MAY be confabulated. However, we need to distinguish two cases: + // + // 1. Read-only task genuinely complete: "What files are in Downloads?" → + // model calls list_dir, says "all done" → NOT confabulation. + // + // 2. Mutable task not executed: "Rename all screenshots" → model calls + // list_dir + OCR but says "all files renamed" → IS confabulation. + // + // Heuristic: check if the completion text specifically claims a mutable + // action (rename, create, move, delete, write, send, encrypt, etc.). + // Generic "all done" / "completed" without mutable verbs is likely a + // legitimate read-only task completion. + // Edge case: ZERO tool calls but model claims completion — always confabulated. + // The model literally did nothing but claims to have finished. + if tool_call_history.is_empty() { + return true; + } + + // The model called tools but none were mutable. Check if the completion + // text specifically claims a mutable action (rename, create, move, etc.). + // Generic "all done" / "completed" without mutable verbs is likely a + // legitimate read-only task completion. + let mutable_action_claims = [ + "renamed", + "moved", + "deleted", + "created", + "written", + "sent", + "encrypted", + "decrypted", + "copied", + "converted", + "merged", + "updated", + "modified", + "saved", + ]; + + let claims_mutable_action = mutable_action_claims.iter().any(|v| lower.contains(v)); + + // Only confabulation if model claims a mutable action it never performed. + // "All done" after read-only work → not confabulation (let it exit). + // "Successfully renamed all files" after only reading → confabulation. + claims_mutable_action +} + +/// Detect if the model is stuck calling the same tool with the same arguments. +/// +/// Returns the number of consecutive times the last tool call signature has +/// repeated. The caller compares this against `MAX_DUPLICATE_TOOL_CALLS`. +/// +/// A "signature" is `"tool_name|arguments_json"` — if the model calls +/// `list_directory(path="~/Downloads")` three rounds in a row, this returns 3. +fn consecutive_duplicate_count(history: &[(String, String)]) -> usize { + if history.is_empty() { + return 0; + } + let last = &history[history.len() - 1]; + let mut count = 1; + for entry in history.iter().rev().skip(1) { + if entry.0 == last.0 && entry.1 == last.1 { + count += 1; + } else { + break; + } + } + count +} + +/// Format a correction hint from the `ToolResolution` data collected during +/// a round where all tool calls failed. +/// +/// Uses the suggestions already computed by `ToolRegistry::resolve()` — no +/// extra registry queries needed. +fn format_correction_hint(unknown_tools: &[(String, ToolResolution)]) -> String { + if unknown_tools.is_empty() { + return "TOOL ERROR: All tool calls in this round failed. \ + Check your tool names and try again." + .to_string(); + } + + let mut parts = Vec::new(); + for (name, resolution) in unknown_tools { + match resolution { + ToolResolution::NotFound { suggestions, .. } if !suggestions.is_empty() => { + parts.push(format!( + "'{name}' does not exist. Did you mean: {}?", + suggestions.join(", ") + )); + } + _ => { + parts.push(format!("'{name}' does not exist.")); + } + } + } + + format!( + "TOOL ERROR: {}. Use ONLY tools listed in your available tools.", + parts.join(" ") + ) +} + +/// Expand `~` or `~/` prefixes to the user's home directory in any string +/// argument value that looks like a file path. +/// +/// MCP servers expect absolute paths. The LLM frequently generates `~/...` +/// despite system-prompt rules. Rather than relying on each MCP server to +/// handle tildes, we expand them centrally before dispatch. +/// +/// Also fixes cross-platform path hallucination: +/// - `/home//...` on macOS → `/Users//...` +/// - `/Users/{user}/...` (placeholder) → real home dir +/// - `/Users//...` → real home dir +/// +/// Only replaces `~` or `~/...` at the start of a string value. Values like +/// `~other_user/` or `~suffix` are left untouched (we can't resolve those). +fn expand_tilde_in_arguments(args: &serde_json::Value) -> serde_json::Value { + match args { + serde_json::Value::Object(map) => { + let mut out = serde_json::Map::new(); + for (k, v) in map { + out.insert(k.clone(), expand_tilde_in_arguments(v)); + } + serde_json::Value::Object(out) + } + serde_json::Value::String(s) => { + if let Some(fixed) = fix_path_string(s) { + serde_json::Value::String(fixed) + } else { + serde_json::Value::String(s.clone()) + } + } + serde_json::Value::Array(arr) => { + serde_json::Value::Array(arr.iter().map(expand_tilde_in_arguments).collect()) + } + other => other.clone(), + } +} + +/// Fix a single path string: tilde expansion + cross-platform path correction. +/// +/// Returns `Some(fixed)` if the path was modified, `None` if no fix was needed. +/// +/// The model hallucinates paths in several forms: +/// - `~/Documents` → tilde shorthand +/// - `Projects` → bare relative dir name +/// - `/home/user/...` → wrong OS prefix (Linux on macOS/Windows) +/// - `/Users/{user}/...` → template placeholders +/// - `C:\Users\{user}\...` → template placeholders (Windows) +/// +/// All corrections use `std::path::Path::join` so separators are always +/// correct for the target platform. +fn fix_path_string(s: &str) -> Option { + use std::path::MAIN_SEPARATOR; + + let home = dirs::home_dir()?; + let home_str = home.to_string_lossy(); + + // ── 1. Tilde expansion: ~/... → /... ────────────────────────────── + if s.starts_with("~/") || s.starts_with("~\\") { + let rest = &s[2..]; + return Some(home.join(rest).to_string_lossy().into_owned()); + } + if s == "~" { + return Some(home_str.into_owned()); + } + + // ── 2. Bare relative path that matches a well-known home subdirectory ─── + // Model outputs "Projects" or "Downloads" instead of an absolute path. + // Guard: skip strings that look like absolute paths or URLs. + let looks_absolute = s.starts_with('/') + || s.starts_with('\\') + || (s.len() >= 3 && s.as_bytes()[1] == b':'); // C:\ or D:\ + if !looks_absolute && !s.contains("://") { + let first_segment = s.split(&['/', '\\'][..]).next().unwrap_or(s); + let well_known = [ + "Desktop", + "Documents", + "Downloads", + "Projects", + "Pictures", + "Music", + "Videos", // Windows + "Movies", // macOS + "Library", // macOS + ]; + if well_known.iter().any(|d| d.eq_ignore_ascii_case(first_segment)) { + return Some(home.join(s).to_string_lossy().into_owned()); + } + } + + // ── 3. Foreign OS home prefix → real home dir ───────────────────────── + // LLMs hallucinate Linux-style /home/... on macOS/Windows and + // macOS-style /Users/... on Linux/Windows. A foreign prefix means + // the entire path is hallucinated — rewrite any username unconditionally. + let foreign_prefixes: &[&str] = if cfg!(target_os = "macos") { + &["/home/"] // /Users/ is native on macOS — handled separately below + } else if cfg!(target_os = "linux") { + &["/Users/"] // /home/ is native on Linux — handled separately below + } else { + &["/home/", "/Users/"] // both are foreign on Windows + }; + + for prefix in foreign_prefixes { + if let Some(after_prefix) = s.strip_prefix(prefix) { + if let Some(slash_idx) = after_prefix.find('/') { + let rest = &after_prefix[slash_idx + 1..]; + return Some(home.join(rest).to_string_lossy().into_owned()); + } + } + } + + // ── 4. Native OS home prefix with template placeholder ────────────── + // /Users/{user}/... on macOS, /home/{user}/... on Linux. + // Only rewrite if the "username" is a known template placeholder — + // never silently replace a real username on a multi-user system. + let native_prefix: &str = if cfg!(target_os = "macos") { + "/Users/" + } else if cfg!(target_os = "linux") { + "/home/" + } else { + "" // Windows native prefix handled in section 5 + }; + + if !native_prefix.is_empty() && s.starts_with(native_prefix) { + // Already matches our home dir — nothing to fix + if s.starts_with(&*home_str) { + return None; + } + + let after_prefix = &s[native_prefix.len()..]; + if let Some(slash_idx) = after_prefix.find('/') { + let placeholder = &after_prefix[..slash_idx]; + let rest = &after_prefix[slash_idx + 1..]; + + let is_template = + (placeholder.starts_with('{') && placeholder.ends_with('}')) + || (placeholder.starts_with('<') && placeholder.ends_with('>')) + || (placeholder.starts_with('[') && placeholder.ends_with(']')); + + if is_template { + return Some(home.join(rest).to_string_lossy().into_owned()); + } + + // Common LLM placeholder words (not real usernames) + let placeholder_lower = placeholder.to_ascii_lowercase(); + let known_placeholders = ["user", "username", "your_name", "me"]; + if known_placeholders.contains(&placeholder_lower.as_str()) { + return Some(home.join(rest).to_string_lossy().into_owned()); + } + } + } + + // ── 5. Windows C:\Users\{placeholder}\... ─────────────────────────────── + let win_prefix = "C:\\Users\\"; + let win_prefix_fwd = "C:/Users/"; // model may use forward slashes on Windows too + for prefix in &[win_prefix, win_prefix_fwd] { + if let Some(after_prefix) = s.strip_prefix(prefix) { + // Already matches our home dir — nothing to fix + if s.starts_with(&*home_str) { + return None; + } + + let sep_idx = after_prefix.find(&['/', '\\'][..]); + + if let Some(idx) = sep_idx { + let placeholder = &after_prefix[..idx]; + let rest = &after_prefix[idx + 1..]; + + let is_template = + (placeholder.starts_with('{') && placeholder.ends_with('}')) + || (placeholder.starts_with('<') && placeholder.ends_with('>')) + || (placeholder.starts_with('[') && placeholder.ends_with(']')); + + if is_template { + return Some(home.join(rest).to_string_lossy().into_owned()); + } + + let placeholder_lower = placeholder.to_ascii_lowercase(); + let known_placeholders = ["user", "username", "your_name", "me"]; + if known_placeholders.contains(&placeholder_lower.as_str()) { + return Some(home.join(rest).to_string_lossy().into_owned()); + } + } + } + } + + // Suppress unused-variable warning on platforms where MAIN_SEPARATOR is `/` + let _ = MAIN_SEPARATOR; + + None +} + +/// Extract readable text from an MCP tool result. +/// +/// MCP results follow the format: `{ "content": [{ "type": "text", "text": "..." }] }` +/// The `text` field may itself be a JSON-serialized result object (e.g. from Python +/// pydantic `.model_dump()` + `json.dumps()`), so we attempt to extract a human-readable +/// summary from known fields like "text", "content", "message", or "result". +fn extract_mcp_result_text(result: &Option) -> String { + let Some(value) = result else { + return "No result returned.".to_string(); + }; + + // Try standard MCP content format + if let Some(content_arr) = value.get("content").and_then(|c| c.as_array()) { + let texts: Vec<&str> = content_arr + .iter() + .filter_map(|item| item.get("text").and_then(|t| t.as_str())) + .collect(); + if !texts.is_empty() { + let raw = texts.join("\n"); + // The text might be a JSON-serialized tool result (e.g. from json.dumps). + // Try to parse it and extract human-readable content. + return unwrap_tool_result_json(&raw); + } + } + + // Fallback: stringify the entire result + match serde_json::to_string_pretty(value) { + Ok(s) => s, + Err(_) => format!("{value:?}"), + } +} + +/// If `raw` is a JSON object with known text fields, extract and format them +/// for human readability. Otherwise return the original string unchanged. +/// +/// This handles the case where Python MCP servers serialize their result model +/// via `json.dumps(result.model_dump())`, producing strings like: +/// `{"text": "extracted text...", "confidence": 0.9, "engine": "lfm_vision"}` +fn unwrap_tool_result_json(raw: &str) -> String { + let Ok(parsed) = serde_json::from_str::(raw) else { + return raw.to_string(); // Not JSON, return as-is + }; + + let obj = match parsed.as_object() { + Some(o) => o, + None => return raw.to_string(), // JSON but not an object + }; + + // Look for a primary text field in priority order + for key in &["text", "content", "message", "result", "output"] { + if let Some(val) = obj.get(*key).and_then(|v| v.as_str()) { + if !val.is_empty() { + // Build a summary with the primary text and any useful metadata + let mut parts = vec![val.to_string()]; + for meta_key in &["engine", "confidence", "language", "page_count"] { + if let Some(meta_val) = obj.get(*meta_key) { + let display = match meta_val { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Number(n) => n.to_string(), + serde_json::Value::Bool(b) => b.to_string(), + other => other.to_string(), + }; + parts.push(format!("[{meta_key}: {display}]")); + } + } + return parts.join("\n"); + } + } + } + + // JSON object but no recognized text field — return the formatted JSON + raw.to_string() +} + +/// Format bytes into human-readable size. +fn format_file_size(bytes: u64) -> String { + if bytes < 1024 { + format!("({bytes} B)") + } else if bytes < 1024 * 1024 { + format!("({:.1} KB)", bytes as f64 / 1024.0) + } else { + format!("({:.1} MB)", bytes as f64 / (1024.0 * 1024.0)) + } +} + +/// Emit context budget to the frontend. +fn emit_context_budget( + app_handle: &tauri::AppHandle, + mgr: &ConversationManager, + session_id: &str, +) { + use tauri::Emitter; + if let Ok(budget) = mgr.get_budget(session_id) { + let _ = app_handle.emit( + "context-budget", + serde_json::json!({ + "total": budget.total, + "systemPrompt": budget.system_prompt, + "toolDefinitions": budget.tool_definitions, + "conversationHistory": budget.conversation_history, + "outputReservation": budget.output_reservation, + "remaining": budget.remaining, + }), + ); + } +} + +// ─── Commands ─────────────────────────────────────────────────────────────── + +/// Start or resume a chat session. +/// +/// On first launch, creates a new session. On subsequent app opens, +/// returns the most recent session that has user messages. +/// If explicitly called with `force_new = true`, always creates a new session. +#[tauri::command] +pub async fn start_session( + force_new: Option, + state: tauri::State<'_, Mutex>, + mcp_state: tauri::State<'_, TokioMutex>, +) -> Result { + // Phase 1: Check for resumable sessions (lock ConversationManager, then drop). + // std::sync::MutexGuard is !Send, so it MUST be dropped before any .await. + { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + + if force_new != Some(true) { + if let Ok(sessions) = mgr.db().list_sessions() { + for session in &sessions { + if let Ok(count) = mgr.db().message_count(&session.id) { + if count > 1 { + tracing::info!( + session_id = %session.id, + message_count = count, + "resuming existing session" + ); + return Ok(SessionInfo { + session_id: session.id.clone(), + resumed: true, + }); + } + } + } + } + } + } // mgr lock dropped here — safe to .await below + + // Phase 2: Build dynamic system prompt from MCP registry (async lock). + // Check if two-pass mode should be noted in the system prompt. + let system_prompt = { + let mcp = mcp_state.lock().await; + let cwd = std::env::current_dir().unwrap_or_default(); + let two_pass_active = if let Ok(cfg_path) = find_config_path(&cwd) { + load_models_config(&cfg_path) + .ok() + .and_then(|cfg| cfg.two_pass_tool_selection) + .unwrap_or(false) + && mcp.registry.len() > TWO_PASS_MIN_TOOLS + } else { + false + }; + build_system_prompt(&mcp.registry, two_pass_active) + }; + + // Phase 3: Create the new session (re-acquire ConversationManager). + let session_id = Uuid::new_v4().to_string(); + + { + let mut mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + + mgr.new_session(&session_id, &system_prompt) + .map_err(|e| format!("Failed to create session: {e}"))?; + + // Set accurate system prompt budget from the actual dynamic prompt + let actual_prompt_tokens = + crate::agent_core::tokens::estimate_system_prompt_tokens(&system_prompt); + mgr.set_system_prompt_budget(actual_prompt_tokens); + + tracing::info!( + session_id = %session_id, + prompt_tokens = actual_prompt_tokens, + "new chat session created with dynamic system prompt" + ); + } + + Ok(SessionInfo { + session_id, + resumed: false, + }) +} + +/// Send a user message and get an assistant response. +/// +/// Implements the agent loop: +/// 1. Persist user message, build history +/// 2. Call LLM with tool definitions (built-in + MCP) +/// 3. If model returns tool calls → execute them → feed results back → repeat +/// 4. When model returns text → stream it to frontend +#[tauri::command] +#[allow(clippy::too_many_arguments)] +pub async fn send_message( + session_id: String, + content: String, + working_directory: Option, + app_handle: tauri::AppHandle, + state: tauri::State<'_, Mutex>, + mcp_state: tauri::State<'_, TokioMutex>, + permission_state: tauri::State<'_, TokioMutex>, + pending_confirm: tauri::State<'_, PendingConfirmation>, + sampling_state: tauri::State<'_, TokioMutex>, + in_flight: tauri::State<'_, crate::InFlightRequests>, +) -> Result<(), String> { + use tauri::Emitter; + + // Request deduplication: check if there's already a request in flight for this session + { + let mut in_flight_guard = in_flight.lock().await; + if in_flight_guard.get(&session_id) == Some(&true) { + tracing::warn!(session_id = %session_id, "duplicate request ignored"); + return Ok(()); // Silently ignore duplicate request + } + in_flight_guard.insert(session_id.clone(), true); + } + + // Generate trace ID for this request (for correlation across logs) + let trace_id = uuid::Uuid::new_v4().to_string()[..8].to_string(); + + tracing::info!(trace_id = %trace_id, session_id = %session_id, content_len = content.len(), "starting message processing"); + + // Read sampling config once at the start of this request. + let sampling_cfg = sampling_state.lock().await.clone(); + let tool_turn_sampling = SamplingOverrides { + temperature: Some(sampling_cfg.tool_temperature), + top_p: Some(sampling_cfg.tool_top_p), + }; + let conversational_sampling = SamplingOverrides { + temperature: Some(sampling_cfg.conversational_temperature), + top_p: Some(sampling_cfg.conversational_top_p), + }; + + // 1. Persist user message and build conversation history + let mut messages = { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + + mgr.add_user_message(&session_id, &content) + .map_err(|e| format!("Failed to save user message: {e}"))?; + + let evicted = mgr + .evict_if_needed(&session_id) + .map_err(|e| format!("Eviction error: {e}"))?; + if evicted > 0 { + tracing::info!(evicted_tokens = evicted, "evicted old messages"); + } + + mgr.build_chat_messages(&session_id) + .map_err(|e| format!("Failed to build messages: {e}"))? + }; + + // 1b. Inject date context directly into the user message when temporal words + // are detected. Small LLMs (24B) have strong training priors for 2023/2024 + // dates and will ignore system prompt dates. Putting the date IN the user + // message forces the model to see it as part of the query itself. + { + use chrono::{Datelike, Duration}; + + let content_lower = content.to_lowercase(); + let has_temporal = content_lower.contains("today") + || content_lower.contains("tomorrow") + || content_lower.contains("this week") + || content_lower.contains("next week") + || content_lower.contains("yesterday") + || content_lower.contains("calendar") + || content_lower.contains("schedule") + || content_lower.contains("meeting"); + + if has_temporal { + let now = chrono::Local::now(); + let today_str = now.format("%Y-%m-%d").to_string(); + let tomorrow_str = (now + Duration::days(1)).format("%Y-%m-%d").to_string(); + let weekday_num = now.weekday().num_days_from_monday(); + let week_start = (now - Duration::days(weekday_num as i64)) + .format("%Y-%m-%d") + .to_string(); + let week_end = (now + Duration::days((6 - weekday_num) as i64)) + .format("%Y-%m-%d") + .to_string(); + + let date_prefix = format!( + "[Today is {today_str}. Tomorrow is {tomorrow_str}. \ + This week is {week_start} to {week_end}.]\n" + ); + + // Find the last user message and prepend the date context + if let Some(last_user_msg) = messages + .iter_mut() + .rev() + .find(|m| m.role == crate::inference::types::Role::User) + { + if let Some(ref mut msg_content) = last_user_msg.content { + let original = msg_content.clone(); + msg_content.clear(); + msg_content.push_str(&date_prefix); + msg_content.push_str(&original); + tracing::info!( + date_injected = %today_str, + "injected date context into user message" + ); + } + } + } + } + + // 1b2. Inject working folder PATH (not file listing) into the user message. + // Same strategy as date injection: small LLMs ignore system prompt paths + // and hallucinate /path/to/... from training data. Putting the folder + // path IN the user message makes it impossible to ignore. + // + // IMPORTANT: Only the path goes here, NOT the file listing. If we put + // files in the user message, the model skips tool calls (it already has + // the answer) and the user never sees the tool trace UI. The full file + // listing stays in the system prompt to guide tool argument selection. + if let Some(ref dir) = working_directory { + if let Some(last_user_msg) = messages + .iter_mut() + .rev() + .find(|m| m.role == crate::inference::types::Role::User) + { + if let Some(ref mut msg_content) = last_user_msg.content { + let folder_prefix = format!( + "[Working folder: {dir}. Use tools on files in this folder.]\n" + ); + + // Prepend — but AFTER any date prefix that may already be there + let original = msg_content.clone(); + msg_content.clear(); + if original.starts_with("[Today is") { + // Date prefix exists — insert folder after it + if let Some(newline_pos) = original.find("]\n") { + let after_date = newline_pos + 2; // skip "]\n" + msg_content.push_str(&original[..after_date]); + msg_content.push_str(&folder_prefix); + msg_content.push_str(&original[after_date..]); + } else { + msg_content.push_str(&folder_prefix); + msg_content.push_str(&original); + } + } else { + msg_content.push_str(&folder_prefix); + msg_content.push_str(&original); + } + + tracing::info!( + working_directory = %dir, + "injected working folder into user message" + ); + } + } + } + + // 1c. Inject working directory context + file listing into the system message. + // This is a per-request overlay — not persisted in the DB — so it + // automatically reflects the user's current folder selection. + // Including the actual file listing is a product-level optimization: + // same pattern as Cowork's project indexing — the model sees concrete + // file names without needing to call list_dir first. + const MAX_FOLDER_ENTRIES: usize = 50; + + if let Some(ref dir) = working_directory { + let mut file_count: usize = 0; + if let Some(system_msg) = messages.first_mut() { + if system_msg.role == crate::inference::types::Role::System { + if let Some(ref mut content) = system_msg.content { + // Build the working folder context block with XML tags + let mut folder_ctx = format!( + "\n\ + Use ONLY the file paths listed below. \ + Do NOT invent or guess paths." + ); + + // List directory contents (skip hidden files, cap at 50) + if let Ok(entries) = std::fs::read_dir(dir) { + let mut files: Vec = entries + .filter_map(|e| e.ok()) + .filter(|e| { + !e.file_name() + .to_string_lossy() + .starts_with('.') + }) + .map(|e| { + let full_path = + e.path().to_string_lossy().into_owned(); + if e.path().is_dir() { + format!(" {full_path}/") + } else { + format!(" {full_path}") + } + }) + .collect(); + files.sort(); + + let total = files.len(); + file_count = total; + if total > MAX_FOLDER_ENTRIES { + files.truncate(MAX_FOLDER_ENTRIES); + files.push(format!( + " (and {} more files...)", + total - MAX_FOLDER_ENTRIES + )); + } + if !files.is_empty() { + folder_ctx.push_str("\nFiles:\n"); + folder_ctx.push_str(&files.join("\n")); + } + } + + folder_ctx.push_str("\n\n"); + + // RECENCY REMINDER: shorter repetition block for the end. + // Research shows repeating key instructions at the end improves + // accuracy for smaller models (primacy + recency positions). + let mut folder_reminder = format!( + "\n\n\n\ + working_folder = {dir}\n\ + Files:" + ); + if let Ok(entries) = std::fs::read_dir(dir) { + let mut files: Vec = entries + .filter_map(|e| e.ok()) + .filter(|e| { + !e.file_name() + .to_string_lossy() + .starts_with('.') + }) + .map(|e| { + let full_path = + e.path().to_string_lossy().into_owned(); + format!(" {full_path}") + }) + .collect(); + files.sort(); + if files.len() > MAX_FOLDER_ENTRIES { + files.truncate(MAX_FOLDER_ENTRIES); + } + if !files.is_empty() { + folder_reminder.push('\n'); + folder_reminder.push_str(&files.join("\n")); + } + } + folder_reminder.push_str( + "\nUse ONLY these paths. Do NOT invent paths.\n\ + " + ); + + // SANDWICH PATTERN: insert working folder at TOP and BOTTOM + // of system prompt. The model sees the file paths at the + // strongest positions (primacy + recency). + let original = content.clone(); + content.clear(); + + // TOP: Insert after the first paragraph (identity intro) + if let Some(pos) = original.find("\n\n") { + content.push_str(&original[..pos]); + content.push_str("\n\n"); + content.push_str(&folder_ctx); + content.push_str(&original[pos..]); + } else { + content.push_str(&folder_ctx); + content.push_str("\n\n"); + content.push_str(&original); + } + + // BOTTOM: Append reminder at the very end + content.push_str(&folder_reminder); + } + } + } + tracing::info!( + working_directory = %dir, + file_count, + "injected working folder into system prompt" + ); + } + + // 2. Create inference client and build merged tool list + let cwd = std::env::current_dir().unwrap_or_default(); + let config_path = + find_config_path(&cwd).map_err(|e| format!("Config error: {e}"))?; + let config = + load_models_config(&config_path).map_err(|e| format!("Config error: {e}"))?; + let mut client = InferenceClient::from_config(config.clone()) + .map_err(|e| format!("Inference client error: {e}"))?; + + // 2a. Build tool definitions — either flat (all tools) or category meta-tools. + // Two-pass mode sends ~15 categories on the first turn (~1,500 tokens) + // instead of all ~67 tools (~8,670 tokens). Selected categories are + // expanded to real tools on subsequent turns. + let (mut tool_phase, mut tools) = { + let mcp = mcp_state.lock().await; + let use_two_pass = config.two_pass_tool_selection.unwrap_or(false) + && mcp.registry.len() > TWO_PASS_MIN_TOOLS; + + if use_two_pass { + let cat_registry = CategoryRegistry::build(&mcp.registry); + let cat_tools = build_category_tool_definitions(&cat_registry); + tracing::info!( + category_count = cat_registry.len(), + tool_count_saved = mcp.registry.len(), + "two-pass mode: sending category meta-tools instead of all tools" + ); + ( + ToolSelectionPhase::Categories { cat_registry }, + cat_tools, + ) + } else { + let all_tools = build_all_tool_definitions(&mcp); + (ToolSelectionPhase::Flat, all_tools) + } + }; + + // Measure actual tool definition tokens and update the budget. + // The default TOOL_DEFINITIONS_BUDGET (2000) was calibrated for stub schemas. + // With real JSON Schema from zod-to-json-schema, 15 tools consume 5000-8000+ + // tokens. Using the measured value ensures accurate eviction timing. + { + let tools_json: Vec = tools + .iter() + .filter_map(|t| serde_json::to_value(t).ok()) + .collect(); + let actual_tool_tokens = + crate::agent_core::tokens::estimate_tool_definitions_tokens(&tools_json); + + tracing::info!( + tool_count = tools.len(), + tool_tokens = actual_tool_tokens, + two_pass = matches!(tool_phase, ToolSelectionPhase::Categories { .. }), + "measured actual tool definition tokens" + ); + + let mut mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.set_tool_definitions_budget(actual_tool_tokens); + } + + // Response text — set by either the orchestrator or the agent loop. + let mut full_response = String::new(); + // Set to true when the orchestrator already persisted the response to DB. + let mut already_persisted = false; + + // 2b. Dual-model orchestrator (ADR-009) — if enabled, try the planner+router + // pipeline before falling into the single-model agent loop. + if let Some(ref orch_config) = config.orchestrator { + if orch_config.enabled { + tracing::info!("orchestrator enabled — attempting dual-model pipeline"); + match crate::agent_core::orchestrator::orchestrate_dual_model( + &session_id, + &content, + &messages, + &config, + orch_config, + &app_handle, + &state, + &mcp_state, + ) + .await + { + Ok(result) if !result.fell_back => { + // Fix F3: Check if orchestrator "succeeded" but no tools were + // actually called. This happens when the router fails to produce + // bracket-format tool calls for every step. + let any_tool_called = result + .step_results + .iter() + .any(|r| r.tool_called.is_some()); + + if !result.all_steps_succeeded && !any_tool_called { + tracing::warn!( + session_id = %session_id, + failed_steps = result.step_results.len(), + "orchestrator: no tools called — falling back to single-model" + ); + // Fall through to single-model agent loop + } else { + tracing::info!( + steps = result.step_results.len(), + all_succeeded = result.all_steps_succeeded, + tools_called = any_tool_called, + "orchestrator completed — skipping single-model loop" + ); + // Set the response so the normal completion path (step 5) + // emits the properly-formatted stream-complete event. + // The orchestrator already persisted the message to the DB. + full_response = result.synthesis; + already_persisted = true; + } + } + Ok(_) => { + tracing::warn!( + "orchestrator fell back — continuing to single-model agent loop" + ); + } + Err(e) => { + tracing::warn!( + error = %e, + "orchestrator error — continuing to single-model agent loop" + ); + } + } + } + } + + // 3. Agent loop: call model → execute tools → repeat + // Variables used by both the agent loop and the force-summary path. + let mut empty_response_count: usize = 0; + let mut tool_call_history: Vec = Vec::new(); + + // ── Turn-level tool call accumulator ────────────────────────────── + // The bracket format emits one tool call per inference round, so a + // multi-tool response spans multiple rounds: + // assistant(toolCalls:[A]) → tool(resultA) → assistant(toolCalls:[B]) → ... + // + // To present this as a single "2 tools executed" block in the UI, we + // accumulate all tool calls under a stable message ID and re-emit the + // growing list on each round. The frontend upserts by ID. + let turn_message_id = chrono::Utc::now().timestamp_millis(); + let mut turn_tool_calls: Vec = Vec::new(); + + // Skip entirely if the orchestrator already produced a response. + if full_response.is_empty() { + + // Track (tool_name, arguments) pairs to detect duplicate calls + let mut tool_call_signatures: Vec<(String, String)> = Vec::new(); + let mut consecutive_error_rounds: usize = 0; + let mut tool_failure_counts: std::collections::HashMap = + std::collections::HashMap::new(); + + for round in 0..MAX_TOOL_ROUNDS { + // ── Token budget gate ────────────────────────────────────────── + // Before each LLM call, check that we have enough remaining + // tokens for a productive round. If not, break early to avoid + // context overflow and degraded model quality. + { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + let budget = mgr + .get_budget(&session_id) + .map_err(|e| format!("Budget error: {e}"))?; + if budget.remaining < MIN_ROUND_TOKEN_BUDGET { + tracing::warn!( + round = round, + remaining = budget.remaining, + threshold = MIN_ROUND_TOKEN_BUDGET, + "token budget exhausted — ending agent loop" + ); + break; + } + } + + tracing::info!( + session_id = %session_id, + round = round, + message_count = messages.len(), + total_content_bytes = messages.iter() + .map(|m| m.content.as_deref().unwrap_or("").len()) + .sum::(), + "=== AGENT LOOP ROUND START ===" + ); + + let mut round_text = String::new(); + let mut tool_calls_detected: Vec = Vec::new(); + + // Measure model inference time (from request to full response parsed). + let inference_start = std::time::Instant::now(); + + match client + .chat_completion_stream(messages.clone(), Some(tools.clone()), Some(tool_turn_sampling)) + .await + { + Ok(stream) => { + futures::pin_mut!(stream); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if let Some(token) = &chunk.token { + round_text.push_str(token); + if tool_calls_detected.is_empty() { + let _ = app_handle.emit( + "stream-token", + token.clone(), + ); + } + } + if let Some(ref calls) = chunk.tool_calls { + for tc in calls { + if !tool_calls_detected + .iter() + .any(|existing| existing.id == tc.id) + { + tool_calls_detected.push(tc.clone()); + } + } + } + } + Err(e) => { + tracing::warn!( + round = round, + error = %e, + "stream error in agent loop" + ); + // Don't abort the whole loop — treat as empty + // response and let retry logic handle it + break; + } + } + } + } + Err(e) => { + let fallback = + crate::inference::client::static_fallback_response(); + if let Some(token) = &fallback.token { + full_response = token.clone(); + let _ = app_handle.emit("stream-token", token.clone()); + } + tracing::warn!(error = %e, "all models unavailable, using static fallback"); + break; + } + } + + let inference_time_ms = inference_start.elapsed().as_millis() as u64; + + tracing::info!( + session_id = %session_id, + round = round, + round_text_len = round_text.len(), + tool_calls_count = tool_calls_detected.len(), + tool_names = ?tool_calls_detected.iter().map(|tc| tc.name.as_str()).collect::>(), + inference_time_ms = inference_time_ms, + "=== MODEL RESPONSE ===" + ); + + // ── Handle empty response (0 text AND 0 tool calls) ──────── + // This is abnormal — typically caused by timeout, context overflow, + // or model confusion. Retry a limited number of times, then force + // a summary. + if tool_calls_detected.is_empty() && round_text.trim().is_empty() { + empty_response_count += 1; + tracing::warn!( + round = round, + empty_count = empty_response_count, + max_retries = MAX_EMPTY_RETRIES, + "model returned empty response (0 text, 0 tools)" + ); + + if empty_response_count >= MAX_EMPTY_RETRIES { + tracing::warn!("max empty retries reached — forcing summary"); + break; + } + + // Inject a nudge prompt instead of retrying with identical messages. + // Retrying unchanged context causes the same stall. A new user message + // gives the model fresh input to work from. + let nudge = if tool_call_history.is_empty() { + "You returned an empty response. Please answer the user's question \ + or call the appropriate tool now." + .to_string() + } else { + format!( + "You returned an empty response after processing {} tool call(s). \ + If there are more files to process, call the next tool now. \ + If the task is complete, provide a final summary of what was done.", + tool_call_history.len() + ) + }; + + messages.push(crate::inference::types::ChatMessage { + role: crate::inference::types::Role::User, + content: Some(nudge), + tool_call_id: None, + tool_calls: None, + }); + + tracing::info!( + round = round, + tools_completed = tool_call_history.len(), + "injected nudge prompt after empty response" + ); + continue; + } + + // Reset empty counter on any successful response + empty_response_count = 0; + + // ── Text response (0 tool calls) — accept and exit ───────── + // When the model returns text without tool calls, it has decided + // the task is complete. Trust the model's judgment and exit. + // + // This is the same pattern as Claude Code: model produces text → + // loop ends. If the user wants more, they say "continue." + // + // Previously, heuristic detectors (is_incomplete_response, + // has_unverified_completion, is_deflection_response) would + // second-guess the model and inject continuation prompts. These + // caused more harm than good — a valid 324-char system info + // summary would trigger "FM-3 deflection" because it contained + // "let me know", causing the model to spiral into unnecessary + // tool calls and produce a worse answer. + // + // Multi-step tasks that need continuation belong in the + // Orchestrator (ADR-009), not in heuristic string-matching. + if tool_calls_detected.is_empty() { + full_response.push_str(&round_text); + break; + } + + // ── Two-pass category expansion ───────────────────────────── + // If we're in Categories phase and the model called category meta-tools, + // expand them to real tools for subsequent rounds. Category "tool calls" + // are NOT executed — they just tell us which capability areas are needed. + if let ToolSelectionPhase::Categories { ref cat_registry } = tool_phase { + let mut selected_categories: Vec = Vec::new(); + let mut direct_tool_calls: Vec = Vec::new(); + + for tc in &tool_calls_detected { + if cat_registry.is_category(&tc.name) { + selected_categories.push(tc.name.clone()); + } else { + // Model called a real tool directly — handle gracefully + direct_tool_calls.push(tc.clone()); + } + } + + if !selected_categories.is_empty() { + // Expand categories to real tool names + let expanded_names = cat_registry.expand_categories(&selected_categories); + + // Build expanded tool definitions from the live registry + let expanded_defs = { + let mcp = mcp_state.lock().await; + let mut defs = builtin_tool_definitions(); + let mcp_tools = mcp.registry.to_openai_tools_filtered(&expanded_names); + for tool_json in mcp_tools { + if let Ok(td) = + serde_json::from_value::(tool_json) + { + defs.push(td); + } + } + defs + }; + + tracing::info!( + session_id = %session_id, + round = round, + categories = ?selected_categories, + expanded_tool_count = expanded_defs.len(), + "two-pass: expanded categories to real tools" + ); + + // Update token budget for the expanded (smaller) tool set + { + let tools_json: Vec = expanded_defs + .iter() + .filter_map(|t| serde_json::to_value(t).ok()) + .collect(); + let expanded_tokens = + crate::agent_core::tokens::estimate_tool_definitions_tokens( + &tools_json, + ); + let mut mgr = + state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.set_tool_definitions_budget(expanded_tokens); + tracing::info!( + expanded_tool_tokens = expanded_tokens, + "updated token budget for expanded tools" + ); + } + + // Transition phase and update tools + tool_phase = ToolSelectionPhase::Expanded { + _selected_categories: selected_categories.clone(), + }; + tools = expanded_defs; + + // Inject an assistant message noting the category selection + // (in-memory only — not persisted, same pattern as continuation prompts) + let cat_text = format!( + "Selected capability areas: {}. Now proceeding with specific tools.", + selected_categories.join(", ") + ); + messages.push(crate::inference::types::ChatMessage { + role: crate::inference::types::Role::Assistant, + content: Some(cat_text), + tool_call_id: None, + tool_calls: None, + }); + + // If the model also called real tools directly, process them + if !direct_tool_calls.is_empty() { + tracing::info!( + direct_tool_count = direct_tool_calls.len(), + "two-pass: model also called real tools directly — \ + processing as fallback" + ); + tool_calls_detected = direct_tool_calls; + // Fall through to normal tool execution below + } else { + // Re-prompt with the expanded real tools — no tool execution + // this round. The model will now see the specific tools. + continue; + } + } + // If no categories were selected (model called only real tools), + // fall through to normal execution — graceful degradation. + } + + // ── Tool execution round ────────────────────────────────────── + + if !round_text.is_empty() { + let _ = app_handle.emit("stream-clear", ()); + } + + tracing::info!( + round = round, + tool_count = tool_calls_detected.len(), + "executing tool calls" + ); + + // Persist the assistant's tool-call message + { + let mgr = + state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.add_tool_call_message(&session_id, &tool_calls_detected) + .map_err(|e| format!("Failed to save tool call: {e}"))?; + } + + // ── Accumulate tool calls for the turn ───────────────────────── + // Push this round's calls into the turn-level accumulator, then + // emit ALL accumulated calls under the same stable message ID. + // The frontend upserts by ID, so the ToolTrace grows in-place + // rather than spawning a new block each round. + for tc in &tool_calls_detected { + turn_tool_calls.push(serde_json::json!({ + "id": tc.id, + "name": tc.name, + "arguments": tc.arguments, + })); + } + + let _ = app_handle.emit( + "tool-call", + serde_json::json!({ + "id": turn_message_id, + "sessionId": session_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "role": "assistant", + "toolCalls": turn_tool_calls, + "tokenCount": 10, + }), + ); + + // Execute each tool and collect typed outcomes. + let mut round_error_count: usize = 0; + let round_call_count = tool_calls_detected.len(); + let mut round_unknown_tools: Vec<(String, ToolResolution)> = Vec::new(); + + for tc in &tool_calls_detected { + // Auto-inject session_id into audit tool arguments so the model + // doesn't need to guess it. Audit tools expect a session_id param + // that matches the agent_core audit log's session column. + // Always override — the model often hallucinates placeholder values + // like "SESSION_ID_FROM_CURRENT_CONTEXT" or tool_call_ids. + let mut effective_arguments = if tc.name.starts_with("audit.") { + let mut args = tc.arguments.clone(); + if let Some(obj) = args.as_object_mut() { + obj.insert( + "session_id".to_string(), + serde_json::Value::String(session_id.clone()), + ); + } + args + } else { + tc.arguments.clone() + }; + + // ── HITL confirmation check ────────────────────────────── + // Built-in tools (list_directory, read_file) are always read-only. + // MCP tools check the registry's confirmation_required metadata. + // If the user has previously granted permission, skip the dialog. + let is_builtin = tc.name == "list_directory" || tc.name == "read_file"; + let needs_confirmation = !is_builtin && { + let mcp = mcp_state.lock().await; + mcp.registry.requires_confirmation(&tc.name) + }; + + let mut user_confirmed = !needs_confirmation; + + if needs_confirmation { + // Check if permission was previously granted + let already_allowed = { + let perms = permission_state.lock().await; + perms.check(&tc.name) == PermissionStatus::Allowed + }; + + if already_allowed { + user_confirmed = true; + tracing::debug!( + tool = %tc.name, + "skipping confirmation — permission granted" + ); + } else { + // Build and emit a confirmation request + let supports_undo = { + let mcp = mcp_state.lock().await; + mcp.registry.supports_undo(&tc.name) + }; + let preview = generate_preview(&tc.name, &effective_arguments); + let is_destructive = is_destructive_action(&tc.name); + + let request = ConfirmationRequest { + request_id: Uuid::new_v4().to_string(), + tool_name: tc.name.clone(), + arguments: effective_arguments.clone(), + preview, + confirmation_required: true, + undo_supported: supports_undo, + is_destructive, + }; + + tracing::info!( + tool = %tc.name, + request_id = %request.request_id, + is_destructive, + "awaiting user confirmation" + ); + + // Create a oneshot channel for this confirmation + let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); + { + let mut pending = pending_confirm.lock().await; + *pending = Some(resp_tx); + } + + // Emit confirmation-request event to frontend + let _ = app_handle.emit("confirmation-request", &request); + + // Wait for user response (blocks the agent loop) + match resp_rx.await { + Ok(ConfirmationResponse::Rejected) => { + tracing::info!( + tool = %tc.name, + "tool call rejected by user" + ); + // Write rejection to audit log + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let _ = mgr.db().insert_audit_entry( + &session_id, + &tc.name, + &effective_arguments, + None, + AuditStatus::RejectedByUser, + false, + 0, + ); + } + + let rejection_text = + format!("Tool '{}' was rejected by the user.", tc.name); + + // Emit rejection result to frontend + let _ = app_handle.emit( + "tool-result", + serde_json::json!({ + "id": chrono::Utc::now().timestamp_millis(), + "sessionId": session_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "role": "tool", + "content": rejection_text, + "toolCallId": tc.id, + "toolResult": { + "success": false, + "result": rejection_text, + "toolCallId": tc.id, + "toolName": tc.name, + }, + "tokenCount": rejection_text.len() / 4, + }), + ); + + // Persist rejection so the model knows + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let result_json = + serde_json::Value::String(rejection_text); + mgr.add_tool_result_message( + &session_id, + &tc.id, + &result_json, + ) + .map_err(|e| { + format!("Failed to save tool result: {e}") + })?; + } + + // Add to conversation history for the LLM + messages.push(crate::inference::types::ChatMessage { + role: crate::inference::types::Role::Tool, + content: Some(format!( + "Tool '{}' was rejected by the user.", + tc.name + )), + tool_call_id: Some(tc.id.clone()), + tool_calls: None, + }); + + round_error_count += 1; + tool_call_history.push(tc.name.clone()); + tool_call_signatures.push(( + tc.name.clone(), + tc.arguments.to_string(), + )); + continue; + } + Ok(ConfirmationResponse::ConfirmedForSession) => { + let mut perms = permission_state.lock().await; + perms.grant(&tc.name, PermissionScope::Session); + user_confirmed = true; + } + Ok(ConfirmationResponse::ConfirmedAlways) => { + let mut perms = permission_state.lock().await; + perms.grant(&tc.name, PermissionScope::Always); + user_confirmed = true; + } + Ok(ConfirmationResponse::Confirmed) => { + user_confirmed = true; + } + Ok(ConfirmationResponse::EditedAndConfirmed { + new_arguments, + }) => { + effective_arguments = new_arguments; + user_confirmed = true; + } + Err(_) => { + tracing::warn!( + tool = %tc.name, + "confirmation channel closed — skipping tool" + ); + continue; + } + } + } + } + + // ── Duplicate call interception ────────────────────────── + // If the model is requesting the exact same tool+args as a + // previous call in this conversation, skip execution entirely. + // Instead, feed back a short "you already have this" nudge so + // the model transitions to summarising the results it already + // has. This is cheaper and more robust than executing the + // duplicate and relying on post-hoc detection to break the loop. + let call_sig = (tc.name.clone(), tc.arguments.to_string()); + let is_duplicate = tool_call_signatures.contains(&call_sig); + + if is_duplicate { + tracing::info!( + session_id = %session_id, + round = round, + tool = %tc.name, + "skipping duplicate tool call — returning cached nudge" + ); + + let nudge = format!( + "You already called {} with these exact arguments. \ + The results are in the conversation above. \ + Summarize those results for the user now.", + tc.name + ); + + // Record the signature so the hard-break counter still works + tool_call_history.push(tc.name.clone()); + tool_call_signatures.push(call_sig); + + // Push the nudge as the tool result so the model sees it + messages.push(crate::inference::types::ChatMessage { + role: crate::inference::types::Role::Tool, + content: Some(nudge.clone()), + tool_call_id: Some(tc.id.clone()), + tool_calls: None, + }); + + // Persist so windowed rebuild includes it + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let nudge_json = serde_json::Value::String(nudge); + mgr.add_tool_result_message(&session_id, &tc.id, &nudge_json) + .map_err(|e| format!("Failed to save nudge: {e}"))?; + } + + continue; // skip to next tool call (or next round) + } + + // ── Execute tool with Error Boundary ─────────────────────── + let tool_start = std::time::Instant::now(); + let outcome: ToolExecutionOutcome = { + let mut mcp = mcp_state.lock().await; + // Error boundary: wrap in a timeout to prevent hung tool executions + // The try_read_with_timeout helper handles both success and error cases + match tokio::time::timeout( + std::time::Duration::from_secs(120), // 2 min timeout per tool + execute_tool(&tc.name, &effective_arguments, &mut mcp), + ) + .await + { + Ok(result) => result, + Err(_elapsed) => { + tracing::error!( + tool = %tc.name, + timeout_secs = 120, + "tool execution timed out — caught by error boundary" + ); + ToolExecutionOutcome::InfraError { + tool_name: tc.name.clone(), + text: format!( + "Tool '{}' timed out after 120 seconds. Please try again or use a different tool.", + tc.name + ), + } + } + } + }; + let execution_time_ms = tool_start.elapsed().as_millis() as u64; + + let is_error = outcome.is_error(); + let result_text = outcome.model_text().to_string(); + + // ── Audit log write ────────────────────────────────────── + // Record every tool execution in the audit_log table so + // audit.get_tool_log / audit.generate_audit_report can read them. + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let audit_status = if is_error { + AuditStatus::Error + } else { + AuditStatus::Success + }; + let result_val = serde_json::Value::String(result_text.clone()); + if let Err(e) = mgr.db().insert_audit_entry( + &session_id, + &tc.name, + &effective_arguments, + Some(&result_val), + audit_status, + user_confirmed, + execution_time_ms, + ) { + tracing::warn!( + session_id = %session_id, + tool = %tc.name, + error = %e, + "failed to write audit log entry" + ); + } + } + + if is_error { + round_error_count += 1; + *tool_failure_counts.entry(tc.name.clone()).or_default() += 1; + } + + // Collect UnknownTool resolutions for correction hints + if let ToolExecutionOutcome::UnknownTool { + ref tool_name, + ref resolution, + .. + } = outcome + { + round_unknown_tools.push((tool_name.clone(), resolution.clone())); + } + + tool_call_history.push(tc.name.clone()); + tool_call_signatures.push(( + tc.name.clone(), + tc.arguments.to_string(), + )); + + if is_error { + tracing::warn!( + session_id = %session_id, + tool = %tc.name, + tool_call_id = %tc.id, + result_len = result_text.len(), + result_preview = %truncate_utf8(&result_text, 200), + execution_time_ms = execution_time_ms, + tools_completed = tool_call_history.len(), + "tool call FAILED" + ); + } else { + tracing::info!( + session_id = %session_id, + tool = %tc.name, + tool_call_id = %tc.id, + result_len = result_text.len(), + execution_time_ms = execution_time_ms, + tools_completed = tool_call_history.len(), + user_confirmed, + "tool execution complete" + ); + } + + let _ = app_handle.emit( + "tool-result", + serde_json::json!({ + "id": chrono::Utc::now().timestamp_millis(), + "sessionId": session_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "role": "tool", + "content": result_text, + "toolCallId": tc.id, + "toolResult": { + "success": !is_error, + "result": result_text, + "toolCallId": tc.id, + "toolName": tc.name, + "executionTimeMs": execution_time_ms, + "inferenceTimeMs": inference_time_ms, + }, + "tokenCount": result_text.len() / 4, + }), + ); + + // Persist tool result in conversation + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + let result_json = serde_json::Value::String(result_text); + mgr.add_tool_result_message( + &session_id, + &tc.id, + &result_json, + ) + .map_err(|e| format!("Failed to save tool result: {e}"))?; + } + } + + // ── Consecutive error round tracking ───────────────────────── + // If ALL tool calls in this round errored, the model may be stuck + // in a loop calling a non-existent tool (e.g., filesystem.rename_file). + // After MAX_CONSECUTIVE_ERROR_ROUNDS, inject a corrective hint using + // the suggestions already computed by ToolRegistry::resolve(). + if round_error_count > 0 && round_error_count == round_call_count { + consecutive_error_rounds += 1; + tracing::warn!( + session_id = %session_id, + round = round, + consecutive_error_rounds = consecutive_error_rounds, + failed_tools = ?tool_calls_detected.iter().map(|tc| tc.name.as_str()).collect::>(), + "all tool calls in round failed" + ); + + if consecutive_error_rounds >= MAX_CONSECUTIVE_ERROR_ROUNDS { + let hint = format_correction_hint(&round_unknown_tools); + + tracing::info!( + round = round, + hint_len = hint.len(), + "injecting tool correction hint after repeated failures" + ); + + // Persist the corrective hint as a user message + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + mgr.add_user_message(&session_id, &hint) + .map_err(|e| format!("Failed to save hint: {e}"))?; + } + + // Reset counter so the model gets another chance + consecutive_error_rounds = 0; + } + } else { + // At least one tool succeeded — reset the counter + consecutive_error_rounds = 0; + } + + // ── Per-tool failure circuit breaker ────────────────────────── + // Even when the per-round counter resets (because the model alternates + // between a succeeding tool and a failing one), the per-tool counter + // keeps accumulating. Once a tool hits MAX_SAME_TOOL_FAILURES, remove + // it from the definitions and inject a hard stop hint. + let stuck_tools: Vec = tool_failure_counts + .iter() + .filter(|(_, &count)| count >= MAX_SAME_TOOL_FAILURES) + .map(|(name, _)| name.clone()) + .collect(); + + if !stuck_tools.is_empty() { + let hint = format!( + "STOP: The following tools have each failed {} or more times and have been \ + removed: {}. Do NOT attempt to call them again. Respond to the user with \ + what you know so far, or try a completely different approach.", + MAX_SAME_TOOL_FAILURES, + stuck_tools.join(", ") + ); + + tracing::warn!( + session_id = %session_id, + round = round, + stuck_tools = ?stuck_tools, + "per-tool failure limit reached — removing stuck tools from definitions" + ); + + // Remove stuck tools from the active tool definitions + tools.retain(|t| !stuck_tools.contains(&t.function.name)); + + // Clear the counters for removed tools so we don't re-trigger + for name in &stuck_tools { + tool_failure_counts.remove(name); + } + + // Inject the hint as a user message + { + let mgr = state + .lock() + .map_err(|e| format!("Lock error: {e}"))?; + mgr.add_user_message(&session_id, &hint) + .map_err(|e| format!("Failed to save stuck-tool hint: {e}"))?; + } + } + + // ── Duplicate tool call detection ───────────────────────────── + // If the model is calling the same tool with the same arguments + // repeatedly (e.g., list_directory("~/Downloads") 3× in a row), + // the results won't change. Break to prevent wasting rounds. + let dup_count = consecutive_duplicate_count(&tool_call_signatures); + if dup_count >= MAX_DUPLICATE_TOOL_CALLS { + tracing::warn!( + session_id = %session_id, + round = round, + duplicate_count = dup_count, + tool = %tool_call_signatures.last().map(|(n, _)| n.as_str()).unwrap_or("?"), + "duplicate tool call detected — model is stuck, breaking loop" + ); + break; + } + + // ── Mid-loop eviction ─────────────────────────────────────── + // After persisting tool results, check if context window needs + // eviction before the next round. This prevents unbounded growth + // during long multi-step workflows. + { + let mgr = + state.lock().map_err(|e| format!("Lock error: {e}"))?; + let evicted = mgr + .evict_if_needed(&session_id) + .map_err(|e| format!("Eviction error: {e}"))?; + if evicted > 0 { + tracing::info!( + round = round, + evicted_tokens = evicted, + "mid-loop eviction" + ); + } + } + + // Rebuild messages (windowed — compress old tool results to save tokens) + messages = { + let mgr = + state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.build_windowed_chat_messages(&session_id, 4) + .map_err(|e| format!("Failed to build messages: {e}"))? + }; + } + + } // end if full_response.is_empty() (skip agent loop when orchestrator succeeded) + + // 4. If the agent loop finished without generating text, force a + // summary. This can happen when: + // - All rounds were used on tool calls (normal for large batches) + // - Model returned empty responses (timeout / context overflow) + // - Streaming errors caused early exit + // + // Strategy: inject a short, explicit "summarize now" user message + // and call the model WITHOUT tools, so it MUST produce text. + if full_response.is_empty() { + tracing::info!( + session_id = %session_id, + rounds_used = empty_response_count, + tool_calls_total = tool_call_history.len(), + "forcing summary — injecting summarize prompt" + ); + + // Inject a constrained summary instruction that prevents confabulation. + // The model MUST only report results it actually received from tools. + let summary_instruction = crate::inference::types::ChatMessage { + role: crate::inference::types::Role::User, + content: Some( + "Based on the tool results above, provide a concise summary.\n\ + CRITICAL RULES:\n\ + - ONLY report results you actually received from tool calls above.\n\ + - If a file was not processed, say 'not processed' — do NOT guess or invent results.\n\ + - If no tool results are visible, say 'I was unable to complete the task.'\n\ + Do NOT call any more tools." + .to_string(), + ), + tool_call_id: None, + tool_calls: None, + }; + messages.push(summary_instruction); + + match client + .chat_completion_stream(messages, None, Some(conversational_sampling)) // No tools → model MUST produce text + .await + { + Ok(stream) => { + futures::pin_mut!(stream); + while let Some(chunk_result) = stream.next().await { + if let Ok(chunk) = chunk_result { + if let Some(token) = &chunk.token { + full_response.push_str(token); + let _ = app_handle.emit("stream-token", token.clone()); + } + } + } + } + Err(e) => { + tracing::warn!(error = %e, "summary call failed"); + } + } + + // If even the summary call returned nothing, use a static fallback + if full_response.is_empty() { + tracing::warn!("summary call also returned empty — using static fallback text"); + full_response = "I processed the requested files using the tools above. \ + You can see the individual results in the tool trace. \ + Please ask a follow-up question if you'd like me to continue." + .to_string(); + let _ = app_handle.emit("stream-token", full_response.clone()); + } + } + + // 5. Persist final assistant text response + // (skip if the orchestrator already persisted it) + { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + + if !full_response.is_empty() && !already_persisted { + mgr.add_assistant_message(&session_id, &full_response) + .map_err(|e| format!("Failed to save assistant message: {e}"))?; + } + + emit_context_budget(&app_handle, &mgr, &session_id); + } + + // 5. Emit the complete message + let message = serde_json::json!({ + "id": chrono::Utc::now().timestamp_millis(), + "sessionId": session_id, + "timestamp": chrono::Utc::now().to_rfc3339(), + "role": "assistant", + "content": full_response, + "tokenCount": full_response.len() / 4, + }); + + let _ = app_handle.emit("stream-complete", message); + + // Release in-flight lock + { + let mut in_flight_guard = in_flight.lock().await; + in_flight_guard.insert(session_id.clone(), false); + } + + Ok(()) +} + +/// Respond to a confirmation request from the agent loop. +/// +/// The frontend calls this when the user clicks Confirm/Cancel on a +/// confirmation dialog. The response is forwarded to the agent loop +/// via the pending oneshot channel. +#[tauri::command] +pub async fn respond_to_confirmation( + request_id: String, + response: serde_json::Value, + pending: tauri::State<'_, PendingConfirmation>, +) -> Result<(), String> { + tracing::info!( + request_id = %request_id, + response = %response, + "confirmation response received" + ); + + let parsed: ConfirmationResponse = serde_json::from_value(response) + .map_err(|e| format!("Invalid confirmation response: {e}"))?; + + let mut lock = pending.lock().await; + if let Some(tx) = lock.take() { + // oneshot::Sender::send returns Err if receiver was dropped + tx.send(parsed).map_err(|_| { + "Confirmation channel closed — agent loop may have timed out".to_string() + })?; + } else { + tracing::warn!( + request_id = %request_id, + "no pending confirmation — response ignored" + ); + } + + Ok(()) +} + +// ─── Tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::agent_core::response_analysis::is_incomplete_response; + + #[test] + fn test_unwrap_tool_result_json_extracts_text() { + // Simulates what Python MCP servers send: json.dumps(result.model_dump()) + let raw = r#"{"text": "LocalCowork OCR Test\nInvoice #12345", "confidence": 0.9, "engine": "lfm_vision"}"#; + let result = unwrap_tool_result_json(raw); + assert!(result.starts_with("LocalCowork OCR Test")); + assert!(result.contains("[engine: lfm_vision]")); + assert!(result.contains("[confidence: 0.9]")); + } + + #[test] + fn test_unwrap_tool_result_json_plain_text() { + let raw = "Just a plain text result"; + let result = unwrap_tool_result_json(raw); + assert_eq!(result, "Just a plain text result"); + } + + #[test] + fn test_unwrap_tool_result_json_no_text_field() { + let raw = r#"{"headers": ["col1", "col2"], "rows": [["a", "b"]]}"#; + let result = unwrap_tool_result_json(raw); + // No recognized text field, should return raw JSON + assert_eq!(result, raw); + } + + #[test] + fn test_extract_mcp_result_text_with_content_array() { + let value = serde_json::json!({ + "content": [{"type": "text", "text": "{\"text\": \"hello\", \"engine\": \"tesseract\"}"}] + }); + let result = extract_mcp_result_text(&Some(value)); + assert!(result.starts_with("hello")); + assert!(result.contains("[engine: tesseract]")); + } + + #[test] + fn test_extract_mcp_result_text_none() { + let result = extract_mcp_result_text(&None); + assert_eq!(result, "No result returned."); + } + + #[test] + fn test_truncate_tool_result_short() { + let result = truncate_tool_result("short result", "test_tool"); + assert_eq!(result, "short result"); + } + + #[test] + fn test_truncate_tool_result_long() { + let long = "x".repeat(10_000); + let result = truncate_tool_result(&long, "test_tool"); + assert!(result.len() < long.len()); + assert!(result.contains("[... truncated: showing first 6000 of 10000 chars]")); + } + + #[test] + fn test_is_incomplete_response_remaining() { + assert!(is_incomplete_response( + "I've processed 3 files. There are 4 remaining screenshots to rename." + )); + } + + #[test] + fn test_is_incomplete_response_next_file() { + assert!(is_incomplete_response( + "Renamed screenshot 1. Moving on to the next file." + )); + } + + #[test] + fn test_is_incomplete_response_complete() { + assert!(!is_incomplete_response( + "All screenshots have been renamed successfully." + )); + } + + #[test] + fn test_is_incomplete_response_no_signals() { + // No incomplete or complete signals — defaults to false (task done) + assert!(!is_incomplete_response( + "Here is the result of your request." + )); + } + + /// Helper to create an McpClient with registered tools for testing. + fn mcp_client_with_tools(tools: Vec<(&str, &str)>) -> McpClient { + use crate::mcp_client::types::{McpServersConfig, McpToolDefinition}; + + let config = McpServersConfig { + servers: std::collections::HashMap::new(), + }; + let mut client = McpClient::new(config, None); + + // Group tools by server name and register them + let mut server_tools: std::collections::HashMap<&str, Vec> = + std::collections::HashMap::new(); + for (server, tool) in tools { + server_tools + .entry(server) + .or_default() + .push(McpToolDefinition { + name: tool.to_string(), + description: format!("Test tool: {tool}"), + params_schema: serde_json::json!({"type": "object", "properties": {}}), + returns_schema: serde_json::json!({}), + confirmation_required: false, + undo_supported: false, + }); + } + for (server, defs) in server_tools { + client.registry.register_server_tools(server, defs); + } + + client + } + + #[test] + fn test_resolve_exact_match() { + let client = mcp_client_with_tools(vec![("filesystem", "move_file")]); + let resolution = client.registry.resolve("filesystem.move_file", 0.5); + assert!(matches!(resolution, ToolResolution::Exact(_))); + assert_eq!(resolution.resolved_name(), Some("filesystem.move_file")); + } + + #[test] + fn test_resolve_unprefixed() { + let client = mcp_client_with_tools(vec![ + ("filesystem", "move_file"), + ("filesystem", "copy_file"), + ("ocr", "extract_text_from_image"), + ]); + let resolution = client.registry.resolve("move_file", 0.5); + assert!(matches!(resolution, ToolResolution::Unprefixed { .. })); + assert_eq!(resolution.resolved_name(), Some("filesystem.move_file")); + } + + #[test] + fn test_resolve_unknown_unprefixed() { + let client = mcp_client_with_tools(vec![("filesystem", "move_file")]); + let resolution = client.registry.resolve("nonexistent_tool", 0.5); + assert!(matches!(resolution, ToolResolution::NotFound { .. })); + assert_eq!(resolution.resolved_name(), None); + } + + #[test] + fn test_resolve_wrong_server_prefix() { + let client = mcp_client_with_tools(vec![("filesystem", "move_file")]); + // "wrong_server" doesn't exist — no same-server tools to match against + let resolution = client.registry.resolve("wrong_server.move_file", 0.5); + assert!(matches!(resolution, ToolResolution::NotFound { .. })); + } + + #[test] + fn test_resolve_ambiguous_unprefixed() { + let client = mcp_client_with_tools(vec![ + ("ocr", "process"), + ("document", "process"), + ]); + // Ambiguous — two servers have "process" + let resolution = client.registry.resolve("process", 0.5); + assert!(matches!(resolution, ToolResolution::NotFound { .. })); + } + + #[test] + fn test_build_system_prompt_includes_server_names() { + use crate::mcp_client::registry::ToolRegistry; + use crate::mcp_client::types::McpToolDefinition; + + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "filesystem", + vec![McpToolDefinition { + name: "list_dir".to_string(), + description: "List directory".to_string(), + params_schema: serde_json::json!({"type": "object"}), + returns_schema: serde_json::json!({}), + confirmation_required: false, + undo_supported: false, + }], + ); + registry.register_server_tools( + "email", + vec![McpToolDefinition { + name: "send_draft".to_string(), + description: "Send draft".to_string(), + params_schema: serde_json::json!({"type": "object"}), + returns_schema: serde_json::json!({}), + confirmation_required: true, + undo_supported: false, + }], + ); + + let prompt = build_system_prompt(®istry, false); + assert!(prompt.contains("filesystem (1)")); + assert!(prompt.contains("email (1)")); + assert!(prompt.contains("2 tools across 2 servers")); + assert!(prompt.contains("LocalCowork")); + // Should include XML-tagged rules section + assert!(prompt.contains("")); + assert!(prompt.contains("fully-qualified tool names")); + } + + #[test] + fn test_build_system_prompt_empty_registry() { + use crate::mcp_client::registry::ToolRegistry; + + let registry = ToolRegistry::new(); + let prompt = build_system_prompt(®istry, false); + assert!(prompt.contains("No MCP tools currently available")); + // Should still include the rules and examples sections + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("filesystem.list_dir")); + } + + #[test] + fn test_build_system_prompt_with_two_pass() { + use crate::mcp_client::registry::ToolRegistry; + + let registry = ToolRegistry::new(); + let prompt = build_system_prompt(®istry, true); + assert!(prompt.contains("category-level tools")); + assert!(prompt.contains("file_browse")); + } + + #[test] + fn test_build_system_prompt_has_precomputed_dates() { + use crate::mcp_client::registry::ToolRegistry; + + let registry = ToolRegistry::new(); + let prompt = build_system_prompt(®istry, false); + // Must contain the block with pre-computed dates + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("today =")); + assert!(prompt.contains("tomorrow =")); + assert!(prompt.contains("this_week =")); + assert!(prompt.contains("NEVER ask the user for a date")); + } + + // ── has_unverified_completion tests ────────────────────────────── + + #[test] + fn test_unverified_completion_claims_done_no_mutable_calls() { + // Model says "all files renamed" but history has no move_file + let history = vec![ + "filesystem.list_dir".to_string(), + "ocr.extract_text_from_image".to_string(), + "ocr.extract_text_from_image".to_string(), + ]; + assert!(has_unverified_completion( + "I've successfully renamed all 9 files.", + &history, + )); + } + + #[test] + fn test_unverified_completion_claims_done_with_mutable_calls() { + // Model says "all files renamed" AND move_file is in history — genuine + let history = vec![ + "filesystem.list_dir".to_string(), + "ocr.extract_text_from_image".to_string(), + "filesystem.move_file".to_string(), + ]; + assert!(!has_unverified_completion( + "I've successfully renamed all 9 files.", + &history, + )); + } + + #[test] + fn test_unverified_completion_no_completion_claim() { + // Model doesn't claim completion — no confabulation check needed + let history = vec!["filesystem.list_dir".to_string()]; + assert!(!has_unverified_completion( + "Here are the files I found on your desktop.", + &history, + )); + } + + #[test] + fn test_unverified_completion_empty_history() { + // Empty tool history + completion claim = confabulation + assert!(has_unverified_completion( + "All done! Finished processing everything.", + &[], + )); + } + + #[test] + fn test_unverified_completion_write_file_counts_as_mutable() { + // write_file is a mutable operation — should count + let history = vec!["filesystem.write_file".to_string()]; + assert!(!has_unverified_completion( + "Task complete. All files processed.", + &history, + )); + } + + #[test] + fn test_unverified_completion_create_task_counts_as_mutable() { + // create_task should now be recognized as mutable + let history = vec![ + "filesystem.read_file".to_string(), + "task.create_task".to_string(), + ]; + assert!(!has_unverified_completion( + "Successfully created the task.", + &history, + )); + } + + #[test] + fn test_unverified_completion_read_only_generic_done() { + // Read-only task (list files) saying "all done" — NOT confabulation. + // The model legitimately completed a read-only request. + let history = vec![ + "filesystem.list_dir".to_string(), + ]; + assert!(!has_unverified_completion( + "All done! Here are the files in your Downloads folder.", + &history, + )); + } + + #[test] + fn test_unverified_completion_read_only_claims_rename() { + // Read-only tools but claims "renamed" → confabulation + let history = vec![ + "filesystem.list_dir".to_string(), + "ocr.extract_text_from_image".to_string(), + ]; + assert!(has_unverified_completion( + "I've successfully renamed all 9 files.", + &history, + )); + } + + #[test] + fn test_unverified_completion_scan_then_complete() { + // Security scan (read-only) followed by "completed" → not confabulation + // (it's a genuinely complete read-only scan task) + let history = vec![ + "security.scan_for_pii".to_string(), + "security.scan_for_secrets".to_string(), + ]; + assert!(!has_unverified_completion( + "All done! Here's what I found in the scan.", + &history, + )); + } + + // ── consecutive_duplicate_count tests ──────────────────────────── + + #[test] + fn test_duplicate_count_empty() { + let history: Vec<(String, String)> = vec![]; + assert_eq!(consecutive_duplicate_count(&history), 0); + } + + #[test] + fn test_duplicate_count_single() { + let history = vec![("list_dir".into(), r#"{"path":"~/Downloads"}"#.into())]; + assert_eq!(consecutive_duplicate_count(&history), 1); + } + + #[test] + fn test_duplicate_count_three_identical() { + let history = vec![ + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ]; + assert_eq!(consecutive_duplicate_count(&history), 3); + } + + #[test] + fn test_duplicate_count_different_args() { + let history = vec![ + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ("list_dir".into(), r#"{"path":"~/Documents"}"#.into()), + ]; + assert_eq!(consecutive_duplicate_count(&history), 1); + } + + #[test] + fn test_duplicate_count_interrupted_by_different_tool() { + let history = vec![ + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ("read_file".into(), r#"{"path":"file.txt"}"#.into()), + ("list_dir".into(), r#"{"path":"~/Downloads"}"#.into()), + ]; + // Only the last consecutive run counts (just 1) + assert_eq!(consecutive_duplicate_count(&history), 1); + } + + // ── expand_tilde_in_arguments tests ────────────────────────────── + + #[test] + fn test_expand_tilde_simple_path() { + let args = serde_json::json!({"path": "~/Documents/file.txt"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!(!path.starts_with('~'), "tilde should be expanded: {path}"); + assert!(path.ends_with("/Documents/file.txt")); + } + + #[test] + fn test_expand_tilde_bare() { + let args = serde_json::json!({"path": "~"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!(!path.starts_with('~')); + assert!(!path.is_empty()); + } + + #[test] + fn test_expand_tilde_leaves_absolute_paths() { + let args = serde_json::json!({"path": "/Users/chintan/Documents/file.txt"}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!( + expanded["path"].as_str().unwrap(), + "/Users/chintan/Documents/file.txt" + ); + } + + #[test] + fn test_expand_tilde_leaves_other_user() { + // ~other_user/... should NOT be expanded + let args = serde_json::json!({"path": "~other_user/file.txt"}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!(expanded["path"].as_str().unwrap(), "~other_user/file.txt"); + } + + #[test] + fn test_expand_tilde_nested_object() { + let args = serde_json::json!({ + "source": "~/Desktop/a.png", + "destination": "/tmp/b.png", + "options": {"backup": "~/backup/"} + }); + let expanded = expand_tilde_in_arguments(&args); + assert!(!expanded["source"].as_str().unwrap().starts_with('~')); + assert_eq!(expanded["destination"].as_str().unwrap(), "/tmp/b.png"); + assert!(!expanded["options"]["backup"].as_str().unwrap().starts_with('~')); + } + + #[test] + fn test_expand_tilde_non_string_values() { + let args = serde_json::json!({"count": 42, "flag": true, "path": "~/file"}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!(expanded["count"], 42); + assert_eq!(expanded["flag"], true); + assert!(!expanded["path"].as_str().unwrap().starts_with('~')); + } + + #[test] + fn test_expand_tilde_array_values() { + let args = serde_json::json!({"paths": ["~/a.txt", "/b.txt", "~/c.txt"]}); + let expanded = expand_tilde_in_arguments(&args); + let paths = expanded["paths"].as_array().unwrap(); + assert!(!paths[0].as_str().unwrap().starts_with('~')); + assert_eq!(paths[1].as_str().unwrap(), "/b.txt"); + assert!(!paths[2].as_str().unwrap().starts_with('~')); + } + + // ── fix_path_string: cross-platform path correction tests ─────── + + /// Helper: build the expected path using Path::join (platform-correct). + fn expected_home_join(suffix: &str) -> String { + dirs::home_dir() + .unwrap() + .join(suffix) + .to_string_lossy() + .into_owned() + } + + #[cfg(target_os = "macos")] + #[test] + fn test_fix_foreign_os_prefix() { + // On macOS, /home/ is foreign — any username is hallucinated + let args = serde_json::json!({"path": "/home/chintan/Downloads"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Downloads")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_real_username_not_rewritten() { + // On macOS, /Users//... should NOT be rewritten + // (could be a legitimate multi-user path) + let args = serde_json::json!({"path": "/Users/admin/shared/notes.txt"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, "/Users/admin/shared/notes.txt", "Real username should not be rewritten"); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_template_user() { + // /Users/{user}/Downloads on macOS — template on native prefix + let args = serde_json::json!({"path": "/Users/{user}/Downloads"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!( + !path.contains("{user}"), + "Placeholder should be replaced: {path}" + ); + assert_eq!(path, expected_home_join("Downloads")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_template_username() { + // /Users/{username}/Documents on macOS + let args = serde_json::json!({"path": "/Users/{username}/Documents"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!( + !path.contains("{username}"), + "Placeholder should be replaced: {path}" + ); + assert_eq!(path, expected_home_join("Documents")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_angle_bracket() { + // /Users//Downloads on macOS + let args = serde_json::json!({"path": "/Users//Downloads"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!( + !path.contains(""), + "Angle-bracket placeholder should be replaced: {path}" + ); + assert_eq!(path, expected_home_join("Downloads")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_square_bracket() { + // /Users/[USER]/Documents/Projects on macOS + let args = serde_json::json!({"path": "/Users/[USER]/Documents/Projects"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert!( + !path.contains("[USER]"), + "Square-bracket placeholder should be replaced: {path}" + ); + assert_eq!(path, expected_home_join("Documents/Projects")); + } + + #[cfg(target_os = "macos")] + #[test] + fn test_native_prefix_known_placeholder_word() { + // /Users/user/Documents on macOS — "user" is a known placeholder + let args = serde_json::json!({"path": "/Users/user/Documents"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Documents")); + } + + #[test] + fn test_fix_bare_relative_path() { + // Model generates just "Projects" instead of an absolute path + let args = serde_json::json!({"path": "Projects"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Projects")); + } + + #[test] + fn test_fix_bare_downloads_relative_path() { + // Model generates "Downloads" + let args = serde_json::json!({"path": "Downloads"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Downloads")); + } + + #[test] + fn test_tilde_expansion() { + let args = serde_json::json!({"path": "~/Documents/file.txt"}); + let expanded = expand_tilde_in_arguments(&args); + let path = expanded["path"].as_str().unwrap(); + assert_eq!(path, expected_home_join("Documents/file.txt")); + } + + #[test] + fn test_no_fix_for_correct_path() { + // Already-correct absolute path should not be modified + let home = dirs::home_dir().unwrap(); + let correct = home.join("Documents").join("test.txt"); + let correct_str = correct.to_string_lossy().into_owned(); + let args = serde_json::json!({"path": correct_str}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!(expanded["path"].as_str().unwrap(), correct_str); + } + + #[test] + fn test_no_fix_for_urls() { + // URL-like strings should not be modified + let args = serde_json::json!({"url": "https://example.com/Documents/file"}); + let expanded = expand_tilde_in_arguments(&args); + assert_eq!( + expanded["url"].as_str().unwrap(), + "https://example.com/Documents/file" + ); + } + + // ─── Tool Result Compression Tests (PR #59) ───────────────────────────── + + #[test] + fn test_truncate_tool_result_small() { + let result = "short result"; + let truncated = truncate_tool_result(result, "test_tool"); + assert_eq!(truncated, result); + } + + #[test] + fn test_truncate_tool_result_large_no_compression() { + // Large result that doesn't match compression patterns + let result = "x".repeat(10000); + let truncated = truncate_tool_result(&result, "unknown_tool"); + assert!(truncated.contains("truncated")); + assert!(truncated.len() < result.len()); + } + + #[test] + fn test_compress_directory_listing() { + let listing = r#"📁 src/ +📁 tests/ +📄 file1.txt (1 KB) +📄 file2.txt (2 KB) +📁 subdir/ +📄 long_file_name.txt (100 KB)"#; + let compressed = compress_directory_listing(listing); + assert!(compressed.is_some()); + let summary = compressed.unwrap(); + assert!(summary.contains("Total:")); + assert!(summary.contains("Directories:")); + assert!(summary.contains("Files:")); + } + + #[test] + fn test_compress_directory_listing_empty() { + let compressed = compress_directory_listing(""); + assert!(compressed.is_none()); + } + + #[test] + fn test_compress_search_results_with_count() { + let result = "Scanning files...\nFound 42 matches\n\n/path/to/file1.txt: line 10\n/path/to/file2.txt: line 20"; + let compressed = compress_search_results(result); + assert!(compressed.is_some()); + let summary = compressed.unwrap(); + assert!(summary.contains("42")); + assert!(summary.contains("Key findings")); + } + + #[test] + fn test_compress_search_results_no_count() { + // Search result without clear count pattern + let result = "File A\nFile B\nFile C\nFile D\nFile E"; + let compressed = compress_search_results(result); + assert!(compressed.is_some()); + } + + #[test] + fn test_compress_json_result_array() { + let json = r#"[ + {"name": "item1", "size": 100}, + {"name": "item2", "size": 200}, + {"name": "item3", "size": 300} + ]"#; + let compressed = compress_json_result(json); + assert!(compressed.is_some()); + let summary = compressed.unwrap(); + assert!(summary.contains("3 items")); + assert!(summary.contains("item1")); + } + + #[test] + fn test_compress_json_result_object() { + let json = r#"{"name": "test", "content": "hello world", "count": 42}"#; + let compressed = compress_json_result(json); + assert!(compressed.is_some()); + let summary = compressed.unwrap(); + assert!(summary.contains("name:")); + assert!(summary.contains("test")); + } + + #[test] + fn test_compress_json_result_empty_array() { + let json = "[]"; + let compressed = compress_json_result(json); + assert!(compressed.is_some()); + assert!(compressed.unwrap().contains("empty")); + } + + #[test] + fn test_compress_tool_result_skips_small() { + // Small results should not be compressed + let result = "short"; + let compressed = compress_tool_result(result, "list_dir"); + assert!(compressed.is_none()); + } + + #[test] + fn test_compress_tool_result_non_compressible_tool() { + // Non-compressible tools should return None even for large results + let result = "x".repeat(10000); + let compressed = compress_tool_result(&result, "write_file"); + assert!(compressed.is_none()); + } + + #[test] + fn test_compress_tool_result_compressible_tool() { + // Compressible tools should attempt compression (need 3000+ chars) + let result = "📁 folder1\n📁 folder2\n📄 file1.txt (1 KB)\n📄 file2.txt (2 KB)\n".repeat(500); + let compressed = compress_tool_result(&result, "list_dir"); + assert!(compressed.is_some()); + } + + #[test] + fn test_builtin_tool_definitions() { + let tools = builtin_tool_definitions(); + assert!(!tools.is_empty()); + + let names: Vec = tools.iter() + .map(|t| t.function.name.clone()) + .collect(); + + assert!(names.contains(&"list_directory".to_string())); + assert!(names.contains(&"read_file".to_string())); + } + + #[test] + fn test_builtin_tool_definitions_have_descriptions() { + let tools = builtin_tool_definitions(); + for tool in &tools { + assert!(!tool.function.description.is_empty()); + assert!(!tool.function.parameters.is_null()); + } + } +} diff --git a/src-tauri/src/commands/filesystem.rs b/src-tauri/src/commands/filesystem.rs new file mode 100644 index 0000000..bd8071e --- /dev/null +++ b/src-tauri/src/commands/filesystem.rs @@ -0,0 +1,115 @@ +//! Tauri IPC commands for filesystem browsing. +//! +//! These stub commands enable the File Browser UI to be developed +//! and tested independently of the MCP server integration. +//! In the full integration, these will dispatch to the filesystem +//! MCP server via the McpClient. + +use serde::Serialize; + +/// Check if a file is hidden (cross-platform). +/// +/// On Unix: files starting with '.' are hidden by convention. +/// On Windows: files with the `FILE_ATTRIBUTE_HIDDEN` attribute are hidden. +fn is_hidden(name: &str, _metadata: &std::fs::Metadata) -> bool { + #[cfg(not(target_os = "windows"))] + { + name.starts_with('.') + } + #[cfg(target_os = "windows")] + { + use std::os::windows::fs::MetadataExt; + const FILE_ATTRIBUTE_HIDDEN: u32 = 0x2; + _metadata.file_attributes() & FILE_ATTRIBUTE_HIDDEN != 0 + } +} + +/// A single file/directory entry. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct FileEntry { + pub name: String, + pub path: String, + pub entry_type: String, + pub size: u64, + pub modified: String, +} + +/// List directory contents. +/// +/// Returns entries sorted: directories first, then files, both alphabetically. +#[tauri::command] +pub fn list_directory(path: String) -> Result, String> { + let dir_path = if path.starts_with('~') { + let home = dirs::home_dir().ok_or("Cannot resolve home directory")?; + home.join(path.strip_prefix("~/").unwrap_or(&path)) + } else { + std::path::PathBuf::from(&path) + }; + + if !dir_path.exists() { + return Err(format!("Directory not found: {path}")); + } + if !dir_path.is_dir() { + return Err(format!("Not a directory: {path}")); + } + + let mut entries = Vec::new(); + let read_dir = std::fs::read_dir(&dir_path) + .map_err(|e| format!("Failed to read directory: {e}"))?; + + for entry_result in read_dir { + let entry = entry_result.map_err(|e| format!("Failed to read entry: {e}"))?; + let metadata = entry.metadata().map_err(|e| format!("Failed to read metadata: {e}"))?; + + let name = entry.file_name().to_string_lossy().to_string(); + + // Skip hidden files + // Unix: dot-prefix convention. Windows: FILE_ATTRIBUTE_HIDDEN attribute. + if is_hidden(&name, &metadata) { + continue; + } + + let entry_type = if metadata.is_dir() { + "dir".to_string() + } else if metadata.file_type().is_symlink() { + "symlink".to_string() + } else { + "file".to_string() + }; + + let size = metadata.len(); + let modified = metadata + .modified() + .map(|t| { + chrono::DateTime::::from(t).to_rfc3339() + }) + .unwrap_or_default(); + + entries.push(FileEntry { + name, + path: entry.path().to_string_lossy().to_string(), + entry_type, + size, + modified, + }); + } + + // Sort: directories first, then files, both alphabetically + entries.sort_by(|a, b| { + let a_is_dir = a.entry_type == "dir"; + let b_is_dir = b.entry_type == "dir"; + b_is_dir + .cmp(&a_is_dir) + .then_with(|| a.name.to_lowercase().cmp(&b.name.to_lowercase())) + }); + + Ok(entries) +} + +/// Get the user's home directory path. +#[tauri::command] +pub fn get_home_dir() -> Result { + let home = dirs::home_dir().ok_or("Cannot resolve home directory")?; + Ok(home.to_string_lossy().to_string()) +} diff --git a/src-tauri/src/commands/hardware.rs b/src-tauri/src/commands/hardware.rs new file mode 100644 index 0000000..e60d09c --- /dev/null +++ b/src-tauri/src/commands/hardware.rs @@ -0,0 +1,222 @@ +//! Tauri IPC commands for hardware detection. +//! +//! Detects CPU, RAM, GPU, and OS details to recommend the optimal +//! inference runtime and quantization level for the local LLM. + +use serde::Serialize; +use sysinfo::System; + +/// GPU information detected on the system. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GpuInfo { + pub vendor: String, + pub model: String, + pub vram_gb: Option, +} + +/// Complete hardware profile for the local machine. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct HardwareInfo { + pub cpu_vendor: String, + pub cpu_model: String, + pub cpu_cores: u32, + pub cpu_threads: u32, + pub ram_total_gb: f64, + pub ram_available_gb: f64, + pub os_name: String, + pub os_version: String, + pub arch: String, + pub gpu: Option, + pub recommended_runtime: String, + pub recommended_quantization: String, +} + +/// Detect whether the system has Apple Silicon. +fn is_apple_silicon() -> bool { + cfg!(target_os = "macos") && std::env::consts::ARCH == "aarch64" +} + +/// Recommend an inference runtime based on detected hardware. +fn recommend_runtime(gpu: &Option) -> String { + if is_apple_silicon() { + return "MLX".to_string(); + } + if let Some(ref g) = gpu { + let vendor_lower = g.vendor.to_lowercase(); + if vendor_lower.contains("nvidia") { + return "vLLM".to_string(); + } + } + "llama.cpp".to_string() +} + +/// Recommend a quantization level based on total RAM. +fn recommend_quantization(ram_total_gb: f64) -> String { + if ram_total_gb >= 32.0 { + "Q8_0".to_string() + } else if ram_total_gb >= 16.0 { + "Q4_K_M".to_string() + } else { + "Q4_0".to_string() + } +} + +/// Detect GPU information. +/// +/// - macOS Apple Silicon: reports the integrated GPU. +/// - Windows: queries WMI via `wmic` for GPU name and VRAM. +/// - Linux: parses `lspci` output for VGA controllers. +fn detect_gpu() -> Option { + if is_apple_silicon() { + return Some(GpuInfo { + vendor: "Apple".to_string(), + model: "Apple Silicon (Unified Memory)".to_string(), + vram_gb: None, + }); + } + + #[cfg(target_os = "windows")] + { + return detect_gpu_windows(); + } + + #[cfg(target_os = "linux")] + { + return detect_gpu_linux(); + } + + #[cfg(not(any(target_os = "windows", target_os = "linux")))] + { + None + } +} + +/// Windows GPU detection via `wmic`. +/// +/// Parses `wmic path win32_VideoController get Name,AdapterRAM /format:csv`. +#[cfg(target_os = "windows")] +fn detect_gpu_windows() -> Option { + let output = std::process::Command::new("wmic") + .args(["path", "win32_VideoController", "get", "Name,AdapterRAM", "/format:csv"]) + .output() + .ok()?; + + let text = String::from_utf8_lossy(&output.stdout); + // CSV format: Node,AdapterRAM,Name (first non-empty data line) + for line in text.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with("Node") { + continue; + } + let parts: Vec<&str> = line.split(',').collect(); + if parts.len() >= 3 { + let adapter_ram_str = parts[1].trim(); + let name = parts[2].trim().to_string(); + if name.is_empty() { + continue; + } + let vram_bytes: u64 = adapter_ram_str.parse().unwrap_or(0); + let vram_gb = if vram_bytes > 0 { + Some((vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0) * 10.0).round() / 10.0) + } else { + None + }; + let vendor = if name.to_lowercase().contains("nvidia") { + "NVIDIA" + } else if name.to_lowercase().contains("amd") || name.to_lowercase().contains("radeon") + { + "AMD" + } else if name.to_lowercase().contains("intel") { + "Intel" + } else { + "Unknown" + }; + return Some(GpuInfo { + vendor: vendor.to_string(), + model: name, + vram_gb, + }); + } + } + None +} + +/// Linux GPU detection via `lspci`. +#[cfg(target_os = "linux")] +fn detect_gpu_linux() -> Option { + let output = std::process::Command::new("lspci") + .output() + .ok()?; + + let text = String::from_utf8_lossy(&output.stdout); + for line in text.lines() { + if line.contains("VGA") || line.contains("3D controller") { + // Format: "01:00.0 VGA compatible controller: NVIDIA Corporation ..." + let desc = line.splitn(2, ": ").nth(1).unwrap_or(line).trim(); + let vendor = if desc.to_lowercase().contains("nvidia") { + "NVIDIA" + } else if desc.to_lowercase().contains("amd") || desc.to_lowercase().contains("radeon") + { + "AMD" + } else if desc.to_lowercase().contains("intel") { + "Intel" + } else { + "Unknown" + }; + return Some(GpuInfo { + vendor: vendor.to_string(), + model: desc.to_string(), + vram_gb: None, + }); + } + } + None +} + +/// Detect hardware capabilities of the local machine. +/// +/// Returns CPU, RAM, GPU, OS details, and recommendations for +/// the optimal inference runtime and model quantization. +#[tauri::command] +pub async fn detect_hardware() -> Result { + let mut sys = System::new_all(); + sys.refresh_all(); + + let cpus = sys.cpus(); + let (cpu_vendor, cpu_model) = if let Some(cpu) = cpus.first() { + (cpu.vendor_id().to_string(), cpu.brand().to_string()) + } else { + ("Unknown".to_string(), "Unknown".to_string()) + }; + + let cpu_cores = sys.physical_core_count().unwrap_or(0) as u32; + let cpu_threads = cpus.len() as u32; + + let ram_total_gb = sys.total_memory() as f64 / (1024.0 * 1024.0 * 1024.0); + let ram_available_gb = sys.available_memory() as f64 / (1024.0 * 1024.0 * 1024.0); + + let os_name = System::name().unwrap_or_else(|| "Unknown".to_string()); + let os_version = System::os_version().unwrap_or_else(|| "Unknown".to_string()); + let arch = std::env::consts::ARCH.to_string(); + + let gpu = detect_gpu(); + let recommended_runtime = recommend_runtime(&gpu); + let recommended_quantization = recommend_quantization(ram_total_gb); + + Ok(HardwareInfo { + cpu_vendor, + cpu_model, + cpu_cores, + cpu_threads, + ram_total_gb: (ram_total_gb * 10.0).round() / 10.0, + ram_available_gb: (ram_available_gb * 10.0).round() / 10.0, + os_name, + os_version, + arch, + gpu, + recommended_runtime, + recommended_quantization, + }) +} diff --git a/src-tauri/src/commands/mod.rs b/src-tauri/src/commands/mod.rs new file mode 100644 index 0000000..f6c0402 --- /dev/null +++ b/src-tauri/src/commands/mod.rs @@ -0,0 +1,20 @@ +//! Tauri IPC commands exposed to the React frontend. +//! +//! Each command is callable via `invoke("command_name", { args })` from +//! the frontend TypeScript code. + +pub mod chat; +pub mod filesystem; +pub mod hardware; +pub mod model_download; +pub mod ollama; +pub mod python_env; +pub mod python_env_startup; +pub mod session; +pub mod settings; + +/// Placeholder IPC command for initial Tauri shell verification. +#[tauri::command] +pub fn greet(name: &str) -> String { + format!("Hello, {}! LocalCowork is running.", name) +} diff --git a/src-tauri/src/commands/model_download.rs b/src-tauri/src/commands/model_download.rs new file mode 100644 index 0000000..dd5d822 --- /dev/null +++ b/src-tauri/src/commands/model_download.rs @@ -0,0 +1,203 @@ +//! Tauri IPC commands for model downloading and verification. +//! +//! Provides streaming download with progress events, SHA-256 +//! verification, and model directory management. + +use futures::StreamExt; +use serde::Serialize; +use sha2::{Digest, Sha256}; +use std::path::PathBuf; +use std::time::Instant; +use tauri::Emitter; +use tokio::io::AsyncWriteExt; + +/// Progress update emitted during model download. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelDownloadProgress { + pub bytes_downloaded: u64, + pub bytes_total: u64, + pub percent: f64, + pub speed_mbps: f64, + pub eta_seconds: u64, +} + +/// Result of a completed model download. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelDownloadResult { + pub success: bool, + pub model_path: String, + pub sha256: String, + pub size_bytes: u64, +} + +/// Get the default model directory (platform-standard data dir / models/). +/// +/// Creates the directory if it does not exist. +#[tauri::command] +pub async fn get_model_dir() -> Result { + let model_dir = crate::data_dir().join("models"); + tokio::fs::create_dir_all(&model_dir) + .await + .map_err(|e| format!("Failed to create model directory: {e}"))?; + Ok(model_dir.to_string_lossy().to_string()) +} + +/// Download a model file with streaming progress events. +/// +/// Fetches the model from the given URL, writes to `target_dir`, and +/// emits `model-download-progress` events throughout. After download +/// completes, computes the SHA-256 hash of the file. +#[tauri::command] +pub async fn download_model( + url: String, + target_dir: String, + app_handle: tauri::AppHandle, +) -> Result { + let client = reqwest::Client::new(); + + let response = client + .get(&url) + .send() + .await + .map_err(|e| format!("Failed to start download: {e}"))?; + + if !response.status().is_success() { + return Err(format!( + "Download failed with status: {}", + response.status() + )); + } + + let bytes_total = response.content_length().unwrap_or(0); + + let filename = url + .split('/') + .next_back() + .unwrap_or("model.gguf") + .to_string(); + + let target_path = PathBuf::from(&target_dir).join(&filename); + + tokio::fs::create_dir_all(&target_dir) + .await + .map_err(|e| format!("Failed to create target directory: {e}"))?; + + let mut file = tokio::fs::File::create(&target_path) + .await + .map_err(|e| format!("Failed to create file: {e}"))?; + + let mut stream = response.bytes_stream(); + let mut bytes_downloaded: u64 = 0; + let start_time = Instant::now(); + let mut last_emit = Instant::now(); + + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result + .map_err(|e| format!("Download stream error: {e}"))?; + + file.write_all(&chunk) + .await + .map_err(|e| format!("Failed to write chunk: {e}"))?; + + bytes_downloaded += chunk.len() as u64; + + // Emit progress at most every 100ms to avoid flooding + if last_emit.elapsed().as_millis() >= 100 + || bytes_downloaded == bytes_total + { + let elapsed_secs = start_time.elapsed().as_secs_f64(); + let speed_mbps = if elapsed_secs > 0.0 { + (bytes_downloaded as f64 / (1024.0 * 1024.0)) / elapsed_secs + } else { + 0.0 + }; + + let percent = if bytes_total > 0 { + (bytes_downloaded as f64 / bytes_total as f64) * 100.0 + } else { + 0.0 + }; + + let eta_seconds = if speed_mbps > 0.0 && bytes_total > 0 { + let remaining_mb = + (bytes_total - bytes_downloaded) as f64 / (1024.0 * 1024.0); + (remaining_mb / speed_mbps) as u64 + } else { + 0 + }; + + let progress = ModelDownloadProgress { + bytes_downloaded, + bytes_total, + percent: (percent * 10.0).round() / 10.0, + speed_mbps: (speed_mbps * 100.0).round() / 100.0, + eta_seconds, + }; + + let _ = app_handle.emit("model-download-progress", &progress); + last_emit = Instant::now(); + } + } + + file.flush() + .await + .map_err(|e| format!("Failed to flush file: {e}"))?; + drop(file); + + let sha256 = compute_sha256(&target_path) + .await + .map_err(|e| format!("Failed to compute SHA-256: {e}"))?; + + Ok(ModelDownloadResult { + success: true, + model_path: target_path.to_string_lossy().to_string(), + sha256, + size_bytes: bytes_downloaded, + }) +} + +/// Verify a downloaded model file against an expected SHA-256 hash. +#[tauri::command] +pub async fn verify_model( + path: String, + expected_sha256: String, +) -> Result { + let file_path = PathBuf::from(&path); + if !file_path.exists() { + return Err(format!("File not found: {path}")); + } + + let actual = compute_sha256(&file_path) + .await + .map_err(|e| format!("Failed to compute SHA-256: {e}"))?; + + Ok(actual.to_lowercase() == expected_sha256.to_lowercase()) +} + +/// Compute the SHA-256 hash of a file, reading in 8 KB chunks. +async fn compute_sha256(path: &PathBuf) -> Result { + use tokio::io::AsyncReadExt; + + let mut file = tokio::fs::File::open(path) + .await + .map_err(|e| format!("Failed to open file for hashing: {e}"))?; + + let mut hasher = Sha256::new(); + let mut buffer = vec![0u8; 8192]; + + loop { + let bytes_read = file + .read(&mut buffer) + .await + .map_err(|e| format!("Failed to read file: {e}"))?; + if bytes_read == 0 { + break; + } + hasher.update(&buffer[..bytes_read]); + } + + let hash = hasher.finalize(); + Ok(format!("{hash:x}")) +} diff --git a/src-tauri/src/commands/ollama.rs b/src-tauri/src/commands/ollama.rs new file mode 100644 index 0000000..dc2f853 --- /dev/null +++ b/src-tauri/src/commands/ollama.rs @@ -0,0 +1,214 @@ +//! Tauri IPC commands for Ollama integration. +//! +//! Provides model listing, status checking, and pull (download) +//! operations against a running Ollama instance. + +use serde::{Deserialize, Serialize}; +use tauri::Emitter; + +/// Base URL for the Ollama HTTP API. +const OLLAMA_API_BASE: &str = "http://localhost:11434"; + +/// Information about a single Ollama model. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct OllamaModelInfo { + pub name: String, + pub size_bytes: u64, + pub parameter_size: String, + pub quantization_level: String, +} + +/// Progress update emitted while pulling an Ollama model. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct OllamaPullProgress { + pub status: String, + pub total: u64, + pub completed: u64, + pub percent: f64, +} + +/// Raw Ollama `/api/tags` response shape. +#[derive(Debug, Deserialize)] +struct OllamaTagsResponse { + models: Option>, +} + +/// Raw model entry from Ollama tags API. +#[derive(Debug, Deserialize)] +struct OllamaTagModel { + name: String, + size: u64, + details: Option, +} + +/// Details sub-object from Ollama tags API. +#[derive(Debug, Deserialize)] +struct OllamaTagModelDetails { + parameter_size: Option, + quantization_level: Option, +} + +/// Raw progress line from Ollama `/api/pull` streaming response. +#[derive(Debug, Deserialize)] +struct OllamaPullLine { + status: Option, + total: Option, + completed: Option, +} + +/// Base URL for the llama.cpp health endpoint (matches _models/config.yaml). +const LLAMA_SERVER_HEALTH: &str = "http://localhost:8080/health"; + +/// Check whether a llama-server instance is running on port 8080. +/// +/// Returns `true` if the llama.cpp `/health` endpoint responds with 2xx. +#[tauri::command] +pub async fn check_llama_server_status() -> Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(3)) + .build() + .map_err(|e| format!("HTTP client error: {e}"))?; + + match client.get(LLAMA_SERVER_HEALTH).send().await { + Ok(resp) => Ok(resp.status().is_success()), + Err(_) => Ok(false), + } +} + +/// Check whether Ollama is running and reachable. +/// +/// Returns `true` if the Ollama API at localhost:11434 responds. +#[tauri::command] +pub async fn check_ollama_status() -> Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(3)) + .build() + .map_err(|e| format!("HTTP client error: {e}"))?; + + match client.get(format!("{OLLAMA_API_BASE}/api/tags")).send().await { + Ok(resp) => Ok(resp.status().is_success()), + Err(_) => Ok(false), + } +} + +/// List all models currently available in the local Ollama instance. +/// +/// Queries `GET /api/tags` and returns a simplified model list. +#[tauri::command] +pub async fn list_ollama_models() -> Result, String> { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(5)) + .build() + .map_err(|e| format!("HTTP client error: {e}"))?; + + let response = client + .get(format!("{OLLAMA_API_BASE}/api/tags")) + .send() + .await + .map_err(|e| format!("Cannot reach Ollama: {e}"))?; + + if !response.status().is_success() { + return Err(format!( + "Ollama API returned status: {}", + response.status() + )); + } + + let tags: OllamaTagsResponse = response + .json() + .await + .map_err(|e| format!("Failed to parse Ollama response: {e}"))?; + + let models = tags + .models + .unwrap_or_default() + .into_iter() + .map(|m| { + let details = m.details.unwrap_or(OllamaTagModelDetails { + parameter_size: None, + quantization_level: None, + }); + OllamaModelInfo { + name: m.name, + size_bytes: m.size, + parameter_size: details.parameter_size.unwrap_or_default(), + quantization_level: details.quantization_level.unwrap_or_default(), + } + }) + .collect(); + + Ok(models) +} + +/// Pull (download) a model via Ollama's streaming API. +/// +/// Streams progress events as `ollama-pull-progress` Tauri events. +/// The model name should be in Ollama format (e.g., "gpt-oss:20b"). +#[tauri::command] +pub async fn pull_ollama_model( + model_name: String, + app_handle: tauri::AppHandle, +) -> Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(3600)) + .build() + .map_err(|e| format!("HTTP client error: {e}"))?; + + let response = client + .post(format!("{OLLAMA_API_BASE}/api/pull")) + .json(&serde_json::json!({ "name": model_name, "stream": true })) + .send() + .await + .map_err(|e| format!("Cannot reach Ollama: {e}"))?; + + if !response.status().is_success() { + return Err(format!( + "Ollama pull failed with status: {}", + response.status() + )); + } + + // Stream the NDJSON response line by line + use futures::StreamExt; + let mut stream = response.bytes_stream(); + let mut buffer = Vec::new(); + + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(|e| format!("Stream error: {e}"))?; + buffer.extend_from_slice(&chunk); + + // Process complete lines (NDJSON — each JSON object ends with \n) + while let Some(pos) = buffer.iter().position(|&b| b == b'\n') { + let line_bytes: Vec = buffer.drain(..=pos).collect(); + let line = String::from_utf8_lossy(&line_bytes); + let trimmed = line.trim(); + + if trimmed.is_empty() { + continue; + } + + if let Ok(pull_line) = serde_json::from_str::(trimmed) { + let total = pull_line.total.unwrap_or(0); + let completed = pull_line.completed.unwrap_or(0); + let percent = if total > 0 { + (completed as f64 / total as f64) * 100.0 + } else { + 0.0 + }; + + let progress = OllamaPullProgress { + status: pull_line.status.unwrap_or_default(), + total, + completed, + percent: (percent * 10.0).round() / 10.0, + }; + + let _ = app_handle.emit("ollama-pull-progress", &progress); + } + } + } + + Ok(true) +} diff --git a/src-tauri/src/commands/python_env.rs b/src-tauri/src/commands/python_env.rs new file mode 100644 index 0000000..562c7f7 --- /dev/null +++ b/src-tauri/src/commands/python_env.rs @@ -0,0 +1,401 @@ +//! IPC commands for Python MCP server environment provisioning. +//! +//! Ensures per-server `.venv` directories exist with all dependencies installed. +//! Called during onboarding (Setup step) and from Settings > Servers (Repair). +//! +//! The flow: +//! 1. Detect Python servers by scanning `mcp-servers/` for `pyproject.toml` +//! 2. For each server: check `.venv/bin/python` (or `Scripts\python.exe` on Windows) +//! 3. If missing: create venv → pip install -e . +//! 4. Emit progress events via Tauri so the frontend can show real-time status + +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use tauri::Emitter; +use tokio::process::Command; + +// ─── Types ────────────────────────────────────────────────────────────────── + +/// Status of a single Python server's environment. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PythonEnvStatus { + /// Server name (e.g., "ocr", "document"). + pub server: String, + /// Whether the venv is ready (exists + deps installed). + pub ready: bool, + /// Error message if provisioning failed. + pub error: Option, +} + +/// Progress event emitted during venv provisioning. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PythonEnvProgress { + /// Server name. + pub server: String, + /// Current stage: "checking", "creating_venv", "installing_deps", "done", "failed". + pub stage: String, + /// Human-readable message. + pub message: String, +} + +// ─── Platform Helpers ─────────────────────────────────────────────────────── + +/// Return the platform-specific venv binary subdirectory name. +pub(super) fn venv_bin_dir() -> &'static str { + if cfg!(target_os = "windows") { + "Scripts" + } else { + "bin" + } +} + +/// Return the Python executable name for this platform. +pub(super) fn python_executable() -> &'static str { + if cfg!(target_os = "windows") { + "python.exe" + } else { + "python" + } +} + +/// Return the pip executable name for this platform. +pub(super) fn pip_executable() -> &'static str { + if cfg!(target_os = "windows") { + "pip.exe" + } else { + "pip" + } +} + +/// Find the system python3 command. +/// +/// Checks `python3` first, then falls back to `python`. +pub(super) async fn find_system_python() -> Result { + for candidate in ["python3", "python"] { + let result = Command::new(candidate) + .arg("--version") + .output() + .await; + if let Ok(output) = result { + if output.status.success() { + return Ok(candidate.to_string()); + } + } + } + Err("Python 3 not found. Install Python 3.11+ from https://python.org".to_string()) +} + +// ─── Core Provisioning ────────────────────────────────────────────────────── + +/// Check if a Python server's venv is already provisioned. +pub(super) fn is_venv_ready(server_dir: &Path) -> bool { + let venv_dir = server_dir.join(".venv"); + let python_path = venv_dir.join(venv_bin_dir()).join(python_executable()); + python_path.exists() +} + +/// Provision a single Python server's venv. +/// +/// 1. Creates `.venv` via `python3 -m venv` +/// 2. Upgrades pip +/// 3. Runs `pip install -e .` to install the server's deps +async fn provision_server_env( + server_name: &str, + server_dir: &Path, + app: &tauri::AppHandle, +) -> Result<(), String> { + let venv_dir = server_dir.join(".venv"); + let bin_dir = venv_dir.join(venv_bin_dir()); + let pip_path = bin_dir.join(pip_executable()); + + // Stage 1: Create venv + emit_progress(app, server_name, "creating_venv", "Creating virtual environment..."); + + let system_python = find_system_python().await?; + + let create_output = Command::new(&system_python) + .args(["-m", "venv"]) + .arg(&venv_dir) + .current_dir(server_dir) + .output() + .await + .map_err(|e| format!("Failed to run python -m venv: {e}"))?; + + if !create_output.status.success() { + let stderr = String::from_utf8_lossy(&create_output.stderr); + return Err(format!("Failed to create venv: {}", stderr.trim())); + } + + // Stage 2: Upgrade pip (prevents compatibility issues with hatchling) + emit_progress(app, server_name, "installing_deps", "Upgrading pip..."); + + let pip_upgrade = Command::new(&pip_path) + .args(["install", "--quiet", "--upgrade", "pip"]) + .current_dir(server_dir) + .output() + .await + .map_err(|e| format!("Failed to upgrade pip: {e}"))?; + + if !pip_upgrade.status.success() { + let stderr = String::from_utf8_lossy(&pip_upgrade.stderr); + tracing::warn!( + server = server_name, + stderr = %stderr.trim(), + "pip upgrade failed (non-fatal, continuing with existing pip)" + ); + } + + // Stage 3: Install server dependencies + emit_progress( + app, + server_name, + "installing_deps", + "Installing dependencies...", + ); + + let install_output = Command::new(&pip_path) + .args(["install", "--quiet", "-e", "."]) + .current_dir(server_dir) + .output() + .await + .map_err(|e| format!("Failed to run pip install: {e}"))?; + + if !install_output.status.success() { + let stderr = String::from_utf8_lossy(&install_output.stderr); + return Err(format!( + "Failed to install dependencies: {}", + stderr.trim() + )); + } + + emit_progress(app, server_name, "done", "Ready"); + Ok(()) +} + +/// Emit a progress event to the frontend. +fn emit_progress(app: &tauri::AppHandle, server: &str, stage: &str, message: &str) { + let payload = PythonEnvProgress { + server: server.to_string(), + stage: stage.to_string(), + message: message.to_string(), + }; + let _ = app.emit("python-env-progress", &payload); + tracing::info!( + server = server, + stage = stage, + message = message, + "python env provisioning" + ); +} + +// ─── Discovery Helper ─────────────────────────────────────────────────────── + +/// Find all Python MCP server directories (those with `pyproject.toml`). +fn discover_python_servers(project_root: &Path) -> Vec<(String, PathBuf)> { + let mcp_dir = project_root.join("mcp-servers"); + let mut servers = Vec::new(); + + let entries = match std::fs::read_dir(&mcp_dir) { + Ok(e) => e, + Err(_) => return servers, + }; + + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_dir() { + continue; + } + + let name = match path.file_name().and_then(|n| n.to_str()) { + Some(n) => n.to_string(), + None => continue, + }; + + // Skip internal/hidden directories + if name.starts_with('_') || name.starts_with('.') { + continue; + } + + // Only Python servers (have pyproject.toml) + if path.join("pyproject.toml").exists() { + servers.push((name, path)); + } + } + + servers +} + +// ─── Tauri IPC Commands ───────────────────────────────────────────────────── + +/// Ensure a single Python server's venv is provisioned. +/// +/// Idempotent: if the venv already exists and has python, returns immediately. +/// Creates the venv and installs deps if missing. +#[tauri::command] +pub async fn ensure_python_server_env( + server_name: String, + app: tauri::AppHandle, +) -> Result { + let project_root = crate::resolve_project_root(); + let server_dir = project_root.join("mcp-servers").join(&server_name); + + if !server_dir.join("pyproject.toml").exists() { + return Ok(PythonEnvStatus { + server: server_name, + ready: false, + error: Some("Not a Python server (no pyproject.toml)".to_string()), + }); + } + + emit_progress(&app, &server_name, "checking", "Checking environment..."); + + if is_venv_ready(&server_dir) { + emit_progress(&app, &server_name, "done", "Already provisioned"); + return Ok(PythonEnvStatus { + server: server_name, + ready: true, + error: None, + }); + } + + match provision_server_env(&server_name, &server_dir, &app).await { + Ok(()) => Ok(PythonEnvStatus { + server: server_name, + ready: true, + error: None, + }), + Err(e) => { + emit_progress(&app, &server_name, "failed", &e); + Ok(PythonEnvStatus { + server: server_name, + ready: false, + error: Some(e), + }) + } + } +} + +/// Ensure all Python MCP servers have their venvs provisioned. +/// +/// Discovers Python servers from `mcp-servers/`, provisions each sequentially, +/// and returns aggregate status. Continues past individual failures. +#[tauri::command] +pub async fn ensure_all_python_envs( + app: tauri::AppHandle, +) -> Result, String> { + let project_root = crate::resolve_project_root(); + let servers = discover_python_servers(&project_root); + + tracing::info!( + count = servers.len(), + servers = ?servers.iter().map(|(n, _)| n.as_str()).collect::>(), + "provisioning Python server environments" + ); + + let mut results = Vec::new(); + + for (name, server_dir) in &servers { + emit_progress(&app, name, "checking", "Checking environment..."); + + if is_venv_ready(server_dir) { + emit_progress(&app, name, "done", "Already provisioned"); + results.push(PythonEnvStatus { + server: name.clone(), + ready: true, + error: None, + }); + continue; + } + + match provision_server_env(name, server_dir, &app).await { + Ok(()) => { + results.push(PythonEnvStatus { + server: name.clone(), + ready: true, + error: None, + }); + } + Err(e) => { + emit_progress(&app, name, "failed", &e); + tracing::warn!( + server = %name, + error = %e, + "failed to provision Python server env (non-fatal)" + ); + results.push(PythonEnvStatus { + server: name.clone(), + ready: false, + error: Some(e), + }); + } + } + } + + let ready_count = results.iter().filter(|r| r.ready).count(); + tracing::info!( + ready = ready_count, + total = results.len(), + "Python server environment provisioning complete" + ); + + Ok(results) +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn test_discover_python_servers() { + let tmp = TempDir::new().unwrap(); + let mcp = tmp.path().join("mcp-servers"); + std::fs::create_dir(&mcp).unwrap(); + + // Python server + let ocr = mcp.join("ocr"); + std::fs::create_dir(&ocr).unwrap(); + std::fs::write(ocr.join("pyproject.toml"), "[project]\nname = \"ocr\"").unwrap(); + + // TypeScript server (should be skipped) + let fs_srv = mcp.join("filesystem"); + std::fs::create_dir(&fs_srv).unwrap(); + std::fs::write(fs_srv.join("package.json"), "{}").unwrap(); + + // Hidden dir (should be skipped) + let hidden = mcp.join("_shared"); + std::fs::create_dir(&hidden).unwrap(); + std::fs::write(hidden.join("pyproject.toml"), "").unwrap(); + + let servers = discover_python_servers(tmp.path()); + assert_eq!(servers.len(), 1); + assert_eq!(servers[0].0, "ocr"); + } + + #[test] + fn test_is_venv_ready_false_when_missing() { + let tmp = TempDir::new().unwrap(); + assert!(!is_venv_ready(tmp.path())); + } + + #[test] + fn test_is_venv_ready_true_when_python_exists() { + let tmp = TempDir::new().unwrap(); + let venv_bin = tmp.path().join(".venv").join(venv_bin_dir()); + std::fs::create_dir_all(&venv_bin).unwrap(); + std::fs::write(venv_bin.join(python_executable()), "").unwrap(); + assert!(is_venv_ready(tmp.path())); + } + + #[test] + fn test_venv_bin_dir_platform() { + let dir = venv_bin_dir(); + if cfg!(target_os = "windows") { + assert_eq!(dir, "Scripts"); + } else { + assert_eq!(dir, "bin"); + } + } +} diff --git a/src-tauri/src/commands/python_env_startup.rs b/src-tauri/src/commands/python_env_startup.rs new file mode 100644 index 0000000..3512a54 --- /dev/null +++ b/src-tauri/src/commands/python_env_startup.rs @@ -0,0 +1,135 @@ +//! Startup-time Python venv provisioning (called from `lib.rs`). +//! +//! Unlike the IPC commands in `python_env.rs`, this runs before the frontend +//! is connected and logs progress to `agent.log` instead of emitting events. + +use std::path::{Path, PathBuf}; +use tokio::process::Command; + +use super::python_env::{find_system_python, is_venv_ready, pip_executable, venv_bin_dir}; + +// ─── Discovery ────────────────────────────────────────────────────────────── + +/// Find all Python MCP server directories (those with `pyproject.toml`). +fn discover_python_servers(project_root: &Path) -> Vec<(String, PathBuf)> { + let mcp_dir = project_root.join("mcp-servers"); + let mut servers = Vec::new(); + + let entries = match std::fs::read_dir(&mcp_dir) { + Ok(e) => e, + Err(_) => return servers, + }; + + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_dir() { + continue; + } + + let name = match path.file_name().and_then(|n| n.to_str()) { + Some(n) => n.to_string(), + None => continue, + }; + + if name.starts_with('_') || name.starts_with('.') { + continue; + } + + if path.join("pyproject.toml").exists() { + servers.push((name, path)); + } + } + + servers +} + +// ─── Startup Provisioning ─────────────────────────────────────────────────── + +/// Provision missing Python venvs at app startup. +/// +/// Scans `mcp-servers/` for Python servers without `.venv` directories, +/// creates venvs and installs dependencies. Idempotent — skips servers +/// that already have a working venv. +pub async fn provision_missing_venvs(project_root: &Path) { + let servers = discover_python_servers(project_root); + + let missing: Vec<_> = servers + .iter() + .filter(|(_, dir)| !is_venv_ready(dir)) + .collect(); + + if missing.is_empty() { + tracing::info!("all Python server venvs already provisioned"); + return; + } + + tracing::info!( + count = missing.len(), + servers = ?missing.iter().map(|(n, _)| n.as_str()).collect::>(), + "provisioning missing Python server venvs at startup" + ); + + let system_python = match find_system_python().await { + Ok(p) => p, + Err(e) => { + tracing::error!(error = %e, "cannot provision Python venvs — python not found"); + return; + } + }; + + for (name, server_dir) in &missing { + let venv_dir = server_dir.join(".venv"); + let bin_dir = venv_dir.join(venv_bin_dir()); + let pip_path = bin_dir.join(pip_executable()); + + tracing::info!(server = %name, "creating venv..."); + + let create_result = Command::new(&system_python) + .args(["-m", "venv"]) + .arg(&venv_dir) + .current_dir(server_dir) + .output() + .await; + + match create_result { + Ok(output) if output.status.success() => {} + Ok(output) => { + let stderr = String::from_utf8_lossy(&output.stderr); + tracing::warn!(server = %name, stderr = %stderr.trim(), "failed to create venv"); + continue; + } + Err(e) => { + tracing::warn!(server = %name, error = %e, "failed to run python -m venv"); + continue; + } + } + + // Upgrade pip (non-fatal) + let _ = Command::new(&pip_path) + .args(["install", "--quiet", "--upgrade", "pip"]) + .current_dir(server_dir) + .output() + .await; + + tracing::info!(server = %name, "installing dependencies..."); + + let install_result = Command::new(&pip_path) + .args(["install", "--quiet", "-e", "."]) + .current_dir(server_dir) + .output() + .await; + + match install_result { + Ok(output) if output.status.success() => { + tracing::info!(server = %name, "venv provisioned successfully"); + } + Ok(output) => { + let stderr = String::from_utf8_lossy(&output.stderr); + tracing::warn!(server = %name, stderr = %stderr.trim(), "pip install failed"); + } + Err(e) => { + tracing::warn!(server = %name, error = %e, "failed to run pip install"); + } + } + } +} diff --git a/src-tauri/src/commands/session.rs b/src-tauri/src/commands/session.rs new file mode 100644 index 0000000..b37c712 --- /dev/null +++ b/src-tauri/src/commands/session.rs @@ -0,0 +1,209 @@ +//! Tauri IPC commands for session management. +//! +//! These commands let the frontend list, load, and delete conversation +//! sessions, as well as query the current context window budget. + +use serde::Serialize; +use std::sync::Mutex; + +use crate::agent_core::ConversationManager; +use crate::agent_core::tokens::truncate_utf8; + +// ─── Response Types ───────────────────────────────────────────────────────── + +/// Summary of a session for the sidebar list. +#[derive(Debug, Serialize)] +pub struct SessionListItem { + pub id: String, + pub created_at: String, + pub last_activity: String, + pub message_count: usize, + pub preview: Option, +} + +/// Context budget snapshot sent to the frontend. +#[derive(Debug, Serialize)] +pub struct ContextBudgetResponse { + pub total: u32, + pub system_prompt: u32, + pub tool_definitions: u32, + pub conversation_history: u32, + pub output_reservation: u32, + pub remaining: u32, +} + +// ─── Commands ─────────────────────────────────────────────────────────────── + +/// List sessions that have actual user content (not just system prompt). +/// +/// Excludes empty sessions to keep the sidebar clean. Sessions are +/// sorted by most recent activity first. +#[tauri::command] +pub fn list_sessions( + state: tauri::State<'_, Mutex>, +) -> Result, String> { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + let sessions = mgr.db().list_sessions().map_err(|e| format!("{e}"))?; + + let mut items = Vec::new(); + for session in sessions { + let count = mgr.db().message_count(&session.id).unwrap_or(0); + + // Skip sessions with only a system prompt (no user interaction) + if count <= 1 { + continue; + } + + // Get first user message as preview + let preview = mgr + .get_history(&session.id) + .ok() + .and_then(|msgs| { + msgs.iter() + .find(|m| m.role == crate::inference::types::Role::User) + .and_then(|m| m.content.clone()) + }) + .map(|s| { + if s.len() > 80 { + format!("{}…", truncate_utf8(&s, 77)) + } else { + s + } + }); + + items.push(SessionListItem { + id: session.id, + created_at: session.created_at, + last_activity: session.last_activity, + message_count: count, + preview, + }); + } + + Ok(items) +} + +/// Load a session's conversation history for display. +/// +/// Returns messages with full metadata including toolCallId and toolCalls +/// so the frontend can properly render ToolTrace components. +#[tauri::command] +pub fn load_session( + session_id: String, + state: tauri::State<'_, Mutex>, +) -> Result, String> { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + let history = mgr.get_history(&session_id).map_err(|e| format!("{e}"))?; + + let messages: Vec = history + .iter() + .filter(|m| m.role != crate::inference::types::Role::System) + .map(|m| { + let mut msg = serde_json::json!({ + "id": m.id, + "sessionId": m.session_id, + "timestamp": m.timestamp, + "role": format!("{:?}", m.role).to_lowercase(), + "content": m.content, + "tokenCount": m.token_count, + }); + + // Include tool_call_id and toolResult for tool result messages + // so ToolTrace can correlate them and show results. + if let Some(ref tc_id) = m.tool_call_id { + let obj = msg.as_object_mut().unwrap(); + obj.insert( + "toolCallId".to_string(), + serde_json::Value::String(tc_id.clone()), + ); + // Include toolResult so ToolTrace can show result status + obj.insert( + "toolResult".to_string(), + serde_json::json!({ + "success": true, + "result": m.content, + "toolCallId": tc_id, + }), + ); + } + + // Include tool_calls for assistant messages + if let Some(ref calls) = m.tool_calls { + let tc_json: Vec = calls + .iter() + .map(|tc| { + serde_json::json!({ + "id": tc.id, + "name": tc.name, + "arguments": tc.arguments, + }) + }) + .collect(); + msg.as_object_mut() + .unwrap() + .insert("toolCalls".to_string(), serde_json::Value::Array(tc_json)); + } + + msg + }) + .collect(); + + Ok(messages) +} + +/// Delete a session and all its data. +#[tauri::command] +pub fn delete_session( + session_id: String, + state: tauri::State<'_, Mutex>, +) -> Result<(), String> { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + mgr.db() + .delete_session(&session_id) + .map_err(|e| format!("{e}")) +} + +/// Get the current context window budget for a session. +#[tauri::command] +pub fn get_context_budget( + session_id: String, + state: tauri::State<'_, Mutex>, +) -> Result { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + let budget = mgr.get_budget(&session_id).map_err(|e| format!("{e}"))?; + + Ok(ContextBudgetResponse { + total: budget.total, + system_prompt: budget.system_prompt, + tool_definitions: budget.tool_definitions, + conversation_history: budget.conversation_history, + output_reservation: budget.output_reservation, + remaining: budget.remaining, + }) +} + +/// Clean up orphan empty sessions (only system prompt, no user messages). +/// +/// Called on app startup to remove sessions from previous launches that +/// were created but never used. +#[tauri::command] +pub fn cleanup_empty_sessions( + state: tauri::State<'_, Mutex>, +) -> Result { + let mgr = state.lock().map_err(|e| format!("Lock error: {e}"))?; + let sessions = mgr.db().list_sessions().map_err(|e| format!("{e}"))?; + + let mut cleaned = 0u32; + for session in &sessions { + if let Ok(count) = mgr.db().message_count(&session.id) { + if count <= 1 && mgr.db().delete_session(&session.id).is_ok() { + cleaned += 1; + } + } + } + + if cleaned > 0 { + tracing::info!(cleaned = cleaned, "cleaned up empty sessions"); + } + Ok(cleaned) +} diff --git a/src-tauri/src/commands/settings.rs b/src-tauri/src/commands/settings.rs new file mode 100644 index 0000000..6c8d20f --- /dev/null +++ b/src-tauri/src/commands/settings.rs @@ -0,0 +1,726 @@ +//! Tauri IPC commands for the Settings panel. +//! +//! Reads model configuration from `_models/config.yaml` (the same source +//! of truth used by the inference client at runtime) and provides live +//! MCP server status from the running McpClient. + +use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, Ordering}; + +use serde::{Deserialize, Serialize}; + +static SETTINGS_CHANGED: AtomicBool = AtomicBool::new(false); + +pub fn settings_changed() { + SETTINGS_CHANGED.store(true, Ordering::SeqCst); +} + +pub fn has_settings_changed() -> bool { + SETTINGS_CHANGED.swap(false, Ordering::SeqCst) +} + +/// Unified app settings that persist across restarts. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AppSettings { + /// Currently active model key from _models/config.yaml + pub active_model_key: Option, + /// Allowed filesystem paths for sandboxed operations + pub allowed_paths: Vec, + /// UI theme preference + pub theme: String, + /// Whether to show tool traces + pub show_tool_traces: bool, + /// Sampling config (integrated from existing system) + pub sampling: SamplingConfig, +} + +impl Default for AppSettings { + fn default() -> Self { + Self { + active_model_key: None, + allowed_paths: Vec::new(), + theme: "system".to_string(), + show_tool_traces: true, + sampling: SamplingConfig::default(), + } + // Default allowed paths + } +} + +impl AppSettings { + const FILE_NAME: &'static str = "settings.json"; + + fn persist_path() -> PathBuf { + crate::data_dir().join(Self::FILE_NAME) + } + + pub fn load_or_default() -> Self { + let path = Self::persist_path(); + if !path.exists() { + return Self::default(); + } + match std::fs::read_to_string(&path) { + Ok(content) => match serde_json::from_str::(&content) { + Ok(settings) => { + tracing::info!(path = %path.display(), "loaded app settings"); + settings + } + Err(e) => { + tracing::warn!(error = %e, "failed to parse settings, using defaults"); + Self::default() + } + }, + Err(e) => { + tracing::warn!(error = %e, "failed to read settings, using defaults"); + Self::default() + } + } + } + + pub fn save(&self) { + let path = Self::persist_path(); + let content = match serde_json::to_string_pretty(self) { + Ok(c) => c, + Err(e) => { + tracing::error!(error = %e, "failed to serialize settings"); + return; + } + }; + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let tmp_path = path.with_extension("json.tmp"); + if let Err(e) = std::fs::write(&tmp_path, &content) { + tracing::error!(error = %e, "failed to write settings temp file"); + return; + } + if let Err(e) = std::fs::rename(&tmp_path, &path) { + tracing::error!(error = %e, "failed to rename settings file"); + return; + } + settings_changed(); + tracing::debug!("saved app settings"); + } + + pub fn export_to_json(&self) -> Result { + serde_json::to_string_pretty(self).map_err(|e| format!("export failed: {}", e)) + } + + pub fn import_from_json(json: &str) -> Result { + let settings: Self = + serde_json::from_str(json).map_err(|e| format!("invalid settings JSON: {}", e))?; + settings.sampling.save(); + settings.save(); + Ok(settings) + } +} + +/// Model configuration exposed to the frontend. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelConfigInfo { + pub key: String, + pub display_name: String, + pub runtime: String, + pub base_url: String, + pub context_window: u32, + pub temperature: f64, + pub max_tokens: u32, + pub estimated_vram_gb: Option, + pub capabilities: Vec, + pub tool_call_format: String, +} + +/// Models overview returned to the frontend. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelsOverviewInfo { + pub active_model: String, + pub models: Vec, + pub fallback_chain: Vec, +} + +/// MCP server status. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServerStatusInfo { + pub name: String, + pub status: String, + pub tool_count: u32, + pub tool_names: Vec, + pub last_check: String, + pub error: Option, +} + +/// Get the models configuration overview. +/// +/// Reads from `_models/config.yaml` using the same config loader +/// that the inference client uses at runtime. +#[tauri::command] +pub fn get_models_config() -> Result { + let cwd = std::env::current_dir().unwrap_or_default(); + let config_path = crate::inference::config::find_config_path(&cwd) + .map_err(|e| format!("Config not found: {e}"))?; + let config = crate::inference::config::load_models_config(&config_path) + .map_err(|e| format!("Config load error: {e}"))?; + + let models: Vec = config + .models + .iter() + .map(|(key, m)| ModelConfigInfo { + key: key.clone(), + display_name: m.display_name.clone(), + runtime: m.runtime.clone(), + base_url: m.base_url.clone(), + context_window: m.context_window, + temperature: f64::from(m.temperature), + max_tokens: m.max_tokens, + estimated_vram_gb: m.estimated_vram_gb.map(f64::from), + capabilities: m.capabilities.clone(), + tool_call_format: format!("{:?}", m.tool_call_format), + }) + .collect(); + + Ok(ModelsOverviewInfo { + active_model: config.active_model.clone(), + models, + fallback_chain: config.fallback_chain.clone(), + }) +} + +/// Get the status of all MCP servers from the running McpClient. +/// +/// Queries actual server state — no hardcoded stubs. Returns configured +/// servers with their running status and tool count. +#[tauri::command] +pub async fn get_mcp_servers_status( + mcp_state: tauri::State<'_, crate::TokioMutex>, +) -> Result, String> { + let mcp = mcp_state.lock().await; + let now = chrono::Utc::now().to_rfc3339(); + + let configured = mcp.configured_servers(); + let mut statuses: Vec = configured + .into_iter() + .map(|name| { + let is_running = mcp.is_server_running(&name); + let tool_count = mcp.registry.tools_for_server(&name) as u32; + let tool_names = mcp.registry.tool_names_for_server(&name); + + McpServerStatusInfo { + status: if is_running { + "initialized".to_string() + } else { + "failed".to_string() + }, + tool_count, + tool_names, + last_check: now.clone(), + error: if is_running { + None + } else { + Some("Server not running".to_string()) + }, + name, + } + }) + .collect(); + + statuses.sort_by(|a, b| a.name.cmp(&b.name)); + Ok(statuses) +} + +// ─── Permission Grant Management ──────────────────────────────────────────── + +/// A permission grant exposed to the frontend. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionGrantInfo { + pub tool_name: String, + pub scope: String, + pub granted_at: String, +} + +/// List all persistent permission grants. +/// +/// Reads from the PermissionStore in Tauri state. +#[tauri::command] +pub async fn list_permission_grants( + perms: tauri::State<'_, crate::TokioMutex>, +) -> Result, String> { + let store = perms.lock().await; + let grants = store + .list_persistent() + .into_iter() + .map(|g| PermissionGrantInfo { + tool_name: g.tool_name.clone(), + scope: format!("{:?}", g.scope).to_lowercase(), + granted_at: g.granted_at.clone(), + }) + .collect(); + Ok(grants) +} + +/// Revoke a persistent permission grant by tool name. +/// +/// Removes the grant from the PermissionStore and persists the change to disk. +#[tauri::command] +pub async fn revoke_permission( + tool_name: String, + perms: tauri::State<'_, crate::TokioMutex>, +) -> Result { + let mut store = perms.lock().await; + let removed = store.revoke(&tool_name); + tracing::info!(tool = %tool_name, removed, "revoke_permission"); + Ok(removed) +} + +// ─── Sampling Configuration ───────────────────────────────────────────────── + +/// Runtime sampling hyperparameters exposed to the frontend. +/// +/// Persisted to `sampling_config.json` in the app data directory. +/// The agent loop reads these at the start of each `send_message` call +/// instead of using hardcoded constants. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingConfig { + pub tool_temperature: f32, + pub tool_top_p: f32, + pub conversational_temperature: f32, + pub conversational_top_p: f32, +} + +impl Default for SamplingConfig { + fn default() -> Self { + Self { + tool_temperature: 0.1, + tool_top_p: 0.2, + conversational_temperature: 0.7, + conversational_top_p: 0.9, + } + } +} + +impl SamplingConfig { + /// Load from disk or return defaults. + pub fn load_or_default() -> Self { + let path = Self::persist_path(); + if !path.exists() { + return Self::default(); + } + match std::fs::read_to_string(&path) { + Ok(content) => match serde_json::from_str::(&content) { + Ok(cfg) => { + tracing::info!(path = %path.display(), "loaded sampling config"); + cfg + } + Err(e) => { + tracing::warn!(error = %e, "failed to parse sampling config, using defaults"); + Self::default() + } + }, + Err(e) => { + tracing::warn!(error = %e, "failed to read sampling config, using defaults"); + Self::default() + } + } + } + + /// Save to disk (atomic write). + pub fn save(&self) { + let path = Self::persist_path(); + let content = match serde_json::to_string_pretty(self) { + Ok(c) => c, + Err(e) => { + tracing::error!(error = %e, "failed to serialize sampling config"); + return; + } + }; + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + } + let tmp_path = path.with_extension("json.tmp"); + if let Err(e) = std::fs::write(&tmp_path, &content) { + tracing::error!(error = %e, "failed to write sampling config temp file"); + return; + } + if let Err(e) = std::fs::rename(&tmp_path, &path) { + tracing::error!(error = %e, "failed to rename sampling config file"); + return; + } + tracing::debug!("saved sampling config"); + } + + fn persist_path() -> PathBuf { + crate::data_dir().join("sampling_config.json") + } +} + +/// Get the current sampling configuration. +#[tauri::command] +pub async fn get_sampling_config( + state: tauri::State<'_, crate::TokioMutex>, +) -> Result { + let cfg = state.lock().await; + Ok(cfg.clone()) +} + +/// Update the sampling configuration and persist to disk. +#[tauri::command] +pub async fn update_sampling_config( + config: SamplingConfig, + state: tauri::State<'_, crate::TokioMutex>, +) -> Result { + let mut cfg = state.lock().await; + *cfg = config; + cfg.save(); + tracing::info!( + tool_temp = cfg.tool_temperature, + tool_top_p = cfg.tool_top_p, + conv_temp = cfg.conversational_temperature, + conv_top_p = cfg.conversational_top_p, + "sampling config updated" + ); + Ok(cfg.clone()) +} + +/// Reset the sampling configuration to defaults and persist. +#[tauri::command] +pub async fn reset_sampling_config( + state: tauri::State<'_, crate::TokioMutex>, +) -> Result { + let mut cfg = state.lock().await; + *cfg = SamplingConfig::default(); + cfg.save(); + tracing::info!("sampling config reset to defaults"); + Ok(cfg.clone()) +} + +// ─── Unified App Settings ──────────────────────────────────────────────────── + +/// Get the current app settings. +#[tauri::command] +pub fn get_app_settings() -> AppSettings { + AppSettings::load_or_default() +} + +/// Update app settings and persist to disk. +#[tauri::command] +pub fn update_app_settings(settings: AppSettings) -> AppSettings { + settings.save(); + tracing::info!( + active_model = ?settings.active_model_key, + theme = %settings.theme, + allowed_paths = settings.allowed_paths.len(), + "app settings updated" + ); + settings +} + +/// Add an allowed path to settings. +#[tauri::command] +pub fn add_allowed_path(path: String) -> AppSettings { + let mut settings = AppSettings::load_or_default(); + if !settings.allowed_paths.contains(&path) { + settings.allowed_paths.push(path.clone()); + settings.save(); + tracing::info!(path = %path, "allowed path added"); + } + settings +} + +/// Remove an allowed path from settings. +#[tauri::command] +pub fn remove_allowed_path(path: String) -> AppSettings { + let mut settings = AppSettings::load_or_default(); + let path_clone = path.clone(); + settings.allowed_paths.retain(|p| p != &path); + settings.save(); + tracing::info!(path = %path_clone, "allowed path removed"); + settings +} + +/// Export settings to JSON string. +#[tauri::command] +pub fn export_settings() -> Result { + let settings = AppSettings::load_or_default(); + settings.export_to_json() +} + +/// Import settings from JSON string. +#[tauri::command] +pub fn import_settings(json: String) -> Result { + AppSettings::import_from_json(&json) +} + +/// Check if settings have changed since last check (for file watching). +#[tauri::command] +pub fn poll_settings_changed() -> bool { + has_settings_changed() +} + +// ─── Config Hot Reload ────────────────────────────────────────────────────── + +use std::sync::atomic::AtomicU64; +use std::time::SystemTime; + +static CONFIG_LAST_MODIFIED: AtomicU64 = AtomicU64::new(0); + +/// Check if config file has been modified since last check. +#[tauri::command] +pub fn check_config_reload() -> Result { + let cwd = std::env::current_dir().unwrap_or_default(); + let config_path = crate::inference::config::find_config_path(&cwd) + .map_err(|e| format!("Config not found: {e}"))?; + + let metadata = std::fs::metadata(&config_path) + .map_err(|e| format!("Failed to read config metadata: {}", e))?; + + let modified = metadata.modified() + .map_err(|e| format!("Failed to get modification time: {}", e))?; + + let modified_secs = modified + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(|e| format!("Time error: {}", e))? + .as_secs(); + + let last_modified = CONFIG_LAST_MODIFIED.load(Ordering::SeqCst); + + if modified_secs > last_modified { + CONFIG_LAST_MODIFIED.store(modified_secs, Ordering::SeqCst); + Ok(true) + } else { + Ok(false) + } +} + +/// Force reload the model config (for manual refresh). +#[tauri::command] +pub fn reload_model_config() -> Result { + get_models_config() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::Ordering; + + #[test] + fn test_sampling_config_default() { + let cfg = SamplingConfig::default(); + assert_eq!(cfg.tool_temperature, 0.1); + assert_eq!(cfg.tool_top_p, 0.2); + assert_eq!(cfg.conversational_temperature, 0.7); + assert_eq!(cfg.conversational_top_p, 0.9); + } + + #[test] + fn test_sampling_config_serialization() { + let cfg = SamplingConfig { + tool_temperature: 0.5, + tool_top_p: 0.3, + conversational_temperature: 0.8, + conversational_top_p: 0.95, + }; + let json = serde_json::to_string(&cfg).unwrap(); + assert!(json.contains("0.5")); + assert!(json.contains("0.3")); + assert!(json.contains("0.8")); + assert!(json.contains("0.95")); + } + + #[test] + fn test_sampling_config_deserialization() { + let json = r#"{ + "toolTemperature": 0.3, + "toolTopP": 0.4, + "conversationalTemperature": 0.6, + "conversationalTopP": 0.8 + }"#; + let cfg: SamplingConfig = serde_json::from_str(json).unwrap(); + assert_eq!(cfg.tool_temperature, 0.3); + assert_eq!(cfg.tool_top_p, 0.4); + assert_eq!(cfg.conversational_temperature, 0.6); + assert_eq!(cfg.conversational_top_p, 0.8); + } + + #[test] + fn test_app_settings_default() { + let settings = AppSettings::default(); + assert_eq!(settings.active_model_key, None); + assert!(settings.allowed_paths.is_empty()); + assert_eq!(settings.theme, "system"); + assert!(settings.show_tool_traces); + // Sampling should be default + assert_eq!(settings.sampling.tool_temperature, 0.1); + } + + #[test] + fn test_app_settings_serialization() { + let mut settings = AppSettings::default(); + settings.active_model_key = Some("test-model".to_string()); + settings.allowed_paths = vec!["/home/user/docs".to_string()]; + settings.theme = "dark".to_string(); + settings.show_tool_traces = false; + + let json = serde_json::to_string(&settings).unwrap(); + assert!(json.contains("test-model")); + assert!(json.contains("dark")); + assert!(json.contains("docs")); + } + + #[test] + fn test_app_settings_deserialization() { + let json = r#"{ + "activeModelKey": "lm-studio-model", + "allowedPaths": ["/tmp", "/var"], + "theme": "light", + "showToolTraces": false, + "sampling": { + "toolTemperature": 0.2, + "toolTopP": 0.3, + "conversationalTemperature": 0.8, + "conversationalTopP": 0.9 + } + }"#; + let settings: AppSettings = serde_json::from_str(json).unwrap(); + assert_eq!(settings.active_model_key, Some("lm-studio-model".to_string())); + assert_eq!(settings.allowed_paths.len(), 2); + assert_eq!(settings.theme, "light"); + assert!(!settings.show_tool_traces); + } + + #[test] + fn test_config_last_modified_atomic() { + // Test that CONFIG_LAST_MODIFIED is properly initialized + let initial = CONFIG_LAST_MODIFIED.load(Ordering::SeqCst); + assert_eq!(initial, 0); + + // Store a value and verify + CONFIG_LAST_MODIFIED.store(12345, Ordering::SeqCst); + let after = CONFIG_LAST_MODIFIED.load(Ordering::SeqCst); + assert_eq!(after, 12345); + + // Reset + CONFIG_LAST_MODIFIED.store(0, Ordering::SeqCst); + } + + #[test] + fn test_settings_changed_atomic() { + // Test the SETTINGS_CHANGED flag + settings_changed(); + assert!(has_settings_changed()); + assert!(!has_settings_changed()); // Should clear after check + + // Setting it again should work + settings_changed(); + assert!(has_settings_changed()); + } + + #[test] + fn test_model_config_info_fields() { + let info = ModelConfigInfo { + key: "test-key".to_string(), + display_name: "Test Model".to_string(), + runtime: "lm-studio".to_string(), + base_url: "http://localhost:1234/v1".to_string(), + context_window: 32768, + temperature: 0.7, + max_tokens: 4096, + estimated_vram_gb: Some(24.0), + capabilities: vec!["chat".to_string(), "tools".to_string()], + tool_call_format: "json".to_string(), + }; + + assert_eq!(info.key, "test-key"); + assert_eq!(info.runtime, "lm-studio"); + assert_eq!(info.context_window, 32768); + } + + #[test] + fn test_models_overview_info_serialization() { + let overview = ModelsOverviewInfo { + active_model: "qwen2.5".to_string(), + models: vec![ + ModelConfigInfo { + key: "qwen2.5".to_string(), + display_name: "Qwen 2.5".to_string(), + runtime: "ollama".to_string(), + base_url: "http://localhost:11434/v1".to_string(), + context_window: 32768, + temperature: 0.7, + max_tokens: 4096, + estimated_vram_gb: Some(20.0), + capabilities: vec!["chat".to_string()], + tool_call_format: "json".to_string(), + } + ], + fallback_chain: vec!["gpt-oss".to_string()], + }; + + let json = serde_json::to_string(&overview).unwrap(); + assert!(json.contains("qwen2.5")); + assert!(json.contains("ollama")); + } + + #[test] + fn test_mcp_server_status_info() { + let status = McpServerStatusInfo { + name: "filesystem".to_string(), + status: "initialized".to_string(), + tool_count: 10, + tool_names: vec!["list_dir".to_string(), "read_file".to_string()], + last_check: "2024-01-01T00:00:00Z".to_string(), + error: None, + }; + + assert_eq!(status.name, "filesystem"); + assert_eq!(status.status, "initialized"); + assert_eq!(status.tool_count, 10); + + // Test with error + let status_with_error = McpServerStatusInfo { + error: Some("Connection refused".to_string()), + ..status + }; + assert!(status_with_error.error.is_some()); + } + + #[test] + fn test_permission_grant_info() { + let grant = PermissionGrantInfo { + tool_name: "filesystem.write_file".to_string(), + scope: "session".to_string(), + granted_at: "2024-01-01T12:00:00Z".to_string(), + }; + + assert_eq!(grant.tool_name, "filesystem.write_file"); + assert_eq!(grant.scope, "session"); + } + + #[test] + fn test_app_settings_export_import_roundtrip() { + let original = AppSettings { + active_model_key: Some("test-model".to_string()), + allowed_paths: vec!["/home/user".to_string()], + theme: "dark".to_string(), + show_tool_traces: true, + sampling: SamplingConfig { + tool_temperature: 0.15, + tool_top_p: 0.25, + conversational_temperature: 0.75, + conversational_top_p: 0.85, + }, + }; + + let json = original.export_to_json().unwrap(); + let imported = AppSettings::import_from_json(&json).unwrap(); + + assert_eq!(imported.active_model_key, original.active_model_key); + assert_eq!(imported.allowed_paths, original.allowed_paths); + assert_eq!(imported.theme, original.theme); + assert_eq!(imported.sampling.tool_temperature, original.sampling.tool_temperature); + } +} diff --git a/src-tauri/src/inference/client.rs b/src-tauri/src/inference/client.rs new file mode 100644 index 0000000..e028fd6 --- /dev/null +++ b/src-tauri/src/inference/client.rs @@ -0,0 +1,870 @@ +//! OpenAI-compatible inference client. +//! +//! Sends chat completion requests to a local LLM endpoint and streams back +//! tokens and tool calls. Handles the fallback chain when the primary model +//! is unavailable. + +use std::time::Duration; + +use futures::future::Either; +use futures::Stream; +use reqwest::Client as HttpClient; +use uuid::Uuid; + +use super::config::{ModelConfig, ModelsConfig, ToolCallFormat}; +use super::errors::InferenceError; +use super::streaming::{parse_non_streaming_response, parse_sse_stream}; +use super::tool_call_parser::{extract_tool_call_from_error, repair_malformed_tool_call_json}; +use super::types::{ + ChatCompletionRequest, ChatMessage, SamplingOverrides, StreamChunk, ToolCall, ToolDefinition, +}; + +// ─── Constants ─────────────────────────────────────────────────────────────── + +/// TCP connection timeout. +const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); + +/// Total request timeout for non-streaming calls. +const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); + +/// Total request timeout for streaming calls. +/// +/// Streaming responses from local models can take a long time, especially +/// when the context window is large (18+ messages). The model needs time +/// to process the full context before emitting the first token. A 30s +/// timeout causes silent stream termination that looks like "empty response" +/// to the agent loop. +const STREAM_REQUEST_TIMEOUT: Duration = Duration::from_secs(180); + +// ─── InferenceClient ───────────────────────────────────────────────────────── + +/// Client for the local LLM inference endpoint. +/// +/// Created from `ModelsConfig` and holds the current model configuration. +/// Provides streaming and non-streaming chat completion methods. +pub struct InferenceClient { + /// HTTP client for non-streaming requests (30s timeout). + http: HttpClient, + /// HTTP client for streaming requests (180s timeout). + http_stream: HttpClient, + /// The full models configuration (for fallback chain). + config: ModelsConfig, + /// The current model key (e.g., "qwen25-32b"). + current_model_key: String, + /// The current model configuration. + current_model: ModelConfig, + /// Models that have already been tried and failed. + exhausted_models: Vec, +} + +impl InferenceClient { + /// Create a new inference client from the models configuration. + /// + /// Resolves the active model from config. Does NOT check connectivity — + /// that happens on the first request. + pub fn from_config(config: ModelsConfig) -> Result { + let (key, model) = super::config::resolve_active_model(&config)?; + + let http = HttpClient::builder() + .connect_timeout(CONNECT_TIMEOUT) + .timeout(REQUEST_TIMEOUT) + .build() + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: model.base_url.clone(), + reason: format!("failed to build HTTP client: {e}"), + })?; + + let http_stream = HttpClient::builder() + .connect_timeout(CONNECT_TIMEOUT) + .timeout(STREAM_REQUEST_TIMEOUT) + .build() + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: model.base_url.clone(), + reason: format!("failed to build streaming HTTP client: {e}"), + })?; + + Ok(Self { + http, + http_stream, + config, + current_model_key: key, + current_model: model, + exhausted_models: Vec::new(), + }) + } + + /// Create an inference client targeting a specific model by key. + /// + /// Unlike [`from_config`] which resolves the active model + fallback chain, + /// this constructor pins the client to a specific model. Used by the + /// orchestrator (ADR-009) to create separate planner and router clients. + pub fn from_config_with_model( + config: ModelsConfig, + model_key: &str, + ) -> Result { + let model = config + .models + .get(model_key) + .ok_or_else(|| InferenceError::ConfigError { + reason: format!("model '{model_key}' not found in config"), + })? + .clone(); + + let http = HttpClient::builder() + .connect_timeout(CONNECT_TIMEOUT) + .timeout(REQUEST_TIMEOUT) + .build() + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: model.base_url.clone(), + reason: format!("failed to build HTTP client: {e}"), + })?; + + let http_stream = HttpClient::builder() + .connect_timeout(CONNECT_TIMEOUT) + .timeout(STREAM_REQUEST_TIMEOUT) + .build() + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: model.base_url.clone(), + reason: format!("failed to build streaming HTTP client: {e}"), + })?; + + Ok(Self { + http, + http_stream, + config, + current_model_key: model_key.to_string(), + current_model: model, + exhausted_models: Vec::new(), + }) + } + + /// The base URL of the current model's endpoint. + pub fn current_base_url(&self) -> &str { + &self.current_model.base_url + } + + /// The name of the currently selected model. + pub fn current_model_name(&self) -> &str { + &self.current_model.display_name + } + + /// The tool call format of the current model. + pub fn tool_call_format(&self) -> ToolCallFormat { + self.current_model.tool_call_format + } + + /// The context window size of the current model. + pub fn context_window(&self) -> u32 { + self.current_model.context_window + } + + // ─── Chat Completion (streaming) ───────────────────────────────────── + + /// Send a streaming chat completion request. + /// + /// Returns a `Stream` of `StreamChunk`s. Each chunk contains either a + /// text token, tool calls, or both. + /// + /// If the current model is unavailable, automatically tries the fallback + /// chain before returning an error. When Ollama returns HTTP 500 due to + /// malformed JSON in tool call arguments, attempts client-side repair + /// before triggering the fallback chain. + pub async fn chat_completion_stream( + &mut self, + messages: Vec, + tools: Option>, + sampling: Option, + ) -> Result< + impl Stream>, + InferenceError, + > { + let mut last_error: Option = None; + + for _attempt in 0..=self.remaining_fallbacks() { + match self.try_stream_request(&messages, &tools, sampling.as_ref()).await { + Ok(stream) => return Ok(Either::Left(stream)), + Err(e) if e.is_tool_call_parse_error() => { + // Ollama returned HTTP 500 because the model generated + // malformed JSON in tool call arguments. Try to repair the + // JSON client-side before falling back to the next model. + if let Some(repaired) = Self::try_repair_from_error(&e) { + tracing::info!( + tool = %repaired.tool_calls.as_ref() + .and_then(|tc| tc.first()) + .map(|tc| tc.name.as_str()) + .unwrap_or("unknown"), + "repaired malformed JSON tool call" + ); + return Ok(Either::Right(futures::stream::once( + async { Ok(repaired) }, + ))); + } + // Repair failed — continue to fallback chain + tracing::warn!("tool call JSON repair failed, falling back"); + last_error = Some(e); + if self.try_next_fallback().is_err() { + break; + } + } + Err(e) if Self::is_retriable(&e) => { + last_error = Some(e); + if self.try_next_fallback().is_err() { + break; // No more fallbacks + } + } + Err(e) => return Err(e), // Non-retriable error + } + } + + Err(last_error.unwrap_or(InferenceError::AllModelsUnavailable { + attempted: self.exhausted_models.clone(), + })) + } + + /// Attempt a single streaming request to the current model. + async fn try_stream_request( + &self, + messages: &[ChatMessage], + tools: &Option>, + sampling: Option<&SamplingOverrides>, + ) -> Result>, InferenceError> { + let url = format!("{}/chat/completions", self.current_model.base_url); + let model_name = self + .current_model + .model_name + .clone() + .unwrap_or_else(|| self.current_model_key.clone()); + + let temperature = sampling + .and_then(|s| s.temperature) + .unwrap_or(self.current_model.temperature); + let top_p = sampling.and_then(|s| s.top_p); + + // Enable JSON response format when the model config opts in AND + // tools are present. This sends `response_format: {"type":"json_object"}` + // which triggers Ollama's GBNF grammar enforcement for valid JSON output. + let response_format = if self.current_model.force_json_response && tools.is_some() { + Some(super::types::ResponseFormat { + r#type: "json_object".to_string(), + }) + } else { + None + }; + + let body = ChatCompletionRequest { + model: model_name, + messages: messages.to_vec(), + tools: tools.clone(), + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + temperature, + top_p, + max_tokens: self.current_model.max_tokens, + stream: true, + response_format, + }; + + // Log the request metadata (not the full body — it can be huge) + tracing::info!( + url = %url, + model = %body.model, + message_count = body.messages.len(), + has_tools = body.tools.is_some(), + tool_count = body.tools.as_ref().map(|t| t.len()).unwrap_or(0), + max_tokens = body.max_tokens, + stream = body.stream, + "=== LLM REQUEST ===" + ); + + let response = self + .http_stream + .post(&url) + .json(&body) + .header("Accept", "text/event-stream") + .send() + .await + .map_err(|e| { + if e.is_connect() { + InferenceError::ConnectionFailed { + endpoint: url.clone(), + reason: e.to_string(), + } + } else if e.is_timeout() { + InferenceError::Timeout { duration_secs: 5 } + } else { + InferenceError::ConnectionFailed { + endpoint: url.clone(), + reason: e.to_string(), + } + } + })?; + + let status = response.status(); + if !status.is_success() { + let body_text = response.text().await.unwrap_or_default(); + return Err(InferenceError::HttpError { + status: status.as_u16(), + body: body_text, + }); + } + + Ok(parse_sse_stream(response, self.current_model.tool_call_format)) + } + + // ─── Tool Call Repair ────────────────────────────────────────────────── + + /// Attempt to repair a malformed tool call from an Ollama HTTP 500 error. + /// + /// Extracts the raw JSON from the error body, applies repair heuristics, + /// and builds a synthetic `StreamChunk` with the repaired tool call. + /// Returns `None` if the error body doesn't match or repair fails. + fn try_repair_from_error(err: &InferenceError) -> Option { + let body = err.error_body()?; + let (_tool_name, raw_args) = extract_tool_call_from_error(body)?; + let repaired_args = repair_malformed_tool_call_json(&raw_args)?; + + // Build a synthetic tool call. The tool name is empty because + // Ollama's error body doesn't include it — the agent loop resolves + // the name from the conversation context (the model declared intent + // before Ollama attempted to parse the arguments). + let tool_call = ToolCall { + id: format!("call_{}", Uuid::new_v4()), + name: _tool_name, + arguments: repaired_args, + }; + + Some(StreamChunk { + token: None, + tool_calls: Some(vec![tool_call]), + finish_reason: Some("tool_calls".to_string()), + }) + } + + // ─── Chat Completion (non-streaming) ───────────────────────────────── + + /// Send a non-streaming chat completion request. + /// + /// Returns a single `StreamChunk` with the complete response. + pub async fn chat_completion( + &mut self, + messages: Vec, + tools: Option>, + sampling: Option, + ) -> Result { + let url = format!("{}/chat/completions", self.current_model.base_url); + let model_name = self + .current_model + .model_name + .clone() + .unwrap_or_else(|| self.current_model_key.clone()); + + let temperature = sampling + .as_ref() + .and_then(|s| s.temperature) + .unwrap_or(self.current_model.temperature); + let top_p = sampling.as_ref().and_then(|s| s.top_p); + + let response_format = if self.current_model.force_json_response && tools.is_some() { + Some(super::types::ResponseFormat { + r#type: "json_object".to_string(), + }) + } else { + None + }; + + let body = ChatCompletionRequest { + model: model_name, + messages, + tools: tools.clone(), + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + temperature, + top_p, + max_tokens: self.current_model.max_tokens, + stream: false, + response_format, + }; + + let response = self + .http + .post(&url) + .json(&body) + .send() + .await + .map_err(|e| InferenceError::ConnectionFailed { + endpoint: url.clone(), + reason: e.to_string(), + })?; + + let status = response.status(); + if !status.is_success() { + let body_text = response.text().await.unwrap_or_default(); + return Err(InferenceError::HttpError { + status: status.as_u16(), + body: body_text, + }); + } + + let body_text = response.text().await.map_err(|e| InferenceError::StreamError { + reason: format!("failed to read response body: {e}"), + })?; + + parse_non_streaming_response(&body_text, self.current_model.tool_call_format) + } + + // ─── Health Check ──────────────────────────────────────────────────── + + /// Check if the current model endpoint is reachable. + /// + /// Sends a lightweight request to verify connectivity. Does not consume + /// inference tokens. + pub async fn health_check(&self) -> Result { + let url = format!("{}/models", self.current_model.base_url); + + match self.http.get(&url).timeout(CONNECT_TIMEOUT).send().await { + Ok(resp) => Ok(resp.status().is_success()), + Err(e) => { + if e.is_connect() || e.is_timeout() { + // Provide helpful hint for LM Studio users + if self.current_model.base_url.contains("1234") { + tracing::debug!("LM Studio may not be running - connection to port 1234 failed"); + } + } + Ok(false) + } + } + } + + /// Get detailed model status including endpoint info. + pub async fn get_status(&self) -> super::types::ModelStatus { + let url = format!("{}/models", self.current_model.base_url); + match self.http.get(&url).timeout(CONNECT_TIMEOUT).send().await { + Ok(resp) if resp.status().is_success() => super::types::ModelStatus { + key: self.current_model_key.clone(), + display_name: self.current_model.display_name.clone(), + base_url: self.current_model.base_url.clone(), + healthy: true, + model_name: self.current_model.model_name.clone().or_else(|| Some(self.current_model_key.clone())), + error: None, + }, + Ok(resp) => super::types::ModelStatus { + key: self.current_model_key.clone(), + display_name: self.current_model.display_name.clone(), + base_url: self.current_model.base_url.clone(), + healthy: false, + model_name: None, + error: Some(format!("HTTP {}", resp.status())), + }, + Err(e) => super::types::ModelStatus { + key: self.current_model_key.clone(), + display_name: self.current_model.display_name.clone(), + base_url: self.current_model.base_url.clone(), + healthy: false, + model_name: None, + error: Some(e.to_string()), + }, + } + } + + // ─── Fallback Chain ─────────────────────────────────────────────────────── + + /// Move to the next model in the fallback chain. + /// + /// Returns `Err` if no more fallbacks are available. + pub fn try_next_fallback(&mut self) -> Result<(), InferenceError> { + self.exhausted_models.push(self.current_model_key.clone()); + + for key in &self.config.fallback_chain { + if self.exhausted_models.contains(key) || key == "static_response" { + continue; + } + if let Some(model) = self.config.models.get(key) { + self.current_model_key = key.clone(); + self.current_model = model.clone(); + return Ok(()); + } + } + + Err(InferenceError::AllModelsUnavailable { + attempted: self.exhausted_models.clone(), + }) + } + + /// Number of remaining fallback models. + fn remaining_fallbacks(&self) -> usize { + self.config + .fallback_chain + .iter() + .filter(|k| !self.exhausted_models.contains(k) && k.as_str() != "static_response") + .count() + } + + /// Whether an error should trigger a fallback attempt. + /// + /// HTTP 404 is included because Ollama returns 404 when a model isn't + /// pulled/installed — the next model in the chain may still be available. + /// + /// HTTP 500 is included because local model servers (Ollama, llama.cpp) + /// return 500 when the model generates malformed JSON in tool call + /// arguments — this is a transient model error, not a permanent server + /// failure. Retrying (or falling back) is the correct behavior. + fn is_retriable(err: &InferenceError) -> bool { + matches!( + err, + InferenceError::ConnectionFailed { .. } + | InferenceError::Timeout { .. } + | InferenceError::HttpError { status: 404, .. } + | InferenceError::HttpError { status: 500, .. } + | InferenceError::HttpError { status: 502..=504, .. } + ) + } +} + +// ─── Static Response Fallback ──────────────────────────────────────────────── + +/// Generate the static response used when all models are unavailable. +pub fn static_fallback_response() -> StreamChunk { + StreamChunk { + token: Some( + "The model server is not running. \ + Start it with: ./scripts/start-model.sh\n\n\ + If using Ollama instead, run: ollama serve" + .to_string(), + ), + tool_calls: None, + finish_reason: Some("stop".to_string()), + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + fn test_config() -> ModelsConfig { + let mut models = HashMap::new(); + models.insert( + "model-a".to_string(), + ModelConfig { + display_name: "Model A".to_string(), + runtime: "ollama".to_string(), + model_name: Some("model-a:latest".to_string()), + model_path: None, + base_url: "http://localhost:11111/v1".to_string(), + context_window: 4096, + tool_call_format: ToolCallFormat::NativeJson, + temperature: 0.7, + max_tokens: 1024, + estimated_vram_gb: None, + capabilities: vec!["text".to_string()], + force_json_response: false, + role: None, + }, + ); + models.insert( + "model-b".to_string(), + ModelConfig { + display_name: "Model B".to_string(), + runtime: "ollama".to_string(), + model_name: Some("model-b:latest".to_string()), + model_path: None, + base_url: "http://localhost:22222/v1".to_string(), + context_window: 8192, + tool_call_format: ToolCallFormat::Pythonic, + temperature: 0.5, + max_tokens: 2048, + estimated_vram_gb: None, + capabilities: vec!["text".to_string()], + force_json_response: false, + role: None, + }, + ); + models.insert( + "lmstudio-model".to_string(), + ModelConfig { + display_name: "LM Studio Model".to_string(), + runtime: "lmstudio".to_string(), + model_name: Some("lmstudio/default".to_string()), + model_path: None, + base_url: "http://localhost:1234/v1".to_string(), + context_window: 32768, + tool_call_format: ToolCallFormat::NativeJson, + temperature: 0.7, + max_tokens: 4096, + estimated_vram_gb: Some(8.0), + capabilities: vec!["text".to_string(), "tool_calling".to_string()], + force_json_response: false, + role: None, + }, + ); + + ModelsConfig { + active_model: "model-a".to_string(), + models_dir: None, + models, + fallback_chain: vec![ + "model-a".to_string(), + "model-b".to_string(), + "static_response".to_string(), + ], + orchestrator: None, + two_pass_tool_selection: None, + enabled_servers: None, + enabled_tools: None, + } + } + + #[test] + fn test_from_config_selects_active_model() { + let client = InferenceClient::from_config(test_config()).unwrap(); + assert_eq!(client.current_model_key, "model-a"); + assert_eq!(client.current_model_name(), "Model A"); + } + + #[test] + fn test_fallback_chain() { + let mut client = InferenceClient::from_config(test_config()).unwrap(); + assert_eq!(client.current_model_key, "model-a"); + + // Fallback to model-b + client.try_next_fallback().unwrap(); + assert_eq!(client.current_model_key, "model-b"); + assert_eq!(client.tool_call_format(), ToolCallFormat::Pythonic); + + // No more fallbacks + let result = client.try_next_fallback(); + assert!(result.is_err()); + } + + #[test] + fn test_lmstudio_model_config() { + let config = test_config(); + // Create client targeting LM Studio model directly + let client = InferenceClient::from_config_with_model(config, "lmstudio-model").unwrap(); + assert_eq!(client.current_model_key, "lmstudio-model"); + assert_eq!(client.current_model_name(), "LM Studio Model"); + assert_eq!(client.current_base_url(), "http://localhost:1234/v1"); + } + + #[test] + fn test_remaining_fallbacks() { + let client = InferenceClient::from_config(test_config()).unwrap(); + // model-a (current, in chain) + model-b = 2 remaining + assert_eq!(client.remaining_fallbacks(), 2); + } + + #[test] + fn test_is_retriable() { + assert!(InferenceClient::is_retriable( + &InferenceError::ConnectionFailed { + endpoint: "".into(), + reason: "".into() + } + )); + assert!(InferenceClient::is_retriable(&InferenceError::Timeout { + duration_secs: 5 + })); + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 404, + body: "model not found".into() + })); + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 500, + body: "malformed JSON".into() + })); + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 503, + body: "".into() + })); + assert!(!InferenceClient::is_retriable( + &InferenceError::HttpError { + status: 400, + body: "".into() + } + )); + assert!(!InferenceClient::is_retriable( + &InferenceError::ToolCallParseError { + raw_response: "".into(), + reason: "".into() + } + )); + } + + #[test] + fn test_is_retriable_connection_failed() { + assert!(InferenceClient::is_retriable( + &InferenceError::ConnectionFailed { + endpoint: "localhost".into(), + reason: "connection refused".into() + } + )); + } + + #[test] + fn test_is_retriable_timeout() { + assert!(InferenceClient::is_retriable(&InferenceError::Timeout { duration_secs: 5 })); + } + + #[test] + fn test_is_retriable_404() { + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 404, + body: "not found".into() + })); + } + + #[test] + fn test_is_retriable_500() { + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 500, + body: "internal error".into() + })); + } + + #[test] + fn test_is_retriable_502() { + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 502, + body: "bad gateway".into() + })); + } + + #[test] + fn test_is_retriable_503() { + assert!(InferenceClient::is_retriable(&InferenceError::HttpError { + status: 503, + body: "service unavailable".into() + })); + } + + #[test] + fn test_is_retriable_400_not_retriable() { + // HTTP 400 should NOT be retriable + assert!(!InferenceClient::is_retriable(&InferenceError::HttpError { + status: 400, + body: "bad request".into() + })); + } + + #[test] + fn test_is_retriable_401_not_retriable() { + // HTTP 401 should NOT be retriable + assert!(!InferenceClient::is_retriable(&InferenceError::HttpError { + status: 401, + body: "unauthorized".into() + })); + } + + #[test] + fn test_is_retriable_403_not_retriable() { + // HTTP 403 should NOT be retriable + assert!(!InferenceClient::is_retriable(&InferenceError::HttpError { + status: 403, + body: "forbidden".into() + })); + } + + #[test] + fn test_is_retriable_tool_call_error_not_retriable() { + // Tool call parse error should NOT be retriable (it's a model issue) + assert!(!InferenceClient::is_retriable(&InferenceError::ToolCallParseError { + raw_response: "invalid".into(), + reason: "bad json".into() + })); + } + + #[test] + fn test_static_fallback_response() { + let chunk = static_fallback_response(); + assert!(chunk.token.is_some()); + assert!(chunk.tool_calls.is_none()); + assert_eq!(chunk.finish_reason.as_deref(), Some("stop")); + } + + #[test] + fn test_try_repair_from_error_success() { + // Simulate the exact Ollama HTTP 500 error with malformed JSON + let err = InferenceError::HttpError { + status: 500, + body: r#"{"error":{"message":"error parsing tool call: raw='{\"create_dirs\":true,\"destination\":\"\"/Users/chintan/Desktop/file.png\",\"source\":\"/tmp/file.png\"}', err=invalid character '/' after object key:value pair"}}"#.to_string(), + }; + + let result = InferenceClient::try_repair_from_error(&err); + assert!(result.is_some(), "should repair the malformed JSON"); + + let chunk = result.unwrap(); + assert!(chunk.token.is_none()); + assert!(chunk.tool_calls.is_some()); + assert_eq!(chunk.finish_reason.as_deref(), Some("tool_calls")); + + let calls = chunk.tool_calls.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].arguments["destination"], "/Users/chintan/Desktop/file.png"); + assert_eq!(calls[0].arguments["create_dirs"], true); + } + + #[test] + fn test_try_repair_from_error_non_tool_call_error() { + // A regular HTTP 500 that isn't a tool call parse error + let err = InferenceError::HttpError { + status: 500, + body: "internal server error".to_string(), + }; + assert!(InferenceClient::try_repair_from_error(&err).is_none()); + } + + #[test] + fn test_try_repair_from_error_non_http_error() { + let err = InferenceError::Timeout { duration_secs: 30 }; + assert!(InferenceClient::try_repair_from_error(&err).is_none()); + } + + #[test] + fn test_lmstudio_base_url_construction() { + // Test LM Studio URL construction with different ports + let client = InferenceClient::from_config_with_model(test_config(), "lmstudio-model").unwrap(); + assert_eq!(client.current_base_url(), "http://localhost:1234/v1"); + } + + #[test] + fn test_fallback_chain_exhausted_error() { + let mut client = InferenceClient::from_config(test_config()).unwrap(); + + // Exhaust all fallbacks + client.try_next_fallback().unwrap(); // model-b + let result = client.try_next_fallback(); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, InferenceError::AllModelsUnavailable { .. })); + } + + #[test] + fn test_current_model_name_for_lmstudio() { + let client = InferenceClient::from_config_with_model(test_config(), "lmstudio-model").unwrap(); + assert_eq!(client.current_model_name(), "LM Studio Model"); + } + + #[test] + fn test_current_model_name_for_ollama() { + let client = InferenceClient::from_config(test_config()).unwrap(); + assert_eq!(client.current_model_name(), "Model A"); + } + + #[test] + fn test_tool_call_format_json() { + // Test that LM Studio uses NativeJson format + let config = test_config(); + let client = InferenceClient::from_config_with_model(config, "lmstudio-model").unwrap(); + assert_eq!(client.tool_call_format(), ToolCallFormat::NativeJson); + } + + #[test] + fn test_tool_call_format_pythonic() { + // Test that model-b uses Pythonic format + let config = test_config(); + let mut client = InferenceClient::from_config(config).unwrap(); + client.try_next_fallback().unwrap(); // model-b + assert_eq!(client.tool_call_format(), ToolCallFormat::Pythonic); + } +} diff --git a/src-tauri/src/inference/config.rs b/src-tauri/src/inference/config.rs new file mode 100644 index 0000000..de5a12f --- /dev/null +++ b/src-tauri/src/inference/config.rs @@ -0,0 +1,343 @@ +//! Model configuration loading and validation. +//! +//! Reads `_models/config.yaml` and resolves environment variables. +//! Config is the single source of truth for model endpoints, formats, and +//! fallback chains. + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +use serde::Deserialize; + +use super::errors::InferenceError; + +// ─── Public Types ──────────────────────────────────────────────────────────── + +/// Which tool-call format the model emits. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolCallFormat { + /// Standard OpenAI JSON tool calls (Qwen, GPT, etc.). + NativeJson, + /// Pythonic `Tool: … Arguments: …` format (LFM2.5). + Pythonic, + /// Bracket format: `[server.tool(args)]` or `<|tool_call_start|>…<|tool_call_end|>` (LFM2-24B-A2B). + Bracket, +} + +/// A single model's runtime configuration. +#[derive(Debug, Clone, Deserialize)] +pub struct ModelConfig { + pub display_name: String, + pub runtime: String, + #[serde(default)] + pub model_name: Option, + #[serde(default)] + pub model_path: Option, + pub base_url: String, + pub context_window: u32, + pub tool_call_format: ToolCallFormat, + pub temperature: f32, + pub max_tokens: u32, + pub estimated_vram_gb: Option, + #[serde(default)] + pub capabilities: Vec, + /// When `true`, sends `response_format: {"type":"json_object"}` on tool-calling + /// turns. This triggers Ollama's GBNF grammar enforcement for valid JSON output. + /// Disabled by default — enable after live testing with the target model. + #[serde(default)] + pub force_json_response: bool, + /// Optional role hint (e.g., "tool_router"). Used by the orchestrator to + /// identify models by purpose rather than name. + #[serde(default)] + pub role: Option, +} + +/// Dual-model orchestrator configuration (ADR-009). +/// +/// When enabled, GPT-OSS-20B plans multi-step workflows and LFM2-1.2B-Tool +/// executes each step with a RAG pre-filtered tool set. +#[derive(Debug, Clone, Deserialize)] +pub struct OrchestratorConfig { + /// Whether the dual-model orchestrator is active. + #[serde(default)] + pub enabled: bool, + /// Model key for the planner (e.g., "gpt-oss-20b"). + #[serde(default)] + pub planner_model: String, + /// Model key for the tool router (e.g., "lfm2-1.2b-tool"). + #[serde(default)] + pub router_model: String, + /// Top-K tools for RAG pre-filter per step (default: 15). + #[serde(default = "default_router_top_k")] + pub router_top_k: u32, + /// Maximum number of steps the planner can produce (default: 10). + #[serde(default = "default_max_plan_steps")] + pub max_plan_steps: u32, + /// Maximum retries per step if the router fails to produce a tool call. + #[serde(default = "default_step_retries")] + pub step_retries: u32, +} + +fn default_router_top_k() -> u32 { + 15 +} +fn default_max_plan_steps() -> u32 { + 10 +} +fn default_step_retries() -> u32 { + 3 +} + +/// Top-level model registry (mirrors `_models/config.yaml`). +#[derive(Debug, Clone, Deserialize)] +pub struct ModelsConfig { + pub active_model: String, + #[serde(default)] + pub models_dir: Option, + pub models: HashMap, + #[serde(default)] + pub fallback_chain: Vec, + /// Dual-model orchestrator settings (ADR-009). When absent, orchestration + /// is disabled and the single-model agent loop runs as before. + #[serde(default)] + pub orchestrator: Option, + /// Enable two-pass category-based tool selection (Tier 1.5). + /// + /// When `true` and >20 MCP tools are registered, the first agent turn + /// sends ~15 category meta-tools (~1,500 tokens) instead of all tools + /// (~8,670 tokens). The model selects 2-3 categories, then subsequent + /// turns use only those categories' real tools. + /// + /// Default: `false` (flat mode — all tools every turn). + #[serde(default)] + pub two_pass_tool_selection: Option, + /// Optional allowlist of MCP server names to start. + /// + /// When set, only servers whose names appear in this list are started. + /// All others are skipped during discovery. This reduces the tool count + /// sent to the model, improving accuracy and reducing token usage. + /// + /// Example: `["security", "audit", "document", "ocr", "email", "system", "clipboard", "filesystem"]` + /// + /// Default: `None` (all discovered servers are started). + #[serde(default)] + pub enabled_servers: Option>, + /// Optional allowlist of fully-qualified tool names to expose to the model. + /// + /// When set, only tools whose names appear in this list are kept in the + /// registry after server startup. All other tools are removed. This allows + /// curating a tight, high-accuracy tool surface from servers that each + /// expose more tools than needed for a specific demo or deployment. + /// + /// Tool names are fully-qualified: `"server.tool"` (e.g., `"filesystem.list_dir"`). + /// + /// Default: `None` (all tools from started servers are exposed). + #[serde(default)] + pub enabled_tools: Option>, +} + +// ─── Loading ───────────────────────────────────────────────────────────────── + +/// Resolve a config path relative to the project root. +/// +/// Searches upward from `start` for `_models/config.yaml`. Falls back to +/// `LOCALCOWORK_PROJECT_ROOT` env var if set. +pub fn find_config_path(start: &Path) -> Result { + // 1. Check env var + if let Ok(root) = std::env::var("LOCALCOWORK_PROJECT_ROOT") { + let candidate = PathBuf::from(&root).join("_models/config.yaml"); + if candidate.exists() { + return Ok(candidate); + } + } + + // 2. Walk upward from `start` + let mut dir = start.to_path_buf(); + loop { + let candidate = dir.join("_models/config.yaml"); + if candidate.exists() { + return Ok(candidate); + } + if !dir.pop() { + break; + } + } + + Err(InferenceError::ConfigError { + reason: "could not find _models/config.yaml".into(), + }) +} + +/// Load and parse the models configuration file. +/// +/// Performs environment-variable interpolation on string values matching +/// `${VAR_NAME}` or `${VAR_NAME:-default}`. +pub fn load_models_config(path: &Path) -> Result { + let raw = std::fs::read_to_string(path).map_err(|e| InferenceError::ConfigError { + reason: format!("failed to read {}: {e}", path.display()), + })?; + + let interpolated = interpolate_env_vars(&raw); + + let config: ModelsConfig = + serde_yaml::from_str(&interpolated).map_err(|e| InferenceError::ConfigError { + reason: format!("failed to parse config: {e}"), + })?; + + Ok(config) +} + +/// Resolve the active model configuration, respecting the fallback chain. +/// +/// Returns `(model_key, ModelConfig)` for the first available model. +/// "Available" here means it exists in the config — actual connectivity is +/// checked at runtime by the client. +pub fn resolve_active_model(config: &ModelsConfig) -> Result<(String, ModelConfig), InferenceError> { + // Try the explicitly active model first + if let Some(model) = config.models.get(&config.active_model) { + return Ok((config.active_model.clone(), model.clone())); + } + + // Walk the fallback chain + for key in &config.fallback_chain { + if key == "static_response" { + continue; // handled by the client as a special case + } + if let Some(model) = config.models.get(key) { + return Ok((key.clone(), model.clone())); + } + } + + Err(InferenceError::ConfigError { + reason: format!( + "active model '{}' not found in config and no fallback available", + config.active_model + ), + }) +} + +// ─── Env-var interpolation ─────────────────────────────────────────────────── + +/// Replace `${VAR}` and `${VAR:-default}` in a string. +fn interpolate_env_vars(input: &str) -> String { + let mut result = String::with_capacity(input.len()); + let mut chars = input.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '$' && chars.peek() == Some(&'{') { + chars.next(); // consume '{' + let mut var_expr = String::new(); + for c in chars.by_ref() { + if c == '}' { + break; + } + var_expr.push(c); + } + let resolved = resolve_var_expr(&var_expr); + result.push_str(&resolved); + } else { + result.push(ch); + } + } + + result +} + +/// Resolve a variable expression like `VAR` or `VAR:-default`. +fn resolve_var_expr(expr: &str) -> String { + if let Some(idx) = expr.find(":-") { + let var_name = &expr[..idx]; + let default = &expr[idx + 2..]; + std::env::var(var_name).unwrap_or_else(|_| expand_tilde(default)) + } else { + std::env::var(expr).unwrap_or_default() + } +} + +/// Expand a leading `~` to the user's home directory. +/// +/// Uses `dirs::home_dir()` for cross-platform support (works on macOS, +/// Linux, and Windows where `$HOME` may not be set). +fn expand_tilde(path: &str) -> String { + if let Some(rest) = path.strip_prefix('~') { + if let Some(home) = dirs::home_dir() { + return format!("{}{rest}", home.display()); + } + } + path.to_string() +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_interpolate_env_vars_with_default() { + // When env var is NOT set, use default + std::env::remove_var("__TEST_NONEXISTENT_VAR__"); + let input = "${__TEST_NONEXISTENT_VAR__:-/fallback/path}"; + let result = interpolate_env_vars(input); + assert_eq!(result, "/fallback/path"); + } + + #[test] + fn test_interpolate_env_vars_with_value() { + std::env::set_var("__TEST_INFERENCE_VAR__", "/custom/path"); + let input = "${__TEST_INFERENCE_VAR__:-/fallback/path}"; + let result = interpolate_env_vars(input); + assert_eq!(result, "/custom/path"); + std::env::remove_var("__TEST_INFERENCE_VAR__"); + } + + #[test] + fn test_interpolate_no_vars() { + let input = "plain text with no variables"; + assert_eq!(interpolate_env_vars(input), input); + } + + #[test] + fn test_expand_tilde() { + let result = expand_tilde("~/Documents"); + assert!(!result.starts_with('~'), "tilde should be expanded"); + assert!(result.ends_with("/Documents")); + } + + #[test] + fn test_resolve_active_model_not_found() { + let config = ModelsConfig { + active_model: "nonexistent".into(), + models_dir: None, + models: HashMap::new(), + fallback_chain: vec![], + orchestrator: None, + two_pass_tool_selection: None, + enabled_servers: None, + enabled_tools: None, + }; + let result = resolve_active_model(&config); + assert!(result.is_err()); + } + + #[test] + fn test_force_json_response_default_is_false() { + // Config YAML without force_json_response should default to false + let yaml = r#" + active_model: test + models: + test: + display_name: "Test Model" + runtime: ollama + base_url: "http://localhost:11434/v1" + context_window: 4096 + tool_call_format: native_json + temperature: 0.7 + max_tokens: 1024 + "#; + let config: ModelsConfig = serde_yaml::from_str(yaml).unwrap(); + let model = config.models.get("test").unwrap(); + assert!(!model.force_json_response, "force_json_response should default to false"); + } +} diff --git a/src-tauri/src/inference/errors.rs b/src-tauri/src/inference/errors.rs new file mode 100644 index 0000000..a20055f --- /dev/null +++ b/src-tauri/src/inference/errors.rs @@ -0,0 +1,133 @@ +//! Inference error types. +//! +//! All errors implement `std::error::Error` via `thiserror`. Structured logging +//! is the caller's responsibility — these types carry the context needed to build +//! meaningful log entries. + +use thiserror::Error; + +/// Errors that can occur during inference operations. +#[derive(Debug, Error)] +pub enum InferenceError { + /// TCP/HTTP connection to the model endpoint failed. + #[error("connection failed to {endpoint}: {reason}")] + ConnectionFailed { + endpoint: String, + reason: String, + }, + + /// The model endpoint did not respond within the configured timeout. + #[error("inference timeout after {duration_secs}s")] + Timeout { + duration_secs: u64, + }, + + /// Failed to parse a tool call from the model's response. + #[error("tool call parse error: {reason}")] + ToolCallParseError { + raw_response: String, + reason: String, + }, + + /// The model returned a tool name that is not in the registry. + #[error("unknown tool: {name}")] + UnknownTool { + name: String, + }, + + /// Every model in the fallback chain was unavailable. + #[error("all models unavailable (tried: {})", attempted.join(", "))] + AllModelsUnavailable { + attempted: Vec, + }, + + /// Non-2xx HTTP response from the model endpoint. + #[error("HTTP {status}: {body}")] + HttpError { + status: u16, + body: String, + }, + + /// SSE stream parsing or chunk-level error. + #[error("stream error: {reason}")] + StreamError { + reason: String, + }, + + /// Configuration loading or validation error. + #[error("config error: {reason}")] + ConfigError { + reason: String, + }, +} + +impl InferenceError { + /// Check if this error is an Ollama tool call parse failure (HTTP 500). + /// + /// Ollama returns HTTP 500 with `"error parsing tool call"` when the model + /// generates malformed JSON in tool call arguments. These errors are + /// candidates for client-side JSON repair. + pub fn is_tool_call_parse_error(&self) -> bool { + matches!( + self, + InferenceError::HttpError { status: 500, body } + if body.contains("error parsing tool call") + ) + } + + /// Extract the error body text, if this is an `HttpError`. + pub fn error_body(&self) -> Option<&str> { + match self { + InferenceError::HttpError { body, .. } => Some(body), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_tool_call_parse_error_true() { + let err = InferenceError::HttpError { + status: 500, + body: r#"{"error":{"message":"error parsing tool call: raw='{...}', err=invalid"}}"# + .to_string(), + }; + assert!(err.is_tool_call_parse_error()); + } + + #[test] + fn test_is_tool_call_parse_error_false_different_body() { + let err = InferenceError::HttpError { + status: 500, + body: "internal server error".to_string(), + }; + assert!(!err.is_tool_call_parse_error()); + } + + #[test] + fn test_is_tool_call_parse_error_false_different_status() { + let err = InferenceError::HttpError { + status: 404, + body: "error parsing tool call".to_string(), + }; + assert!(!err.is_tool_call_parse_error()); + } + + #[test] + fn test_error_body_http_error() { + let err = InferenceError::HttpError { + status: 500, + body: "test body".to_string(), + }; + assert_eq!(err.error_body(), Some("test body")); + } + + #[test] + fn test_error_body_non_http() { + let err = InferenceError::Timeout { duration_secs: 5 }; + assert!(err.error_body().is_none()); + } +} diff --git a/src-tauri/src/inference/mod.rs b/src-tauri/src/inference/mod.rs new file mode 100644 index 0000000..d25a374 --- /dev/null +++ b/src-tauri/src/inference/mod.rs @@ -0,0 +1,25 @@ +//! Inference Client — OpenAI-compatible API client for local LLM inference. +//! +//! This module handles all communication with the local model endpoint: +//! - Streaming and non-streaming chat completions +//! - Tool call parsing (native JSON + Pythonic formats) +//! - SSE stream parsing +//! - Fallback chain management +//! - Model configuration loading from `_models/config.yaml` +//! +//! The client speaks the OpenAI Chat Completions API, making the model +//! interchangeable via config. Switching from Qwen to LFM2.5 is a config +//! change, not a code change. + +pub mod client; +pub mod config; +pub mod errors; +pub mod streaming; +pub mod tool_call_parser; +pub mod types; + +// Re-exports for convenience +pub use client::InferenceClient; +pub use config::{ModelConfig, ModelsConfig, ToolCallFormat}; +pub use errors::InferenceError; +pub use types::{ChatMessage, Role, SamplingOverrides, StreamChunk, ToolCall, ToolDefinition}; diff --git a/src-tauri/src/inference/streaming.rs b/src-tauri/src/inference/streaming.rs new file mode 100644 index 0000000..86f1147 --- /dev/null +++ b/src-tauri/src/inference/streaming.rs @@ -0,0 +1,567 @@ +//! SSE streaming response parser for OpenAI-compatible chat completions. +//! +//! Reads a `reqwest::Response` as a byte stream, splits on SSE boundaries +//! (`data: …\n\n`), parses each chunk as JSON, and accumulates tool calls +//! across multiple deltas. + +use futures::stream::{self, Stream, StreamExt}; +use serde::Deserialize; +use uuid::Uuid; + +use super::config::ToolCallFormat; +use super::errors::InferenceError; +use super::tool_call_parser::{ + parse_bracket_tool_calls, parse_native_json_tool_call, parse_pythonic_tool_calls, +}; +use super::types::{ChatCompletionChunk, StreamChunk, ToolCall}; + +// ─── SSE line parser ───────────────────────────────────────────────────────── + +/// Parse raw SSE bytes into `StreamChunk`s. +/// +/// This is the main entry point for streaming. It: +/// 1. Splits the HTTP body into SSE events +/// 2. Parses each `data:` line as a `ChatCompletionChunk` +/// 3. Accumulates tool call fragments across deltas +/// 4. Emits complete `StreamChunk`s for each event +pub fn parse_sse_stream( + response: reqwest::Response, + tool_call_format: ToolCallFormat, +) -> impl Stream> { + let byte_stream = response.bytes_stream(); + + // Buffer for incomplete SSE lines across chunk boundaries + let state = StreamState::new(tool_call_format); + + stream::unfold( + (byte_stream, state, String::new()), + |(mut byte_stream, mut state, mut buffer)| async move { + loop { + // Check if we have a complete SSE event in the buffer + if let Some(event_end) = buffer.find("\n\n") { + let event = buffer[..event_end].to_string(); + buffer = buffer[event_end + 2..].to_string(); + + match state.process_event(&event) { + Ok(Some(chunk)) => return Some((Ok(chunk), (byte_stream, state, buffer))), + Ok(None) => continue, // [DONE] or keep-alive + Err(e) => return Some((Err(e), (byte_stream, state, buffer))), + } + } + + // Need more data from the stream + match byte_stream.next().await { + Some(Ok(bytes)) => { + let text = String::from_utf8_lossy(&bytes); + buffer.push_str(&text); + } + Some(Err(e)) => { + return Some(( + Err(InferenceError::StreamError { + reason: format!("stream read error: {e}"), + }), + (byte_stream, state, buffer), + )); + } + None => { + // Stream ended — check for any remaining buffer content + if !buffer.trim().is_empty() { + match state.process_event(buffer.trim()) { + Ok(Some(chunk)) => { + buffer.clear(); + return Some((Ok(chunk), (byte_stream, state, buffer))); + } + Ok(None) => return None, + Err(e) => return Some((Err(e), (byte_stream, state, buffer))), + } + } + return None; + } + } + } + }, + ) +} + +// ─── Stream State ──────────────────────────────────────────────────────────── + +/// Mutable state for accumulating tool call fragments across SSE events. +struct StreamState { + tool_call_format: ToolCallFormat, + /// Accumulated content for Pythonic format parsing. + accumulated_content: String, + /// In-progress tool calls (native_json): `(index, id, name, arguments_buffer)`. + pending_tool_calls: Vec<(u32, Option, String, String)>, +} + +impl StreamState { + fn new(tool_call_format: ToolCallFormat) -> Self { + Self { + tool_call_format, + accumulated_content: String::new(), + pending_tool_calls: Vec::new(), + } + } + + /// Process a single SSE event string (may contain multiple `data:` lines). + fn process_event(&mut self, event: &str) -> Result, InferenceError> { + let mut data_content = String::new(); + + for line in event.lines() { + if let Some(data) = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:")) { + let data = data.trim(); + if data == "[DONE]" { + // Stream complete — finalize any pending pythonic calls + return self.finalize(); + } + data_content.push_str(data); + } + // Ignore non-data lines (comments, event types, etc.) + } + + if data_content.is_empty() { + return Ok(None); // Keep-alive or comment + } + + let chunk: ChatCompletionChunk = + serde_json::from_str(&data_content).map_err(|e| InferenceError::StreamError { + reason: format!("failed to parse SSE chunk: {e} (data: {data_content})"), + })?; + + self.process_chunk(chunk) + } + + /// Process a parsed `ChatCompletionChunk`. + fn process_chunk( + &mut self, + chunk: ChatCompletionChunk, + ) -> Result, InferenceError> { + let choice = match chunk.choices.first() { + Some(c) => c, + None => return Ok(None), + }; + + let mut result = StreamChunk { + token: None, + tool_calls: None, + finish_reason: choice.finish_reason.clone(), + }; + + // Handle text content — use only `content`, ignore `reasoning`. + // Reasoning/thinking models (Qwen3, GPT-OSS) stream chain-of-thought + // in `reasoning` and the actual answer in `content`. We only surface + // `content` to the user; reasoning tokens are silently discarded. + if let Some(ref content) = choice.delta.content { + if !content.is_empty() { + result.token = Some(content.clone()); + self.accumulated_content.push_str(content); + } + } + + // Handle native tool call deltas + if let Some(ref tool_calls) = choice.delta.tool_calls { + for tc in tool_calls { + let index = tc.index.unwrap_or(0); + + // Find or create the pending tool call for this index + let pending = self + .pending_tool_calls + .iter_mut() + .find(|(idx, _, _, _)| *idx == index); + + match pending { + Some((_, ref mut id, ref mut name, ref mut args)) => { + // Append to existing + if let Some(ref f) = tc.function { + if let Some(ref n) = f.name { + name.push_str(n); + } + if let Some(ref a) = f.arguments { + args.push_str(a); + } + } + if tc.id.is_some() { + *id = tc.id.clone(); + } + } + None => { + // New tool call + let name = tc + .function + .as_ref() + .and_then(|f| f.name.clone()) + .unwrap_or_default(); + let args = tc + .function + .as_ref() + .and_then(|f| f.arguments.clone()) + .unwrap_or_default(); + self.pending_tool_calls + .push((index, tc.id.clone(), name, args)); + } + } + } + } + + // If finish_reason is "tool_calls" (native) or "stop" (might have text-based calls), + // finalize the tool calls + if let Some(ref reason) = result.finish_reason { + if reason == "tool_calls" { + result.tool_calls = Some(self.finalize_native_tool_calls()?); + } else if reason == "stop" { + // Check accumulated content for text-based tool call formats + match self.tool_call_format { + ToolCallFormat::Pythonic => { + let calls = parse_pythonic_tool_calls(&self.accumulated_content)?; + if !calls.is_empty() { + result.tool_calls = Some(calls); + result.finish_reason = Some("tool_calls".into()); + } + } + ToolCallFormat::Bracket => { + let calls = parse_bracket_tool_calls(&self.accumulated_content)?; + if !calls.is_empty() { + result.tool_calls = Some(calls); + result.finish_reason = Some("tool_calls".into()); + } + } + ToolCallFormat::NativeJson => {} // Native handles via structured deltas + } + } + } + + Ok(Some(result)) + } + + /// Finalize accumulated native JSON tool calls. + fn finalize_native_tool_calls(&mut self) -> Result, InferenceError> { + let pending = std::mem::take(&mut self.pending_tool_calls); + let mut calls = Vec::with_capacity(pending.len()); + + for (_index, id, name, args) in pending { + calls.push(parse_native_json_tool_call(id.as_deref(), &name, &args)?); + } + + Ok(calls) + } + + /// Finalize the stream — emit any remaining tool calls. + fn finalize(&mut self) -> Result, InferenceError> { + // Check for pending native tool calls + if !self.pending_tool_calls.is_empty() { + let calls = self.finalize_native_tool_calls()?; + return Ok(Some(StreamChunk { + token: None, + tool_calls: Some(calls), + finish_reason: Some("tool_calls".into()), + })); + } + + // Check for text-based tool calls in accumulated content + if !self.accumulated_content.is_empty() { + let calls = match self.tool_call_format { + ToolCallFormat::Pythonic => parse_pythonic_tool_calls(&self.accumulated_content)?, + ToolCallFormat::Bracket => parse_bracket_tool_calls(&self.accumulated_content)?, + ToolCallFormat::NativeJson => Vec::new(), + }; + if !calls.is_empty() { + return Ok(Some(StreamChunk { + token: None, + tool_calls: Some(calls), + finish_reason: Some("tool_calls".into()), + })); + } + } + + Ok(None) + } +} + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +/// Parse a non-streaming response body into tool calls and content. +/// +/// Used for fallback when streaming is not supported by the endpoint. +pub fn parse_non_streaming_response( + body: &str, + format: ToolCallFormat, +) -> Result { + #[derive(Deserialize)] + struct NonStreamResponse { + choices: Vec, + } + + #[derive(Deserialize)] + struct NonStreamChoice { + message: NonStreamMessage, + finish_reason: Option, + } + + #[derive(Deserialize)] + struct NonStreamMessage { + content: Option, + /// Reasoning/thinking content from models like Qwen3, GPT-OSS. + /// Deserialized to prevent serde unknown-field errors, but not used — + /// `content` holds the actual answer. See ADR comment on reasoning models. + #[allow(dead_code)] + reasoning: Option, + tool_calls: Option>, + } + + #[derive(Deserialize)] + struct NonStreamToolCall { + id: Option, + function: NonStreamFunction, + } + + #[derive(Deserialize)] + struct NonStreamFunction { + name: String, + arguments: String, + } + + let resp: NonStreamResponse = + serde_json::from_str(body).map_err(|e| InferenceError::StreamError { + reason: format!("failed to parse non-streaming response: {e}"), + })?; + + let choice = resp.choices.first().ok_or(InferenceError::StreamError { + reason: "empty choices array".into(), + })?; + + // Use `content` only. Reasoning/thinking models (Qwen3, GPT-OSS via Ollama) + // put chain-of-thought in `reasoning` and the actual answer in `content`. + // If `content` is empty (model exhausted max_tokens during reasoning), we + // treat it as an empty response — the caller handles the retry/fallback. + let content = choice.message.content.clone().filter(|c| !c.is_empty()); + + // Check for native tool calls in the response + let mut tool_calls = Vec::new(); + if let Some(ref tcs) = choice.message.tool_calls { + for tc in tcs { + let id = tc + .id + .clone() + .unwrap_or_else(|| format!("call_{}", Uuid::new_v4())); + let args: serde_json::Value = serde_json::from_str(&tc.function.arguments) + .map_err(|e| InferenceError::ToolCallParseError { + raw_response: tc.function.arguments.clone(), + reason: format!("invalid JSON: {e}"), + })?; + tool_calls.push(ToolCall { + id, + name: tc.function.name.clone(), + arguments: args, + }); + } + } + + // Check for text-based tool calls in content (pythonic or bracket format) + if tool_calls.is_empty() { + if let Some(ref text) = content { + let parsed = match format { + ToolCallFormat::Pythonic => parse_pythonic_tool_calls(text)?, + ToolCallFormat::Bracket => parse_bracket_tool_calls(text)?, + ToolCallFormat::NativeJson => Vec::new(), + }; + if !parsed.is_empty() { + tool_calls = parsed; + } + } + } + + let finish_reason = if !tool_calls.is_empty() { + Some("tool_calls".into()) + } else { + choice.finish_reason.clone() + }; + + Ok(StreamChunk { + token: content, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + finish_reason, + }) +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_non_streaming_with_content() { + let body = r#"{ + "choices": [{ + "message": {"role": "assistant", "content": "Hello, world!"}, + "finish_reason": "stop" + }] + }"#; + + let chunk = parse_non_streaming_response(body, ToolCallFormat::NativeJson).unwrap(); + assert_eq!(chunk.token.as_deref(), Some("Hello, world!")); + assert!(chunk.tool_calls.is_none()); + assert_eq!(chunk.finish_reason.as_deref(), Some("stop")); + } + + #[test] + fn test_parse_non_streaming_with_tool_calls() { + let body = r#"{ + "choices": [{ + "message": { + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": "call_abc", + "type": "function", + "function": { + "name": "filesystem.list_dir", + "arguments": "{\"path\": \"/tmp\"}" + } + }] + }, + "finish_reason": "tool_calls" + }] + }"#; + + let chunk = parse_non_streaming_response(body, ToolCallFormat::NativeJson).unwrap(); + let calls = chunk.tool_calls.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.list_dir"); + assert_eq!(calls[0].arguments["path"], "/tmp"); + } + + #[test] + fn test_parse_non_streaming_pythonic() { + let body = r#"{ + "choices": [{ + "message": { + "role": "assistant", + "content": "Let me check.\n\nTool: filesystem.list_dir\nArguments: {\"path\": \"/tmp\"}" + }, + "finish_reason": "stop" + }] + }"#; + + let chunk = parse_non_streaming_response(body, ToolCallFormat::Pythonic).unwrap(); + let calls = chunk.tool_calls.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.list_dir"); + assert_eq!(chunk.finish_reason.as_deref(), Some("tool_calls")); + } + + #[test] + fn test_parse_non_streaming_empty_choices() { + let body = r#"{"choices": []}"#; + let result = parse_non_streaming_response(body, ToolCallFormat::NativeJson); + assert!(result.is_err()); + } + + /// Reasoning models (Qwen3, GPT-OSS via Ollama) return both `content` and + /// `reasoning`. We use `content` only — `reasoning` is chain-of-thought. + #[test] + fn test_parse_non_streaming_reasoning_model_uses_content() { + let body = r#"{ + "choices": [{ + "message": { + "role": "assistant", + "content": "{\"needs_tools\":true,\"steps\":[]}", + "reasoning": "Let me think about this task... The user wants..." + }, + "finish_reason": "stop" + }] + }"#; + + let chunk = parse_non_streaming_response(body, ToolCallFormat::NativeJson).unwrap(); + assert_eq!( + chunk.token.as_deref(), + Some("{\"needs_tools\":true,\"steps\":[]}"), + "should use content, not reasoning" + ); + } + + /// When a reasoning model exhausts max_tokens during thinking, `content` is + /// empty and `reasoning` has the chain-of-thought. We return None for content + /// (not the reasoning text) so callers can handle the incomplete response. + #[test] + fn test_parse_non_streaming_reasoning_model_empty_content() { + let body = r#"{ + "choices": [{ + "message": { + "role": "assistant", + "content": "", + "reasoning": "Let me think step by step about this..." + }, + "finish_reason": "length" + }] + }"#; + + let chunk = parse_non_streaming_response(body, ToolCallFormat::NativeJson).unwrap(); + assert!( + chunk.token.is_none(), + "empty content should be None, not fallback to reasoning" + ); + assert_eq!(chunk.finish_reason.as_deref(), Some("length")); + } + + /// Deserialization should not fail when `reasoning` field is present. + #[test] + fn test_parse_non_streaming_reasoning_field_deserialized() { + let body = r#"{ + "choices": [{ + "message": { + "role": "assistant", + "content": "Hello!", + "reasoning": "Quick response needed." + }, + "finish_reason": "stop" + }] + }"#; + + let chunk = parse_non_streaming_response(body, ToolCallFormat::NativeJson).unwrap(); + assert_eq!(chunk.token.as_deref(), Some("Hello!")); + } + + #[test] + fn test_parse_non_streaming_bracket() { + let body = r#"{ + "choices": [{ + "message": { + "role": "assistant", + "content": "I'll list the directory.\n\n[filesystem.list_dir(path=\"/tmp\")]" + }, + "finish_reason": "stop" + }] + }"#; + + let chunk = parse_non_streaming_response(body, ToolCallFormat::Bracket).unwrap(); + let calls = chunk.tool_calls.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.list_dir"); + assert_eq!(calls[0].arguments["path"], "/tmp"); + assert_eq!(chunk.finish_reason.as_deref(), Some("tool_calls")); + } + + #[test] + fn test_parse_non_streaming_bracket_special_tokens() { + let body = r#"{ + "choices": [{ + "message": { + "role": "assistant", + "content": "<|tool_call_start|>[filesystem.search_files(pattern=\"*.pdf\", path=\"/home\")]<|tool_call_end|>" + }, + "finish_reason": "stop" + }] + }"#; + + let chunk = parse_non_streaming_response(body, ToolCallFormat::Bracket).unwrap(); + let calls = chunk.tool_calls.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.search_files"); + assert_eq!(calls[0].arguments["pattern"], "*.pdf"); + } +} diff --git a/src-tauri/src/inference/tool_call_parser.rs b/src-tauri/src/inference/tool_call_parser.rs new file mode 100644 index 0000000..a0319ec --- /dev/null +++ b/src-tauri/src/inference/tool_call_parser.rs @@ -0,0 +1,1077 @@ +//! Tool call parsing — normalizes model output to `ToolCall` structs. +//! +//! Supports three formats (per `_models/config.yaml` `tool_call_format`): +//! +//! 1. **native_json** — Standard OpenAI JSON tool calls (Qwen, GPT). +//! The model returns `tool_calls` in the response delta with function name +//! and JSON-encoded arguments. +//! +//! 2. **pythonic** — LFM2.5-style text-based calls. +//! The model emits lines like: +//! ```text +//! Tool: filesystem.list_dir +//! Arguments: {"path": "/Users/chintan"} +//! ``` +//! These must be parsed and converted to the standard `ToolCall` struct. +//! +//! 3. **bracket** — LFM2-24B-A2B bracket format. +//! The model emits tool calls in text content using brackets: +//! ```text +//! [filesystem.list_dir(path="/tmp")] +//! ``` +//! Or with special tokens: +//! ```text +//! <|tool_call_start|>[filesystem.list_dir(path="/tmp")]<|tool_call_end|> +//! ``` + +use uuid::Uuid; + +use super::config::ToolCallFormat; +use super::errors::InferenceError; +use super::types::ToolCall; + +// ─── Native JSON Parsing ───────────────────────────────────────────────────── + +/// Parse a tool call from the accumulated streaming deltas (native_json format). +/// +/// `name` and `arguments_json` are the concatenated values from all chunks +/// for a single tool call index. +pub fn parse_native_json_tool_call( + id: Option<&str>, + name: &str, + arguments_json: &str, +) -> Result { + let call_id = id + .map(String::from) + .unwrap_or_else(|| format!("call_{}", Uuid::new_v4())); + + if name.is_empty() { + return Err(InferenceError::ToolCallParseError { + raw_response: arguments_json.to_string(), + reason: "empty tool name".into(), + }); + } + + let arguments: serde_json::Value = + serde_json::from_str(arguments_json).map_err(|e| InferenceError::ToolCallParseError { + raw_response: arguments_json.to_string(), + reason: format!("invalid JSON arguments: {e}"), + })?; + + Ok(ToolCall { + id: call_id, + name: name.to_string(), + arguments, + }) +} + +// ─── Pythonic Format Parsing ───────────────────────────────────────────────── + +/// Extract tool calls from text content that uses the Pythonic format. +/// +/// Looks for patterns like: +/// ```text +/// Tool: server.tool_name +/// Arguments: {"key": "value"} +/// ``` +/// +/// Returns all tool calls found in the text. +pub fn parse_pythonic_tool_calls(text: &str) -> Result, InferenceError> { + let mut calls = Vec::new(); + let lines: Vec<&str> = text.lines().collect(); + + let mut i = 0; + while i < lines.len() { + let line = lines[i].trim(); + + // Look for "Tool: " lines + if let Some(name) = line.strip_prefix("Tool:").or_else(|| line.strip_prefix("tool:")) { + let tool_name = name.trim().to_string(); + + if tool_name.is_empty() { + i += 1; + continue; + } + + // The next line should be "Arguments: " + let arguments = if i + 1 < lines.len() { + let next_line = lines[i + 1].trim(); + if let Some(args_str) = next_line + .strip_prefix("Arguments:") + .or_else(|| next_line.strip_prefix("arguments:")) + { + let args_str = args_str.trim(); + i += 1; // consume the arguments line + + serde_json::from_str(args_str).map_err(|e| { + InferenceError::ToolCallParseError { + raw_response: args_str.to_string(), + reason: format!("invalid Pythonic arguments JSON: {e}"), + } + })? + } else { + // No arguments line — use empty object + serde_json::Value::Object(serde_json::Map::new()) + } + } else { + serde_json::Value::Object(serde_json::Map::new()) + }; + + calls.push(ToolCall { + id: format!("call_{}", Uuid::new_v4()), + name: tool_name, + arguments, + }); + } + + i += 1; + } + + Ok(calls) +} + +// ─── Bracket Format Parsing ───────────────────────────────────────────────── + +/// Extract tool calls from text that uses the bracket format (LFM2-24B-A2B). +/// +/// Supports multiple modes (tried in order, first match wins): +/// +/// 1. **Special tokens**: `<|tool_call_start|>[server.tool(args)]<|tool_call_end|>` +/// 2. **Bare bracket**: `[server.tool_name(key="value")]` +/// 3. **Backtick mention**: `` `server.tool_name` `` (no arguments) +/// +/// Arguments inside parens are parsed as Python-style kwargs and converted to JSON. +pub fn parse_bracket_tool_calls(text: &str) -> Result, InferenceError> { + // Mode 1: Special token markers + let mut calls = parse_bracket_special_tokens(text)?; + if !calls.is_empty() { + return Ok(calls); + } + + // Mode 2: Bare bracket [server.tool_name(args)] + calls = parse_bracket_bare(text)?; + if !calls.is_empty() { + return Ok(calls); + } + + // Mode 3: Backtick or bare mention of tool names (no arguments) + calls = parse_bracket_mention(text)?; + + Ok(calls) +} + +/// Mode 1: Parse `<|tool_call_start|>…<|tool_call_end|>` blocks. +fn parse_bracket_special_tokens(text: &str) -> Result, InferenceError> { + const START_TAG: &str = "<|tool_call_start|>"; + const END_TAG: &str = "<|tool_call_end|>"; + + let mut calls = Vec::new(); + let mut search_from = 0; + + while let Some(start_offset) = text[search_from..].find(START_TAG) { + let abs_start = search_from + start_offset + START_TAG.len(); + if let Some(end_offset) = text[abs_start..].find(END_TAG) { + let block = text[abs_start..abs_start + end_offset].trim(); + search_from = abs_start + end_offset + END_TAG.len(); + + // Strip optional outer brackets + let inner = if block.starts_with('[') && block.ends_with(']') { + &block[1..block.len() - 1] + } else { + block + }; + + if inner.is_empty() { + continue; + } + + if let Some(call) = parse_bracket_expression(inner)? { + calls.push(call); + } + } else { + break; + } + } + + Ok(calls) +} + +/// Mode 2: Parse bare `[server.tool_name(args)]` patterns. +fn parse_bracket_bare(text: &str) -> Result, InferenceError> { + let mut calls = Vec::new(); + + // Match [word.word_name(anything)] — the tool name must be dotted + // Use a manual scan to handle nested parens correctly + let bytes = text.as_bytes(); + let mut i = 0; + while i < bytes.len() { + if bytes[i] == b'[' { + // Find the matching close bracket + if let Some(close) = find_matching_bracket(text, i) { + let inner = &text[i + 1..close]; + // Check if it matches tool_name(args) pattern + if let Some(paren) = inner.find('(') { + let name = inner[..paren].trim(); + if is_dotted_tool_name(name) { + let args_str = if inner.ends_with(')') { + &inner[paren + 1..inner.len() - 1] + } else { + &inner[paren + 1..] + }; + let arguments = parse_bracket_args(args_str); + calls.push(ToolCall { + id: format!("call_{}", Uuid::new_v4()), + name: name.to_string(), + arguments, + }); + } + } + i = close + 1; + } else { + i += 1; + } + } else { + i += 1; + } + } + + Ok(calls) +} + +/// Mode 3: Parse backtick mentions `` `server.tool` `` or bare `server.tool` references. +fn parse_bracket_mention(text: &str) -> Result, InferenceError> { + let mut calls = Vec::new(); + + // Try backtick first: `server.tool_name` + let mut search_from = 0; + while let Some(start) = text[search_from..].find('`') { + let abs_start = search_from + start + 1; + if let Some(end_offset) = text[abs_start..].find('`') { + let name = &text[abs_start..abs_start + end_offset]; + if is_dotted_tool_name(name) { + calls.push(ToolCall { + id: format!("call_{}", Uuid::new_v4()), + name: name.to_string(), + arguments: serde_json::Value::Object(serde_json::Map::new()), + }); + // Only take the first backtick mention + return Ok(calls); + } + search_from = abs_start + end_offset + 1; + } else { + break; + } + } + + Ok(calls) +} + +/// Common file extensions — used to reject filenames that structurally +/// resemble dotted tool names (e.g., `original_me.png`, `report.txt`). +/// +/// This is a closed, stable set (file extensions rarely change) unlike +/// server prefixes which evolve as new MCP servers are added. Actual +/// tool name validation against the runtime registry happens downstream +/// in `ToolRegistry::resolve()`. +const FILE_EXTENSIONS: &[&str] = &[ + // Images + "png", "jpg", "jpeg", "gif", "bmp", "svg", "webp", "ico", "tiff", + // Documents + "pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx", "odt", + // Text / data + "txt", "md", "csv", "json", "xml", "yaml", "yml", "toml", + // Web / code + "html", "htm", "css", "js", "ts", "py", "rs", "go", "rb", + // Archives + "zip", "tar", "gz", "bz", "rar", "dmg", "iso", + // Media + "mp3", "mp4", "wav", "avi", "mov", "mkv", "flac", + // Misc + "log", "bak", "tmp", "swp", +]; + +/// Check if a string looks structurally like a dotted tool name +/// (e.g., `filesystem.list_dir`). +/// +/// This is a **syntactic pre-filter** for the bracket parser — it does not +/// validate against the runtime tool registry. False positives that slip +/// through are caught downstream by `ToolRegistry::resolve()`. +/// +/// Rules: +/// 1. Exactly one dot separating two parts +/// 2. Both parts are non-empty, lowercase ASCII letters + underscores only +/// 3. The suffix (part after the dot) must not be a known file extension +fn is_dotted_tool_name(name: &str) -> bool { + let parts: Vec<&str> = name.split('.').collect(); + if parts.len() != 2 { + return false; + } + let server = parts[0]; + let tool = parts[1]; + + if server.is_empty() || tool.is_empty() { + return false; + } + + let valid_char = |c: char| c.is_ascii_lowercase() || c == '_'; + if !server.chars().all(valid_char) || !tool.chars().all(valid_char) { + return false; + } + + // Reject if the suffix is a known file extension + !FILE_EXTENSIONS.contains(&tool) +} + +/// Find the matching `]` for a `[` at position `start`. +/// +/// Brackets inside quoted strings are ignored to handle cases like: +/// `[task.create_task(assignments="[team]")]` where the inner `[` is part +/// of a string value, not a nested bracket. +fn find_matching_bracket(text: &str, start: usize) -> Option { + let bytes = text.as_bytes(); + let mut depth = 0; + let mut in_string = false; + let mut string_char = 0u8; + let mut i = start; + + while i < bytes.len() { + let b = bytes[i]; + + // Track string boundaries (skip brackets inside quoted strings) + if !in_string && (b == b'"' || b == b'\'') { + in_string = true; + string_char = b; + i += 1; + continue; + } + if in_string { + if b == string_char && (i == 0 || bytes[i - 1] != b'\\') { + in_string = false; + } + i += 1; + continue; + } + + // Only count brackets outside strings + match b { + b'[' => depth += 1, + b']' => { + depth -= 1; + if depth == 0 { + return Some(i); + } + } + _ => {} + } + i += 1; + } + None +} + +/// Parse a single bracket expression like `server.tool(key="value", key2=123)`. +fn parse_bracket_expression(expr: &str) -> Result, InferenceError> { + let paren_idx = match expr.find('(') { + Some(idx) => idx, + None => { + // No parens — just a tool name + let name = expr.trim(); + if name.is_empty() { + return Ok(None); + } + return Ok(Some(ToolCall { + id: format!("call_{}", Uuid::new_v4()), + name: name.to_string(), + arguments: serde_json::Value::Object(serde_json::Map::new()), + })); + } + }; + + let tool_name = expr[..paren_idx].trim(); + if tool_name.is_empty() { + return Ok(None); + } + + // Extract args between ( and final ) + let args_str = if expr.ends_with(')') { + &expr[paren_idx + 1..expr.len() - 1] + } else { + &expr[paren_idx + 1..] + }; + + let arguments = parse_bracket_args(args_str); + + Ok(Some(ToolCall { + id: format!("call_{}", Uuid::new_v4()), + name: tool_name.to_string(), + arguments, + })) +} + +/// Parse Python-style kwargs like `key="value", key2=123` into a JSON object. +/// +/// Handles: string values (single/double quoted), numeric values, booleans, +/// and raw JSON. Falls back to empty object if kwargs parsing finds nothing. +fn parse_bracket_args(raw: &str) -> serde_json::Value { + let raw = raw.trim(); + if raw.is_empty() { + return serde_json::Value::Object(serde_json::Map::new()); + } + + // Try parsing as raw JSON first (some models emit JSON in brackets) + if raw.starts_with('{') { + if let Ok(v) = serde_json::from_str::(raw) { + return v; + } + } + + // Parse Python-style kwargs: key="value", key2=123 + let mut map = serde_json::Map::new(); + let mut remaining = raw; + + while !remaining.is_empty() { + remaining = remaining.trim_start_matches([',', ' '].as_ref()).trim(); + if remaining.is_empty() { + break; + } + + // Find key=value + let eq_idx = match remaining.find('=') { + Some(idx) => idx, + None => break, + }; + + let key = remaining[..eq_idx].trim().trim_matches('"').trim_matches('\''); + remaining = &remaining[eq_idx + 1..]; + + // Parse the value + let (value, rest) = parse_bracket_value(remaining); + map.insert(key.to_string(), value); + remaining = rest; + } + + if map.is_empty() { + // Fallback: wrap as empty object + serde_json::Value::Object(serde_json::Map::new()) + } else { + serde_json::Value::Object(map) + } +} + +/// Parse a single value from a kwargs expression. Returns `(value, remaining_str)`. +fn parse_bracket_value(input: &str) -> (serde_json::Value, &str) { + let input = input.trim(); + + // Quoted string (double or single) + if input.starts_with('"') || input.starts_with('\'') { + let quote = input.as_bytes()[0] as char; + let mut end = 1; + let mut escaped = false; + for ch in input[1..].chars() { + if escaped { + escaped = false; + end += ch.len_utf8(); + continue; + } + if ch == '\\' { + escaped = true; + end += 1; + continue; + } + if ch == quote { + let val = &input[1..end]; + let rest = &input[end + 1..]; + return (serde_json::Value::String(val.to_string()), rest); + } + end += ch.len_utf8(); + } + // Unterminated string — take everything + return (serde_json::Value::String(input[1..].to_string()), ""); + } + + // Find the next comma or end + let end_idx = input.find(',').unwrap_or(input.len()); + let val_str = input[..end_idx].trim().trim_end_matches(')'); + + // Try numeric + if let Ok(n) = val_str.parse::() { + return (serde_json::Value::Number(n.into()), &input[end_idx..]); + } + if let Ok(n) = val_str.parse::() { + if let Some(num) = serde_json::Number::from_f64(n) { + return (serde_json::Value::Number(num), &input[end_idx..]); + } + } + + // Boolean / null + match val_str.to_lowercase().as_str() { + "true" => return (serde_json::Value::Bool(true), &input[end_idx..]), + "false" => return (serde_json::Value::Bool(false), &input[end_idx..]), + "none" | "null" => return (serde_json::Value::Null, &input[end_idx..]), + _ => {} + } + + // Fallback: treat as string + (serde_json::Value::String(val_str.to_string()), &input[end_idx..]) +} + +// ─── Malformed JSON Repair ────────────────────────────────────────────────── + +/// Extract tool name and raw arguments JSON from an Ollama HTTP 500 error body. +/// +/// Ollama's Harmony parser returns errors like: +/// ```json +/// {"error":{"message":"error parsing tool call: raw='{...}', err=..."}} +/// ``` +/// +/// Returns `Some((tool_name, raw_arguments))` if the error matches, `None` otherwise. +pub fn extract_tool_call_from_error(error_body: &str) -> Option<(String, String)> { + // Parse the error JSON to extract the message + let parsed: serde_json::Value = serde_json::from_str(error_body).ok()?; + let message = parsed + .get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str())?; + + // Must be a tool call parse error + if !message.contains("error parsing tool call") { + return None; + } + + // Extract raw='{...}' — find the boundaries + let raw_start = message.find("raw='")?; + let raw_content_start = raw_start + 5; // skip "raw='" + let raw_end = message[raw_content_start..].rfind("', err=")?; + let raw_json = &message[raw_content_start..raw_content_start + raw_end]; + + // Try to extract tool name from the raw JSON keys or from a + // best-effort parse. The raw JSON is the arguments object, so the tool + // name comes from Ollama's Harmony format. We look for known tool + // argument patterns to infer the tool. + // For now, we return an empty tool name — the caller must resolve it + // from context (e.g., the last tool call the model was attempting). + // + // In practice, Ollama's error doesn't include the tool name in the raw + // field. The tool name is available in the Harmony channel marker + // (`to=functions.`) which is in the full response but not the + // error message. We return empty and let the caller handle it. + Some((String::new(), raw_json.to_string())) +} + +/// Attempt to repair malformed JSON arguments from a model tool call. +/// +/// Common malformations observed in production: +/// 1. Double quotes: `"key":"value"` (extra quote before value) +/// 2. Trailing commas: `{"a":1,}` +/// 3. Missing closing brace (unbalanced) +/// 4. Unescaped control characters in string values +/// +/// Returns `Some(value)` if repair succeeds, `None` if irreparable. +pub fn repair_malformed_tool_call_json(raw: &str) -> Option { + // First, try parsing as-is (maybe it's already valid) + if let Ok(v) = serde_json::from_str::(raw) { + return Some(v); + } + + let mut repaired = raw.to_string(); + + // Repair 1: Fix double-quote patterns like `":"` → `":"` + // The model sometimes generates `"key":"value"` instead of `"key":"value"` + // Detect `":"` (quote-colon-quote-quote) and collapse to `":"` + repaired = repair_double_quotes(&repaired); + if let Ok(v) = serde_json::from_str::(&repaired) { + return Some(v); + } + + // Repair 2: Remove trailing commas before closing braces/brackets + repaired = repair_trailing_commas(&repaired); + if let Ok(v) = serde_json::from_str::(&repaired) { + return Some(v); + } + + // Repair 3: Balance braces — append missing closing braces + repaired = repair_unbalanced_braces(&repaired); + if let Ok(v) = serde_json::from_str::(&repaired) { + return Some(v); + } + + // Repair 4: Strip control characters (except \n, \r, \t which are valid in JSON strings) + repaired = repair_control_characters(&repaired); + if let Ok(v) = serde_json::from_str::(&repaired) { + return Some(v); + } + + None +} + +/// Fix double-quote patterns: `":"` → `":"` +/// +/// Scans character-by-character to find `":"` sequences and collapses the +/// extra quote. This handles the exact failure observed in production: +/// `"destination":""/Users/...` → `"destination":"/Users/...` +fn repair_double_quotes(input: &str) -> String { + let bytes = input.as_bytes(); + let mut result = Vec::with_capacity(bytes.len()); + let mut i = 0; + + while i < bytes.len() { + // Look for pattern: `":"` (colon followed by two quotes) + if i + 2 < bytes.len() + && bytes[i] == b':' + && bytes[i + 1] == b'"' + && bytes[i + 2] == b'"' + { + // Check this isn't a legitimate empty string `:""` + // Empty string would be followed by `,` or `}` or end + if i + 3 < bytes.len() && bytes[i + 3] != b',' && bytes[i + 3] != b'}' { + // Extra quote — skip it: emit `:"` and skip the second `"` + result.push(b':'); + result.push(b'"'); + i += 3; // skip `:""` + continue; + } + } + result.push(bytes[i]); + i += 1; + } + + String::from_utf8(result).unwrap_or_else(|_| input.to_string()) +} + +/// Remove trailing commas before `}` or `]`. +fn repair_trailing_commas(input: &str) -> String { + let mut result = String::with_capacity(input.len()); + let chars: Vec = input.chars().collect(); + let mut i = 0; + + while i < chars.len() { + if chars[i] == ',' { + // Look ahead past whitespace for `}` or `]` + let mut j = i + 1; + while j < chars.len() && chars[j].is_whitespace() { + j += 1; + } + if j < chars.len() && (chars[j] == '}' || chars[j] == ']') { + // Skip the trailing comma + i += 1; + continue; + } + } + result.push(chars[i]); + i += 1; + } + + result +} + +/// Append closing braces to balance unmatched opening braces. +fn repair_unbalanced_braces(input: &str) -> String { + let mut brace_depth: i32 = 0; + let mut in_string = false; + let mut escape_next = false; + + for ch in input.chars() { + if escape_next { + escape_next = false; + continue; + } + if ch == '\\' && in_string { + escape_next = true; + continue; + } + if ch == '"' { + in_string = !in_string; + continue; + } + if !in_string { + if ch == '{' { + brace_depth += 1; + } else if ch == '}' { + brace_depth -= 1; + } + } + } + + if brace_depth > 0 { + let mut result = input.to_string(); + for _ in 0..brace_depth { + result.push('}'); + } + result + } else { + input.to_string() + } +} + +/// Remove non-printable control characters that break JSON parsing. +/// Preserves `\n`, `\r`, `\t` which are valid in JSON strings. +fn repair_control_characters(input: &str) -> String { + input + .chars() + .filter(|&c| !c.is_control() || c == '\n' || c == '\r' || c == '\t') + .collect() +} + +/// Parse tool calls from accumulated content, using the configured format. +pub fn parse_tool_calls( + format: ToolCallFormat, + content: &str, + native_calls: &[(Option, String, String)], +) -> Result, InferenceError> { + match format { + ToolCallFormat::NativeJson => { + let mut calls = Vec::new(); + for (id, name, args) in native_calls { + calls.push(parse_native_json_tool_call( + id.as_deref(), + name, + args, + )?); + } + Ok(calls) + } + ToolCallFormat::Pythonic => parse_pythonic_tool_calls(content), + ToolCallFormat::Bracket => parse_bracket_tool_calls(content), + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_native_json_valid() { + let result = parse_native_json_tool_call( + Some("call_123"), + "filesystem.list_dir", + r#"{"path": "/tmp"}"#, + ) + .unwrap(); + + assert_eq!(result.id, "call_123"); + assert_eq!(result.name, "filesystem.list_dir"); + assert_eq!(result.arguments["path"], "/tmp"); + } + + #[test] + fn test_parse_native_json_generates_id() { + let result = parse_native_json_tool_call( + None, + "filesystem.read_file", + r#"{"path": "/etc/hosts"}"#, + ) + .unwrap(); + + assert!(result.id.starts_with("call_")); + assert_eq!(result.name, "filesystem.read_file"); + } + + #[test] + fn test_parse_native_json_empty_name() { + let result = parse_native_json_tool_call(None, "", r#"{}"#); + assert!(result.is_err()); + } + + #[test] + fn test_parse_native_json_invalid_json() { + let result = parse_native_json_tool_call(None, "test.tool", "not json"); + assert!(result.is_err()); + if let Err(InferenceError::ToolCallParseError { reason, .. }) = result { + assert!(reason.contains("invalid JSON")); + } + } + + #[test] + fn test_parse_pythonic_single_call() { + let text = "Tool: filesystem.list_dir\nArguments: {\"path\": \"/tmp\"}"; + let calls = parse_pythonic_tool_calls(text).unwrap(); + + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.list_dir"); + assert_eq!(calls[0].arguments["path"], "/tmp"); + } + + #[test] + fn test_parse_pythonic_multiple_calls() { + let text = "\ +I'll list the directory and then read a file. + +Tool: filesystem.list_dir +Arguments: {\"path\": \"/tmp\"} + +Tool: filesystem.read_file +Arguments: {\"path\": \"/tmp/test.txt\"}"; + + let calls = parse_pythonic_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "filesystem.list_dir"); + assert_eq!(calls[1].name, "filesystem.read_file"); + } + + #[test] + fn test_parse_pythonic_no_arguments() { + let text = "Tool: system.get_info\nSome other text"; + let calls = parse_pythonic_tool_calls(text).unwrap(); + + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "system.get_info"); + assert!(calls[0].arguments.is_object()); + } + + #[test] + fn test_parse_pythonic_no_tool_calls() { + let text = "Just a regular response with no tool calls."; + let calls = parse_pythonic_tool_calls(text).unwrap(); + assert!(calls.is_empty()); + } + + #[test] + fn test_parse_tool_calls_native() { + let native = vec![( + Some("id1".to_string()), + "test.tool".to_string(), + r#"{"key": "val"}"#.to_string(), + )]; + + let calls = parse_tool_calls(ToolCallFormat::NativeJson, "", &native).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "test.tool"); + } + + #[test] + fn test_parse_tool_calls_pythonic() { + let content = "Tool: test.tool\nArguments: {\"key\": \"val\"}"; + let calls = parse_tool_calls(ToolCallFormat::Pythonic, content, &[]).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "test.tool"); + } + + // ─── JSON Repair Tests ────────────────────────────────────────────── + + #[test] + fn test_repair_double_quote() { + // The exact failure from agent.log: `"destination":""/Users/...` + let raw = r#"{"create_dirs":true,"destination":""/Users/chintan/Desktop/file.png","source":"/Users/chintan/Desktop/file.png"}"#; + let result = repair_malformed_tool_call_json(raw); + assert!(result.is_some(), "should repair double-quote pattern"); + let v = result.unwrap(); + assert_eq!(v["destination"], "/Users/chintan/Desktop/file.png"); + assert_eq!(v["source"], "/Users/chintan/Desktop/file.png"); + assert_eq!(v["create_dirs"], true); + } + + #[test] + fn test_repair_trailing_comma() { + let raw = r#"{"path": "/tmp", "recursive": true,}"#; + let result = repair_malformed_tool_call_json(raw); + assert!(result.is_some(), "should repair trailing comma"); + assert_eq!(result.unwrap()["path"], "/tmp"); + } + + #[test] + fn test_repair_missing_closing_brace() { + let raw = r#"{"path": "/tmp", "recursive": true"#; + let result = repair_malformed_tool_call_json(raw); + assert!(result.is_some(), "should repair missing closing brace"); + assert_eq!(result.unwrap()["path"], "/tmp"); + } + + #[test] + fn test_repair_already_valid() { + let raw = r#"{"path": "/tmp"}"#; + let result = repair_malformed_tool_call_json(raw); + assert!(result.is_some()); + assert_eq!(result.unwrap()["path"], "/tmp"); + } + + #[test] + fn test_repair_irreparable() { + let raw = "this is not json at all and cannot be repaired"; + let result = repair_malformed_tool_call_json(raw); + assert!(result.is_none(), "should return None for irreparable input"); + } + + #[test] + fn test_extract_tool_call_from_error_valid() { + let body = r#"{"error":{"message":"error parsing tool call: raw='{\"path\":\"\"/tmp\"}', err=invalid character '/' after object key:value pair"}}"#; + let result = extract_tool_call_from_error(body); + assert!(result.is_some()); + let (name, raw) = result.unwrap(); + assert!(name.is_empty(), "tool name not available in Ollama error"); + assert!(raw.contains("path")); + } + + #[test] + fn test_extract_tool_call_from_error_non_matching() { + let body = r#"{"error":{"message":"model not found"}}"#; + let result = extract_tool_call_from_error(body); + assert!(result.is_none()); + } + + #[test] + fn test_extract_tool_call_from_error_invalid_json() { + let result = extract_tool_call_from_error("not json"); + assert!(result.is_none()); + } + + // ─── Bracket Format Tests ───────────────────────────────────────────── + + #[test] + fn test_bracket_special_tokens() { + let text = r#"<|tool_call_start|>[filesystem.list_dir(path="/tmp")]<|tool_call_end|>"#; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.list_dir"); + assert_eq!(calls[0].arguments["path"], "/tmp"); + } + + #[test] + fn test_bracket_special_tokens_multiple() { + let text = r#"I'll help you. +<|tool_call_start|>[filesystem.list_dir(path="/tmp")]<|tool_call_end|> +<|tool_call_start|>[filesystem.read_file(path="/tmp/test.txt")]<|tool_call_end|>"#; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "filesystem.list_dir"); + assert_eq!(calls[1].name, "filesystem.read_file"); + } + + #[test] + fn test_bracket_bare() { + let text = r#"I'll search for the file. [filesystem.search_files(pattern="*.pdf", path="/home")]"#; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.search_files"); + assert_eq!(calls[0].arguments["pattern"], "*.pdf"); + assert_eq!(calls[0].arguments["path"], "/home"); + } + + #[test] + fn test_bracket_no_args() { + let text = "<|tool_call_start|>system.get_system_info<|tool_call_end|>"; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "system.get_system_info"); + assert!(calls[0].arguments.is_object()); + } + + #[test] + fn test_bracket_backtick_mention() { + let text = "You should use `filesystem.list_dir` to browse the directory."; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.list_dir"); + } + + #[test] + fn test_bracket_numeric_and_bool_args() { + let text = r#"[data.query_sqlite(query="SELECT *", limit=50, verbose=true)]"#; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "data.query_sqlite"); + assert_eq!(calls[0].arguments["query"], "SELECT *"); + assert_eq!(calls[0].arguments["limit"], 50); + assert_eq!(calls[0].arguments["verbose"], true); + } + + #[test] + fn test_bracket_no_tool_calls() { + let text = "Just a regular response with no tool calls at all."; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert!(calls.is_empty()); + } + + #[test] + fn test_bracket_json_args() { + let text = r#"<|tool_call_start|>[filesystem.list_dir({"path": "/tmp"})]<|tool_call_end|>"#; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.list_dir"); + assert_eq!(calls[0].arguments["path"], "/tmp"); + } + + #[test] + fn test_parse_tool_calls_bracket() { + let content = r#"[filesystem.list_dir(path="/tmp")]"#; + let calls = parse_tool_calls(ToolCallFormat::Bracket, content, &[]).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "filesystem.list_dir"); + } + + #[test] + fn test_is_dotted_tool_name() { + // Valid tool names — all pass structural check + assert!(is_dotted_tool_name("filesystem.list_dir")); + assert!(is_dotted_tool_name("ocr.extract_text_from_image")); + assert!(is_dotted_tool_name("document.extract_text")); + assert!(is_dotted_tool_name("knowledge.search")); + assert!(is_dotted_tool_name("meeting.transcribe")); + assert!(is_dotted_tool_name("security.scan_for_pii")); + assert!(is_dotted_tool_name("calendar.list_events")); + assert!(is_dotted_tool_name("email.send")); + assert!(is_dotted_tool_name("task.create")); + assert!(is_dotted_tool_name("data.query_sqlite")); + assert!(is_dotted_tool_name("audit.list_entries")); + assert!(is_dotted_tool_name("clipboard.read")); + assert!(is_dotted_tool_name("system.get_system_info")); + + // Valid structurally — unknown prefixes pass the syntactic check. + // Downstream ToolRegistry::resolve() handles semantic validation. + assert!(is_dotted_tool_name("unknown.tool_name")); + assert!(is_dotted_tool_name("custom.something")); + + // Invalid — not dotted at all + assert!(!is_dotted_tool_name("not_dotted")); + + // Invalid — structural issues + assert!(!is_dotted_tool_name(".starts_with_dot")); + assert!(!is_dotted_tool_name("ends_with_dot.")); + assert!(!is_dotted_tool_name("has.two.dots")); + assert!(!is_dotted_tool_name("has.UPPER")); + + // Invalid — filenames (suffix is a file extension) + assert!(!is_dotted_tool_name("original_me.png")); + assert!(!is_dotted_tool_name("screenshot_2026.pdf")); + assert!(!is_dotted_tool_name("report_final.txt")); + assert!(!is_dotted_tool_name("my_photo.jpg")); + assert!(!is_dotted_tool_name("archive.zip")); + assert!(!is_dotted_tool_name("audio_clip.mp3")); + assert!(!is_dotted_tool_name("config.yaml")); + assert!(!is_dotted_tool_name("debug.log")); + } + + // ─── Fix F8: Brackets inside quoted strings ───────────────────────── + + #[test] + fn find_matching_bracket_simple() { + let text = "[filesystem.list_dir(path=\"/tmp\")]"; + assert_eq!(find_matching_bracket(text, 0), Some(text.len() - 1)); + } + + #[test] + fn find_matching_bracket_with_inner_bracket_in_string() { + // The "[" inside the quoted string should NOT be counted as a nested bracket + let text = r#"[task.create_task(assignments="[team]")]"#; + assert_eq!(find_matching_bracket(text, 0), Some(text.len() - 1)); + } + + #[test] + fn find_matching_bracket_with_unmatched_bracket_in_value() { + // This was the Phase 2c Session 2 failure: assignments="[" + let text = r#"[task.create_task(title="Review Q4", assignments="[")]"#; + assert_eq!(find_matching_bracket(text, 0), Some(text.len() - 1)); + } + + #[test] + fn parse_bracket_with_brackets_in_args() { + let text = r#"[task.create_task(title="Review Q4", assignments="[team]")]"#; + let calls = parse_bracket_tool_calls(text).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "task.create_task"); + assert_eq!( + calls[0].arguments.get("title").and_then(|v| v.as_str()), + Some("Review Q4") + ); + } +} diff --git a/src-tauri/src/inference/types.rs b/src-tauri/src/inference/types.rs new file mode 100644 index 0000000..ea42915 --- /dev/null +++ b/src-tauri/src/inference/types.rs @@ -0,0 +1,438 @@ +//! Shared types for the inference client. +//! +//! These mirror the OpenAI Chat Completions API types, used for both +//! request building and response parsing. + +use serde::{Deserialize, Serialize}; + +// ─── Request Types ─────────────────────────────────────────────────────────── + +/// A single message in the conversation. +/// +/// Serialization notes for OpenAI-compatible local models: +/// - `content` must be `""` (not `null`) for assistant messages with tool calls. +/// Many local models (Ollama, llama.cpp) misinterpret `null` content and fail +/// to recognize the tool call round-trip pattern. +/// - `tool_call_id` and `tool_calls` are skipped when `None`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: Role, + #[serde(serialize_with = "serialize_content")] + pub content: Option, + /// Tool call results are sent back as `tool` role messages. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Assistant messages may contain tool calls. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +/// Custom serializer for `content`: emit `""` instead of `null` when `None`. +/// +/// OpenAI's API accepts `null` content, but many local LLM runtimes +/// (Ollama, llama.cpp, vLLM) reject or mishandle `null` content fields. +/// Using `""` (empty string) is universally safe. +fn serialize_content(value: &Option, serializer: S) -> Result +where + S: serde::Serializer, +{ + match value { + Some(s) => serializer.serialize_str(s), + None => serializer.serialize_str(""), + } +} + +/// Message role. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, + Tool, +} + +/// Tool definition sent in the request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDefinition { + pub r#type: String, + pub function: FunctionDefinition, +} + +/// Function definition within a tool. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +/// Structured output format hint for the model. +/// +/// When set to `json_object`, instructs Ollama to use GBNF grammar +/// enforcement to guarantee valid JSON output. This is opt-in and +/// experimental — only enable after live testing with the target model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseFormat { + /// The format type. Currently only `"json_object"` is supported. + pub r#type: String, +} + +/// Request body for `POST /v1/chat/completions`. +#[derive(Debug, Clone, Serialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + pub temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + pub max_tokens: u32, + pub stream: bool, + /// Optional structured output format. When set, the model backend + /// (Ollama/llama.cpp) uses grammar constraints to enforce valid output. + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, +} + +/// Optional sampling parameter overrides for a single inference call. +/// +/// When provided, these override the model config defaults. +/// Used to lower temperature/top_p for tool-calling turns (more deterministic) +/// and raise them for conversational turns (more creative). +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingOverrides { + /// Override temperature (0.0 = deterministic, 1.0 = creative). + pub temperature: Option, + /// Override top_p (nucleus sampling threshold). + pub top_p: Option, +} + +// ─── Response Types ────────────────────────────────────────────────────────── + +/// A parsed tool call extracted from the model's response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + /// Unique ID for this tool call (generated if the model doesn't provide one). + pub id: String, + /// Fully qualified tool name, e.g. `"filesystem.list_dir"`. + pub name: String, + /// Validated JSON arguments. + pub arguments: serde_json::Value, +} + +/// Tool call as returned in the OpenAI response format. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallResponse { + pub id: String, + pub r#type: String, + pub function: FunctionCallResponse, +} + +/// Function call details in a response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCallResponse { + pub name: String, + pub arguments: String, +} + +/// A single chunk from the streaming response. +#[derive(Debug, Clone)] +pub struct StreamChunk { + /// Incremental text token (if this chunk carries text). + pub token: Option, + /// Tool calls detected in this chunk (accumulated). + pub tool_calls: Option>, + /// Why the model stopped: `"stop"`, `"tool_calls"`, or `None` (still going). + pub finish_reason: Option, +} + +/// Raw SSE chunk from the OpenAI API. +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionChunk { + #[allow(dead_code)] + pub id: Option, + pub choices: Vec, +} + +/// A single choice within a streaming chunk. +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkChoice { + pub delta: ChunkDelta, + pub finish_reason: Option, +} + +/// The delta (incremental update) within a chunk choice. +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkDelta { + #[serde(default)] + pub content: Option, + /// Reasoning/thinking content from models like Qwen3 and GPT-OSS. + /// Deserialized to prevent serde unknown-field errors, but not used for + /// streaming output — `content` holds the actual answer after reasoning + /// completes. Reasoning tokens are silently discarded. + #[serde(default)] + #[allow(dead_code)] + pub reasoning: Option, + #[serde(default)] + pub tool_calls: Option>, +} + +/// A tool call fragment within a streaming delta. +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkToolCall { + pub index: Option, + pub id: Option, + pub function: Option, +} + +/// A function call fragment within a streaming tool call. +#[derive(Debug, Clone, Deserialize)] +pub struct ChunkFunction { + pub name: Option, + pub arguments: Option, +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_top_p_omitted_when_none() { + let req = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![], + tools: None, + tool_choice: None, + temperature: 0.7, + top_p: None, + max_tokens: 1024, + stream: false, + response_format: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(!json.contains("top_p"), "top_p should be omitted when None"); + } + + #[test] + fn test_top_p_included_when_some() { + let req = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![], + tools: None, + tool_choice: None, + temperature: 0.1, + top_p: Some(0.2), + max_tokens: 1024, + stream: false, + response_format: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!( + json.contains("\"top_p\":0.2"), + "top_p should appear in JSON when Some" + ); + } + + #[test] + fn test_response_format_omitted_when_none() { + let req = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![], + tools: None, + tool_choice: None, + temperature: 0.7, + top_p: None, + max_tokens: 1024, + stream: false, + response_format: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!( + !json.contains("response_format"), + "response_format should be omitted when None" + ); + } + + #[test] + fn test_response_format_included_when_set() { + let req = ChatCompletionRequest { + model: "test".to_string(), + messages: vec![], + tools: None, + tool_choice: None, + temperature: 0.7, + top_p: None, + max_tokens: 1024, + stream: false, + response_format: Some(ResponseFormat { + r#type: "json_object".to_string(), + }), + }; + let json = serde_json::to_string(&req).unwrap(); + assert!( + json.contains("\"response_format\""), + "response_format should appear in JSON when Some" + ); + assert!( + json.contains("\"json_object\""), + "type should be json_object" + ); + } + + #[test] + fn test_sampling_overrides_default() { + let overrides = SamplingOverrides::default(); + assert!(overrides.temperature.is_none()); + assert!(overrides.top_p.is_none()); + } + + #[test] + fn test_sampling_overrides_with_values() { + let overrides = SamplingOverrides { + temperature: Some(0.5), + top_p: Some(0.9), + }; + assert!(overrides.temperature.is_some()); + assert_eq!(overrides.temperature.unwrap(), 0.5); + } + + #[test] + fn test_sampling_overrides_serialization() { + let overrides = SamplingOverrides { + temperature: Some(0.3), + top_p: Some(0.7), + }; + let json = serde_json::to_string(&overrides).unwrap(); + assert!(json.contains("0.3")); + assert!(json.contains("0.7")); + } + + #[test] + fn test_sampling_overrides_deserialization() { + let json = r#"{"temperature": 0.4, "topP": 0.8}"#; + let overrides: SamplingOverrides = serde_json::from_str(json).unwrap(); + assert_eq!(overrides.temperature, Some(0.4)); + assert_eq!(overrides.top_p, Some(0.8)); + } +} + +/// Model status information for health monitoring. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelStatus { + pub key: String, + pub display_name: String, + pub base_url: String, + pub healthy: bool, + pub model_name: Option, + pub error: Option, +} + +#[cfg(test)] +mod model_status_tests { + use super::*; + + #[test] + fn test_model_status_healthy() { + let status = ModelStatus { + key: "model-a".to_string(), + display_name: "Model A".to_string(), + base_url: "http://localhost:11434".to_string(), + healthy: true, + model_name: Some("model-a:latest".to_string()), + error: None, + }; + + assert!(status.healthy); + assert!(status.error.is_none()); + assert!(status.model_name.is_some()); + } + + #[test] + fn test_model_status_unhealthy() { + let status = ModelStatus { + key: "model-a".to_string(), + display_name: "Model A".to_string(), + base_url: "http://localhost:11434".to_string(), + healthy: false, + model_name: None, + error: Some("connection refused".to_string()), + }; + + assert!(!status.healthy); + assert!(status.error.is_some()); + assert!(status.model_name.is_none()); + } + + #[test] + fn test_model_status_serialization() { + let status = ModelStatus { + key: "qwen".to_string(), + display_name: "Qwen 2.5".to_string(), + base_url: "http://localhost:1234/v1".to_string(), + healthy: true, + model_name: Some("qwen2.5:14b".to_string()), + error: None, + }; + + let json = serde_json::to_string(&status).unwrap(); + assert!(json.contains("qwen")); + assert!(json.contains("healthy")); + assert!(json.contains("true")); + } + + #[test] + fn test_model_status_deserialization() { + let json = r#"{ + "key": "lm-studio-model", + "displayName": "LM Studio Model", + "baseUrl": "http://localhost:1234/v1", + "healthy": true, + "modelName": "model-name", + "error": null + }"#; + + let status: ModelStatus = serde_json::from_str(json).unwrap(); + assert_eq!(status.key, "lm-studio-model"); + assert!(status.healthy); + assert_eq!(status.model_name, Some("model-name".to_string())); + } + + #[test] + fn test_model_status_with_error_deserialization() { + let json = r#"{ + "key": "ollama-model", + "displayName": "Ollama Model", + "baseUrl": "http://localhost:11434", + "healthy": false, + "modelName": null, + "error": "timeout after 30s" + }"#; + + let status: ModelStatus = serde_json::from_str(json).unwrap(); + assert!(!status.healthy); + assert_eq!(status.error, Some("timeout after 30s".to_string())); + } + + #[test] + fn test_model_status_display() { + let status = ModelStatus { + key: "test".to_string(), + display_name: "Test Model".to_string(), + base_url: "http://localhost:8080".to_string(), + healthy: true, + model_name: None, + error: None, + }; + + let debug_str = format!("{:?}", status); + assert!(debug_str.contains("Test Model")); + } +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs new file mode 100644 index 0000000..c8ec075 --- /dev/null +++ b/src-tauri/src/lib.rs @@ -0,0 +1,915 @@ +pub mod agent_core; +pub mod commands; +pub mod inference; +pub mod mcp_client; + +use std::collections::HashMap; + +use agent_core::{AgentDatabase, ConversationManager, ConfirmationResponse, PermissionStore}; +use commands::settings::SamplingConfig; +use mcp_client::McpClient; +use tauri::Manager; + +/// Pending confirmation channel — holds a oneshot sender while the agent loop +/// awaits a user response via the ConfirmationDialog. +pub type PendingConfirmation = + TokioMutex>>; + +/// In-flight request tracker — prevents duplicate requests for the same session. +pub type InFlightRequests = TokioMutex>; + +/// Async mutex for types that require `.await` inside their methods. +pub type TokioMutex = tokio::sync::Mutex; + +/// Return the platform-standard data directory for LocalCowork. +/// +/// - macOS: `~/Library/Application Support/com.localcowork.app/` +/// - Windows: `{FOLDERID_RoamingAppData}\localcowork\` +/// - Linux: `$XDG_DATA_HOME/com.localcowork.app/` (fallback `~/.local/share/...`) +/// +/// Falls back to `~/.localcowork/` only if none of the above can be resolved. +pub(crate) fn data_dir() -> std::path::PathBuf { + if let Some(dir) = dirs::data_dir() { + return dir.join("com.localcowork.app"); + } + dirs::home_dir() + .unwrap_or_else(|| std::path::PathBuf::from(".")) + .join(".localcowork") +} + +/// Returns the cache directory for the app (embedding indexes, etc.). +#[allow(dead_code)] +pub(crate) fn cache_dir() -> std::path::PathBuf { + data_dir().join("cache") +} + +/// Initialize the tracing subscriber — writes structured logs to the app data directory. +/// +/// On each app startup: +/// 1. Rotates existing logs (agent.log → agent.log.1 → .2 → .3, keeps last 3). +/// 2. Opens a fresh agent.log with a line-flushing writer for crash resilience. +/// 3. Logs a startup banner with the data directory path for discoverability. +/// +/// Returns early on error since logging is not critical to app functionality. +fn init_tracing() { + use tracing_subscriber::fmt; + use tracing_subscriber::EnvFilter; + + let log_dir = data_dir(); + let _ = std::fs::create_dir_all(&log_dir); + + let log_path = log_dir.join("agent.log"); + + // Rotate: agent.log.2 → .3, .1 → .2, agent.log → .1 + rotate_log_file(&log_path, 3); + + let log_file = match std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&log_path) + { + Ok(file) => file, + // Use eprintln since tracing isn't initialized yet + Err(e) => { + eprintln!("failed to open agent.log: {}", e); + return; + } + }; + + let flushing_writer = FlushingWriter::new(log_file); + + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("localcowork=info,warn")); + + fmt::fmt() + .with_env_filter(filter) + .with_writer(flushing_writer) + .with_ansi(false) + .with_target(true) + .with_thread_ids(false) + .init(); + + // Startup banner — makes it easy to find the right log file + tracing::info!( + version = env!("CARGO_PKG_VERSION"), + data_dir = %log_dir.display(), + log_file = %log_path.display(), + pid = std::process::id(), + "=== LocalCowork starting ===" + ); +} + +/// Rotate log files: `agent.log` → `agent.log.1` → `.2` → … → `.{keep}`. +/// +/// Oldest file beyond `keep` is deleted. Missing files in the chain are skipped. +fn rotate_log_file(base_path: &std::path::Path, keep: u32) { + // Delete the oldest + let oldest = format!("{}.{keep}", base_path.display()); + let _ = std::fs::remove_file(&oldest); + + // Shift: .{n-1} → .{n} + for i in (1..keep).rev() { + let from = format!("{}.{i}", base_path.display()); + let to = format!("{}.{}", base_path.display(), i + 1); + let _ = std::fs::rename(&from, &to); + } + + // Current → .1 + if base_path.exists() { + let to = format!("{}.1", base_path.display()); + let _ = std::fs::rename(base_path, &to); + } +} + +/// A writer that wraps `std::fs::File` and flushes after every write. +/// +/// `tracing-subscriber` buffers log output internally. Without explicit +/// flushing, log entries may sit in OS buffers and be lost on crash. +/// This wrapper ensures each log line is on disk immediately. +/// +/// Performance impact is minimal for a desktop app (~100 log lines/minute). +#[derive(Clone)] +struct FlushingWriter { + file: std::sync::Arc>, +} + +impl FlushingWriter { + fn new(file: std::fs::File) -> Self { + Self { + file: std::sync::Arc::new(std::sync::Mutex::new(file)), + } + } +} + +impl std::io::Write for FlushingWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut f = self.file.lock().map_err(|e| { + std::io::Error::other(format!("lock poisoned: {e}")) + })?; + let n = std::io::Write::write(&mut *f, buf)?; + std::io::Write::flush(&mut *f)?; + Ok(n) + } + + fn flush(&mut self) -> std::io::Result<()> { + let mut f = self.file.lock().map_err(|e| { + std::io::Error::other(format!("lock poisoned: {e}")) + })?; + std::io::Write::flush(&mut *f) + } +} + +impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for FlushingWriter { + type Writer = FlushingWriter; + + fn make_writer(&'a self) -> Self::Writer { + self.clone() + } +} + +/// Resolve the path for the agent SQLite database. +/// +/// Uses the platform-standard data directory (creates it if needed). +fn resolve_db_path() -> String { + let dir = data_dir(); + if !dir.exists() { + let _ = std::fs::create_dir_all(&dir); + } + dir.join("agent.db").to_string_lossy().into_owned() +} + +/// Resolve the MCP servers configuration using auto-discovery + optional overrides. +/// +/// 1. Auto-discovers servers by scanning `mcp-servers/` for `package.json` (TS) +/// or `pyproject.toml` (Python) markers. +/// 2. Loads `mcp-servers.json` as optional overrides (missing file is fine). +/// 3. Merges: override entries fully replace discovered entries. +/// 4. Resolves relative paths, venvs, and injects vision model env vars. +fn resolve_mcp_config() -> mcp_client::types::McpServersConfig { + let project_root = resolve_project_root(); + + // 1. Auto-discover servers from mcp-servers/ directory + let mcp_servers_dir = project_root.join("mcp-servers"); + let discovered = mcp_client::discovery::discover_servers(&mcp_servers_dir); + tracing::info!( + discovered = discovered.len(), + servers = ?discovered.keys().collect::>(), + "auto-discovered MCP servers" + ); + + // 2. Load optional override file + let overrides = load_override_file(&project_root); + + // 3. Merge: overrides win + let mut merged = mcp_client::discovery::merge_configs(discovered, overrides); + + // 4. Filter by enabled_servers allowlist from _models/config.yaml (if set) + filter_by_enabled_servers(&mut merged, &project_root); + + let mut config = mcp_client::types::McpServersConfig { servers: merged }; + + // 5. Post-process: resolve paths, venvs, inject vision env vars + resolve_paths_and_env(&mut config, &project_root); + + tracing::info!( + server_count = config.servers.len(), + servers = ?config.servers.keys().collect::>(), + "final MCP server config" + ); + + config +} + +/// Filter discovered servers by the `enabled_servers` allowlist in `_models/config.yaml`. +/// +/// When `enabled_servers` is set, only servers whose names appear in the list +/// are kept. All others are removed. When absent or empty, all servers pass through. +fn filter_by_enabled_servers( + servers: &mut std::collections::HashMap, + project_root: &std::path::Path, +) { + let config_path = project_root.join("_models/config.yaml"); + let content = match std::fs::read_to_string(&config_path) { + Ok(c) => c, + Err(_) => return, // No config file — skip filtering + }; + + // Parse just enough YAML to extract enabled_servers without requiring + // the full ModelsConfig (which needs model configs to be valid). + let yaml: serde_json::Value = match serde_yaml::from_str(&content) { + Ok(v) => v, + Err(_) => return, + }; + + let enabled = match yaml.get("enabled_servers").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return, // Field absent — no filtering + }; + + let allowlist: std::collections::HashSet = enabled + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + + if allowlist.is_empty() { + return; + } + + let before = servers.len(); + servers.retain(|name, _| allowlist.contains(name)); + let after = servers.len(); + + tracing::info!( + before, + after, + enabled = ?allowlist, + "filtered MCP servers by enabled_servers allowlist" + ); +} + +/// Filter tools by the `enabled_tools` allowlist in `_models/config.yaml`. +/// +/// When `enabled_tools` is set, only tools whose fully-qualified names appear +/// in the list are kept in the registry. All others are removed. This allows +/// curating a tight tool surface for specific demos or deployments. +/// +/// Must be called AFTER `McpClient::start_all()` has populated the registry. +fn filter_tools_by_allowlist(mcp_client: &mut McpClient, project_root: &std::path::Path) { + let config_path = project_root.join("_models/config.yaml"); + let content = match std::fs::read_to_string(&config_path) { + Ok(c) => c, + Err(_) => return, // No config file — skip filtering + }; + + let yaml: serde_json::Value = match serde_yaml::from_str(&content) { + Ok(v) => v, + Err(_) => return, + }; + + let enabled = match yaml.get("enabled_tools").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return, // Field absent — no filtering + }; + + let allowlist: std::collections::HashSet = enabled + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + + if allowlist.is_empty() { + return; + } + + mcp_client.registry.retain_tools(&allowlist); +} + +/// Determine the project root directory. +/// +/// Resolution order: +/// 1. `mcp-servers/` relative to cwd (dev mode, running from project root). +/// 2. `../mcp-servers/` relative to cwd (dev mode, running from `src-tauri/`). +/// 3. `mcp-servers/` relative to the executable (packaged app). +/// 4. Fallback: cwd parent directory. +pub(crate) fn resolve_project_root() -> std::path::PathBuf { + let cwd = std::env::current_dir().unwrap_or_default(); + + // Dev mode: cwd is the project root + if cwd.join("mcp-servers").is_dir() { + return cwd; + } + + // Dev mode: cwd is src-tauri/ + if cwd.join("..").join("mcp-servers").is_dir() { + return cwd.join("..").canonicalize().unwrap_or(cwd); + } + + // Packaged app: check relative to the executable location. + // macOS: .app/Contents/MacOS/localcowork → .app/Contents/Resources/ + // Windows: install_dir/localcowork.exe → install_dir/ + // Linux: install_dir/localcowork → install_dir/ + if let Ok(exe) = std::env::current_exe() { + if let Some(exe_dir) = exe.parent() { + // macOS .app bundle: Resources/ is a sibling of MacOS/ + let macos_resources = exe_dir.join("../Resources"); + if macos_resources.join("mcp-servers").is_dir() { + if let Ok(resolved) = macos_resources.canonicalize() { + return resolved; + } + } + // Flat layout (Windows/Linux or dev binary) + if exe_dir.join("mcp-servers").is_dir() { + return exe_dir.to_path_buf(); + } + } + } + + // Last resort: cwd parent + cwd.parent() + .unwrap_or(std::path::Path::new(".")) + .to_path_buf() +} + +/// Load the optional `mcp-servers.json` override file. +/// +/// Returns an empty map if the file doesn't exist or can't be parsed. +fn load_override_file( + project_root: &std::path::Path, +) -> std::collections::HashMap { + let candidates = [ + project_root.join("src-tauri/mcp-servers.json"), + project_root.join("mcp-servers.json"), + ]; + + for path in &candidates { + if let Ok(content) = std::fs::read_to_string(path) { + match serde_json::from_str::(&content) { + Ok(cfg) => { + tracing::info!( + path = %path.display(), + count = cfg.servers.len(), + "loaded MCP override config" + ); + return cfg.servers; + } + Err(e) => { + tracing::warn!( + path = %path.display(), + error = %e, + "failed to parse MCP override config" + ); + } + } + } + } + + std::collections::HashMap::new() +} + +/// Resolve relative paths, venvs, and inject vision env vars into all server configs. +fn resolve_paths_and_env( + config: &mut mcp_client::types::McpServersConfig, + project_root: &std::path::Path, +) { + for server_config in config.servers.values_mut() { + // Resolve relative cwd to absolute + if let Some(ref cwd) = server_config.cwd { + if !std::path::Path::new(cwd).is_absolute() { + let abs_cwd = project_root.join(cwd); + server_config.cwd = Some(abs_cwd.to_string_lossy().into_owned()); + } + } + + // Resolve venv: rewrite command to venv binary and inject env vars + if let Some(ref venv) = server_config.venv { + let base_dir = server_config + .cwd + .as_ref() + .map(std::path::PathBuf::from) + .unwrap_or_else(|| project_root.to_path_buf()); + + let abs_venv = if std::path::Path::new(venv).is_absolute() { + std::path::PathBuf::from(venv) + } else { + base_dir.join(venv) + }; + // Windows venvs use Scripts\ instead of bin/ + let venv_bin = if cfg!(target_os = "windows") { + abs_venv.join("Scripts") + } else { + abs_venv.join("bin") + }; + let venv_command = venv_bin.join(&server_config.command); + + if venv_command.exists() { + server_config.command = venv_command.to_string_lossy().into_owned(); + server_config.env.insert( + "VIRTUAL_ENV".to_string(), + abs_venv.to_string_lossy().into_owned(), + ); + let system_path = std::env::var("PATH").unwrap_or_default(); + server_config.env.insert( + "PATH".to_string(), + if cfg!(target_os = "windows") { + format!("{};{system_path}", venv_bin.to_string_lossy()) + } else { + format!("{}:{system_path}", venv_bin.to_string_lossy()) + }, + ); + tracing::info!( + venv = %abs_venv.display(), + command = %server_config.command, + "resolved venv for MCP server" + ); + } else { + tracing::warn!( + venv = %abs_venv.display(), + command = %server_config.command, + "venv binary not found, using command as-is" + ); + } + + server_config.venv = Some(abs_venv.to_string_lossy().into_owned()); + } + } + + // Inject LOCALCOWORK_DATA_DIR so MCP servers use platform-standard paths + let app_data = data_dir().to_string_lossy().into_owned(); + for server_config in config.servers.values_mut() { + server_config + .env + .entry("LOCALCOWORK_DATA_DIR".to_string()) + .or_insert_with(|| app_data.clone()); + } + + // Inject vision model endpoint env vars + if let Some((vision_endpoint, vision_model)) = resolve_vision_model(project_root) { + for server_config in config.servers.values_mut() { + server_config + .env + .entry("LOCALCOWORK_VISION_ENDPOINT".to_string()) + .or_insert_with(|| vision_endpoint.clone()); + server_config + .env + .entry("LOCALCOWORK_VISION_MODEL".to_string()) + .or_insert_with(|| vision_model.clone()); + } + tracing::info!( + endpoint = %vision_endpoint, + model = %vision_model, + "injected vision model env vars into MCP servers" + ); + } +} + +/// Find the first vision-capable model from `_models/config.yaml`. +/// +/// Returns `(base_url, model_name)` if a model with the "vision" capability is found. +/// Checks: (1) active model, (2) fallback chain, (3) any model in the config. +fn resolve_vision_model(project_root: &std::path::Path) -> Option<(String, String)> { + let config_path = project_root.join("_models/config.yaml"); + let content = std::fs::read_to_string(&config_path).ok()?; + let yaml: serde_json::Value = serde_yaml::from_str(&content).ok()?; + + let models = yaml.get("models")?.as_object()?; + let active = yaml.get("active_model")?.as_str()?; + + // Helper: check if a model has vision capability + let has_vision = |key: &str| -> Option<(String, String)> { + let model = models.get(key)?; + let caps = model.get("capabilities")?.as_array()?; + let is_vision = caps.iter().any(|c| c.as_str() == Some("vision")); + if !is_vision { + return None; + } + let base_url = model.get("base_url")?.as_str()?.to_string(); + let model_name = model + .get("model_name") + .and_then(|v| v.as_str()) + .unwrap_or(key) + .to_string(); + Some((base_url, model_name)) + }; + + // 1. Check active model first + if let Some(result) = has_vision(active) { + return Some(result); + } + + // 2. Check fallback chain + if let Some(chain) = yaml.get("fallback_chain").and_then(|c| c.as_array()) { + for entry in chain { + if let Some(key) = entry.as_str() { + if let Some(result) = has_vision(key) { + return Some(result); + } + } + } + } + + // 3. Scan all models for any with vision capability (e.g., dedicated VL model) + for key in models.keys() { + if let Some(result) = has_vision(key) { + return Some(result); + } + } + + None +} + +/// Run the Tauri application. +pub fn run() { + // Initialize tracing FIRST — before any tracing::info!() calls + init_tracing(); + + // Initialize the SQLite-backed ConversationManager + let db_path = resolve_db_path(); + let db = match AgentDatabase::open(&db_path) { + Ok(db) => db, + Err(e) => { + tracing::error!(error = %e, path = %db_path, "failed to open agent database"); + std::process::exit(1); + } + }; + let conversation_manager = ConversationManager::new(db); + + tracing::info!(db_path = %db_path, "agent database initialized"); + + // Register an empty MCP client synchronously so that TokioMutex + // is always available in Tauri state. The async setup task will replace the + // empty client with a fully initialized one once servers are started. + // This prevents panics if start_session is called before MCP init completes. + let empty_mcp_config = mcp_client::types::McpServersConfig { + servers: std::collections::HashMap::new(), + }; + + tauri::Builder::default() + .plugin(tauri_plugin_shell::init()) + .manage(TokioMutex::new(conversation_manager)) + .manage(TokioMutex::new(InFlightRequests::default())) + .manage(TokioMutex::new(McpClient::default())) + .manage(TokioMutex::new(empty_mcp_config)) + .manage(TokioMutex::new(PermissionStore::default())) + .manage(TokioMutex::new(SamplingConfig::load_or_default())) + .manage(TokioMutex::new(None::>) + as PendingConfirmation) + .setup(|app| { + // Initialize MCP client asynchronously during app setup. + // Once servers are started, replace the empty client via lock. + let handle = app.handle().clone(); + tauri::async_runtime::spawn(async move { + // Provision missing Python venvs BEFORE resolving MCP config, + // so that discovery picks up the newly created .venv directories. + let project_root = resolve_project_root(); + commands::python_env_startup::provision_missing_venvs(&project_root).await; + + let config = resolve_mcp_config(); + let mut mcp_client = McpClient::new(config, None); + + let errors = mcp_client.start_all().await; + for (name, err) in &errors { + tracing::warn!( + server = %name, + error = %err, + "MCP server failed to start (non-fatal)" + ); + } + + // Filter tools by enabled_tools allowlist (if configured) + filter_tools_by_allowlist(&mut mcp_client, &project_root); + + + let running = mcp_client.running_server_count(); + let tools = mcp_client.tool_count(); + tracing::info!( + running_servers = running, + total_tools = tools, + "MCP client initialized" + ); + + // Replace the empty placeholder with the fully initialized client + let state: tauri::State<'_, TokioMutex> = handle.state(); + let mut lock = state.lock().await; + *lock = mcp_client; + }); + + Ok(()) + }) + .invoke_handler(tauri::generate_handler![ + commands::greet, + commands::chat::start_session, + commands::chat::send_message, + commands::chat::respond_to_confirmation, + commands::session::list_sessions, + commands::session::load_session, + commands::session::delete_session, + commands::session::get_context_budget, + commands::session::cleanup_empty_sessions, + commands::filesystem::list_directory, + commands::filesystem::get_home_dir, + commands::settings::get_models_config, + commands::settings::get_mcp_servers_status, + commands::settings::list_permission_grants, + commands::settings::revoke_permission, + commands::settings::get_sampling_config, + commands::settings::update_sampling_config, + commands::settings::reset_sampling_config, + commands::settings::get_app_settings, + commands::settings::update_app_settings, + commands::settings::add_allowed_path, + commands::settings::remove_allowed_path, + commands::settings::export_settings, + commands::settings::import_settings, + commands::settings::poll_settings_changed, + commands::settings::check_config_reload, + commands::settings::reload_model_config, + commands::hardware::detect_hardware, + commands::model_download::download_model, + commands::model_download::verify_model, + commands::model_download::get_model_dir, + commands::ollama::check_llama_server_status, + commands::ollama::check_ollama_status, + commands::ollama::list_ollama_models, + commands::ollama::pull_ollama_model, + commands::python_env::ensure_python_server_env, + commands::python_env::ensure_all_python_envs, + ]) + .run(tauri::generate_context!()) + .unwrap_or_else(|e| { + tracing::error!(error = %e, "error while running tauri application"); + std::process::exit(1); + }); +} + +#[cfg(test)] +mod tests { + use super::*; + use mcp_client::ServerConfig; + use std::collections::HashMap; + use tempfile::TempDir; + + #[test] + fn test_data_dir_returns_valid_path() { + let dir = data_dir(); + assert!(dir.is_absolute()); + assert!(dir.to_string_lossy().contains("com.localcowork.app")); + } + + #[test] + fn test_cache_dir_is_subdirectory_of_data_dir() { + let data = data_dir(); + let cache = cache_dir(); + assert!(cache.starts_with(&data)); + assert!(cache.to_string_lossy().contains("cache")); + } + + #[test] + fn test_rotate_log_file_creates_rotated_copies() { + let temp_dir = TempDir::new().unwrap(); + let log_path = temp_dir.path().join("test.log"); + + // Create original file + std::fs::write(&log_path, "original content").unwrap(); + + // Rotate + rotate_log_file(&log_path, 3); + + // Original should be moved to .1 + let rotated = log_path.with_extension("log.1"); + assert!(rotated.exists()); + + let content = std::fs::read_to_string(&rotated).unwrap(); + assert_eq!(content, "original content"); + } + + #[test] + fn test_rotate_log_file_handles_missing_file() { + let temp_dir = TempDir::new().unwrap(); + let log_path = temp_dir.path().join("nonexistent.log"); + + // Should not panic + rotate_log_file(&log_path, 3); + } + + #[test] + fn test_rotate_log_file_multiple_rotations() { + let temp_dir = TempDir::new().unwrap(); + let log_path = temp_dir.path().join("test.log"); + + // Create and rotate multiple times + std::fs::write(&log_path, "v1").unwrap(); + rotate_log_file(&log_path, 3); + + std::fs::write(&log_path, "v2").unwrap(); + rotate_log_file(&log_path, 3); + + std::fs::write(&log_path, "v3").unwrap(); + rotate_log_file(&log_path, 3); + + // Check all versions exist + assert!(log_path.with_extension("log.1").exists()); + assert!(log_path.with_extension("log.2").exists()); + assert!(log_path.with_extension("log.3").exists()); + + // Oldest should be v1 + let v1 = std::fs::read_to_string(log_path.with_extension("log.3")).unwrap(); + assert_eq!(v1, "v1"); + } + + #[test] + fn test_resolve_db_path_returns_sqlite_path() { + let path = resolve_db_path(); + assert!(path.starts_with('/')); // Should be absolute path + assert!(path.ends_with(".db")); + } + + #[test] + fn test_resolve_project_root_finds_mcp_servers() { + // This test verifies the function returns a valid path + let root = resolve_project_root(); + assert!(root.is_absolute()); + } + + fn test_filter_by_enabled_servers_filters_correctly() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + + // Create config with enabled_servers + let config_content = r#" +enabled_servers: + - filesystem + - task +"#; + std::fs::write(project_root.join("_models/config.yaml"), config_content).unwrap(); + + // Create servers + let mut servers = HashMap::new(); + servers.insert("filesystem".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + servers.insert("task".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + servers.insert("calendar".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); // Should be removed + servers.insert("email".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); // Should be removed + + let before = servers.len(); + filter_by_enabled_servers(&mut servers, project_root); + let after = servers.len(); + + assert_eq!(before, 4); + assert_eq!(after, 2); + assert!(servers.contains_key("filesystem")); + assert!(servers.contains_key("task")); + assert!(!servers.contains_key("calendar")); + assert!(!servers.contains_key("email")); + } + + #[test] + fn test_filter_by_enabled_servers_handles_missing_config() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + // No config file at all + + let mut servers = HashMap::new(); + servers.insert("a".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + servers.insert("b".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + + let before = servers.len(); + filter_by_enabled_servers(&mut servers, project_root); + + // Should keep all since no config + assert_eq!(servers.len(), before); + } + + #[test] + fn test_filter_by_enabled_servers_no_config_keeps_all() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + // No config file + + let mut servers = HashMap::new(); + servers.insert("a".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + servers.insert("b".to_string(), ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }); + + let before = servers.len(); + filter_by_enabled_servers(&mut servers, project_root); + + // Should keep all since no config + assert_eq!(servers.len(), before); + } + + #[test] + fn test_load_override_file_returns_empty_for_missing() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + + let result = load_override_file(project_root); + assert!(result.is_empty()); + } + + #[test] + fn test_load_override_file_parses_valid_config() { + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + + let config_content = r#"{ + "servers": { + "test-server": { + "command": "node", + "args": ["test.js"] + } + } + }"#; + std::fs::write(project_root.join("mcp-servers.json"), config_content).unwrap(); + + let result = load_override_file(project_root); + assert!(result.contains_key("test-server")); + } + + #[test] + fn test_resolve_vision_model_returns_none_without_config() { + let temp_dir = TempDir::new().unwrap(); + + let result = resolve_vision_model(temp_dir.path()); + assert!(result.is_none()); + } + + #[test] + fn test_filter_tools_by_allowlist_works_without_config() { + // Test that filter_tools_by_allowlist doesn't panic without config + let temp_dir = TempDir::new().unwrap(); + let project_root = temp_dir.path(); + + let mut mcp_client = McpClient::new( + mcp_client::types::McpServersConfig { servers: HashMap::new() }, + None, + ); + + // Should not panic + filter_tools_by_allowlist(&mut mcp_client, project_root); + } +} diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs new file mode 100644 index 0000000..e355c34 --- /dev/null +++ b/src-tauri/src/main.rs @@ -0,0 +1,6 @@ +// Prevents additional console window on Windows in release +#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] + +fn main() { + localcowork::run(); +} diff --git a/src-tauri/src/mcp_client/client.rs b/src-tauri/src/mcp_client/client.rs new file mode 100644 index 0000000..a1009c3 --- /dev/null +++ b/src-tauri/src/mcp_client/client.rs @@ -0,0 +1,330 @@ +//! MCP Client — high-level interface for tool execution. +//! +//! Orchestrates server lifecycle, tool registry, and tool call dispatch. +//! This is the primary API used by the ToolRouter (WS-2D). + +use std::collections::HashMap; +use std::time::Instant; + +use super::errors::McpError; +use super::lifecycle; +use super::registry::ToolRegistry; +use super::types::{McpServersConfig, ServerConfig, ToolCallResult}; + +// ─── Constants ─────────────────────────────────────────────────────────────── + +/// Default timeout for tool call execution (ms). +const DEFAULT_CALL_TIMEOUT_MS: u64 = 30_000; + +// ─── McpClient ─────────────────────────────────────────────────────────────── + +/// High-level MCP client that manages multiple servers and routes tool calls. +pub struct McpClient { + /// Running server processes. + servers: HashMap, + /// Server configurations (for restarts). + configs: HashMap, + /// Aggregated tool definitions from all servers. + pub registry: ToolRegistry, + /// Working directory for server processes. + working_dir: Option, + /// Tool call timeout in milliseconds. + call_timeout_ms: u64, +} + +impl McpClient { + /// Create a new MCP client from a servers configuration file. + pub fn new(config: McpServersConfig, working_dir: Option) -> Self { + Self { + servers: HashMap::new(), + configs: config.servers, + registry: ToolRegistry::new(), + working_dir, + call_timeout_ms: DEFAULT_CALL_TIMEOUT_MS, + } + } + + /// Set the tool call timeout in milliseconds. + pub fn set_call_timeout(&mut self, timeout_ms: u64) { + self.call_timeout_ms = timeout_ms; + } + + // ─── Lifecycle ─────────────────────────────────────────────────────── + + /// Start all configured servers and build the tool registry. + /// + /// Returns a list of servers that failed to start (partial startup is OK). + pub async fn start_all(&mut self) -> Vec<(String, McpError)> { + let (servers, errors) = + lifecycle::spawn_all_servers(&self.configs, self.working_dir.as_deref()).await; + + // Build registry from all successfully started servers + for (name, server) in &servers { + self.registry + .register_server_tools(name, server.tools.clone()); + } + + self.servers = servers; + errors + } + + /// Start a specific server by name. + pub async fn start_server(&mut self, name: &str) -> Result<(), McpError> { + let config = self.configs.get(name).ok_or(McpError::ConfigError { + reason: format!("no configuration for server '{name}'"), + })?; + + let server = + lifecycle::spawn_server(name, config, self.working_dir.as_deref()).await?; + + self.registry + .register_server_tools(name, server.tools.clone()); + self.servers.insert(name.to_string(), server); + + Ok(()) + } + + /// Shut down all servers gracefully. + pub async fn shutdown_all(&mut self) { + lifecycle::shutdown_all_servers(&mut self.servers).await; + self.registry = ToolRegistry::new(); + } + + /// Shut down a specific server. + pub async fn shutdown_server(&mut self, name: &str) { + if let Some(mut server) = self.servers.remove(name) { + let _ = server.shutdown().await; + } + self.registry.unregister_server(name); + } + + // ─── Tool Execution ────────────────────────────────────────────────── + + /// Execute a tool call, routing to the appropriate server. + /// + /// Steps: + /// 1. Validate the tool exists and arguments are structurally valid + /// 2. Find the owning server + /// 3. Send JSON-RPC `tools/call` request + /// 4. Parse and return the result + pub async fn call_tool( + &mut self, + tool_name: &str, + arguments: serde_json::Value, + ) -> Result { + let start = Instant::now(); + + // 1. Validate + self.registry.validate_tool_call(tool_name, &arguments)?; + + // 2. Find server + let server_name = self + .registry + .get_server_for_tool(tool_name) + .ok_or(McpError::UnknownTool { + name: tool_name.to_string(), + })? + .to_string(); + + let server = self + .servers + .get(&server_name) + .ok_or(McpError::ServerCrashed { + name: server_name.clone(), + reason: "server not running".into(), + })?; + + // 3. Send request + let params = serde_json::json!({ + "name": tool_name, + "arguments": arguments, + }); + + let response = tokio::time::timeout( + std::time::Duration::from_millis(self.call_timeout_ms), + server.transport.request("tools/call", Some(params)), + ) + .await + .map_err(|_| McpError::Timeout { + tool: tool_name.to_string(), + timeout_ms: self.call_timeout_ms, + })? + .map_err(|e| { + // Check if this is a transport error (server might have crashed) + if matches!(e, McpError::TransportError { .. }) { + McpError::ServerCrashed { + name: server_name.clone(), + reason: e.to_string(), + } + } else { + e + } + })?; + + let elapsed = start.elapsed().as_millis() as u64; + + // 4. Parse response + match super::transport::extract_result(response) { + Ok(result) => Ok(ToolCallResult { + tool_name: tool_name.to_string(), + success: true, + result: Some(result), + error: None, + execution_time_ms: elapsed, + }), + Err(McpError::ServerError { code, message, .. }) => Ok(ToolCallResult { + tool_name: tool_name.to_string(), + success: false, + result: None, + error: Some(format!("[{code}] {message}")), + execution_time_ms: elapsed, + }), + Err(e) => Err(e), + } + } + + /// Restart a crashed server and re-register its tools. + pub async fn restart_server(&mut self, name: &str) -> Result<(), McpError> { + let config = self.configs.get(name).ok_or(McpError::ConfigError { + reason: format!("no configuration for server '{name}'"), + })?; + + let restart_count = self + .servers + .get(name) + .map(|s| s.restart_count()) + .unwrap_or(0); + + // Remove the old server + self.registry.unregister_server(name); + if let Some(mut old) = self.servers.remove(name) { + let _ = old.shutdown().await; + } + + // Restart with backoff + let server = lifecycle::restart_server( + name, + config, + self.working_dir.as_deref(), + restart_count, + ) + .await?; + + self.registry + .register_server_tools(name, server.tools.clone()); + self.servers.insert(name.to_string(), server); + + Ok(()) + } + + // ─── Status ────────────────────────────────────────────────────────── + + /// Get the number of running servers. + pub fn running_server_count(&self) -> usize { + self.servers.len() + } + + /// Get the number of registered tools. + pub fn tool_count(&self) -> usize { + self.registry.len() + } + + /// Check if a specific server is running. + pub fn is_server_running(&self, name: &str) -> bool { + self.servers.contains_key(name) + } + + /// Get a list of running server names. + pub fn running_servers(&self) -> Vec { + self.servers.keys().cloned().collect() + } + + /// Get names of all configured servers (including those that failed to start). + pub fn configured_servers(&self) -> Vec { + let mut names: Vec = self.configs.keys().cloned().collect(); + names.sort(); + names + } +} + +impl Default for McpClient { + fn default() -> Self { + Self { + servers: HashMap::new(), + configs: HashMap::new(), + registry: ToolRegistry::new(), + working_dir: None, + call_timeout_ms: DEFAULT_CALL_TIMEOUT_MS, + } + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn empty_config() -> McpServersConfig { + McpServersConfig { + servers: HashMap::new(), + } + } + + #[test] + fn test_new_client_empty() { + let client = McpClient::new(empty_config(), None); + assert_eq!(client.running_server_count(), 0); + assert_eq!(client.tool_count(), 0); + assert!(client.registry.is_empty()); + } + + #[test] + fn test_set_call_timeout() { + let mut client = McpClient::new(empty_config(), None); + client.set_call_timeout(5000); + assert_eq!(client.call_timeout_ms, 5000); + } + + #[test] + fn test_is_server_running() { + let client = McpClient::new(empty_config(), None); + assert!(!client.is_server_running("filesystem")); + } + + #[test] + fn test_running_servers_empty() { + let client = McpClient::new(empty_config(), None); + assert!(client.running_servers().is_empty()); + } + + #[test] + fn test_configured_servers() { + let mut servers = HashMap::new(); + servers.insert( + "zeta".to_string(), + ServerConfig { + command: "npx".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }, + ); + servers.insert( + "alpha".to_string(), + ServerConfig { + command: "npx".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }, + ); + let config = McpServersConfig { servers }; + let client = McpClient::new(config, None); + + let names = client.configured_servers(); + assert_eq!(names, vec!["alpha", "zeta"]); // sorted + } +} diff --git a/src-tauri/src/mcp_client/discovery.rs b/src-tauri/src/mcp_client/discovery.rs new file mode 100644 index 0000000..97a84ef --- /dev/null +++ b/src-tauri/src/mcp_client/discovery.rs @@ -0,0 +1,362 @@ +//! MCP Server Auto-Discovery — scan `mcp-servers/` and build configs from conventions. +//! +//! Eliminates the need to manually maintain `mcp-servers.json` for every server. +//! Servers are detected by the presence of `package.json` (TypeScript) or +//! `pyproject.toml` (Python). The JSON file becomes an optional override. + +use std::collections::HashMap; +use std::path::Path; + +use super::types::ServerConfig; + +// ─── Language Detection ────────────────────────────────────────────────────── + +/// Detected language of an MCP server. +#[derive(Debug, PartialEq, Eq)] +enum ServerLanguage { + TypeScript, + Python, +} + +/// Detect the language of a server directory by checking marker files. +/// +/// - `package.json` → TypeScript +/// - `pyproject.toml` → Python +/// - Neither → `None` (not a server) +fn detect_language(server_dir: &Path) -> Option { + if server_dir.join("package.json").exists() { + Some(ServerLanguage::TypeScript) + } else if server_dir.join("pyproject.toml").exists() { + Some(ServerLanguage::Python) + } else { + None + } +} + +// ─── Platform Helpers ──────────────────────────────────────────────────────── + +/// Platform-correct npx command. +/// +/// Windows requires `npx.cmd` because `npx` is a batch script; +/// `Command::new("npx")` fails without the extension on Windows. +fn default_npx_command() -> &'static str { + if cfg!(target_os = "windows") { + "npx.cmd" + } else { + "npx" + } +} + +/// Platform-correct Python command. +/// +/// macOS 12.3+ removed the `python` symlink; only `python3` exists. +/// Windows installs Python as `python.exe` via the official installer. +fn default_python_command() -> &'static str { + if cfg!(target_os = "windows") { + "python" + } else { + "python3" + } +} + +// ─── Config Generation ────────────────────────────────────────────────────── + +/// Generate a `ServerConfig` for a TypeScript MCP server. +fn ts_config(name: &str) -> ServerConfig { + ServerConfig { + command: default_npx_command().to_string(), + args: vec!["tsx".to_string(), "src/index.ts".to_string()], + env: HashMap::new(), + cwd: Some(format!("mcp-servers/{name}")), + venv: None, + } +} + +/// Detect the Python entry point module for a server. +/// +/// Checks for `src/server.py` first (preferred convention), then +/// `src/main.py` (used by knowledge, security, meeting, screenshot-pipeline). +fn detect_py_entry_module(server_dir: &Path) -> String { + if server_dir.join("src").join("server.py").exists() { + "src.server".to_string() + } else if server_dir.join("src").join("main.py").exists() { + "src.main".to_string() + } else { + // Fallback to convention — will fail at runtime with a clear error + "src.server".to_string() + } +} + +/// Generate a `ServerConfig` for a Python MCP server. +/// +/// If a `.venv` directory exists inside the server dir, sets `venv: ".venv"`. +/// Detects the entry point module by checking for `src/server.py` or `src/main.py`. +fn py_config(name: &str, server_dir: &Path) -> ServerConfig { + let venv = if server_dir.join(".venv").is_dir() { + Some(".venv".to_string()) + } else { + None + }; + + let entry_module = detect_py_entry_module(server_dir); + + ServerConfig { + command: default_python_command().to_string(), + args: vec!["-m".to_string(), entry_module], + env: HashMap::new(), + cwd: Some(format!("mcp-servers/{name}")), + venv, + } +} + +// ─── Discovery ────────────────────────────────────────────────────────────── + +/// Scan the `mcp-servers/` directory and generate `ServerConfig` entries. +/// +/// Skips directories starting with `_` or `.`. Returns an empty map if the +/// directory doesn't exist (graceful degradation). +pub fn discover_servers(mcp_servers_dir: &Path) -> HashMap { + let mut configs = HashMap::new(); + + let entries = match std::fs::read_dir(mcp_servers_dir) { + Ok(entries) => entries, + Err(e) => { + tracing::warn!( + path = %mcp_servers_dir.display(), + error = %e, + "mcp-servers directory not found, skipping auto-discovery" + ); + return configs; + } + }; + + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_dir() { + continue; + } + + let dir_name = match path.file_name().and_then(|n| n.to_str()) { + Some(name) => name.to_string(), + None => continue, + }; + + // Skip internal/hidden directories + if dir_name.starts_with('_') || dir_name.starts_with('.') { + continue; + } + + if let Some(language) = detect_language(&path) { + let config = match language { + ServerLanguage::TypeScript => ts_config(&dir_name), + ServerLanguage::Python => py_config(&dir_name, &path), + }; + tracing::debug!( + server = %dir_name, + language = ?language, + "auto-discovered MCP server" + ); + configs.insert(dir_name, config); + } + } + + configs +} + +// ─── Merge ────────────────────────────────────────────────────────────────── + +/// Merge auto-discovered configs with manual overrides from `mcp-servers.json`. +/// +/// Override entries **fully replace** discovered entries for the same server name. +/// Override-only servers (not on disk) are added as-is (supports external servers). +pub fn merge_configs( + mut discovered: HashMap, + overrides: HashMap, +) -> HashMap { + for (name, override_config) in overrides { + // Override fully replaces the discovered config + discovered.insert(name, override_config); + } + discovered +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn test_discover_ts_server() { + let tmp = TempDir::new().unwrap(); + let server_dir = tmp.path().join("filesystem"); + std::fs::create_dir(&server_dir).unwrap(); + std::fs::write(server_dir.join("package.json"), "{}").unwrap(); + + let configs = discover_servers(tmp.path()); + assert_eq!(configs.len(), 1); + assert!(configs.contains_key("filesystem")); + + let cfg = &configs["filesystem"]; + assert_eq!(cfg.command, default_npx_command()); + assert_eq!(cfg.args, vec!["tsx", "src/index.ts"]); + assert_eq!(cfg.cwd, Some("mcp-servers/filesystem".to_string())); + assert_eq!(cfg.venv, None); + } + + #[test] + fn test_discover_py_server_with_server_py() { + let tmp = TempDir::new().unwrap(); + let server_dir = tmp.path().join("document"); + std::fs::create_dir_all(server_dir.join("src")).unwrap(); + std::fs::write(server_dir.join("pyproject.toml"), "").unwrap(); + std::fs::write(server_dir.join("src").join("server.py"), "").unwrap(); + + let configs = discover_servers(tmp.path()); + assert_eq!(configs.len(), 1); + + let cfg = &configs["document"]; + assert_eq!(cfg.command, default_python_command()); + assert_eq!(cfg.args, vec!["-m", "src.server"]); + assert_eq!(cfg.venv, None); + } + + #[test] + fn test_discover_py_server_with_main_py() { + let tmp = TempDir::new().unwrap(); + let server_dir = tmp.path().join("knowledge"); + std::fs::create_dir_all(server_dir.join("src")).unwrap(); + std::fs::write(server_dir.join("pyproject.toml"), "").unwrap(); + std::fs::write(server_dir.join("src").join("main.py"), "").unwrap(); + + let configs = discover_servers(tmp.path()); + assert_eq!(configs.len(), 1); + + let cfg = &configs["knowledge"]; + assert_eq!(cfg.command, default_python_command()); + assert_eq!(cfg.args, vec!["-m", "src.main"]); + assert_eq!(cfg.venv, None); + } + + #[test] + fn test_discover_py_server_with_venv() { + let tmp = TempDir::new().unwrap(); + let server_dir = tmp.path().join("ocr"); + std::fs::create_dir(&server_dir).unwrap(); + std::fs::write(server_dir.join("pyproject.toml"), "").unwrap(); + std::fs::create_dir(server_dir.join(".venv")).unwrap(); + + let configs = discover_servers(tmp.path()); + let cfg = &configs["ocr"]; + assert_eq!(cfg.command, default_python_command()); + assert_eq!(cfg.venv, Some(".venv".to_string())); + } + + #[test] + fn test_skip_underscore_dirs() { + let tmp = TempDir::new().unwrap(); + let shared = tmp.path().join("_shared"); + std::fs::create_dir(&shared).unwrap(); + std::fs::write(shared.join("package.json"), "{}").unwrap(); + + let configs = discover_servers(tmp.path()); + assert!(configs.is_empty()); + } + + #[test] + fn test_skip_dot_dirs() { + let tmp = TempDir::new().unwrap(); + let hidden = tmp.path().join(".hidden"); + std::fs::create_dir(&hidden).unwrap(); + std::fs::write(hidden.join("package.json"), "{}").unwrap(); + + let configs = discover_servers(tmp.path()); + assert!(configs.is_empty()); + } + + #[test] + fn test_skip_non_directory_files() { + let tmp = TempDir::new().unwrap(); + std::fs::write(tmp.path().join("README.md"), "# readme").unwrap(); + + let configs = discover_servers(tmp.path()); + assert!(configs.is_empty()); + } + + #[test] + fn test_missing_directory() { + let configs = discover_servers(Path::new("/nonexistent/path/mcp-servers")); + assert!(configs.is_empty()); + } + + #[test] + fn test_merge_override_replaces() { + let mut discovered = HashMap::new(); + discovered.insert( + "filesystem".to_string(), + ts_config("filesystem"), + ); + + let mut overrides = HashMap::new(); + overrides.insert( + "filesystem".to_string(), + ServerConfig { + command: "node".to_string(), + args: vec!["dist/index.js".to_string()], + env: HashMap::new(), + cwd: Some("/custom/path".to_string()), + venv: None, + }, + ); + + let merged = merge_configs(discovered, overrides); + assert_eq!(merged["filesystem"].command, "node"); + assert_eq!(merged["filesystem"].cwd, Some("/custom/path".to_string())); + } + + #[test] + fn test_merge_adds_override_only_servers() { + let discovered = HashMap::new(); + let mut overrides = HashMap::new(); + overrides.insert( + "external".to_string(), + ServerConfig { + command: "custom-mcp".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }, + ); + + let merged = merge_configs(discovered, overrides); + assert!(merged.contains_key("external")); + assert_eq!(merged["external"].command, "custom-mcp"); + } + + #[test] + fn test_merge_preserves_non_overridden() { + let mut discovered = HashMap::new(); + discovered.insert("fs".to_string(), ts_config("fs")); + discovered.insert("ocr".to_string(), ts_config("ocr")); + + let mut overrides = HashMap::new(); + overrides.insert( + "fs".to_string(), + ServerConfig { + command: "node".to_string(), + args: vec![], + env: HashMap::new(), + cwd: None, + venv: None, + }, + ); + + let merged = merge_configs(discovered, overrides); + // fs was overridden + assert_eq!(merged["fs"].command, "node"); + // ocr was NOT overridden — preserved from discovery + assert_eq!(merged["ocr"].command, "npx"); + } +} diff --git a/src-tauri/src/mcp_client/errors.rs b/src-tauri/src/mcp_client/errors.rs new file mode 100644 index 0000000..2123d83 --- /dev/null +++ b/src-tauri/src/mcp_client/errors.rs @@ -0,0 +1,76 @@ +//! MCP Client error types. + +use thiserror::Error; + +/// Errors that can occur during MCP client operations. +#[derive(Debug, Error)] +pub enum McpError { + /// A server process failed to start. + #[error("failed to spawn server '{name}': {reason}")] + SpawnFailed { + name: String, + reason: String, + }, + + /// The initialization handshake failed. + #[error("server '{name}' initialization failed: {reason}")] + InitFailed { + name: String, + reason: String, + }, + + /// JSON-RPC communication error (malformed message, I/O error). + #[error("transport error for server '{server}': {reason}")] + TransportError { + server: String, + reason: String, + }, + + /// Server returned a JSON-RPC error response. + #[error("server error [{code}]: {message}")] + ServerError { + code: i32, + message: String, + data: Option, + }, + + /// Tool not found in the aggregated registry. + #[error("unknown tool: '{name}'")] + UnknownTool { + name: String, + }, + + /// Tool call arguments failed schema validation. + #[error("invalid arguments for '{tool}': {reason}")] + InvalidArguments { + tool: String, + reason: String, + }, + + /// A tool call timed out. + #[error("tool call '{tool}' timed out after {timeout_ms}ms")] + Timeout { + tool: String, + timeout_ms: u64, + }, + + /// Server process crashed unexpectedly. + #[error("server '{name}' crashed: {reason}")] + ServerCrashed { + name: String, + reason: String, + }, + + /// Configuration error (missing servers, bad config file). + #[error("config error: {reason}")] + ConfigError { + reason: String, + }, + + /// All restart attempts exhausted for a server. + #[error("server '{name}' failed after {attempts} restart attempts")] + RestartExhausted { + name: String, + attempts: u32, + }, +} diff --git a/src-tauri/src/mcp_client/lifecycle.rs b/src-tauri/src/mcp_client/lifecycle.rs new file mode 100644 index 0000000..c670f74 --- /dev/null +++ b/src-tauri/src/mcp_client/lifecycle.rs @@ -0,0 +1,334 @@ +//! Server process lifecycle management. +//! +//! Handles spawning, monitoring, restarting, and shutting down MCP server +//! child processes. Each server runs as a separate OS process communicating +//! via JSON-RPC over stdio. + +use std::collections::HashMap; +use std::time::Duration; + +use tokio::process::{Child, Command}; +use tokio::time::sleep; + +use super::errors::McpError; +use super::transport::StdioTransport; +use super::types::{InitializeResult, McpToolDefinition, ServerConfig}; + +// ─── Constants ─────────────────────────────────────────────────────────────── + +/// Maximum restart attempts before giving up on a server. +const MAX_RESTART_ATTEMPTS: u32 = 3; + +/// Base delay between restart attempts (doubles each time). +const RESTART_BASE_DELAY: Duration = Duration::from_secs(1); + +/// Timeout for the initialize handshake. +/// +/// Set to 30s to accommodate ML-heavy servers (meeting, ocr) that import +/// PyTorch, Whisper, and other large frameworks at startup. +const INIT_TIMEOUT: Duration = Duration::from_secs(30); + +/// Timeout for graceful shutdown before force-killing. +const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +// ─── ManagedServer ─────────────────────────────────────────────────────────── + +/// A running MCP server process with its transport and tool definitions. +pub struct ManagedServer { + /// Human-readable server name (e.g., "filesystem"). + pub name: String, + /// The child process handle. + process: Child, + /// JSON-RPC transport (stdin/stdout). + pub transport: StdioTransport, + /// Tool definitions received during initialization. + pub tools: Vec, + /// Number of times this server has been restarted. + restart_count: u32, + /// The original config used to spawn this server (retained for restart). + #[allow(dead_code)] + config: ServerConfig, +} + +impl ManagedServer { + /// How many times this server has been restarted. + pub fn restart_count(&self) -> u32 { + self.restart_count + } + + /// Check if the server process is still running. + pub async fn is_alive(&mut self) -> bool { + match self.process.try_wait() { + Ok(None) => true, // Still running + Ok(Some(_)) => false, // Exited + Err(_) => false, // Error checking — assume dead + } + } + + /// Attempt to gracefully shut down the server. + pub async fn shutdown(&mut self) -> Result<(), McpError> { + // Send shutdown notification (best-effort) + let _ = self.transport.notify("shutdown", None).await; + + // Wait for graceful exit + let result = tokio::time::timeout(SHUTDOWN_TIMEOUT, self.process.wait()).await; + + match result { + Ok(Ok(_)) => Ok(()), + _ => { + // Force kill if graceful shutdown failed/timed out + let _ = self.process.kill().await; + Ok(()) + } + } + } +} + +// ─── Spawning ──────────────────────────────────────────────────────────────── + +/// Spawn a single MCP server process and perform the initialization handshake. +/// +/// Returns a `ManagedServer` with its transport and discovered tools. +pub async fn spawn_server( + name: &str, + config: &ServerConfig, + working_dir: Option<&str>, +) -> Result { + let mut cmd = Command::new(&config.command); + cmd.args(&config.args); + + // Set environment variables + for (key, value) in &config.env { + cmd.env(key, value); + } + + // Set working directory: per-server cwd overrides the global working_dir + let effective_dir = config.cwd.as_deref().or(working_dir); + if let Some(dir) = effective_dir { + cmd.current_dir(dir); + } + + // Windows: prevent console window from appearing for child processes + #[cfg(target_os = "windows")] + { + use std::os::windows::process::CommandExt; + const CREATE_NO_WINDOW: u32 = 0x08000000; + cmd.creation_flags(CREATE_NO_WINDOW); + } + + // Wire stdio for JSON-RPC + cmd.stdin(std::process::Stdio::piped()); + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); // Capture stderr for logging + + let mut child = cmd.spawn().map_err(|e| McpError::SpawnFailed { + name: name.to_string(), + reason: format!("{e}"), + })?; + + let stdin = child.stdin.take().ok_or(McpError::SpawnFailed { + name: name.to_string(), + reason: "failed to capture stdin".into(), + })?; + + let stdout = child.stdout.take().ok_or(McpError::SpawnFailed { + name: name.to_string(), + reason: "failed to capture stdout".into(), + })?; + + // Extract stderr for diagnostic capture on failure + let stderr_handle = child.stderr.take(); + + let transport = StdioTransport::new(name, stdin, stdout); + + // Perform initialization handshake with timeout + let tools = match tokio::time::timeout(INIT_TIMEOUT, initialize(&transport, name)).await { + Ok(Ok(tools)) => tools, + Ok(Err(e)) => { + let stderr_ctx = read_stderr_on_failure(stderr_handle).await; + if !stderr_ctx.is_empty() { + tracing::warn!( + server = name, + stderr = %stderr_ctx, + "server stderr captured on failure" + ); + } + let reason = format!("{e}{}", format_stderr_suffix(&stderr_ctx)); + return Err(McpError::InitFailed { + name: name.to_string(), + reason, + }); + } + Err(_) => { + let stderr_ctx = read_stderr_on_failure(stderr_handle).await; + if !stderr_ctx.is_empty() { + tracing::warn!( + server = name, + stderr = %stderr_ctx, + "server stderr captured on timeout" + ); + } + let _ = child.kill().await; + let reason = format!( + "initialization timed out after {}s{}", + INIT_TIMEOUT.as_secs(), + format_stderr_suffix(&stderr_ctx) + ); + return Err(McpError::InitFailed { + name: name.to_string(), + reason, + }); + } + }; + + Ok(ManagedServer { + name: name.to_string(), + process: child, + transport, + tools, + restart_count: 0, + config: config.clone(), + }) +} + +/// Read any available stderr output from a failed server process. +/// +/// Uses a short timeout to avoid blocking if stderr is empty or the process +/// is still writing. Truncates to 2000 chars to keep log messages readable. +async fn read_stderr_on_failure( + stderr_handle: Option, +) -> String { + use tokio::io::AsyncReadExt; + + let Some(mut stderr) = stderr_handle else { + return String::new(); + }; + + let mut buf = String::new(); + match tokio::time::timeout( + Duration::from_millis(500), + stderr.read_to_string(&mut buf), + ) + .await + { + Ok(Ok(_)) => { + if buf.len() > 2000 { + buf.truncate(2000); + buf.push_str("...(truncated)"); + } + buf + } + _ => String::new(), + } +} + +/// Format a stderr suffix for error messages (empty string if no stderr). +fn format_stderr_suffix(stderr: &str) -> String { + if stderr.is_empty() { + String::new() + } else { + format!(" | stderr: {}", stderr.trim()) + } +} + +/// Perform the MCP initialization handshake. +async fn initialize( + transport: &StdioTransport, + server_name: &str, +) -> Result, McpError> { + let response = transport.request("initialize", None).await?; + + let result = super::transport::extract_result(response)?; + + let init_result: InitializeResult = + serde_json::from_value(result).map_err(|e| McpError::InitFailed { + name: server_name.to_string(), + reason: format!("failed to parse initialize response: {e}"), + })?; + + Ok(init_result.tools) +} + +/// Restart a crashed server with exponential backoff. +/// +/// Returns the new `ManagedServer` if successful, or an error if all +/// attempts are exhausted. +pub async fn restart_server( + name: &str, + config: &ServerConfig, + working_dir: Option<&str>, + current_restart_count: u32, +) -> Result { + if current_restart_count >= MAX_RESTART_ATTEMPTS { + return Err(McpError::RestartExhausted { + name: name.to_string(), + attempts: MAX_RESTART_ATTEMPTS, + }); + } + + // Exponential backoff: 1s, 2s, 4s + let delay = RESTART_BASE_DELAY * 2u32.pow(current_restart_count); + sleep(delay).await; + + let mut server = spawn_server(name, config, working_dir).await?; + server.restart_count = current_restart_count + 1; + Ok(server) +} + +// ─── Batch Operations ──────────────────────────────────────────────────────── + +/// Spawn all configured servers concurrently. +/// +/// Returns a map of server name → `ManagedServer`. Servers that fail to +/// start are logged but not included (partial startup is acceptable). +pub async fn spawn_all_servers( + configs: &HashMap, + working_dir: Option<&str>, +) -> (HashMap, Vec<(String, McpError)>) { + let mut servers = HashMap::new(); + let mut errors = Vec::new(); + + // Spawn servers concurrently using join_all + let mut handles: Vec<(String, _)> = Vec::new(); + for (name, config) in configs { + let name = name.clone(); + let config = config.clone(); + let wd = working_dir.map(|s| s.to_string()); + handles.push(( + name.clone(), + tokio::spawn(async move { + spawn_server(&name, &config, wd.as_deref()).await + }), + )); + } + + for (name, handle) in handles { + match handle.await { + Ok(Ok(server)) => { + servers.insert(name, server); + } + Ok(Err(e)) => { + errors.push((name, e)); + } + Err(e) => { + errors.push(( + name.clone(), + McpError::SpawnFailed { + name, + reason: format!("join error: {e}"), + }, + )); + } + } + } + + (servers, errors) +} + +/// Shut down all managed servers gracefully. +pub async fn shutdown_all_servers(servers: &mut HashMap) { + for (_, server) in servers.iter_mut() { + let _ = server.shutdown().await; + } + servers.clear(); +} diff --git a/src-tauri/src/mcp_client/mod.rs b/src-tauri/src/mcp_client/mod.rs new file mode 100644 index 0000000..59ed5a6 --- /dev/null +++ b/src-tauri/src/mcp_client/mod.rs @@ -0,0 +1,25 @@ +//! MCP Client — JSON-RPC over stdio transport for MCP server management. +//! +//! This module handles: +//! - Spawning and managing MCP server child processes +//! - JSON-RPC 2.0 communication over process stdio +//! - Tool discovery and aggregation across all servers +//! - Tool call routing, validation, and execution +//! - Server lifecycle (start, restart with backoff, graceful shutdown) +//! +//! The MCP Client is used by the ToolRouter (WS-2D) to dispatch tool calls +//! from the LLM to the appropriate MCP server. + +pub mod client; +pub mod discovery; +pub mod errors; +pub mod lifecycle; +pub mod registry; +pub mod transport; +pub mod types; + +// Re-exports for convenience +pub use client::McpClient; +pub use errors::McpError; +pub use registry::{CategoryRegistry, ToolCategory, ToolRegistry, ToolResolution}; +pub use types::{McpServersConfig, McpToolDefinition, ServerConfig, ToolCallResult}; diff --git a/src-tauri/src/mcp_client/registry.rs b/src-tauri/src/mcp_client/registry.rs new file mode 100644 index 0000000..8cf1cde --- /dev/null +++ b/src-tauri/src/mcp_client/registry.rs @@ -0,0 +1,1605 @@ +//! Tool registry — aggregates tool definitions across all MCP servers. +//! +//! Provides: +//! - Tool lookup by fully-qualified name (`server.tool`) +//! - Server-name extraction from tool names +//! - Validation that a tool call matches the registered schema +//! - Serialization of tools into the LLM system prompt format + +use std::collections::HashMap; + +use super::errors::McpError; +use super::types::McpToolDefinition; + +// ─── ToolRegistry ──────────────────────────────────────────────────────────── + +/// Aggregated tool registry across all MCP servers. +/// +/// Tool names are stored as `"server_name.tool_name"` (e.g., `"filesystem.list_dir"`). +#[derive(Debug, Clone)] +pub struct ToolRegistry { + /// `tool_name → (server_name, definition)`. + tools: HashMap, +} + +impl ToolRegistry { + /// Create an empty registry. + pub fn new() -> Self { + Self { + tools: HashMap::new(), + } + } + + /// Register tools from a server. + /// + /// Tool names are expected to already be fully qualified (`server.tool`). + /// If not, they are prefixed with the server name. + pub fn register_server_tools(&mut self, server_name: &str, tools: Vec) { + for tool in tools { + let fq_name = if tool.name.contains('.') { + tool.name.clone() + } else { + format!("{server_name}.{}", tool.name) + }; + + self.tools + .insert(fq_name, (server_name.to_string(), tool)); + } + } + + /// Remove all tools belonging to a server. + pub fn unregister_server(&mut self, server_name: &str) { + self.tools.retain(|_, (srv, _)| srv != server_name); + } + + /// Look up a tool by its fully-qualified name. + pub fn get_tool(&self, name: &str) -> Option<&McpToolDefinition> { + self.tools.get(name).map(|(_, def)| def) + } + + /// Get the server name that owns a tool. + pub fn get_server_for_tool(&self, tool_name: &str) -> Option<&str> { + self.tools.get(tool_name).map(|(srv, _)| srv.as_str()) + } + + /// Extract the server name from a fully-qualified tool name. + /// + /// E.g., `"filesystem.list_dir"` → `"filesystem"`. + pub fn server_name_from_tool(tool_name: &str) -> Option<&str> { + tool_name.split('.').next() + } + + /// Check whether a tool requires user confirmation before execution. + pub fn requires_confirmation(&self, tool_name: &str) -> bool { + self.tools + .get(tool_name) + .map(|(_, def)| def.confirmation_required) + .unwrap_or(true) // Default to requiring confirmation for unknown tools + } + + /// Check whether a tool supports undo. + pub fn supports_undo(&self, tool_name: &str) -> bool { + self.tools + .get(tool_name) + .map(|(_, def)| def.undo_supported) + .unwrap_or(false) + } + + /// Return all registered tool definitions. + pub fn all_tools(&self) -> Vec<&McpToolDefinition> { + self.tools.values().map(|(_, def)| def).collect() + } + + /// Return all registered tool names. + pub fn tool_names(&self) -> Vec<&str> { + self.tools.keys().map(|k| k.as_str()).collect() + } + + /// Return `(name, description)` pairs for all tools. + /// + /// Used by the RAG pre-filter to build the embedding index. + pub fn tool_name_description_pairs(&self) -> Vec<(String, String)> { + self.tools + .iter() + .map(|(name, (_, def))| (name.clone(), def.description.clone())) + .collect() + } + + /// Number of registered tools. + pub fn len(&self) -> usize { + self.tools.len() + } + + /// Whether the registry is empty. + pub fn is_empty(&self) -> bool { + self.tools.is_empty() + } + + /// Count tools belonging to a specific server. + pub fn tools_for_server(&self, server_name: &str) -> usize { + self.tools + .values() + .filter(|(srv, _)| srv == server_name) + .count() + } + + /// Return fully-qualified tool names belonging to a specific server. + pub fn tool_names_for_server(&self, server_name: &str) -> Vec { + let mut names: Vec = self + .tools + .iter() + .filter(|(_, (srv, _))| srv == server_name) + .map(|(name, _)| name.clone()) + .collect(); + names.sort(); + names + } + + /// Retain only tools whose fully-qualified names appear in the allowlist. + /// + /// Removes all tools not in the set. Used by `enabled_tools` config to + /// curate a tight, high-accuracy tool surface for specific deployments. + pub fn retain_tools(&mut self, allowed: &std::collections::HashSet) { + let before = self.tools.len(); + self.tools.retain(|name, _| allowed.contains(name)); + let after = self.tools.len(); + tracing::info!( + before, + after, + "filtered tool registry by enabled_tools allowlist" + ); + } + + /// Return all unique server names. + pub fn server_names(&self) -> Vec { + let mut names: Vec = self + .tools + .values() + .map(|(srv, _)| srv.clone()) + .collect::>() + .into_iter() + .collect(); + names.sort(); + names + } + + /// Validate a tool call: tool exists and arguments match schema. + /// + /// This is a basic structural check — required fields present, correct types + /// for top-level fields. Full JSON Schema validation is deferred to the + /// server itself. + pub fn validate_tool_call( + &self, + tool_name: &str, + arguments: &serde_json::Value, + ) -> Result<(), McpError> { + let def = self.get_tool(tool_name).ok_or(McpError::UnknownTool { + name: tool_name.to_string(), + })?; + + // Validate required fields if schema specifies them + if let Some(required) = def.params_schema.get("required") { + if let Some(required_arr) = required.as_array() { + let args_obj = arguments.as_object(); + for field in required_arr { + if let Some(field_name) = field.as_str() { + let has_field = args_obj + .map(|obj| obj.contains_key(field_name)) + .unwrap_or(false); + if !has_field { + return Err(McpError::InvalidArguments { + tool: tool_name.to_string(), + reason: format!("missing required field: '{field_name}'"), + }); + } + } + } + } + } + + Ok(()) + } + + /// Generate a concise capability summary for the system prompt. + /// + /// Lists available servers with tool counts and categorizes them by + /// action type (read vs write) based on `confirmation_required` metadata. + /// Designed to be compact (~170 tokens) so it fits within the system + /// prompt budget alongside behavioral rules and few-shot examples. + pub fn capability_summary(&self) -> String { + if self.is_empty() { + return "No MCP tools currently available. Built-in tools: \ + list_directory, read_file." + .to_string(); + } + + let server_names = self.server_names(); + let total_tools = self.len(); + + // Build per-server summaries: "filesystem (9)" + let server_parts: Vec = server_names + .iter() + .map(|name| { + let count = self.tools_for_server(name); + format!("{name} ({count})") + }) + .collect(); + + // Categorize servers by whether they have read-only or write tools + let mut read_servers: Vec = Vec::new(); + let mut write_servers: Vec = Vec::new(); + + for name in &server_names { + let has_read = self + .tools + .values() + .any(|(srv, def)| srv == name && !def.confirmation_required); + let has_write = self + .tools + .values() + .any(|(srv, def)| srv == name && def.confirmation_required); + + if has_read { + read_servers.push(name.clone()); + } + if has_write { + write_servers.push(name.clone()); + } + } + + let mut summary = format!( + "Available capabilities ({total_tools} tools across {} servers): {}.", + server_names.len(), + server_parts.join(", "), + ); + + if !read_servers.is_empty() { + summary.push_str(&format!( + "\nREAD servers (execute immediately): {}.", + read_servers.join(", ") + )); + } + + if !write_servers.is_empty() { + summary.push_str(&format!( + "\nWRITE servers (confirmation shown automatically): {}.", + write_servers.join(", ") + )); + } + + summary + } + + /// Serialize all tool definitions into OpenAI function-calling format. + /// + /// Used to populate the `tools` field in chat completion requests. + pub fn to_openai_tools(&self) -> Vec { + self.tools + .iter() + .map(|(fq_name, (_, def))| { + serde_json::json!({ + "type": "function", + "function": { + "name": fq_name, + "description": def.description, + "parameters": def.params_schema, + } + }) + }) + .collect() + } + + /// Serialize tool definitions for a specific set of tool names. + /// + /// Used by the two-pass category system to expand selected categories + /// into their real tool definitions. Tools not found in the registry + /// are silently skipped. + pub fn to_openai_tools_filtered(&self, tool_names: &[String]) -> Vec { + tool_names + .iter() + .filter_map(|name| { + self.tools.get(name).map(|(_, def)| { + serde_json::json!({ + "type": "function", + "function": { + "name": name, + "description": def.description, + "parameters": def.params_schema, + } + }) + }) + }) + .collect() + } +} + +// ─── Tool Resolution ──────────────────────────────────────────────────────── + +/// Known semantic equivalences where Levenshtein edit distance fails. +/// +/// Checked AFTER exact match and unprefixed lookup, BEFORE fuzzy matching. +/// Only the tool suffix (the part after the dot) is matched — the server +/// prefix is preserved from the original name. +/// +/// Example: `filesystem.rename_file` → `filesystem.move_file` (not `read_file` +/// which is closer by edit distance but semantically wrong). +const SEMANTIC_ALIASES: &[(&str, &str)] = &[ + ("rename_file", "move_file"), + ("rename", "move_file"), + ("delete_file", "move_to_trash"), + ("remove_file", "move_to_trash"), +]; + +/// Result of resolving a tool name against the registry. +/// +/// The agent loop uses this to understand *how* a name was resolved and to +/// generate helpful error messages when tools are not found. +#[derive(Debug, Clone, PartialEq)] +pub enum ToolResolution { + /// Name found as-is in the registry. + Exact(String), + + /// Name lacked a server prefix but uniquely matched one registered tool. + Unprefixed { + resolved: String, + original: String, + }, + + /// Name was qualified (`server.tool`) but didn't exist. A similar tool + /// from the same server was found via edit distance. + Corrected { + resolved: String, + original: String, + score: f64, + }, + + /// No match found. `suggestions` contains up to 3 similar tool names. + NotFound { + original: String, + suggestions: Vec, + }, +} + +impl ToolResolution { + /// The resolved tool name, if resolution succeeded. + pub fn resolved_name(&self) -> Option<&str> { + match self { + Self::Exact(name) => Some(name), + Self::Unprefixed { resolved, .. } => Some(resolved), + Self::Corrected { resolved, .. } => Some(resolved), + Self::NotFound { .. } => None, + } + } + + /// Whether this resolution found a usable tool name. + pub fn is_resolved(&self) -> bool { + !matches!(self, Self::NotFound { .. }) + } +} + +impl ToolRegistry { + /// Resolve a tool name that may be wrong, unprefixed, or hallucinated. + /// + /// Strategy (first match wins): + /// 1. **Exact:** name exists in the registry as-is. + /// 2. **Unprefixed:** name has no dot — search for `*.{name}`, unique match wins. + /// 3. **Same-server fuzzy:** name has dot but doesn't exist — find the most + /// similar tool from the same server prefix via Levenshtein distance. + /// 4. **NotFound:** no match above `min_similarity` threshold (0.0–1.0). + pub fn resolve(&self, name: &str, min_similarity: f64) -> ToolResolution { + // 1. Exact match + if self.get_tool(name).is_some() { + return ToolResolution::Exact(name.to_string()); + } + + // 2. Unprefixed (no dot) — search for `*.{name}` + if !name.contains('.') { + let suffix = format!(".{name}"); + let candidates: Vec<&str> = self + .tools + .keys() + .filter(|fq| fq.ends_with(&suffix)) + .map(|s| s.as_str()) + .collect(); + + return match candidates.len() { + 1 => ToolResolution::Unprefixed { + resolved: candidates[0].to_string(), + original: name.to_string(), + }, + 0 => ToolResolution::NotFound { + original: name.to_string(), + suggestions: self.find_similar(name, 3), + }, + _ => { + // Ambiguous — return NotFound with the candidates as suggestions + ToolResolution::NotFound { + original: name.to_string(), + suggestions: candidates.into_iter().map(String::from).collect(), + } + } + }; + } + + // 3. Semantic alias — known intent mappings where edit distance fails. + // e.g., "rename_file" means "move_file" but is closer to "read_file" by + // Levenshtein distance. This table is checked AFTER exact match, BEFORE + // fuzzy matching, and only fires for qualified names (server.tool). + let server_prefix = name.split('.').next().unwrap_or(""); + let suffix = name.split('.').nth(1).unwrap_or(""); + + for &(alias_from, alias_to) in SEMANTIC_ALIASES { + if suffix == alias_from { + let candidate = format!("{server_prefix}.{alias_to}"); + if self.get_tool(&candidate).is_some() { + return ToolResolution::Corrected { + resolved: candidate, + original: name.to_string(), + score: 1.0, // Semantic match — highest confidence + }; + } + } + } + + // 4. Qualified but wrong — find similar tools from the same server prefix + + let same_server: Vec<(&str, &str)> = self + .tools + .keys() + .filter_map(|fq| { + let tool_suffix = fq.split('.').nth(1)?; + if fq.starts_with(server_prefix) && fq.contains('.') { + Some((fq.as_str(), tool_suffix)) + } else { + None + } + }) + .collect(); + + if !same_server.is_empty() { + // Compare suffixes (the part after the dot) via Levenshtein + let mut best: Option<(&str, f64)> = None; + for (fq_name, tool_suffix) in &same_server { + let score = similarity(suffix, tool_suffix); + if score >= min_similarity + && best.map_or(true, |(_, best_score)| score > best_score) + { + best = Some((fq_name, score)); + } + } + + if let Some((resolved, score)) = best { + return ToolResolution::Corrected { + resolved: resolved.to_string(), + original: name.to_string(), + score, + }; + } + } + + // 4. Nothing matched — provide suggestions from the full registry + ToolResolution::NotFound { + original: name.to_string(), + suggestions: self.find_similar(name, 3), + } + } + + /// Find up to `max_results` tools most similar to `name`, ranked by score. + /// + /// Compares against tool suffixes if `name` is qualified (`server.tool`), + /// or against full names otherwise. Returns `(tool_name, score)` pairs + /// with score in 0.0–1.0 (higher = more similar). + pub fn find_similar(&self, name: &str, max_results: usize) -> Vec { + let query_suffix = name.split('.').next_back().unwrap_or(name); + + let mut scored: Vec<(String, f64)> = self + .tools + .keys() + .map(|fq| { + let tool_suffix = fq.split('.').nth(1).unwrap_or(fq); + let score = similarity(query_suffix, tool_suffix); + (fq.clone(), score) + }) + .filter(|(_, score)| *score > 0.3) // Floor — don't suggest wildly different tools + .collect(); + + scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored.truncate(max_results); + scored.into_iter().map(|(name, _)| name).collect() + } +} + +// ─── Edit Distance ────────────────────────────────────────────────────────── + +/// Compute the Levenshtein edit distance between two strings. +fn levenshtein(a: &str, b: &str) -> usize { + let a_bytes = a.as_bytes(); + let b_bytes = b.as_bytes(); + let m = a_bytes.len(); + let n = b_bytes.len(); + + // Use single-row DP for O(min(m,n)) space + let mut prev: Vec = (0..=n).collect(); + let mut curr = vec![0usize; n + 1]; + + for i in 1..=m { + curr[0] = i; + for j in 1..=n { + let cost = if a_bytes[i - 1] == b_bytes[j - 1] { + 0 + } else { + 1 + }; + curr[j] = (prev[j] + 1) // deletion + .min(curr[j - 1] + 1) // insertion + .min(prev[j - 1] + cost); // substitution + } + std::mem::swap(&mut prev, &mut curr); + } + + prev[n] +} + +/// Normalized similarity between two strings (0.0 = completely different, 1.0 = identical). +fn similarity(a: &str, b: &str) -> f64 { + let max_len = a.len().max(b.len()); + if max_len == 0 { + return 1.0; + } + let dist = levenshtein(a, b); + 1.0 - (dist as f64 / max_len as f64) +} + +impl Default for ToolRegistry { + fn default() -> Self { + Self::new() + } +} + +// ─── Category-Based Tool Selection (Tier 1.5) ────────────────────────────── + +/// A functional category grouping related tools for two-pass selection. +/// +/// Presented to the model as a synthetic "meta-tool" on the first turn. +/// When selected by the model, the category is expanded to its real tools. +#[derive(Debug, Clone)] +pub struct ToolCategory { + /// Category identifier used as the function name (e.g., `"file_browse"`). + pub name: String, + /// Human-readable description for the model. Must be discriminative + /// enough to distinguish sibling categories. + pub description: String, + /// Fully-qualified tool names belonging to this category. + pub tool_names: Vec, +} + +/// Registry of tool categories for two-pass selection. +/// +/// Built from the live `ToolRegistry` at startup. Categories are hardcoded +/// functional groupings (not auto-generated from servers), because the +/// grouping is semantic — filesystem read vs write, document read vs create. +/// +/// The 16 categories reduce the tool-selection decision space from 83 flat +/// tools (~10,700 tokens) to 16 categories (~1,600 tokens), near the +/// K=15 sweet spot identified in ADR-010 benchmarks. +#[derive(Debug, Clone)] +pub struct CategoryRegistry { + categories: Vec, + tool_to_category: HashMap, +} + +impl CategoryRegistry { + /// Build category definitions, filtering out categories whose tools + /// are not present in the live registry (server not running). + /// + /// A category is included if at least one of its tools is registered. + pub fn build(registry: &ToolRegistry) -> Self { + let defs = default_category_definitions(); + let mut categories = Vec::new(); + let mut tool_to_category = HashMap::new(); + + for (name, description, tool_names) in defs { + // Keep only tools that actually exist in the registry + let live_tools: Vec = tool_names + .into_iter() + .filter(|t| registry.get_tool(t).is_some()) + .collect(); + + if live_tools.is_empty() { + continue; // Skip categories with no live tools + } + + for tool in &live_tools { + tool_to_category.insert(tool.clone(), name.clone()); + } + + categories.push(ToolCategory { + name, + description, + tool_names: live_tools, + }); + } + + Self { + categories, + tool_to_category, + } + } + + /// Serialize categories as OpenAI function-calling format. + /// + /// Each category becomes a synthetic tool with a single `"intent"` + /// parameter. The model calls these to signal which capability areas + /// it needs for the current task. + pub fn to_openai_tools(&self) -> Vec { + self.categories + .iter() + .map(|cat| { + serde_json::json!({ + "type": "function", + "function": { + "name": cat.name, + "description": cat.description, + "parameters": { + "type": "object", + "properties": { + "intent": { + "type": "string", + "description": "Brief description of what you want to do" + } + }, + "required": ["intent"] + } + } + }) + }) + .collect() + } + + /// Expand selected category names into the union of their real tool names. + /// + /// Unknown category names are silently ignored. Duplicate tools across + /// categories are deduplicated. + pub fn expand_categories(&self, selected: &[String]) -> Vec { + let mut tool_set: std::collections::HashSet = + std::collections::HashSet::new(); + + for cat in &self.categories { + if selected.iter().any(|s| s == &cat.name) { + tool_set.extend(cat.tool_names.iter().cloned()); + } + } + + let mut tools: Vec = tool_set.into_iter().collect(); + tools.sort(); // Deterministic ordering + tools + } + + /// Check if a name is a known category (not a real tool). + pub fn is_category(&self, name: &str) -> bool { + self.categories.iter().any(|c| c.name == name) + } + + /// Look up which category a tool belongs to. + pub fn category_for_tool(&self, tool_name: &str) -> Option<&str> { + self.tool_to_category.get(tool_name).map(|s| s.as_str()) + } + + /// Number of active categories. + pub fn len(&self) -> usize { + self.categories.len() + } + + /// Whether any categories are registered. + pub fn is_empty(&self) -> bool { + self.categories.is_empty() + } + + /// All category names, in definition order. + pub fn category_names(&self) -> Vec<&str> { + self.categories.iter().map(|c| c.name.as_str()).collect() + } +} + +/// The 16 hardcoded functional categories covering all 83 tools. +/// +/// Returns `(name, description, tool_names)` triples. Tool names use +/// the fully-qualified `"server.tool"` format. +/// +/// Category design rationale: +/// - 16 categories (~K=15 sweet spot from ADR-010 benchmarks) +/// - Filesystem split into read (file_browse) vs write (file_edit) +/// to reduce mutable operation exposure +/// - Document split into read vs create for the same reason +/// - clipboard + system-info merged (both are quick OS queries) +/// - system.{open_application, open_file_with, take_screenshot} are +/// separate from system-info because they're action-oriented +/// - screenshot_pipeline tools grouped with image_ocr (visual extraction) +/// - system-settings gets its own category (OS preference changes) +fn default_category_definitions() -> Vec<(String, String, Vec)> { + vec![ + ( + "file_browse".into(), + "Browse, search, read, and inspect files and folders. List directory \ + contents, read file contents, search by filename or glob pattern, check \ + file metadata (size, dates, permissions), and watch folders for changes." + .into(), + vec![ + "filesystem.list_dir".into(), + "filesystem.read_file".into(), + "filesystem.search_files".into(), + "filesystem.get_metadata".into(), + "filesystem.watch_folder".into(), + ], + ), + ( + "file_edit".into(), + "Create, write, move, copy, rename, or delete files. All file modification \ + operations. Requires user confirmation." + .into(), + vec![ + "filesystem.write_file".into(), + "filesystem.move_file".into(), + "filesystem.copy_file".into(), + "filesystem.delete_file".into(), + ], + ), + ( + "document_read".into(), + "Extract text content from PDF, DOCX, HTML, or spreadsheet files. Compare \ + two documents for differences. Read CSV and Excel data." + .into(), + vec![ + "document.extract_text".into(), + "document.diff_documents".into(), + "document.read_spreadsheet".into(), + ], + ), + ( + "document_create".into(), + "Create PDF or DOCX files from markdown, fill PDF form fields, merge \ + multiple PDFs, or convert between document formats." + .into(), + vec![ + "document.convert_format".into(), + "document.create_pdf".into(), + "document.fill_pdf_form".into(), + "document.merge_pdfs".into(), + "document.create_docx".into(), + ], + ), + ( + "image_ocr".into(), + "Extract text from images and screenshots using OCR. Extract structured \ + data (receipts, invoices) or tabular data from visual content. Capture \ + screenshots with text extraction, identify UI elements, and suggest \ + actions from screen captures." + .into(), + vec![ + "ocr.extract_text_from_image".into(), + "ocr.extract_text_from_pdf".into(), + "ocr.extract_structured_data".into(), + "ocr.extract_table".into(), + "screenshot.capture_and_extract".into(), + "screenshot.extract_ui_elements".into(), + "screenshot.suggest_actions".into(), + ], + ), + ( + "data_analysis".into(), + "Query SQLite databases with SQL, process and export CSV files, find \ + duplicate records, and detect anomalies in datasets." + .into(), + vec![ + "data.query_sqlite".into(), + "data.deduplicate_records".into(), + "data.summarize_anomalies".into(), + "data.write_csv".into(), + "data.write_sqlite".into(), + ], + ), + ( + "knowledge_search".into(), + "Semantic search across indexed documents. Index folders for search, ask \ + questions about file contents, and find related text passages." + .into(), + vec![ + "knowledge.index_folder".into(), + "knowledge.search_documents".into(), + "knowledge.ask_about_files".into(), + "knowledge.update_index".into(), + "knowledge.get_related_chunks".into(), + ], + ), + ( + "security_privacy".into(), + "Scan files for PII (SSN, credit cards, emails) or leaked secrets \ + (API keys, passwords). Find duplicate files, propose cleanup, encrypt \ + or decrypt files." + .into(), + vec![ + "security.scan_for_pii".into(), + "security.scan_for_secrets".into(), + "security.find_duplicates".into(), + "security.propose_cleanup".into(), + "security.encrypt_file".into(), + "security.decrypt_file".into(), + ], + ), + ( + "task_management".into(), + "Create, list, and update tasks with priorities and due dates. Check \ + overdue items. Generate a daily task briefing." + .into(), + vec![ + "task.create_task".into(), + "task.list_tasks".into(), + "task.update_task".into(), + "task.get_overdue".into(), + "task.daily_briefing".into(), + ], + ), + ( + "calendar_scheduling".into(), + "View calendar events in a date range, create new events, find available \ + time slots, and block focus time." + .into(), + vec![ + "calendar.list_events".into(), + "calendar.create_event".into(), + "calendar.find_free_slots".into(), + "calendar.create_time_block".into(), + ], + ), + ( + "email_messaging".into(), + "Draft and send emails, list saved drafts, search mail by keyword, sender, \ + or date, and summarize email conversation threads." + .into(), + vec![ + "email.draft_email".into(), + "email.list_drafts".into(), + "email.search_emails".into(), + "email.summarize_thread".into(), + "email.send_draft".into(), + ], + ), + ( + "meeting_audio".into(), + "Transcribe audio recordings to text, extract action items and commitments \ + from transcripts, and generate formatted meeting minutes." + .into(), + vec![ + "meeting.transcribe_audio".into(), + "meeting.extract_action_items".into(), + "meeting.extract_commitments".into(), + "meeting.generate_minutes".into(), + ], + ), + ( + "clipboard_system".into(), + "Access clipboard contents and history (get, set, history). Get system \ + information (OS, CPU, memory, disk, network), monitor CPU and memory \ + usage, and list running processes." + .into(), + vec![ + "clipboard.get_clipboard".into(), + "clipboard.set_clipboard".into(), + "clipboard.clipboard_history".into(), + "system.get_system_info".into(), + "system.list_processes".into(), + "system.get_cpu_usage".into(), + "system.get_memory_usage".into(), + "system.get_disk_usage".into(), + "system.get_network_info".into(), + ], + ), + ( + "app_launcher".into(), + "Launch applications by name, open files with a specific program, take \ + a screenshot of the current screen, or kill a running process." + .into(), + vec![ + "system.open_application".into(), + "system.open_file_with".into(), + "system.take_screenshot".into(), + "system.kill_process".into(), + ], + ), + ( + "audit_compliance".into(), + "View tool execution logs, get session summaries, generate text audit \ + reports, and export signed audit PDFs." + .into(), + vec![ + "audit.get_tool_log".into(), + "audit.get_session_summary".into(), + "audit.generate_audit_report".into(), + "audit.export_audit_pdf".into(), + ], + ), + ( + "system_settings".into(), + "View and change OS preferences: display settings, sleep timer, audio \ + volume, default applications, default browser, power settings, and \ + Do Not Disturb mode. Requires user confirmation for changes." + .into(), + vec![ + "system-settings.get_display_settings".into(), + "system-settings.set_display_sleep".into(), + "system-settings.get_audio_settings".into(), + "system-settings.set_audio_volume".into(), + "system-settings.get_default_apps".into(), + "system-settings.set_default_browser".into(), + "system-settings.get_power_settings".into(), + "system-settings.toggle_do_not_disturb".into(), + ], + ), + ] +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_tool(name: &str, confirmation: bool, undo: bool) -> McpToolDefinition { + McpToolDefinition { + name: name.to_string(), + description: format!("Test tool: {name}"), + params_schema: serde_json::json!({ + "type": "object", + "properties": { + "path": { "type": "string" } + }, + "required": ["path"] + }), + returns_schema: serde_json::json!({}), + confirmation_required: confirmation, + undo_supported: undo, + } + } + + #[test] + fn test_register_and_lookup() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "filesystem", + vec![sample_tool("list_dir", false, false)], + ); + + assert_eq!(registry.len(), 1); + assert!(registry.get_tool("filesystem.list_dir").is_some()); + assert!(registry.get_tool("nonexistent.tool").is_none()); + } + + #[test] + fn test_fully_qualified_names_preserved() { + let mut registry = ToolRegistry::new(); + let tool = sample_tool("filesystem.read_file", false, false); + registry.register_server_tools("filesystem", vec![tool]); + + // Already has a dot, so no double-prefix + assert!(registry.get_tool("filesystem.read_file").is_some()); + } + + #[test] + fn test_server_name_lookup() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "ocr", + vec![sample_tool("extract_text", false, false)], + ); + + assert_eq!( + registry.get_server_for_tool("ocr.extract_text"), + Some("ocr") + ); + } + + #[test] + fn test_server_name_from_tool() { + assert_eq!( + ToolRegistry::server_name_from_tool("filesystem.list_dir"), + Some("filesystem") + ); + } + + #[test] + fn test_confirmation_and_undo() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "fs", + vec![ + sample_tool("read_file", false, false), + sample_tool("write_file", true, true), + ], + ); + + assert!(!registry.requires_confirmation("fs.read_file")); + assert!(registry.requires_confirmation("fs.write_file")); + assert!(!registry.supports_undo("fs.read_file")); + assert!(registry.supports_undo("fs.write_file")); + // Unknown tools default to requiring confirmation + assert!(registry.requires_confirmation("unknown.tool")); + } + + #[test] + fn test_unregister_server() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools("a", vec![sample_tool("tool1", false, false)]); + registry.register_server_tools("b", vec![sample_tool("tool2", false, false)]); + + assert_eq!(registry.len(), 2); + registry.unregister_server("a"); + assert_eq!(registry.len(), 1); + assert!(registry.get_tool("a.tool1").is_none()); + assert!(registry.get_tool("b.tool2").is_some()); + } + + #[test] + fn test_validate_tool_call_valid() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools("fs", vec![sample_tool("list_dir", false, false)]); + + let args = serde_json::json!({"path": "/tmp"}); + assert!(registry.validate_tool_call("fs.list_dir", &args).is_ok()); + } + + #[test] + fn test_validate_tool_call_missing_required() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools("fs", vec![sample_tool("list_dir", false, false)]); + + let args = serde_json::json!({}); + let err = registry.validate_tool_call("fs.list_dir", &args).unwrap_err(); + assert!(matches!(err, McpError::InvalidArguments { .. })); + } + + #[test] + fn test_validate_tool_call_unknown_tool() { + let registry = ToolRegistry::new(); + let args = serde_json::json!({}); + let err = registry + .validate_tool_call("nonexistent.tool", &args) + .unwrap_err(); + assert!(matches!(err, McpError::UnknownTool { .. })); + } + + #[test] + fn test_to_openai_tools() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools("fs", vec![sample_tool("list_dir", false, false)]); + + let openai_tools = registry.to_openai_tools(); + assert_eq!(openai_tools.len(), 1); + assert_eq!(openai_tools[0]["type"], "function"); + assert_eq!(openai_tools[0]["function"]["name"], "fs.list_dir"); + } + + #[test] + fn test_server_names() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools("fs", vec![sample_tool("t1", false, false)]); + registry.register_server_tools("ocr", vec![sample_tool("t2", false, false)]); + + let names = registry.server_names(); + assert_eq!(names.len(), 2); + assert!(names.contains(&"fs".to_string())); + assert!(names.contains(&"ocr".to_string())); + } + + #[test] + fn test_tools_for_server() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "fs", + vec![ + sample_tool("list_dir", false, false), + sample_tool("read_file", false, false), + ], + ); + registry.register_server_tools("ocr", vec![sample_tool("extract", false, false)]); + + assert_eq!(registry.tools_for_server("fs"), 2); + assert_eq!(registry.tools_for_server("ocr"), 1); + assert_eq!(registry.tools_for_server("nonexistent"), 0); + } + + #[test] + fn test_capability_summary_empty() { + let registry = ToolRegistry::new(); + let summary = registry.capability_summary(); + assert!(summary.contains("No MCP tools currently available")); + assert!(summary.contains("list_directory")); + assert!(summary.contains("read_file")); + } + + #[test] + fn test_capability_summary_with_servers() { + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "filesystem", + vec![ + sample_tool("list_dir", false, false), + sample_tool("read_file", false, false), + sample_tool("write_file", true, true), + ], + ); + registry.register_server_tools( + "ocr", + vec![sample_tool("extract_text", false, false)], + ); + + let summary = registry.capability_summary(); + assert!(summary.contains("4 tools across 2 servers")); + assert!(summary.contains("filesystem (3)")); + assert!(summary.contains("ocr (1)")); + } + + #[test] + fn test_capability_summary_categorizes_by_confirmation() { + let mut registry = ToolRegistry::new(); + // "audit" has only read tools (no confirmation) + registry.register_server_tools( + "audit", + vec![sample_tool("get_log", false, false)], + ); + // "email" has only write tools (confirmation required) + registry.register_server_tools( + "email", + vec![sample_tool("send_draft", true, false)], + ); + // "filesystem" has both + registry.register_server_tools( + "filesystem", + vec![ + sample_tool("read_file", false, false), + sample_tool("delete_file", true, true), + ], + ); + + let summary = registry.capability_summary(); + + // READ servers line should include audit and filesystem but not email + let read_line = summary + .lines() + .find(|l| l.starts_with("READ servers")) + .expect("should have READ line"); + assert!(read_line.contains("audit")); + assert!(read_line.contains("filesystem")); + assert!(!read_line.contains("email")); + + // WRITE servers line should include email and filesystem but not audit + let write_line = summary + .lines() + .find(|l| l.starts_with("WRITE servers")) + .expect("should have WRITE line"); + assert!(write_line.contains("email")); + assert!(write_line.contains("filesystem")); + assert!(!write_line.contains("audit")); + } + + // ─── Levenshtein / Similarity Tests ────────────────────────────── + + #[test] + fn test_levenshtein_identical() { + assert_eq!(levenshtein("move_file", "move_file"), 0); + } + + #[test] + fn test_levenshtein_basic() { + assert_eq!(levenshtein("rename_file", "move_file"), 5); + assert_eq!(levenshtein("rename_file", "read_file"), 3); + assert_eq!(levenshtein("kitten", "sitting"), 3); + } + + #[test] + fn test_levenshtein_empty() { + assert_eq!(levenshtein("", "abc"), 3); + assert_eq!(levenshtein("abc", ""), 3); + assert_eq!(levenshtein("", ""), 0); + } + + #[test] + fn test_similarity_range() { + let s = similarity("move_file", "move_file"); + assert!((s - 1.0).abs() < f64::EPSILON); + + let s = similarity("rename_file", "move_file"); + assert!(s > 0.5); // "rename_file" vs "move_file" share "_file" + + let s = similarity("abc", "xyz"); + assert!(s < 0.5); // Completely different + } + + // ─── ToolResolution Tests ──────────────────────────────────────── + + fn build_filesystem_registry() -> ToolRegistry { + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "filesystem", + vec![ + sample_tool("list_dir", false, false), + sample_tool("read_file", false, false), + sample_tool("write_file", true, true), + sample_tool("move_file", true, false), + sample_tool("copy_file", true, false), + sample_tool("search_files", false, false), + ], + ); + registry.register_server_tools( + "ocr", + vec![sample_tool("extract_text_from_image", false, false)], + ); + registry + } + + #[test] + fn test_resolve_exact_match() { + let registry = build_filesystem_registry(); + let result = registry.resolve("filesystem.list_dir", 0.5); + assert_eq!(result, ToolResolution::Exact("filesystem.list_dir".into())); + } + + #[test] + fn test_resolve_unprefixed() { + let registry = build_filesystem_registry(); + let result = registry.resolve("move_file", 0.5); + assert!(matches!( + result, + ToolResolution::Unprefixed { ref resolved, .. } if resolved == "filesystem.move_file" + )); + } + + #[test] + fn test_resolve_corrected_rename_to_move_via_alias() { + let registry = build_filesystem_registry(); + let result = registry.resolve("filesystem.rename_file", 0.5); + // "rename_file" is closer to "read_file" by edit distance (3 vs 5), + // but SEMANTIC_ALIASES maps "rename_file" → "move_file" before + // Levenshtein runs. This prevents dispatching to the wrong tool. + assert_eq!( + result, + ToolResolution::Corrected { + resolved: "filesystem.move_file".to_string(), + original: "filesystem.rename_file".to_string(), + score: 1.0, + }, + ); + } + + #[test] + fn test_resolve_corrected_returns_best_match() { + // With a limited registry, verify the closest match is selected + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "filesystem", + vec![ + sample_tool("move_file", true, false), + sample_tool("list_dir", false, false), + ], + ); + let result = registry.resolve("filesystem.move_files", 0.5); + // "move_files" vs "move_file" = distance 1, similarity ~0.9 + assert!( + matches!( + result, + ToolResolution::Corrected { ref resolved, score, .. } + if resolved == "filesystem.move_file" && score > 0.8 + ), + "expected Corrected to filesystem.move_file, got: {result:?}" + ); + } + + #[test] + fn test_resolve_not_found_below_threshold() { + let registry = build_filesystem_registry(); + // "filesystem.zzzzzz" has no similarity to any tool + let result = registry.resolve("filesystem.zzzzzz", 0.5); + assert!(matches!(result, ToolResolution::NotFound { .. })); + } + + #[test] + fn test_resolve_not_found_unknown_server() { + let registry = build_filesystem_registry(); + let result = registry.resolve("nonexistent.some_tool", 0.5); + assert!(matches!(result, ToolResolution::NotFound { .. })); + } + + #[test] + fn test_resolve_not_found_has_suggestions() { + let registry = build_filesystem_registry(); + let result = registry.resolve("filesystem.zzzzzz", 0.3); + if let ToolResolution::NotFound { suggestions, .. } = result { + // Should have some suggestions even at low threshold + // (tools with shared "_" characters might score above 0.3) + assert!(suggestions.len() <= 3); // max_results = 3 + } + } + + #[test] + fn test_resolve_unprefixed_ambiguous() { + // Register the same tool name in two different servers + let mut registry = ToolRegistry::new(); + registry.register_server_tools("a", vec![sample_tool("run", false, false)]); + registry.register_server_tools("b", vec![sample_tool("run", false, false)]); + + let result = registry.resolve("run", 0.5); + // Ambiguous — two matches → NotFound with candidates as suggestions + assert!(matches!(result, ToolResolution::NotFound { ref suggestions, .. } if suggestions.len() == 2)); + } + + #[test] + fn test_find_similar() { + let registry = build_filesystem_registry(); + let similar = registry.find_similar("rename_file", 3); + assert!(!similar.is_empty()); + // "move_file" should be among the suggestions (shares "_file" suffix) + assert!(similar.iter().any(|s| s.contains("move_file"))); + } + + #[test] + fn test_resolution_resolved_name() { + let exact = ToolResolution::Exact("filesystem.list_dir".into()); + assert_eq!(exact.resolved_name(), Some("filesystem.list_dir")); + + let not_found = ToolResolution::NotFound { + original: "bad.tool".into(), + suggestions: vec![], + }; + assert_eq!(not_found.resolved_name(), None); + } + + // ─── to_openai_tools_filtered Tests ───────────────────────────── + + #[test] + fn test_to_openai_tools_filtered_returns_matching() { + let registry = build_filesystem_registry(); + let filtered = registry.to_openai_tools_filtered(&[ + "filesystem.list_dir".to_string(), + "filesystem.move_file".to_string(), + ]); + assert_eq!(filtered.len(), 2); + let names: Vec<&str> = filtered + .iter() + .filter_map(|v| v["function"]["name"].as_str()) + .collect(); + assert!(names.contains(&"filesystem.list_dir")); + assert!(names.contains(&"filesystem.move_file")); + } + + #[test] + fn test_to_openai_tools_filtered_skips_missing() { + let registry = build_filesystem_registry(); + let filtered = registry.to_openai_tools_filtered(&[ + "filesystem.list_dir".to_string(), + "nonexistent.tool".to_string(), + ]); + assert_eq!(filtered.len(), 1); + } + + #[test] + fn test_to_openai_tools_filtered_empty_input() { + let registry = build_filesystem_registry(); + let filtered = registry.to_openai_tools_filtered(&[]); + assert!(filtered.is_empty()); + } + + // ─── CategoryRegistry Tests ───────────────────────────────────── + + /// Build a registry with tools from multiple servers for category testing. + fn build_multi_server_registry() -> ToolRegistry { + let mut registry = ToolRegistry::new(); + registry.register_server_tools( + "filesystem", + vec![ + sample_tool("list_dir", false, false), + sample_tool("read_file", false, false), + sample_tool("search_files", false, false), + sample_tool("get_metadata", false, false), + sample_tool("watch_folder", false, false), + sample_tool("write_file", true, false), + sample_tool("move_file", true, true), + sample_tool("copy_file", true, false), + sample_tool("delete_file", true, true), + ], + ); + registry.register_server_tools( + "ocr", + vec![ + sample_tool("extract_text_from_image", false, false), + sample_tool("extract_text_from_pdf", false, false), + sample_tool("extract_structured_data", false, false), + sample_tool("extract_table", false, false), + ], + ); + registry.register_server_tools( + "task", + vec![ + sample_tool("create_task", true, false), + sample_tool("list_tasks", false, false), + sample_tool("update_task", true, false), + sample_tool("get_overdue", false, false), + sample_tool("daily_briefing", false, false), + ], + ); + registry + } + + #[test] + fn test_category_registry_build_filters_missing_servers() { + // Registry only has filesystem, ocr, task — not all 13 servers. + // Categories for missing servers should be excluded. + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + // Should include: file_browse, file_edit, image_ocr, task_management + assert!(cat_reg.is_category("file_browse")); + assert!(cat_reg.is_category("file_edit")); + assert!(cat_reg.is_category("image_ocr")); + assert!(cat_reg.is_category("task_management")); + + // Should NOT include categories for servers that aren't running + assert!(!cat_reg.is_category("email_messaging")); + assert!(!cat_reg.is_category("calendar_scheduling")); + assert!(!cat_reg.is_category("meeting_audio")); + assert!(!cat_reg.is_category("knowledge_search")); + } + + #[test] + fn test_category_registry_expand_single() { + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + let expanded = cat_reg.expand_categories(&["file_browse".to_string()]); + assert_eq!(expanded.len(), 5); + assert!(expanded.contains(&"filesystem.list_dir".to_string())); + assert!(expanded.contains(&"filesystem.read_file".to_string())); + assert!(expanded.contains(&"filesystem.search_files".to_string())); + assert!(expanded.contains(&"filesystem.get_metadata".to_string())); + assert!(expanded.contains(&"filesystem.watch_folder".to_string())); + } + + #[test] + fn test_category_registry_expand_multiple() { + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + let expanded = cat_reg.expand_categories(&[ + "file_browse".to_string(), + "image_ocr".to_string(), + ]); + // 5 file_browse + 4 image_ocr = 9 tools + assert_eq!(expanded.len(), 9); + assert!(expanded.contains(&"filesystem.list_dir".to_string())); + assert!(expanded.contains(&"ocr.extract_text_from_image".to_string())); + } + + #[test] + fn test_category_registry_expand_unknown_ignored() { + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + let expanded = cat_reg.expand_categories(&[ + "file_browse".to_string(), + "nonexistent_category".to_string(), + ]); + // Should only have file_browse tools, nonexistent is ignored + assert_eq!(expanded.len(), 5); + } + + #[test] + fn test_category_registry_to_openai_tools() { + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + let tools = cat_reg.to_openai_tools(); + assert!(!tools.is_empty()); + + // Each tool should have the correct structure + for tool in &tools { + assert_eq!(tool["type"], "function"); + assert!(tool["function"]["name"].is_string()); + assert!(tool["function"]["description"].is_string()); + assert!(tool["function"]["parameters"]["properties"]["intent"].is_object()); + } + + // Check a specific category + let file_browse = tools + .iter() + .find(|t| t["function"]["name"] == "file_browse"); + assert!(file_browse.is_some()); + } + + #[test] + fn test_category_registry_is_category() { + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + assert!(cat_reg.is_category("file_browse")); + assert!(cat_reg.is_category("image_ocr")); + assert!(!cat_reg.is_category("filesystem.list_dir")); // real tool, not category + assert!(!cat_reg.is_category("nonexistent")); + } + + #[test] + fn test_category_registry_category_for_tool() { + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + assert_eq!( + cat_reg.category_for_tool("filesystem.list_dir"), + Some("file_browse") + ); + assert_eq!( + cat_reg.category_for_tool("filesystem.move_file"), + Some("file_edit") + ); + assert_eq!( + cat_reg.category_for_tool("ocr.extract_text_from_image"), + Some("image_ocr") + ); + assert_eq!(cat_reg.category_for_tool("nonexistent.tool"), None); + } + + #[test] + fn test_category_registry_len() { + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + // Should have 4 categories: file_browse, file_edit, image_ocr, task_management + assert_eq!(cat_reg.len(), 4); + assert!(!cat_reg.is_empty()); + } + + #[test] + fn test_category_registry_empty() { + let registry = ToolRegistry::new(); + let cat_reg = CategoryRegistry::build(®istry); + + assert_eq!(cat_reg.len(), 0); + assert!(cat_reg.is_empty()); + } + + #[test] + fn test_category_registry_category_names() { + let registry = build_multi_server_registry(); + let cat_reg = CategoryRegistry::build(®istry); + + let names = cat_reg.category_names(); + assert!(names.contains(&"file_browse")); + assert!(names.contains(&"file_edit")); + assert!(names.contains(&"image_ocr")); + assert!(names.contains(&"task_management")); + } + + #[test] + fn test_default_categories_cover_all_83_tools() { + // Verify that the hardcoded categories cover all 83 tools across 15 servers. + // Original 67 PRD tools + 5 extra system tools + 3 screenshot-pipeline + 8 system-settings. + let defs = default_category_definitions(); + let mut all_tools: Vec = Vec::new(); + for (_, _, tools) in &defs { + all_tools.extend(tools.clone()); + } + // Should be exactly 83 tools + assert_eq!(all_tools.len(), 83, "categories must cover all 83 tools"); + + // No duplicates + let unique: std::collections::HashSet<&str> = + all_tools.iter().map(|s| s.as_str()).collect(); + assert_eq!( + unique.len(), + all_tools.len(), + "no tool should appear in multiple categories" + ); + } + + #[test] + fn test_default_categories_count() { + let defs = default_category_definitions(); + assert_eq!(defs.len(), 16, "should have exactly 16 categories"); + } +} diff --git a/src-tauri/src/mcp_client/transport.rs b/src-tauri/src/mcp_client/transport.rs new file mode 100644 index 0000000..f1ff112 --- /dev/null +++ b/src-tauri/src/mcp_client/transport.rs @@ -0,0 +1,244 @@ +//! JSON-RPC over stdio transport. +//! +//! Handles low-level communication with MCP server child processes: +//! - Writing JSON-RPC requests to stdin +//! - Reading JSON-RPC responses from stdout +//! - Line-delimited JSON protocol (one JSON object per line) + +use std::sync::atomic::{AtomicU64, Ordering}; + +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{ChildStdin, ChildStdout}; +use tokio::sync::Mutex; + +use super::errors::McpError; +use super::types::{JsonRpcRequest, JsonRpcResponse}; + +// ─── Request ID Generator ──────────────────────────────────────────────────── + +/// Global monotonic request ID counter. +static NEXT_REQUEST_ID: AtomicU64 = AtomicU64::new(1); + +/// Generate a unique request ID. +pub fn next_request_id() -> u64 { + NEXT_REQUEST_ID.fetch_add(1, Ordering::Relaxed) +} + +// ─── Transport ─────────────────────────────────────────────────────────────── + +/// Bi-directional JSON-RPC transport over a child process's stdio. +pub struct StdioTransport { + server_name: String, + writer: Mutex, + reader: Mutex>, +} + +impl StdioTransport { + /// Create a new transport from a child process's stdin/stdout. + pub fn new(server_name: &str, stdin: ChildStdin, stdout: ChildStdout) -> Self { + Self { + server_name: server_name.to_string(), + writer: Mutex::new(stdin), + reader: Mutex::new(BufReader::new(stdout)), + } + } + + /// Send a JSON-RPC request and wait for the matching response. + /// + /// This is a simple request-response pattern: write one line of JSON, + /// read lines until we get a response with a matching `id`. + pub async fn request( + &self, + method: &str, + params: Option, + ) -> Result { + let id = next_request_id(); + let req = JsonRpcRequest::new(id, method, params); + + // Serialize and send + let mut json = serde_json::to_string(&req).map_err(|e| McpError::TransportError { + server: self.server_name.clone(), + reason: format!("failed to serialize request: {e}"), + })?; + json.push('\n'); + + { + let mut writer = self.writer.lock().await; + writer + .write_all(json.as_bytes()) + .await + .map_err(|e| McpError::TransportError { + server: self.server_name.clone(), + reason: format!("failed to write to stdin: {e}"), + })?; + writer + .flush() + .await + .map_err(|e| McpError::TransportError { + server: self.server_name.clone(), + reason: format!("failed to flush stdin: {e}"), + })?; + } + + // Read response lines until we find one with matching id + let mut line_buf = String::new(); + let mut reader = self.reader.lock().await; + + loop { + line_buf.clear(); + let bytes_read = reader + .read_line(&mut line_buf) + .await + .map_err(|e| McpError::TransportError { + server: self.server_name.clone(), + reason: format!("failed to read from stdout: {e}"), + })?; + + if bytes_read == 0 { + return Err(McpError::TransportError { + server: self.server_name.clone(), + reason: "server stdout closed (process may have exited)".into(), + }); + } + + let trimmed = line_buf.trim(); + if trimmed.is_empty() { + continue; + } + + // Try to parse as JSON-RPC response + match serde_json::from_str::(trimmed) { + Ok(resp) if resp.id == id => return Ok(resp), + Ok(_) => { + // Response for a different request ID — skip + // This shouldn't happen in our single-threaded protocol, + // but handle gracefully. + continue; + } + Err(_) => { + // Not a JSON-RPC response — could be server log output. + // Skip and keep reading. + continue; + } + } + } + } + + /// Send a JSON-RPC notification (no response expected). + pub async fn notify( + &self, + method: &str, + params: Option, + ) -> Result<(), McpError> { + let notification = serde_json::json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }); + + let mut json = serde_json::to_string(¬ification).map_err(|e| { + McpError::TransportError { + server: self.server_name.clone(), + reason: format!("failed to serialize notification: {e}"), + } + })?; + json.push('\n'); + + let mut writer = self.writer.lock().await; + writer + .write_all(json.as_bytes()) + .await + .map_err(|e| McpError::TransportError { + server: self.server_name.clone(), + reason: format!("failed to write notification: {e}"), + })?; + writer + .flush() + .await + .map_err(|e| McpError::TransportError { + server: self.server_name.clone(), + reason: format!("failed to flush notification: {e}"), + })?; + + Ok(()) + } +} + +// ─── Response Helpers ──────────────────────────────────────────────────────── + +/// Extract the result from a JSON-RPC response, converting errors to `McpError`. +pub fn extract_result(response: JsonRpcResponse) -> Result { + if let Some(err) = response.error { + return Err(McpError::ServerError { + code: err.code, + message: err.message, + data: err.data, + }); + } + + response.result.ok_or(McpError::ServerError { + code: -32603, + message: "response missing both result and error".into(), + data: None, + }) +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_next_request_id_is_monotonic() { + let id1 = next_request_id(); + let id2 = next_request_id(); + assert!(id2 > id1); + } + + #[test] + fn test_extract_result_success() { + let resp = JsonRpcResponse { + jsonrpc: "2.0".into(), + id: 1, + result: Some(serde_json::json!({"text": "hello"})), + error: None, + }; + let result = extract_result(resp).unwrap(); + assert_eq!(result["text"], "hello"); + } + + #[test] + fn test_extract_result_error() { + let resp = JsonRpcResponse { + jsonrpc: "2.0".into(), + id: 1, + result: None, + error: Some(super::super::types::JsonRpcError { + code: -32601, + message: "Method not found".into(), + data: None, + }), + }; + let err = extract_result(resp).unwrap_err(); + match err { + McpError::ServerError { code, message, .. } => { + assert_eq!(code, -32601); + assert_eq!(message, "Method not found"); + } + _ => panic!("expected ServerError"), + } + } + + #[test] + fn test_extract_result_missing_both() { + let resp = JsonRpcResponse { + jsonrpc: "2.0".into(), + id: 1, + result: None, + error: None, + }; + let err = extract_result(resp).unwrap_err(); + assert!(matches!(err, McpError::ServerError { .. })); + } +} diff --git a/src-tauri/src/mcp_client/types.rs b/src-tauri/src/mcp_client/types.rs new file mode 100644 index 0000000..2e9b112 --- /dev/null +++ b/src-tauri/src/mcp_client/types.rs @@ -0,0 +1,196 @@ +//! Shared types for the MCP client. +//! +//! JSON-RPC 2.0 message types and MCP protocol structures. + +use serde::{Deserialize, Serialize}; + +// ─── JSON-RPC 2.0 ─────────────────────────────────────────────────────────── + +/// JSON-RPC 2.0 request message. +#[derive(Debug, Clone, Serialize)] +pub struct JsonRpcRequest { + pub jsonrpc: String, + pub id: u64, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +impl JsonRpcRequest { + /// Create a new JSON-RPC request. + pub fn new(id: u64, method: &str, params: Option) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params, + } + } +} + +/// JSON-RPC 2.0 response message (success or error). +#[derive(Debug, Clone, Deserialize)] +pub struct JsonRpcResponse { + #[allow(dead_code)] + pub jsonrpc: String, + pub id: u64, + pub result: Option, + pub error: Option, +} + +/// JSON-RPC 2.0 error object. +#[derive(Debug, Clone, Deserialize)] +pub struct JsonRpcError { + pub code: i32, + pub message: String, + pub data: Option, +} + +// ─── MCP Protocol Types ────────────────────────────────────────────────────── + +/// MCP tool definition as returned by `initialize`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpToolDefinition { + pub name: String, + pub description: String, + #[serde(default, alias = "inputSchema")] + pub params_schema: serde_json::Value, + #[serde(default)] + pub returns_schema: serde_json::Value, + #[serde(default, alias = "confirmationRequired")] + pub confirmation_required: bool, + #[serde(default, alias = "undoSupported")] + pub undo_supported: bool, +} + +/// Server configuration from `mcp_servers.json`. +#[derive(Debug, Clone, Deserialize)] +pub struct ServerConfig { + pub command: String, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub env: std::collections::HashMap, + /// Per-server working directory (overrides the global working_dir). + #[serde(default)] + pub cwd: Option, + /// Optional Python virtual environment path. When set, `command` is resolved + /// to `{venv}/bin/{command}` and `VIRTUAL_ENV` + `PATH` are injected. + #[serde(default)] + pub venv: Option, +} + +/// Top-level MCP servers configuration file. +#[derive(Debug, Clone, Deserialize)] +pub struct McpServersConfig { + pub servers: std::collections::HashMap, +} + +/// Result of a tool call execution. +#[derive(Debug, Clone, Serialize)] +pub struct ToolCallResult { + pub tool_name: String, + pub success: bool, + pub result: Option, + pub error: Option, + pub execution_time_ms: u64, +} + +/// MCP initialize response payload. +#[derive(Debug, Clone, Deserialize)] +pub struct InitializeResult { + #[serde(default)] + pub capabilities: serde_json::Value, + #[serde(default)] + pub tools: Vec, + #[serde(default, alias = "serverInfo")] + pub server_info: Option, +} + +/// Server info returned in the initialize response. +#[derive(Debug, Clone, Deserialize)] +pub struct ServerInfo { + pub name: Option, + pub version: Option, +} + +// ─── Standard MCP Error Codes ──────────────────────────────────────────────── + +/// Well-known JSON-RPC / MCP error codes. +pub mod error_codes { + /// Invalid JSON was received. + pub const PARSE_ERROR: i32 = -32700; + /// The JSON sent is not a valid Request object. + pub const INVALID_REQUEST: i32 = -32600; + /// The method does not exist or is not available. + pub const METHOD_NOT_FOUND: i32 = -32601; + /// Invalid method parameters. + pub const INVALID_PARAMS: i32 = -32602; + /// Internal JSON-RPC error. + pub const INTERNAL_ERROR: i32 = -32603; + /// File not found (MCP extension). + pub const FILE_NOT_FOUND: i32 = -32001; + /// Permission denied (MCP extension). + pub const PERMISSION_DENIED: i32 = -32002; + /// Operation cancelled by user (MCP extension). + pub const CANCELLED: i32 = -32003; +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_json_rpc_request_serialization() { + let req = JsonRpcRequest::new(1, "initialize", None); + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"jsonrpc\":\"2.0\"")); + assert!(json.contains("\"id\":1")); + assert!(json.contains("\"method\":\"initialize\"")); + // params should be omitted when None + assert!(!json.contains("params")); + } + + #[test] + fn test_json_rpc_request_with_params() { + let params = serde_json::json!({"name": "test.tool", "arguments": {"path": "/tmp"}}); + let req = JsonRpcRequest::new(42, "tools/call", Some(params)); + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"id\":42")); + assert!(json.contains("tools/call")); + assert!(json.contains("/tmp")); + } + + #[test] + fn test_json_rpc_response_deserialization() { + let json = r#"{"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}"#; + let resp: JsonRpcResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.id, 1); + assert!(resp.result.is_some()); + assert!(resp.error.is_none()); + } + + #[test] + fn test_json_rpc_error_response() { + let json = r#"{ + "jsonrpc": "2.0", + "id": 2, + "result": null, + "error": {"code": -32601, "message": "Method not found"} + }"#; + let resp: JsonRpcResponse = serde_json::from_str(json).unwrap(); + assert!(resp.error.is_some()); + let err = resp.error.unwrap(); + assert_eq!(err.code, error_codes::METHOD_NOT_FOUND); + } + + #[test] + fn test_tool_definition_defaults() { + let json = r#"{"name": "test.tool", "description": "A test tool"}"#; + let tool: McpToolDefinition = serde_json::from_str(json).unwrap(); + assert!(!tool.confirmation_required); + assert!(!tool.undo_supported); + } +} diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json new file mode 100644 index 0000000..dedd647 --- /dev/null +++ b/src-tauri/tauri.conf.json @@ -0,0 +1,54 @@ +{ + "$schema": "https://schema.tauri.app/config/2", + "productName": "LocalCowork", + "version": "0.1.0", + "identifier": "com.localcowork.app", + "build": { + "frontendDist": "../dist", + "devUrl": "http://localhost:5173", + "beforeDevCommand": "npm run dev", + "beforeBuildCommand": "npm run build" + }, + "app": { + "windows": [ + { + "label": "main", + "title": "LocalCowork", + "width": 1200, + "height": 800, + "minWidth": 800, + "minHeight": 600, + "resizable": true, + "fullscreen": false + } + ], + "security": { + "csp": "default-src 'self'; connect-src 'self' http://localhost:* https://localhost:*; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self' data:" + } + }, + "bundle": { + "active": true, + "targets": "all", + "icon": [ + "icons/32x32.png", + "icons/128x128.png", + "icons/128x128@2x.png", + "icons/icon.icns", + "icons/icon.ico" + ], + "category": "public.app-category.productivity", + "copyright": "Copyright 2026 LocalCowork", + "publisher": "LocalCowork", + "longDescription": "On-device AI assistant powered by a locally-hosted LLM. Full privacy, no cloud dependency.", + "macOS": { + "entitlements": "entitlements.plist", + "minimumSystemVersion": "12.0" + }, + "windows": { + "nsis": { + "installMode": "currentUser", + "displayLanguageSelector": false + } + } + } +} diff --git a/src/App.tsx b/src/App.tsx new file mode 100644 index 0000000..b4ece1a --- /dev/null +++ b/src/App.tsx @@ -0,0 +1,92 @@ +import { useEffect } from "react"; + +import { ChatPanel } from "./components/Chat"; +import { FileBrowser } from "./components/FileBrowser"; +import { OnboardingWizard } from "./components/Onboarding"; +import { SettingsPanel } from "./components/Settings"; +import { useOnboardingStore } from "./stores/onboardingStore"; +import { useSettingsStore } from "./stores/settingsStore"; + +/** + * Root application component. + * + * Shows the OnboardingWizard on first run, then the main app layout. + */ +export function App(): React.JSX.Element { + const toggleSettings = useSettingsStore((s) => s.togglePanel); + const isSettingsOpen = useSettingsStore((s) => s.isOpen); + const startConfigWatch = useSettingsStore((s) => s.startConfigWatch); + const stopConfigWatch = useSettingsStore((s) => s.stopConfigWatch); + const configReloadNotification = useSettingsStore( + (s) => s.configReloadNotification, + ); + const clearConfigReloadNotification = useSettingsStore( + (s) => s.clearConfigReloadNotification, + ); + const isOnboardingComplete = useOnboardingStore((s) => s.isComplete); + + // Start/stop config file watching based on settings panel state + useEffect(() => { + if (isSettingsOpen) { + startConfigWatch(); + } else { + stopConfigWatch(); + } + return () => stopConfigWatch(); + }, [isSettingsOpen, startConfigWatch, stopConfigWatch]); + + if (!isOnboardingComplete) { + return ; + } + + return ( +
+ {/* Config reload toast notification */} + {configReloadNotification && ( +
+ 🔄 + {configReloadNotification} + +
+ )} + +
+
+
+

LocalCowork

+ on-device +
+ + powered by LFM2-24B-A2B from Liquid AI + +
+
+ +
+ +
+ + +
+ +
+ v0.1.0 — Agent Core +
+ + +
+ ); +} diff --git a/src/components/Chat/MessageInput.tsx b/src/components/Chat/MessageInput.tsx new file mode 100644 index 0000000..fd401b6 --- /dev/null +++ b/src/components/Chat/MessageInput.tsx @@ -0,0 +1,118 @@ +/** + * MessageInput — text input area for sending messages. + * + * Supports Enter to send (Shift+Enter for newline) and disables + * input while the assistant is generating. Includes an InputToolbar + * below the textarea for folder context (Cowork-style "Work in a folder"). + * Implements debouncing to prevent duplicate sends. + */ + +import { useCallback, useRef, useState } from "react"; + +import { InputToolbar } from "./InputToolbar"; + +interface MessageInputProps { + readonly onSend: (content: string) => void; + readonly disabled: boolean; +} + +/** Minimum time between send requests to prevent duplicates (500ms) */ +const SEND_DEBOUNCE_MS = 500; + +export function MessageInput({ + onSend, + disabled, +}: MessageInputProps): React.JSX.Element { + const [value, setValue] = useState(""); + const textareaRef = useRef(null); + const lastSendTimeRef = useRef(0); + const [isDebouncing, setIsDebouncing] = useState(false); + + const handleSend = useCallback(() => { + const trimmed = value.trim(); + if (!trimmed || disabled) return; + + // Debounce: ignore clicks within 500ms + const now = Date.now(); + if (now - lastSendTimeRef.current < SEND_DEBOUNCE_MS) { + setIsDebouncing(true); + setTimeout(() => setIsDebouncing(false), SEND_DEBOUNCE_MS); + return; + } + lastSendTimeRef.current = now; + + onSend(trimmed); + setValue(""); + + // Reset textarea height + if (textareaRef.current) { + textareaRef.current.style.height = "auto"; + } + }, [value, disabled, onSend]); + + const handleKeyDown = (e: React.KeyboardEvent): void => { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + handleSend(); + } + }; + + const handleInput = (e: React.ChangeEvent): void => { + setValue(e.target.value); + + // Auto-resize textarea + const textarea = e.target; + textarea.style.height = "auto"; + textarea.style.height = `${Math.min(textarea.scrollHeight, 200)}px`; + }; + + const isLoading = disabled || isDebouncing; + + return ( +
+
+