Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/localcowork/src-tauri/src/agent_core/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,17 @@ pub async fn orchestrate_dual_model(

tracing::info!(step_count = plan.steps.len(), "orchestrator: plan created");

// ── Build tool embedding index ──────────────────────────────────────
// ── Build tool embedding index (with cache) ───────────────────────────
let tool_pairs: Vec<(String, String)> = {
let mcp = mcp_state.lock().await;
mcp.registry.tool_name_description_pairs()
};

let tool_index = match ToolEmbeddingIndex::build(
let cache_dir = crate::cache_dir();
let tool_index = match ToolEmbeddingIndex::build_with_cache(
router.current_base_url(),
&tool_pairs,
Some(&cache_dir),
)
.await
{
Expand All @@ -212,7 +214,7 @@ pub async fn orchestrate_dual_model(
}
};

tracing::info!(tool_count = tool_index.len(), "orchestrator: tool index built");
tracing::info!(tool_count = tool_index.len(), "orchestrator: tool index ready");

// ── Plan validation gate (Improvement I4) ─────────────────────────
{
Expand Down
207 changes: 206 additions & 1 deletion examples/localcowork/src-tauri/src/agent_core/tool_prefilter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
//! endpoint, then for each user query, embeds the query and selects the
//! top-K tools by cosine similarity.
//!
//! Caching: The embedding index can be cached to disk for fast startup.
//! Cache is invalidated when the model endpoint or tool list changes.
//!
//! Ported from TypeScript `tests/model-behavior/benchmark-lfm.ts` (lines 270-414).

use reqwest::Client as HttpClient;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime, UNIX_EPOCH};

#[allow(unused_imports)]
use std::io::{self, Write};

Comment on lines +18 to 20
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::io::{self, Write} is imported with #[allow(unused_imports)] but never used in this module. Please remove the unused import (and the allow) to keep the module clean and avoid masking future unused-import issues.

Suggested change
#[allow(unused_imports)]
use std::io::{self, Write};

Copilot uses AI. Check for mistakes.
// ─── Error Type ─────────────────────────────────────────────────────────────

Expand Down Expand Up @@ -43,6 +51,27 @@ struct EmbeddingResponse {
data: Vec<RawEmbeddingItem>,
}

/// Metadata stored alongside the cached index for validation.
#[derive(Debug, Serialize, Deserialize)]
struct CacheMetadata {
/// Endpoint used to build the index (e.g., "http://localhost:1234/v1")
endpoint: String,
/// Hash of tool names + descriptions to detect tool changes.
tool_hash: String,
/// Unix timestamp when the cache was created.
created_at: u64,
/// Embedding dimension.
dimension: usize,
}

/// Full cache file contents.
#[derive(Debug, Serialize, Deserialize)]
struct CachedIndex {
metadata: CacheMetadata,
tool_names: Vec<String>,
embeddings: Vec<Vec<f32>>,
}

// ─── Tool Embedding Index ───────────────────────────────────────────────────

/// Pre-computed tool embedding index for RAG pre-filtering.
Expand Down Expand Up @@ -94,6 +123,143 @@ impl ToolEmbeddingIndex {
})
}

/// Build the index, trying to load from cache first.
///
/// If a valid cached index exists (same endpoint + tools), loads it.
/// Otherwise builds fresh and saves to cache.
pub async fn build_with_cache(
endpoint: &str,
tools: &[(String, String)],
cache_dir: Option<&Path>,
) -> Result<Self, ToolPreFilterError> {
// Try to load from cache first
if let Some(dir) = cache_dir {
if let Ok(cached) = Self::load_from_cache(endpoint, tools, dir) {
tracing::info!(
tool_count = cached.len(),
"orchestrator: tool index loaded from cache"
);
return Ok(cached);
}
}

// Build fresh
let index = Self::build(endpoint, tools).await?;

// Save to cache
if let Some(dir) = cache_dir {
if let Err(e) = index.save_to_cache(endpoint, tools, dir) {
tracing::warn!(error = %e, "failed to save tool index cache");
} else {
tracing::info!("orchestrator: tool index cached");
}
}

Ok(index)
}

/// Get the cache file path for a given endpoint.
fn cache_path(cache_dir: &Path, endpoint: &str) -> PathBuf {
// Create a safe filename from the endpoint
let safe_name = endpoint
.replace("://", "_")
.replace("/", "_")
.replace(":", "_")
.replace(".", "_");
cache_dir.join(format!("tool_index_{}.json", safe_name))
}

