From 2bdfe5bac6e8859152072dc59b4004cb09ed4166 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Wed, 11 Jun 2025 19:24:06 -0400 Subject: [PATCH] [goose-llm] update example to add toolResult status error (#2854) --- bindings/kotlin/example/Usage.kt | 76 +++++- crates/goose-llm/src/message/contents.rs | 102 ++++++++ .../goose-llm/src/message/message_content.rs | 243 +++++++++++++++++- 3 files changed, 407 insertions(+), 14 deletions(-) diff --git a/bindings/kotlin/example/Usage.kt b/bindings/kotlin/example/Usage.kt index 99089515..cdb06c82 100644 --- a/bindings/kotlin/example/Usage.kt +++ b/bindings/kotlin/example/Usage.kt @@ -29,7 +29,7 @@ fun main() = runBlocking { "value": { "name": "calculator_extension__toolname", "arguments": { - "operation": "multiply", + "operation": "doesnotexist", "numbers": [7, 6] }, "needsApproval": false @@ -45,6 +45,51 @@ fun main() = runBlocking { Message( role = Role.USER, created = now + 3, + content = listOf( + MessageContent.ToolResp( + ToolResponse( + id = "calc1", + toolResult = """ + { + "status": "error", + "error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add', 'subtract', 'multiply', 'divide']" + } + """.trimIndent() + ) + ) + ) + ), + + // 4) Assistant makes a tool request (ToolReq) to calculate 7×6 + Message( + role = Role.ASSISTANT, + created = now + 4, + content = listOf( + MessageContent.ToolReq( + ToolRequest( + id = "calc1", + toolCall = """ + { + "status": "success", + "value": { + "name": "calculator_extension__toolname", + "arguments": { + "operation": "multiply", + "numbers": [7, 6] + }, + "needsApproval": false + } + } + """.trimIndent() + ) + ) + ) + ), + + // 5) User (on behalf of the tool) responds with the tool result (ToolResp) + Message( + role = Role.USER, + created = now + 5, content = listOf( MessageContent.ToolResp( ToolResponse( @@ -124,8 +169,30 @@ fun main() = runBlocking { val extensions = listOf(calculator_extension) val systemPreamble = "You are a helpful assistant." + // Testing with tool calls with an error in tool name + val reqToolErr = createCompletionRequest( + providerName, + providerConfig, + modelConfig, + systemPreamble, + messages = listOf( + Message( + role = Role.USER, + created = now, + content = listOf( + MessageContent.Text( + TextContent("What is 7 x 6?") + ) + ) + )), + extensions = extensions + ) - val req = createCompletionRequest( + val respToolErr = completion(reqToolErr) + println("\nCompletion Response (one msg):\n${respToolErr.message}") + println() + + val reqAll = createCompletionRequest( providerName, providerConfig, modelConfig, @@ -134,14 +201,13 @@ fun main() = runBlocking { extensions = extensions ) - val response = completion(req) - println("\nCompletion Response:\n${response.message}") + val respAll = completion(reqAll) + println("\nCompletion Response (all msgs):\n${respAll.message}") println() // ---- UI Extraction (custom schema) ---- runUiExtraction(providerName, providerConfig) - // --- Prompt Override --- val prompt_req = createCompletionRequest( providerName, diff --git a/crates/goose-llm/src/message/contents.rs b/crates/goose-llm/src/message/contents.rs index 89e80f79..9c8f459f 100644 --- a/crates/goose-llm/src/message/contents.rs +++ b/crates/goose-llm/src/message/contents.rs @@ -82,3 +82,105 @@ uniffi::custom_type!(Contents, Vec, { Ok(Contents::from(contents)) }, }); + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::core::{Content, TextContent, ToolCall, ToolError}; + use serde_json::json; + + // ------------------------------------------------------------ + // Helpers + // ------------------------------------------------------------ + fn make_tool_req_ok(id: &str) -> MessageContent { + let call = ToolCall::new("echo", json!({"text": "hi"})); + MessageContent::tool_request(id, Ok(call).into()) + } + + fn make_tool_resp_ok(id: &str) -> MessageContent { + let body = vec![Content::Text(TextContent { + text: "done".into(), + })]; + MessageContent::tool_response(id, Ok(body).into()) + } + + fn make_tool_req_err(id: &str) -> MessageContent { + let err = ToolError::NotFound(format!( + "The provided function name '{}' had invalid characters", + "bad$name" + )); + MessageContent::tool_request(id, Err(err).into()) + } + + fn make_tool_resp_err(id: &str) -> MessageContent { + let err = ToolError::InvalidParameters("Could not interpret tool use parameters".into()); + MessageContent::tool_response(id, Err(err).into()) + } + + // ------------------------------------------------------------ + // Round-trip: success + // ------------------------------------------------------------ + #[test] + fn contents_roundtrip_ok() { + let items: Contents = vec![make_tool_req_ok("req-1"), make_tool_resp_ok("resp-1")].into(); + + // ---- serialise + let json_str = serde_json::to_string(&items).expect("serialise OK"); + println!("JSON: {:?}", json_str); + + assert!( + json_str.contains(r#""type":"toolReq""#) + && json_str.contains(r#""type":"toolResp""#) + && json_str.contains(r#""status":"success""#), + "JSON should contain both variants and success-status" + ); + + // ---- deserialise + let parsed: Contents = serde_json::from_str(&json_str).expect("deserialise OK"); + + assert_eq!(parsed, items, "full round-trip equality"); + } + + // ------------------------------------------------------------ + // Round-trip: error (all variants collapse to ExecutionError) + // ------------------------------------------------------------ + #[test] + fn contents_roundtrip_err() { + let original_items: Contents = + vec![make_tool_req_err("req-e"), make_tool_resp_err("resp-e")].into(); + + // ---- serialise + let json_str = serde_json::to_string(&original_items).expect("serialise OK"); + println!("JSON: {:?}", json_str); + + assert!(json_str.contains(r#""status":"error""#)); + + // ---- deserialise + let parsed: Contents = serde_json::from_str(&json_str).expect("deserialise OK"); + + // ─── validate structure ─────────────────────────────────── + assert_eq!(parsed.len(), 2); + + // ToolReq error + match &parsed[0] { + MessageContent::ToolReq(req) => match &*req.tool_call { + Err(ToolError::ExecutionError(msg)) => { + assert!(msg.contains("invalid characters")) + } + other => panic!("expected ExecutionError, got {:?}", other), + }, + other => panic!("expected ToolReq, got {:?}", other), + } + + // ToolResp error + match &parsed[1] { + MessageContent::ToolResp(resp) => match &*resp.tool_result { + Err(ToolError::ExecutionError(msg)) => { + assert!(msg.contains("interpret tool use parameters")) + } + other => panic!("expected ExecutionError, got {:?}", other), + }, + other => panic!("expected ToolResp, got {:?}", other), + } + } +} diff --git a/crates/goose-llm/src/message/message_content.rs b/crates/goose-llm/src/message/message_content.rs index b657fbfb..75daa3a8 100644 --- a/crates/goose-llm/src/message/message_content.rs +++ b/crates/goose-llm/src/message/message_content.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use serde_json; +use serde_json::{self, Deserializer, Serializer}; use crate::message::tool_result_serde; use crate::types::core::{Content, ImageContent, TextContent, ToolCall, ToolResult}; @@ -52,22 +52,43 @@ impl From, crate::types::core::ToolError>> for ToolResponseT // — Register the newtypes with UniFFI, converting via JSON strings — // UniFFI’s FFI layer supports only primitive buffers (here String), so we JSON-serialize // through our `tool_result_serde` to preserve the same success/error schema on both sides. +// see https://github.com/mozilla/uniffi-rs/issues/2533 uniffi::custom_type!(ToolRequestToolCall, String, { - lower: |obj| { - serde_json::to_string(&obj.0).unwrap() + lower: |wrapper: &ToolRequestToolCall| { + let mut buf = Vec::new(); + { + let mut ser = Serializer::new(&mut buf); + // note the borrow on wrapper.0 + tool_result_serde::serialize(&wrapper.0, &mut ser) + .expect("ToolRequestToolCall serialization failed"); + } + String::from_utf8(buf).expect("ToolRequestToolCall produced invalid UTF-8") }, - try_lift: |val| { - Ok(serde_json::from_str(&val).unwrap() ) + try_lift: |s: String| { + let mut de = Deserializer::from_str(&s); + let result = tool_result_serde::deserialize(&mut de) + .map_err(anyhow::Error::new)?; + Ok(ToolRequestToolCall(result)) }, }); uniffi::custom_type!(ToolResponseToolResult, String, { - lower: |obj| { - serde_json::to_string(&obj.0).unwrap() + lower: |wrapper: &ToolResponseToolResult| { + let mut buf = Vec::new(); + { + let mut ser = Serializer::new(&mut buf); + // note the borrow on wrapper.0 + tool_result_serde::serialize(&wrapper.0, &mut ser) + .expect("ToolResponseToolResult serialization failed"); + } + String::from_utf8(buf).expect("ToolResponseToolResult produced invalid UTF-8") }, - try_lift: |val| { - Ok(serde_json::from_str(&val).unwrap() ) + try_lift: |s: String| { + let mut de = Deserializer::from_str(&s); + let result = tool_result_serde::deserialize(&mut de) + .map_err(anyhow::Error::new)?; + Ok(ToolResponseToolResult(result)) }, }); @@ -238,3 +259,207 @@ impl From for MessageContent { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::core::{ToolCall, ToolError}; + use crate::UniFfiTag; + use serde_json::json; + use uniffi::{FfiConverter, RustBuffer}; + + // ---------- ToolRequestToolCall ---------------------------------------------------------- + + #[test] + fn tool_request_tool_call_roundtrip_ok() { + // Build a valid ToolCall + let call = ToolCall::new("my_function", json!({"a": 1, "b": 2})); + + // Wrap it in the new-type + let wrapper = ToolRequestToolCall::from(Ok(call.clone())); + + // Serialize → JSON + let json_str = serde_json::to_string(&wrapper).expect("serialize OK"); + assert!( + json_str.contains(r#""status":"success""#), + "must mark success" + ); + + // Deserialize ← JSON + let parsed: ToolRequestToolCall = serde_json::from_str(&json_str).expect("deserialize OK"); + + // Round-trip equality + assert_eq!(*parsed, Ok(call)); + } + + #[test] + fn tool_request_tool_call_roundtrip_err() { + // Typical failure variant that could come from `is_valid_function_name` + let err = ToolError::NotFound( + "The provided function name 'bad$name' had invalid characters".into(), + ); + + let wrapper = ToolRequestToolCall::from(Err(err.clone())); + + let json_str = serde_json::to_string(&wrapper).expect("serialize OK"); + assert!( + json_str.contains(r#""status":"error""#) && json_str.contains("invalid characters"), + "must mark error and carry message" + ); + + let parsed: ToolRequestToolCall = serde_json::from_str(&json_str).expect("deserialize OK"); + + match &*parsed { + Err(ToolError::ExecutionError(msg)) => { + assert!(msg.contains("invalid characters")) + } + other => panic!("expected ExecutionError, got {:?}", other), + } + } + + // ---------- ToolResponseToolResult ------------------------------------------------------- + + #[test] + fn tool_response_tool_result_roundtrip_ok() { + // Minimal content vector (one text item) + let content_vec = vec![Content::Text(TextContent { + text: "hello".into(), + })]; + + let wrapper = ToolResponseToolResult::from(Ok(content_vec.clone())); + + let json_str = serde_json::to_string(&wrapper).expect("serialize OK"); + assert!(json_str.contains(r#""status":"success""#)); + + let parsed: ToolResponseToolResult = + serde_json::from_str(&json_str).expect("deserialize OK"); + + assert_eq!(*parsed, Ok(content_vec)); + } + + #[test] + fn tool_response_tool_result_roundtrip_err() { + let err = ToolError::InvalidParameters("Could not interpret tool use parameters".into()); + + let wrapper = ToolResponseToolResult::from(Err(err.clone())); + + let json_str = serde_json::to_string(&wrapper).expect("serialize OK"); + assert!(json_str.contains(r#""status":"error""#)); + + let parsed: ToolResponseToolResult = + serde_json::from_str(&json_str).expect("deserialize OK"); + + match &*parsed { + Err(ToolError::ExecutionError(msg)) => { + assert!(msg.contains("interpret tool use")) + } + other => panic!("expected ExecutionError, got {:?}", other), + } + } + + // ---------- FFI (lower / lift) round-trips ---------------------------------------------- + // https://mozilla.github.io/uniffi-rs/latest/internals/lifting_and_lowering.html + + #[test] + fn ffi_roundtrip_tool_request_ok_and_err() { + // ---------- status: success ---------- + let ok_call = ToolCall::new("echo", json!({"text": "hi"})); + let ok_wrapper = ToolRequestToolCall::from(Ok(ok_call.clone())); + + // First lower → inspect JSON + let buf1: RustBuffer = + >::lower(ok_wrapper.clone()); + + let json_ok: String = + >::try_lift(buf1).expect("lift String OK"); + println!("ToolReq - Lowered JSON (status: success): {:?}", json_ok); + assert!(json_ok.contains(r#""status":"success""#)); + + // Second lower → round-trip wrapper + let buf2: RustBuffer = + >::lower(ok_wrapper.clone()); + + let lifted_ok = >::try_lift(buf2) + .expect("lift wrapper OK"); + println!( + "ToolReq - Lifted wrapper (status: success): {:?}", + lifted_ok + ); + assert_eq!(lifted_ok, ok_wrapper); + + // ---------- status: error ---------- + let err_call = ToolError::NotFound("no such function".into()); + let err_wrapper = ToolRequestToolCall::from(Err(err_call.clone())); + + let buf1: RustBuffer = + >::lower(err_wrapper.clone()); + let json_err: String = + >::try_lift(buf1).expect("lift String ERR"); + println!("ToolReq - Lowered JSON (status: error): {:?}", json_err); + assert!(json_err.contains(r#""status":"error""#)); + + let buf2: RustBuffer = + >::lower(err_wrapper.clone()); + let lifted_err = >::try_lift(buf2) + .expect("lift wrapper ERR"); + println!("ToolReq - Lifted wrapper (status: error): {:?}", lifted_err); + + match &*lifted_err { + Err(ToolError::ExecutionError(msg)) => { + assert!(msg.contains("no such function")) + } + other => panic!("expected ExecutionError, got {:?}", other), + } + } + + #[test] + fn ffi_roundtrip_tool_response_ok_and_err() { + // ---------- status: success ---------- + let body = vec![Content::Text(TextContent { + text: "done".into(), + })]; + let ok_wrapper = ToolResponseToolResult::from(Ok(body.clone())); + + let buf1: RustBuffer = + >::lower(ok_wrapper.clone()); + let json_ok: String = >::try_lift(buf1).unwrap(); + println!("ToolResp - Lowered JSON (status: success): {:?}", json_ok); + assert!(json_ok.contains(r#""status":"success""#)); + + let buf2: RustBuffer = + >::lower(ok_wrapper.clone()); + let lifted_ok = + >::try_lift(buf2).unwrap(); + println!( + "ToolResp - Lifted wrapper (status: success): {:?}", + lifted_ok + ); + assert_eq!(lifted_ok, ok_wrapper); + + // ---------- status: error ---------- + let err_call = ToolError::InvalidParameters("bad params".into()); + let err_wrapper = ToolResponseToolResult::from(Err(err_call.clone())); + + let buf1: RustBuffer = + >::lower(err_wrapper.clone()); + let json_err: String = >::try_lift(buf1).unwrap(); + println!("ToolResp - Lowered JSON (status: error): {:?}", json_err); + assert!(json_err.contains(r#""status":"error""#)); + + let buf2: RustBuffer = + >::lower(err_wrapper.clone()); + let lifted_err = + >::try_lift(buf2).unwrap(); + println!( + "ToolResp - Lifted wrapper (status: error): {:?}", + lifted_err + ); + + match &*lifted_err { + Err(ToolError::ExecutionError(msg)) => { + assert!(msg.contains("bad params")) + } + other => panic!("expected ExecutionError, got {:?}", other), + } + } +}