draft: use rust messages in typescript (#1393)

This commit is contained in:
Bradley Axen
2025-02-27 04:02:43 +01:00
committed by GitHub
parent 552facb7ef
commit 0602b35ddc
19 changed files with 1224 additions and 1537 deletions

View File

@@ -10,9 +10,8 @@ use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::message::{Message, MessageContent};
use mcp_core::{content::Content, role::Role};
use serde::Deserialize;
use serde_json::{json, Value};
use mcp_core::role::Role;
use serde::{Deserialize, Serialize};
use std::{
convert::Infallible,
pin::Pin,
@@ -23,33 +22,13 @@ use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;
// Types matching the incoming JSON structure
// Direct message serialization for the chat request
#[derive(Debug, Deserialize)]
struct ChatRequest {
messages: Vec<IncomingMessage>,
messages: Vec<Message>,
}
#[derive(Debug, Deserialize)]
struct IncomingMessage {
role: String,
content: String,
#[serde(default)]
#[serde(rename = "toolInvocations")]
tool_invocations: Vec<ToolInvocation>,
}
#[derive(Debug, Deserialize)]
struct ToolInvocation {
state: String,
#[serde(rename = "toolCallId")]
tool_call_id: String,
#[serde(rename = "toolName")]
tool_name: String,
args: Value,
result: Option<Vec<Content>>,
}
// Custom SSE response type that implements the Vercel AI SDK protocol
// Custom SSE response type for streaming messages
pub struct SseResponse {
rx: ReceiverStream<String>,
}
@@ -79,188 +58,32 @@ impl IntoResponse for SseResponse {
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.header("x-vercel-ai-data-stream", "v1")
.body(body)
.unwrap()
}
}
// Convert incoming messages to our internal Message type
fn convert_messages(incoming: Vec<IncomingMessage>) -> Vec<Message> {
let mut messages = Vec::new();
for msg in incoming {
match msg.role.as_str() {
"user" => {
messages.push(Message::user().with_text(msg.content));
}
"assistant" => {
// First handle any tool invocations - each represents a complete request/response cycle
for tool in msg.tool_invocations {
if tool.state == "result" {
// Add the original tool request from assistant
let tool_call = mcp_core::tool::ToolCall {
name: tool.tool_name,
arguments: tool.args,
};
messages.push(
Message::assistant()
.with_tool_request(tool.tool_call_id.clone(), Ok(tool_call)),
);
// Add the tool response from user
if let Some(result) = &tool.result {
messages.push(
Message::user()
.with_tool_response(tool.tool_call_id, Ok(result.clone())),
);
}
}
}
// Then add the assistant's text response after tool interactions
if !msg.content.is_empty() {
messages.push(Message::assistant().with_text(msg.content));
}
}
_ => {
tracing::warn!("Unknown role: {}", msg.role);
}
}
}
messages
// Message event types for SSE streaming
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum MessageEvent {
Message { message: Message },
Error { error: String },
Finish { reason: String },
}
// Protocol-specific message formatting
struct ProtocolFormatter;
impl ProtocolFormatter {
fn format_text(text: &str) -> String {
let encoded_text = serde_json::to_string(text).unwrap_or_else(|_| String::new());
format!("0:{}\n", encoded_text)
}
fn format_tool_call(id: &str, name: &str, args: &Value) -> String {
// Tool calls start with "9:"
let tool_call = json!({
"toolCallId": id,
"toolName": name,
"args": args
});
format!("9:{}\n", tool_call)
}
fn format_tool_response(id: &str, result: &Vec<Content>) -> String {
// Tool responses start with "a:"
let response = json!({
"toolCallId": id,
"result": result,
});
format!("a:{}\n", response)
}
fn format_error(error: &str) -> String {
// Error messages start with "3:" in the new protocol.
let encoded_error = serde_json::to_string(error).unwrap_or_else(|_| String::new());
format!("3:{}\n", encoded_error)
}
fn format_finish(reason: &str) -> String {
// Finish messages start with "d:"
let finish = json!({
"finishReason": reason,
"usage": {
"promptTokens": 0,
"completionTokens": 0
}
});
format!("d:{}\n", finish)
}
}
async fn stream_message(
message: Message,
// Stream a message as an SSE event
async fn stream_event(
event: MessageEvent,
tx: &mpsc::Sender<String>,
) -> Result<(), mpsc::error::SendError<String>> {
match message.role {
Role::User => {
// Handle tool responses
for content in message.content {
// I believe with the protocol we aren't intended to pass back user messages, so we only deal with
// the tool responses here
if let MessageContent::ToolResponse(response) = content {
// We should return a result for either an error or a success
match response.tool_result {
Ok(result) => {
tx.send(ProtocolFormatter::format_tool_response(
&response.id,
&result,
))
.await?;
}
Err(err) => {
// Send an error message first
tx.send(ProtocolFormatter::format_error(&err.to_string()))
.await?;
// Then send an empty tool response to maintain the protocol
let result =
vec![Content::text(format!("Error: {}", err)).with_priority(0.0)];
tx.send(ProtocolFormatter::format_tool_response(
&response.id,
&result,
))
.await?;
}
}
}
}
}
Role::Assistant => {
for content in message.content {
match content {
MessageContent::ToolRequest(request) => {
match request.tool_call {
Ok(tool_call) => {
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
&tool_call.name,
&tool_call.arguments,
))
.await?;
}
Err(err) => {
// Send a placeholder tool call to maintain protocol
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
"invalid_tool",
&json!({"error": err.to_string()}),
))
.await?;
}
}
}
MessageContent::Text(text) => {
for line in text.text.lines() {
let modified_line = format!("{}\n", line);
tx.send(ProtocolFormatter::format_text(&modified_line))
.await?;
}
}
MessageContent::ToolConfirmationRequest(_) => {
// skip tool confirmation requests
}
MessageContent::Image(_) => {
// skip images
}
MessageContent::ToolResponse(_) => {
// skip tool responses
}
}
}
}
}
Ok(())
let json = serde_json::to_string(&event).unwrap_or_else(|e| {
format!(
r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#,
e
)
});
tx.send(format!("data: {}\n\n", json)).await
}
async fn handler(
@@ -278,19 +101,12 @@ async fn handler(
return Err(StatusCode::UNAUTHORIZED);
}
// Check protocol header (optional in our case)
if let Some(protocol) = headers.get("x-protocol") {
if protocol.to_str().map(|p| p != "data").unwrap_or(true) {
return Err(StatusCode::BAD_REQUEST);
}
}
// Create channel for streaming
let (tx, rx) = mpsc::channel(100);
let stream = ReceiverStream::new(rx);
// Convert incoming messages
let messages = convert_messages(request.messages);
// Get messages directly from the request
let messages = request.messages;
// Get a lock on the shared agent
let agent = state.agent.clone();
@@ -301,10 +117,20 @@ async fn handler(
let agent = match agent.as_ref() {
Some(agent) => agent,
None => {
let _ = tx
.send(ProtocolFormatter::format_error("No agent configured"))
.await;
let _ = tx.send(ProtocolFormatter::format_finish("error")).await;
let _ = stream_event(
MessageEvent::Error {
error: "No agent configured".to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
)
.await;
return;
}
};
@@ -313,10 +139,20 @@ async fn handler(
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {:?}", e);
let _ = tx
.send(ProtocolFormatter::format_error(&e.to_string()))
.await;
let _ = tx.send(ProtocolFormatter::format_finish("error")).await;
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
)
.await;
let _ = stream_event(
MessageEvent::Finish {
reason: "error".to_string(),
},
&tx,
)
.await;
return;
}
};
@@ -326,25 +162,32 @@ async fn handler(
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(message))) => {
if let Err(e) = stream_message(message, &tx).await {
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
break;
}
Ok(None) => {
break;
}
Err(_) => { // Heartbeat, used to detect disconnected clients and then end running tools.
Err(_) => { // Heartbeat, used to detect disconnected clients
if tx.is_closed() {
// Kill any running processes when the client disconnects
// TODO is this used? I suspect post MCP this is on the server instead
// goose::process_store::kill_processes();
break;
}
continue;
@@ -354,24 +197,30 @@ async fn handler(
}
}
// Send finish message
let _ = tx.send(ProtocolFormatter::format_finish("stop")).await;
// Send finish event
let _ = stream_event(
MessageEvent::Finish {
reason: "stop".to_string(),
},
&tx,
)
.await;
});
Ok(SseResponse::new(stream))
}
#[derive(Debug, Deserialize, serde::Serialize)]
#[derive(Debug, Deserialize, Serialize)]
struct AskRequest {
prompt: String,
}
#[derive(Debug, serde::Serialize)]
#[derive(Debug, Serialize)]
struct AskResponse {
response: String,
}
// simple ask an AI for a response, non streaming
// Simple ask an AI for a response, non streaming
async fn ask_handler(
State(state): State<AppState>,
headers: HeaderMap,
@@ -478,85 +327,6 @@ mod tests {
}
}
#[test]
fn test_convert_messages_user_only() {
let incoming = vec![IncomingMessage {
role: "user".to_string(),
content: "Hello".to_string(),
tool_invocations: vec![],
}];
let messages = convert_messages(incoming);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, Role::User);
assert!(
matches!(&messages[0].content[0], MessageContent::Text(text) if text.text == "Hello")
);
}
#[test]
fn test_convert_messages_with_tool_invocation() {
let tool_result = vec![Content::text("tool response").with_priority(0.0)];
let incoming = vec![IncomingMessage {
role: "assistant".to_string(),
content: "".to_string(),
tool_invocations: vec![ToolInvocation {
state: "result".to_string(),
tool_call_id: "123".to_string(),
tool_name: "test_tool".to_string(),
args: json!({"key": "value"}),
result: Some(tool_result.clone()),
}],
}];
let messages = convert_messages(incoming);
assert_eq!(messages.len(), 2); // Tool request and response
// Check tool request
assert_eq!(messages[0].role, Role::Assistant);
assert!(
matches!(&messages[0].content[0], MessageContent::ToolRequest(req) if req.id == "123")
);
// Check tool response
assert_eq!(messages[1].role, Role::User);
assert!(
matches!(&messages[1].content[0], MessageContent::ToolResponse(resp) if resp.id == "123")
);
}
#[test]
fn test_protocol_formatter() {
// Test text formatting
let text = "Hello world";
let formatted = ProtocolFormatter::format_text(text);
assert_eq!(formatted, "0:\"Hello world\"\n");
// Test tool call formatting
let formatted =
ProtocolFormatter::format_tool_call("123", "test_tool", &json!({"key": "value"}));
assert!(formatted.starts_with("9:"));
assert!(formatted.contains("\"toolCallId\":\"123\""));
assert!(formatted.contains("\"toolName\":\"test_tool\""));
// Test tool response formatting
let result = vec![Content::text("response").with_priority(0.0)];
let formatted = ProtocolFormatter::format_tool_response("123", &result);
assert!(formatted.starts_with("a:"));
assert!(formatted.contains("\"toolCallId\":\"123\""));
// Test error formatting
let formatted = ProtocolFormatter::format_error("Test error");
println!("Formatted error: {}", formatted);
assert!(formatted.starts_with("3:"));
assert!(formatted.contains("Test error"));
// Test finish formatting
let formatted = ProtocolFormatter::format_finish("stop");
assert!(formatted.starts_with("d:"));
assert!(formatted.contains("\"finishReason\":\"stop\""));
}
mod integration_tests {
use super::*;
use axum::{body::Body, http::Request};
@@ -575,7 +345,7 @@ mod tests {
});
let agent = AgentFactory::create("reference", mock_provider).unwrap();
let state = AppState {
config: Arc::new(Mutex::new(HashMap::new())), // Add this line
config: Arc::new(Mutex::new(HashMap::new())),
agent: Arc::new(Mutex::new(Some(agent))),
secret_key: "test-secret".to_string(),
};

