mirror of
https://github.com/aljazceru/notedeck.git
synced 2025-12-18 17:14:21 +01:00
827 lines
26 KiB
Rust
827 lines
26 KiB
Rust
use async_openai::{
|
|
config::OpenAIConfig,
|
|
types::{
|
|
ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
|
|
ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
|
|
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
|
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
|
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
|
|
ChatCompletionTool, ChatCompletionToolType, CreateChatCompletionRequest, FunctionCall,
|
|
FunctionObject,
|
|
},
|
|
Client,
|
|
};
|
|
use futures::StreamExt;
|
|
use nostrdb::{Ndb, NoteKey, Transaction};
|
|
use notedeck::AppContext;
|
|
use notedeck_ui::icons::search_icon;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::{json, Value};
|
|
use std::collections::HashMap;
|
|
use std::sync::mpsc::{self, Receiver};
|
|
use std::sync::Arc;
|
|
use time::{format_description::well_known::Rfc3339, OffsetDateTime};
|
|
|
|
pub use avatar::DaveAvatar;
|
|
use egui_wgpu::RenderState;
|
|
|
|
pub use quaternion::Quaternion;
|
|
pub use vec3::Vec3;
|
|
|
|
mod avatar;
|
|
mod quaternion;
|
|
mod vec3;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum Message {
|
|
User(String),
|
|
Assistant(String),
|
|
System(String),
|
|
ToolCalls(Vec<ToolCall>),
|
|
ToolResponse(ToolResponse),
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SearchResponse {
|
|
//context: SearchContext,
|
|
notes: Vec<u64>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub enum ToolResponses {
|
|
Search(SearchResponse),
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolResponse {
|
|
id: String,
|
|
typ: ToolResponses,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolCall {
|
|
id: String,
|
|
typ: ToolCalls,
|
|
}
|
|
|
|
#[derive(Default, Debug, Clone)]
|
|
pub struct PartialToolCall {
|
|
id: Option<String>,
|
|
name: Option<String>,
|
|
arguments: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct UnknownToolCall {
|
|
id: String,
|
|
name: String,
|
|
arguments: String,
|
|
}
|
|
|
|
impl UnknownToolCall {
|
|
pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> {
|
|
let Some(tool) = tools.get(&self.name) else {
|
|
return Err(ToolCallError::NotFound(self.name.to_owned()));
|
|
};
|
|
|
|
let parsed_args = (tool.parse_call)(&self.arguments)?;
|
|
Ok(ToolCall {
|
|
id: self.id.clone(),
|
|
typ: parsed_args,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl PartialToolCall {
|
|
pub fn complete(&self) -> Option<UnknownToolCall> {
|
|
Some(UnknownToolCall {
|
|
id: self.id.clone()?,
|
|
name: self.name.clone()?,
|
|
arguments: self.arguments.clone()?,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl ToolCall {
|
|
pub fn to_api(&self) -> ChatCompletionMessageToolCall {
|
|
ChatCompletionMessageToolCall {
|
|
id: self.id.clone(),
|
|
r#type: ChatCompletionToolType::Function,
|
|
function: self.typ.to_api(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub enum ToolCalls {
|
|
Search(SearchCall),
|
|
}
|
|
|
|
impl ToolCalls {
|
|
pub fn to_api(&self) -> FunctionCall {
|
|
FunctionCall {
|
|
name: self.name().to_owned(),
|
|
arguments: self.arguments(),
|
|
}
|
|
}
|
|
|
|
fn name(&self) -> &'static str {
|
|
match self {
|
|
Self::Search(_) => "search",
|
|
}
|
|
}
|
|
|
|
fn arguments(&self) -> String {
|
|
match self {
|
|
Self::Search(search) => serde_json::to_string(search).unwrap(),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub enum DaveResponse {
|
|
ToolCalls(Vec<ToolCall>),
|
|
Token(String),
|
|
}
|
|
|
|
impl Message {
|
|
pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage {
|
|
match self {
|
|
Message::User(msg) => {
|
|
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
|
|
name: None,
|
|
content: ChatCompletionRequestUserMessageContent::Text(msg.clone()),
|
|
})
|
|
}
|
|
|
|
Message::Assistant(msg) => {
|
|
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
|
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
msg.clone(),
|
|
)),
|
|
..Default::default()
|
|
})
|
|
}
|
|
|
|
Message::System(msg) => {
|
|
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
|
content: ChatCompletionRequestSystemMessageContent::Text(msg.clone()),
|
|
..Default::default()
|
|
})
|
|
}
|
|
|
|
Message::ToolCalls(calls) => {
|
|
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
|
tool_calls: Some(calls.iter().map(|c| c.to_api()).collect()),
|
|
..Default::default()
|
|
})
|
|
}
|
|
|
|
Message::ToolResponse(resp) => {
|
|
let tool_response = format_tool_response_for_ai(txn, ndb, &resp.typ);
|
|
|
|
ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
|
tool_call_id: resp.id.clone(),
|
|
content: ChatCompletionRequestToolMessageContent::Text(tool_response),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct SimpleNote {
|
|
pubkey: String,
|
|
name: String,
|
|
content: String,
|
|
created_at: String,
|
|
note_kind: String, // todo: add replying to
|
|
}
|
|
|
|
fn note_kind_desc(kind: u64) -> String {
|
|
match kind {
|
|
1 => "microblog".to_string(),
|
|
0 => "profile".to_string(),
|
|
_ => kind.to_string(),
|
|
}
|
|
}
|
|
|
|
/// Take the result of a tool response and present it to the ai so that
|
|
/// it can interepret it and take further action
|
|
fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
|
|
match resp {
|
|
ToolResponses::Search(search_r) => {
|
|
let simple_notes: Vec<SimpleNote> = search_r
|
|
.notes
|
|
.iter()
|
|
.filter_map(|nkey| {
|
|
let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else {
|
|
return None;
|
|
};
|
|
|
|
let name = ndb
|
|
.get_profile_by_pubkey(txn, note.pubkey())
|
|
.ok()
|
|
.and_then(|p| p.record().profile())
|
|
.and_then(|p| p.name().or_else(|| p.display_name()))
|
|
.unwrap_or("Anonymous")
|
|
.to_string();
|
|
|
|
let content = note.content().to_owned();
|
|
let pubkey = hex::encode(note.pubkey());
|
|
let note_kind = note_kind_desc(note.kind() as u64);
|
|
let created_at = OffsetDateTime::from_unix_timestamp(note.created_at() as i64)
|
|
.unwrap()
|
|
.format(&Rfc3339)
|
|
.unwrap();
|
|
|
|
Some(SimpleNote {
|
|
pubkey,
|
|
name,
|
|
content,
|
|
created_at,
|
|
note_kind,
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
serde_json::to_string(&json!({"search_results": simple_notes})).unwrap()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum SearchContext {
|
|
Home,
|
|
Profile,
|
|
Any,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
pub struct SearchCall {
|
|
//context: SearchContext,
|
|
limit: i32,
|
|
query: String,
|
|
}
|
|
|
|
impl SearchCall {
|
|
pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> SearchResponse {
|
|
let limit = 10i32;
|
|
let filter = nostrdb::Filter::new()
|
|
.search(&self.query)
|
|
.limit(limit as u64)
|
|
.build();
|
|
let notes = {
|
|
if let Ok(results) = ndb.query(txn, &[filter], limit) {
|
|
results.into_iter().map(|r| r.note_key.as_u64()).collect()
|
|
} else {
|
|
vec![]
|
|
}
|
|
};
|
|
SearchResponse {
|
|
//context: self.context.clone(),
|
|
notes,
|
|
}
|
|
}
|
|
|
|
pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
|
|
match serde_json::from_str::<SearchCall>(args) {
|
|
Ok(call) => Ok(ToolCalls::Search(call)),
|
|
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
|
|
"Failed to parse args: '{}', error: {}",
|
|
args, e
|
|
))),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Dave {
|
|
chat: Vec<Message>,
|
|
/// A 3d representation of dave.
|
|
avatar: Option<DaveAvatar>,
|
|
input: String,
|
|
pubkey: String,
|
|
tools: Arc<HashMap<String, Tool>>,
|
|
client: async_openai::Client<OpenAIConfig>,
|
|
incoming_tokens: Option<Receiver<DaveResponse>>,
|
|
model_config: ModelConfig,
|
|
}
|
|
|
|
pub struct ModelConfig {
|
|
endpoint: Option<String>,
|
|
model: String,
|
|
api_key: Option<String>,
|
|
}
|
|
|
|
impl Default for ModelConfig {
|
|
fn default() -> Self {
|
|
ModelConfig {
|
|
endpoint: None,
|
|
model: "gpt-4o".to_string(),
|
|
api_key: std::env::var("OPENAI_API_KEY").ok(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ModelConfig {
|
|
pub fn ollama() -> Self {
|
|
ModelConfig {
|
|
endpoint: std::env::var("OLLAMA_HOST").ok(),
|
|
model: "hhao/qwen2.5-coder-tools:latest".to_string(),
|
|
api_key: None,
|
|
}
|
|
}
|
|
|
|
pub fn to_api(&self) -> OpenAIConfig {
|
|
let mut cfg = OpenAIConfig::new();
|
|
if let Some(endpoint) = &self.endpoint {
|
|
cfg = cfg.with_api_base(endpoint.to_owned());
|
|
}
|
|
|
|
if let Some(api_key) = &self.api_key {
|
|
cfg = cfg.with_api_key(api_key.to_owned());
|
|
}
|
|
|
|
cfg
|
|
}
|
|
}
|
|
|
|
impl Dave {
|
|
pub fn avatar_mut(&mut self) -> Option<&mut DaveAvatar> {
|
|
self.avatar.as_mut()
|
|
}
|
|
|
|
pub fn new(render_state: Option<&RenderState>) -> Self {
|
|
//let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1");
|
|
let model_config = ModelConfig::default();
|
|
let client = Client::with_config(model_config.to_api());
|
|
|
|
let input = "".to_string();
|
|
let pubkey = "32e1827635450ebb3c5a7d12c1f8e7b2b514439ac10a67eef3d9fd9c5c68e245".to_string();
|
|
let avatar = render_state.map(DaveAvatar::new);
|
|
let mut tools: HashMap<String, Tool> = HashMap::new();
|
|
for tool in dave_tools() {
|
|
tools.insert(tool.name.to_string(), tool);
|
|
}
|
|
|
|
let system_prompt = Message::System(format!("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find and summarize content for users. The current user's pubkey is {}.", &pubkey).to_string());
|
|
|
|
Dave {
|
|
client,
|
|
pubkey,
|
|
avatar,
|
|
incoming_tokens: None,
|
|
tools: Arc::new(tools),
|
|
input,
|
|
model_config,
|
|
chat: vec![system_prompt],
|
|
}
|
|
}
|
|
|
|
fn render(&mut self, app_ctx: &AppContext, ui: &mut egui::Ui) {
|
|
let mut should_send = false;
|
|
if let Some(recvr) = &self.incoming_tokens {
|
|
while let Ok(res) = recvr.try_recv() {
|
|
if let Some(avatar) = &mut self.avatar {
|
|
avatar.random_nudge();
|
|
}
|
|
match res {
|
|
DaveResponse::Token(token) => match self.chat.last_mut() {
|
|
Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
|
|
Some(_) => self.chat.push(Message::Assistant(token)),
|
|
None => {}
|
|
},
|
|
|
|
DaveResponse::ToolCalls(toolcalls) => {
|
|
tracing::info!("got tool calls: {:?}", toolcalls);
|
|
self.chat.push(Message::ToolCalls(toolcalls.clone()));
|
|
|
|
let txn = Transaction::new(app_ctx.ndb).unwrap();
|
|
for call in &toolcalls {
|
|
// execute toolcall
|
|
match &call.typ {
|
|
ToolCalls::Search(search_call) => {
|
|
let resp = search_call.execute(&txn, app_ctx.ndb);
|
|
self.chat.push(Message::ToolResponse(ToolResponse {
|
|
id: call.id.clone(),
|
|
typ: ToolResponses::Search(resp),
|
|
}))
|
|
}
|
|
}
|
|
}
|
|
|
|
should_send = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Scroll area for chat messages
|
|
egui::Frame::new()
|
|
.inner_margin(egui::Margin {
|
|
left: 50,
|
|
right: 50,
|
|
top: 50,
|
|
bottom: 50,
|
|
})
|
|
.show(ui, |ui| {
|
|
egui::ScrollArea::vertical()
|
|
.stick_to_bottom(true)
|
|
.auto_shrink([false; 2])
|
|
.show(ui, |ui| {
|
|
ui.vertical(|ui| {
|
|
self.render_chat(ui);
|
|
|
|
self.inputbox(app_ctx, ui);
|
|
})
|
|
});
|
|
});
|
|
|
|
/*
|
|
// he lives in the sidebar now
|
|
if let Some(avatar) = &mut self.avatar {
|
|
let avatar_size = Vec2::splat(300.0);
|
|
let pos = Vec2::splat(100.0).to_pos2();
|
|
let pos = Rect::from_min_max(pos, pos + avatar_size);
|
|
avatar.render(pos, ui);
|
|
}
|
|
*/
|
|
|
|
// send again
|
|
if should_send {
|
|
self.send_user_message(app_ctx, ui.ctx());
|
|
}
|
|
}
|
|
|
|
fn render_chat(&self, ui: &mut egui::Ui) {
|
|
for message in &self.chat {
|
|
match message {
|
|
Message::User(msg) => self.user_chat(msg, ui),
|
|
Message::Assistant(msg) => self.assistant_chat(msg, ui),
|
|
Message::ToolResponse(msg) => Self::tool_response_ui(msg, ui),
|
|
Message::System(_msg) => {
|
|
// system prompt is not rendered. Maybe we could
|
|
// have a debug option to show this
|
|
}
|
|
Message::ToolCalls(toolcalls) => {
|
|
Self::tool_call_ui(toolcalls, ui);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn tool_response_ui(tool_response: &ToolResponse, ui: &mut egui::Ui) {
|
|
ui.label(format!("tool_response: {:?}", tool_response));
|
|
}
|
|
|
|
fn search_call_ui(search_call: &SearchCall, ui: &mut egui::Ui) {
|
|
ui.add(search_icon(16.0, 16.0));
|
|
ui.add_space(8.0);
|
|
//let context = match search_call.context {
|
|
// SearchContext::Profile => "profile ",
|
|
// SearchContext::Any => "",
|
|
// SearchContext::Home => "home ",
|
|
//};
|
|
|
|
ui.label(format!(
|
|
"Searching for '{}'",
|
|
/*context,*/ search_call.query
|
|
));
|
|
}
|
|
|
|
fn tool_call_ui(toolcalls: &[ToolCall], ui: &mut egui::Ui) {
|
|
ui.vertical(|ui| {
|
|
for call in toolcalls {
|
|
match &call.typ {
|
|
ToolCalls::Search(search_call) => {
|
|
ui.horizontal(|ui| {
|
|
egui::Frame::new()
|
|
.inner_margin(10.0)
|
|
.corner_radius(10.0)
|
|
.fill(ui.visuals().widgets.inactive.weak_bg_fill)
|
|
.show(ui, |ui| {
|
|
Self::search_call_ui(search_call, ui);
|
|
})
|
|
});
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
fn inputbox(&mut self, app_ctx: &AppContext, ui: &mut egui::Ui) {
|
|
ui.horizontal(|ui| {
|
|
ui.add(egui::TextEdit::multiline(&mut self.input));
|
|
ui.add_space(8.0);
|
|
if ui.button("Send").clicked() {
|
|
self.chat.push(Message::User(self.input.clone()));
|
|
self.send_user_message(app_ctx, ui.ctx());
|
|
self.input.clear();
|
|
}
|
|
});
|
|
}
|
|
|
|
fn user_chat(&self, msg: &str, ui: &mut egui::Ui) {
|
|
ui.with_layout(egui::Layout::right_to_left(egui::Align::TOP), |ui| {
|
|
egui::Frame::new()
|
|
.inner_margin(10.0)
|
|
.corner_radius(10.0)
|
|
.fill(ui.visuals().widgets.inactive.weak_bg_fill)
|
|
.show(ui, |ui| {
|
|
ui.label(msg);
|
|
})
|
|
});
|
|
}
|
|
|
|
fn assistant_chat(&self, msg: &str, ui: &mut egui::Ui) {
|
|
ui.horizontal_wrapped(|ui| {
|
|
ui.add(egui::Label::new(msg).wrap_mode(egui::TextWrapMode::Wrap));
|
|
});
|
|
}
|
|
|
|
fn send_user_message(&mut self, app_ctx: &AppContext, ctx: &egui::Context) {
|
|
let messages: Vec<ChatCompletionRequestMessage> = {
|
|
let txn = Transaction::new(app_ctx.ndb).expect("txn");
|
|
self.chat
|
|
.iter()
|
|
.map(|c| c.to_api_msg(&txn, app_ctx.ndb))
|
|
.collect()
|
|
};
|
|
tracing::debug!("sending messages, latest: {:?}", messages.last().unwrap());
|
|
let pubkey = self.pubkey.clone();
|
|
let ctx = ctx.clone();
|
|
let client = self.client.clone();
|
|
let tools = self.tools.clone();
|
|
let model_name = self.model_config.model.clone();
|
|
|
|
let (tx, rx) = mpsc::channel();
|
|
self.incoming_tokens = Some(rx);
|
|
|
|
tokio::spawn(async move {
|
|
let mut token_stream = match client
|
|
.chat()
|
|
.create_stream(CreateChatCompletionRequest {
|
|
model: model_name,
|
|
stream: Some(true),
|
|
messages,
|
|
tools: Some(dave_tools().iter().map(|t| t.to_api()).collect()),
|
|
user: Some(pubkey),
|
|
..Default::default()
|
|
})
|
|
.await
|
|
{
|
|
Err(err) => {
|
|
tracing::error!("openai chat error: {err}");
|
|
return;
|
|
}
|
|
|
|
Ok(stream) => stream,
|
|
};
|
|
|
|
let mut all_tool_calls: HashMap<u32, PartialToolCall> = HashMap::new();
|
|
|
|
while let Some(token) = token_stream.next().await {
|
|
let token = match token {
|
|
Ok(token) => token,
|
|
Err(err) => {
|
|
tracing::error!("failed to get token: {err}");
|
|
return;
|
|
}
|
|
};
|
|
|
|
for choice in &token.choices {
|
|
let resp = &choice.delta;
|
|
|
|
// if we have tool call arg chunks, collect them here
|
|
if let Some(tool_calls) = &resp.tool_calls {
|
|
for tool in tool_calls {
|
|
let entry = all_tool_calls.entry(tool.index).or_default();
|
|
|
|
if let Some(id) = &tool.id {
|
|
entry.id.get_or_insert(id.to_string());
|
|
}
|
|
|
|
if let Some(name) = tool.function.as_ref().and_then(|f| f.name.as_ref())
|
|
{
|
|
entry.name.get_or_insert(name.to_string());
|
|
}
|
|
|
|
if let Some(argchunk) =
|
|
tool.function.as_ref().and_then(|f| f.arguments.as_ref())
|
|
{
|
|
entry
|
|
.arguments
|
|
.get_or_insert_with(String::new)
|
|
.push_str(argchunk);
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(content) = &resp.content {
|
|
if let Err(err) = tx.send(DaveResponse::Token(content.to_owned())) {
|
|
tracing::error!("failed to send dave response token to ui: {err}");
|
|
}
|
|
ctx.request_repaint();
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut parsed_tool_calls = vec![];
|
|
for (_index, partial) in all_tool_calls {
|
|
let Some(unknown_tool_call) = partial.complete() else {
|
|
tracing::error!("could not complete partial tool call: {:?}", partial);
|
|
continue;
|
|
};
|
|
|
|
match unknown_tool_call.parse(&tools) {
|
|
Ok(tool_call) => {
|
|
parsed_tool_calls.push(tool_call);
|
|
}
|
|
Err(err) => {
|
|
tracing::error!(
|
|
"failed to parse tool call {:?}: {:?}",
|
|
unknown_tool_call,
|
|
err,
|
|
);
|
|
// TODO: return error to user
|
|
}
|
|
};
|
|
}
|
|
|
|
if !parsed_tool_calls.is_empty() {
|
|
tx.send(DaveResponse::ToolCalls(parsed_tool_calls)).unwrap();
|
|
ctx.request_repaint();
|
|
}
|
|
|
|
tracing::debug!("stream closed");
|
|
});
|
|
}
|
|
}
|
|
|
|
impl notedeck::App for Dave {
|
|
fn update(&mut self, ctx: &mut AppContext<'_>, ui: &mut egui::Ui) {
|
|
/*
|
|
self.app
|
|
.frame_history
|
|
.on_new_frame(ctx.input(|i| i.time), frame.info().cpu_usage);
|
|
*/
|
|
|
|
//update_dave(self, ctx, ui.ctx());
|
|
self.render(ctx, ui);
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum ArgType {
|
|
String,
|
|
Number,
|
|
Enum(Vec<&'static str>),
|
|
}
|
|
|
|
impl ArgType {
|
|
pub fn type_string(&self) -> &'static str {
|
|
match self {
|
|
Self::String => "string",
|
|
Self::Number => "number",
|
|
Self::Enum(_) => "string",
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct ToolArg {
|
|
typ: ArgType,
|
|
name: &'static str,
|
|
required: bool,
|
|
description: &'static str,
|
|
default: Option<Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Tool {
|
|
parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>,
|
|
name: &'static str,
|
|
description: &'static str,
|
|
arguments: Vec<ToolArg>,
|
|
}
|
|
|
|
impl Tool {
|
|
pub fn to_function_object(&self) -> FunctionObject {
|
|
let required_args = self
|
|
.arguments
|
|
.iter()
|
|
.filter_map(|arg| {
|
|
if arg.required {
|
|
Some(Value::String(arg.name.to_owned()))
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
|
|
parameters.insert("type".to_string(), Value::String("object".to_string()));
|
|
parameters.insert("required".to_string(), Value::Array(required_args));
|
|
parameters.insert("additionalProperties".to_string(), Value::Bool(false));
|
|
|
|
let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
|
|
|
|
for arg in &self.arguments {
|
|
let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
|
|
props.insert(
|
|
"type".to_string(),
|
|
Value::String(arg.typ.type_string().to_string()),
|
|
);
|
|
if let Some(default) = &arg.default {
|
|
props.insert("default".to_string(), default.clone());
|
|
}
|
|
props.insert(
|
|
"description".to_string(),
|
|
Value::String(arg.description.to_owned()),
|
|
);
|
|
if let ArgType::Enum(enums) = &arg.typ {
|
|
props.insert(
|
|
"enum".to_string(),
|
|
Value::Array(
|
|
enums
|
|
.iter()
|
|
.map(|s| Value::String((*s).to_owned()))
|
|
.collect(),
|
|
),
|
|
);
|
|
}
|
|
|
|
properties.insert(arg.name.to_owned(), Value::Object(props));
|
|
}
|
|
|
|
parameters.insert("properties".to_string(), Value::Object(properties));
|
|
|
|
FunctionObject {
|
|
name: self.name.to_owned(),
|
|
description: Some(self.description.to_owned()),
|
|
strict: Some(true),
|
|
parameters: Some(Value::Object(parameters)),
|
|
}
|
|
}
|
|
|
|
pub fn to_api(&self) -> ChatCompletionTool {
|
|
ChatCompletionTool {
|
|
r#type: ChatCompletionToolType::Function,
|
|
function: self.to_function_object(),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn search_tool() -> Tool {
|
|
Tool {
|
|
name: "search",
|
|
parse_call: SearchCall::parse,
|
|
description: "Full-text search functionality. Used for finding individual notes with specific terms in the contents.",
|
|
arguments: vec![
|
|
ToolArg {
|
|
name: "query",
|
|
typ: ArgType::String,
|
|
required: true,
|
|
default: None,
|
|
description: "The fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include 'and', 'punctuation', etc if you don't need to.",
|
|
},
|
|
|
|
ToolArg {
|
|
name: "limit",
|
|
typ: ArgType::Number,
|
|
required: true,
|
|
default: Some(Value::Number(serde_json::Number::from_i128(10).unwrap())),
|
|
description: "The number of results to return.",
|
|
},
|
|
|
|
/*
|
|
ToolArg {
|
|
name: "kind",
|
|
typ: ArgType::Enum(vec!["microblog", "longform"]),
|
|
required: true,
|
|
description: "The kind of note. microblogs are short snippets of texts (aka tweets, this is what you want to search by default). Longform are blog posts/articles.",
|
|
},
|
|
|
|
ToolArg {
|
|
name: "context",
|
|
typ: ArgType::Enum(vec!["home", "profile", "any"]),
|
|
required: true,
|
|
description: "The context in which the search is occuring. valid options are 'home', 'profile', 'any'",
|
|
}
|
|
*/
|
|
]
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum ToolCallError {
|
|
EmptyName,
|
|
EmptyArgs,
|
|
NotFound(String),
|
|
ArgParseFailure(String),
|
|
}
|
|
|
|
fn dave_tools() -> Vec<Tool> {
|
|
vec![search_tool()]
|
|
}
|