mirror of
https://github.com/aljazceru/notedeck.git
synced 2025-12-21 18:24:21 +01:00
dave: tools working even better
Signed-off-by: William Casarin <jb55@jb55.com>
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -3302,10 +3302,13 @@ dependencies = [
|
|||||||
"egui",
|
"egui",
|
||||||
"egui-wgpu",
|
"egui-wgpu",
|
||||||
"futures",
|
"futures",
|
||||||
|
"hex",
|
||||||
|
"nostrdb",
|
||||||
"notedeck",
|
"notedeck",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
@@ -5261,9 +5264,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "time"
|
name = "time"
|
||||||
version = "0.3.40"
|
version = "0.3.41"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9d9c75b47bdff86fa3334a3db91356b8d7d86a9b839dab7d0bdc5c3d3a077618"
|
checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deranged",
|
"deranged",
|
||||||
"itoa",
|
"itoa",
|
||||||
@@ -5282,9 +5285,9 @@ checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "time-macros"
|
name = "time-macros"
|
||||||
version = "0.2.21"
|
version = "0.2.22"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "29aa485584182073ed57fd5004aa09c371f021325014694e432313345865fd04"
|
checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"num-conv",
|
"num-conv",
|
||||||
"time-core",
|
"time-core",
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ tracing = { workspace = true }
|
|||||||
egui-wgpu = { workspace = true }
|
egui-wgpu = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
|
nostrdb = { workspace = true }
|
||||||
|
hex = { workspace = true }
|
||||||
|
time = "0.3.41"
|
||||||
bytemuck = "1.22.0"
|
bytemuck = "1.22.0"
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
reqwest = "0.12.15"
|
reqwest = "0.12.15"
|
||||||
|
|||||||
@@ -1,21 +1,25 @@
|
|||||||
use async_openai::{
|
use async_openai::{
|
||||||
config::OpenAIConfig,
|
config::OpenAIConfig,
|
||||||
types::{
|
types::{
|
||||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
|
ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
|
||||||
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
|
||||||
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
|
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
||||||
ChatCompletionRequestUserMessageContent, ChatCompletionTool, ChatCompletionToolType,
|
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
||||||
CreateChatCompletionRequest, FunctionObject,
|
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
|
||||||
|
ChatCompletionTool, ChatCompletionToolType, CreateChatCompletionRequest, FunctionCall,
|
||||||
|
FunctionObject,
|
||||||
},
|
},
|
||||||
Client,
|
Client,
|
||||||
};
|
};
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
use nostrdb::{Ndb, NoteKey, Transaction};
|
||||||
use notedeck::AppContext;
|
use notedeck::AppContext;
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::mpsc::{self, Receiver};
|
use std::sync::mpsc::{self, Receiver};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use time::{format_description::well_known::Rfc3339, OffsetDateTime};
|
||||||
|
|
||||||
use avatar::DaveAvatar;
|
use avatar::DaveAvatar;
|
||||||
use egui::{Rect, Vec2};
|
use egui::{Rect, Vec2};
|
||||||
@@ -28,10 +32,114 @@ pub enum Message {
|
|||||||
User(String),
|
User(String),
|
||||||
Assistant(String),
|
Assistant(String),
|
||||||
System(String),
|
System(String),
|
||||||
|
ToolCalls(Vec<ToolCall>),
|
||||||
|
ToolResponse(ToolResponse),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct SearchResponse {
|
||||||
|
context: SearchContext,
|
||||||
|
notes: Vec<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub enum ToolResponses {
|
||||||
|
Search(SearchResponse),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolResponse {
|
||||||
|
id: String,
|
||||||
|
typ: ToolResponses,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
id: String,
|
||||||
|
typ: ToolCalls,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct PartialToolCall {
|
||||||
|
id: Option<String>,
|
||||||
|
name: Option<String>,
|
||||||
|
arguments: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct UnknownToolCall {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnknownToolCall {
|
||||||
|
pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> {
|
||||||
|
let Some(tool) = tools.get(&self.name) else {
|
||||||
|
return Err(ToolCallError::NotFound(self.name.to_owned()));
|
||||||
|
};
|
||||||
|
|
||||||
|
let parsed_args = (tool.parse_call)(&self.arguments)?;
|
||||||
|
Ok(ToolCall {
|
||||||
|
id: self.id.clone(),
|
||||||
|
typ: parsed_args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialToolCall {
|
||||||
|
pub fn complete(&self) -> Option<UnknownToolCall> {
|
||||||
|
Some(UnknownToolCall {
|
||||||
|
id: self.id.clone()?,
|
||||||
|
name: self.name.clone()?,
|
||||||
|
arguments: self.arguments.clone()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolCall {
|
||||||
|
pub fn to_api(&self) -> ChatCompletionMessageToolCall {
|
||||||
|
ChatCompletionMessageToolCall {
|
||||||
|
id: self.id.clone(),
|
||||||
|
r#type: ChatCompletionToolType::Function,
|
||||||
|
function: self.typ.to_api(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub enum ToolCalls {
|
||||||
|
Search(SearchCall),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolCalls {
|
||||||
|
pub fn to_api(&self) -> FunctionCall {
|
||||||
|
FunctionCall {
|
||||||
|
name: self.name().to_owned(),
|
||||||
|
arguments: self.arguments(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Search(_) => "search",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn arguments(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::Search(search) => serde_json::to_string(search).unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum DaveResponse {
|
||||||
|
ToolCalls(Vec<ToolCall>),
|
||||||
|
Token(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
pub fn to_api_msg(&self) -> ChatCompletionRequestMessage {
|
pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage {
|
||||||
match self {
|
match self {
|
||||||
Message::User(msg) => {
|
Message::User(msg) => {
|
||||||
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
|
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
|
||||||
@@ -55,11 +163,88 @@ impl Message {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Message::ToolCalls(calls) => {
|
||||||
|
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
||||||
|
tool_calls: Some(calls.iter().map(|c| c.to_api()).collect()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
Message::ToolResponse(resp) => {
|
||||||
|
let tool_response = format_tool_response_for_ai(txn, ndb, &resp.typ);
|
||||||
|
|
||||||
|
ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
||||||
|
tool_call_id: resp.id.clone(),
|
||||||
|
content: ChatCompletionRequestToolMessageContent::Text(tool_response),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
struct SimpleNote {
|
||||||
|
pubkey: String,
|
||||||
|
name: String,
|
||||||
|
content: String,
|
||||||
|
created_at: String,
|
||||||
|
note_kind: String, // todo: add replying to
|
||||||
|
}
|
||||||
|
|
||||||
|
fn note_kind_desc(kind: u64) -> String {
|
||||||
|
match kind {
|
||||||
|
1 => "microblog".to_string(),
|
||||||
|
0 => "profile".to_string(),
|
||||||
|
_ => kind.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take the result of a tool response and present it to the ai so that
|
||||||
|
/// it can interepret it and take further action
|
||||||
|
fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
|
||||||
|
match resp {
|
||||||
|
ToolResponses::Search(search_r) => {
|
||||||
|
let simple_notes: Vec<SimpleNote> = search_r
|
||||||
|
.notes
|
||||||
|
.iter()
|
||||||
|
.filter_map(|nkey| {
|
||||||
|
let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let name = ndb
|
||||||
|
.get_profile_by_pubkey(txn, note.pubkey())
|
||||||
|
.ok()
|
||||||
|
.and_then(|p| p.record().profile())
|
||||||
|
.and_then(|p| p.name().or_else(|| p.display_name()))
|
||||||
|
.unwrap_or("Anonymous")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let content = note.content().to_owned();
|
||||||
|
let pubkey = hex::encode(note.pubkey());
|
||||||
|
let note_kind = note_kind_desc(note.kind() as u64);
|
||||||
|
let created_at = OffsetDateTime::from_unix_timestamp(note.created_at() as i64)
|
||||||
|
.unwrap()
|
||||||
|
.format(&Rfc3339)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Some(SimpleNote {
|
||||||
|
pubkey,
|
||||||
|
name,
|
||||||
|
content,
|
||||||
|
created_at,
|
||||||
|
note_kind,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
serde_json::to_string(&json!({"search_results": simple_notes})).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum SearchContext {
|
pub enum SearchContext {
|
||||||
Home,
|
Home,
|
||||||
@@ -67,16 +252,35 @@ pub enum SearchContext {
|
|||||||
Any,
|
Any,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
pub struct SearchCall {
|
pub struct SearchCall {
|
||||||
context: SearchContext,
|
context: SearchContext,
|
||||||
query: String,
|
query: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SearchCall {
|
impl SearchCall {
|
||||||
pub fn parse(args: &str) -> Result<ToolCall, ToolCallError> {
|
pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> SearchResponse {
|
||||||
|
let limit = 10i32;
|
||||||
|
let filter = nostrdb::Filter::new()
|
||||||
|
.search(&self.query)
|
||||||
|
.limit(limit as u64)
|
||||||
|
.build();
|
||||||
|
let notes = {
|
||||||
|
if let Ok(results) = ndb.query(&txn, &[filter], limit) {
|
||||||
|
results.into_iter().map(|r| r.note_key.as_u64()).collect()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
SearchResponse {
|
||||||
|
context: self.context.clone(),
|
||||||
|
notes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
|
||||||
match serde_json::from_str::<SearchCall>(args) {
|
match serde_json::from_str::<SearchCall>(args) {
|
||||||
Ok(call) => Ok(ToolCall::Search(call)),
|
Ok(call) => Ok(ToolCalls::Search(call)),
|
||||||
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
|
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
|
||||||
"Failed to parse args: '{}', error: {}",
|
"Failed to parse args: '{}', error: {}",
|
||||||
args, e
|
args, e
|
||||||
@@ -85,16 +289,6 @@ impl SearchCall {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum ToolCall {
|
|
||||||
Search(SearchCall),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum DaveResponse {
|
|
||||||
ToolCall(ToolCall),
|
|
||||||
Token(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Dave {
|
pub struct Dave {
|
||||||
chat: Vec<Message>,
|
chat: Vec<Message>,
|
||||||
/// A 3d representation of dave.
|
/// A 3d representation of dave.
|
||||||
@@ -116,13 +310,15 @@ impl Dave {
|
|||||||
let client = Client::with_config(config);
|
let client = Client::with_config(config);
|
||||||
|
|
||||||
let input = "".to_string();
|
let input = "".to_string();
|
||||||
let pubkey = "test_pubkey".to_string();
|
let pubkey = "32e1827635450ebb3c5a7d12c1f8e7b2b514439ac10a67eef3d9fd9c5c68e245".to_string();
|
||||||
let avatar = render_state.map(DaveAvatar::new);
|
let avatar = render_state.map(DaveAvatar::new);
|
||||||
let mut tools: HashMap<String, Tool> = HashMap::new();
|
let mut tools: HashMap<String, Tool> = HashMap::new();
|
||||||
for tool in dave_tools() {
|
for tool in dave_tools() {
|
||||||
tools.insert(tool.name.to_string(), tool);
|
tools.insert(tool.name.to_string(), tool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let system_prompt = Message::System(format!("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find and summarize content for users. The current user's pubkey is {}.", &pubkey).to_string());
|
||||||
|
|
||||||
Dave {
|
Dave {
|
||||||
client,
|
client,
|
||||||
pubkey,
|
pubkey,
|
||||||
@@ -130,13 +326,11 @@ impl Dave {
|
|||||||
incoming_tokens: None,
|
incoming_tokens: None,
|
||||||
tools: Arc::new(tools),
|
tools: Arc::new(tools),
|
||||||
input,
|
input,
|
||||||
chat: vec![
|
chat: vec![system_prompt],
|
||||||
Message::System("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find content for users (TODO: actually implement this)".to_string()),
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render(&mut self, ui: &mut egui::Ui) {
|
fn render(&mut self, app_ctx: &AppContext, ui: &mut egui::Ui) {
|
||||||
if let Some(recvr) = &self.incoming_tokens {
|
if let Some(recvr) = &self.incoming_tokens {
|
||||||
while let Ok(res) = recvr.try_recv() {
|
while let Ok(res) = recvr.try_recv() {
|
||||||
match res {
|
match res {
|
||||||
@@ -146,8 +340,23 @@ impl Dave {
|
|||||||
None => {}
|
None => {}
|
||||||
},
|
},
|
||||||
|
|
||||||
DaveResponse::ToolCall(tool) => {
|
DaveResponse::ToolCalls(toolcalls) => {
|
||||||
tracing::info!("got tool call: {:?}", tool);
|
tracing::info!("got tool calls: {:?}", toolcalls);
|
||||||
|
self.chat.push(Message::ToolCalls(toolcalls.clone()));
|
||||||
|
|
||||||
|
let txn = Transaction::new(app_ctx.ndb).unwrap();
|
||||||
|
for call in &toolcalls {
|
||||||
|
// execute toolcall
|
||||||
|
match &call.typ {
|
||||||
|
ToolCalls::Search(search_call) => {
|
||||||
|
let resp = search_call.execute(&txn, app_ctx.ndb);
|
||||||
|
self.chat.push(Message::ToolResponse(ToolResponse {
|
||||||
|
id: call.id.clone(),
|
||||||
|
typ: ToolResponses::Search(resp),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -162,7 +371,7 @@ impl Dave {
|
|||||||
ui.vertical(|ui| {
|
ui.vertical(|ui| {
|
||||||
self.render_chat(ui);
|
self.render_chat(ui);
|
||||||
|
|
||||||
self.inputbox(ui);
|
self.inputbox(app_ctx, ui);
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -180,20 +389,48 @@ impl Dave {
|
|||||||
match message {
|
match message {
|
||||||
Message::User(msg) => self.user_chat(msg, ui),
|
Message::User(msg) => self.user_chat(msg, ui),
|
||||||
Message::Assistant(msg) => self.assistant_chat(msg, ui),
|
Message::Assistant(msg) => self.assistant_chat(msg, ui),
|
||||||
|
Message::ToolResponse(msg) => Self::tool_response_ui(msg, ui),
|
||||||
Message::System(_msg) => {
|
Message::System(_msg) => {
|
||||||
// system prompt is not rendered. Maybe we could
|
// system prompt is not rendered. Maybe we could
|
||||||
// have a debug option to show this
|
// have a debug option to show this
|
||||||
}
|
}
|
||||||
|
Message::ToolCalls(toolcalls) => {
|
||||||
|
Self::tool_call_ui(&toolcalls, ui);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inputbox(&mut self, ui: &mut egui::Ui) {
|
fn tool_response_ui(tool_response: &ToolResponse, ui: &mut egui::Ui) {
|
||||||
|
ui.label(format!("tool_response: {:?}", tool_response));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tool_call_ui(toolcalls: &[ToolCall], ui: &mut egui::Ui) {
|
||||||
|
ui.vertical(|ui| {
|
||||||
|
for call in toolcalls {
|
||||||
|
match &call.typ {
|
||||||
|
ToolCalls::Search(search_call) => {
|
||||||
|
ui.horizontal(|ui| {
|
||||||
|
let context = match search_call.context {
|
||||||
|
SearchContext::Profile => "profile ",
|
||||||
|
SearchContext::Any => " ",
|
||||||
|
SearchContext::Home => "home ",
|
||||||
|
};
|
||||||
|
|
||||||
|
ui.label(format!("Searching {}for '{}'", context, search_call.query));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inputbox(&mut self, app_ctx: &AppContext, ui: &mut egui::Ui) {
|
||||||
ui.horizontal(|ui| {
|
ui.horizontal(|ui| {
|
||||||
ui.add(egui::TextEdit::multiline(&mut self.input));
|
ui.add(egui::TextEdit::multiline(&mut self.input));
|
||||||
if ui.button("Sned").clicked() {
|
if ui.button("Sned").clicked() {
|
||||||
self.chat.push(Message::User(self.input.clone()));
|
self.chat.push(Message::User(self.input.clone()));
|
||||||
self.send_user_message(ui.ctx());
|
self.send_user_message(app_ctx, ui.ctx());
|
||||||
self.input.clear();
|
self.input.clear();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -217,14 +454,22 @@ impl Dave {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_user_message(&mut self, ctx: &egui::Context) {
|
fn send_user_message(&mut self, app_ctx: &AppContext, ctx: &egui::Context) {
|
||||||
let messages = self.chat.iter().map(|c| c.to_api_msg()).collect();
|
let messages = {
|
||||||
|
let txn = Transaction::new(app_ctx.ndb).expect("txn");
|
||||||
|
self.chat
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.to_api_msg(&txn, app_ctx.ndb))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
let pubkey = self.pubkey.clone();
|
let pubkey = self.pubkey.clone();
|
||||||
let (tx, rx) = mpsc::channel();
|
|
||||||
self.incoming_tokens = Some(rx);
|
|
||||||
let ctx = ctx.clone();
|
let ctx = ctx.clone();
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let tools = self.tools.clone();
|
let tools = self.tools.clone();
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::channel();
|
||||||
|
self.incoming_tokens = Some(rx);
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut token_stream = match client
|
let mut token_stream = match client
|
||||||
.chat()
|
.chat()
|
||||||
@@ -247,8 +492,7 @@ impl Dave {
|
|||||||
Ok(stream) => stream,
|
Ok(stream) => stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut tool_call_name: Option<String> = None;
|
let mut all_tool_calls: HashMap<u32, PartialToolCall> = HashMap::new();
|
||||||
let mut tool_call_chunks: Vec<String> = vec![];
|
|
||||||
|
|
||||||
while let Some(token) = token_stream.next().await {
|
while let Some(token) = token_stream.next().await {
|
||||||
let token = match token {
|
let token = match token {
|
||||||
@@ -265,19 +509,25 @@ impl Dave {
|
|||||||
// if we have tool call arg chunks, collect them here
|
// if we have tool call arg chunks, collect them here
|
||||||
if let Some(tool_calls) = &resp.tool_calls {
|
if let Some(tool_calls) = &resp.tool_calls {
|
||||||
for tool in tool_calls {
|
for tool in tool_calls {
|
||||||
let Some(fcall) = &tool.function else {
|
let entry = all_tool_calls.entry(tool.index).or_default();
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(name) = &fcall.name {
|
if let Some(id) = &tool.id {
|
||||||
tool_call_name = Some(name.clone());
|
entry.id.get_or_insert(id.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(argchunk) = &fcall.arguments else {
|
if let Some(name) = tool.function.as_ref().and_then(|f| f.name.as_ref())
|
||||||
continue;
|
{
|
||||||
};
|
entry.name.get_or_insert(name.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
tool_call_chunks.push(argchunk.clone());
|
if let Some(argchunk) =
|
||||||
|
tool.function.as_ref().and_then(|f| f.arguments.as_ref())
|
||||||
|
{
|
||||||
|
entry
|
||||||
|
.arguments
|
||||||
|
.get_or_insert_with(String::new)
|
||||||
|
.push_str(&argchunk);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -288,28 +538,31 @@ impl Dave {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(tool_name) = tool_call_name {
|
let mut parsed_tool_calls = vec![];
|
||||||
if !tool_call_chunks.is_empty() {
|
for (_index, partial) in &all_tool_calls {
|
||||||
let args = tool_call_chunks.join("");
|
let Some(unknown_tool_call) = partial.complete() else {
|
||||||
match parse_tool_call(&tools, &tool_name, &args) {
|
tracing::error!("could not complete partial tool call: {:?}", partial);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
match unknown_tool_call.parse(&tools) {
|
||||||
Ok(tool_call) => {
|
Ok(tool_call) => {
|
||||||
tx.send(DaveResponse::ToolCall(tool_call)).unwrap();
|
parsed_tool_calls.push(tool_call);
|
||||||
ctx.request_repaint();
|
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
"failed to parse tool call err({:?}): name({:?}) args({:?})",
|
"failed to parse tool call {:?}: {:?}",
|
||||||
|
unknown_tool_call,
|
||||||
err,
|
err,
|
||||||
tool_name,
|
|
||||||
args,
|
|
||||||
);
|
);
|
||||||
// TODO: return error to user
|
// TODO: return error to user
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} else {
|
|
||||||
// TODO: return error to user
|
|
||||||
tracing::error!("got tool call '{}' with no arguments?", tool_name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !parsed_tool_calls.is_empty() {
|
||||||
|
tx.send(DaveResponse::ToolCalls(parsed_tool_calls)).unwrap();
|
||||||
|
ctx.request_repaint();
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::debug!("stream closed");
|
tracing::debug!("stream closed");
|
||||||
@@ -318,7 +571,7 @@ impl Dave {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl notedeck::App for Dave {
|
impl notedeck::App for Dave {
|
||||||
fn update(&mut self, _ctx: &mut AppContext<'_>, ui: &mut egui::Ui) {
|
fn update(&mut self, ctx: &mut AppContext<'_>, ui: &mut egui::Ui) {
|
||||||
/*
|
/*
|
||||||
self.app
|
self.app
|
||||||
.frame_history
|
.frame_history
|
||||||
@@ -326,7 +579,7 @@ impl notedeck::App for Dave {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
//update_dave(self, ctx, ui.ctx());
|
//update_dave(self, ctx, ui.ctx());
|
||||||
self.render(ui);
|
self.render(ctx, ui);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,7 +610,7 @@ struct ToolArg {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Tool {
|
pub struct Tool {
|
||||||
parse_call: fn(&str) -> Result<ToolCall, ToolCallError>,
|
parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
description: &'static str,
|
description: &'static str,
|
||||||
arguments: Vec<ToolArg>,
|
arguments: Vec<ToolArg>,
|
||||||
@@ -458,18 +711,6 @@ pub enum ToolCallError {
|
|||||||
ArgParseFailure(String),
|
ArgParseFailure(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_tool_call(
|
|
||||||
tools: &HashMap<String, Tool>,
|
|
||||||
name: &str,
|
|
||||||
args: &str,
|
|
||||||
) -> Result<ToolCall, ToolCallError> {
|
|
||||||
let Some(tool) = tools.get(name) else {
|
|
||||||
return Err(ToolCallError::NotFound(name.to_owned()));
|
|
||||||
};
|
|
||||||
|
|
||||||
(tool.parse_call)(&args)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dave_tools() -> Vec<Tool> {
|
fn dave_tools() -> Vec<Tool> {
|
||||||
vec![search_tool()]
|
vec![search_tool()]
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user