Skip to content

Commit 9a64257

Browse files
committed
[fix] Fixes issues with Tool Input parsing.
* Ocassionally the model will generate a tool use which parameters are not a valid json. When this happens it corrupts the conversation history. * Here we first avoid storing the tool use and add the propert validation logic to the conversation history.
1 parent bf183f0 commit 9a64257

File tree

3 files changed

+103
-2
lines changed

3 files changed

+103
-2
lines changed

crates/chat-cli/src/cli/chat/consts.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ pub const MAX_USER_MESSAGE_SIZE: usize = 400_000;
1313

1414
pub const DUMMY_TOOL_NAME: &str = "dummy";
1515

16+
/// Marker key used to identify invalid tool arguments (non-JSON objects)
17+
pub const INVALID_TOOL_ARGS_MARKER: &str = "__error__invalid_args_json__";
18+
1619
pub const MAX_NUMBER_OF_IMAGES_PER_REQUEST: usize = 10;
1720

1821
/// In bytes - 10 MB

crates/chat-cli/src/cli/chat/parser.rs

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use tracing::{
2525
warn,
2626
};
2727

28+
use super::consts::INVALID_TOOL_ARGS_MARKER;
2829
use super::message::{
2930
AssistantMessage,
3031
AssistantToolUse,
@@ -472,7 +473,18 @@ impl ResponseParser {
472473
}
473474

474475
let args = match serde_json::from_str(&tool_string) {
475-
Ok(args) => args,
476+
Ok(args) => {
477+
// Ensure we have a valid JSON object
478+
match args {
479+
serde_json::Value::Object(_) => args,
480+
_ => {
481+
error!("Received non-object JSON for tool arguments: {:?}", args);
482+
serde_json::json!({
483+
INVALID_TOOL_ARGS_MARKER: format!("Expected JSON object, got: {:?}", args)
484+
})
485+
},
486+
}
487+
},
476488
Err(err) if !tool_string.is_empty() => {
477489
// If we failed deserializing after waiting for a long time, then this is most
478490
// likely bedrock responding with a stop event for some reason without actually
@@ -753,4 +765,73 @@ mod tests {
753765
"assistant text preceding a code reference should be ignored as this indicates licensed code is being returned"
754766
);
755767
}
768+
769+
#[tokio::test]
770+
async fn test_response_parser_avoid_invalid_json() {
771+
let content_to_ignore = "IGNORE ME PLEASE";
772+
let tool_use_id = "TEST_ID".to_string();
773+
let tool_name = "execute_bash".to_string();
774+
let tool_args = serde_json::json!("invalid json").to_string();
775+
let mut events = vec![
776+
ChatResponseStream::AssistantResponseEvent {
777+
content: "hi".to_string(),
778+
},
779+
ChatResponseStream::AssistantResponseEvent {
780+
content: " there".to_string(),
781+
},
782+
ChatResponseStream::AssistantResponseEvent {
783+
content: content_to_ignore.to_string(),
784+
},
785+
ChatResponseStream::CodeReferenceEvent(()),
786+
ChatResponseStream::ToolUseEvent {
787+
tool_use_id: tool_use_id.clone(),
788+
name: tool_name.clone(),
789+
input: None,
790+
stop: None,
791+
},
792+
ChatResponseStream::ToolUseEvent {
793+
tool_use_id: tool_use_id.clone(),
794+
name: tool_name.clone(),
795+
input: Some(tool_args),
796+
stop: None,
797+
},
798+
];
799+
events.reverse();
800+
let mock = SendMessageOutput::Mock(events);
801+
let mut parser = ResponseParser::new(
802+
mock,
803+
"".to_string(),
804+
None,
805+
1,
806+
vec![],
807+
mpsc::channel(32).0,
808+
Instant::now(),
809+
SystemTime::now(),
810+
CancellationToken::new(),
811+
Arc::new(Mutex::new(None)),
812+
);
813+
814+
let mut output = String::new();
815+
let mut found_invalid_marker = false;
816+
for _ in 0..5 {
817+
let event = parser.recv().await.unwrap();
818+
output.push_str(&format!("{:?}", event));
819+
820+
// Check for invalid args marker in ToolUse events
821+
if let ResponseEvent::ToolUse(tool_use) = event {
822+
if tool_use.args.get(INVALID_TOOL_ARGS_MARKER).is_some() {
823+
found_invalid_marker = true;
824+
}
825+
}
826+
}
827+
828+
assert!(
829+
!output.contains(content_to_ignore),
830+
"assistant text preceding a code reference should be ignored as this indicates licensed code is being returned"
831+
);
832+
assert!(
833+
found_invalid_marker,
834+
"Expected to find invalid args marker for non-object JSON"
835+
);
836+
}
756837
}

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ use crate::cli::agent::{
6464
McpServerConfig,
6565
};
6666
use crate::cli::chat::cli::prompts::GetPromptError;
67-
use crate::cli::chat::consts::DUMMY_TOOL_NAME;
67+
use crate::cli::chat::consts::{
68+
DUMMY_TOOL_NAME,
69+
INVALID_TOOL_ARGS_MARKER,
70+
};
6871
use crate::cli::chat::message::AssistantToolUse;
6972
use crate::cli::chat::server_messenger::{
7073
ServerMessengerBuilder,
@@ -847,6 +850,20 @@ impl ToolManager {
847850
}
848851

849852
pub async fn get_tool_from_tool_use(&mut self, value: AssistantToolUse) -> Result<Tool, ToolResult> {
853+
// Check for invalid args marker
854+
// in case parser, identified some
855+
// fundamental error in the inputs.
856+
if let Some(error_msg) = value.args.get(INVALID_TOOL_ARGS_MARKER).and_then(|v| v.as_str()) {
857+
return Err(ToolResult {
858+
tool_use_id: value.id.clone(),
859+
content: vec![ToolResultContentBlock::Text(format!(
860+
"The tool \"{}\" is supplied with invalid input format. {}",
861+
value.name, error_msg
862+
))],
863+
status: ToolResultStatus::Error,
864+
});
865+
}
866+
850867
let map_err = |parse_error| ToolResult {
851868
tool_use_id: value.id.clone(),
852869
content: vec![ToolResultContentBlock::Text(format!(

0 commit comments

Comments
 (0)