diff --git a/examples/localcowork/src-tauri/src/agent_core/orchestrator.rs b/examples/localcowork/src-tauri/src/agent_core/orchestrator.rs index 2222db2..b20bcdc 100644 --- a/examples/localcowork/src-tauri/src/agent_core/orchestrator.rs +++ b/examples/localcowork/src-tauri/src/agent_core/orchestrator.rs @@ -188,15 +188,17 @@ pub async fn orchestrate_dual_model( tracing::info!(step_count = plan.steps.len(), "orchestrator: plan created"); - // ── Build tool embedding index ────────────────────────────────────── + // ── Build tool embedding index (with cache) ─────────────────────────── let tool_pairs: Vec<(String, String)> = { let mcp = mcp_state.lock().await; mcp.registry.tool_name_description_pairs() }; - let tool_index = match ToolEmbeddingIndex::build( + let cache_dir = crate::cache_dir(); + let tool_index = match ToolEmbeddingIndex::build_with_cache( router.current_base_url(), &tool_pairs, + Some(&cache_dir), ) .await { @@ -212,7 +214,7 @@ pub async fn orchestrate_dual_model( } }; - tracing::info!(tool_count = tool_index.len(), "orchestrator: tool index built"); + tracing::info!(tool_count = tool_index.len(), "orchestrator: tool index ready"); // ── Plan validation gate (Improvement I4) ───────────────────────── { diff --git a/examples/localcowork/src-tauri/src/agent_core/tool_prefilter.rs b/examples/localcowork/src-tauri/src/agent_core/tool_prefilter.rs index 54afdc5..eb7b0d5 100644 --- a/examples/localcowork/src-tauri/src/agent_core/tool_prefilter.rs +++ b/examples/localcowork/src-tauri/src/agent_core/tool_prefilter.rs @@ -4,11 +4,19 @@ //! endpoint, then for each user query, embeds the query and selects the //! top-K tools by cosine similarity. //! +//! Caching: The embedding index can be cached to disk for fast startup. +//! Cache is invalidated when the model endpoint or tool list changes. +//! //! Ported from TypeScript `tests/model-behavior/benchmark-lfm.ts` (lines 270-414). use reqwest::Client as HttpClient; use serde::{Deserialize, Serialize}; -use std::time::Duration; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +#[allow(unused_imports)] +use std::io::{self, Write}; // ─── Error Type ───────────────────────────────────────────────────────────── @@ -43,6 +51,27 @@ struct EmbeddingResponse { data: Vec, } +/// Metadata stored alongside the cached index for validation. +#[derive(Debug, Serialize, Deserialize)] +struct CacheMetadata { + /// Endpoint used to build the index (e.g., "http://localhost:1234/v1") + endpoint: String, + /// Hash of tool names + descriptions to detect tool changes. + tool_hash: String, + /// Unix timestamp when the cache was created. + created_at: u64, + /// Embedding dimension. + dimension: usize, +} + +/// Full cache file contents. +#[derive(Debug, Serialize, Deserialize)] +struct CachedIndex { + metadata: CacheMetadata, + tool_names: Vec, + embeddings: Vec>, +} + // ─── Tool Embedding Index ─────────────────────────────────────────────────── /// Pre-computed tool embedding index for RAG pre-filtering. @@ -94,6 +123,143 @@ impl ToolEmbeddingIndex { }) } + /// Build the index, trying to load from cache first. + /// + /// If a valid cached index exists (same endpoint + tools), loads it. + /// Otherwise builds fresh and saves to cache. + pub async fn build_with_cache( + endpoint: &str, + tools: &[(String, String)], + cache_dir: Option<&Path>, + ) -> Result { + // Try to load from cache first + if let Some(dir) = cache_dir { + if let Ok(cached) = Self::load_from_cache(endpoint, tools, dir) { + tracing::info!( + tool_count = cached.len(), + "orchestrator: tool index loaded from cache" + ); + return Ok(cached); + } + } + + // Build fresh + let index = Self::build(endpoint, tools).await?; + + // Save to cache + if let Some(dir) = cache_dir { + if let Err(e) = index.save_to_cache(endpoint, tools, dir) { + tracing::warn!(error = %e, "failed to save tool index cache"); + } else { + tracing::info!("orchestrator: tool index cached"); + } + } + + Ok(index) + } + + /// Get the cache file path for a given endpoint. + fn cache_path(cache_dir: &Path, endpoint: &str) -> PathBuf { + // Create a safe filename from the endpoint + let safe_name = endpoint + .replace("://", "_") + .replace("/", "_") + .replace(":", "_") + .replace(".", "_"); + cache_dir.join(format!("tool_index_{}.json", safe_name)) + } + + /// Load the index from cache if valid. + fn load_from_cache( + endpoint: &str, + tools: &[(String, String)], + cache_dir: &Path, + ) -> Result { + let cache_path = Self::cache_path(cache_dir, endpoint); + + if !cache_path.exists() { + return Err(ToolPreFilterError::RequestFailed { + reason: "cache file not found".to_string(), + }); + } + + let content = fs::read_to_string(&cache_path).map_err(|e| ToolPreFilterError::RequestFailed { + reason: format!("failed to read cache: {}", e), + })?; + + let cached: CachedIndex = serde_json::from_str(&content).map_err(|e| ToolPreFilterError::RequestFailed { + reason: format!("failed to parse cache: {}", e), + })?; + + // Validate endpoint + if cached.metadata.endpoint != endpoint { + return Err(ToolPreFilterError::RequestFailed { + reason: "endpoint changed".to_string(), + }); + } + + // Validate tool hash + let current_hash = compute_tool_hash(tools); + if cached.metadata.tool_hash != current_hash { + return Err(ToolPreFilterError::RequestFailed { + reason: "tools changed".to_string(), + }); + } + + // Validate embeddings exist + if cached.embeddings.is_empty() && !cached.tool_names.is_empty() { + return Err(ToolPreFilterError::RequestFailed { + reason: "cached embeddings are empty".to_string(), + }); + } + + Ok(Self { + tool_names: cached.tool_names, + embeddings: cached.embeddings, + }) + } + + /// Save the index to cache. + fn save_to_cache( + &self, + endpoint: &str, + tools: &[(String, String)], + cache_dir: &Path, + ) -> Result<(), ToolPreFilterError> { + // Create cache directory if needed + fs::create_dir_all(cache_dir).map_err(|e| ToolPreFilterError::RequestFailed { + reason: format!("failed to create cache dir: {}", e), + })?; + + let dimension = self.embeddings.first().map(|e| e.len()).unwrap_or(0); + let created_at = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + + let cached = CachedIndex { + metadata: CacheMetadata { + endpoint: endpoint.to_string(), + tool_hash: compute_tool_hash(tools), + created_at, + dimension, + }, + tool_names: self.tool_names.clone(), + embeddings: self.embeddings.clone(), + }; + + let content = serde_json::to_string_pretty(&cached).map_err(|e| ToolPreFilterError::RequestFailed { + reason: format!("failed to serialize cache: {}", e), + })?; + + let cache_path = Self::cache_path(cache_dir, endpoint); + fs::write(&cache_path, content).map_err(|e| ToolPreFilterError::RequestFailed { + reason: format!("failed to write cache: {}", e), + })?; + + Ok(()) + } + /// Select the top-K tool names by cosine similarity to the query. /// /// Returns `(selected_names, scored_tools)` where scored_tools is sorted @@ -281,6 +447,21 @@ fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() } +/// Compute a hash of the tool list for cache validation. +/// +/// Uses a simple FNV-like hash for speed. +fn compute_tool_hash(tools: &[(String, String)]) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + for (name, desc) in tools { + name.hash(&mut hasher); + desc.hash(&mut hasher); + } + format!("{:x}", hasher.finish()) +} + // ─── Tests ────────────────────────────────────────────────────────────────── #[cfg(test)] @@ -359,4 +540,28 @@ mod tests { assert_eq!(index.len(), 2); assert!(!index.is_empty()); } + + #[test] + fn compute_tool_hash_deterministic() { + let tools = vec![ + ("tool1".to_string(), "description1".to_string()), + ("tool2".to_string(), "description2".to_string()), + ]; + let hash1 = compute_tool_hash(&tools); + let hash2 = compute_tool_hash(&tools); + assert_eq!(hash1, hash2, "same tools should produce same hash"); + } + + #[test] + fn compute_tool_hash_different_for_different_tools() { + let tools1 = vec![ + ("tool1".to_string(), "description1".to_string()), + ]; + let tools2 = vec![ + ("tool2".to_string(), "description2".to_string()), + ]; + let hash1 = compute_tool_hash(&tools1); + let hash2 = compute_tool_hash(&tools2); + assert_ne!(hash1, hash2, "different tools should produce different hashes"); + } } diff --git a/examples/localcowork/src-tauri/src/lib.rs b/examples/localcowork/src-tauri/src/lib.rs index b8c1de0..7e4cdf7 100644 --- a/examples/localcowork/src-tauri/src/lib.rs +++ b/examples/localcowork/src-tauri/src/lib.rs @@ -34,6 +34,11 @@ pub(crate) fn data_dir() -> std::path::PathBuf { .join(".localcowork") } +/// Returns the cache directory for the app (embedding indexes, etc.). +pub(crate) fn cache_dir() -> std::path::PathBuf { + data_dir().join("cache") +} + /// Initialize the tracing subscriber — writes structured logs to the app data directory. /// /// On each app startup: