mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-21 22:44:24 +01:00
draft: use rust messages in typescript (#1393)
This commit is contained in:
@@ -10,9 +10,8 @@ use bytes::Bytes;
|
||||
use futures::{stream::StreamExt, Stream};
|
||||
use goose::message::{Message, MessageContent};
|
||||
|
||||
use mcp_core::{content::Content, role::Role};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use mcp_core::role::Role;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
pin::Pin,
|
||||
@@ -23,33 +22,13 @@ use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
// Types matching the incoming JSON structure
|
||||
// Direct message serialization for the chat request
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatRequest {
|
||||
messages: Vec<IncomingMessage>,
|
||||
messages: Vec<Message>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct IncomingMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
#[serde(rename = "toolInvocations")]
|
||||
tool_invocations: Vec<ToolInvocation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ToolInvocation {
|
||||
state: String,
|
||||
#[serde(rename = "toolCallId")]
|
||||
tool_call_id: String,
|
||||
#[serde(rename = "toolName")]
|
||||
tool_name: String,
|
||||
args: Value,
|
||||
result: Option<Vec<Content>>,
|
||||
}
|
||||
|
||||
// Custom SSE response type that implements the Vercel AI SDK protocol
|
||||
// Custom SSE response type for streaming messages
|
||||
pub struct SseResponse {
|
||||
rx: ReceiverStream<String>,
|
||||
}
|
||||
@@ -79,188 +58,32 @@ impl IntoResponse for SseResponse {
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache")
|
||||
.header("Connection", "keep-alive")
|
||||
.header("x-vercel-ai-data-stream", "v1")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// Convert incoming messages to our internal Message type
|
||||
fn convert_messages(incoming: Vec<IncomingMessage>) -> Vec<Message> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for msg in incoming {
|
||||
match msg.role.as_str() {
|
||||
"user" => {
|
||||
messages.push(Message::user().with_text(msg.content));
|
||||
}
|
||||
"assistant" => {
|
||||
// First handle any tool invocations - each represents a complete request/response cycle
|
||||
for tool in msg.tool_invocations {
|
||||
if tool.state == "result" {
|
||||
// Add the original tool request from assistant
|
||||
let tool_call = mcp_core::tool::ToolCall {
|
||||
name: tool.tool_name,
|
||||
arguments: tool.args,
|
||||
};
|
||||
messages.push(
|
||||
Message::assistant()
|
||||
.with_tool_request(tool.tool_call_id.clone(), Ok(tool_call)),
|
||||
);
|
||||
|
||||
// Add the tool response from user
|
||||
if let Some(result) = &tool.result {
|
||||
messages.push(
|
||||
Message::user()
|
||||
.with_tool_response(tool.tool_call_id, Ok(result.clone())),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Then add the assistant's text response after tool interactions
|
||||
if !msg.content.is_empty() {
|
||||
messages.push(Message::assistant().with_text(msg.content));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!("Unknown role: {}", msg.role);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages
|
||||
// Message event types for SSE streaming
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum MessageEvent {
|
||||
Message { message: Message },
|
||||
Error { error: String },
|
||||
Finish { reason: String },
|
||||
}
|
||||
|
||||
// Protocol-specific message formatting
|
||||
struct ProtocolFormatter;
|
||||
|
||||
impl ProtocolFormatter {
|
||||
fn format_text(text: &str) -> String {
|
||||
let encoded_text = serde_json::to_string(text).unwrap_or_else(|_| String::new());
|
||||
format!("0:{}\n", encoded_text)
|
||||
}
|
||||
|
||||
fn format_tool_call(id: &str, name: &str, args: &Value) -> String {
|
||||
// Tool calls start with "9:"
|
||||
let tool_call = json!({
|
||||
"toolCallId": id,
|
||||
"toolName": name,
|
||||
"args": args
|
||||
});
|
||||
format!("9:{}\n", tool_call)
|
||||
}
|
||||
|
||||
fn format_tool_response(id: &str, result: &Vec<Content>) -> String {
|
||||
// Tool responses start with "a:"
|
||||
let response = json!({
|
||||
"toolCallId": id,
|
||||
"result": result,
|
||||
});
|
||||
format!("a:{}\n", response)
|
||||
}
|
||||
|
||||
fn format_error(error: &str) -> String {
|
||||
// Error messages start with "3:" in the new protocol.
|
||||
let encoded_error = serde_json::to_string(error).unwrap_or_else(|_| String::new());
|
||||
format!("3:{}\n", encoded_error)
|
||||
}
|
||||
|
||||
fn format_finish(reason: &str) -> String {
|
||||
// Finish messages start with "d:"
|
||||
let finish = json!({
|
||||
"finishReason": reason,
|
||||
"usage": {
|
||||
"promptTokens": 0,
|
||||
"completionTokens": 0
|
||||
}
|
||||
});
|
||||
format!("d:{}\n", finish)
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream_message(
|
||||
message: Message,
|
||||
// Stream a message as an SSE event
|
||||
async fn stream_event(
|
||||
event: MessageEvent,
|
||||
tx: &mpsc::Sender<String>,
|
||||
) -> Result<(), mpsc::error::SendError<String>> {
|
||||
match message.role {
|
||||
Role::User => {
|
||||
// Handle tool responses
|
||||
for content in message.content {
|
||||
// I believe with the protocol we aren't intended to pass back user messages, so we only deal with
|
||||
// the tool responses here
|
||||
if let MessageContent::ToolResponse(response) = content {
|
||||
// We should return a result for either an error or a success
|
||||
match response.tool_result {
|
||||
Ok(result) => {
|
||||
tx.send(ProtocolFormatter::format_tool_response(
|
||||
&response.id,
|
||||
&result,
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
Err(err) => {
|
||||
// Send an error message first
|
||||
tx.send(ProtocolFormatter::format_error(&err.to_string()))
|
||||
.await?;
|
||||
// Then send an empty tool response to maintain the protocol
|
||||
let result =
|
||||
vec![Content::text(format!("Error: {}", err)).with_priority(0.0)];
|
||||
tx.send(ProtocolFormatter::format_tool_response(
|
||||
&response.id,
|
||||
&result,
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Role::Assistant => {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::ToolRequest(request) => {
|
||||
match request.tool_call {
|
||||
Ok(tool_call) => {
|
||||
tx.send(ProtocolFormatter::format_tool_call(
|
||||
&request.id,
|
||||
&tool_call.name,
|
||||
&tool_call.arguments,
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
Err(err) => {
|
||||
// Send a placeholder tool call to maintain protocol
|
||||
tx.send(ProtocolFormatter::format_tool_call(
|
||||
&request.id,
|
||||
"invalid_tool",
|
||||
&json!({"error": err.to_string()}),
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::Text(text) => {
|
||||
for line in text.text.lines() {
|
||||
let modified_line = format!("{}\n", line);
|
||||
tx.send(ProtocolFormatter::format_text(&modified_line))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
MessageContent::ToolConfirmationRequest(_) => {
|
||||
// skip tool confirmation requests
|
||||
}
|
||||
MessageContent::Image(_) => {
|
||||
// skip images
|
||||
}
|
||||
MessageContent::ToolResponse(_) => {
|
||||
// skip tool responses
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
let json = serde_json::to_string(&event).unwrap_or_else(|e| {
|
||||
format!(
|
||||
r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#,
|
||||
e
|
||||
)
|
||||
});
|
||||
tx.send(format!("data: {}\n\n", json)).await
|
||||
}
|
||||
|
||||
async fn handler(
|
||||
@@ -278,19 +101,12 @@ async fn handler(
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
// Check protocol header (optional in our case)
|
||||
if let Some(protocol) = headers.get("x-protocol") {
|
||||
if protocol.to_str().map(|p| p != "data").unwrap_or(true) {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
}
|
||||
|
||||
// Create channel for streaming
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let stream = ReceiverStream::new(rx);
|
||||
|
||||
// Convert incoming messages
|
||||
let messages = convert_messages(request.messages);
|
||||
// Get messages directly from the request
|
||||
let messages = request.messages;
|
||||
|
||||
// Get a lock on the shared agent
|
||||
let agent = state.agent.clone();
|
||||
@@ -301,10 +117,20 @@ async fn handler(
|
||||
let agent = match agent.as_ref() {
|
||||
Some(agent) => agent,
|
||||
None => {
|
||||
let _ = tx
|
||||
.send(ProtocolFormatter::format_error("No agent configured"))
|
||||
.await;
|
||||
let _ = tx.send(ProtocolFormatter::format_finish("error")).await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: "No agent configured".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "error".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -313,10 +139,20 @@ async fn handler(
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to start reply stream: {:?}", e);
|
||||
let _ = tx
|
||||
.send(ProtocolFormatter::format_error(&e.to_string()))
|
||||
.await;
|
||||
let _ = tx.send(ProtocolFormatter::format_finish("error")).await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "error".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -326,25 +162,32 @@ async fn handler(
|
||||
response = timeout(Duration::from_millis(500), stream.next()) => {
|
||||
match response {
|
||||
Ok(Some(Ok(message))) => {
|
||||
if let Err(e) = stream_message(message, &tx).await {
|
||||
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Some(Err(e))) => {
|
||||
tracing::error!("Error processing message: {}", e);
|
||||
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
break;
|
||||
}
|
||||
Ok(None) => {
|
||||
break;
|
||||
}
|
||||
Err(_) => { // Heartbeat, used to detect disconnected clients and then end running tools.
|
||||
Err(_) => { // Heartbeat, used to detect disconnected clients
|
||||
if tx.is_closed() {
|
||||
// Kill any running processes when the client disconnects
|
||||
// TODO is this used? I suspect post MCP this is on the server instead
|
||||
// goose::process_store::kill_processes();
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
@@ -354,24 +197,30 @@ async fn handler(
|
||||
}
|
||||
}
|
||||
|
||||
// Send finish message
|
||||
let _ = tx.send(ProtocolFormatter::format_finish("stop")).await;
|
||||
// Send finish event
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "stop".to_string(),
|
||||
},
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(SseResponse::new(stream))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, serde::Serialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct AskRequest {
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AskResponse {
|
||||
response: String,
|
||||
}
|
||||
|
||||
// simple ask an AI for a response, non streaming
|
||||
// Simple ask an AI for a response, non streaming
|
||||
async fn ask_handler(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
@@ -478,85 +327,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_messages_user_only() {
|
||||
let incoming = vec![IncomingMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Hello".to_string(),
|
||||
tool_invocations: vec![],
|
||||
}];
|
||||
|
||||
let messages = convert_messages(incoming);
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].role, Role::User);
|
||||
assert!(
|
||||
matches!(&messages[0].content[0], MessageContent::Text(text) if text.text == "Hello")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_messages_with_tool_invocation() {
|
||||
let tool_result = vec![Content::text("tool response").with_priority(0.0)];
|
||||
let incoming = vec![IncomingMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: "".to_string(),
|
||||
tool_invocations: vec![ToolInvocation {
|
||||
state: "result".to_string(),
|
||||
tool_call_id: "123".to_string(),
|
||||
tool_name: "test_tool".to_string(),
|
||||
args: json!({"key": "value"}),
|
||||
result: Some(tool_result.clone()),
|
||||
}],
|
||||
}];
|
||||
|
||||
let messages = convert_messages(incoming);
|
||||
assert_eq!(messages.len(), 2); // Tool request and response
|
||||
|
||||
// Check tool request
|
||||
assert_eq!(messages[0].role, Role::Assistant);
|
||||
assert!(
|
||||
matches!(&messages[0].content[0], MessageContent::ToolRequest(req) if req.id == "123")
|
||||
);
|
||||
|
||||
// Check tool response
|
||||
assert_eq!(messages[1].role, Role::User);
|
||||
assert!(
|
||||
matches!(&messages[1].content[0], MessageContent::ToolResponse(resp) if resp.id == "123")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protocol_formatter() {
|
||||
// Test text formatting
|
||||
let text = "Hello world";
|
||||
let formatted = ProtocolFormatter::format_text(text);
|
||||
assert_eq!(formatted, "0:\"Hello world\"\n");
|
||||
|
||||
// Test tool call formatting
|
||||
let formatted =
|
||||
ProtocolFormatter::format_tool_call("123", "test_tool", &json!({"key": "value"}));
|
||||
assert!(formatted.starts_with("9:"));
|
||||
assert!(formatted.contains("\"toolCallId\":\"123\""));
|
||||
assert!(formatted.contains("\"toolName\":\"test_tool\""));
|
||||
|
||||
// Test tool response formatting
|
||||
let result = vec![Content::text("response").with_priority(0.0)];
|
||||
let formatted = ProtocolFormatter::format_tool_response("123", &result);
|
||||
assert!(formatted.starts_with("a:"));
|
||||
assert!(formatted.contains("\"toolCallId\":\"123\""));
|
||||
|
||||
// Test error formatting
|
||||
let formatted = ProtocolFormatter::format_error("Test error");
|
||||
println!("Formatted error: {}", formatted);
|
||||
assert!(formatted.starts_with("3:"));
|
||||
assert!(formatted.contains("Test error"));
|
||||
|
||||
// Test finish formatting
|
||||
let formatted = ProtocolFormatter::format_finish("stop");
|
||||
assert!(formatted.starts_with("d:"));
|
||||
assert!(formatted.contains("\"finishReason\":\"stop\""));
|
||||
}
|
||||
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
use axum::{body::Body, http::Request};
|
||||
@@ -575,7 +345,7 @@ mod tests {
|
||||
});
|
||||
let agent = AgentFactory::create("reference", mock_provider).unwrap();
|
||||
let state = AppState {
|
||||
config: Arc::new(Mutex::new(HashMap::new())), // Add this line
|
||||
config: Arc::new(Mutex::new(HashMap::new())),
|
||||
agent: Arc::new(Mutex::new(Some(agent))),
|
||||
secret_key: "test-secret".to_string(),
|
||||
};
|
||||
|
||||
@@ -14,19 +14,26 @@ use mcp_core::role::Role;
|
||||
use mcp_core::tool::ToolCall;
|
||||
use serde_json::Value;
|
||||
|
||||
mod tool_result_serde;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolRequest {
|
||||
pub id: String,
|
||||
#[serde(with = "tool_result_serde")]
|
||||
pub tool_call: ToolResult<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolResponse {
|
||||
pub id: String,
|
||||
#[serde(with = "tool_result_serde")]
|
||||
pub tool_result: ToolResult<Vec<Content>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolConfirmationRequest {
|
||||
pub id: String,
|
||||
pub tool_name: String,
|
||||
@@ -36,6 +43,7 @@ pub struct ToolConfirmationRequest {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
/// Content passed inside a message, which can be both simple content and tool content
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum MessageContent {
|
||||
Text(TextContent),
|
||||
Image(ImageContent),
|
||||
@@ -150,6 +158,7 @@ impl From<Content> for MessageContent {
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
/// A message to or from an LLM
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub created: i64,
|
||||
@@ -292,3 +301,123 @@ impl Message {
|
||||
.all(|c| matches!(c, MessageContent::Text(_)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_core::handler::ToolError;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization() {
|
||||
let message = Message::assistant()
|
||||
.with_text("Hello, I'll help you with that.")
|
||||
.with_tool_request(
|
||||
"tool123",
|
||||
Ok(ToolCall::new("test_tool", json!({"param": "value"}))),
|
||||
);
|
||||
|
||||
let json_str = serde_json::to_string_pretty(&message).unwrap();
|
||||
println!("Serialized message: {}", json_str);
|
||||
|
||||
// Parse back to Value to check structure
|
||||
let value: Value = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
// Check top-level fields
|
||||
assert_eq!(value["role"], "assistant");
|
||||
assert!(value["created"].is_i64());
|
||||
assert!(value["content"].is_array());
|
||||
|
||||
// Check content items
|
||||
let content = &value["content"];
|
||||
|
||||
// First item should be text
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
|
||||
|
||||
// Second item should be toolRequest
|
||||
assert_eq!(content[1]["type"], "toolRequest");
|
||||
assert_eq!(content[1]["id"], "tool123");
|
||||
|
||||
// Check tool_call serialization
|
||||
assert_eq!(content[1]["toolCall"]["status"], "success");
|
||||
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
|
||||
assert_eq!(
|
||||
content[1]["toolCall"]["value"]["arguments"]["param"],
|
||||
"value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_serialization() {
|
||||
let message = Message::assistant().with_tool_request(
|
||||
"tool123",
|
||||
Err(ToolError::ExecutionError(
|
||||
"Something went wrong".to_string(),
|
||||
)),
|
||||
);
|
||||
|
||||
let json_str = serde_json::to_string_pretty(&message).unwrap();
|
||||
println!("Serialized error: {}", json_str);
|
||||
|
||||
// Parse back to Value to check structure
|
||||
let value: Value = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
// Check tool_call serialization with error
|
||||
let tool_call = &value["content"][0]["toolCall"];
|
||||
assert_eq!(tool_call["status"], "error");
|
||||
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialization() {
|
||||
// Create a JSON string with our new format
|
||||
let json_str = r#"{
|
||||
"role": "assistant",
|
||||
"created": 1740171566,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "I'll help you with that."
|
||||
},
|
||||
{
|
||||
"type": "toolRequest",
|
||||
"id": "tool123",
|
||||
"toolCall": {
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param": "value"}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let message: Message = serde_json::from_str(json_str).unwrap();
|
||||
|
||||
assert_eq!(message.role, Role::Assistant);
|
||||
assert_eq!(message.created, 1740171566);
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
// Check first content item
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "I'll help you with that.");
|
||||
} else {
|
||||
panic!("Expected Text content");
|
||||
}
|
||||
|
||||
// Check second content item
|
||||
if let MessageContent::ToolRequest(req) = &message.content[1] {
|
||||
assert_eq!(req.id, "tool123");
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
assert_eq!(tool_call.name, "test_tool");
|
||||
assert_eq!(tool_call.arguments, json!({"param": "value"}));
|
||||
} else {
|
||||
panic!("Expected successful tool call");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
64
crates/goose/src/message/tool_result_serde.rs
Normal file
64
crates/goose/src/message/tool_result_serde.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use mcp_core::handler::{ToolError, ToolResult};
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
pub fn serialize<T, S>(value: &ToolResult<T>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
S: Serializer,
|
||||
{
|
||||
match value {
|
||||
Ok(val) => {
|
||||
let mut state = serializer.serialize_struct("ToolResult", 2)?;
|
||||
state.serialize_field("status", "success")?;
|
||||
state.serialize_field("value", val)?;
|
||||
state.end()
|
||||
}
|
||||
Err(err) => {
|
||||
let mut state = serializer.serialize_struct("ToolResult", 2)?;
|
||||
state.serialize_field("status", "error")?;
|
||||
state.serialize_field("error", &err.to_string())?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For deserialization, let's use a simpler approach that works with the format we're serializing to
|
||||
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<ToolResult<T>, D::Error>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
// Define a helper enum to handle the two possible formats
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ResultFormat<T> {
|
||||
Success { status: String, value: T },
|
||||
Error { status: String, error: String },
|
||||
}
|
||||
|
||||
let format = ResultFormat::deserialize(deserializer)?;
|
||||
|
||||
match format {
|
||||
ResultFormat::Success { status, value } => {
|
||||
if status == "success" {
|
||||
Ok(Ok(value))
|
||||
} else {
|
||||
Err(serde::de::Error::custom(format!(
|
||||
"Expected status 'success', got '{}'",
|
||||
status
|
||||
)))
|
||||
}
|
||||
}
|
||||
ResultFormat::Error { status, error } => {
|
||||
if status == "error" {
|
||||
Ok(Err(ToolError::ExecutionError(error)))
|
||||
} else {
|
||||
Err(serde::de::Error::custom(format!(
|
||||
"Expected status 'error', got '{}'",
|
||||
status
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
# AI
|
||||
|
||||
This is a small fork of some files in the Vercel AI SDK to make a custom version of useChat which doesn't append text content to messages.
|
||||
|
||||
We can work to surface our desired functionality and upstream a change that could support it.
|
||||
@@ -1,92 +0,0 @@
|
||||
import { processCustomChatResponse } from './process-custom-chat-response';
|
||||
import { IdGenerator, JSONValue, Message, UseChatOptions } from '@ai-sdk/ui-utils';
|
||||
|
||||
// use function to allow for mocking in tests:
|
||||
const getOriginalFetch = () => fetch;
|
||||
|
||||
export async function callCustomChatApi({
|
||||
api,
|
||||
body,
|
||||
streamProtocol = 'data',
|
||||
credentials,
|
||||
headers,
|
||||
abortController,
|
||||
restoreMessagesOnFailure,
|
||||
onResponse,
|
||||
onUpdate,
|
||||
onFinish,
|
||||
onToolCall,
|
||||
generateId,
|
||||
fetch = getOriginalFetch(),
|
||||
}: {
|
||||
api: string;
|
||||
body: Record<string, any>;
|
||||
streamProtocol: 'data' | 'text' | undefined;
|
||||
credentials: RequestCredentials | undefined;
|
||||
headers: HeadersInit | undefined;
|
||||
abortController: (() => AbortController | null) | undefined;
|
||||
restoreMessagesOnFailure: () => void;
|
||||
onResponse: ((response: Response) => void | Promise<void>) | undefined;
|
||||
onUpdate: (newMessages: Message[], data: JSONValue[] | undefined) => void;
|
||||
onFinish: UseChatOptions['onFinish'];
|
||||
onToolCall: UseChatOptions['onToolCall'];
|
||||
generateId: IdGenerator;
|
||||
fetch: ReturnType<typeof getOriginalFetch> | undefined;
|
||||
}) {
|
||||
const response = await fetch(api, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(body),
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...headers,
|
||||
},
|
||||
signal: abortController?.()?.signal,
|
||||
credentials,
|
||||
}).catch((err) => {
|
||||
restoreMessagesOnFailure();
|
||||
throw err;
|
||||
});
|
||||
|
||||
if (onResponse) {
|
||||
try {
|
||||
await onResponse(response);
|
||||
} catch (err) {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
restoreMessagesOnFailure();
|
||||
throw new Error((await response.text()) ?? 'Failed to fetch the chat response.');
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('The response body is empty.');
|
||||
}
|
||||
|
||||
switch (streamProtocol) {
|
||||
case 'text': {
|
||||
throw new Error('Text protocol not supported in custom chat API');
|
||||
}
|
||||
|
||||
case 'data': {
|
||||
await processCustomChatResponse({
|
||||
stream: response.body,
|
||||
update: onUpdate,
|
||||
onToolCall,
|
||||
onFinish({ message, finishReason, usage }) {
|
||||
if (onFinish && message != null) {
|
||||
onFinish(message, { usage, finishReason });
|
||||
}
|
||||
},
|
||||
generateId,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
default: {
|
||||
const exhaustiveCheck: never = streamProtocol;
|
||||
throw new Error(`Unknown stream protocol: ${exhaustiveCheck}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
export interface LanguageModelUsage {
|
||||
completionTokens: number;
|
||||
promptTokens: number;
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export function calculateLanguageModelUsage(usage: LanguageModelUsage): LanguageModelUsage {
|
||||
return {
|
||||
completionTokens: usage.completionTokens,
|
||||
promptTokens: usage.promptTokens,
|
||||
totalTokens: usage.totalTokens,
|
||||
};
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
import { generateId as generateIdFunction } from '@ai-sdk/provider-utils';
|
||||
import type { JSONValue, Message } from '@ai-sdk/ui-utils';
|
||||
import { parsePartialJson, processDataStream } from '@ai-sdk/ui-utils';
|
||||
import { LanguageModelV1FinishReason } from '@ai-sdk/provider';
|
||||
import { LanguageModelUsage } from './core/types/usage';
|
||||
|
||||
// Simple usage calculation since we don't have access to the original
|
||||
function calculateLanguageModelUsage(usage: LanguageModelUsage): LanguageModelUsage {
|
||||
return {
|
||||
completionTokens: usage.completionTokens,
|
||||
promptTokens: usage.promptTokens,
|
||||
totalTokens: usage.totalTokens,
|
||||
};
|
||||
}
|
||||
|
||||
export async function processCustomChatResponse({
|
||||
stream,
|
||||
update,
|
||||
onToolCall,
|
||||
onFinish,
|
||||
generateId = generateIdFunction,
|
||||
getCurrentDate = () => new Date(),
|
||||
}: {
|
||||
stream: ReadableStream<Uint8Array>;
|
||||
update: (newMessages: Message[], data: JSONValue[] | undefined) => void;
|
||||
onToolCall?: (options: { toolCall: any }) => Promise<any>;
|
||||
onFinish?: (options: {
|
||||
message: Message | undefined;
|
||||
finishReason: LanguageModelV1FinishReason;
|
||||
usage: LanguageModelUsage;
|
||||
}) => void;
|
||||
generateId?: () => string;
|
||||
getCurrentDate?: () => Date;
|
||||
}) {
|
||||
const createdAt = getCurrentDate();
|
||||
let currentMessage: Message | undefined = undefined;
|
||||
const previousMessages: Message[] = [];
|
||||
const data: JSONValue[] = [];
|
||||
let lastEventType: 'text' | 'tool' | undefined = undefined;
|
||||
|
||||
// Keep track of partial tool calls
|
||||
const partialToolCalls: Record<string, { text: string; index: number; toolName: string }> = {};
|
||||
|
||||
let usage: LanguageModelUsage = {
|
||||
completionTokens: NaN,
|
||||
promptTokens: NaN,
|
||||
totalTokens: NaN,
|
||||
};
|
||||
let finishReason: LanguageModelV1FinishReason = 'unknown';
|
||||
|
||||
function execUpdate() {
|
||||
const copiedData = [...data];
|
||||
if (currentMessage == null) {
|
||||
update(previousMessages, copiedData);
|
||||
return;
|
||||
}
|
||||
|
||||
const copiedMessage = {
|
||||
...JSON.parse(JSON.stringify(currentMessage)),
|
||||
revisionId: generateId(),
|
||||
} as Message;
|
||||
|
||||
update([...previousMessages, copiedMessage], copiedData);
|
||||
}
|
||||
|
||||
// Create a new message only if needed
|
||||
function createNewMessage(): Message {
|
||||
if (currentMessage == null) {
|
||||
currentMessage = {
|
||||
id: generateId(),
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
createdAt,
|
||||
};
|
||||
}
|
||||
return currentMessage;
|
||||
}
|
||||
|
||||
// Move the current message to previous messages if it exists
|
||||
function archiveCurrentMessage() {
|
||||
if (currentMessage != null) {
|
||||
previousMessages.push(currentMessage);
|
||||
currentMessage = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
await processDataStream({
|
||||
stream,
|
||||
onTextPart(value) {
|
||||
// If the last event wasn't text, or we don't have a current message, create a new one
|
||||
if (lastEventType !== 'text' || currentMessage == null) {
|
||||
// Only archive if there are no tool invocations in 'call' state
|
||||
const hasPendingToolCalls =
|
||||
currentMessage?.toolInvocations?.some((invocation) => invocation.state === 'call') ??
|
||||
false;
|
||||
|
||||
if (!hasPendingToolCalls) {
|
||||
archiveCurrentMessage();
|
||||
}
|
||||
|
||||
if (!currentMessage) {
|
||||
currentMessage = createNewMessage();
|
||||
}
|
||||
currentMessage.content = value;
|
||||
} else {
|
||||
// Concatenate with the existing message
|
||||
currentMessage.content += value;
|
||||
}
|
||||
lastEventType = 'text';
|
||||
execUpdate();
|
||||
},
|
||||
onToolCallStreamingStartPart(value) {
|
||||
// Always create a new message for tool calls
|
||||
archiveCurrentMessage();
|
||||
currentMessage = createNewMessage();
|
||||
lastEventType = 'tool';
|
||||
|
||||
if (currentMessage.toolInvocations == null) {
|
||||
currentMessage.toolInvocations = [];
|
||||
}
|
||||
|
||||
partialToolCalls[value.toolCallId] = {
|
||||
text: '',
|
||||
toolName: value.toolName,
|
||||
index: currentMessage.toolInvocations.length,
|
||||
};
|
||||
|
||||
currentMessage.toolInvocations.push({
|
||||
state: 'partial-call',
|
||||
toolCallId: value.toolCallId,
|
||||
toolName: value.toolName,
|
||||
args: undefined,
|
||||
});
|
||||
|
||||
execUpdate();
|
||||
},
|
||||
onToolCallDeltaPart(value) {
|
||||
if (!currentMessage) {
|
||||
currentMessage = createNewMessage();
|
||||
}
|
||||
lastEventType = 'tool';
|
||||
|
||||
const partialToolCall = partialToolCalls[value.toolCallId];
|
||||
partialToolCall.text += value.argsTextDelta;
|
||||
|
||||
const { value: partialArgs } = parsePartialJson(partialToolCall.text);
|
||||
|
||||
currentMessage.toolInvocations![partialToolCall.index] = {
|
||||
state: 'partial-call',
|
||||
toolCallId: value.toolCallId,
|
||||
toolName: partialToolCall.toolName,
|
||||
args: partialArgs,
|
||||
};
|
||||
|
||||
execUpdate();
|
||||
},
|
||||
async onToolCallPart(value) {
|
||||
if (!currentMessage) {
|
||||
currentMessage = createNewMessage();
|
||||
}
|
||||
lastEventType = 'tool';
|
||||
|
||||
if (partialToolCalls[value.toolCallId] != null) {
|
||||
currentMessage.toolInvocations![partialToolCalls[value.toolCallId].index] = {
|
||||
state: 'call',
|
||||
...value,
|
||||
};
|
||||
} else {
|
||||
if (currentMessage.toolInvocations == null) {
|
||||
currentMessage.toolInvocations = [];
|
||||
}
|
||||
|
||||
currentMessage.toolInvocations.push({
|
||||
state: 'call',
|
||||
...value,
|
||||
});
|
||||
}
|
||||
|
||||
if (onToolCall) {
|
||||
const result = await onToolCall({ toolCall: value });
|
||||
if (result != null) {
|
||||
currentMessage.toolInvocations![currentMessage.toolInvocations!.length - 1] = {
|
||||
state: 'result',
|
||||
...value,
|
||||
result,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
execUpdate();
|
||||
},
|
||||
onToolResultPart(value) {
|
||||
if (!currentMessage) {
|
||||
currentMessage = createNewMessage();
|
||||
}
|
||||
lastEventType = 'tool';
|
||||
|
||||
const toolInvocations = currentMessage.toolInvocations;
|
||||
if (toolInvocations == null) {
|
||||
throw new Error('tool_result must be preceded by a tool_call');
|
||||
}
|
||||
|
||||
const toolInvocationIndex = toolInvocations.findIndex(
|
||||
(invocation) => invocation.toolCallId === value.toolCallId
|
||||
);
|
||||
|
||||
if (toolInvocationIndex === -1) {
|
||||
throw new Error('tool_result must be preceded by a tool_call with the same toolCallId');
|
||||
}
|
||||
|
||||
toolInvocations[toolInvocationIndex] = {
|
||||
...toolInvocations[toolInvocationIndex],
|
||||
state: 'result' as const,
|
||||
...value,
|
||||
};
|
||||
|
||||
execUpdate();
|
||||
},
|
||||
onDataPart(value) {
|
||||
data.push(...value);
|
||||
execUpdate();
|
||||
},
|
||||
onFinishStepPart() {
|
||||
// Archive the current message when a step finishes
|
||||
archiveCurrentMessage();
|
||||
},
|
||||
onFinishMessagePart(value) {
|
||||
finishReason = value.finishReason;
|
||||
if (value.usage != null) {
|
||||
usage = calculateLanguageModelUsage(value.usage);
|
||||
}
|
||||
},
|
||||
onErrorPart(error) {
|
||||
throw new Error(error);
|
||||
},
|
||||
});
|
||||
|
||||
onFinish?.({ message: currentMessage, finishReason, usage });
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
import throttleFunction from 'throttleit';
|
||||
|
||||
export function throttle<T extends (...args: any[]) => any>(fn: T, waitMs: number | undefined): T {
|
||||
return waitMs != null ? throttleFunction(fn, waitMs) : fn;
|
||||
}
|
||||
@@ -1,611 +0,0 @@
|
||||
import { FetchFunction } from '@ai-sdk/provider-utils';
|
||||
import type {
|
||||
Attachment,
|
||||
ChatRequest,
|
||||
ChatRequestOptions,
|
||||
CreateMessage,
|
||||
IdGenerator,
|
||||
JSONValue,
|
||||
Message,
|
||||
UseChatOptions,
|
||||
} from '@ai-sdk/ui-utils';
|
||||
import { generateId as generateIdFunc } from '@ai-sdk/ui-utils';
|
||||
import { callCustomChatApi as callChatApi } from './call-custom-chat-api';
|
||||
import { useCallback, useEffect, useId, useRef, useState } from 'react';
|
||||
import useSWR, { KeyedMutator } from 'swr';
|
||||
import { throttle } from './throttle';
|
||||
import { getSecretKey } from '../config';
|
||||
|
||||
export type { CreateMessage, Message, UseChatOptions };
|
||||
|
||||
export type UseChatHelpers = {
|
||||
/** Current messages in the chat */
|
||||
messages: Message[];
|
||||
/** The error object of the API request */
|
||||
error: undefined | Error;
|
||||
/**
|
||||
* Append a user message to the chat list. This triggers the API call to fetch
|
||||
* the assistant's response.
|
||||
* @param message The message to append
|
||||
* @param options Additional options to pass to the API call
|
||||
*/
|
||||
append: (
|
||||
message: Message | CreateMessage,
|
||||
chatRequestOptions?: ChatRequestOptions
|
||||
) => Promise<string | null | undefined>;
|
||||
/**
|
||||
* Reload the last AI chat response for the given chat history. If the last
|
||||
* message isn't from the assistant, it will request the API to generate a
|
||||
* new response.
|
||||
*/
|
||||
reload: (chatRequestOptions?: ChatRequestOptions) => Promise<string | null | undefined>;
|
||||
/**
|
||||
* Abort the current request immediately, keep the generated tokens if any.
|
||||
*/
|
||||
stop: () => void;
|
||||
/**
|
||||
* Update the `messages` state locally. This is useful when you want to
|
||||
* edit the messages on the client, and then trigger the `reload` method
|
||||
* manually to regenerate the AI response.
|
||||
*/
|
||||
setMessages: (messages: Message[] | ((messages: Message[]) => Message[])) => void;
|
||||
/** The current value of the input */
|
||||
input: string;
|
||||
/** setState-powered method to update the input value */
|
||||
setInput: React.Dispatch<React.SetStateAction<string>>;
|
||||
/** An input/textarea-ready onChange handler to control the value of the input */
|
||||
handleInputChange: (
|
||||
e: React.ChangeEvent<HTMLInputElement> | React.ChangeEvent<HTMLTextAreaElement>
|
||||
) => void;
|
||||
/** Form submission handler to automatically reset input and append a user message */
|
||||
handleSubmit: (
|
||||
event?: { preventDefault?: () => void },
|
||||
chatRequestOptions?: ChatRequestOptions
|
||||
) => void;
|
||||
metadata?: object;
|
||||
/** Whether the API request is in progress */
|
||||
isLoading: boolean;
|
||||
|
||||
/** Additional data added on the server via StreamData. */
|
||||
data?: JSONValue[];
|
||||
/** Set the data of the chat. You can use this to transform or clear the chat data. */
|
||||
setData: (
|
||||
data: JSONValue[] | undefined | ((data: JSONValue[] | undefined) => JSONValue[] | undefined)
|
||||
) => void;
|
||||
};
|
||||
|
||||
const processResponseStream = async (
|
||||
api: string,
|
||||
chatRequest: ChatRequest,
|
||||
mutate: KeyedMutator<Message[]>,
|
||||
mutateStreamData: KeyedMutator<JSONValue[] | undefined>,
|
||||
existingDataRef: React.MutableRefObject<JSONValue[] | undefined>,
|
||||
extraMetadataRef: React.MutableRefObject<any>,
|
||||
messagesRef: React.MutableRefObject<Message[]>,
|
||||
abortControllerRef: React.MutableRefObject<AbortController | null>,
|
||||
generateId: IdGenerator,
|
||||
streamProtocol: UseChatOptions['streamProtocol'],
|
||||
onFinish: UseChatOptions['onFinish'],
|
||||
onResponse: ((response: Response) => void | Promise<void>) | undefined,
|
||||
onToolCall: UseChatOptions['onToolCall'] | undefined,
|
||||
sendExtraMessageFields: boolean | undefined,
|
||||
experimental_prepareRequestBody:
|
||||
| ((options: {
|
||||
messages: Message[];
|
||||
requestData?: JSONValue;
|
||||
requestBody?: object;
|
||||
}) => JSONValue)
|
||||
| undefined,
|
||||
fetch: FetchFunction | undefined,
|
||||
keepLastMessageOnError: boolean
|
||||
) => {
|
||||
// Do an optimistic update to the chat state to show the updated messages immediately:
|
||||
const previousMessages = messagesRef.current;
|
||||
mutate(chatRequest.messages, false);
|
||||
|
||||
const constructedMessagesPayload = sendExtraMessageFields
|
||||
? chatRequest.messages
|
||||
: chatRequest.messages.map(
|
||||
({ role, content, experimental_attachments, data, annotations, toolInvocations }) => ({
|
||||
role,
|
||||
content,
|
||||
...(experimental_attachments !== undefined && {
|
||||
experimental_attachments,
|
||||
}),
|
||||
...(data !== undefined && { data }),
|
||||
...(annotations !== undefined && { annotations }),
|
||||
...(toolInvocations !== undefined && { toolInvocations }),
|
||||
})
|
||||
);
|
||||
|
||||
const existingData = existingDataRef.current;
|
||||
|
||||
return await callChatApi({
|
||||
api,
|
||||
body: experimental_prepareRequestBody?.({
|
||||
messages: chatRequest.messages,
|
||||
requestData: chatRequest.data,
|
||||
requestBody: chatRequest.body,
|
||||
}) ?? {
|
||||
messages: constructedMessagesPayload,
|
||||
data: chatRequest.data,
|
||||
...extraMetadataRef.current.body,
|
||||
...chatRequest.body,
|
||||
},
|
||||
streamProtocol,
|
||||
credentials: extraMetadataRef.current.credentials,
|
||||
headers: {
|
||||
...extraMetadataRef.current.headers,
|
||||
...chatRequest.headers,
|
||||
'X-Secret-Key': getSecretKey(),
|
||||
},
|
||||
abortController: () => abortControllerRef.current,
|
||||
restoreMessagesOnFailure() {
|
||||
if (!keepLastMessageOnError) {
|
||||
mutate(previousMessages, false);
|
||||
}
|
||||
},
|
||||
onResponse,
|
||||
onUpdate(merged, data) {
|
||||
mutate([...chatRequest.messages, ...merged], false);
|
||||
if (data?.length) {
|
||||
mutateStreamData([...(existingData ?? []), ...data], false);
|
||||
}
|
||||
},
|
||||
onToolCall,
|
||||
onFinish,
|
||||
generateId,
|
||||
fetch,
|
||||
});
|
||||
};
|
||||
|
||||
export function useChat({
|
||||
api = '/api/chat',
|
||||
id,
|
||||
initialMessages,
|
||||
initialInput = '',
|
||||
sendExtraMessageFields,
|
||||
onToolCall,
|
||||
experimental_prepareRequestBody,
|
||||
maxSteps = 1,
|
||||
streamProtocol = 'data',
|
||||
onResponse,
|
||||
onFinish,
|
||||
onError,
|
||||
credentials,
|
||||
headers,
|
||||
body,
|
||||
generateId = generateIdFunc,
|
||||
fetch,
|
||||
keepLastMessageOnError = true,
|
||||
experimental_throttle: throttleWaitMs,
|
||||
}: UseChatOptions & {
|
||||
key?: string;
|
||||
|
||||
/**
|
||||
* Experimental (React only). When a function is provided, it will be used
|
||||
* to prepare the request body for the chat API. This can be useful for
|
||||
* customizing the request body based on the messages and data in the chat.
|
||||
*
|
||||
* @param messages The current messages in the chat.
|
||||
* @param requestData The data object passed in the chat request.
|
||||
* @param requestBody The request body object passed in the chat request.
|
||||
*/
|
||||
experimental_prepareRequestBody?: (options: {
|
||||
messages: Message[];
|
||||
requestData?: JSONValue;
|
||||
requestBody?: object;
|
||||
}) => JSONValue;
|
||||
|
||||
/**
|
||||
Custom throttle wait in ms for the chat messages and data updates.
|
||||
Default is undefined, which disables throttling.
|
||||
*/
|
||||
experimental_throttle?: number;
|
||||
|
||||
/**
|
||||
Maximum number of sequential LLM calls (steps), e.g. when you use tool calls. Must be at least 1.
|
||||
|
||||
A maximum number is required to prevent infinite loops in the case of misconfigured tools.
|
||||
|
||||
By default, it's set to 1, which means that only a single LLM call is made.
|
||||
*/
|
||||
maxSteps?: number;
|
||||
} = {}): UseChatHelpers & {
|
||||
addToolResult: ({ toolCallId, result }: { toolCallId: string; result: any }) => void;
|
||||
} {
|
||||
// Generate a unique id for the chat if not provided.
|
||||
const hookId = useId();
|
||||
const idKey = id ?? hookId;
|
||||
const chatKey = typeof api === 'string' ? [api, idKey] : idKey;
|
||||
|
||||
// Store a empty array as the initial messages
|
||||
// (instead of using a default parameter value that gets re-created each time)
|
||||
// to avoid re-renders:
|
||||
const [initialMessagesFallback] = useState([]);
|
||||
|
||||
// Store the chat state in SWR, using the chatId as the key to share states.
|
||||
const { data: messages, mutate } = useSWR<Message[]>([chatKey, 'messages'], null, {
|
||||
fallbackData: initialMessages ?? initialMessagesFallback,
|
||||
});
|
||||
|
||||
// Keep the latest messages in a ref.
|
||||
const messagesRef = useRef<Message[]>(messages || []);
|
||||
useEffect(() => {
|
||||
messagesRef.current = messages || [];
|
||||
}, [messages]);
|
||||
|
||||
// stream data
|
||||
const { data: streamData, mutate: mutateStreamData } = useSWR<JSONValue[] | undefined>(
|
||||
[chatKey, 'streamData'],
|
||||
null
|
||||
);
|
||||
|
||||
// keep the latest stream data in a ref
|
||||
const streamDataRef = useRef<JSONValue[] | undefined>(streamData);
|
||||
useEffect(() => {
|
||||
streamDataRef.current = streamData;
|
||||
}, [streamData]);
|
||||
|
||||
// We store loading state in another hook to sync loading states across hook invocations
|
||||
const { data: isLoading = false, mutate: mutateLoading } = useSWR<boolean>(
|
||||
[chatKey, 'loading'],
|
||||
null
|
||||
);
|
||||
|
||||
const { data: error = undefined, mutate: setError } = useSWR<undefined | Error>(
|
||||
[chatKey, 'error'],
|
||||
null
|
||||
);
|
||||
|
||||
// Abort controller to cancel the current API call.
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
const extraMetadataRef = useRef({
|
||||
credentials,
|
||||
headers,
|
||||
body,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
extraMetadataRef.current = {
|
||||
credentials,
|
||||
headers,
|
||||
body,
|
||||
};
|
||||
}, [credentials, headers, body]);
|
||||
|
||||
const triggerRequest = useCallback(
|
||||
async (chatRequest: ChatRequest) => {
|
||||
const messageCount = messagesRef.current.length;
|
||||
|
||||
try {
|
||||
mutateLoading(true);
|
||||
setError(undefined);
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
await processResponseStream(
|
||||
api,
|
||||
chatRequest,
|
||||
// throttle streamed ui updates:
|
||||
throttle(mutate, throttleWaitMs),
|
||||
throttle(mutateStreamData, throttleWaitMs),
|
||||
streamDataRef,
|
||||
extraMetadataRef,
|
||||
messagesRef,
|
||||
abortControllerRef,
|
||||
generateId,
|
||||
streamProtocol,
|
||||
onFinish,
|
||||
onResponse,
|
||||
onToolCall,
|
||||
sendExtraMessageFields,
|
||||
experimental_prepareRequestBody,
|
||||
fetch,
|
||||
keepLastMessageOnError
|
||||
);
|
||||
|
||||
abortControllerRef.current = null;
|
||||
} catch (err) {
|
||||
// Ignore abort errors as they are expected.
|
||||
if ((err as any).name === 'AbortError') {
|
||||
abortControllerRef.current = null;
|
||||
return null;
|
||||
}
|
||||
|
||||
if (onError && err instanceof Error) {
|
||||
onError(err);
|
||||
}
|
||||
|
||||
setError(err as Error);
|
||||
} finally {
|
||||
mutateLoading(false);
|
||||
}
|
||||
|
||||
// auto-submit when all tool calls in the last assistant message have results:
|
||||
const messages = messagesRef.current;
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (
|
||||
// ensure we actually have new messages (to prevent infinite loops in case of errors):
|
||||
messages.length > messageCount &&
|
||||
// ensure there is a last message:
|
||||
lastMessage != null &&
|
||||
// check if the feature is enabled:
|
||||
maxSteps > 1 &&
|
||||
// check that next step is possible:
|
||||
isAssistantMessageWithCompletedToolCalls(lastMessage) &&
|
||||
// limit the number of automatic steps:
|
||||
countTrailingAssistantMessages(messages) < maxSteps
|
||||
) {
|
||||
await triggerRequest({ messages });
|
||||
}
|
||||
},
|
||||
[
|
||||
mutate,
|
||||
mutateLoading,
|
||||
api,
|
||||
extraMetadataRef,
|
||||
onResponse,
|
||||
onFinish,
|
||||
onError,
|
||||
setError,
|
||||
mutateStreamData,
|
||||
streamDataRef,
|
||||
streamProtocol,
|
||||
sendExtraMessageFields,
|
||||
experimental_prepareRequestBody,
|
||||
onToolCall,
|
||||
maxSteps,
|
||||
messagesRef,
|
||||
abortControllerRef,
|
||||
generateId,
|
||||
fetch,
|
||||
keepLastMessageOnError,
|
||||
throttleWaitMs,
|
||||
]
|
||||
);
|
||||
|
||||
const append = useCallback(
|
||||
async (
|
||||
message: Message | CreateMessage,
|
||||
{ data, headers, body, experimental_attachments }: ChatRequestOptions = {}
|
||||
) => {
|
||||
if (!message.id) {
|
||||
message.id = generateId();
|
||||
}
|
||||
|
||||
const attachmentsForRequest = await prepareAttachmentsForRequest(experimental_attachments);
|
||||
|
||||
const messages = messagesRef.current.concat({
|
||||
...message,
|
||||
id: message.id ?? generateId(),
|
||||
createdAt: message.createdAt ?? new Date(),
|
||||
experimental_attachments:
|
||||
attachmentsForRequest.length > 0 ? attachmentsForRequest : undefined,
|
||||
});
|
||||
|
||||
return triggerRequest({ messages, headers, body, data });
|
||||
},
|
||||
[triggerRequest, generateId]
|
||||
);
|
||||
|
||||
const reload = useCallback(
|
||||
async ({ data, headers, body }: ChatRequestOptions = {}) => {
|
||||
const messages = messagesRef.current;
|
||||
|
||||
if (messages.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Remove last assistant message and retry last user message.
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
return triggerRequest({
|
||||
messages: lastMessage.role === 'assistant' ? messages.slice(0, -1) : messages,
|
||||
headers,
|
||||
body,
|
||||
data,
|
||||
});
|
||||
},
|
||||
[triggerRequest]
|
||||
);
|
||||
|
||||
const stop = useCallback(() => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const setMessages = useCallback(
|
||||
(messages: Message[] | ((messages: Message[]) => Message[])) => {
|
||||
if (typeof messages === 'function') {
|
||||
messages = messages(messagesRef.current);
|
||||
}
|
||||
|
||||
mutate(messages, false);
|
||||
messagesRef.current = messages;
|
||||
},
|
||||
[mutate]
|
||||
);
|
||||
|
||||
const setData = useCallback(
|
||||
(
|
||||
data: JSONValue[] | undefined | ((data: JSONValue[] | undefined) => JSONValue[] | undefined)
|
||||
) => {
|
||||
if (typeof data === 'function') {
|
||||
data = data(streamDataRef.current);
|
||||
}
|
||||
|
||||
mutateStreamData(data, false);
|
||||
streamDataRef.current = data;
|
||||
},
|
||||
[mutateStreamData]
|
||||
);
|
||||
|
||||
// Input state and handlers.
|
||||
const [input, setInput] = useState(initialInput);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (
|
||||
event?: { preventDefault?: () => void },
|
||||
options: ChatRequestOptions = {},
|
||||
metadata?: object
|
||||
) => {
|
||||
event?.preventDefault?.();
|
||||
|
||||
if (!input && !options.allowEmptySubmit) return;
|
||||
|
||||
if (metadata) {
|
||||
extraMetadataRef.current = {
|
||||
...extraMetadataRef.current,
|
||||
...metadata,
|
||||
};
|
||||
}
|
||||
|
||||
const attachmentsForRequest = await prepareAttachmentsForRequest(
|
||||
options.experimental_attachments
|
||||
);
|
||||
|
||||
const messages =
|
||||
!input && !attachmentsForRequest.length && options.allowEmptySubmit
|
||||
? messagesRef.current
|
||||
: messagesRef.current.concat({
|
||||
id: generateId(),
|
||||
createdAt: new Date(),
|
||||
role: 'user',
|
||||
content: input,
|
||||
experimental_attachments:
|
||||
attachmentsForRequest.length > 0 ? attachmentsForRequest : undefined,
|
||||
});
|
||||
|
||||
const chatRequest: ChatRequest = {
|
||||
messages,
|
||||
headers: options.headers,
|
||||
body: options.body,
|
||||
data: options.data,
|
||||
};
|
||||
|
||||
triggerRequest(chatRequest);
|
||||
|
||||
setInput('');
|
||||
},
|
||||
[input, generateId, triggerRequest]
|
||||
);
|
||||
|
||||
const handleInputChange = (e: any) => {
|
||||
setInput(e.target.value);
|
||||
};
|
||||
|
||||
const addToolResult = ({ toolCallId, result }: { toolCallId: string; result: any }) => {
|
||||
const updatedMessages = messagesRef.current.map((message, index, arr) =>
|
||||
// update the tool calls in the last assistant message:
|
||||
index === arr.length - 1 && message.role === 'assistant' && message.toolInvocations
|
||||
? {
|
||||
...message,
|
||||
toolInvocations: message.toolInvocations.map((toolInvocation) =>
|
||||
toolInvocation.toolCallId === toolCallId
|
||||
? {
|
||||
...toolInvocation,
|
||||
result,
|
||||
state: 'result' as const,
|
||||
}
|
||||
: toolInvocation
|
||||
),
|
||||
}
|
||||
: message
|
||||
);
|
||||
|
||||
mutate(updatedMessages, false);
|
||||
|
||||
// auto-submit when all tool calls in the last assistant message have results:
|
||||
const lastMessage = updatedMessages[updatedMessages.length - 1];
|
||||
if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
|
||||
triggerRequest({ messages: updatedMessages });
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
messages: messages || [],
|
||||
setMessages,
|
||||
data: streamData,
|
||||
setData,
|
||||
error,
|
||||
append,
|
||||
reload,
|
||||
stop,
|
||||
input,
|
||||
setInput,
|
||||
handleInputChange,
|
||||
handleSubmit,
|
||||
isLoading,
|
||||
addToolResult,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
Check if the message is an assistant message with completed tool calls.
|
||||
The message must have at least one tool invocation and all tool invocations
|
||||
must have a result.
|
||||
*/
|
||||
function isAssistantMessageWithCompletedToolCalls(message: Message) {
|
||||
return (
|
||||
message.role === 'assistant' &&
|
||||
message.toolInvocations &&
|
||||
message.toolInvocations.length > 0 &&
|
||||
message.toolInvocations.every((toolInvocation) => 'result' in toolInvocation)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
Returns the number of trailing assistant messages in the array.
|
||||
*/
|
||||
function countTrailingAssistantMessages(messages: Message[]) {
|
||||
let count = 0;
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === 'assistant') {
|
||||
count++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
async function prepareAttachmentsForRequest(
|
||||
attachmentsFromOptions: FileList | Array<Attachment> | undefined
|
||||
) {
|
||||
if (attachmentsFromOptions == null) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (attachmentsFromOptions instanceof FileList) {
|
||||
return Promise.all(
|
||||
Array.from(attachmentsFromOptions).map(async (attachment) => {
|
||||
const { name, type } = attachment;
|
||||
|
||||
const dataUrl = await new Promise<string>((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (readerEvent) => {
|
||||
resolve(readerEvent.target?.result as string);
|
||||
};
|
||||
reader.onerror = (error) => reject(error);
|
||||
reader.readAsDataURL(attachment);
|
||||
});
|
||||
|
||||
return {
|
||||
name,
|
||||
contentType: type,
|
||||
url: dataUrl,
|
||||
};
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (Array.isArray(attachmentsFromOptions)) {
|
||||
return attachmentsFromOptions;
|
||||
}
|
||||
|
||||
throw new Error('Invalid attachments type');
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
import React, { useEffect, useRef, useState } from 'react';
|
||||
import { Message, useChat } from '../ai-sdk-fork/useChat';
|
||||
import { getApiUrl } from '../config';
|
||||
import BottomMenu from './BottomMenu';
|
||||
import FlappyGoose from './FlappyGoose';
|
||||
@@ -14,15 +13,13 @@ import UserMessage from './UserMessage';
|
||||
import { askAi } from '../utils/askAI';
|
||||
import Splash from './Splash';
|
||||
import 'react-toastify/dist/ReactToastify.css';
|
||||
import { useMessageStream } from '../hooks/useMessageStream';
|
||||
import { Message, createUserMessage, getTextContent } from '../types/message';
|
||||
|
||||
export interface ChatType {
|
||||
id: number;
|
||||
title: string;
|
||||
messages: Array<{
|
||||
id: string;
|
||||
role: 'function' | 'system' | 'user' | 'assistant' | 'data' | 'tool';
|
||||
content: string;
|
||||
}>;
|
||||
messages: Message[];
|
||||
}
|
||||
|
||||
export default function ChatView({ setView }: { setView: (view: View) => void }) {
|
||||
@@ -39,14 +36,27 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
|
||||
const [showGame, setShowGame] = useState(false);
|
||||
const scrollRef = useRef<ScrollAreaHandle>(null);
|
||||
|
||||
const { messages, append, stop, isLoading, error, setMessages } = useChat({
|
||||
const {
|
||||
messages,
|
||||
append,
|
||||
stop,
|
||||
isLoading,
|
||||
error,
|
||||
setMessages,
|
||||
input: _input,
|
||||
setInput: _setInput,
|
||||
handleInputChange: _handleInputChange,
|
||||
handleSubmit: _submitMessage,
|
||||
} = useMessageStream({
|
||||
api: getApiUrl('/reply'),
|
||||
initialMessages: chat?.messages || [],
|
||||
onFinish: async (message, _) => {
|
||||
onFinish: async (message, _reason) => {
|
||||
window.electron.stopPowerSaveBlocker();
|
||||
|
||||
const fetchResponses = await askAi(message.content);
|
||||
setMessageMetadata((prev) => ({ ...prev, [message.id]: fetchResponses }));
|
||||
// Extract text content from the message to pass to askAi
|
||||
const messageText = getTextContent(message);
|
||||
const fetchResponses = await askAi(messageText);
|
||||
setMessageMetadata((prev) => ({ ...prev, [message.id || '']: fetchResponses }));
|
||||
|
||||
const timeSinceLastInteraction = Date.now() - lastInteractionTime;
|
||||
window.electron.logInfo('last interaction:' + lastInteractionTime);
|
||||
@@ -58,11 +68,16 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
|
||||
});
|
||||
}
|
||||
},
|
||||
onToolCall: (toolCall) => {
|
||||
// Handle tool calls if needed
|
||||
console.log('Tool call received:', toolCall);
|
||||
// Implement tool call handling logic here
|
||||
},
|
||||
});
|
||||
|
||||
// Update chat messages when they change
|
||||
useEffect(() => {
|
||||
setChat({ ...chat, messages });
|
||||
setChat((prevChat) => ({ ...prevChat, messages }));
|
||||
}, [messages]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -78,10 +93,7 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
|
||||
const content = customEvent.detail?.value || '';
|
||||
if (content.trim()) {
|
||||
setLastInteractionTime(Date.now());
|
||||
append({
|
||||
role: 'user',
|
||||
content,
|
||||
});
|
||||
append(createUserMessage(content));
|
||||
if (scrollRef.current?.scrollToBottom) {
|
||||
scrollRef.current.scrollToBottom();
|
||||
}
|
||||
@@ -97,47 +109,38 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
|
||||
setLastInteractionTime(Date.now());
|
||||
window.electron.stopPowerSaveBlocker();
|
||||
|
||||
const lastMessage: Message = messages[messages.length - 1];
|
||||
if (lastMessage.role === 'user' && lastMessage.toolInvocations === undefined) {
|
||||
// Remove the last user message.
|
||||
// Handle stopping the message stream
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (lastMessage && lastMessage.role === 'user') {
|
||||
// Remove the last user message if it's the most recent one
|
||||
if (messages.length > 1) {
|
||||
setMessages(messages.slice(0, -1));
|
||||
} else {
|
||||
setMessages([]);
|
||||
}
|
||||
} else if (lastMessage.role === 'assistant' && lastMessage.toolInvocations !== undefined) {
|
||||
// Add messaging about interrupted ongoing tool invocations
|
||||
const newLastMessage: Message = {
|
||||
...lastMessage,
|
||||
toolInvocations: lastMessage.toolInvocations.map((invocation) => {
|
||||
if (invocation.state !== 'result') {
|
||||
return {
|
||||
...invocation,
|
||||
result: [
|
||||
{
|
||||
audience: ['user'],
|
||||
text: 'Interrupted.\n',
|
||||
type: 'text',
|
||||
},
|
||||
{
|
||||
audience: ['assistant'],
|
||||
text: 'Interrupted by the user to make a correction.\n',
|
||||
type: 'text',
|
||||
},
|
||||
],
|
||||
state: 'result',
|
||||
};
|
||||
} else {
|
||||
return invocation;
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
const updatedMessages = [...messages.slice(0, -1), newLastMessage];
|
||||
setMessages(updatedMessages);
|
||||
}
|
||||
// Note: Tool call interruption handling would need to be implemented
|
||||
// differently with the new message format
|
||||
};
|
||||
|
||||
// Filter out standalone tool response messages for rendering
|
||||
// They will be shown as part of the tool invocation in the assistant message
|
||||
const filteredMessages = messages.filter((message) => {
|
||||
// Keep all assistant messages and user messages that aren't just tool responses
|
||||
if (message.role === 'assistant') return true;
|
||||
|
||||
// For user messages, check if they're only tool responses
|
||||
if (message.role === 'user') {
|
||||
const hasOnlyToolResponses = message.content.every((c) => c.type === 'toolResponse');
|
||||
const hasTextContent = message.content.some((c) => c.type === 'text');
|
||||
|
||||
// Keep the message if it has text content or is not just tool responses
|
||||
return hasTextContent || !hasOnlyToolResponses;
|
||||
}
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-full h-screen items-center justify-center">
|
||||
<div className="relative flex items-center h-[36px] w-full bg-bgSubtle border-b border-borderSubtle">
|
||||
@@ -145,19 +148,19 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
|
||||
</div>
|
||||
<Card className="flex flex-col flex-1 rounded-none h-[calc(100vh-95px)] w-full bg-bgApp mt-0 border-none relative">
|
||||
{messages.length === 0 ? (
|
||||
<Splash append={append} />
|
||||
<Splash append={(text) => append(createUserMessage(text))} />
|
||||
) : (
|
||||
<ScrollArea ref={scrollRef} className="flex-1 px-4" autoScroll>
|
||||
{messages.map((message) => (
|
||||
<div key={message.id} className="mt-[16px]">
|
||||
{filteredMessages.map((message, index) => (
|
||||
<div key={message.id || index} className="mt-[16px]">
|
||||
{message.role === 'user' ? (
|
||||
<UserMessage message={message} />
|
||||
) : (
|
||||
<GooseMessage
|
||||
message={message}
|
||||
messages={messages}
|
||||
metadata={messageMetadata[message.id]}
|
||||
append={append}
|
||||
metadata={messageMetadata[message.id || '']}
|
||||
append={(text) => append(createUserMessage(text))}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@@ -166,20 +169,17 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
|
||||
<div className="flex flex-col items-center justify-center p-4">
|
||||
<div className="text-red-700 dark:text-red-300 bg-red-400/50 p-3 rounded-lg mb-2">
|
||||
{error.message || 'Honk! Goose experienced an error while responding'}
|
||||
{error.status && <span className="ml-2">(Status: {error.status})</span>}
|
||||
</div>
|
||||
<div
|
||||
className="px-3 py-2 mt-2 text-center whitespace-nowrap cursor-pointer text-textStandard border border-borderSubtle hover:bg-bgSubtle rounded-full inline-block transition-all duration-150"
|
||||
onClick={async () => {
|
||||
// Find the last user message
|
||||
const lastUserMessage = messages.reduceRight(
|
||||
(found, m) => found || (m.role === 'user' ? m : null),
|
||||
null
|
||||
null as Message | null
|
||||
);
|
||||
if (lastUserMessage) {
|
||||
append({
|
||||
role: 'user',
|
||||
content: lastUserMessage.content,
|
||||
});
|
||||
append(lastUserMessage);
|
||||
}
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -1,41 +1,77 @@
|
||||
import React from 'react';
|
||||
import ToolInvocations from './ToolInvocations';
|
||||
import React, { useMemo } from 'react';
|
||||
import LinkPreview from './LinkPreview';
|
||||
import GooseResponseForm from './GooseResponseForm';
|
||||
import { extractUrls } from '../utils/urlUtils';
|
||||
import MarkdownContent from './MarkdownContent';
|
||||
import ToolCallWithResponse from './ToolCallWithResponse';
|
||||
import { Message, getTextContent, getToolRequests, getToolResponses } from '../types/message';
|
||||
|
||||
interface GooseMessageProps {
|
||||
message: any;
|
||||
messages: any[];
|
||||
metadata?: any;
|
||||
append: (value: any) => void;
|
||||
message: Message;
|
||||
messages: Message[];
|
||||
metadata?: string[];
|
||||
append: (value: string) => void;
|
||||
}
|
||||
|
||||
export default function GooseMessage({ message, metadata, messages, append }: GooseMessageProps) {
|
||||
// Extract text content from the message
|
||||
const textContent = getTextContent(message);
|
||||
|
||||
// Get tool requests from the message
|
||||
const toolRequests = getToolRequests(message);
|
||||
|
||||
// Extract URLs under a few conditions
|
||||
// 1. The message is purely text
|
||||
// 2. The link wasn't also present in the previous message
|
||||
// 3. The message contains the explicit http:// or https:// protocol at the beginning
|
||||
const messageIndex = messages?.findIndex((msg) => msg.id === message.id);
|
||||
const previousMessage = messageIndex > 0 ? messages[messageIndex - 1] : null;
|
||||
const previousUrls = previousMessage ? extractUrls(previousMessage.content) : [];
|
||||
const urls = !message.toolInvocations ? extractUrls(message.content, previousUrls) : [];
|
||||
const previousUrls = previousMessage ? extractUrls(getTextContent(previousMessage)) : [];
|
||||
const urls = toolRequests.length === 0 ? extractUrls(textContent, previousUrls) : [];
|
||||
|
||||
// Find tool responses that correspond to the tool requests in this message
|
||||
const toolResponsesMap = useMemo(() => {
|
||||
const responseMap = new Map();
|
||||
|
||||
// Look for tool responses in subsequent messages
|
||||
if (messageIndex !== undefined && messageIndex >= 0) {
|
||||
for (let i = messageIndex + 1; i < messages.length; i++) {
|
||||
const responses = getToolResponses(messages[i]);
|
||||
|
||||
for (const response of responses) {
|
||||
// Check if this response matches any of our tool requests
|
||||
const matchingRequest = toolRequests.find((req) => req.id === response.id);
|
||||
if (matchingRequest) {
|
||||
responseMap.set(response.id, response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return responseMap;
|
||||
}, [messages, messageIndex, toolRequests]);
|
||||
|
||||
return (
|
||||
<div className="goose-message flex w-[90%] justify-start opacity-0 animate-[appear_150ms_ease-in_forwards]">
|
||||
<div className="flex flex-col w-full">
|
||||
{message.content && (
|
||||
{/* Always show the top content area if there are tool calls, even if textContent is empty */}
|
||||
{(textContent || toolRequests.length > 0) && (
|
||||
<div
|
||||
className={`goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 ${message.toolInvocations ? 'rounded-b-none' : ''}`}
|
||||
className={`goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 ${toolRequests.length > 0 ? 'rounded-b-none' : ''}`}
|
||||
>
|
||||
<MarkdownContent content={message.content} />
|
||||
{textContent ? <MarkdownContent content={textContent} /> : null}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{message.toolInvocations && (
|
||||
{toolRequests.length > 0 && (
|
||||
<div className="goose-message-tool bg-bgApp border border-borderSubtle dark:border-gray-700 rounded-b-2xl px-4 pt-4 pb-2 mt-1">
|
||||
<ToolInvocations toolInvocations={message.toolInvocations} />
|
||||
{toolRequests.map((toolRequest) => (
|
||||
<ToolCallWithResponse
|
||||
key={toolRequest.id}
|
||||
toolRequest={toolRequest}
|
||||
toolResponse={toolResponsesMap.get(toolRequest.id)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -53,7 +89,7 @@ export default function GooseMessage({ message, metadata, messages, append }: Go
|
||||
{/* NOTE from alexhancock on 1/14/2025 - disabling again temporarily due to non-determinism in when the forms show up */}
|
||||
{false && metadata && (
|
||||
<div className="flex mt-[16px]">
|
||||
<GooseResponseForm message={message.content} metadata={metadata} append={append} />
|
||||
<GooseResponseForm message={textContent} metadata={metadata} append={append} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -3,6 +3,8 @@ import MarkdownContent from './MarkdownContent';
|
||||
import { Button } from './ui/button';
|
||||
import { cn } from '../utils';
|
||||
import { Send } from './icons';
|
||||
// Prefixing unused imports with underscore
|
||||
import { createUserMessage as _createUserMessage } from '../types/message';
|
||||
|
||||
interface FormField {
|
||||
label: string;
|
||||
@@ -20,8 +22,8 @@ interface DynamicForm {
|
||||
|
||||
interface GooseResponseFormProps {
|
||||
message: string;
|
||||
metadata: any;
|
||||
append: (value: any) => void;
|
||||
metadata: string[] | null;
|
||||
append: (value: string) => void;
|
||||
}
|
||||
|
||||
export default function GooseResponseForm({
|
||||
@@ -103,31 +105,19 @@ export default function GooseResponseForm({
|
||||
};
|
||||
|
||||
const handleAccept = () => {
|
||||
const message = {
|
||||
content: 'Yes - go ahead.',
|
||||
role: 'user',
|
||||
};
|
||||
append(message);
|
||||
append('Yes - go ahead.');
|
||||
};
|
||||
|
||||
const handleSubmit = () => {
|
||||
if (selectedOption !== null && options[selectedOption]) {
|
||||
const message = {
|
||||
content: `Yes - continue with: ${options[selectedOption].optionTitle}`,
|
||||
role: 'user',
|
||||
};
|
||||
append(message);
|
||||
append(`Yes - continue with: ${options[selectedOption].optionTitle}`);
|
||||
}
|
||||
};
|
||||
|
||||
const handleFormSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (dynamicForm) {
|
||||
const message = {
|
||||
content: JSON.stringify(formValues),
|
||||
role: 'user',
|
||||
};
|
||||
append(message);
|
||||
append(JSON.stringify(formValues));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import React from 'react';
|
||||
import SplashPills from './SplashPills';
|
||||
import { Goose, Rain } from './icons/Goose';
|
||||
import GooseLogo from './GooseLogo';
|
||||
|
||||
export default function Splash({ append }) {
|
||||
|
||||
@@ -5,11 +5,8 @@ function SplashPill({ content, append, className = '', longForm = '' }) {
|
||||
<div
|
||||
className={`px-4 py-2 text-sm text-center text-textSubtle dark:text-textStandard cursor-pointer border border-borderSubtle hover:bg-bgSubtle rounded-full transition-all duration-150 ${className}`}
|
||||
onClick={async () => {
|
||||
const message = {
|
||||
content: longForm || content,
|
||||
role: 'user',
|
||||
};
|
||||
await append(message);
|
||||
// Use the longForm text if provided, otherwise use the content
|
||||
await append(longForm || content);
|
||||
}}
|
||||
>
|
||||
<div className="line-clamp-2">{content}</div>
|
||||
|
||||
138
ui/desktop/src/components/ToolCallWithResponse.tsx
Normal file
138
ui/desktop/src/components/ToolCallWithResponse.tsx
Normal file
@@ -0,0 +1,138 @@
|
||||
import React from 'react';
|
||||
import { Card } from './ui/card';
|
||||
import Box from './ui/Box';
|
||||
import { ToolCallArguments } from './ToolCallArguments';
|
||||
import MarkdownContent from './MarkdownContent';
|
||||
import { LoadingPlaceholder } from './LoadingPlaceholder';
|
||||
import { ChevronUp } from 'lucide-react';
|
||||
import { Content, ToolRequestMessageContent, ToolResponseMessageContent } from '../types/message';
|
||||
import { snakeToTitleCase } from '../utils';
|
||||
|
||||
interface ToolCallWithResponseProps {
|
||||
toolRequest: ToolRequestMessageContent;
|
||||
toolResponse?: ToolResponseMessageContent;
|
||||
}
|
||||
|
||||
export default function ToolCallWithResponse({
|
||||
toolRequest,
|
||||
toolResponse,
|
||||
}: ToolCallWithResponseProps) {
|
||||
const toolCall = toolRequest.toolCall.status === 'success' ? toolRequest.toolCall.value : null;
|
||||
|
||||
if (!toolCall) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full">
|
||||
<Card className="">
|
||||
<ToolCallView toolCall={toolCall} />
|
||||
{toolResponse ? (
|
||||
<ToolResultView
|
||||
result={
|
||||
toolResponse.toolResult.status === 'success'
|
||||
? toolResponse.toolResult.value
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<LoadingPlaceholder />
|
||||
)}
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ToolCallViewProps {
|
||||
toolCall: {
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
};
|
||||
}
|
||||
|
||||
function ToolCallView({ toolCall }: ToolCallViewProps) {
|
||||
return (
|
||||
<div>
|
||||
<div className="flex items-center mb-4">
|
||||
<Box size={16} />
|
||||
<span className="ml-[8px] text-textStandard">
|
||||
{snakeToTitleCase(toolCall.name.substring(toolCall.name.lastIndexOf('__') + 2))}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{toolCall.arguments && <ToolCallArguments args={toolCall.arguments} />}
|
||||
|
||||
<div className="self-stretch h-px my-[10px] -mx-4 bg-borderSubtle dark:bg-gray-700" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ToolResultViewProps {
|
||||
result?: Content[];
|
||||
}
|
||||
|
||||
function ToolResultView({ result }: ToolResultViewProps) {
|
||||
// State to track expanded items
|
||||
const [expandedItems, setExpandedItems] = React.useState<number[]>([]);
|
||||
|
||||
// If no result info, don't show anything
|
||||
if (!result) return null;
|
||||
|
||||
// Find results where either audience is not set, or it's set to a list that includes user
|
||||
const filteredResults = result.filter((item) => {
|
||||
// Check audience (which may not be in the type)
|
||||
const audience = item.annotations?.audience;
|
||||
|
||||
return !audience || audience.includes('user');
|
||||
});
|
||||
|
||||
if (filteredResults.length === 0) return null;
|
||||
|
||||
const toggleExpand = (index: number) => {
|
||||
setExpandedItems((prev) =>
|
||||
prev.includes(index) ? prev.filter((i) => i !== index) : [...prev, index]
|
||||
);
|
||||
};
|
||||
|
||||
const shouldShowExpanded = (item: Content, index: number) => {
|
||||
return (
|
||||
(item.annotations.priority !== undefined && item.annotations.priority >= 0.5) ||
|
||||
expandedItems.includes(index)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="">
|
||||
{filteredResults.map((item, index) => {
|
||||
const isExpanded = shouldShowExpanded(item, index);
|
||||
const shouldMinimize =
|
||||
item.annotations.priority === undefined || item.annotations.priority < 0.5;
|
||||
return (
|
||||
<div key={index} className="relative">
|
||||
{shouldMinimize && (
|
||||
<button
|
||||
onClick={() => toggleExpand(index)}
|
||||
className="mb-1 flex items-center text-textStandard"
|
||||
>
|
||||
<span className="mr-2 text-sm">Output</span>
|
||||
<ChevronUp
|
||||
className={`h-5 w-5 transition-all origin-center ${!isExpanded ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
{(isExpanded || !shouldMinimize) && (
|
||||
<>
|
||||
{item.text && (
|
||||
<MarkdownContent
|
||||
content={item.text}
|
||||
className="whitespace-pre-wrap p-2 max-w-full overflow-x-auto"
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
import React from 'react';
|
||||
import { Card } from './ui/card';
|
||||
import Box from './ui/Box';
|
||||
import { ToolCallArguments } from './ToolCallArguments';
|
||||
import MarkdownContent from './MarkdownContent';
|
||||
import { snakeToTitleCase } from '../utils';
|
||||
import { LoadingPlaceholder } from './LoadingPlaceholder';
|
||||
import { ChevronUp } from 'lucide-react';
|
||||
|
||||
export default function ToolInvocations({ toolInvocations }) {
|
||||
return (
|
||||
<>
|
||||
{toolInvocations.map((toolInvocation) => (
|
||||
<ToolInvocation key={toolInvocation.toolCallId} toolInvocation={toolInvocation} />
|
||||
))}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolInvocation({ toolInvocation }) {
|
||||
return (
|
||||
<div className="w-full">
|
||||
<Card className="">
|
||||
<ToolCall call={toolInvocation} />
|
||||
{toolInvocation.state === 'result' ? (
|
||||
<ToolResult result={toolInvocation} />
|
||||
) : (
|
||||
<LoadingPlaceholder />
|
||||
)}
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ToolCallProps {
|
||||
call: {
|
||||
state: 'call' | 'result';
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: Record<string, any>;
|
||||
};
|
||||
}
|
||||
|
||||
function ToolCall({ call }: ToolCallProps) {
|
||||
return (
|
||||
<div>
|
||||
<div className="flex items-center mb-4">
|
||||
<Box size={16} />
|
||||
<span className="ml-[8px] text-textStandard">
|
||||
{snakeToTitleCase(call.toolName.substring(call.toolName.lastIndexOf('__') + 2))}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{call.args && <ToolCallArguments args={call.args} />}
|
||||
|
||||
<div className="self-stretch h-px my-[10px] -mx-4 bg-borderSubtle dark:bg-gray-700" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface Annotations {
|
||||
audience?: string[]; // Array of audience types
|
||||
priority?: number; // Priority value between 0 and 1
|
||||
}
|
||||
|
||||
interface ResultItem {
|
||||
text?: string;
|
||||
type: 'text' | 'image';
|
||||
mimeType?: string;
|
||||
data?: string; // Base64 encoded image data
|
||||
annotations?: Annotations;
|
||||
}
|
||||
|
||||
interface ToolResultProps {
|
||||
result: {
|
||||
message?: string;
|
||||
result?: ResultItem[];
|
||||
state?: string;
|
||||
toolCallId?: string;
|
||||
toolName?: string;
|
||||
args?: any;
|
||||
input_todo?: any;
|
||||
};
|
||||
}
|
||||
|
||||
function ToolResult({ result }: ToolResultProps) {
|
||||
// State to track expanded items
|
||||
const [expandedItems, setExpandedItems] = React.useState<number[]>([]);
|
||||
|
||||
// If no result info, don't show anything
|
||||
if (!result || !result.result) return null;
|
||||
|
||||
// Normalize to an array
|
||||
const results = Array.isArray(result.result) ? result.result : [result.result];
|
||||
|
||||
// Find results where either audience is not set, or it's set to a list that contains user
|
||||
const filteredResults = results.filter(
|
||||
(item: ResultItem) =>
|
||||
!item.annotations?.audience || item.annotations?.audience?.includes('user')
|
||||
);
|
||||
|
||||
if (filteredResults.length === 0) return null;
|
||||
|
||||
const toggleExpand = (index: number) => {
|
||||
setExpandedItems((prev) =>
|
||||
prev.includes(index) ? prev.filter((i) => i !== index) : [...prev, index]
|
||||
);
|
||||
};
|
||||
|
||||
const shouldShowExpanded = (item: ResultItem, index: number) => {
|
||||
// (priority is defined and > 0.5) OR already in the expandedItems
|
||||
return (
|
||||
(item.annotations?.priority !== undefined && item.annotations?.priority >= 0.5) ||
|
||||
expandedItems.includes(index)
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="">
|
||||
{filteredResults.map((item: ResultItem, index: number) => {
|
||||
const isExpanded = shouldShowExpanded(item, index);
|
||||
// minimize if priority is not set or < 0.5
|
||||
const shouldMinimize =
|
||||
item.annotations?.priority === undefined || item.annotations?.priority < 0.5;
|
||||
return (
|
||||
<div key={index} className="relative">
|
||||
{shouldMinimize && (
|
||||
<button
|
||||
onClick={() => toggleExpand(index)}
|
||||
className="mb-1 flex items-center text-textStandard"
|
||||
>
|
||||
<span className="mr-2 text-sm">Output</span>
|
||||
<ChevronUp
|
||||
className={`h-5 w-5 transition-all origin-center ${!isExpanded ? 'rotate-180' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
{(isExpanded || !shouldMinimize) && (
|
||||
<>
|
||||
{item.type === 'text' && item.text && (
|
||||
<MarkdownContent
|
||||
content={item.text}
|
||||
className="whitespace-pre-wrap p-2 max-w-full overflow-x-auto"
|
||||
/>
|
||||
)}
|
||||
{item.type === 'image' && item.data && item.mimeType && (
|
||||
<img
|
||||
src={`data:${item.mimeType};base64,${item.data}`}
|
||||
alt="Tool result"
|
||||
className="max-w-full h-auto rounded-md"
|
||||
onError={(e) => {
|
||||
console.error('Failed to load image: Invalid MIME-type encoded image data');
|
||||
e.currentTarget.style.display = 'none';
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -2,16 +2,24 @@ import React from 'react';
|
||||
import LinkPreview from './LinkPreview';
|
||||
import { extractUrls } from '../utils/urlUtils';
|
||||
import MarkdownContent from './MarkdownContent';
|
||||
import { Message, getTextContent } from '../types/message';
|
||||
|
||||
export default function UserMessage({ message }) {
|
||||
interface UserMessageProps {
|
||||
message: Message;
|
||||
}
|
||||
|
||||
export default function UserMessage({ message }: UserMessageProps) {
|
||||
// Extract text content from the message
|
||||
const textContent = getTextContent(message);
|
||||
|
||||
// Extract URLs which explicitly contain the http:// or https:// protocol
|
||||
const urls = extractUrls(message.content, []);
|
||||
const urls = extractUrls(textContent, []);
|
||||
|
||||
return (
|
||||
<div className="flex justify-end mt-[16px] w-full opacity-0 animate-[appear_150ms_ease-in_forwards]">
|
||||
<div className="flex-col max-w-[85%]">
|
||||
<div className="flex bg-slate text-white rounded-xl rounded-br-none py-2 px-3">
|
||||
<MarkdownContent content={message.content} className="text-white" />
|
||||
<MarkdownContent content={textContent} className="text-white" />
|
||||
</div>
|
||||
|
||||
{/* TODO(alexhancock): Re-enable link previews once styled well again */}
|
||||
@@ -25,4 +33,4 @@ export default function UserMessage({ message }) {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
487
ui/desktop/src/hooks/useMessageStream.ts
Normal file
487
ui/desktop/src/hooks/useMessageStream.ts
Normal file
@@ -0,0 +1,487 @@
|
||||
import { useState, useCallback, useEffect, useRef, useId } from 'react';
|
||||
import useSWR from 'swr';
|
||||
import { getSecretKey } from '../config';
|
||||
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
|
||||
|
||||
// Ensure TextDecoder is available in the global scope
|
||||
const TextDecoder = globalThis.TextDecoder;
|
||||
|
||||
// Event types for SSE stream
|
||||
type MessageEvent =
|
||||
| { type: 'Message'; message: Message }
|
||||
| { type: 'Error'; error: string }
|
||||
| { type: 'Finish'; reason: string };
|
||||
|
||||
export interface UseMessageStreamOptions {
|
||||
/**
|
||||
* The API endpoint that accepts a `{ messages: Message[] }` object and returns
|
||||
* a stream of messages. Defaults to `/api/chat/reply`.
|
||||
*/
|
||||
api?: string;
|
||||
|
||||
/**
|
||||
* A unique identifier for the chat. If not provided, a random one will be
|
||||
* generated. When provided, the hook with the same `id` will
|
||||
* have shared states across components.
|
||||
*/
|
||||
id?: string;
|
||||
|
||||
/**
|
||||
* Initial messages of the chat. Useful to load an existing chat history.
|
||||
*/
|
||||
initialMessages?: Message[];
|
||||
|
||||
/**
|
||||
* Initial input of the chat.
|
||||
*/
|
||||
initialInput?: string;
|
||||
|
||||
/**
|
||||
* Callback function to be called when a tool call is received.
|
||||
* You can optionally return a result for the tool call.
|
||||
*/
|
||||
_onToolCall?: (toolCall: Record<string, unknown>) => void | Promise<unknown> | unknown;
|
||||
|
||||
/**
|
||||
* Callback function to be called when the API response is received.
|
||||
*/
|
||||
onResponse?: (response: Response) => void | Promise<void>;
|
||||
|
||||
/**
|
||||
* Callback function to be called when the assistant message is finished streaming.
|
||||
*/
|
||||
onFinish?: (message: Message, reason: string) => void;
|
||||
|
||||
/**
|
||||
* Callback function to be called when an error is encountered.
|
||||
*/
|
||||
onError?: (error: Error) => void;
|
||||
|
||||
/**
|
||||
* HTTP headers to be sent with the API request.
|
||||
*/
|
||||
headers?: Record<string, string> | HeadersInit;
|
||||
|
||||
/**
|
||||
* Extra body object to be sent with the API request.
|
||||
*/
|
||||
body?: object;
|
||||
|
||||
/**
|
||||
* Maximum number of sequential LLM calls (steps), e.g. when you use tool calls.
|
||||
* Default is 1.
|
||||
*/
|
||||
maxSteps?: number;
|
||||
}
|
||||
|
||||
export interface UseMessageStreamHelpers {
|
||||
/** Current messages in the chat */
|
||||
messages: Message[];
|
||||
|
||||
/** The error object of the API request */
|
||||
error: undefined | Error;
|
||||
|
||||
/**
|
||||
* Append a user message to the chat list. This triggers the API call to fetch
|
||||
* the assistant's response.
|
||||
*/
|
||||
append: (message: Message | string) => Promise<void>;
|
||||
|
||||
/**
|
||||
* Reload the last AI chat response for the given chat history.
|
||||
*/
|
||||
reload: () => Promise<void>;
|
||||
|
||||
/**
|
||||
* Abort the current request immediately.
|
||||
*/
|
||||
stop: () => void;
|
||||
|
||||
/**
|
||||
* Update the `messages` state locally.
|
||||
*/
|
||||
setMessages: (messages: Message[] | ((messages: Message[]) => Message[])) => void;
|
||||
|
||||
/** The current value of the input */
|
||||
input: string;
|
||||
|
||||
/** setState-powered method to update the input value */
|
||||
setInput: React.Dispatch<React.SetStateAction<string>>;
|
||||
|
||||
/** An input/textarea-ready onChange handler to control the value of the input */
|
||||
handleInputChange: (
|
||||
e: React.ChangeEvent<HTMLInputElement> | React.ChangeEvent<HTMLTextAreaElement>
|
||||
) => void;
|
||||
|
||||
/** Form submission handler to automatically reset input and append a user message */
|
||||
handleSubmit: (event?: { preventDefault?: () => void }) => void;
|
||||
|
||||
/** Whether the API request is in progress */
|
||||
isLoading: boolean;
|
||||
|
||||
/** Add a tool result to a tool call */
|
||||
addToolResult: ({ toolCallId, result }: { toolCallId: string; result: unknown }) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for streaming messages directly from the server using the native Goose message format
|
||||
*/
|
||||
export function useMessageStream({
|
||||
api = '/api/chat/reply',
|
||||
id,
|
||||
initialMessages = [],
|
||||
initialInput = '',
|
||||
onResponse,
|
||||
onFinish,
|
||||
onError,
|
||||
headers,
|
||||
body,
|
||||
maxSteps = 1,
|
||||
}: UseMessageStreamOptions = {}): UseMessageStreamHelpers {
|
||||
// Generate a unique id for the chat if not provided
|
||||
const hookId = useId();
|
||||
const idKey = id ?? hookId;
|
||||
const chatKey = typeof api === 'string' ? [api, idKey] : idKey;
|
||||
|
||||
// Store the chat state in SWR, using the chatId as the key to share states
|
||||
const { data: messages, mutate } = useSWR<Message[]>([chatKey, 'messages'], null, {
|
||||
fallbackData: initialMessages,
|
||||
});
|
||||
|
||||
// Keep the latest messages in a ref
|
||||
const messagesRef = useRef<Message[]>(messages || []);
|
||||
useEffect(() => {
|
||||
messagesRef.current = messages || [];
|
||||
}, [messages]);
|
||||
|
||||
// We store loading state in another hook to sync loading states across hook invocations
|
||||
const { data: isLoading = false, mutate: mutateLoading } = useSWR<boolean>(
|
||||
[chatKey, 'loading'],
|
||||
null
|
||||
);
|
||||
|
||||
const { data: error = undefined, mutate: setError } = useSWR<undefined | Error>(
|
||||
[chatKey, 'error'],
|
||||
null
|
||||
);
|
||||
|
||||
// Abort controller to cancel the current API call
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
// Extra metadata for requests
|
||||
const extraMetadataRef = useRef({
|
||||
headers,
|
||||
body,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
extraMetadataRef.current = {
|
||||
headers,
|
||||
body,
|
||||
};
|
||||
}, [headers, body]);
|
||||
|
||||
// Process the SSE stream from the server
|
||||
const processMessageStream = useCallback(
|
||||
async (response: Response, currentMessages: Message[]) => {
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is empty');
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
|
||||
try {
|
||||
let running = true;
|
||||
while (running) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
running = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// Decode the chunk and add it to our buffer
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
// Process complete SSE events
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() || ''; // Keep the last incomplete event in the buffer
|
||||
|
||||
for (const event of events) {
|
||||
if (event.startsWith('data: ')) {
|
||||
try {
|
||||
const data = event.slice(6); // Remove 'data: ' prefix
|
||||
const parsedEvent = JSON.parse(data) as MessageEvent;
|
||||
|
||||
switch (parsedEvent.type) {
|
||||
case 'Message':
|
||||
// Update messages with the new message
|
||||
currentMessages = [...currentMessages, parsedEvent.message];
|
||||
mutate(currentMessages, false);
|
||||
break;
|
||||
|
||||
case 'Error':
|
||||
throw new Error(parsedEvent.error);
|
||||
|
||||
case 'Finish':
|
||||
// Call onFinish with the last message if available
|
||||
if (onFinish && currentMessages.length > 0) {
|
||||
const lastMessage = currentMessages[currentMessages.length - 1];
|
||||
onFinish(lastMessage, parsedEvent.reason);
|
||||
}
|
||||
break;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error parsing SSE event:', e);
|
||||
if (onError && e instanceof Error) {
|
||||
onError(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name !== 'AbortError') {
|
||||
console.error('Error reading SSE stream:', e);
|
||||
if (onError) {
|
||||
onError(e);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
return currentMessages;
|
||||
},
|
||||
[mutate, onFinish, onError]
|
||||
);
|
||||
|
||||
// Send a request to the server
|
||||
const sendRequest = useCallback(
|
||||
async (requestMessages: Message[]) => {
|
||||
try {
|
||||
mutateLoading(true);
|
||||
setError(undefined);
|
||||
|
||||
// Create abort controller
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
// Log the request messages for debugging
|
||||
console.log('Sending messages to server:', JSON.stringify(requestMessages, null, 2));
|
||||
|
||||
// Send request to the server
|
||||
const response = await fetch(api, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Secret-Key': getSecretKey(),
|
||||
...extraMetadataRef.current.headers,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
messages: requestMessages,
|
||||
...extraMetadataRef.current.body,
|
||||
}),
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
if (onResponse) {
|
||||
await onResponse(response);
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(text || `Error ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
// Process the SSE stream
|
||||
const updatedMessages = await processMessageStream(response, requestMessages);
|
||||
|
||||
// Auto-submit when all tool calls in the last assistant message have results
|
||||
if (maxSteps > 1 && updatedMessages.length > requestMessages.length) {
|
||||
const lastMessage = updatedMessages[updatedMessages.length - 1];
|
||||
if (lastMessage.role === 'assistant' && hasCompletedToolCalls(lastMessage)) {
|
||||
// Count trailing assistant messages to prevent infinite loops
|
||||
let assistantCount = 0;
|
||||
for (let i = updatedMessages.length - 1; i >= 0; i--) {
|
||||
if (updatedMessages[i].role === 'assistant') {
|
||||
assistantCount++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (assistantCount < maxSteps) {
|
||||
await sendRequest(updatedMessages);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abortControllerRef.current = null;
|
||||
} catch (err) {
|
||||
// Ignore abort errors as they are expected
|
||||
if (err instanceof Error && err.name === 'AbortError') {
|
||||
abortControllerRef.current = null;
|
||||
return;
|
||||
}
|
||||
|
||||
if (onError && err instanceof Error) {
|
||||
onError(err);
|
||||
}
|
||||
|
||||
setError(err as Error);
|
||||
} finally {
|
||||
mutateLoading(false);
|
||||
}
|
||||
},
|
||||
[api, processMessageStream, mutateLoading, setError, onResponse, onError, maxSteps]
|
||||
);
|
||||
|
||||
// Append a new message and send request
|
||||
const append = useCallback(
|
||||
async (message: Message | string) => {
|
||||
// If a string is passed, convert it to a Message object
|
||||
const messageToAppend = typeof message === 'string' ? createUserMessage(message) : message;
|
||||
|
||||
console.log('Appending message:', JSON.stringify(messageToAppend, null, 2));
|
||||
|
||||
const currentMessages = [...messagesRef.current, messageToAppend];
|
||||
mutate(currentMessages, false);
|
||||
await sendRequest(currentMessages);
|
||||
},
|
||||
[mutate, sendRequest]
|
||||
);
|
||||
|
||||
// Reload the last message
|
||||
const reload = useCallback(async () => {
|
||||
const currentMessages = messagesRef.current;
|
||||
if (currentMessages.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove last assistant message if present
|
||||
const lastMessage = currentMessages[currentMessages.length - 1];
|
||||
const messagesToSend =
|
||||
lastMessage.role === 'assistant' ? currentMessages.slice(0, -1) : currentMessages;
|
||||
|
||||
await sendRequest(messagesToSend);
|
||||
}, [sendRequest]);
|
||||
|
||||
// Stop the current request
|
||||
const stop = useCallback(() => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Set messages directly
|
||||
const setMessages = useCallback(
|
||||
(messagesOrFn: Message[] | ((messages: Message[]) => Message[])) => {
|
||||
if (typeof messagesOrFn === 'function') {
|
||||
const newMessages = messagesOrFn(messagesRef.current);
|
||||
mutate(newMessages, false);
|
||||
messagesRef.current = newMessages;
|
||||
} else {
|
||||
mutate(messagesOrFn, false);
|
||||
messagesRef.current = messagesOrFn;
|
||||
}
|
||||
},
|
||||
[mutate]
|
||||
);
|
||||
|
||||
// Input state and handlers
|
||||
const [input, setInput] = useState(initialInput);
|
||||
|
||||
const handleInputChange = useCallback(
|
||||
(e: React.ChangeEvent<HTMLInputElement> | React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInput(e.target.value);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (event?: { preventDefault?: () => void }) => {
|
||||
event?.preventDefault?.();
|
||||
if (!input.trim()) return;
|
||||
|
||||
console.log('handleSubmit called with input:', input);
|
||||
await append(input);
|
||||
setInput('');
|
||||
},
|
||||
[input, append]
|
||||
);
|
||||
|
||||
// Add tool result to a message
|
||||
const addToolResult = useCallback(
|
||||
({ toolCallId, result }: { toolCallId: string; result: unknown }) => {
|
||||
const currentMessages = messagesRef.current;
|
||||
|
||||
// Find the last assistant message with the tool call
|
||||
let lastAssistantIndex = -1;
|
||||
for (let i = currentMessages.length - 1; i >= 0; i--) {
|
||||
if (currentMessages[i].role === 'assistant') {
|
||||
const toolRequests = currentMessages[i].content.filter(
|
||||
(content) => content.type === 'toolRequest' && content.id === toolCallId
|
||||
);
|
||||
if (toolRequests.length > 0) {
|
||||
lastAssistantIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (lastAssistantIndex === -1) return;
|
||||
|
||||
// Create a tool response message
|
||||
const toolResponseMessage: Message = {
|
||||
role: 'user' as const,
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
content: [
|
||||
{
|
||||
type: 'toolResponse' as const,
|
||||
id: toolCallId,
|
||||
toolResult: {
|
||||
status: 'success' as const,
|
||||
value: Array.isArray(result)
|
||||
? result
|
||||
: [{ type: 'text' as const, text: String(result), priority: 0 }],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// Insert the tool response after the assistant message
|
||||
const updatedMessages = [
|
||||
...currentMessages.slice(0, lastAssistantIndex + 1),
|
||||
toolResponseMessage,
|
||||
...currentMessages.slice(lastAssistantIndex + 1),
|
||||
];
|
||||
|
||||
mutate(updatedMessages, false);
|
||||
messagesRef.current = updatedMessages;
|
||||
|
||||
// Auto-submit if we have tool results
|
||||
if (maxSteps > 1) {
|
||||
sendRequest(updatedMessages);
|
||||
}
|
||||
},
|
||||
[mutate, maxSteps, sendRequest]
|
||||
);
|
||||
|
||||
return {
|
||||
messages: messages || [],
|
||||
error,
|
||||
append,
|
||||
reload,
|
||||
stop,
|
||||
setMessages,
|
||||
input,
|
||||
setInput,
|
||||
handleInputChange,
|
||||
handleSubmit,
|
||||
isLoading: isLoading || false,
|
||||
addToolResult,
|
||||
};
|
||||
}
|
||||
198
ui/desktop/src/types/message.ts
Normal file
198
ui/desktop/src/types/message.ts
Normal file
@@ -0,0 +1,198 @@
|
||||
/**
|
||||
* Message types that match the Rust message structures
|
||||
* for direct serialization between client and server
|
||||
*/
|
||||
|
||||
export type Role = 'user' | 'assistant';
|
||||
|
||||
export interface TextContent {
|
||||
type: 'text';
|
||||
text: string;
|
||||
annotations?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface ImageContent {
|
||||
type: 'image';
|
||||
data: string;
|
||||
mimeType: string;
|
||||
annotations?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export type Content = TextContent | ImageContent;
|
||||
|
||||
export interface ToolCall {
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface ToolCallResult<T> {
|
||||
status: 'success' | 'error';
|
||||
value?: T;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface ToolRequest {
|
||||
id: string;
|
||||
toolCall: ToolCallResult<ToolCall>;
|
||||
}
|
||||
|
||||
export interface ToolResponse {
|
||||
id: string;
|
||||
toolResult: ToolCallResult<Content[]>;
|
||||
}
|
||||
|
||||
export interface ToolConfirmationRequest {
|
||||
id: string;
|
||||
toolName: string;
|
||||
arguments: Record<string, unknown>;
|
||||
prompt?: string;
|
||||
}
|
||||
|
||||
export interface ToolRequestMessageContent {
|
||||
type: 'toolRequest';
|
||||
id: string;
|
||||
toolCall: ToolCallResult<ToolCall>;
|
||||
}
|
||||
|
||||
export interface ToolResponseMessageContent {
|
||||
type: 'toolResponse';
|
||||
id: string;
|
||||
toolResult: ToolCallResult<Content[]>;
|
||||
}
|
||||
|
||||
export interface ToolConfirmationRequestMessageContent {
|
||||
type: 'toolConfirmationRequest';
|
||||
id: string;
|
||||
toolName: string;
|
||||
arguments: Record<string, unknown>;
|
||||
prompt?: string;
|
||||
}
|
||||
|
||||
export type MessageContent =
|
||||
| TextContent
|
||||
| ImageContent
|
||||
| ToolRequestMessageContent
|
||||
| ToolResponseMessageContent
|
||||
| ToolConfirmationRequestMessageContent;
|
||||
|
||||
export interface Message {
|
||||
id?: string;
|
||||
role: Role;
|
||||
created: number;
|
||||
content: MessageContent[];
|
||||
}
|
||||
|
||||
// Helper functions to create messages
|
||||
export function createUserMessage(text: string): Message {
|
||||
return {
|
||||
id: generateId(),
|
||||
role: 'user',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
content: [{ type: 'text', text }],
|
||||
};
|
||||
}
|
||||
|
||||
export function createAssistantMessage(text: string): Message {
|
||||
return {
|
||||
id: generateId(),
|
||||
role: 'assistant',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
content: [{ type: 'text', text }],
|
||||
};
|
||||
}
|
||||
|
||||
export function createToolRequestMessage(
|
||||
id: string,
|
||||
toolName: string,
|
||||
args: Record<string, unknown>
|
||||
): Message {
|
||||
return {
|
||||
id: generateId(),
|
||||
role: 'assistant',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
content: [
|
||||
{
|
||||
type: 'toolRequest',
|
||||
id,
|
||||
toolCall: {
|
||||
status: 'success',
|
||||
value: {
|
||||
name: toolName,
|
||||
arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
export function createToolResponseMessage(id: string, result: Content[]): Message {
|
||||
return {
|
||||
id: generateId(),
|
||||
role: 'user',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
content: [
|
||||
{
|
||||
type: 'toolResponse',
|
||||
id,
|
||||
toolResult: {
|
||||
status: 'success',
|
||||
value: result,
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
export function createToolErrorResponseMessage(id: string, error: string): Message {
|
||||
return {
|
||||
id: generateId(),
|
||||
role: 'user',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
content: [
|
||||
{
|
||||
type: 'toolResponse',
|
||||
id,
|
||||
toolResult: {
|
||||
status: 'error',
|
||||
error,
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
// Generate a unique ID for messages
|
||||
function generateId(): string {
|
||||
return Math.random().toString(36).substring(2, 10);
|
||||
}
|
||||
|
||||
// Helper functions to extract content from messages
|
||||
export function getTextContent(message: Message): string {
|
||||
return message.content
|
||||
.filter((content): content is TextContent => content.type === 'text')
|
||||
.map((content) => content.text)
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
export function getToolRequests(message: Message): ToolRequestMessageContent[] {
|
||||
return message.content.filter(
|
||||
(content): content is ToolRequestMessageContent => content.type === 'toolRequest'
|
||||
);
|
||||
}
|
||||
|
||||
export function getToolResponses(message: Message): ToolResponseMessageContent[] {
|
||||
return message.content.filter(
|
||||
(content): content is ToolResponseMessageContent => content.type === 'toolResponse'
|
||||
);
|
||||
}
|
||||
|
||||
export function hasCompletedToolCalls(message: Message): boolean {
|
||||
const toolRequests = getToolRequests(message);
|
||||
if (toolRequests.length === 0) return false;
|
||||
|
||||
// For now, we'll assume all tool calls are completed when this is checked
|
||||
// In a real implementation, you'd need to check if all tool requests have responses
|
||||
// by looking through subsequent messages
|
||||
return true;
|
||||
}
|
||||
Reference in New Issue
Block a user