View File

@@ -14,19 +14,26 @@ use mcp_core::role::Role;
use mcp_core::tool::ToolCall;
use serde_json::Value;
mod tool_result_serde;
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolRequest {
pub id: String,
#[serde(with = "tool_result_serde")]
pub tool_call: ToolResult<ToolCall>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {
pub id: String,
#[serde(with = "tool_result_serde")]
pub tool_result: ToolResult<Vec<Content>>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfirmationRequest {
pub id: String,
pub tool_name: String,
@@ -36,6 +43,7 @@ pub struct ToolConfirmationRequest {
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// Content passed inside a message, which can be both simple content and tool content
#[serde(tag = "type", rename_all = "camelCase")]
pub enum MessageContent {
Text(TextContent),
Image(ImageContent),
@@ -150,6 +158,7 @@ impl From<Content> for MessageContent {
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
pub role: Role,
pub created: i64,
@@ -292,3 +301,123 @@ impl Message {
.all(|c| matches!(c, MessageContent::Text(_)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use mcp_core::handler::ToolError;
use serde_json::{json, Value};
#[test]
fn test_message_serialization() {
let message = Message::assistant()
.with_text("Hello, I'll help you with that.")
.with_tool_request(
"tool123",
Ok(ToolCall::new("test_tool", json!({"param": "value"}))),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized message: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check top-level fields
assert_eq!(value["role"], "assistant");
assert!(value["created"].is_i64());
assert!(value["content"].is_array());
// Check content items
let content = &value["content"];
// First item should be text
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
// Second item should be toolRequest
assert_eq!(content[1]["type"], "toolRequest");
assert_eq!(content[1]["id"], "tool123");
// Check tool_call serialization
assert_eq!(content[1]["toolCall"]["status"], "success");
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
assert_eq!(
content[1]["toolCall"]["value"]["arguments"]["param"],
"value"
);
}
#[test]
fn test_error_serialization() {
let message = Message::assistant().with_tool_request(
"tool123",
Err(ToolError::ExecutionError(
"Something went wrong".to_string(),
)),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized error: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check tool_call serialization with error
let tool_call = &value["content"][0]["toolCall"];
assert_eq!(tool_call["status"], "error");
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
}
#[test]
fn test_deserialization() {
// Create a JSON string with our new format
let json_str = r#"{
"role": "assistant",
"created": 1740171566,
"content": [
{
"type": "text",
"text": "I'll help you with that."
},
{
"type": "toolRequest",
"id": "tool123",
"toolCall": {
"status": "success",
"value": {
"name": "test_tool",
"arguments": {"param": "value"}
}
}
}
]
}"#;
let message: Message = serde_json::from_str(json_str).unwrap();
assert_eq!(message.role, Role::Assistant);
assert_eq!(message.created, 1740171566);
assert_eq!(message.content.len(), 2);
// Check first content item
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "I'll help you with that.");
} else {
panic!("Expected Text content");
}
// Check second content item
if let MessageContent::ToolRequest(req) = &message.content[1] {
assert_eq!(req.id, "tool123");
if let Ok(tool_call) = &req.tool_call {
assert_eq!(tool_call.name, "test_tool");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
} else {
panic!("Expected successful tool call");
}
} else {
panic!("Expected ToolRequest content");
}
}
}

View File

@@ -0,0 +1,64 @@
use mcp_core::handler::{ToolError, ToolResult};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<T, S>(value: &ToolResult<T>, serializer: S) -> Result<S::Ok, S::Error>
where
T: Serialize,
S: Serializer,
{
match value {
Ok(val) => {
let mut state = serializer.serialize_struct("ToolResult", 2)?;
state.serialize_field("status", "success")?;
state.serialize_field("value", val)?;
state.end()
}
Err(err) => {
let mut state = serializer.serialize_struct("ToolResult", 2)?;
state.serialize_field("status", "error")?;
state.serialize_field("error", &err.to_string())?;
state.end()
}
}
}
// For deserialization, let's use a simpler approach that works with the format we're serializing to
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<ToolResult<T>, D::Error>
where
T: Deserialize<'de>,
D: Deserializer<'de>,
{
// Define a helper enum to handle the two possible formats
#[derive(Deserialize)]
#[serde(untagged)]
enum ResultFormat<T> {
Success { status: String, value: T },
Error { status: String, error: String },
}
let format = ResultFormat::deserialize(deserializer)?;
match format {
ResultFormat::Success { status, value } => {
if status == "success" {
Ok(Ok(value))
} else {
Err(serde::de::Error::custom(format!(
"Expected status 'success', got '{}'",
status
)))
}
}
ResultFormat::Error { status, error } => {
if status == "error" {
Ok(Err(ToolError::ExecutionError(error)))
} else {
Err(serde::de::Error::custom(format!(
"Expected status 'error', got '{}'",
status
)))
}
}
}
}

View File

@@ -1,5 +0,0 @@
# AI
This is a small fork of some files in the Vercel AI SDK to make a custom version of useChat which doesn't append text content to messages.
We can work to surface our desired functionality and upstream a change that could support it.

View File

@@ -1,92 +0,0 @@
import { processCustomChatResponse } from './process-custom-chat-response';
import { IdGenerator, JSONValue, Message, UseChatOptions } from '@ai-sdk/ui-utils';
// use function to allow for mocking in tests:
const getOriginalFetch = () => fetch;
export async function callCustomChatApi({
api,
body,
streamProtocol = 'data',
credentials,
headers,
abortController,
restoreMessagesOnFailure,
onResponse,
onUpdate,
onFinish,
onToolCall,
generateId,
fetch = getOriginalFetch(),
}: {
api: string;
body: Record<string, any>;
streamProtocol: 'data' | 'text' | undefined;
credentials: RequestCredentials | undefined;
headers: HeadersInit | undefined;
abortController: (() => AbortController | null) | undefined;
restoreMessagesOnFailure: () => void;
onResponse: ((response: Response) => void | Promise<void>) | undefined;
onUpdate: (newMessages: Message[], data: JSONValue[] | undefined) => void;
onFinish: UseChatOptions['onFinish'];
onToolCall: UseChatOptions['onToolCall'];
generateId: IdGenerator;
fetch: ReturnType<typeof getOriginalFetch> | undefined;
}) {
const response = await fetch(api, {
method: 'POST',
body: JSON.stringify(body),
headers: {
'Content-Type': 'application/json',
...headers,
},
signal: abortController?.()?.signal,
credentials,
}).catch((err) => {
restoreMessagesOnFailure();
throw err;
});
if (onResponse) {
try {
await onResponse(response);
} catch (err) {
throw err;
}
}
if (!response.ok) {
restoreMessagesOnFailure();
throw new Error((await response.text()) ?? 'Failed to fetch the chat response.');
}
if (!response.body) {
throw new Error('The response body is empty.');
}
switch (streamProtocol) {
case 'text': {
throw new Error('Text protocol not supported in custom chat API');
}
case 'data': {
await processCustomChatResponse({
stream: response.body,
update: onUpdate,
onToolCall,
onFinish({ message, finishReason, usage }) {
if (onFinish && message != null) {
onFinish(message, { usage, finishReason });
}
},
generateId,
});
return;
}
default: {
const exhaustiveCheck: never = streamProtocol;
throw new Error(`Unknown stream protocol: ${exhaustiveCheck}`);
}
}
}

View File

@@ -1,13 +0,0 @@
export interface LanguageModelUsage {
completionTokens: number;
promptTokens: number;
totalTokens: number;
}
export function calculateLanguageModelUsage(usage: LanguageModelUsage): LanguageModelUsage {
return {
completionTokens: usage.completionTokens,
promptTokens: usage.promptTokens,
totalTokens: usage.totalTokens,
};
}

View File

@@ -1,239 +0,0 @@
import { generateId as generateIdFunction } from '@ai-sdk/provider-utils';
import type { JSONValue, Message } from '@ai-sdk/ui-utils';
import { parsePartialJson, processDataStream } from '@ai-sdk/ui-utils';
import { LanguageModelV1FinishReason } from '@ai-sdk/provider';
import { LanguageModelUsage } from './core/types/usage';
// Simple usage calculation since we don't have access to the original
function calculateLanguageModelUsage(usage: LanguageModelUsage): LanguageModelUsage {
return {
completionTokens: usage.completionTokens,
promptTokens: usage.promptTokens,
totalTokens: usage.totalTokens,
};
}
export async function processCustomChatResponse({
stream,
update,
onToolCall,
onFinish,
generateId = generateIdFunction,
getCurrentDate = () => new Date(),
}: {
stream: ReadableStream<Uint8Array>;
update: (newMessages: Message[], data: JSONValue[] | undefined) => void;
onToolCall?: (options: { toolCall: any }) => Promise<any>;
onFinish?: (options: {
message: Message | undefined;
finishReason: LanguageModelV1FinishReason;
usage: LanguageModelUsage;
}) => void;
generateId?: () => string;
getCurrentDate?: () => Date;
}) {
const createdAt = getCurrentDate();
let currentMessage: Message | undefined = undefined;
const previousMessages: Message[] = [];
const data: JSONValue[] = [];
let lastEventType: 'text' | 'tool' | undefined = undefined;
// Keep track of partial tool calls
const partialToolCalls: Record<string, { text: string; index: number; toolName: string }> = {};
let usage: LanguageModelUsage = {
completionTokens: NaN,
promptTokens: NaN,
totalTokens: NaN,
};
let finishReason: LanguageModelV1FinishReason = 'unknown';
function execUpdate() {
const copiedData = [...data];
if (currentMessage == null) {
update(previousMessages, copiedData);
return;
}
const copiedMessage = {
...JSON.parse(JSON.stringify(currentMessage)),
revisionId: generateId(),
} as Message;
update([...previousMessages, copiedMessage], copiedData);
}
// Create a new message only if needed
function createNewMessage(): Message {
if (currentMessage == null) {
currentMessage = {
id: generateId(),
role: 'assistant',
content: '',
createdAt,
};
}
return currentMessage;
}
// Move the current message to previous messages if it exists
function archiveCurrentMessage() {
if (currentMessage != null) {
previousMessages.push(currentMessage);
currentMessage = undefined;
}
}
await processDataStream({
stream,
onTextPart(value) {
// If the last event wasn't text, or we don't have a current message, create a new one
if (lastEventType !== 'text' || currentMessage == null) {
// Only archive if there are no tool invocations in 'call' state
const hasPendingToolCalls =
currentMessage?.toolInvocations?.some((invocation) => invocation.state === 'call') ??
false;
if (!hasPendingToolCalls) {
archiveCurrentMessage();
}
if (!currentMessage) {
currentMessage = createNewMessage();
}
currentMessage.content = value;
} else {
// Concatenate with the existing message
currentMessage.content += value;
}
lastEventType = 'text';
execUpdate();
},
onToolCallStreamingStartPart(value) {
// Always create a new message for tool calls
archiveCurrentMessage();
currentMessage = createNewMessage();
lastEventType = 'tool';
if (currentMessage.toolInvocations == null) {
currentMessage.toolInvocations = [];
}
partialToolCalls[value.toolCallId] = {
text: '',
toolName: value.toolName,
index: currentMessage.toolInvocations.length,
};
currentMessage.toolInvocations.push({
state: 'partial-call',
toolCallId: value.toolCallId,
toolName: value.toolName,
args: undefined,
});
execUpdate();
},
onToolCallDeltaPart(value) {
if (!currentMessage) {
currentMessage = createNewMessage();
}
lastEventType = 'tool';
const partialToolCall = partialToolCalls[value.toolCallId];
partialToolCall.text += value.argsTextDelta;
const { value: partialArgs } = parsePartialJson(partialToolCall.text);
currentMessage.toolInvocations![partialToolCall.index] = {
state: 'partial-call',
toolCallId: value.toolCallId,
toolName: partialToolCall.toolName,
args: partialArgs,
};
execUpdate();
},
async onToolCallPart(value) {
if (!currentMessage) {
currentMessage = createNewMessage();
}
lastEventType = 'tool';
if (partialToolCalls[value.toolCallId] != null) {
currentMessage.toolInvocations![partialToolCalls[value.toolCallId].index] = {
state: 'call',
...value,
};
} else {
if (currentMessage.toolInvocations == null) {
currentMessage.toolInvocations = [];
}
currentMessage.toolInvocations.push({
state: 'call',
...value,
});
}
if (onToolCall) {
const result = await onToolCall({ toolCall: value });
if (result != null) {
currentMessage.toolInvocations![currentMessage.toolInvocations!.length - 1] = {
state: 'result',
...value,
result,
};
}
}
execUpdate();
},
onToolResultPart(value) {
if (!currentMessage) {
currentMessage = createNewMessage();
}
lastEventType = 'tool';
const toolInvocations = currentMessage.toolInvocations;
if (toolInvocations == null) {
throw new Error('tool_result must be preceded by a tool_call');
}
const toolInvocationIndex = toolInvocations.findIndex(
(invocation) => invocation.toolCallId === value.toolCallId
);
if (toolInvocationIndex === -1) {
throw new Error('tool_result must be preceded by a tool_call with the same toolCallId');
}
toolInvocations[toolInvocationIndex] = {
...toolInvocations[toolInvocationIndex],
state: 'result' as const,
...value,
};
execUpdate();
},
onDataPart(value) {
data.push(...value);
execUpdate();
},
onFinishStepPart() {
// Archive the current message when a step finishes
archiveCurrentMessage();
},
onFinishMessagePart(value) {
finishReason = value.finishReason;
if (value.usage != null) {
usage = calculateLanguageModelUsage(value.usage);
}
},
onErrorPart(error) {
throw new Error(error);
},
});
onFinish?.({ message: currentMessage, finishReason, usage });
}

View File

@@ -1,5 +0,0 @@
import throttleFunction from 'throttleit';
export function throttle<T extends (...args: any[]) => any>(fn: T, waitMs: number | undefined): T {
return waitMs != null ? throttleFunction(fn, waitMs) : fn;
}

View File

@@ -1,611 +0,0 @@
import { FetchFunction } from '@ai-sdk/provider-utils';
import type {
Attachment,
ChatRequest,
ChatRequestOptions,
CreateMessage,
IdGenerator,
JSONValue,
Message,
UseChatOptions,
} from '@ai-sdk/ui-utils';
import { generateId as generateIdFunc } from '@ai-sdk/ui-utils';
import { callCustomChatApi as callChatApi } from './call-custom-chat-api';
import { useCallback, useEffect, useId, useRef, useState } from 'react';
import useSWR, { KeyedMutator } from 'swr';
import { throttle } from './throttle';
import { getSecretKey } from '../config';
export type { CreateMessage, Message, UseChatOptions };
export type UseChatHelpers = {
/** Current messages in the chat */
messages: Message[];
/** The error object of the API request */
error: undefined | Error;
/**
* Append a user message to the chat list. This triggers the API call to fetch
* the assistant's response.
* @param message The message to append
* @param options Additional options to pass to the API call
*/
append: (
message: Message | CreateMessage,
chatRequestOptions?: ChatRequestOptions
) => Promise<string | null | undefined>;
/**
* Reload the last AI chat response for the given chat history. If the last
* message isn't from the assistant, it will request the API to generate a
* new response.
*/
reload: (chatRequestOptions?: ChatRequestOptions) => Promise<string | null | undefined>;
/**
* Abort the current request immediately, keep the generated tokens if any.
*/
stop: () => void;
/**
* Update the `messages` state locally. This is useful when you want to
* edit the messages on the client, and then trigger the `reload` method
* manually to regenerate the AI response.
*/
setMessages: (messages: Message[] | ((messages: Message[]) => Message[])) => void;
/** The current value of the input */
input: string;
/** setState-powered method to update the input value */
setInput: React.Dispatch<React.SetStateAction<string>>;
/** An input/textarea-ready onChange handler to control the value of the input */
handleInputChange: (
e: React.ChangeEvent<HTMLInputElement> | React.ChangeEvent<HTMLTextAreaElement>
) => void;
/** Form submission handler to automatically reset input and append a user message */
handleSubmit: (
event?: { preventDefault?: () => void },
chatRequestOptions?: ChatRequestOptions
) => void;
metadata?: object;
/** Whether the API request is in progress */
isLoading: boolean;
/** Additional data added on the server via StreamData. */
data?: JSONValue[];
/** Set the data of the chat. You can use this to transform or clear the chat data. */
setData: (
data: JSONValue[] | undefined | ((data: JSONValue[] | undefined) => JSONValue[] | undefined)
) => void;
};
const processResponseStream = async (
api: string,
chatRequest: ChatRequest,
mutate: KeyedMutator<Message[]>,
mutateStreamData: KeyedMutator<JSONValue[] | undefined>,
existingDataRef: React.MutableRefObject<JSONValue[] | undefined>,
extraMetadataRef: React.MutableRefObject<any>,
messagesRef: React.MutableRefObject<Message[]>,
abortControllerRef: React.MutableRefObject<AbortController | null>,
generateId: IdGenerator,
streamProtocol: UseChatOptions['streamProtocol'],
onFinish: UseChatOptions['onFinish'],
onResponse: ((response: Response) => void | Promise<void>) | undefined,
onToolCall: UseChatOptions['onToolCall'] | undefined,
sendExtraMessageFields: boolean | undefined,
experimental_prepareRequestBody:
| ((options: {
messages: Message[];
requestData?: JSONValue;
requestBody?: object;
}) => JSONValue)
| undefined,
fetch: FetchFunction | undefined,
keepLastMessageOnError: boolean
) => {
// Do an optimistic update to the chat state to show the updated messages immediately:
const previousMessages = messagesRef.current;
mutate(chatRequest.messages, false);
const constructedMessagesPayload = sendExtraMessageFields
? chatRequest.messages
: chatRequest.messages.map(
({ role, content, experimental_attachments, data, annotations, toolInvocations }) => ({
role,
content,
...(experimental_attachments !== undefined && {
experimental_attachments,
}),
...(data !== undefined && { data }),
...(annotations !== undefined && { annotations }),
...(toolInvocations !== undefined && { toolInvocations }),
})
);
const existingData = existingDataRef.current;
return await callChatApi({
api,
body: experimental_prepareRequestBody?.({
messages: chatRequest.messages,
requestData: chatRequest.data,
requestBody: chatRequest.body,
}) ?? {
messages: constructedMessagesPayload,
data: chatRequest.data,
...extraMetadataRef.current.body,
...chatRequest.body,
},
streamProtocol,
credentials: extraMetadataRef.current.credentials,
headers: {
...extraMetadataRef.current.headers,
...chatRequest.headers,
'X-Secret-Key': getSecretKey(),
},
abortController: () => abortControllerRef.current,
restoreMessagesOnFailure() {
if (!keepLastMessageOnError) {
mutate(previousMessages, false);
}
},
onResponse,
onUpdate(merged, data) {
mutate([...chatRequest.messages, ...merged], false);
if (data?.length) {
mutateStreamData([...(existingData ?? []), ...data], false);
}
},
onToolCall,
onFinish,
generateId,
fetch,
});
};
export function useChat({
api = '/api/chat',
id,
initialMessages,
initialInput = '',
sendExtraMessageFields,
onToolCall,
experimental_prepareRequestBody,
maxSteps = 1,
streamProtocol = 'data',
onResponse,
onFinish,
onError,
credentials,
headers,
body,
generateId = generateIdFunc,
fetch,
keepLastMessageOnError = true,
experimental_throttle: throttleWaitMs,
}: UseChatOptions & {
key?: string;
/**
* Experimental (React only). When a function is provided, it will be used
* to prepare the request body for the chat API. This can be useful for
* customizing the request body based on the messages and data in the chat.
*
* @param messages The current messages in the chat.
* @param requestData The data object passed in the chat request.
* @param requestBody The request body object passed in the chat request.
*/
experimental_prepareRequestBody?: (options: {
messages: Message[];
requestData?: JSONValue;
requestBody?: object;
}) => JSONValue;
/**
Custom throttle wait in ms for the chat messages and data updates.
Default is undefined, which disables throttling.
*/
experimental_throttle?: number;
/**
Maximum number of sequential LLM calls (steps), e.g. when you use tool calls. Must be at least 1.
A maximum number is required to prevent infinite loops in the case of misconfigured tools.
By default, it's set to 1, which means that only a single LLM call is made.
*/
maxSteps?: number;
} = {}): UseChatHelpers & {
addToolResult: ({ toolCallId, result }: { toolCallId: string; result: any }) => void;
} {
// Generate a unique id for the chat if not provided.
const hookId = useId();
const idKey = id ?? hookId;
const chatKey = typeof api === 'string' ? [api, idKey] : idKey;
// Store a empty array as the initial messages
// (instead of using a default parameter value that gets re-created each time)
// to avoid re-renders:
const [initialMessagesFallback] = useState([]);
// Store the chat state in SWR, using the chatId as the key to share states.
const { data: messages, mutate } = useSWR<Message[]>([chatKey, 'messages'], null, {
fallbackData: initialMessages ?? initialMessagesFallback,
});
// Keep the latest messages in a ref.
const messagesRef = useRef<Message[]>(messages || []);
useEffect(() => {
messagesRef.current = messages || [];
}, [messages]);
// stream data
const { data: streamData, mutate: mutateStreamData } = useSWR<JSONValue[] | undefined>(
[chatKey, 'streamData'],
null
);
// keep the latest stream data in a ref
const streamDataRef = useRef<JSONValue[] | undefined>(streamData);
useEffect(() => {
streamDataRef.current = streamData;
}, [streamData]);
// We store loading state in another hook to sync loading states across hook invocations
const { data: isLoading = false, mutate: mutateLoading } = useSWR<boolean>(
[chatKey, 'loading'],
null
);
const { data: error = undefined, mutate: setError } = useSWR<undefined | Error>(
[chatKey, 'error'],
null
);
// Abort controller to cancel the current API call.
const abortControllerRef = useRef<AbortController | null>(null);
const extraMetadataRef = useRef({
credentials,
headers,
body,
});
useEffect(() => {
extraMetadataRef.current = {
credentials,
headers,
body,
};
}, [credentials, headers, body]);
const triggerRequest = useCallback(
async (chatRequest: ChatRequest) => {
const messageCount = messagesRef.current.length;
try {
mutateLoading(true);
setError(undefined);
const abortController = new AbortController();
abortControllerRef.current = abortController;
await processResponseStream(
api,
chatRequest,
// throttle streamed ui updates:
throttle(mutate, throttleWaitMs),
throttle(mutateStreamData, throttleWaitMs),
streamDataRef,
extraMetadataRef,
messagesRef,
abortControllerRef,
generateId,
streamProtocol,
onFinish,
onResponse,
onToolCall,
sendExtraMessageFields,
experimental_prepareRequestBody,
fetch,
keepLastMessageOnError
);
abortControllerRef.current = null;
} catch (err) {
// Ignore abort errors as they are expected.
if ((err as any).name === 'AbortError') {
abortControllerRef.current = null;
return null;
}
if (onError && err instanceof Error) {
onError(err);
}
setError(err as Error);
} finally {
mutateLoading(false);
}
// auto-submit when all tool calls in the last assistant message have results:
const messages = messagesRef.current;
const lastMessage = messages[messages.length - 1];
if (
// ensure we actually have new messages (to prevent infinite loops in case of errors):
messages.length > messageCount &&
// ensure there is a last message:
lastMessage != null &&
// check if the feature is enabled:
maxSteps > 1 &&
// check that next step is possible:
isAssistantMessageWithCompletedToolCalls(lastMessage) &&
// limit the number of automatic steps:
countTrailingAssistantMessages(messages) < maxSteps
) {
await triggerRequest({ messages });
}
},
[
mutate,
mutateLoading,
api,
extraMetadataRef,
onResponse,
onFinish,
onError,
setError,
mutateStreamData,
streamDataRef,
streamProtocol,
sendExtraMessageFields,
experimental_prepareRequestBody,
onToolCall,
maxSteps,
messagesRef,
abortControllerRef,
generateId,
fetch,
keepLastMessageOnError,
throttleWaitMs,
]
);
const append = useCallback(
async (
message: Message | CreateMessage,
{ data, headers, body, experimental_attachments }: ChatRequestOptions = {}
) => {
if (!message.id) {
message.id = generateId();
}
const attachmentsForRequest = await prepareAttachmentsForRequest(experimental_attachments);
const messages = messagesRef.current.concat({
...message,
id: message.id ?? generateId(),
createdAt: message.createdAt ?? new Date(),
experimental_attachments:
attachmentsForRequest.length > 0 ? attachmentsForRequest : undefined,
});
return triggerRequest({ messages, headers, body, data });
},
[triggerRequest, generateId]
);
const reload = useCallback(
async ({ data, headers, body }: ChatRequestOptions = {}) => {
const messages = messagesRef.current;
if (messages.length === 0) {
return null;
}
// Remove last assistant message and retry last user message.
const lastMessage = messages[messages.length - 1];
return triggerRequest({
messages: lastMessage.role === 'assistant' ? messages.slice(0, -1) : messages,
headers,
body,
data,
});
},
[triggerRequest]
);
const stop = useCallback(() => {
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
}
}, []);
const setMessages = useCallback(
(messages: Message[] | ((messages: Message[]) => Message[])) => {
if (typeof messages === 'function') {
messages = messages(messagesRef.current);
}
mutate(messages, false);
messagesRef.current = messages;
},
[mutate]
);
const setData = useCallback(
(
data: JSONValue[] | undefined | ((data: JSONValue[] | undefined) => JSONValue[] | undefined)
) => {
if (typeof data === 'function') {
data = data(streamDataRef.current);
}
mutateStreamData(data, false);
streamDataRef.current = data;
},
[mutateStreamData]
);
// Input state and handlers.
const [input, setInput] = useState(initialInput);
const handleSubmit = useCallback(
async (
event?: { preventDefault?: () => void },
options: ChatRequestOptions = {},
metadata?: object
) => {
event?.preventDefault?.();
if (!input && !options.allowEmptySubmit) return;
if (metadata) {
extraMetadataRef.current = {
...extraMetadataRef.current,
...metadata,
};
}
const attachmentsForRequest = await prepareAttachmentsForRequest(
options.experimental_attachments
);
const messages =
!input && !attachmentsForRequest.length && options.allowEmptySubmit
? messagesRef.current
: messagesRef.current.concat({
id: generateId(),
createdAt: new Date(),
role: 'user',
content: input,
experimental_attachments:
attachmentsForRequest.length > 0 ? attachmentsForRequest : undefined,
});
const chatRequest: ChatRequest = {
messages,
headers: options.headers,
body: options.body,
data: options.data,
};
triggerRequest(chatRequest);
setInput('');
},
[input, generateId, triggerRequest]
);
const handleInputChange = (e: any) => {
setInput(e.target.value);
};
const addToolResult = ({ toolCallId, result }: { toolCallId: string; result: any }) => {
const updatedMessages = messagesRef.current.map((message, index, arr) =>
// update the tool calls in the last assistant message:
index === arr.length - 1 && message.role === 'assistant' && message.toolInvocations
? {
...message,
toolInvocations: message.toolInvocations.map((toolInvocation) =>
toolInvocation.toolCallId === toolCallId
? {
...toolInvocation,
result,
state: 'result' as const,
}
: toolInvocation
),
}
: message
);
mutate(updatedMessages, false);
// auto-submit when all tool calls in the last assistant message have results:
const lastMessage = updatedMessages[updatedMessages.length - 1];
if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
triggerRequest({ messages: updatedMessages });
}
};
return {
messages: messages || [],
setMessages,
data: streamData,
setData,
error,
append,
reload,
stop,
input,
setInput,
handleInputChange,
handleSubmit,
isLoading,
addToolResult,
};
}
/**
Check if the message is an assistant message with completed tool calls.
The message must have at least one tool invocation and all tool invocations
must have a result.
*/
function isAssistantMessageWithCompletedToolCalls(message: Message) {
return (
message.role === 'assistant' &&
message.toolInvocations &&
message.toolInvocations.length > 0 &&
message.toolInvocations.every((toolInvocation) => 'result' in toolInvocation)
);
}
/**
Returns the number of trailing assistant messages in the array.
*/
function countTrailingAssistantMessages(messages: Message[]) {
let count = 0;
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].role === 'assistant') {
count++;
} else {
break;
}
}
return count;
}
async function prepareAttachmentsForRequest(
attachmentsFromOptions: FileList | Array<Attachment> | undefined
) {
if (attachmentsFromOptions == null) {
return [];
}
if (attachmentsFromOptions instanceof FileList) {
return Promise.all(
Array.from(attachmentsFromOptions).map(async (attachment) => {
const { name, type } = attachment;
const dataUrl = await new Promise<string>((resolve, reject) => {
const reader = new FileReader();
reader.onload = (readerEvent) => {
resolve(readerEvent.target?.result as string);
};
reader.onerror = (error) => reject(error);
reader.readAsDataURL(attachment);
});
return {
name,
contentType: type,
url: dataUrl,
};
})
);
}
if (Array.isArray(attachmentsFromOptions)) {
return attachmentsFromOptions;
}
throw new Error('Invalid attachments type');
}

View File

@@ -1,5 +1,4 @@
import React, { useEffect, useRef, useState } from 'react';
import { Message, useChat } from '../ai-sdk-fork/useChat';
import { getApiUrl } from '../config';
import BottomMenu from './BottomMenu';
import FlappyGoose from './FlappyGoose';
@@ -14,15 +13,13 @@ import UserMessage from './UserMessage';
import { askAi } from '../utils/askAI';
import Splash from './Splash';
import 'react-toastify/dist/ReactToastify.css';
import { useMessageStream } from '../hooks/useMessageStream';
import { Message, createUserMessage, getTextContent } from '../types/message';
export interface ChatType {
id: number;
title: string;
messages: Array<{
id: string;
role: 'function' | 'system' | 'user' | 'assistant' | 'data' | 'tool';
content: string;
}>;
messages: Message[];
}
export default function ChatView({ setView }: { setView: (view: View) => void }) {
@@ -39,14 +36,27 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
const [showGame, setShowGame] = useState(false);
const scrollRef = useRef<ScrollAreaHandle>(null);
const { messages, append, stop, isLoading, error, setMessages } = useChat({
const {
messages,
append,
stop,
isLoading,
error,
setMessages,
input: _input,
setInput: _setInput,
handleInputChange: _handleInputChange,
handleSubmit: _submitMessage,
} = useMessageStream({
api: getApiUrl('/reply'),
initialMessages: chat?.messages || [],
onFinish: async (message, _) => {
onFinish: async (message, _reason) => {
window.electron.stopPowerSaveBlocker();
const fetchResponses = await askAi(message.content);
setMessageMetadata((prev) => ({ ...prev, [message.id]: fetchResponses }));
// Extract text content from the message to pass to askAi
const messageText = getTextContent(message);
const fetchResponses = await askAi(messageText);
setMessageMetadata((prev) => ({ ...prev, [message.id || '']: fetchResponses }));
const timeSinceLastInteraction = Date.now() - lastInteractionTime;
window.electron.logInfo('last interaction:' + lastInteractionTime);
@@ -58,11 +68,16 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
});
}
},
onToolCall: (toolCall) => {
// Handle tool calls if needed
console.log('Tool call received:', toolCall);
// Implement tool call handling logic here
},
});
// Update chat messages when they change
useEffect(() => {
setChat({ ...chat, messages });
setChat((prevChat) => ({ ...prevChat, messages }));
}, [messages]);
useEffect(() => {
@@ -78,10 +93,7 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
const content = customEvent.detail?.value || '';
if (content.trim()) {
setLastInteractionTime(Date.now());
append({
role: 'user',
content,
});
append(createUserMessage(content));
if (scrollRef.current?.scrollToBottom) {
scrollRef.current.scrollToBottom();
}
@@ -97,47 +109,38 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
setLastInteractionTime(Date.now());
window.electron.stopPowerSaveBlocker();
const lastMessage: Message = messages[messages.length - 1];
if (lastMessage.role === 'user' && lastMessage.toolInvocations === undefined) {
// Remove the last user message.
// Handle stopping the message stream
const lastMessage = messages[messages.length - 1];
if (lastMessage && lastMessage.role === 'user') {
// Remove the last user message if it's the most recent one
if (messages.length > 1) {
setMessages(messages.slice(0, -1));
} else {
setMessages([]);
}
} else if (lastMessage.role === 'assistant' && lastMessage.toolInvocations !== undefined) {
// Add messaging about interrupted ongoing tool invocations
const newLastMessage: Message = {
...lastMessage,
toolInvocations: lastMessage.toolInvocations.map((invocation) => {
if (invocation.state !== 'result') {
return {
...invocation,
result: [
{
audience: ['user'],
text: 'Interrupted.\n',
type: 'text',
},
{
audience: ['assistant'],
text: 'Interrupted by the user to make a correction.\n',
type: 'text',
},
],
state: 'result',
};
} else {
return invocation;
}
}),
};
const updatedMessages = [...messages.slice(0, -1), newLastMessage];
setMessages(updatedMessages);
}
// Note: Tool call interruption handling would need to be implemented
// differently with the new message format
};
// Filter out standalone tool response messages for rendering
// They will be shown as part of the tool invocation in the assistant message
const filteredMessages = messages.filter((message) => {
// Keep all assistant messages and user messages that aren't just tool responses
if (message.role === 'assistant') return true;
// For user messages, check if they're only tool responses
if (message.role === 'user') {
const hasOnlyToolResponses = message.content.every((c) => c.type === 'toolResponse');
const hasTextContent = message.content.some((c) => c.type === 'text');
// Keep the message if it has text content or is not just tool responses
return hasTextContent || !hasOnlyToolResponses;
}
return true;
});
return (
<div className="flex flex-col w-full h-screen items-center justify-center">
<div className="relative flex items-center h-[36px] w-full bg-bgSubtle border-b border-borderSubtle">
@@ -145,19 +148,19 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
</div>
<Card className="flex flex-col flex-1 rounded-none h-[calc(100vh-95px)] w-full bg-bgApp mt-0 border-none relative">
{messages.length === 0 ? (
<Splash append={append} />
<Splash append={(text) => append(createUserMessage(text))} />
) : (
<ScrollArea ref={scrollRef} className="flex-1 px-4" autoScroll>
{messages.map((message) => (
<div key={message.id} className="mt-[16px]">
{filteredMessages.map((message, index) => (
<div key={message.id || index} className="mt-[16px]">
{message.role === 'user' ? (
<UserMessage message={message} />
) : (
<GooseMessage
message={message}
messages={messages}
metadata={messageMetadata[message.id]}
append={append}
metadata={messageMetadata[message.id || '']}
append={(text) => append(createUserMessage(text))}
/>
)}
</div>
@@ -166,20 +169,17 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
<div className="flex flex-col items-center justify-center p-4">
<div className="text-red-700 dark:text-red-300 bg-red-400/50 p-3 rounded-lg mb-2">
{error.message || 'Honk! Goose experienced an error while responding'}
{error.status && <span className="ml-2">(Status: {error.status})</span>}
</div>
<div
className="px-3 py-2 mt-2 text-center whitespace-nowrap cursor-pointer text-textStandard border border-borderSubtle hover:bg-bgSubtle rounded-full inline-block transition-all duration-150"
onClick={async () => {
// Find the last user message
const lastUserMessage = messages.reduceRight(
(found, m) => found || (m.role === 'user' ? m : null),
null
null as Message | null
);
if (lastUserMessage) {
append({
role: 'user',
content: lastUserMessage.content,
});
append(lastUserMessage);
}
}}
>

View File

@@ -1,41 +1,77 @@
import React from 'react';
import ToolInvocations from './ToolInvocations';
import React, { useMemo } from 'react';
import LinkPreview from './LinkPreview';
import GooseResponseForm from './GooseResponseForm';
import { extractUrls } from '../utils/urlUtils';
import MarkdownContent from './MarkdownContent';
import ToolCallWithResponse from './ToolCallWithResponse';
import { Message, getTextContent, getToolRequests, getToolResponses } from '../types/message';
interface GooseMessageProps {
message: any;
messages: any[];
metadata?: any;
append: (value: any) => void;
message: Message;
messages: Message[];
metadata?: string[];
append: (value: string) => void;
}
export default function GooseMessage({ message, metadata, messages, append }: GooseMessageProps) {
// Extract text content from the message
const textContent = getTextContent(message);
// Get tool requests from the message
const toolRequests = getToolRequests(message);
// Extract URLs under a few conditions
// 1. The message is purely text
// 2. The link wasn't also present in the previous message
// 3. The message contains the explicit http:// or https:// protocol at the beginning
const messageIndex = messages?.findIndex((msg) => msg.id === message.id);
const previousMessage = messageIndex > 0 ? messages[messageIndex - 1] : null;
const previousUrls = previousMessage ? extractUrls(previousMessage.content) : [];
const urls = !message.toolInvocations ? extractUrls(message.content, previousUrls) : [];
const previousUrls = previousMessage ? extractUrls(getTextContent(previousMessage)) : [];
const urls = toolRequests.length === 0 ? extractUrls(textContent, previousUrls) : [];
// Find tool responses that correspond to the tool requests in this message
const toolResponsesMap = useMemo(() => {
const responseMap = new Map();
// Look for tool responses in subsequent messages
if (messageIndex !== undefined && messageIndex >= 0) {
for (let i = messageIndex + 1; i < messages.length; i++) {
const responses = getToolResponses(messages[i]);
for (const response of responses) {
// Check if this response matches any of our tool requests
const matchingRequest = toolRequests.find((req) => req.id === response.id);
if (matchingRequest) {
responseMap.set(response.id, response);
}
}
}
}
return responseMap;
}, [messages, messageIndex, toolRequests]);
return (
<div className="goose-message flex w-[90%] justify-start opacity-0 animate-[appear_150ms_ease-in_forwards]">
<div className="flex flex-col w-full">
{message.content && (
{/* Always show the top content area if there are tool calls, even if textContent is empty */}
{(textContent || toolRequests.length > 0) && (
<div
className={`goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 ${message.toolInvocations ? 'rounded-b-none' : ''}`}
className={`goose-message-content bg-bgSubtle rounded-2xl px-4 py-2 ${toolRequests.length > 0 ? 'rounded-b-none' : ''}`}
>
<MarkdownContent content={message.content} />
{textContent ? <MarkdownContent content={textContent} /> : null}
</div>
)}
{message.toolInvocations && (
{toolRequests.length > 0 && (
<div className="goose-message-tool bg-bgApp border border-borderSubtle dark:border-gray-700 rounded-b-2xl px-4 pt-4 pb-2 mt-1">
<ToolInvocations toolInvocations={message.toolInvocations} />
{toolRequests.map((toolRequest) => (
<ToolCallWithResponse
key={toolRequest.id}
toolRequest={toolRequest}
toolResponse={toolResponsesMap.get(toolRequest.id)}
/>
))}
</div>
)}
</div>
@@ -53,7 +89,7 @@ export default function GooseMessage({ message, metadata, messages, append }: Go
{/* NOTE from alexhancock on 1/14/2025 - disabling again temporarily due to non-determinism in when the forms show up */}
{false && metadata && (
<div className="flex mt-[16px]">
<GooseResponseForm message={message.content} metadata={metadata} append={append} />
<GooseResponseForm message={textContent} metadata={metadata} append={append} />
</div>
)}
</div>

View File

@@ -3,6 +3,8 @@ import MarkdownContent from './MarkdownContent';
import { Button } from './ui/button';
import { cn } from '../utils';
import { Send } from './icons';
// Prefixing unused imports with underscore
import { createUserMessage as _createUserMessage } from '../types/message';
interface FormField {
label: string;
@@ -20,8 +22,8 @@ interface DynamicForm {
interface GooseResponseFormProps {
message: string;
metadata: any;
append: (value: any) => void;
metadata: string[] | null;
append: (value: string) => void;
}
export default function GooseResponseForm({
@@ -103,31 +105,19 @@ export default function GooseResponseForm({
};
const handleAccept = () => {
const message = {
content: 'Yes - go ahead.',
role: 'user',
};
append(message);
append('Yes - go ahead.');
};
const handleSubmit = () => {
if (selectedOption !== null && options[selectedOption]) {
const message = {
content: `Yes - continue with: ${options[selectedOption].optionTitle}`,
role: 'user',
};
append(message);
append(`Yes - continue with: ${options[selectedOption].optionTitle}`);
}
};
const handleFormSubmit = (e: React.FormEvent) => {
e.preventDefault();
if (dynamicForm) {
const message = {
content: JSON.stringify(formValues),
role: 'user',
};
append(message);
append(JSON.stringify(formValues));
}
};

View File

@@ -1,6 +1,5 @@
import React from 'react';
import SplashPills from './SplashPills';
import { Goose, Rain } from './icons/Goose';
import GooseLogo from './GooseLogo';
export default function Splash({ append }) {

View File

@@ -5,11 +5,8 @@ function SplashPill({ content, append, className = '', longForm = '' }) {
<div
className={`px-4 py-2 text-sm text-center text-textSubtle dark:text-textStandard cursor-pointer border border-borderSubtle hover:bg-bgSubtle rounded-full transition-all duration-150 ${className}`}
onClick={async () => {
const message = {
content: longForm || content,
role: 'user',
};
await append(message);
// Use the longForm text if provided, otherwise use the content
await append(longForm || content);
}}
>
<div className="line-clamp-2">{content}</div>

View File

@@ -0,0 +1,138 @@
import React from 'react';
import { Card } from './ui/card';
import Box from './ui/Box';
import { ToolCallArguments } from './ToolCallArguments';
import MarkdownContent from './MarkdownContent';
import { LoadingPlaceholder } from './LoadingPlaceholder';
import { ChevronUp } from 'lucide-react';
import { Content, ToolRequestMessageContent, ToolResponseMessageContent } from '../types/message';
import { snakeToTitleCase } from '../utils';
interface ToolCallWithResponseProps {
toolRequest: ToolRequestMessageContent;
toolResponse?: ToolResponseMessageContent;
}
export default function ToolCallWithResponse({
toolRequest,
toolResponse,
}: ToolCallWithResponseProps) {
const toolCall = toolRequest.toolCall.status === 'success' ? toolRequest.toolCall.value : null;
if (!toolCall) {
return null;
}
return (
<div className="w-full">
<Card className="">
<ToolCallView toolCall={toolCall} />
{toolResponse ? (
<ToolResultView
result={
toolResponse.toolResult.status === 'success'
? toolResponse.toolResult.value
: undefined
}
/>
) : (
<LoadingPlaceholder />
)}
</Card>
</div>
);
}
interface ToolCallViewProps {
toolCall: {
name: string;
arguments: Record<string, unknown>;
};
}
function ToolCallView({ toolCall }: ToolCallViewProps) {
return (
<div>
<div className="flex items-center mb-4">
<Box size={16} />
<span className="ml-[8px] text-textStandard">
{snakeToTitleCase(toolCall.name.substring(toolCall.name.lastIndexOf('__') + 2))}
</span>
</div>
{toolCall.arguments && <ToolCallArguments args={toolCall.arguments} />}
<div className="self-stretch h-px my-[10px] -mx-4 bg-borderSubtle dark:bg-gray-700" />
</div>
);
}
interface ToolResultViewProps {
result?: Content[];
}
function ToolResultView({ result }: ToolResultViewProps) {
// State to track expanded items
const [expandedItems, setExpandedItems] = React.useState<number[]>([]);
// If no result info, don't show anything
if (!result) return null;
// Find results where either audience is not set, or it's set to a list that includes user
const filteredResults = result.filter((item) => {
// Check audience (which may not be in the type)
const audience = item.annotations?.audience;
return !audience || audience.includes('user');
});
if (filteredResults.length === 0) return null;
const toggleExpand = (index: number) => {
setExpandedItems((prev) =>
prev.includes(index) ? prev.filter((i) => i !== index) : [...prev, index]
);
};
const shouldShowExpanded = (item: Content, index: number) => {
return (
(item.annotations.priority !== undefined && item.annotations.priority >= 0.5) ||
expandedItems.includes(index)
);
};
return (
<div className="">
{filteredResults.map((item, index) => {
const isExpanded = shouldShowExpanded(item, index);
const shouldMinimize =
item.annotations.priority === undefined || item.annotations.priority < 0.5;
return (
<div key={index} className="relative">
{shouldMinimize && (
<button
onClick={() => toggleExpand(index)}
className="mb-1 flex items-center text-textStandard"
>
<span className="mr-2 text-sm">Output</span>
<ChevronUp
className={`h-5 w-5 transition-all origin-center ${!isExpanded ? 'rotate-180' : ''}`}
/>
</button>
)}
{(isExpanded || !shouldMinimize) && (
<>
{item.text && (
<MarkdownContent
content={item.text}
className="whitespace-pre-wrap p-2 max-w-full overflow-x-auto"
/>
)}
</>
)}
</div>
);
})}
</div>
);
}

View File

@@ -1,164 +0,0 @@
import React from 'react';
import { Card } from './ui/card';
import Box from './ui/Box';
import { ToolCallArguments } from './ToolCallArguments';
import MarkdownContent from './MarkdownContent';
import { snakeToTitleCase } from '../utils';
import { LoadingPlaceholder } from './LoadingPlaceholder';
import { ChevronUp } from 'lucide-react';
export default function ToolInvocations({ toolInvocations }) {
return (
<>
{toolInvocations.map((toolInvocation) => (
<ToolInvocation key={toolInvocation.toolCallId} toolInvocation={toolInvocation} />
))}
</>
);
}
function ToolInvocation({ toolInvocation }) {
return (
<div className="w-full">
<Card className="">
<ToolCall call={toolInvocation} />
{toolInvocation.state === 'result' ? (
<ToolResult result={toolInvocation} />
) : (
<LoadingPlaceholder />
)}
</Card>
</div>
);
}
interface ToolCallProps {
call: {
state: 'call' | 'result';
toolCallId: string;
toolName: string;
args: Record<string, any>;
};
}
function ToolCall({ call }: ToolCallProps) {
return (
<div>
<div className="flex items-center mb-4">
<Box size={16} />
<span className="ml-[8px] text-textStandard">
{snakeToTitleCase(call.toolName.substring(call.toolName.lastIndexOf('__') + 2))}
</span>
</div>
{call.args && <ToolCallArguments args={call.args} />}
<div className="self-stretch h-px my-[10px] -mx-4 bg-borderSubtle dark:bg-gray-700" />
</div>
);
}
interface Annotations {
audience?: string[]; // Array of audience types
priority?: number; // Priority value between 0 and 1
}
interface ResultItem {
text?: string;
type: 'text' | 'image';
mimeType?: string;
data?: string; // Base64 encoded image data
annotations?: Annotations;
}
interface ToolResultProps {
result: {
message?: string;
result?: ResultItem[];
state?: string;
toolCallId?: string;
toolName?: string;
args?: any;
input_todo?: any;
};
}
function ToolResult({ result }: ToolResultProps) {
// State to track expanded items
const [expandedItems, setExpandedItems] = React.useState<number[]>([]);
// If no result info, don't show anything
if (!result || !result.result) return null;
// Normalize to an array
const results = Array.isArray(result.result) ? result.result : [result.result];
// Find results where either audience is not set, or it's set to a list that contains user
const filteredResults = results.filter(
(item: ResultItem) =>
!item.annotations?.audience || item.annotations?.audience?.includes('user')
);
if (filteredResults.length === 0) return null;
const toggleExpand = (index: number) => {
setExpandedItems((prev) =>
prev.includes(index) ? prev.filter((i) => i !== index) : [...prev, index]
);
};
const shouldShowExpanded = (item: ResultItem, index: number) => {
// (priority is defined and > 0.5) OR already in the expandedItems
return (
(item.annotations?.priority !== undefined && item.annotations?.priority >= 0.5) ||
expandedItems.includes(index)
);
};
return (
<div className="">
{filteredResults.map((item: ResultItem, index: number) => {
const isExpanded = shouldShowExpanded(item, index);
// minimize if priority is not set or < 0.5
const shouldMinimize =
item.annotations?.priority === undefined || item.annotations?.priority < 0.5;
return (
<div key={index} className="relative">
{shouldMinimize && (
<button
onClick={() => toggleExpand(index)}
className="mb-1 flex items-center text-textStandard"
>
<span className="mr-2 text-sm">Output</span>
<ChevronUp
className={`h-5 w-5 transition-all origin-center ${!isExpanded ? 'rotate-180' : ''}`}
/>
</button>
)}
{(isExpanded || !shouldMinimize) && (
<>
{item.type === 'text' && item.text && (
<MarkdownContent
content={item.text}
className="whitespace-pre-wrap p-2 max-w-full overflow-x-auto"
/>
)}
{item.type === 'image' && item.data && item.mimeType && (
<img
src={`data:${item.mimeType};base64,${item.data}`}
alt="Tool result"
className="max-w-full h-auto rounded-md"
onError={(e) => {
console.error('Failed to load image: Invalid MIME-type encoded image data');
e.currentTarget.style.display = 'none';
}}
/>
)}
</>
)}
</div>
);
})}
</div>
);
}

View File

@@ -2,16 +2,24 @@ import React from 'react';
import LinkPreview from './LinkPreview';
import { extractUrls } from '../utils/urlUtils';
import MarkdownContent from './MarkdownContent';
import { Message, getTextContent } from '../types/message';
export default function UserMessage({ message }) {
interface UserMessageProps {
message: Message;
}
export default function UserMessage({ message }: UserMessageProps) {
// Extract text content from the message
const textContent = getTextContent(message);
// Extract URLs which explicitly contain the http:// or https:// protocol
const urls = extractUrls(message.content, []);
const urls = extractUrls(textContent, []);
return (
<div className="flex justify-end mt-[16px] w-full opacity-0 animate-[appear_150ms_ease-in_forwards]">
<div className="flex-col max-w-[85%]">
<div className="flex bg-slate text-white rounded-xl rounded-br-none py-2 px-3">
<MarkdownContent content={message.content} className="text-white" />
<MarkdownContent content={textContent} className="text-white" />
</div>
{/* TODO(alexhancock): Re-enable link previews once styled well again */}
@@ -25,4 +33,4 @@ export default function UserMessage({ message }) {
</div>
</div>
);
}
}

View File

@@ -0,0 +1,487 @@
import { useState, useCallback, useEffect, useRef, useId } from 'react';
import useSWR from 'swr';
import { getSecretKey } from '../config';
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
// Ensure TextDecoder is available in the global scope
const TextDecoder = globalThis.TextDecoder;
// Event types for SSE stream
type MessageEvent =
| { type: 'Message'; message: Message }
| { type: 'Error'; error: string }
| { type: 'Finish'; reason: string };
export interface UseMessageStreamOptions {
/**
* The API endpoint that accepts a `{ messages: Message[] }` object and returns
* a stream of messages. Defaults to `/api/chat/reply`.
*/
api?: string;
/**
* A unique identifier for the chat. If not provided, a random one will be
* generated. When provided, the hook with the same `id` will
* have shared states across components.
*/
id?: string;
/**
* Initial messages of the chat. Useful to load an existing chat history.
*/
initialMessages?: Message[];
/**
* Initial input of the chat.
*/
initialInput?: string;
/**
* Callback function to be called when a tool call is received.
* You can optionally return a result for the tool call.
*/
_onToolCall?: (toolCall: Record<string, unknown>) => void | Promise<unknown> | unknown;
/**
* Callback function to be called when the API response is received.
*/
onResponse?: (response: Response) => void | Promise<void>;
/**
* Callback function to be called when the assistant message is finished streaming.
*/
onFinish?: (message: Message, reason: string) => void;
/**
* Callback function to be called when an error is encountered.
*/
onError?: (error: Error) => void;
/**
* HTTP headers to be sent with the API request.
*/
headers?: Record<string, string> | HeadersInit;
/**
* Extra body object to be sent with the API request.
*/
body?: object;
/**
* Maximum number of sequential LLM calls (steps), e.g. when you use tool calls.
* Default is 1.
*/
maxSteps?: number;
}
export interface UseMessageStreamHelpers {
/** Current messages in the chat */
messages: Message[];
/** The error object of the API request */
error: undefined | Error;
/**
* Append a user message to the chat list. This triggers the API call to fetch
* the assistant's response.
*/
append: (message: Message | string) => Promise<void>;
/**
* Reload the last AI chat response for the given chat history.
*/
reload: () => Promise<void>;
/**
* Abort the current request immediately.
*/
stop: () => void;
/**
* Update the `messages` state locally.
*/
setMessages: (messages: Message[] | ((messages: Message[]) => Message[])) => void;
/** The current value of the input */
input: string;
/** setState-powered method to update the input value */
setInput: React.Dispatch<React.SetStateAction<string>>;
/** An input/textarea-ready onChange handler to control the value of the input */
handleInputChange: (
e: React.ChangeEvent<HTMLInputElement> | React.ChangeEvent<HTMLTextAreaElement>
) => void;
/** Form submission handler to automatically reset input and append a user message */
handleSubmit: (event?: { preventDefault?: () => void }) => void;
/** Whether the API request is in progress */
isLoading: boolean;
/** Add a tool result to a tool call */
addToolResult: ({ toolCallId, result }: { toolCallId: string; result: unknown }) => void;
}
/**
* Hook for streaming messages directly from the server using the native Goose message format
*/
export function useMessageStream({
api = '/api/chat/reply',
id,
initialMessages = [],
initialInput = '',
onResponse,
onFinish,
onError,
headers,
body,
maxSteps = 1,
}: UseMessageStreamOptions = {}): UseMessageStreamHelpers {
// Generate a unique id for the chat if not provided
const hookId = useId();
const idKey = id ?? hookId;
const chatKey = typeof api === 'string' ? [api, idKey] : idKey;
// Store the chat state in SWR, using the chatId as the key to share states
const { data: messages, mutate } = useSWR<Message[]>([chatKey, 'messages'], null, {
fallbackData: initialMessages,
});
// Keep the latest messages in a ref
const messagesRef = useRef<Message[]>(messages || []);
useEffect(() => {
messagesRef.current = messages || [];
}, [messages]);
// We store loading state in another hook to sync loading states across hook invocations
const { data: isLoading = false, mutate: mutateLoading } = useSWR<boolean>(
[chatKey, 'loading'],
null
);
const { data: error = undefined, mutate: setError } = useSWR<undefined | Error>(
[chatKey, 'error'],
null
);
// Abort controller to cancel the current API call
const abortControllerRef = useRef<AbortController | null>(null);
// Extra metadata for requests
const extraMetadataRef = useRef({
headers,
body,
});
useEffect(() => {
extraMetadataRef.current = {
headers,
body,
};
}, [headers, body]);
// Process the SSE stream from the server
const processMessageStream = useCallback(
async (response: Response, currentMessages: Message[]) => {
if (!response.body) {
throw new Error('Response body is empty');
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
try {
let running = true;
while (running) {
const { done, value } = await reader.read();
if (done) {
running = false;
break;
}
// Decode the chunk and add it to our buffer
buffer += decoder.decode(value, { stream: true });
// Process complete SSE events
const events = buffer.split('\n\n');
buffer = events.pop() || ''; // Keep the last incomplete event in the buffer
for (const event of events) {
if (event.startsWith('data: ')) {
try {
const data = event.slice(6); // Remove 'data: ' prefix
const parsedEvent = JSON.parse(data) as MessageEvent;
switch (parsedEvent.type) {
case 'Message':
// Update messages with the new message
currentMessages = [...currentMessages, parsedEvent.message];
mutate(currentMessages, false);
break;
case 'Error':
throw new Error(parsedEvent.error);
case 'Finish':
// Call onFinish with the last message if available
if (onFinish && currentMessages.length > 0) {
const lastMessage = currentMessages[currentMessages.length - 1];
onFinish(lastMessage, parsedEvent.reason);
}
break;
}
} catch (e) {
console.error('Error parsing SSE event:', e);
if (onError && e instanceof Error) {
onError(e);
}
}
}
}
}
} catch (e) {
if (e instanceof Error && e.name !== 'AbortError') {
console.error('Error reading SSE stream:', e);
if (onError) {
onError(e);
}
}
} finally {
reader.releaseLock();
}
return currentMessages;
},
[mutate, onFinish, onError]
);
// Send a request to the server
const sendRequest = useCallback(
async (requestMessages: Message[]) => {
try {
mutateLoading(true);
setError(undefined);
// Create abort controller
const abortController = new AbortController();
abortControllerRef.current = abortController;
// Log the request messages for debugging
console.log('Sending messages to server:', JSON.stringify(requestMessages, null, 2));
// Send request to the server
const response = await fetch(api, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
...extraMetadataRef.current.headers,
},
body: JSON.stringify({
messages: requestMessages,
...extraMetadataRef.current.body,
}),
signal: abortController.signal,
});
if (onResponse) {
await onResponse(response);
}
if (!response.ok) {
const text = await response.text();
throw new Error(text || `Error ${response.status}: ${response.statusText}`);
}
// Process the SSE stream
const updatedMessages = await processMessageStream(response, requestMessages);
// Auto-submit when all tool calls in the last assistant message have results
if (maxSteps > 1 && updatedMessages.length > requestMessages.length) {
const lastMessage = updatedMessages[updatedMessages.length - 1];
if (lastMessage.role === 'assistant' && hasCompletedToolCalls(lastMessage)) {
// Count trailing assistant messages to prevent infinite loops
let assistantCount = 0;
for (let i = updatedMessages.length - 1; i >= 0; i--) {
if (updatedMessages[i].role === 'assistant') {
assistantCount++;
} else {
break;
}
}
if (assistantCount < maxSteps) {
await sendRequest(updatedMessages);
}
}
}
abortControllerRef.current = null;
} catch (err) {
// Ignore abort errors as they are expected
if (err instanceof Error && err.name === 'AbortError') {
abortControllerRef.current = null;
return;
}
if (onError && err instanceof Error) {
onError(err);
}
setError(err as Error);
} finally {
mutateLoading(false);
}
},
[api, processMessageStream, mutateLoading, setError, onResponse, onError, maxSteps]
);
// Append a new message and send request
const append = useCallback(
async (message: Message | string) => {
// If a string is passed, convert it to a Message object
const messageToAppend = typeof message === 'string' ? createUserMessage(message) : message;
console.log('Appending message:', JSON.stringify(messageToAppend, null, 2));
const currentMessages = [...messagesRef.current, messageToAppend];
mutate(currentMessages, false);
await sendRequest(currentMessages);
},
[mutate, sendRequest]
);
// Reload the last message
const reload = useCallback(async () => {
const currentMessages = messagesRef.current;
if (currentMessages.length === 0) {
return;
}
// Remove last assistant message if present
const lastMessage = currentMessages[currentMessages.length - 1];
const messagesToSend =
lastMessage.role === 'assistant' ? currentMessages.slice(0, -1) : currentMessages;
await sendRequest(messagesToSend);
}, [sendRequest]);
// Stop the current request
const stop = useCallback(() => {
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
}
}, []);
// Set messages directly
const setMessages = useCallback(
(messagesOrFn: Message[] | ((messages: Message[]) => Message[])) => {
if (typeof messagesOrFn === 'function') {
const newMessages = messagesOrFn(messagesRef.current);
mutate(newMessages, false);
messagesRef.current = newMessages;
} else {
mutate(messagesOrFn, false);
messagesRef.current = messagesOrFn;
}
},
[mutate]
);
// Input state and handlers
const [input, setInput] = useState(initialInput);
const handleInputChange = useCallback(
(e: React.ChangeEvent<HTMLInputElement> | React.ChangeEvent<HTMLTextAreaElement>) => {
setInput(e.target.value);
},
[]
);
const handleSubmit = useCallback(
async (event?: { preventDefault?: () => void }) => {
event?.preventDefault?.();
if (!input.trim()) return;
console.log('handleSubmit called with input:', input);
await append(input);
setInput('');
},
[input, append]
);
// Add tool result to a message
const addToolResult = useCallback(
({ toolCallId, result }: { toolCallId: string; result: unknown }) => {
const currentMessages = messagesRef.current;
// Find the last assistant message with the tool call
let lastAssistantIndex = -1;
for (let i = currentMessages.length - 1; i >= 0; i--) {
if (currentMessages[i].role === 'assistant') {
const toolRequests = currentMessages[i].content.filter(
(content) => content.type === 'toolRequest' && content.id === toolCallId
);
if (toolRequests.length > 0) {
lastAssistantIndex = i;
break;
}
}
}
if (lastAssistantIndex === -1) return;
// Create a tool response message
const toolResponseMessage: Message = {
role: 'user' as const,
created: Math.floor(Date.now() / 1000),
content: [
{
type: 'toolResponse' as const,
id: toolCallId,
toolResult: {
status: 'success' as const,
value: Array.isArray(result)
? result
: [{ type: 'text' as const, text: String(result), priority: 0 }],
},
},
],
};
// Insert the tool response after the assistant message
const updatedMessages = [
...currentMessages.slice(0, lastAssistantIndex + 1),
toolResponseMessage,
...currentMessages.slice(lastAssistantIndex + 1),
];
mutate(updatedMessages, false);
messagesRef.current = updatedMessages;
// Auto-submit if we have tool results
if (maxSteps > 1) {
sendRequest(updatedMessages);
}
},
[mutate, maxSteps, sendRequest]
);
return {
messages: messages || [],
error,
append,
reload,
stop,
setMessages,
input,
setInput,
handleInputChange,
handleSubmit,
isLoading: isLoading || false,
addToolResult,
};
}

View File

@@ -0,0 +1,198 @@
/**
* Message types that match the Rust message structures
* for direct serialization between client and server
*/
export type Role = 'user' | 'assistant';
export interface TextContent {
type: 'text';
text: string;
annotations?: Record<string, unknown>;
}
export interface ImageContent {
type: 'image';
data: string;
mimeType: string;
annotations?: Record<string, unknown>;
}
export type Content = TextContent | ImageContent;
export interface ToolCall {
name: string;
arguments: Record<string, unknown>;
}
export interface ToolCallResult<T> {
status: 'success' | 'error';
value?: T;
error?: string;
}
export interface ToolRequest {
id: string;
toolCall: ToolCallResult<ToolCall>;
}
export interface ToolResponse {
id: string;
toolResult: ToolCallResult<Content[]>;
}
export interface ToolConfirmationRequest {
id: string;
toolName: string;
arguments: Record<string, unknown>;
prompt?: string;
}
export interface ToolRequestMessageContent {
type: 'toolRequest';
id: string;
toolCall: ToolCallResult<ToolCall>;
}
export interface ToolResponseMessageContent {
type: 'toolResponse';
id: string;
toolResult: ToolCallResult<Content[]>;
}
export interface ToolConfirmationRequestMessageContent {
type: 'toolConfirmationRequest';
id: string;
toolName: string;
arguments: Record<string, unknown>;
prompt?: string;
}
export type MessageContent =
| TextContent
| ImageContent
| ToolRequestMessageContent
| ToolResponseMessageContent
| ToolConfirmationRequestMessageContent;
export interface Message {
id?: string;
role: Role;
created: number;
content: MessageContent[];
}
// Helper functions to create messages
export function createUserMessage(text: string): Message {
return {
id: generateId(),
role: 'user',
created: Math.floor(Date.now() / 1000),
content: [{ type: 'text', text }],
};
}
export function createAssistantMessage(text: string): Message {
return {
id: generateId(),
role: 'assistant',
created: Math.floor(Date.now() / 1000),
content: [{ type: 'text', text }],
};
}
export function createToolRequestMessage(
id: string,
toolName: string,
args: Record<string, unknown>
): Message {
return {
id: generateId(),
role: 'assistant',
created: Math.floor(Date.now() / 1000),
content: [
{
type: 'toolRequest',
id,
toolCall: {
status: 'success',
value: {
name: toolName,
arguments: args,
},
},
},
],
};
}
export function createToolResponseMessage(id: string, result: Content[]): Message {
return {
id: generateId(),
role: 'user',
created: Math.floor(Date.now() / 1000),
content: [
{
type: 'toolResponse',
id,
toolResult: {
status: 'success',
value: result,
},
},
],
};
}
export function createToolErrorResponseMessage(id: string, error: string): Message {
return {
id: generateId(),
role: 'user',
created: Math.floor(Date.now() / 1000),
content: [
{
type: 'toolResponse',
id,
toolResult: {
status: 'error',
error,
},
},
],
};
}
// Generate a unique ID for messages
function generateId(): string {
return Math.random().toString(36).substring(2, 10);
}
// Helper functions to extract content from messages
export function getTextContent(message: Message): string {
return message.content
.filter((content): content is TextContent => content.type === 'text')
.map((content) => content.text)
.join('\n');
}
export function getToolRequests(message: Message): ToolRequestMessageContent[] {
return message.content.filter(
(content): content is ToolRequestMessageContent => content.type === 'toolRequest'
);
}
export function getToolResponses(message: Message): ToolResponseMessageContent[] {
return message.content.filter(
(content): content is ToolResponseMessageContent => content.type === 'toolResponse'
);
}
export function hasCompletedToolCalls(message: Message): boolean {
const toolRequests = getToolRequests(message);
if (toolRequests.length === 0) return false;
// For now, we'll assume all tool calls are completed when this is checked
// In a real implementation, you'd need to check if all tool requests have responses
// by looking through subsequent messages
return true;
}