Skip to content

Commit 19873a4

Browse files
committed
adjusts tool settings eval order
1 parent c99ed57 commit 19873a4

File tree

4 files changed

+79
-65
lines changed

4 files changed

+79
-65
lines changed

crates/chat-cli/src/cli/chat/tools/execute/mod.rs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ impl ExecuteCommand {
202202
let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" };
203203
let is_in_allowlist = agent.allowed_tools.contains(tool_name);
204204
match agent.tools_settings.get(tool_name) {
205-
Some(settings) if is_in_allowlist => {
205+
Some(settings) => {
206206
let Settings {
207207
allowed_commands,
208208
denied_commands,
@@ -226,7 +226,7 @@ impl ExecuteCommand {
226226
return PermissionEvalResult::Deny(denied_match_set);
227227
}
228228

229-
if self.requires_acceptance(Some(&allowed_commands), allow_read_only) {
229+
if !is_in_allowlist || self.requires_acceptance(Some(&allowed_commands), allow_read_only) {
230230
PermissionEvalResult::Ask
231231
} else {
232232
PermissionEvalResult::Allow
@@ -263,10 +263,7 @@ pub fn format_output(output: &str, max_size: usize) -> String {
263263

264264
#[cfg(test)]
265265
mod tests {
266-
use std::collections::{
267-
HashMap,
268-
HashSet,
269-
};
266+
use std::collections::HashMap;
270267

271268
use super::*;
272269
use crate::cli::agent::ToolSettingTarget;
@@ -408,13 +405,8 @@ mod tests {
408405
#[test]
409406
fn test_eval_perm() {
410407
let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" };
411-
let agent = Agent {
408+
let mut agent = Agent {
412409
name: "test_agent".to_string(),
413-
allowed_tools: {
414-
let mut allowed_tools = HashSet::<String>::new();
415-
allowed_tools.insert(tool_name.to_string());
416-
allowed_tools
417-
},
418410
tools_settings: {
419411
let mut map = HashMap::<ToolSettingTarget, serde_json::Value>::new();
420412
map.insert(
@@ -428,20 +420,32 @@ mod tests {
428420
..Default::default()
429421
};
430422

431-
let tool = serde_json::from_value::<ExecuteCommand>(serde_json::json!({
423+
let tool_one = serde_json::from_value::<ExecuteCommand>(serde_json::json!({
432424
"command": "git status",
433425
}))
434426
.unwrap();
435427

436-
let res = tool.eval_perm(&agent);
428+
let res = tool_one.eval_perm(&agent);
437429
assert!(matches!(res, PermissionEvalResult::Deny(ref rules) if rules.contains(&"\\Agit .*\\z".to_string())));
438430

439-
let tool = serde_json::from_value::<ExecuteCommand>(serde_json::json!({
431+
let tool_two = serde_json::from_value::<ExecuteCommand>(serde_json::json!({
440432
"command": "echo hello",
441433
}))
442434
.unwrap();
443435

444-
let res = tool.eval_perm(&agent);
436+
let res = tool_two.eval_perm(&agent);
437+
assert!(matches!(res, PermissionEvalResult::Ask));
438+
439+
agent.allowed_tools.insert(tool_name.to_string());
440+
441+
let res = tool_two.eval_perm(&agent);
442+
assert!(matches!(res, PermissionEvalResult::Allow));
443+
444+
// Denied list should remain denied
445+
let res = tool_one.eval_perm(&agent);
446+
assert!(matches!(res, PermissionEvalResult::Deny(ref rules) if rules.contains(&"\\Agit .*\\z".to_string())));
447+
448+
let res = tool_two.eval_perm(&agent);
445449
assert!(matches!(res, PermissionEvalResult::Allow));
446450
}
447451

crates/chat-cli/src/cli/chat/tools/fs_read.rs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ impl FsRead {
126126

127127
let is_in_allowlist = agent.allowed_tools.contains("fs_read");
128128
match agent.tools_settings.get("fs_read") {
129-
Some(settings) if is_in_allowlist => {
129+
Some(settings) => {
130130
let Settings {
131131
allowed_paths,
132132
denied_paths,
@@ -188,7 +188,7 @@ impl FsRead {
188188

189189
// We only want to ask if we are not allowing read only
190190
// operation
191-
if !allow_read_only && !allow_set.is_match(path) {
191+
if !is_in_allowlist && !allow_read_only && !allow_set.is_match(path) {
192192
ask = true;
193193
}
194194
},
@@ -209,7 +209,10 @@ impl FsRead {
209209

210210
// We only want to ask if we are not allowing read only
211211
// operation
212-
if !allow_read_only && !paths.iter().any(|path| allow_set.is_match(path)) {
212+
if !is_in_allowlist
213+
&& !allow_read_only
214+
&& !paths.iter().any(|path| allow_set.is_match(path))
215+
{
213216
ask = true;
214217
}
215218
},
@@ -844,10 +847,7 @@ fn format_mode(mode: u32) -> [char; 9] {
844847

845848
#[cfg(test)]
846849
mod tests {
847-
use std::collections::{
848-
HashMap,
849-
HashSet,
850-
};
850+
use std::collections::HashMap;
851851

852852
use super::*;
853853
use crate::cli::agent::ToolSettingTarget;
@@ -1387,13 +1387,8 @@ mod tests {
13871387
const DENIED_PATH_ONE: &str = "/some/denied/path";
13881388
const DENIED_PATH_GLOB: &str = "/denied/glob/**/path";
13891389

1390-
let agent = Agent {
1390+
let mut agent = Agent {
13911391
name: "test_agent".to_string(),
1392-
allowed_tools: {
1393-
let mut allowed_tools = HashSet::<String>::new();
1394-
allowed_tools.insert("fs_read".to_string());
1395-
allowed_tools
1396-
},
13971392
tools_settings: {
13981393
let mut map = HashMap::<ToolSettingTarget, serde_json::Value>::new();
13991394
map.insert(
@@ -1407,7 +1402,7 @@ mod tests {
14071402
..Default::default()
14081403
};
14091404

1410-
let tool = serde_json::from_value::<FsRead>(serde_json::json!({
1405+
let tool_one = serde_json::from_value::<FsRead>(serde_json::json!({
14111406
"operations": [
14121407
{ "path": DENIED_PATH_ONE, "mode": "Line", "start_line": 1, "end_line": 2 },
14131408
{ "path": "/denied/glob", "mode": "Directory" },
@@ -1418,7 +1413,17 @@ mod tests {
14181413
}))
14191414
.unwrap();
14201415

1421-
let res = tool.eval_perm(&agent);
1416+
let res = tool_one.eval_perm(&agent);
1417+
assert!(matches!(
1418+
res,
1419+
PermissionEvalResult::Deny(ref deny_list)
1420+
if deny_list.iter().filter(|p| *p == DENIED_PATH_GLOB).collect::<Vec<_>>().len() == 2
1421+
&& deny_list.iter().filter(|p| *p == DENIED_PATH_ONE).collect::<Vec<_>>().len() == 1
1422+
));
1423+
1424+
agent.allowed_tools.insert("fs_read".to_string());
1425+
1426+
let res = tool_one.eval_perm(&agent);
14221427
assert!(matches!(
14231428
res,
14241429
PermissionEvalResult::Deny(ref deny_list)

crates/chat-cli/src/cli/chat/tools/fs_write.rs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ impl FsWrite {
433433

434434
let is_in_allowlist = agent.allowed_tools.contains("fs_write");
435435
match agent.tools_settings.get("fs_write") {
436-
Some(settings) if is_in_allowlist => {
436+
Some(settings) => {
437437
let Settings {
438438
allowed_paths,
439439
denied_paths,
@@ -486,7 +486,7 @@ impl FsWrite {
486486
.collect::<Vec<_>>()
487487
});
488488
}
489-
if allow_set.is_match(path) {
489+
if is_in_allowlist || allow_set.is_match(path) {
490490
return PermissionEvalResult::Allow;
491491
}
492492
},
@@ -808,10 +808,7 @@ fn syntect_to_crossterm_color(syntect: syntect::highlighting::Color) -> style::C
808808

809809
#[cfg(test)]
810810
mod tests {
811-
use std::collections::{
812-
HashMap,
813-
HashSet,
814-
};
811+
use std::collections::HashMap;
815812

816813
use super::*;
817814
use crate::cli::agent::ToolSettingTarget;
@@ -1270,13 +1267,8 @@ mod tests {
12701267
const DENIED_PATH_ONE: &str = "/some/denied/path/**";
12711268
const DENIED_PATH_GLOB: &str = "/denied/glob/**/path/**";
12721269

1273-
let agent = Agent {
1270+
let mut agent = Agent {
12741271
name: "test_agent".to_string(),
1275-
allowed_tools: {
1276-
let mut allowed_tools = HashSet::<String>::new();
1277-
allowed_tools.insert("fs_write".to_string());
1278-
allowed_tools
1279-
},
12801272
tools_settings: {
12811273
let mut map = HashMap::<ToolSettingTarget, serde_json::Value>::new();
12821274
map.insert(
@@ -1290,51 +1282,62 @@ mod tests {
12901282
..Default::default()
12911283
};
12921284

1293-
let tool = serde_json::from_value::<FsWrite>(serde_json::json!({
1285+
let tool_one = serde_json::from_value::<FsWrite>(serde_json::json!({
12941286
"path": "/not/a/denied/path/file.txt",
12951287
"command": "create",
12961288
"file_text": "content in nested path"
12971289
}))
12981290
.unwrap();
12991291

1300-
let res = tool.eval_perm(&agent);
1292+
let res = tool_one.eval_perm(&agent);
13011293
assert!(matches!(res, PermissionEvalResult::Ask));
13021294

1303-
let tool = serde_json::from_value::<FsWrite>(serde_json::json!({
1295+
let tool_two = serde_json::from_value::<FsWrite>(serde_json::json!({
13041296
"path": format!("{DENIED_PATH_ONE}/file.txt"),
13051297
"command": "create",
13061298
"file_text": "content in nested path"
13071299
}))
13081300
.unwrap();
13091301

1310-
let res = tool.eval_perm(&agent);
1302+
let res = tool_two.eval_perm(&agent);
13111303
assert!(
13121304
matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_ONE.to_string()))
13131305
);
13141306

1315-
let tool = serde_json::from_value::<FsWrite>(serde_json::json!({
1307+
let tool_three = serde_json::from_value::<FsWrite>(serde_json::json!({
13161308
"path": format!("/denied/glob/child_one/path/file.txt"),
13171309
"command": "create",
13181310
"file_text": "content in nested path"
13191311
}))
13201312
.unwrap();
13211313

1322-
let res = tool.eval_perm(&agent);
1314+
let res = tool_three.eval_perm(&agent);
13231315
assert!(
13241316
matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_GLOB.to_string()))
13251317
);
13261318

1327-
let tool = serde_json::from_value::<FsWrite>(serde_json::json!({
1319+
let tool_four = serde_json::from_value::<FsWrite>(serde_json::json!({
13281320
"path": format!("/denied/glob/child_one/grand_child_one/path/file.txt"),
13291321
"command": "create",
13301322
"file_text": "content in nested path"
13311323
}))
13321324
.unwrap();
13331325

1334-
let res = tool.eval_perm(&agent);
1326+
let res = tool_four.eval_perm(&agent);
1327+
assert!(
1328+
matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_GLOB.to_string()))
1329+
);
1330+
1331+
agent.allowed_tools.insert("fs_write".to_string());
1332+
1333+
// Denied list should remained denied
1334+
let res = tool_four.eval_perm(&agent);
13351335
assert!(
13361336
matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_GLOB.to_string()))
13371337
);
1338+
1339+
let res = tool_one.eval_perm(&agent);
1340+
assert!(matches!(res, PermissionEvalResult::Allow));
13381341
}
13391342

13401343
#[tokio::test]

crates/chat-cli/src/cli/chat/tools/use_aws.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ impl UseAws {
194194
let Self { service_name, .. } = self;
195195
let is_in_allowlist = agent.allowed_tools.contains("use_aws");
196196
match agent.tools_settings.get("use_aws") {
197-
Some(settings) if is_in_allowlist => {
197+
Some(settings) => {
198198
let settings = match serde_json::from_value::<Settings>(settings.clone()) {
199199
Ok(settings) => settings,
200200
Err(e) => {
@@ -205,7 +205,7 @@ impl UseAws {
205205
if settings.denied_services.contains(service_name) {
206206
return PermissionEvalResult::Deny(vec![service_name.clone()]);
207207
}
208-
if settings.allowed_services.contains(service_name) {
208+
if is_in_allowlist || settings.allowed_services.contains(service_name) {
209209
return PermissionEvalResult::Allow;
210210
}
211211
PermissionEvalResult::Ask
@@ -224,8 +224,6 @@ impl UseAws {
224224

225225
#[cfg(test)]
226226
mod tests {
227-
use std::collections::HashSet;
228-
229227
use super::*;
230228
use crate::cli::agent::ToolSettingTarget;
231229

@@ -351,21 +349,16 @@ mod tests {
351349

352350
#[test]
353351
fn test_eval_perm() {
354-
let cmd = use_aws! {{
352+
let cmd_one = use_aws! {{
355353
"service_name": "s3",
356354
"operation_name": "put-object",
357355
"region": "us-west-2",
358356
"profile_name": "default",
359357
"label": ""
360358
}};
361359

362-
let agent = Agent {
360+
let mut agent = Agent {
363361
name: "test_agent".to_string(),
364-
allowed_tools: {
365-
let mut allowed_tools = HashSet::<String>::new();
366-
allowed_tools.insert("use_aws".to_string());
367-
allowed_tools
368-
},
369362
tools_settings: {
370363
let mut map = HashMap::<ToolSettingTarget, serde_json::Value>::new();
371364
map.insert(
@@ -379,18 +372,27 @@ mod tests {
379372
..Default::default()
380373
};
381374

382-
let res = cmd.eval_perm(&agent);
375+
let res = cmd_one.eval_perm(&agent);
383376
assert!(matches!(res, PermissionEvalResult::Deny(ref services) if services.contains(&"s3".to_string())));
384377

385-
let cmd = use_aws! {{
378+
let cmd_two = use_aws! {{
386379
"service_name": "api_gateway",
387380
"operation_name": "request",
388381
"region": "us-west-2",
389382
"profile_name": "default",
390383
"label": ""
391384
}};
392385

393-
let res = cmd.eval_perm(&agent);
386+
let res = cmd_two.eval_perm(&agent);
394387
assert!(matches!(res, PermissionEvalResult::Ask));
388+
389+
agent.allowed_tools.insert("use_aws".to_string());
390+
391+
let res = cmd_two.eval_perm(&agent);
392+
assert!(matches!(res, PermissionEvalResult::Allow));
393+
394+
// Denied services should still be denied after trusting tool
395+
let res = cmd_one.eval_perm(&agent);
396+
assert!(matches!(res, PermissionEvalResult::Deny(ref services) if services.contains(&"s3".to_string())));
395397
}
396398
}

0 commit comments

Comments
 (0)