/// Load the index from cache if valid.
fn load_from_cache(
endpoint: &str,
tools: &[(String, String)],
cache_dir: &Path,
) -> Result<Self, ToolPreFilterError> {
let cache_path = Self::cache_path(cache_dir, endpoint);

if !cache_path.exists() {
return Err(ToolPreFilterError::RequestFailed {
reason: "cache file not found".to_string(),
});
}

let content = fs::read_to_string(&cache_path).map_err(|e| ToolPreFilterError::RequestFailed {
reason: format!("failed to read cache: {}", e),
})?;

let cached: CachedIndex = serde_json::from_str(&content).map_err(|e| ToolPreFilterError::RequestFailed {
reason: format!("failed to parse cache: {}", e),
})?;

// Validate endpoint
if cached.metadata.endpoint != endpoint {
return Err(ToolPreFilterError::RequestFailed {
reason: "endpoint changed".to_string(),
});
}

// Validate tool hash
let current_hash = compute_tool_hash(tools);
if cached.metadata.tool_hash != current_hash {
return Err(ToolPreFilterError::RequestFailed {
reason: "tools changed".to_string(),
});
}

// Validate embeddings exist
if cached.embeddings.is_empty() && !cached.tool_names.is_empty() {
return Err(ToolPreFilterError::RequestFailed {
reason: "cached embeddings are empty".to_string(),
});
}

Ok(Self {
tool_names: cached.tool_names,
embeddings: cached.embeddings,
})
Comment on lines +194 to +219
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache metadata records dimension, but the loader never checks it (nor that all cached embeddings have consistent lengths). Because cosine similarity currently zips vectors, a dimension mismatch would silently produce incorrect scores. Consider validating metadata.dimension against the cached vectors (and/or returning ToolPreFilterError::DimensionMismatch) so invalid caches are rejected instead of degrading retrieval quality.

Copilot uses AI. Check for mistakes.
}
Comment on lines +209 to +220
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_from_cache can construct an index where tool_names.len() and embeddings.len() differ, which will later panic in filter() when indexing self.embeddings[i]. Please validate that counts match (and ideally that each embedding has the expected dimension) before returning the cached index, and treat mismatches as cache-invalid.

Copilot uses AI. Check for mistakes.

/// Save the index to cache.
fn save_to_cache(
&self,
endpoint: &str,
tools: &[(String, String)],
cache_dir: &Path,
) -> Result<(), ToolPreFilterError> {
// Create cache directory if needed
fs::create_dir_all(cache_dir).map_err(|e| ToolPreFilterError::RequestFailed {
reason: format!("failed to create cache dir: {}", e),
})?;

let dimension = self.embeddings.first().map(|e| e.len()).unwrap_or(0);
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);

let cached = CachedIndex {
metadata: CacheMetadata {
endpoint: endpoint.to_string(),
tool_hash: compute_tool_hash(tools),
created_at,
dimension,
},
tool_names: self.tool_names.clone(),
embeddings: self.embeddings.clone(),
};

let content = serde_json::to_string_pretty(&cached).map_err(|e| ToolPreFilterError::RequestFailed {
reason: format!("failed to serialize cache: {}", e),
})?;

let cache_path = Self::cache_path(cache_dir, endpoint);
fs::write(&cache_path, content).map_err(|e| ToolPreFilterError::RequestFailed {
reason: format!("failed to write cache: {}", e),
})?;

Ok(())
}

/// Select the top-K tool names by cosine similarity to the query.
///
/// Returns `(selected_names, scored_tools)` where scored_tools is sorted
Expand Down Expand Up @@ -281,6 +447,21 @@ fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}

/// Compute a hash of the tool list for cache validation.
///
/// Uses a simple FNV-like hash for speed.
fn compute_tool_hash(tools: &[(String, String)]) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

let mut hasher = DefaultHasher::new();
for (name, desc) in tools {
name.hash(&mut hasher);
desc.hash(&mut hasher);
}
format!("{:x}", hasher.finish())
Comment on lines +450 to +462
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc comment says this uses an "FNV-like" hash, but the implementation uses std::collections::hash_map::DefaultHasher (SipHash), which is not FNV. Please update the comment to match the implementation, or switch to an explicit (and stable) hash algorithm if you want the hash format/behavior to be predictable across Rust versions.

Copilot uses AI. Check for mistakes.
}

// ─── Tests ──────────────────────────────────────────────────────────────────

#[cfg(test)]
Expand Down Expand Up @@ -359,4 +540,28 @@ mod tests {
assert_eq!(index.len(), 2);
assert!(!index.is_empty());
}

#[test]
fn compute_tool_hash_deterministic() {
let tools = vec![
("tool1".to_string(), "description1".to_string()),
("tool2".to_string(), "description2".to_string()),
];
let hash1 = compute_tool_hash(&tools);
let hash2 = compute_tool_hash(&tools);
assert_eq!(hash1, hash2, "same tools should produce same hash");
}

