mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-06 15:14:31 +01:00
feat: ollama tool shim (#1448)
Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
@@ -296,8 +296,16 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
|
||||
let spin = spinner();
|
||||
spin.start("Checking your configuration...");
|
||||
|
||||
// Use max tokens to speed up the provider test.
|
||||
let model_config = goose::model::ModelConfig::new(model.clone()).with_max_tokens(Some(50));
|
||||
// Create model config with env var settings
|
||||
let model_config = goose::model::ModelConfig::new(model.clone())
|
||||
.with_max_tokens(Some(50))
|
||||
.with_toolshim(
|
||||
std::env::var("GOOSE_TOOLSHIM")
|
||||
.map(|val| val == "1" || val.to_lowercase() == "true")
|
||||
.unwrap_or(false),
|
||||
)
|
||||
.with_toolshim_model(std::env::var("GOOSE_TOOLSHIM_OLLAMA_MODEL").ok());
|
||||
|
||||
let provider = create(provider_name, model_config)?;
|
||||
|
||||
let messages =
|
||||
|
||||
@@ -19,6 +19,9 @@ use crate::message::{Message, ToolRequest};
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::base::ProviderUsage;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::toolshim::{
|
||||
augment_message_with_tool_calls, modify_system_prompt_for_tool_json, OllamaInterpreter,
|
||||
};
|
||||
use crate::register_agent;
|
||||
use crate::session;
|
||||
use crate::token_counter::TokenCounter;
|
||||
@@ -217,7 +220,17 @@ impl Agent for TruncateAgent {
|
||||
tools.push(list_resources_tool);
|
||||
}
|
||||
|
||||
let system_prompt = capabilities.get_system_prompt().await;
|
||||
let config = capabilities.provider().get_model_config();
|
||||
let mut system_prompt = capabilities.get_system_prompt().await;
|
||||
let mut toolshim_tools = vec![];
|
||||
if config.toolshim {
|
||||
// If tool interpretation is enabled, modify the system prompt to instruct to return JSON tool requests
|
||||
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
|
||||
// make a copy of tools before empty
|
||||
toolshim_tools = tools.clone();
|
||||
// pass empty tools vector to provider completion since toolshim will handle tool calls instead
|
||||
tools = vec![];
|
||||
}
|
||||
|
||||
// Set the user_message field in the span instead of creating a new event
|
||||
if let Some(content) = messages
|
||||
@@ -236,7 +249,15 @@ impl Agent for TruncateAgent {
|
||||
&messages,
|
||||
&tools,
|
||||
).await {
|
||||
Ok((response, usage)) => {
|
||||
Ok((mut response, usage)) => {
|
||||
// Post-process / structure the response only if tool interpretation is enabled
|
||||
if config.toolshim {
|
||||
let interpreter = OllamaInterpreter::new()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create OllamaInterpreter: {}", e))?;
|
||||
|
||||
response = augment_message_with_tool_calls(&interpreter, response, &toolshim_tools).await?;
|
||||
}
|
||||
|
||||
capabilities.record_usage(usage.clone()).await;
|
||||
|
||||
// record usage for the session in the session file
|
||||
|
||||
@@ -21,6 +21,10 @@ pub struct ModelConfig {
|
||||
pub temperature: Option<f32>,
|
||||
/// Optional maximum tokens to generate
|
||||
pub max_tokens: Option<i32>,
|
||||
/// Whether to interpret tool calls with toolshim
|
||||
pub toolshim: bool,
|
||||
/// Model to use for toolshim (optional as a default exists)
|
||||
pub toolshim_model: Option<String>,
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
@@ -34,12 +38,20 @@ impl ModelConfig {
|
||||
let context_limit = Self::get_model_specific_limit(&model_name);
|
||||
let tokenizer_name = Self::infer_tokenizer_name(&model_name);
|
||||
|
||||
let toolshim = std::env::var("GOOSE_TOOLSHIM")
|
||||
.map(|val| val == "1" || val.to_lowercase() == "true")
|
||||
.unwrap_or(false);
|
||||
|
||||
let toolshim_model = std::env::var("GOOSE_TOOLSHIM_OLLAMA_MODEL").ok();
|
||||
|
||||
Self {
|
||||
model_name,
|
||||
tokenizer_name: tokenizer_name.to_string(),
|
||||
context_limit,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
toolshim,
|
||||
toolshim_model,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,7 +108,19 @@ impl ModelConfig {
|
||||
self
|
||||
}
|
||||
|
||||
// Get the tokenizer name
|
||||
/// Set whether to interpret tool calls
|
||||
pub fn with_toolshim(mut self, toolshim: bool) -> Self {
|
||||
self.toolshim = toolshim;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the tool call interpreter model
|
||||
pub fn with_toolshim_model(mut self, model: Option<String>) -> Self {
|
||||
self.toolshim_model = model;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the tokenizer name
|
||||
pub fn tokenizer_name(&self) -> &str {
|
||||
&self.tokenizer_name
|
||||
}
|
||||
@@ -142,4 +166,20 @@ mod tests {
|
||||
assert_eq!(config.max_tokens, Some(1000));
|
||||
assert_eq!(config.context_limit, Some(50_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config_tool_interpretation() {
|
||||
// Test without env vars - should be false
|
||||
let config = ModelConfig::new("test-model".to_string());
|
||||
assert!(!config.toolshim);
|
||||
|
||||
// Test with tool interpretation setting
|
||||
let config = ModelConfig::new("test-model".to_string()).with_toolshim(true);
|
||||
assert!(config.toolshim);
|
||||
|
||||
// Test tool interpreter model
|
||||
let config = ModelConfig::new("test-model".to_string())
|
||||
.with_toolshim_model(Some("mistral-nemo".to_string()));
|
||||
assert_eq!(config.toolshim_model, Some("mistral-nemo".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -826,6 +826,8 @@ mod tests {
|
||||
context_limit: Some(4096),
|
||||
temperature: None,
|
||||
max_tokens: Some(1024),
|
||||
toolshim: false,
|
||||
toolshim_model: None,
|
||||
};
|
||||
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
|
||||
let obj = request.as_object().unwrap();
|
||||
@@ -856,6 +858,8 @@ mod tests {
|
||||
context_limit: Some(4096),
|
||||
temperature: None,
|
||||
max_tokens: Some(1024),
|
||||
toolshim: false,
|
||||
toolshim_model: None,
|
||||
};
|
||||
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
|
||||
let obj = request.as_object().unwrap();
|
||||
@@ -887,6 +891,8 @@ mod tests {
|
||||
context_limit: Some(4096),
|
||||
temperature: None,
|
||||
max_tokens: Some(1024),
|
||||
toolshim: false,
|
||||
toolshim_model: None,
|
||||
};
|
||||
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
|
||||
let obj = request.as_object().unwrap();
|
||||
|
||||
@@ -14,6 +14,7 @@ pub mod oauth;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod openrouter;
|
||||
pub mod toolshim;
|
||||
pub mod utils;
|
||||
|
||||
pub use factory::{create, providers};
|
||||
|
||||
@@ -6,7 +6,6 @@ use crate::model::ModelConfig;
|
||||
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use indoc::formatdoc;
|
||||
use mcp_core::tool::Tool;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
@@ -53,8 +52,8 @@ impl OllamaProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
// TODO: remove this later when the UI handles provider config refresh
|
||||
/// Get the base URL for Ollama API calls
|
||||
fn get_base_url(&self) -> Result<Url, ProviderError> {
|
||||
// OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme
|
||||
let base = if self.host.starts_with("http://") || self.host.starts_with("https://") {
|
||||
self.host.clone()
|
||||
@@ -73,6 +72,13 @@ impl OllamaProvider {
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(base_url)
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
// TODO: remove this later when the UI handles provider config refresh
|
||||
let base_url = self.get_base_url()?;
|
||||
|
||||
let url = base_url.join("v1/chat/completions").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
@@ -116,86 +122,17 @@ impl Provider for OllamaProvider {
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
// Transform the system message to replace developer instructions
|
||||
let modified_system = if let Some(dev_section) = system.split("## developer").nth(1) {
|
||||
if let (Some(start_idx), Some(end_idx)) = (
|
||||
dev_section.find("### Instructions"),
|
||||
dev_section.find("operating system:"),
|
||||
) {
|
||||
let new_instructions = formatdoc! {r#"
|
||||
The Developer extension enables you to edit code files, execute shell commands, and capture screen/window content. These tools allow for various development and debugging workflows.
|
||||
Available Tools:
|
||||
1. Shell Execution (`shell`)
|
||||
Executes commands in the shell and returns the combined output and error messages.
|
||||
Use cases:
|
||||
- Running scripts: `python script.py`
|
||||
- Installing dependencies: `pip install -r requirements.txt`
|
||||
- Checking system information: `uname -a`, `df -h`
|
||||
- Searching for files or text: **Use `rg` (ripgrep) instead of `find` or `ls -r`**
|
||||
- Find a file: `rg --files | rg example.py`
|
||||
- Search within files: `rg 'class Example'`
|
||||
Best Practices:
|
||||
- **Avoid commands with large output** (pipe them to a file if necessary).
|
||||
- **Run background processes** if they take a long time (e.g., `uvicorn main:app &`).
|
||||
- **git commands can be run on the shell, however if the git extension is installed, you should use the git tool instead.
|
||||
- **If the shell command is a rm, mv, or cp, you should verify with the user before running the command.
|
||||
2. Text Editor (`text_editor`)
|
||||
Performs file-based operations such as viewing, writing, replacing text, and undoing edits.
|
||||
Commands:
|
||||
- view: Read the content of a file.
|
||||
- write: Create or overwrite a file. Caution: Overwrites the entire file!
|
||||
- str_replace: Replace a specific string in a file.
|
||||
- undo_edit: Revert the last edit.
|
||||
Example Usage:
|
||||
text_editor(command="view", file_path="/absolute/path/to/file.py")
|
||||
text_editor(command="write", file_path="/absolute/path/to/file.py", file_text="print('hello world')")
|
||||
text_editor(command="str_replace", file_path="/absolute/path/to/file.py", old_str="hello world", new_str="goodbye world")
|
||||
text_editor(command="undo_edit", file_path="/absolute/path/to/file.py")
|
||||
Protocol for Text Editor:
|
||||
For edit and replace commands, please verify what you are editing with the user before running the command.
|
||||
- User: "Please edit the file /absolute/path/to/file.py"
|
||||
- Assistant: "Ok sounds good, I'll be editing the file /absolute/path/to/file.py and creating modifications xyz to the file. Let me know whether you'd like to proceed."
|
||||
- User: "Yes, please proceed."
|
||||
- Assistant: "I've created the modifications xyz to the file /absolute/path/to/file.py"
|
||||
3. List Windows (`list_windows`)
|
||||
Lists all visible windows with their titles.
|
||||
Use this to find window titles for screen capture.
|
||||
4. Screen Capture (`screen_capture`)
|
||||
Takes a screenshot of a display or specific window.
|
||||
Options:
|
||||
- Capture display: `screen_capture(display=0)` # Main display
|
||||
- Capture window: `screen_capture(window_title="Window Title")`
|
||||
Info: at the start of the session, the user's directory is:
|
||||
"#};
|
||||
|
||||
let before_dev = system.split("## developer").next().unwrap_or("");
|
||||
let after_marker = &dev_section[end_idx..];
|
||||
|
||||
format!(
|
||||
"{}## developer{}### Instructions\n{}{}",
|
||||
before_dev,
|
||||
&dev_section[..start_idx],
|
||||
new_instructions,
|
||||
after_marker
|
||||
)
|
||||
} else {
|
||||
system.to_string()
|
||||
}
|
||||
} else {
|
||||
system.to_string()
|
||||
};
|
||||
|
||||
let payload = create_request(
|
||||
&self.model,
|
||||
&modified_system,
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// Parse response
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let message = response_to_message(response.clone())?;
|
||||
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(usage) => usage,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
|
||||
362
crates/goose/src/providers/toolshim.rs
Normal file
362
crates/goose/src/providers/toolshim.rs
Normal file
@@ -0,0 +1,362 @@
|
||||
//! # ToolShim Module
|
||||
//!
|
||||
//! The ToolShim module provides a reusable component for interpreting and augmenting LLM outputs with tool calls,
|
||||
//! regardless of whether the underlying model natively supports tool/function calling.
|
||||
//!
|
||||
//! ## Overview
|
||||
//!
|
||||
//! ToolShim addresses the challenge of working with models that don't natively support tools by:
|
||||
//!
|
||||
//! 1. Taking the text output from any LLM
|
||||
//! 2. Sending it to a separate "interpreter" model (which can be the same or different model)
|
||||
//! 3. Using a model to extract tool call intentions into the appropriate format
|
||||
//! 4. Converting the outputs of the interpreter model into proper tool call structs
|
||||
//! 5. Augmenting the original message with the extracted tool calls
|
||||
//!
|
||||
//! ## Key Components
|
||||
//!
|
||||
//! ### ToolInterpreter Trait
|
||||
//!
|
||||
//! The core of ToolShim is the `ToolInterpreter` trait, which defines the interface for any model that can interpret text and extract tool calls.
|
||||
//!
|
||||
//! ### Implementations
|
||||
//!
|
||||
//! The module provides an implementation for Ollama:
|
||||
//!
|
||||
//! - `OllamaInterpreter`: Uses Ollama's structured output API to interpret tool calls
|
||||
//!
|
||||
//! ### Helper Functions
|
||||
//!
|
||||
//! - `augment_message_with_tool_calls`: A utility function that takes any message, extracts text content, sends it to an interpreter, and adds any detected tool calls back to the message.
|
||||
//!
|
||||
|
||||
use super::errors::ProviderError;
|
||||
use super::ollama::OLLAMA_DEFAULT_PORT;
|
||||
use super::ollama::OLLAMA_HOST;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::formats::openai::create_request;
|
||||
use anyhow::Result;
|
||||
use mcp_core::tool::{Tool, ToolCall};
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Default model to use for tool interpretation
|
||||
pub const DEFAULT_INTERPRETER_MODEL_OLLAMA: &str = "mistral-nemo";
|
||||
|
||||
/// Environment variables that affect behavior:
|
||||
/// - GOOSE_TOOLSHIM: When set to "true" or "1", enables using the tool shim in the standard OllamaProvider (default: false)
|
||||
/// - GOOSE_TOOLSHIM_OLLAMA_MODEL: Ollama model to use as the tool interpreter (default: DEFAULT_INTERPRETER_MODEL)
|
||||
/// A trait for models that can interpret text into structured tool call JSON format
|
||||
#[async_trait::async_trait]
|
||||
pub trait ToolInterpreter {
|
||||
/// Interpret potential tool calls from text and convert them to proper tool call JSON format
|
||||
async fn interpret_to_tool_calls(
|
||||
&self,
|
||||
content: &str,
|
||||
tools: &[Tool],
|
||||
) -> Result<Vec<ToolCall>, ProviderError>;
|
||||
}
|
||||
|
||||
/// Ollama-specific implementation of the ToolInterpreter trait
|
||||
pub struct OllamaInterpreter {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl OllamaInterpreter {
|
||||
pub fn new() -> Result<Self, ProviderError> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(600))
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
let base_url = Self::get_ollama_base_url()?;
|
||||
|
||||
Ok(Self { client, base_url })
|
||||
}
|
||||
|
||||
/// Get the Ollama base URL from existing config or use default values
|
||||
fn get_ollama_base_url() -> Result<String, ProviderError> {
|
||||
let config = crate::config::Config::global();
|
||||
let host: String = config
|
||||
.get_param("OLLAMA_HOST")
|
||||
.unwrap_or_else(|_| OLLAMA_HOST.to_string());
|
||||
|
||||
// Format the URL correctly with http:// prefix if needed
|
||||
let base = if host.starts_with("http://") || host.starts_with("https://") {
|
||||
host.clone()
|
||||
} else {
|
||||
format!("http://{}", host)
|
||||
};
|
||||
|
||||
let mut base_url = url::Url::parse(&base)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
// Set the default port if missing
|
||||
let explicit_default_port = host.ends_with(":80") || host.ends_with(":443");
|
||||
if base_url.port().is_none() && !explicit_default_port {
|
||||
base_url.set_port(Some(OLLAMA_DEFAULT_PORT)).map_err(|_| {
|
||||
ProviderError::RequestFailed("Failed to set default port".to_string())
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(base_url.to_string())
|
||||
}
|
||||
|
||||
fn tool_structured_ouput_format_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The name of the tool to call"
|
||||
},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"description": "The arguments to pass to the tool"
|
||||
}
|
||||
},
|
||||
"required": ["name", "arguments"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["tool_calls"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn post_structured(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
format_instruction: &str,
|
||||
format_schema: Value,
|
||||
model: &str,
|
||||
) -> Result<Value, ProviderError> {
|
||||
let base_url = self.base_url.trim_end_matches('/');
|
||||
let url = format!("{}/api/chat", base_url);
|
||||
|
||||
let mut messages = Vec::new();
|
||||
let user_message = Message::user().with_text(format_instruction);
|
||||
messages.push(user_message);
|
||||
|
||||
let model_config = ModelConfig::new(model.to_string());
|
||||
|
||||
let mut payload = create_request(
|
||||
&model_config,
|
||||
system_prompt,
|
||||
&messages,
|
||||
&[], // No tools
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
|
||||
payload["stream"] = json!(false); // needed for the /api/chat endpoint to work
|
||||
payload["format"] = format_schema;
|
||||
|
||||
// tracing::warn!("payload: {}", serde_json::to_string_pretty(&payload).unwrap_or_default());
|
||||
|
||||
let response = self.client.post(&url).json(&payload).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
|
||||
let error_text = match response.text().await {
|
||||
Ok(text) => text,
|
||||
Err(_) => "Could not read error response".to_string(),
|
||||
};
|
||||
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
"Ollama structured API returned error status {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let response_json: Value = response.json().await.map_err(|e| {
|
||||
ProviderError::RequestFailed(format!(
|
||||
"Failed to parse Ollama structured API response: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(response_json)
|
||||
}
|
||||
|
||||
fn process_interpreter_response(response: &Value) -> Result<Vec<ToolCall>, ProviderError> {
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
// Extract tool_calls array from the response
|
||||
if response.get("message").is_some() && response["message"].get("content").is_some() {
|
||||
let content = response["message"]["content"].as_str().unwrap_or_default();
|
||||
|
||||
// Try to parse the content as JSON
|
||||
if let Ok(content_json) = serde_json::from_str::<Value>(content) {
|
||||
// Check for the format with tool_calls array inside an object
|
||||
if content_json.is_object() && content_json.get("tool_calls").is_some() {
|
||||
// Process each tool call in the array
|
||||
if let Some(tool_calls_array) = content_json["tool_calls"].as_array() {
|
||||
for item in tool_calls_array {
|
||||
if item.is_object()
|
||||
&& item.get("name").is_some()
|
||||
&& item.get("arguments").is_some()
|
||||
{
|
||||
// Create ToolCall directly from the JSON data
|
||||
let name = item["name"].as_str().unwrap_or_default().to_string();
|
||||
let arguments = item["arguments"].clone();
|
||||
|
||||
// Add the tool call to our result vector
|
||||
tool_calls.push(ToolCall::new(name, arguments));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tool_calls)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ToolInterpreter for OllamaInterpreter {
|
||||
async fn interpret_to_tool_calls(
|
||||
&self,
|
||||
last_assistant_msg: &str,
|
||||
tools: &[Tool],
|
||||
) -> Result<Vec<ToolCall>, ProviderError> {
|
||||
if tools.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Create the system prompt
|
||||
let system_prompt = "Rewrite JSON-formatted tool requests into valid JSON tool calls in the following format.
|
||||
|
||||
Always respond with the following tool_calls array format:
|
||||
{{
|
||||
\"tool_calls\": [
|
||||
{{
|
||||
\"name\": \"tool_name\",
|
||||
\"arguments\": {{
|
||||
\"param1\": \"value1\",
|
||||
\"param2\": \"value2\"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
You should return an empty tool_calls array if no tools are explicitly referenced:
|
||||
{{
|
||||
\"tool_calls\": []
|
||||
}}
|
||||
";
|
||||
|
||||
// Create enhanced content with instruction to output tool calls as JSON
|
||||
let format_instruction = format!(
|
||||
"{}\n\nWrite valid json if there is detectable json or an attempt at json",
|
||||
last_assistant_msg
|
||||
);
|
||||
|
||||
// Define the JSON schema for tool call format
|
||||
let format_schema = OllamaInterpreter::tool_structured_ouput_format_schema();
|
||||
|
||||
// Determine which model to use for interpretation (from env var or default)
|
||||
let interpreter_model = std::env::var("GOOSE_TOOLSHIM_OLLAMA_MODEL")
|
||||
.unwrap_or_else(|_| DEFAULT_INTERPRETER_MODEL_OLLAMA.to_string());
|
||||
|
||||
// Make a call to ollama with structured output
|
||||
let interpreter_response = self
|
||||
.post_structured(
|
||||
system_prompt,
|
||||
&format_instruction,
|
||||
format_schema,
|
||||
&interpreter_model,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Process the interpreter response to get tool calls directly
|
||||
let tool_calls = OllamaInterpreter::process_interpreter_response(&interpreter_response)?;
|
||||
|
||||
Ok(tool_calls)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a string containing formatted tool information
|
||||
pub fn format_tool_info(tools: &[Tool]) -> String {
|
||||
let mut tool_info = String::new();
|
||||
for tool in tools {
|
||||
tool_info.push_str(&format!(
|
||||
"Tool Name: {}\nSchema: {}\nDescription: {}\n\n",
|
||||
tool.name,
|
||||
serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default(),
|
||||
tool.description
|
||||
));
|
||||
}
|
||||
tool_info
|
||||
}
|
||||
|
||||
/// Modifies the system prompt to include tool usage instructions when tool interpretation is enabled
|
||||
pub fn modify_system_prompt_for_tool_json(system_prompt: &str, tools: &[Tool]) -> String {
|
||||
let tool_info = format_tool_info(tools);
|
||||
format!(
|
||||
"{}\n\n{}\n\nBreak down your task into smaller steps and do one step and tool call at a time. Do not try to use multiple tools at once. If you want to use a tool, tell the user what tool to use by specifying the tool in this JSON format\n{{\n \"name\": \"tool_name\",\n \"arguments\": {{\n \"parameter1\": \"value1\",\n \"parameter2\": \"value2\"\n }}\n}}. After you get the tool result back, consider the result and then proceed to do the next step and tool call if required.",
|
||||
system_prompt,
|
||||
tool_info
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper function to augment a message with tool calls if any are detected
|
||||
pub async fn augment_message_with_tool_calls<T: ToolInterpreter>(
|
||||
interpreter: &T,
|
||||
message: Message,
|
||||
tools: &[Tool],
|
||||
) -> Result<Message, ProviderError> {
|
||||
// If there are no tools or the message is empty, return the original message
|
||||
if tools.is_empty() {
|
||||
return Ok(message);
|
||||
}
|
||||
|
||||
// Extract content from the message
|
||||
let content_opt = message.content.iter().find_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
Some(text.text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
// If there's no text content or it's already a tool request, return the original message
|
||||
let content = match content_opt {
|
||||
Some(text) => text,
|
||||
None => return Ok(message),
|
||||
};
|
||||
|
||||
// Check if there's already a tool request
|
||||
if message
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| matches!(content, MessageContent::ToolRequest(_)))
|
||||
{
|
||||
return Ok(message);
|
||||
}
|
||||
|
||||
// Use the interpreter to convert the content to tool calls
|
||||
let tool_calls = interpreter.interpret_to_tool_calls(content, tools).await?;
|
||||
|
||||
// If no tool calls were detected, return the original message
|
||||
if tool_calls.is_empty() {
|
||||
return Ok(message);
|
||||
}
|
||||
|
||||
// Add each tool call to the message
|
||||
let mut final_message = message;
|
||||
for tool_call in tool_calls {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
final_message = final_message.with_tool_request(id, Ok(tool_call));
|
||||
}
|
||||
|
||||
Ok(final_message)
|
||||
}
|
||||
Reference in New Issue
Block a user