#[test]
fn compute_tool_hash_different_for_different_tools() {
let tools1 = vec![
("tool1".to_string(), "description1".to_string()),
];
let tools2 = vec![
("tool2".to_string(), "description2".to_string()),
];
let hash1 = compute_tool_hash(&tools1);
let hash2 = compute_tool_hash(&tools2);
assert_ne!(hash1, hash2, "different tools should produce different hashes");
}
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds disk caching logic (build_with_cache / load_from_cache / save_to_cache), but the tests added only cover compute_tool_hash. Since this module already has unit tests, please add coverage for the cache round-trip and invalidation cases (endpoint change, tool_hash change, corrupted file) using a temp directory.

Suggested change
}
}
// ─── Cache Tests ─────────────────────────────────────────────────────────
use std::env;
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
/// Create a unique temporary directory for cache tests.
fn temp_cache_dir() -> PathBuf {
let mut dir = env::temp_dir();
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time went backwards")
.as_nanos();
dir.push(format!("tool_prefilter_cache_test_{}", timestamp));
fs::create_dir_all(&dir).expect("failed to create temp cache dir");
dir
}
#[test]
fn cache_round_trip_works() {
let cache_dir = temp_cache_dir();
let index = ToolEmbeddingIndex {
tool_names: vec!["tool1".to_string()],
embeddings: vec![vec![1.0, 2.0, 3.0]],
};
let endpoint = "http://example.com/embeddings";
let tools = vec![("tool1".to_string(), "description1".to_string())];
let tool_hash = compute_tool_hash(&tools);
// Save to cache, then load it back and verify contents.
save_to_cache(&cache_dir, endpoint, &tool_hash, &index)
.expect("saving to cache should succeed");
let loaded = load_from_cache(&cache_dir, endpoint, &tool_hash)
.expect("expected Some(index) from cache");
assert_eq!(loaded.tool_names, index.tool_names);
assert_eq!(loaded.embeddings, index.embeddings);
}
#[test]
fn cache_invalidation_on_endpoint_change() {
let cache_dir = temp_cache_dir();
let index = ToolEmbeddingIndex {
tool_names: vec!["tool1".to_string()],
embeddings: vec![vec![1.0]],
};
let endpoint = "http://example.com/embeddings";
let tools = vec![("tool1".to_string(), "description1".to_string())];
let tool_hash = compute_tool_hash(&tools);
save_to_cache(&cache_dir, endpoint, &tool_hash, &index)
.expect("saving to cache should succeed");
let different_endpoint = "http://other.example.com/embeddings";
let loaded = load_from_cache(&cache_dir, different_endpoint, &tool_hash);
assert!(
loaded.is_none(),
"cache should be invalidated when endpoint changes"
);
}
#[test]
fn cache_invalidation_on_tool_hash_change() {
let cache_dir = temp_cache_dir();
let index = ToolEmbeddingIndex {
tool_names: vec!["tool1".to_string()],
embeddings: vec![vec![1.0]],
};
let endpoint = "http://example.com/embeddings";
let tools = vec![("tool1".to_string(), "description1".to_string())];
let tool_hash = compute_tool_hash(&tools);
save_to_cache(&cache_dir, endpoint, &tool_hash, &index)
.expect("saving to cache should succeed");
let different_tools =
vec![("tool2".to_string(), "a different description".to_string())];
let different_tool_hash = compute_tool_hash(&different_tools);
let loaded = load_from_cache(&cache_dir, endpoint, &different_tool_hash);
assert!(
loaded.is_none(),
"cache should be invalidated when tool hash changes"
);
}
#[test]
fn cache_corrupted_file_returns_none() {
let cache_dir = temp_cache_dir();
let index = ToolEmbeddingIndex {
tool_names: vec!["tool1".to_string()],
embeddings: vec![vec![1.0]],
};
let endpoint = "http://example.com/embeddings";
let tools = vec![("tool1".to_string(), "description1".to_string())];
let tool_hash = compute_tool_hash(&tools);
// First create a valid cache file.
save_to_cache(&cache_dir, endpoint, &tool_hash, &index)
.expect("saving to cache should succeed");
// Corrupt whatever file was created in the cache directory.
let cache_file_path = fs::read_dir(&cache_dir)
.expect("cache dir should be readable")
.filter_map(Result::ok)
.map(|entry| entry.path())
.next()
.expect("cache dir should contain at least one file");
fs::write(&cache_file_path, b"not valid json")
.expect("failed to corrupt cache file");
let loaded = load_from_cache(&cache_dir, endpoint, &tool_hash);
assert!(
loaded.is_none(),
"corrupted cache file should be treated as a cache miss"
);
}

Copilot uses AI. Check for mistakes.
}
5 changes: 5 additions & 0 deletions examples/localcowork/src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ pub(crate) fn data_dir() -> std::path::PathBuf {
.join(".localcowork")
}

/// Returns the cache directory for the app (embedding indexes, etc.).
pub(crate) fn cache_dir() -> std::path::PathBuf {
data_dir().join("cache")
}

/// Initialize the tracing subscriber — writes structured logs to the app data directory.
///
/// On each app startup:
Expand Down
Loading