mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
Merge branch 'block:main' into goose-api
This commit is contained in:
@@ -25,6 +25,7 @@ include_dir = "0.7.4"
|
||||
once_cell = "1.19"
|
||||
regex = "1.11.1"
|
||||
toml = "0.8.20"
|
||||
dotenvy = "0.15.7"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
273
crates/goose-bench/README.md
Normal file
273
crates/goose-bench/README.md
Normal file
@@ -0,0 +1,273 @@
|
||||
# Goose Benchmarking Framework
|
||||
|
||||
The `goose-bench` crate provides a framework for benchmarking and evaluating LLM models with the Goose framework. This tool helps quantify model performance across various tasks and generate structured reports.
|
||||
|
||||
## Features
|
||||
|
||||
- Run benchmark suites across multiple LLM models
|
||||
- Execute evaluations in parallel when supported
|
||||
- Generate structured JSON and CSV reports
|
||||
- Process evaluation results with custom scripts
|
||||
- Calculate aggregate metrics across evaluations
|
||||
- Support for tool-shim evaluation
|
||||
- Generate leaderboards and comparative metrics
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Python Environment**: The `generate-leaderboard` command executes Python scripts and requires a valid Python environment with necessary dependencies (pandas, etc.)
|
||||
- **OpenAI API Key**: For evaluations using LLM-as-judge (like `blog_summary` and `restaurant_research`), you must have an `OPENAI_API_KEY` environment variable set, as the judge uses the OpenAI GPT-4o model
|
||||
|
||||
## Benchmark Workflow
|
||||
|
||||
Running benchmarks is a two-step process:
|
||||
|
||||
### Step 1: Run Benchmarks
|
||||
|
||||
First, run the benchmark evaluations with your configuration:
|
||||
|
||||
```bash
|
||||
goose bench run --config /path/to/your-config.json
|
||||
```
|
||||
|
||||
This will execute all evaluations for all models specified in your configuration and create a benchmark directory with results.
|
||||
|
||||
### Step 2: Generate Leaderboard
|
||||
|
||||
After the benchmarks complete, generate the leaderboard and aggregated metrics:
|
||||
|
||||
```bash
|
||||
goose bench generate-leaderboard --benchmark-dir /path/to/benchmark-output-directory
|
||||
```
|
||||
|
||||
The benchmark directory path will be shown in the output of the previous command, typically in the format `benchmark-YYYY-MM-DD-HH:MM:SS`.
|
||||
|
||||
**Note**: This command requires a valid Python environment as it executes Python scripts for data aggregation and leaderboard generation.
|
||||
|
||||
## Configuration
|
||||
|
||||
Benchmark configuration is provided through a JSON file. Here's a sample configuration file (leaderboard-config.json) that you can use as a template:
|
||||
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"provider": "databricks",
|
||||
"name": "gpt-4-1-mini",
|
||||
"parallel_safe": true,
|
||||
"tool_shim": {
|
||||
"use_tool_shim": false,
|
||||
"tool_shim_model": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"provider": "databricks",
|
||||
"name": "claude-3-5-sonnet",
|
||||
"parallel_safe": true,
|
||||
"tool_shim": null
|
||||
},
|
||||
{
|
||||
"provider": "databricks",
|
||||
"name": "gpt-4o",
|
||||
"parallel_safe": true,
|
||||
"tool_shim": null
|
||||
}
|
||||
],
|
||||
"evals": [
|
||||
{
|
||||
"selector": "core:developer",
|
||||
"post_process_cmd": null,
|
||||
"parallel_safe": true
|
||||
},
|
||||
{
|
||||
"selector": "core:developer_search_replace",
|
||||
"post_process_cmd": null,
|
||||
"parallel_safe": true
|
||||
},
|
||||
{
|
||||
"selector": "vibes:blog_summary",
|
||||
"post_process_cmd": "/Users/ahau/Development/goose-1.0/goose/scripts/bench-postprocess-scripts/llm-judges/run_vibes_judge.sh",
|
||||
"parallel_safe": true
|
||||
},
|
||||
{
|
||||
"selector": "vibes:restaurant_research",
|
||||
"post_process_cmd": "/Users/ahau/Development/goose-1.0/goose/scripts/bench-postprocess-scripts/llm-judges/run_vibes_judge.sh",
|
||||
"parallel_safe": true
|
||||
}
|
||||
],
|
||||
"include_dirs": [],
|
||||
"repeat": 3,
|
||||
"run_id": null,
|
||||
"output_dir": "/path/to/output/directory",
|
||||
"eval_result_filename": "eval-results.json",
|
||||
"run_summary_filename": "run-results-summary.json",
|
||||
"env_file": "/path/to/.goosebench.env"
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Models
|
||||
|
||||
- `provider`: The LLM provider (e.g., "databricks", "openai")
|
||||
- `name`: The model name
|
||||
- `parallel_safe`: Whether the model can be run in parallel
|
||||
- `tool_shim`: Configuration for tool-shim support
|
||||
- `use_tool_shim`: Whether to use tool-shim
|
||||
- `tool_shim_model`: Optional custom model for tool-shim
|
||||
|
||||
### Evaluations
|
||||
|
||||
- `selector`: The evaluation selector in format `suite:evaluation`
|
||||
- `post_process_cmd`: Optional path to a post-processing script
|
||||
- `parallel_safe`: Whether the evaluation can be run in parallel
|
||||
|
||||
### Global Configuration
|
||||
|
||||
- `include_dirs`: Additional directories to include in the benchmark environment
|
||||
- `repeat`: Number of times to repeat evaluations (for statistical significance)
|
||||
- `run_id`: Optional identifier for the run (defaults to timestamp)
|
||||
- `output_dir`: Directory to store benchmark results (must be absolute path)
|
||||
- `eval_result_filename`: Filename for individual evaluation results
|
||||
- `run_summary_filename`: Filename for run summary
|
||||
- `env_file`: Optional path to environment variables file
|
||||
|
||||
## Environment Variables
|
||||
|
||||
You can provide environment variables through the `env_file` configuration option. This is useful for provider API keys and other sensitive information. Example `.goosebench.env` file:
|
||||
|
||||
```bash
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
DATABRICKS_TOKEN=your_databricks_token_here
|
||||
# Add other environment variables as needed
|
||||
```
|
||||
|
||||
**Important**: For evaluations that use LLM-as-judge (like `blog_summary` and `restaurant_research`), you must set `OPENAI_API_KEY` as the judging system uses OpenAI's GPT-4o model.
|
||||
|
||||
## Post-Processing
|
||||
|
||||
You can specify post-processing commands for evaluations, which will be executed after each evaluation completes. The command receives the path to the evaluation results file as its first argument.
|
||||
|
||||
For example, the `run_vibes_judge.sh` script processes outputs from the `blog_summary` and `restaurant_research` evaluations, using LLM-based judging to assign scores.
|
||||
|
||||
## Output Structure
|
||||
|
||||
Results are organized in a directory structure that follows this pattern:
|
||||
|
||||
```
|
||||
{benchmark_dir}/
|
||||
├── config.cfg # Configuration used for the benchmark
|
||||
├── {provider}-{model}/
|
||||
│ ├── eval-results/
|
||||
│ │ └── aggregate_metrics.csv # Aggregated metrics for this model
|
||||
│ └── run-{run_id}/
|
||||
│ ├── {suite}/
|
||||
│ │ └── {evaluation}/
|
||||
│ │ ├── eval-results.json # Individual evaluation results
|
||||
│ │ ├── {eval_name}.jsonl # Session logs
|
||||
│ │ └── work_dir.json # Info about evaluation working dir
|
||||
│ └── run-results-summary.json # Summary of all evaluations in this run
|
||||
├── leaderboard.csv # Final leaderboard comparing all models
|
||||
└── all_metrics.csv # Union of all metrics across all models
|
||||
```
|
||||
|
||||
### Output Files Explained
|
||||
|
||||
#### Per-Model Files
|
||||
|
||||
- **`eval-results/aggregate_metrics.csv`**: Contains aggregated metrics for each evaluation, averaged across all runs. Includes metrics like `score_mean`, `total_tokens_mean`, `prompt_execution_time_seconds_mean`, etc.
|
||||
|
||||
#### Global Output Files
|
||||
|
||||
- **`leaderboard.csv`**: Final leaderboard ranking all models by their average performance across evaluations. Contains columns like:
|
||||
- `provider`, `model_name`: Model identification
|
||||
- `avg_score_mean`: Average score across all evaluations
|
||||
- `avg_prompt_execution_time_seconds_mean`: Average execution time
|
||||
- `avg_total_tool_calls_mean`: Average number of tool calls
|
||||
- `avg_total_tokens_mean`: Average token usage
|
||||
|
||||
- **`all_metrics.csv`**: Comprehensive dataset containing detailed metrics for every model-evaluation combination. This is a union of all individual model metrics, useful for detailed analysis and custom reporting.
|
||||
|
||||
Each model gets its own directory, containing run results and aggregated CSV files for analysis. The `generate-leaderboard` command processes all individual evaluation results and creates the comparative metrics files.
|
||||
|
||||
## Error Handling and Troubleshooting
|
||||
|
||||
**Important**: The current version of goose-bench does not have robust error handling for common issues that can occur during evaluation runs, such as:
|
||||
|
||||
- Rate limiting from inference providers
|
||||
- Network timeouts or connection errors
|
||||
- Provider API errors that cause early session termination
|
||||
- Resource exhaustion or memory issues
|
||||
|
||||
### Checking for Failed Evaluations
|
||||
|
||||
After running benchmarks, you should inspect the generated metrics files to identify any evaluations that may have failed or terminated early:
|
||||
|
||||
1. **Check the `aggregate_metrics.csv` files** in each model's `eval-results/` directory for:
|
||||
- Missing evaluations (fewer rows than expected)
|
||||
- Unusually low scores or metrics
|
||||
- Zero or near-zero execution times
|
||||
- Missing or NaN values
|
||||
|
||||
2. **Look for `server_error_mean` column** in the aggregate metrics - values greater than 0 indicate server errors occurred during evaluation
|
||||
|
||||
3. **Review session logs** (`.jsonl` files) in individual evaluation directories for error messages like:
|
||||
- "Server error"
|
||||
- "Rate limit exceeded"
|
||||
- "TEMPORARILY_UNAVAILABLE"
|
||||
- Unexpected session terminations
|
||||
|
||||
### Re-running Failed Evaluations
|
||||
|
||||
If you identify failed evaluations, you may need to:
|
||||
|
||||
1. **Adjust rate limiting**: Add delays between requests or reduce parallel execution
|
||||
2. **Update environment variables**: Ensure API keys and tokens are valid
|
||||
3. **Re-run specific model/evaluation combinations**: Create a new config with only the failed combinations
|
||||
4. **Check provider status**: Verify the inference provider is operational
|
||||
|
||||
Example of creating a config to re-run failed evaluations:
|
||||
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"provider": "databricks",
|
||||
"name": "claude-3-5-sonnet",
|
||||
"parallel_safe": false
|
||||
}
|
||||
],
|
||||
"evals": [
|
||||
{
|
||||
"selector": "vibes:blog_summary",
|
||||
"post_process_cmd": "/path/to/scripts/bench-postprocess-scripts/llm-judges/run_vibes_judge.sh",
|
||||
"parallel_safe": false
|
||||
}
|
||||
],
|
||||
"repeat": 1,
|
||||
"output_dir": "/path/to/retry-benchmark"
|
||||
}
|
||||
```
|
||||
|
||||
We recommend monitoring evaluation progress and checking for errors regularly, especially when running large benchmark suites across multiple models.
|
||||
|
||||
## Available Commands
|
||||
|
||||
### List Evaluations
|
||||
```bash
|
||||
goose bench selectors --config /path/to/config.json
|
||||
```
|
||||
|
||||
### Generate Initial Config
|
||||
```bash
|
||||
goose bench init-config --name my-benchmark-config.json
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
```bash
|
||||
goose bench run --config /path/to/config.json
|
||||
```
|
||||
|
||||
### Generate Leaderboard
|
||||
```bash
|
||||
goose bench generate-leaderboard --benchmark-dir /path/to/benchmark-output
|
||||
```
|
||||
@@ -30,6 +30,7 @@ pub struct BenchRunConfig {
|
||||
pub include_dirs: Vec<PathBuf>,
|
||||
pub repeat: Option<usize>,
|
||||
pub run_id: Option<String>,
|
||||
pub output_dir: Option<PathBuf>,
|
||||
pub eval_result_filename: String,
|
||||
pub run_summary_filename: String,
|
||||
pub env_file: Option<PathBuf>,
|
||||
@@ -63,6 +64,7 @@ impl Default for BenchRunConfig {
|
||||
include_dirs: vec![],
|
||||
repeat: Some(2),
|
||||
run_id: None,
|
||||
output_dir: None,
|
||||
eval_result_filename: "eval-results.json".to_string(),
|
||||
run_summary_filename: "run-results-summary.json".to_string(),
|
||||
env_file: None,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use anyhow::Context;
|
||||
use chrono::Local;
|
||||
use include_dir::{include_dir, Dir};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
@@ -53,15 +53,35 @@ impl BenchmarkWorkDir {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_experiment() {
|
||||
pub fn init_experiment(output_dir: PathBuf) -> anyhow::Result<()> {
|
||||
if !output_dir.is_absolute() {
|
||||
anyhow::bail!(
|
||||
"Internal Error: init_experiment received a non-absolute path: {}",
|
||||
output_dir.display()
|
||||
);
|
||||
}
|
||||
|
||||
// create experiment folder
|
||||
let current_time = Local::now().format("%H:%M:%S").to_string();
|
||||
let current_date = Local::now().format("%Y-%m-%d").to_string();
|
||||
let exp_name = format!("{}-{}", ¤t_date, current_time);
|
||||
let base_path = PathBuf::from(format!("./benchmark-{}", exp_name));
|
||||
fs::create_dir_all(&base_path).unwrap();
|
||||
std::env::set_current_dir(&base_path).unwrap();
|
||||
let exp_folder_name = format!("benchmark-{}-{}", ¤t_date, ¤t_time);
|
||||
let base_path = output_dir.join(exp_folder_name);
|
||||
|
||||
fs::create_dir_all(&base_path).with_context(|| {
|
||||
format!(
|
||||
"Failed to create benchmark directory: {}",
|
||||
base_path.display()
|
||||
)
|
||||
})?;
|
||||
std::env::set_current_dir(&base_path).with_context(|| {
|
||||
format!(
|
||||
"Failed to change working directory to: {}",
|
||||
base_path.display()
|
||||
)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn canonical_dirs(include_dirs: Vec<PathBuf>) -> Vec<PathBuf> {
|
||||
include_dirs
|
||||
.iter()
|
||||
@@ -186,7 +206,7 @@ impl BenchmarkWorkDir {
|
||||
Ok(())
|
||||
} else {
|
||||
let error_message = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
Err(io::Error::new(ErrorKind::Other, error_message))
|
||||
Err(io::Error::other(error_message))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,17 +3,29 @@ use crate::bench_work_dir::BenchmarkWorkDir;
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
pub type Model = (String, String);
|
||||
pub type Extension = String;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub enum EvalMetricValue {
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
String(String),
|
||||
Boolean(bool),
|
||||
}
|
||||
|
||||
impl fmt::Display for EvalMetricValue {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
EvalMetricValue::Integer(i) => write!(f, "{}", i),
|
||||
EvalMetricValue::Float(fl) => write!(f, "{:.2}", fl),
|
||||
EvalMetricValue::String(s) => write!(f, "{}", s),
|
||||
EvalMetricValue::Boolean(b) => write!(f, "{}", b),
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EvalMetric {
|
||||
pub name: String,
|
||||
|
||||
@@ -98,17 +98,6 @@ impl BenchmarkResults {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for EvalMetricValue {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
EvalMetricValue::Integer(i) => write!(f, "{}", i),
|
||||
EvalMetricValue::Float(fl) => write!(f, "{:.2}", fl),
|
||||
EvalMetricValue::String(s) => write!(f, "{}", s),
|
||||
EvalMetricValue::Boolean(b) => write!(f, "{}", b),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BenchmarkResults {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(f, "Benchmark Results")?;
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::bench_work_dir::BenchmarkWorkDir;
|
||||
use crate::eval_suites::EvaluationSuite;
|
||||
use crate::runners::model_runner::ModelRunner;
|
||||
use crate::utilities::{await_process_exits, parallel_bench_cmd};
|
||||
use anyhow::Context;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -11,9 +12,27 @@ pub struct BenchRunner {
|
||||
}
|
||||
|
||||
impl BenchRunner {
|
||||
pub fn new(config: PathBuf) -> anyhow::Result<BenchRunner> {
|
||||
let config = BenchRunConfig::from(config)?;
|
||||
BenchmarkWorkDir::init_experiment();
|
||||
pub fn new(config_path: PathBuf) -> anyhow::Result<BenchRunner> {
|
||||
let config = BenchRunConfig::from(config_path.clone())?;
|
||||
|
||||
let resolved_output_dir = match &config.output_dir {
|
||||
Some(path) => {
|
||||
if !path.is_absolute() {
|
||||
anyhow::bail!(
|
||||
"Config Error in '{}': 'output_dir' must be an absolute path, but found relative path: {}",
|
||||
config_path.display(),
|
||||
path.display()
|
||||
);
|
||||
}
|
||||
path.clone()
|
||||
}
|
||||
None => std::env::current_dir().context(
|
||||
"Failed to get current working directory to use as default output directory",
|
||||
)?,
|
||||
};
|
||||
|
||||
BenchmarkWorkDir::init_experiment(resolved_output_dir)?;
|
||||
|
||||
config.save("config.cfg".to_string());
|
||||
Ok(BenchRunner { config })
|
||||
}
|
||||
|
||||
@@ -4,12 +4,14 @@ use crate::bench_work_dir::BenchmarkWorkDir;
|
||||
use crate::eval_suites::{EvaluationSuite, ExtensionRequirements};
|
||||
use crate::reporting::EvaluationResult;
|
||||
use crate::utilities::await_process_exits;
|
||||
use anyhow::{bail, Context, Result};
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use tracing;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EvalRunner {
|
||||
@@ -17,13 +19,17 @@ pub struct EvalRunner {
|
||||
}
|
||||
|
||||
impl EvalRunner {
|
||||
pub fn from(config: String) -> anyhow::Result<EvalRunner> {
|
||||
let config = BenchRunConfig::from_string(config)?;
|
||||
pub fn from(config: String) -> Result<EvalRunner> {
|
||||
let config = BenchRunConfig::from_string(config)
|
||||
.context("Failed to parse evaluation configuration")?;
|
||||
Ok(EvalRunner { config })
|
||||
}
|
||||
|
||||
fn create_work_dir(&self, config: &BenchRunConfig) -> anyhow::Result<BenchmarkWorkDir> {
|
||||
let goose_model = config.models.first().unwrap();
|
||||
fn create_work_dir(&self, config: &BenchRunConfig) -> Result<BenchmarkWorkDir> {
|
||||
let goose_model = config
|
||||
.models
|
||||
.first()
|
||||
.context("No model specified in configuration")?;
|
||||
let model_name = goose_model.name.clone();
|
||||
let provider_name = goose_model.provider.clone();
|
||||
|
||||
@@ -48,13 +54,21 @@ impl EvalRunner {
|
||||
let work_dir = BenchmarkWorkDir::new(work_dir_name, include_dir);
|
||||
Ok(work_dir)
|
||||
}
|
||||
pub async fn run<F, Fut>(&mut self, agent_generator: F) -> anyhow::Result<()>
|
||||
|
||||
pub async fn run<F, Fut>(&mut self, agent_generator: F) -> Result<()>
|
||||
where
|
||||
F: Fn(ExtensionRequirements, String) -> Fut,
|
||||
Fut: Future<Output = BenchAgent> + Send,
|
||||
{
|
||||
let mut work_dir = self.create_work_dir(&self.config)?;
|
||||
let bench_eval = self.config.evals.first().unwrap();
|
||||
let mut work_dir = self
|
||||
.create_work_dir(&self.config)
|
||||
.context("Failed to create evaluation work directory")?;
|
||||
|
||||
let bench_eval = self
|
||||
.config
|
||||
.evals
|
||||
.first()
|
||||
.context("No evaluations specified in configuration")?;
|
||||
|
||||
let run_id = &self
|
||||
.config
|
||||
@@ -65,41 +79,89 @@ impl EvalRunner {
|
||||
|
||||
// create entire dir subtree for eval and cd into dir for running eval
|
||||
work_dir.set_eval(&bench_eval.selector, run_id);
|
||||
tracing::info!("Set evaluation directory for {}", bench_eval.selector);
|
||||
|
||||
if let Some(eval) = EvaluationSuite::from(&bench_eval.selector) {
|
||||
let now_stamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_nanos();
|
||||
let now_stamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.context("Failed to get current timestamp")?
|
||||
.as_nanos();
|
||||
|
||||
let session_id = format!("{}-{}", bench_eval.selector.clone(), now_stamp);
|
||||
let mut agent = agent_generator(eval.required_extensions(), session_id).await;
|
||||
tracing::info!("Agent created for {}", eval.name());
|
||||
|
||||
let mut result = EvaluationResult::new(eval.name().to_string());
|
||||
|
||||
if let Ok(metrics) = eval.run(&mut agent, &mut work_dir).await {
|
||||
for (name, metric) in metrics {
|
||||
result.add_metric(name, metric);
|
||||
match eval.run(&mut agent, &mut work_dir).await {
|
||||
Ok(metrics) => {
|
||||
tracing::info!("Evaluation run successful with {} metrics", metrics.len());
|
||||
for (name, metric) in metrics {
|
||||
result.add_metric(name, metric);
|
||||
}
|
||||
}
|
||||
|
||||
// Add any errors that occurred
|
||||
for error in agent.get_errors().await {
|
||||
result.add_error(error);
|
||||
Err(e) => {
|
||||
tracing::error!("Evaluation run failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
let eval_results = serde_json::to_string_pretty(&result)?;
|
||||
// Add any errors that occurred
|
||||
let errors = agent.get_errors().await;
|
||||
tracing::info!("Agent reported {} errors", errors.len());
|
||||
for error in errors {
|
||||
result.add_error(error);
|
||||
}
|
||||
|
||||
// Write results to file
|
||||
let eval_results = serde_json::to_string_pretty(&result)
|
||||
.context("Failed to serialize evaluation results to JSON")?;
|
||||
|
||||
let eval_results_file = env::current_dir()
|
||||
.context("Failed to get current directory")?
|
||||
.join(&self.config.eval_result_filename);
|
||||
|
||||
fs::write(&eval_results_file, &eval_results).with_context(|| {
|
||||
format!(
|
||||
"Failed to write evaluation results to {}",
|
||||
eval_results_file.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(
|
||||
"Wrote evaluation results to {}",
|
||||
eval_results_file.display()
|
||||
);
|
||||
|
||||
let eval_results_file = env::current_dir()?.join(&self.config.eval_result_filename);
|
||||
fs::write(&eval_results_file, &eval_results)?;
|
||||
self.config.save("config.cfg".to_string());
|
||||
work_dir.save();
|
||||
|
||||
// handle running post-process cmd if configured
|
||||
if let Some(cmd) = &bench_eval.post_process_cmd {
|
||||
let handle = Command::new(cmd).arg(&eval_results_file).spawn()?;
|
||||
tracing::info!("Running post-process command: {:?}", cmd);
|
||||
|
||||
let handle = Command::new(cmd)
|
||||
.arg(&eval_results_file)
|
||||
.spawn()
|
||||
.with_context(|| {
|
||||
format!("Failed to execute post-process command: {:?}", cmd)
|
||||
})?;
|
||||
|
||||
await_process_exits(&mut [handle], Vec::new());
|
||||
}
|
||||
|
||||
// copy session file into eval-dir
|
||||
let here = env::current_dir()?.canonicalize()?;
|
||||
BenchmarkWorkDir::deep_copy(agent.session_file().as_path(), here.as_path(), false)?;
|
||||
let here = env::current_dir()
|
||||
.context("Failed to get current directory")?
|
||||
.canonicalize()
|
||||
.context("Failed to canonicalize current directory path")?;
|
||||
|
||||
BenchmarkWorkDir::deep_copy(agent.session_file().as_path(), here.as_path(), false)
|
||||
.context("Failed to copy session file to evaluation directory")?;
|
||||
|
||||
tracing::info!("Evaluation completed successfully");
|
||||
} else {
|
||||
tracing::error!("No evaluation found for selector: {}", bench_eval.selector);
|
||||
bail!("No evaluation found for selector: {}", bench_eval.selector);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
81
crates/goose-bench/src/runners/metric_aggregator.rs
Normal file
81
crates/goose-bench/src/runners/metric_aggregator.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use std::path::PathBuf;
|
||||
use tracing;
|
||||
|
||||
pub struct MetricAggregator;
|
||||
|
||||
impl MetricAggregator {
|
||||
/// Generate leaderboard and aggregated metrics CSV files from benchmark directory
|
||||
pub fn generate_csv_from_benchmark_dir(benchmark_dir: &PathBuf) -> Result<()> {
|
||||
use std::process::Command;
|
||||
|
||||
// Step 1: Run prepare_aggregate_metrics.py to create aggregate_metrics.csv files
|
||||
let prepare_script_path = std::env::current_dir()
|
||||
.context("Failed to get current working directory")?
|
||||
.join("scripts")
|
||||
.join("bench-postprocess-scripts")
|
||||
.join("prepare_aggregate_metrics.py");
|
||||
|
||||
ensure!(
|
||||
prepare_script_path.exists(),
|
||||
"Prepare script not found: {}",
|
||||
prepare_script_path.display()
|
||||
);
|
||||
|
||||
tracing::info!(
|
||||
"Preparing aggregate metrics from benchmark directory: {}",
|
||||
benchmark_dir.display()
|
||||
);
|
||||
|
||||
let output = Command::new(&prepare_script_path)
|
||||
.arg("--benchmark-dir")
|
||||
.arg(benchmark_dir)
|
||||
.output()
|
||||
.context("Failed to execute prepare_aggregate_metrics.py script")?;
|
||||
|
||||
if !output.status.success() {
|
||||
let error_message = String::from_utf8_lossy(&output.stderr);
|
||||
bail!("Failed to prepare aggregate metrics: {}", error_message);
|
||||
}
|
||||
|
||||
let success_message = String::from_utf8_lossy(&output.stdout);
|
||||
tracing::info!("{}", success_message);
|
||||
|
||||
// Step 2: Run generate_leaderboard.py to create the final leaderboard
|
||||
let leaderboard_script_path = std::env::current_dir()
|
||||
.context("Failed to get current working directory")?
|
||||
.join("scripts")
|
||||
.join("bench-postprocess-scripts")
|
||||
.join("generate_leaderboard.py");
|
||||
|
||||
ensure!(
|
||||
leaderboard_script_path.exists(),
|
||||
"Leaderboard script not found: {}",
|
||||
leaderboard_script_path.display()
|
||||
);
|
||||
|
||||
tracing::info!(
|
||||
"Generating leaderboard from benchmark directory: {}",
|
||||
benchmark_dir.display()
|
||||
);
|
||||
|
||||
let output = Command::new(&leaderboard_script_path)
|
||||
.arg("--benchmark-dir")
|
||||
.arg(benchmark_dir)
|
||||
.arg("--leaderboard-output")
|
||||
.arg("leaderboard.csv")
|
||||
.arg("--union-output")
|
||||
.arg("all_metrics.csv")
|
||||
.output()
|
||||
.context("Failed to execute generate_leaderboard.py script")?;
|
||||
|
||||
if !output.status.success() {
|
||||
let error_message = String::from_utf8_lossy(&output.stderr);
|
||||
bail!("Failed to generate leaderboard: {}", error_message);
|
||||
}
|
||||
|
||||
let success_message = String::from_utf8_lossy(&output.stdout);
|
||||
tracing::info!("{}", success_message);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod bench_runner;
|
||||
pub mod eval_runner;
|
||||
pub mod metric_aggregator;
|
||||
pub mod model_runner;
|
||||
|
||||
@@ -3,12 +3,14 @@ use crate::eval_suites::EvaluationSuite;
|
||||
use crate::reporting::{BenchmarkResults, SuiteResult};
|
||||
use crate::runners::eval_runner::EvalRunner;
|
||||
use crate::utilities::{await_process_exits, parallel_bench_cmd};
|
||||
use anyhow::{Context, Result};
|
||||
use dotenvy::from_path_iter;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::read_to_string;
|
||||
use std::io::{self, BufRead};
|
||||
use std::path::PathBuf;
|
||||
use std::process::Child;
|
||||
use std::thread;
|
||||
use tracing;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModelRunner {
|
||||
@@ -16,23 +18,27 @@ pub struct ModelRunner {
|
||||
}
|
||||
|
||||
impl ModelRunner {
|
||||
pub fn from(config: String) -> anyhow::Result<ModelRunner> {
|
||||
let config = BenchRunConfig::from_string(config)?;
|
||||
pub fn from(config: String) -> Result<ModelRunner> {
|
||||
let config =
|
||||
BenchRunConfig::from_string(config).context("Failed to parse configuration")?;
|
||||
Ok(ModelRunner { config })
|
||||
}
|
||||
|
||||
pub fn run(&self) -> anyhow::Result<()> {
|
||||
let model = self.config.models.first().unwrap();
|
||||
pub fn run(&self) -> Result<()> {
|
||||
let model = self
|
||||
.config
|
||||
.models
|
||||
.first()
|
||||
.context("No model specified in config")?;
|
||||
let suites = self.collect_evals_for_run();
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..self.config.repeat.unwrap_or(1) {
|
||||
let mut self_copy = self.clone();
|
||||
let self_copy = self.clone();
|
||||
let model_clone = model.clone();
|
||||
let suites_clone = suites.clone();
|
||||
// create thread to handle launching parallel processes to run model's evals in parallel
|
||||
let handle = thread::spawn(move || {
|
||||
let handle = thread::spawn(move || -> Result<()> {
|
||||
self_copy.run_benchmark(&model_clone, suites_clone, i.to_string())
|
||||
});
|
||||
handles.push(handle);
|
||||
@@ -41,55 +47,32 @@ impl ModelRunner {
|
||||
|
||||
let mut all_runs_results: Vec<BenchmarkResults> = Vec::new();
|
||||
for i in 0..self.config.repeat.unwrap_or(1) {
|
||||
let run_results =
|
||||
self.collect_run_results(model.clone(), suites.clone(), i.to_string())?;
|
||||
all_runs_results.push(run_results);
|
||||
match self.collect_run_results(model.clone(), suites.clone(), i.to_string()) {
|
||||
Ok(run_results) => all_runs_results.push(run_results),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to collect results for run {}: {}", i, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
// write summary file
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_env_file(&self, path: &PathBuf) -> anyhow::Result<Vec<(String, String)>> {
|
||||
let file = std::fs::File::open(path)?;
|
||||
let reader = io::BufReader::new(file);
|
||||
let mut env_vars = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
// Skip empty lines and comments
|
||||
if line.trim().is_empty() || line.trim_start().starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Split on first '=' only
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
let key = key.trim().to_string();
|
||||
// Remove quotes if present
|
||||
let value = value
|
||||
.trim()
|
||||
.trim_matches('"')
|
||||
.trim_matches('\'')
|
||||
.to_string();
|
||||
env_vars.push((key, value));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(env_vars)
|
||||
}
|
||||
|
||||
fn run_benchmark(
|
||||
&mut self,
|
||||
&self,
|
||||
model: &BenchModel,
|
||||
suites: HashMap<String, Vec<BenchEval>>,
|
||||
run_id: String,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> Result<()> {
|
||||
let mut results_handles = HashMap::<String, Vec<Child>>::new();
|
||||
|
||||
// Load environment variables from file if specified
|
||||
let mut envs = self.toolshim_envs();
|
||||
if let Some(env_file) = &self.config.env_file {
|
||||
let env_vars = self.load_env_file(env_file)?;
|
||||
let env_vars = ModelRunner::load_env_file(env_file).context(format!(
|
||||
"Failed to load environment file: {}",
|
||||
env_file.display()
|
||||
))?;
|
||||
envs.extend(env_vars);
|
||||
}
|
||||
envs.push(("GOOSE_MODEL".to_string(), model.clone().name));
|
||||
@@ -116,9 +99,13 @@ impl ModelRunner {
|
||||
// Run parallel-safe evaluations in parallel
|
||||
if !parallel_evals.is_empty() {
|
||||
for eval_selector in ¶llel_evals {
|
||||
self.config.run_id = Some(run_id.clone());
|
||||
self.config.evals = vec![(*eval_selector).clone()];
|
||||
let cfg = self.config.to_string()?;
|
||||
let mut config_copy = self.config.clone();
|
||||
config_copy.run_id = Some(run_id.clone());
|
||||
config_copy.evals = vec![(*eval_selector).clone()];
|
||||
let cfg = config_copy
|
||||
.to_string()
|
||||
.context("Failed to serialize configuration")?;
|
||||
|
||||
let handle = parallel_bench_cmd("exec-eval".to_string(), cfg, envs.clone());
|
||||
results_handles.get_mut(suite).unwrap().push(handle);
|
||||
}
|
||||
@@ -126,9 +113,13 @@ impl ModelRunner {
|
||||
|
||||
// Run non-parallel-safe evaluations sequentially
|
||||
for eval_selector in &sequential_evals {
|
||||
self.config.run_id = Some(run_id.clone());
|
||||
self.config.evals = vec![(*eval_selector).clone()];
|
||||
let cfg = self.config.to_string()?;
|
||||
let mut config_copy = self.config.clone();
|
||||
config_copy.run_id = Some(run_id.clone());
|
||||
config_copy.evals = vec![(*eval_selector).clone()];
|
||||
let cfg = config_copy
|
||||
.to_string()
|
||||
.context("Failed to serialize configuration")?;
|
||||
|
||||
let handle = parallel_bench_cmd("exec-eval".to_string(), cfg, envs.clone());
|
||||
|
||||
// Wait for this process to complete before starting the next one
|
||||
@@ -150,7 +141,7 @@ impl ModelRunner {
|
||||
model: BenchModel,
|
||||
suites: HashMap<String, Vec<BenchEval>>,
|
||||
run_id: String,
|
||||
) -> anyhow::Result<BenchmarkResults> {
|
||||
) -> Result<BenchmarkResults> {
|
||||
let mut results = BenchmarkResults::new(model.provider.clone());
|
||||
|
||||
let mut summary_path: Option<PathBuf> = None;
|
||||
@@ -161,7 +152,17 @@ impl ModelRunner {
|
||||
let mut eval_path =
|
||||
EvalRunner::path_for_eval(&model, eval_selector, run_id.clone());
|
||||
eval_path.push(self.config.eval_result_filename.clone());
|
||||
let eval_result = serde_json::from_str(&read_to_string(&eval_path)?)?;
|
||||
|
||||
let content = read_to_string(&eval_path).with_context(|| {
|
||||
format!(
|
||||
"Failed to read evaluation results from {}",
|
||||
eval_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let eval_result = serde_json::from_str(&content)
|
||||
.context("Failed to parse evaluation results JSON")?;
|
||||
|
||||
suite_result.add_evaluation(eval_result);
|
||||
|
||||
// use current eval to determine where the summary should be written
|
||||
@@ -180,12 +181,21 @@ impl ModelRunner {
|
||||
results.add_suite(suite_result);
|
||||
}
|
||||
|
||||
let mut run_summary = PathBuf::new();
|
||||
run_summary.push(summary_path.clone().unwrap());
|
||||
run_summary.push(&self.config.run_summary_filename);
|
||||
if let Some(path) = summary_path {
|
||||
let mut run_summary = PathBuf::new();
|
||||
run_summary.push(path);
|
||||
run_summary.push(&self.config.run_summary_filename);
|
||||
|
||||
let output_str = serde_json::to_string_pretty(&results)?;
|
||||
std::fs::write(run_summary, &output_str)?;
|
||||
let output_str = serde_json::to_string_pretty(&results)
|
||||
.context("Failed to serialize benchmark results to JSON")?;
|
||||
|
||||
std::fs::write(&run_summary, &output_str).with_context(|| {
|
||||
format!(
|
||||
"Failed to write results summary to {}",
|
||||
run_summary.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
@@ -210,20 +220,29 @@ impl ModelRunner {
|
||||
|
||||
fn toolshim_envs(&self) -> Vec<(String, String)> {
|
||||
// read tool-shim preference from config, set respective env vars accordingly
|
||||
let model = self.config.models.first().unwrap();
|
||||
|
||||
let mut shim_envs: Vec<(String, String)> = Vec::new();
|
||||
if let Some(shim_opt) = &model.tool_shim {
|
||||
if shim_opt.use_tool_shim {
|
||||
shim_envs.push(("GOOSE_TOOLSHIM".to_string(), "true".to_string()));
|
||||
if let Some(shim_model) = &shim_opt.tool_shim_model {
|
||||
shim_envs.push((
|
||||
"GOOSE_TOOLSHIM_OLLAMA_MODEL".to_string(),
|
||||
shim_model.clone(),
|
||||
));
|
||||
if let Some(model) = self.config.models.first() {
|
||||
if let Some(shim_opt) = &model.tool_shim {
|
||||
if shim_opt.use_tool_shim {
|
||||
shim_envs.push(("GOOSE_TOOLSHIM".to_string(), "true".to_string()));
|
||||
if let Some(shim_model) = &shim_opt.tool_shim_model {
|
||||
shim_envs.push((
|
||||
"GOOSE_TOOLSHIM_OLLAMA_MODEL".to_string(),
|
||||
shim_model.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
shim_envs
|
||||
}
|
||||
|
||||
fn load_env_file(path: &PathBuf) -> Result<Vec<(String, String)>> {
|
||||
let iter =
|
||||
from_path_iter(path).context("Failed to read environment variables from file")?;
|
||||
let env_vars = iter
|
||||
.map(|item| item.context("Failed to parse environment variable"))
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(env_vars)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
use anyhow::Result;
|
||||
use std::env;
|
||||
use std::process::{Child, Command};
|
||||
use std::thread::JoinHandle;
|
||||
use tracing;
|
||||
|
||||
pub fn await_process_exits(
|
||||
child_processes: &mut [Child],
|
||||
handles: Vec<JoinHandle<anyhow::Result<()>>>,
|
||||
) {
|
||||
pub fn await_process_exits(child_processes: &mut [Child], handles: Vec<JoinHandle<Result<()>>>) {
|
||||
for child in child_processes.iter_mut() {
|
||||
match child.wait() {
|
||||
Ok(status) => println!("Child exited with status: {}", status),
|
||||
Err(e) => println!("Error waiting for child: {}", e),
|
||||
Ok(status) => tracing::info!("Child exited with status: {}", status),
|
||||
Err(e) => tracing::error!("Error waiting for child: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +17,7 @@ pub fn await_process_exits(
|
||||
Ok(_res) => (),
|
||||
Err(e) => {
|
||||
// Handle thread panic
|
||||
println!("Thread panicked: {:?}", e);
|
||||
tracing::error!("Thread panicked: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,6 +52,9 @@ shlex = "1.3.0"
|
||||
async-trait = "0.1.86"
|
||||
base64 = "0.22.1"
|
||||
regex = "1.11.1"
|
||||
minijinja = "2.8.0"
|
||||
nix = { version = "0.30.1", features = ["process", "signal"] }
|
||||
tar = "0.4"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
@@ -7,15 +7,22 @@ use crate::commands::bench::agent_generator;
|
||||
use crate::commands::configure::handle_configure;
|
||||
use crate::commands::info::handle_info;
|
||||
use crate::commands::mcp::run_server;
|
||||
use crate::commands::project::{handle_project_default, handle_projects_interactive};
|
||||
use crate::commands::recipe::{handle_deeplink, handle_validate};
|
||||
// Import the new handlers from commands::schedule
|
||||
use crate::commands::schedule::{
|
||||
handle_schedule_add, handle_schedule_list, handle_schedule_remove, handle_schedule_run_now,
|
||||
handle_schedule_sessions,
|
||||
};
|
||||
use crate::commands::session::{handle_session_list, handle_session_remove};
|
||||
use crate::logging::setup_logging;
|
||||
use crate::recipe::load_recipe;
|
||||
use crate::recipes::recipe::{explain_recipe_with_parameters, load_recipe_as_template};
|
||||
use crate::session;
|
||||
use crate::session::{build_session, SessionBuilderConfig};
|
||||
use goose_bench::bench_config::BenchRunConfig;
|
||||
use goose_bench::runners::bench_runner::BenchRunner;
|
||||
use goose_bench::runners::eval_runner::EvalRunner;
|
||||
use goose_bench::runners::metric_aggregator::MetricAggregator;
|
||||
use goose_bench::runners::model_runner::ModelRunner;
|
||||
use std::io::Read;
|
||||
use std::path::PathBuf;
|
||||
@@ -59,6 +66,13 @@ fn extract_identifier(identifier: Identifier) -> session::Identifier {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_key_val(s: &str) -> Result<(String, String), String> {
|
||||
match s.split_once('=') {
|
||||
Some((key, value)) => Ok((key.to_string(), value.to_string())),
|
||||
None => Err(format!("invalid KEY=VALUE: {}", s)),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum SessionCommand {
|
||||
#[command(about = "List all available sessions")]
|
||||
@@ -81,17 +95,52 @@ enum SessionCommand {
|
||||
)]
|
||||
ascending: bool,
|
||||
},
|
||||
#[command(about = "Remove sessions")]
|
||||
#[command(about = "Remove sessions. Runs interactively if no ID or regex is provided.")]
|
||||
Remove {
|
||||
#[arg(short, long, help = "session id to be removed", default_value = "")]
|
||||
#[arg(short, long, help = "Session ID to be removed (optional)")]
|
||||
id: Option<String>,
|
||||
#[arg(short, long, help = "Regex for removing matched sessions (optional)")]
|
||||
regex: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum SchedulerCommand {
|
||||
#[command(about = "Add a new scheduled job")]
|
||||
Add {
|
||||
#[arg(long, help = "Unique ID for the job")]
|
||||
id: String,
|
||||
#[arg(long, help = "Cron string for the schedule (e.g., '0 0 * * * *')")]
|
||||
cron: String,
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
help = "regex for removing matched session",
|
||||
default_value = ""
|
||||
help = "Recipe source (path to file, or base64 encoded recipe string)"
|
||||
)]
|
||||
regex: String,
|
||||
recipe_source: String,
|
||||
},
|
||||
#[command(about = "List all scheduled jobs")]
|
||||
List {},
|
||||
#[command(about = "Remove a scheduled job by ID")]
|
||||
Remove {
|
||||
#[arg(long, help = "ID of the job to remove")] // Changed from positional to named --id
|
||||
id: String,
|
||||
},
|
||||
/// List sessions created by a specific schedule
|
||||
#[command(about = "List sessions created by a specific schedule")]
|
||||
Sessions {
|
||||
/// ID of the schedule
|
||||
#[arg(long, help = "ID of the schedule")] // Explicitly make it --id
|
||||
id: String,
|
||||
/// Maximum number of sessions to return
|
||||
#[arg(long, help = "Maximum number of sessions to return")]
|
||||
limit: Option<u32>,
|
||||
},
|
||||
/// Run a scheduled job immediately
|
||||
#[command(about = "Run a scheduled job immediately")]
|
||||
RunNow {
|
||||
/// ID of the schedule to run
|
||||
#[arg(long, help = "ID of the schedule to run")] // Explicitly make it --id
|
||||
id: String,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -134,24 +183,39 @@ pub enum BenchCommand {
|
||||
#[arg(short, long, help = "A serialized config file for the eval only.")]
|
||||
config: String,
|
||||
},
|
||||
|
||||
#[command(
|
||||
name = "generate-leaderboard",
|
||||
about = "Generate a leaderboard CSV from benchmark results"
|
||||
)]
|
||||
GenerateLeaderboard {
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
help = "Path to the benchmark directory containing model evaluation results"
|
||||
)]
|
||||
benchmark_dir: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum RecipeCommand {
|
||||
/// Validate a recipe file
|
||||
#[command(about = "Validate a recipe file")]
|
||||
#[command(about = "Validate a recipe")]
|
||||
Validate {
|
||||
/// Path to the recipe file to validate
|
||||
#[arg(help = "Path to the recipe file to validate")]
|
||||
file: String,
|
||||
/// Recipe name to get recipe file to validate
|
||||
#[arg(help = "recipe name to get recipe file or full path to the recipe file to validate")]
|
||||
recipe_name: String,
|
||||
},
|
||||
|
||||
/// Generate a deeplink for a recipe file
|
||||
#[command(about = "Generate a deeplink for a recipe file")]
|
||||
#[command(about = "Generate a deeplink for a recipe")]
|
||||
Deeplink {
|
||||
/// Path to the recipe file
|
||||
#[arg(help = "Path to the recipe file")]
|
||||
file: String,
|
||||
/// Recipe name to get recipe file to generate deeplink
|
||||
#[arg(
|
||||
help = "recipe name to get recipe file or full path to the recipe file to generate deeplink"
|
||||
)]
|
||||
recipe_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -194,6 +258,14 @@ enum Command {
|
||||
)]
|
||||
resume: bool,
|
||||
|
||||
/// Show message history when resuming
|
||||
#[arg(
|
||||
long,
|
||||
help = "Show previous messages when resuming a session",
|
||||
requires = "resume"
|
||||
)]
|
||||
history: bool,
|
||||
|
||||
/// Enable debug output mode
|
||||
#[arg(
|
||||
long,
|
||||
@@ -202,6 +274,15 @@ enum Command {
|
||||
)]
|
||||
debug: bool,
|
||||
|
||||
/// Maximum number of consecutive identical tool calls allowed
|
||||
#[arg(
|
||||
long = "max-tool-repetitions",
|
||||
value_name = "NUMBER",
|
||||
help = "Maximum number of consecutive identical tool calls allowed",
|
||||
long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops."
|
||||
)]
|
||||
max_tool_repetitions: Option<u32>,
|
||||
|
||||
/// Add stdio extensions with environment variables and commands
|
||||
#[arg(
|
||||
long = "with-extension",
|
||||
@@ -233,6 +314,14 @@ enum Command {
|
||||
builtins: Vec<String>,
|
||||
},
|
||||
|
||||
/// Open the last project directory
|
||||
#[command(about = "Open the last project directory", visible_alias = "p")]
|
||||
Project {},
|
||||
|
||||
/// List recent project directories
|
||||
#[command(about = "List recent project directories", visible_alias = "ps")]
|
||||
Projects,
|
||||
|
||||
/// Execute commands from an instruction file
|
||||
#[command(about = "Execute commands from an instruction file or stdin")]
|
||||
Run {
|
||||
@@ -259,18 +348,28 @@ enum Command {
|
||||
)]
|
||||
input_text: Option<String>,
|
||||
|
||||
/// Path to recipe.yaml file
|
||||
/// Recipe name or full path to the recipe file
|
||||
#[arg(
|
||||
short = None,
|
||||
long = "recipe",
|
||||
value_name = "FILE",
|
||||
help = "Path to recipe.yaml file",
|
||||
long_help = "Path to a recipe.yaml file that defines a custom agent configuration",
|
||||
value_name = "RECIPE_NAME or FULL_PATH_TO_RECIPE_FILE",
|
||||
help = "Recipe name to get recipe file or the full path of the recipe file (use --explain to see recipe details)",
|
||||
long_help = "Recipe name to get recipe file or the full path of the recipe file that defines a custom agent configuration. Use --explain to see the recipe's title, description, and parameters.",
|
||||
conflicts_with = "instructions",
|
||||
conflicts_with = "input_text"
|
||||
)]
|
||||
recipe: Option<String>,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
value_name = "KEY=VALUE",
|
||||
help = "Dynamic parameters (e.g., --params username=alice --params channel_name=goose-channel)",
|
||||
long_help = "Key-value parameters to pass to the recipe file. Can be specified multiple times.",
|
||||
action = clap::ArgAction::Append,
|
||||
value_parser = parse_key_val,
|
||||
)]
|
||||
params: Vec<(String, String)>,
|
||||
|
||||
/// Continue in interactive mode after processing input
|
||||
#[arg(
|
||||
short = 's',
|
||||
@@ -279,6 +378,31 @@ enum Command {
|
||||
)]
|
||||
interactive: bool,
|
||||
|
||||
/// Run without storing a session file
|
||||
#[arg(
|
||||
long = "no-session",
|
||||
help = "Run without storing a session file",
|
||||
long_help = "Execute commands without creating or using a session file. Useful for automated runs.",
|
||||
conflicts_with_all = ["resume", "name", "path"]
|
||||
)]
|
||||
no_session: bool,
|
||||
|
||||
/// Show the recipe title, description, and parameters
|
||||
#[arg(
|
||||
long = "explain",
|
||||
help = "Show the recipe title, description, and parameters"
|
||||
)]
|
||||
explain: bool,
|
||||
|
||||
/// Maximum number of consecutive identical tool calls allowed
|
||||
#[arg(
|
||||
long = "max-tool-repetitions",
|
||||
value_name = "NUMBER",
|
||||
help = "Maximum number of consecutive identical tool calls allowed",
|
||||
long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops."
|
||||
)]
|
||||
max_tool_repetitions: Option<u32>,
|
||||
|
||||
/// Identifier for this run session
|
||||
#[command(flatten)]
|
||||
identifier: Option<Identifier>,
|
||||
@@ -339,6 +463,13 @@ enum Command {
|
||||
command: RecipeCommand,
|
||||
},
|
||||
|
||||
/// Manage scheduled jobs
|
||||
#[command(about = "Manage scheduled jobs", visible_alias = "sched")]
|
||||
Schedule {
|
||||
#[command(subcommand)]
|
||||
command: SchedulerCommand,
|
||||
},
|
||||
|
||||
/// Update the Goose CLI version
|
||||
#[command(about = "Update the goose CLI version")]
|
||||
Update {
|
||||
@@ -378,6 +509,11 @@ struct InputConfig {
|
||||
pub async fn cli() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Track the current directory in projects.json
|
||||
if let Err(e) = crate::project_tracker::update_project_tracker(None, None) {
|
||||
eprintln!("Warning: Failed to update project tracker: {}", e);
|
||||
}
|
||||
|
||||
match cli.command {
|
||||
Some(Command::Configure {}) => {
|
||||
let _ = handle_configure().await;
|
||||
@@ -394,7 +530,9 @@ pub async fn cli() -> Result<()> {
|
||||
command,
|
||||
identifier,
|
||||
resume,
|
||||
history,
|
||||
debug,
|
||||
max_tool_repetitions,
|
||||
extensions,
|
||||
remote_extensions,
|
||||
builtins,
|
||||
@@ -414,26 +552,44 @@ pub async fn cli() -> Result<()> {
|
||||
}
|
||||
None => {
|
||||
// Run session command by default
|
||||
let mut session = build_session(SessionBuilderConfig {
|
||||
let mut session: crate::Session = build_session(SessionBuilderConfig {
|
||||
identifier: identifier.map(extract_identifier),
|
||||
resume,
|
||||
no_session: false,
|
||||
extensions,
|
||||
remote_extensions,
|
||||
builtins,
|
||||
extensions_override: None,
|
||||
additional_system_prompt: None,
|
||||
debug,
|
||||
max_tool_repetitions,
|
||||
})
|
||||
.await;
|
||||
setup_logging(
|
||||
session.session_file().file_stem().and_then(|s| s.to_str()),
|
||||
None,
|
||||
)?;
|
||||
|
||||
// Render previous messages if resuming a session and history flag is set
|
||||
if resume && history {
|
||||
session.render_message_history();
|
||||
}
|
||||
|
||||
let _ = session.interactive(None).await;
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
}
|
||||
Some(Command::Project {}) => {
|
||||
// Default behavior: offer to resume the last project
|
||||
handle_project_default()?;
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Projects) => {
|
||||
// Interactive project selection
|
||||
handle_projects_interactive()?;
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Run {
|
||||
instructions,
|
||||
input_text,
|
||||
@@ -441,13 +597,17 @@ pub async fn cli() -> Result<()> {
|
||||
interactive,
|
||||
identifier,
|
||||
resume,
|
||||
no_session,
|
||||
debug,
|
||||
max_tool_repetitions,
|
||||
extensions,
|
||||
remote_extensions,
|
||||
builtins,
|
||||
params,
|
||||
explain,
|
||||
}) => {
|
||||
let input_config = match (instructions, input_text, recipe) {
|
||||
(Some(file), _, _) if file == "-" => {
|
||||
let input_config = match (instructions, input_text, recipe, explain) {
|
||||
(Some(file), _, _, _) if file == "-" => {
|
||||
let mut input = String::new();
|
||||
std::io::stdin()
|
||||
.read_to_string(&mut input)
|
||||
@@ -459,7 +619,7 @@ pub async fn cli() -> Result<()> {
|
||||
additional_system_prompt: None,
|
||||
}
|
||||
}
|
||||
(Some(file), _, _) => {
|
||||
(Some(file), _, _, _) => {
|
||||
let contents = std::fs::read_to_string(&file).unwrap_or_else(|err| {
|
||||
eprintln!(
|
||||
"Instruction file not found — did you mean to use goose run --text?\n{}",
|
||||
@@ -473,23 +633,28 @@ pub async fn cli() -> Result<()> {
|
||||
additional_system_prompt: None,
|
||||
}
|
||||
}
|
||||
(_, Some(text), _) => InputConfig {
|
||||
(_, Some(text), _, _) => InputConfig {
|
||||
contents: Some(text),
|
||||
extensions_override: None,
|
||||
additional_system_prompt: None,
|
||||
},
|
||||
(_, _, Some(file)) => {
|
||||
let recipe = load_recipe(&file, true).unwrap_or_else(|err| {
|
||||
eprintln!("{}: {}", console::style("Error").red().bold(), err);
|
||||
std::process::exit(1);
|
||||
});
|
||||
(_, _, Some(recipe_name), explain) => {
|
||||
if explain {
|
||||
explain_recipe_with_parameters(&recipe_name, params)?;
|
||||
return Ok(());
|
||||
}
|
||||
let recipe =
|
||||
load_recipe_as_template(&recipe_name, params).unwrap_or_else(|err| {
|
||||
eprintln!("{}: {}", console::style("Error").red().bold(), err);
|
||||
std::process::exit(1);
|
||||
});
|
||||
InputConfig {
|
||||
contents: recipe.prompt,
|
||||
extensions_override: recipe.extensions,
|
||||
additional_system_prompt: Some(recipe.instructions),
|
||||
additional_system_prompt: recipe.instructions,
|
||||
}
|
||||
}
|
||||
(None, None, None) => {
|
||||
(None, None, None, _) => {
|
||||
eprintln!("Error: Must provide either --instructions (-i), --text (-t), or --recipe. Use -i - for stdin.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
@@ -498,12 +663,14 @@ pub async fn cli() -> Result<()> {
|
||||
let mut session = build_session(SessionBuilderConfig {
|
||||
identifier: identifier.map(extract_identifier),
|
||||
resume,
|
||||
no_session,
|
||||
extensions,
|
||||
remote_extensions,
|
||||
builtins,
|
||||
extensions_override: input_config.extensions_override,
|
||||
additional_system_prompt: input_config.additional_system_prompt,
|
||||
debug,
|
||||
max_tool_repetitions,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -523,6 +690,32 @@ pub async fn cli() -> Result<()> {
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Schedule { command }) => {
|
||||
match command {
|
||||
SchedulerCommand::Add {
|
||||
id,
|
||||
cron,
|
||||
recipe_source,
|
||||
} => {
|
||||
handle_schedule_add(id, cron, recipe_source).await?;
|
||||
}
|
||||
SchedulerCommand::List {} => {
|
||||
handle_schedule_list().await?;
|
||||
}
|
||||
SchedulerCommand::Remove { id } => {
|
||||
handle_schedule_remove(id).await?;
|
||||
}
|
||||
SchedulerCommand::Sessions { id, limit } => {
|
||||
// New arm
|
||||
handle_schedule_sessions(id, limit).await?;
|
||||
}
|
||||
SchedulerCommand::RunNow { id } => {
|
||||
// New arm
|
||||
handle_schedule_run_now(id).await?;
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Update {
|
||||
canary,
|
||||
reconfigure,
|
||||
@@ -533,22 +726,31 @@ pub async fn cli() -> Result<()> {
|
||||
Some(Command::Bench { cmd }) => {
|
||||
match cmd {
|
||||
BenchCommand::Selectors { config } => BenchRunner::list_selectors(config)?,
|
||||
BenchCommand::InitConfig { name } => BenchRunConfig::default().save(name),
|
||||
BenchCommand::InitConfig { name } => {
|
||||
let mut config = BenchRunConfig::default();
|
||||
let cwd =
|
||||
std::env::current_dir().expect("Failed to get current working directory");
|
||||
config.output_dir = Some(cwd);
|
||||
config.save(name);
|
||||
}
|
||||
BenchCommand::Run { config } => BenchRunner::new(config)?.run()?,
|
||||
BenchCommand::EvalModel { config } => ModelRunner::from(config)?.run()?,
|
||||
BenchCommand::ExecEval { config } => {
|
||||
EvalRunner::from(config)?.run(agent_generator).await?
|
||||
}
|
||||
BenchCommand::GenerateLeaderboard { benchmark_dir } => {
|
||||
MetricAggregator::generate_csv_from_benchmark_dir(&benchmark_dir)?
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Recipe { command }) => {
|
||||
match command {
|
||||
RecipeCommand::Validate { file } => {
|
||||
handle_validate(file)?;
|
||||
RecipeCommand::Validate { recipe_name } => {
|
||||
handle_validate(&recipe_name)?;
|
||||
}
|
||||
RecipeCommand::Deeplink { file } => {
|
||||
handle_deeplink(file)?;
|
||||
RecipeCommand::Deeplink { recipe_name } => {
|
||||
handle_deeplink(&recipe_name)?;
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
@@ -559,7 +761,19 @@ pub async fn cli() -> Result<()> {
|
||||
Ok(())
|
||||
} else {
|
||||
// Run session command by default
|
||||
let mut session = build_session(SessionBuilderConfig::default()).await;
|
||||
let mut session = build_session(SessionBuilderConfig {
|
||||
identifier: None,
|
||||
resume: false,
|
||||
no_session: false,
|
||||
extensions: Vec::new(),
|
||||
remote_extensions: Vec::new(),
|
||||
builtins: Vec::new(),
|
||||
extensions_override: None,
|
||||
additional_system_prompt: None,
|
||||
debug: false,
|
||||
max_tool_repetitions: None,
|
||||
})
|
||||
.await;
|
||||
setup_logging(
|
||||
session.session_file().file_stem().and_then(|s| s.to_str()),
|
||||
None,
|
||||
|
||||
@@ -34,12 +34,14 @@ pub async fn agent_generator(
|
||||
let base_session = build_session(SessionBuilderConfig {
|
||||
identifier,
|
||||
resume: false,
|
||||
no_session: false,
|
||||
extensions: requirements.external,
|
||||
remote_extensions: requirements.remote,
|
||||
builtins: requirements.builtin,
|
||||
extensions_override: None,
|
||||
additional_system_prompt: None,
|
||||
debug: false,
|
||||
max_tool_repetitions: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
|
||||
use crate::recipes::github_recipe::GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY;
|
||||
|
||||
// useful for light themes where there is no dicernible colour contrast between
|
||||
// cursor-selected and cursor-unselected items.
|
||||
const MULTISELECT_VISIBILITY_HINT: &str = "<";
|
||||
@@ -193,7 +195,7 @@ pub async fn handle_configure() -> Result<(), Box<dyn Error>> {
|
||||
.item(
|
||||
"settings",
|
||||
"Goose Settings",
|
||||
"Set the Goose Mode, Tool Output, Tool Permissions, Experiment and more",
|
||||
"Set the Goose Mode, Tool Output, Tool Permissions, Experiment, Goose recipe github repo and more",
|
||||
)
|
||||
.interact()?;
|
||||
|
||||
@@ -325,11 +327,40 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
|
||||
}
|
||||
}
|
||||
|
||||
// Select model, defaulting to the provider's recommended model UNLESS there is an env override
|
||||
let default_model = std::env::var("GOOSE_MODEL").unwrap_or(provider_meta.default_model.clone());
|
||||
let model: String = cliclack::input("Enter a model from that provider:")
|
||||
.default_input(&default_model)
|
||||
.interact()?;
|
||||
// Attempt to fetch supported models for this provider
|
||||
let spin = spinner();
|
||||
spin.start("Attempting to fetch supported models...");
|
||||
let models_res = {
|
||||
let temp_model_config = goose::model::ModelConfig::new(provider_meta.default_model.clone());
|
||||
let temp_provider = create(provider_name, temp_model_config)?;
|
||||
temp_provider.fetch_supported_models_async().await
|
||||
};
|
||||
spin.stop(style("Model fetch complete").green());
|
||||
|
||||
// Select a model: on fetch error show styled error and abort; if Some(models), show list; if None, free-text input
|
||||
let model: String = match models_res {
|
||||
Err(e) => {
|
||||
// Provider hook error
|
||||
cliclack::outro(style(e.to_string()).on_red().white())?;
|
||||
return Ok(false);
|
||||
}
|
||||
Ok(Some(models)) => cliclack::select("Select a model:")
|
||||
.items(
|
||||
&models
|
||||
.iter()
|
||||
.map(|m| (m, m.as_str(), ""))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.interact()?
|
||||
.to_string(),
|
||||
Ok(None) => {
|
||||
let default_model =
|
||||
std::env::var("GOOSE_MODEL").unwrap_or(provider_meta.default_model.clone());
|
||||
cliclack::input("Enter a model from that provider:")
|
||||
.default_input(&default_model)
|
||||
.interact()?
|
||||
}
|
||||
};
|
||||
|
||||
// Test the configuration
|
||||
let spin = spinner();
|
||||
@@ -793,6 +824,11 @@ pub fn remove_extension_dialog() -> Result<(), Box<dyn Error>> {
|
||||
pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
|
||||
let setting_type = cliclack::select("What setting would you like to configure?")
|
||||
.item("goose_mode", "Goose Mode", "Configure Goose mode")
|
||||
.item(
|
||||
"goose_router_strategy",
|
||||
"Router Tool Selection Strategy",
|
||||
"Configure the strategy for selecting tools to use",
|
||||
)
|
||||
.item(
|
||||
"tool_permission",
|
||||
"Tool Permission",
|
||||
@@ -808,12 +844,20 @@ pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
|
||||
"Toggle Experiment",
|
||||
"Enable or disable an experiment feature",
|
||||
)
|
||||
.item(
|
||||
"recipe",
|
||||
"Goose recipe github repo",
|
||||
"Goose will pull recipes from this repo if not found locally.",
|
||||
)
|
||||
.interact()?;
|
||||
|
||||
match setting_type {
|
||||
"goose_mode" => {
|
||||
configure_goose_mode_dialog()?;
|
||||
}
|
||||
"goose_router_strategy" => {
|
||||
configure_goose_router_strategy_dialog()?;
|
||||
}
|
||||
"tool_permission" => {
|
||||
configure_tool_permissions_dialog().await.and(Ok(()))?;
|
||||
}
|
||||
@@ -823,6 +867,9 @@ pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
|
||||
"experiment" => {
|
||||
toggle_experiments_dialog()?;
|
||||
}
|
||||
"recipe" => {
|
||||
configure_recipe_dialog()?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
@@ -882,6 +929,49 @@ pub fn configure_goose_mode_dialog() -> Result<(), Box<dyn Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn configure_goose_router_strategy_dialog() -> Result<(), Box<dyn Error>> {
|
||||
let config = Config::global();
|
||||
|
||||
// Check if GOOSE_ROUTER_STRATEGY is set as an environment variable
|
||||
if std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY").is_ok() {
|
||||
let _ = cliclack::log::info("Notice: GOOSE_ROUTER_TOOL_SELECTION_STRATEGY environment variable is set. Configuration will override this.");
|
||||
}
|
||||
|
||||
let strategy = cliclack::select("Which router strategy would you like to use?")
|
||||
.item(
|
||||
"vector",
|
||||
"Vector Strategy",
|
||||
"Use vector-based similarity to select tools",
|
||||
)
|
||||
.item(
|
||||
"default",
|
||||
"Default Strategy",
|
||||
"Use the default tool selection strategy",
|
||||
)
|
||||
.interact()?;
|
||||
|
||||
match strategy {
|
||||
"vector" => {
|
||||
config.set_param(
|
||||
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
|
||||
Value::String("vector".to_string()),
|
||||
)?;
|
||||
cliclack::outro(
|
||||
"Set to Vector Strategy - using vector-based similarity for tool selection",
|
||||
)?;
|
||||
}
|
||||
"default" => {
|
||||
config.set_param(
|
||||
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
|
||||
Value::String("default".to_string()),
|
||||
)?;
|
||||
cliclack::outro("Set to Default Strategy - using default tool selection")?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn configure_tool_output_dialog() -> Result<(), Box<dyn Error>> {
|
||||
let config = Config::global();
|
||||
// Check if GOOSE_CLI_MIN_PRIORITY is set as an environment variable
|
||||
@@ -1104,3 +1194,26 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn configure_recipe_dialog() -> Result<(), Box<dyn Error>> {
|
||||
let key_name = GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY;
|
||||
let config = Config::global();
|
||||
let default_recipe_repo = std::env::var(key_name)
|
||||
.ok()
|
||||
.or_else(|| config.get_param(key_name).unwrap_or(None));
|
||||
let mut recipe_repo_input = cliclack::input(
|
||||
"Enter your Goose Recipe Github repo (owner/repo): eg: my_org/goose-recipes",
|
||||
)
|
||||
.required(false);
|
||||
if let Some(recipe_repo) = default_recipe_repo {
|
||||
recipe_repo_input = recipe_repo_input.default_input(&recipe_repo);
|
||||
}
|
||||
let input_value: String = recipe_repo_input.interact()?;
|
||||
// if input is blank, it clears the recipe github repo settings in the config file
|
||||
if input_value.clone().trim().is_empty() {
|
||||
config.delete(key_name)?;
|
||||
} else {
|
||||
config.set_param(key_name, Value::String(input_value))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -7,6 +7,16 @@ use mcp_server::router::RouterService;
|
||||
use mcp_server::{BoundedService, ByteTransport, Server};
|
||||
use tokio::io::{stdin, stdout};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
#[cfg(unix)]
|
||||
use nix::sys::signal::{kill, Signal};
|
||||
#[cfg(unix)]
|
||||
use nix::unistd::getpgrp;
|
||||
#[cfg(unix)]
|
||||
use nix::unistd::Pid;
|
||||
|
||||
pub async fn run_server(name: &str) -> Result<()> {
|
||||
// Initialize logging
|
||||
crate::logging::setup_logging(Some(&format!("mcp-{name}")), None)?;
|
||||
@@ -26,10 +36,38 @@ pub async fn run_server(name: &str) -> Result<()> {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Create shutdown notification channel
|
||||
let shutdown = Arc::new(Notify::new());
|
||||
let shutdown_clone = shutdown.clone();
|
||||
|
||||
// Spawn shutdown signal handler
|
||||
tokio::spawn(async move {
|
||||
crate::signal::shutdown_signal().await;
|
||||
shutdown_clone.notify_one();
|
||||
});
|
||||
|
||||
// Create and run the server
|
||||
let server = Server::new(router.unwrap_or_else(|| panic!("Unknown server requested {}", name)));
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
|
||||
tracing::info!("Server initialized and ready to handle requests");
|
||||
Ok(server.run(transport).await?)
|
||||
|
||||
tokio::select! {
|
||||
result = server.run(transport) => {
|
||||
Ok(result?)
|
||||
}
|
||||
_ = shutdown.notified() => {
|
||||
// On Unix systems, kill the entire process group
|
||||
#[cfg(unix)]
|
||||
{
|
||||
fn terminate_process_group() {
|
||||
let pgid = getpgrp();
|
||||
kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGTERM)
|
||||
.expect("Failed to send SIGTERM to process group");
|
||||
}
|
||||
terminate_process_group();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ pub mod bench;
|
||||
pub mod configure;
|
||||
pub mod info;
|
||||
pub mod mcp;
|
||||
pub mod project;
|
||||
pub mod recipe;
|
||||
pub mod schedule;
|
||||
pub mod session;
|
||||
pub mod update;
|
||||
|
||||
307
crates/goose-cli/src/commands/project.rs
Normal file
307
crates/goose-cli/src/commands/project.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use anyhow::Result;
|
||||
use chrono::DateTime;
|
||||
use cliclack::{self, intro, outro};
|
||||
use std::path::Path;
|
||||
|
||||
use crate::project_tracker::ProjectTracker;
|
||||
|
||||
/// Format a DateTime for display
|
||||
fn format_date(date: DateTime<chrono::Utc>) -> String {
|
||||
// Format: "2025-05-08 18:15:30"
|
||||
date.format("%Y-%m-%d %H:%M:%S").to_string()
|
||||
}
|
||||
|
||||
/// Handle the default project command
|
||||
///
|
||||
/// Offers options to resume the most recently accessed project
|
||||
pub fn handle_project_default() -> Result<()> {
|
||||
let tracker = ProjectTracker::load()?;
|
||||
let mut projects = tracker.list_projects();
|
||||
|
||||
if projects.is_empty() {
|
||||
// If no projects exist, just start a new one in the current directory
|
||||
println!("No previous projects found. Starting a new session in the current directory.");
|
||||
let mut command = std::process::Command::new("goose");
|
||||
command.arg("session");
|
||||
let status = command.status()?;
|
||||
|
||||
if !status.success() {
|
||||
println!("Failed to run Goose. Exit code: {:?}", status.code());
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Sort projects by last_accessed (newest first)
|
||||
projects.sort_by(|a, b| b.last_accessed.cmp(&a.last_accessed));
|
||||
|
||||
// Get the most recent project
|
||||
let project = &projects[0];
|
||||
let project_dir = &project.path;
|
||||
|
||||
// Check if the directory exists
|
||||
if !Path::new(project_dir).exists() {
|
||||
println!(
|
||||
"Most recent project directory '{}' no longer exists.",
|
||||
project_dir
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Format the path for display
|
||||
let path = Path::new(project_dir);
|
||||
let components: Vec<_> = path.components().collect();
|
||||
let len = components.len();
|
||||
let short_path = if len <= 2 {
|
||||
project_dir.clone()
|
||||
} else {
|
||||
let mut path_str = String::new();
|
||||
path_str.push_str("...");
|
||||
for component in components.iter().skip(len - 2) {
|
||||
path_str.push('/');
|
||||
path_str.push_str(component.as_os_str().to_string_lossy().as_ref());
|
||||
}
|
||||
path_str
|
||||
};
|
||||
|
||||
// Ask the user what they want to do
|
||||
let _ = intro("Goose Project Manager");
|
||||
|
||||
let current_dir = std::env::current_dir()?;
|
||||
let current_dir_display = current_dir.display();
|
||||
|
||||
let choice = cliclack::select("Choose an option:")
|
||||
.item(
|
||||
"resume",
|
||||
format!("Resume project with session: {}", short_path),
|
||||
"Continue with the previous session",
|
||||
)
|
||||
.item(
|
||||
"fresh",
|
||||
format!("Resume project with fresh session: {}", short_path),
|
||||
"Change to the project directory but start a new session",
|
||||
)
|
||||
.item(
|
||||
"new",
|
||||
format!(
|
||||
"Start new project in current directory: {}",
|
||||
current_dir_display
|
||||
),
|
||||
"Stay in the current directory and start a new session",
|
||||
)
|
||||
.interact()?;
|
||||
|
||||
match choice {
|
||||
"resume" => {
|
||||
let _ = outro(format!("Changing to directory: {}", project_dir));
|
||||
|
||||
// Get the session ID if available
|
||||
let session_id = project.last_session_id.clone();
|
||||
|
||||
// Change to the project directory
|
||||
std::env::set_current_dir(project_dir)?;
|
||||
|
||||
// Build the command to run Goose
|
||||
let mut command = std::process::Command::new("goose");
|
||||
command.arg("session");
|
||||
|
||||
if let Some(id) = session_id {
|
||||
command.arg("--name").arg(&id).arg("--resume");
|
||||
println!("Resuming session: {}", id);
|
||||
}
|
||||
|
||||
// Execute the command
|
||||
let status = command.status()?;
|
||||
|
||||
if !status.success() {
|
||||
println!("Failed to run Goose. Exit code: {:?}", status.code());
|
||||
}
|
||||
}
|
||||
"fresh" => {
|
||||
let _ = outro(format!(
|
||||
"Changing to directory: {} with a fresh session",
|
||||
project_dir
|
||||
));
|
||||
|
||||
// Change to the project directory
|
||||
std::env::set_current_dir(project_dir)?;
|
||||
|
||||
// Build the command to run Goose with a fresh session
|
||||
let mut command = std::process::Command::new("goose");
|
||||
command.arg("session");
|
||||
|
||||
// Execute the command
|
||||
let status = command.status()?;
|
||||
|
||||
if !status.success() {
|
||||
println!("Failed to run Goose. Exit code: {:?}", status.code());
|
||||
}
|
||||
}
|
||||
"new" => {
|
||||
let _ = outro("Starting a new session in the current directory");
|
||||
|
||||
// Build the command to run Goose
|
||||
let mut command = std::process::Command::new("goose");
|
||||
command.arg("session");
|
||||
|
||||
// Execute the command
|
||||
let status = command.status()?;
|
||||
|
||||
if !status.success() {
|
||||
println!("Failed to run Goose. Exit code: {:?}", status.code());
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let _ = outro("Operation canceled");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle the interactive projects command
|
||||
///
|
||||
/// Shows a list of projects and lets the user select one to resume
|
||||
pub fn handle_projects_interactive() -> Result<()> {
|
||||
let tracker = ProjectTracker::load()?;
|
||||
let mut projects = tracker.list_projects();
|
||||
|
||||
if projects.is_empty() {
|
||||
println!("No projects found.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Sort projects by last_accessed (newest first)
|
||||
projects.sort_by(|a, b| b.last_accessed.cmp(&a.last_accessed));
|
||||
|
||||
// Format project paths for display
|
||||
let project_choices: Vec<(String, String)> = projects
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, project)| {
|
||||
let path = Path::new(&project.path);
|
||||
let components: Vec<_> = path.components().collect();
|
||||
let len = components.len();
|
||||
let short_path = if len <= 2 {
|
||||
project.path.clone()
|
||||
} else {
|
||||
let mut path_str = String::new();
|
||||
path_str.push_str("...");
|
||||
for component in components.iter().skip(len - 2) {
|
||||
path_str.push('/');
|
||||
path_str.push_str(component.as_os_str().to_string_lossy().as_ref());
|
||||
}
|
||||
path_str
|
||||
};
|
||||
|
||||
// Include last instruction if available (truncated)
|
||||
let instruction_preview =
|
||||
project
|
||||
.last_instruction
|
||||
.as_ref()
|
||||
.map_or(String::new(), |instr| {
|
||||
let truncated = if instr.len() > 40 {
|
||||
format!("{}...", &instr[0..37])
|
||||
} else {
|
||||
instr.clone()
|
||||
};
|
||||
format!(" [{}]", truncated)
|
||||
});
|
||||
|
||||
let formatted_date = format_date(project.last_accessed);
|
||||
(
|
||||
format!("{}", i + 1), // Value to return
|
||||
format!("{} ({}){}", short_path, formatted_date, instruction_preview), // Display text with instruction
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Let the user select a project
|
||||
let _ = intro("Goose Project Manager");
|
||||
let mut select = cliclack::select("Select a project:");
|
||||
|
||||
// Add each project as an option
|
||||
for (value, display) in &project_choices {
|
||||
select = select.item(value, display, "");
|
||||
}
|
||||
|
||||
// Add a cancel option
|
||||
let cancel_value = String::from("cancel");
|
||||
select = select.item(&cancel_value, "Cancel", "Don't resume any project");
|
||||
|
||||
let selected = select.interact()?;
|
||||
|
||||
if selected == "cancel" {
|
||||
let _ = outro("Project selection canceled.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Parse the selected index
|
||||
let index = selected.parse::<usize>().unwrap_or(0);
|
||||
if index == 0 || index > projects.len() {
|
||||
let _ = outro("Invalid selection.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Get the selected project
|
||||
let project = &projects[index - 1];
|
||||
let project_dir = &project.path;
|
||||
|
||||
// Check if the directory exists
|
||||
if !Path::new(project_dir).exists() {
|
||||
let _ = outro(format!(
|
||||
"Project directory '{}' no longer exists.",
|
||||
project_dir
|
||||
));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Ask if the user wants to resume the session or start a new one
|
||||
let session_id = project.last_session_id.clone();
|
||||
let has_previous_session = session_id.is_some();
|
||||
|
||||
// Change to the project directory first
|
||||
std::env::set_current_dir(project_dir)?;
|
||||
let _ = outro(format!("Changed to directory: {}", project_dir));
|
||||
|
||||
// Only ask about resuming if there's a previous session
|
||||
let resume_session = if has_previous_session {
|
||||
let session_choice = cliclack::select("What would you like to do?")
|
||||
.item(
|
||||
"resume",
|
||||
"Resume previous session",
|
||||
"Continue with the previous session",
|
||||
)
|
||||
.item(
|
||||
"new",
|
||||
"Start new session",
|
||||
"Start a fresh session in this project directory",
|
||||
)
|
||||
.interact()?;
|
||||
|
||||
session_choice == "resume"
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// Build the command to run Goose
|
||||
let mut command = std::process::Command::new("goose");
|
||||
command.arg("session");
|
||||
|
||||
if resume_session {
|
||||
if let Some(id) = session_id {
|
||||
command.arg("--name").arg(&id).arg("--resume");
|
||||
println!("Resuming session: {}", id);
|
||||
}
|
||||
} else {
|
||||
println!("Starting new session");
|
||||
}
|
||||
|
||||
// Execute the command
|
||||
let status = command.status()?;
|
||||
|
||||
if !status.success() {
|
||||
println!("Failed to run Goose. Exit code: {:?}", status.code());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use console::style;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::recipe::load_recipe;
|
||||
use crate::recipes::recipe::load_recipe;
|
||||
|
||||
/// Validates a recipe file
|
||||
///
|
||||
@@ -14,9 +13,9 @@ use crate::recipe::load_recipe;
|
||||
/// # Returns
|
||||
///
|
||||
/// Result indicating success or failure
|
||||
pub fn handle_validate<P: AsRef<Path>>(file_path: P) -> Result<()> {
|
||||
pub fn handle_validate(recipe_name: &str) -> Result<()> {
|
||||
// Load and validate the recipe file
|
||||
match load_recipe(&file_path, false) {
|
||||
match load_recipe(recipe_name) {
|
||||
Ok(_) => {
|
||||
println!("{} recipe file is valid", style("✓").green().bold());
|
||||
Ok(())
|
||||
@@ -37,9 +36,9 @@ pub fn handle_validate<P: AsRef<Path>>(file_path: P) -> Result<()> {
|
||||
/// # Returns
|
||||
///
|
||||
/// Result indicating success or failure
|
||||
pub fn handle_deeplink<P: AsRef<Path>>(file_path: P) -> Result<()> {
|
||||
pub fn handle_deeplink(recipe_name: &str) -> Result<()> {
|
||||
// Load the recipe file first to validate it
|
||||
match load_recipe(&file_path, false) {
|
||||
match load_recipe(recipe_name) {
|
||||
Ok(recipe) => {
|
||||
if let Ok(recipe_json) = serde_json::to_string(&recipe) {
|
||||
let deeplink = base64::engine::general_purpose::STANDARD.encode(recipe_json);
|
||||
|
||||
186
crates/goose-cli/src/commands/schedule.rs
Normal file
186
crates/goose-cli/src/commands/schedule.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use base64::engine::{general_purpose::STANDARD as BASE64_STANDARD, Engine};
|
||||
use goose::scheduler::{
|
||||
get_default_scheduled_recipes_dir, get_default_scheduler_storage_path, ScheduledJob, Scheduler,
|
||||
SchedulerError,
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
// Base64 decoding function - might be needed if recipe_source_arg can be base64
|
||||
// For now, handle_schedule_add will assume it's a path.
|
||||
async fn _decode_base64_recipe(source: &str) -> Result<String> {
|
||||
let bytes = BASE64_STANDARD
|
||||
.decode(source.as_bytes())
|
||||
.with_context(|| "Recipe source is not a valid path and not valid Base64.")?;
|
||||
String::from_utf8(bytes).with_context(|| "Decoded Base64 recipe source is not valid UTF-8.")
|
||||
}
|
||||
|
||||
pub async fn handle_schedule_add(
|
||||
id: String,
|
||||
cron: String,
|
||||
recipe_source_arg: String, // This is expected to be a file path by the Scheduler
|
||||
) -> Result<()> {
|
||||
println!(
|
||||
"[CLI Debug] Scheduling job ID: {}, Cron: {}, Recipe Source Path: {}",
|
||||
id, cron, recipe_source_arg
|
||||
);
|
||||
|
||||
// The Scheduler's add_scheduled_job will handle copying the recipe from recipe_source_arg
|
||||
// to its internal storage and validating the path.
|
||||
let job = ScheduledJob {
|
||||
id: id.clone(),
|
||||
source: recipe_source_arg.clone(), // Pass the original user-provided path
|
||||
cron,
|
||||
last_run: None,
|
||||
currently_running: false,
|
||||
paused: false,
|
||||
};
|
||||
|
||||
let scheduler_storage_path =
|
||||
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
|
||||
let scheduler = Scheduler::new(scheduler_storage_path)
|
||||
.await
|
||||
.context("Failed to initialize scheduler")?;
|
||||
|
||||
match scheduler.add_scheduled_job(job).await {
|
||||
Ok(_) => {
|
||||
// The scheduler has copied the recipe to its internal directory.
|
||||
// We can reconstruct the likely path for display if needed, or adjust success message.
|
||||
let scheduled_recipes_dir = get_default_scheduled_recipes_dir()
|
||||
.unwrap_or_else(|_| Path::new("./.goose_scheduled_recipes").to_path_buf()); // Fallback for display
|
||||
let extension = Path::new(&recipe_source_arg)
|
||||
.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.unwrap_or("yaml");
|
||||
let final_recipe_path = scheduled_recipes_dir.join(format!("{}.{}", id, extension));
|
||||
|
||||
println!(
|
||||
"Scheduled job '{}' added. Recipe expected at {:?}",
|
||||
id, final_recipe_path
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
// No local file to clean up by the CLI in this revised flow.
|
||||
match e {
|
||||
SchedulerError::JobIdExists(job_id) => {
|
||||
bail!("Error: Job with ID '{}' already exists.", job_id);
|
||||
}
|
||||
SchedulerError::RecipeLoadError(msg) => {
|
||||
bail!(
|
||||
"Error with recipe source: {}. Path: {}",
|
||||
msg,
|
||||
recipe_source_arg
|
||||
);
|
||||
}
|
||||
_ => Err(anyhow::Error::new(e))
|
||||
.context(format!("Failed to add job '{}' to scheduler", id)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_schedule_list() -> Result<()> {
|
||||
let scheduler_storage_path =
|
||||
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
|
||||
let scheduler = Scheduler::new(scheduler_storage_path)
|
||||
.await
|
||||
.context("Failed to initialize scheduler")?;
|
||||
|
||||
let jobs = scheduler.list_scheduled_jobs().await;
|
||||
if jobs.is_empty() {
|
||||
println!("No scheduled jobs found.");
|
||||
} else {
|
||||
println!("Scheduled Jobs:");
|
||||
for job in jobs {
|
||||
println!(
|
||||
"- ID: {}\n Cron: {}\n Recipe Source (in store): {}\n Last Run: {}",
|
||||
job.id,
|
||||
job.cron,
|
||||
job.source, // This source is now the path within scheduled_recipes_dir
|
||||
job.last_run
|
||||
.map_or_else(|| "Never".to_string(), |dt| dt.to_rfc3339())
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_schedule_remove(id: String) -> Result<()> {
|
||||
let scheduler_storage_path =
|
||||
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
|
||||
let scheduler = Scheduler::new(scheduler_storage_path)
|
||||
.await
|
||||
.context("Failed to initialize scheduler")?;
|
||||
|
||||
match scheduler.remove_scheduled_job(&id).await {
|
||||
Ok(_) => {
|
||||
println!("Scheduled job '{}' and its associated recipe removed.", id);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => match e {
|
||||
SchedulerError::JobNotFound(job_id) => {
|
||||
bail!("Error: Job with ID '{}' not found.", job_id);
|
||||
}
|
||||
_ => Err(anyhow::Error::new(e))
|
||||
.context(format!("Failed to remove job '{}' from scheduler", id)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_schedule_sessions(id: String, limit: Option<u32>) -> Result<()> {
|
||||
let scheduler_storage_path =
|
||||
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
|
||||
let scheduler = Scheduler::new(scheduler_storage_path)
|
||||
.await
|
||||
.context("Failed to initialize scheduler")?;
|
||||
|
||||
match scheduler.sessions(&id, limit.unwrap_or(50) as usize).await {
|
||||
Ok(sessions) => {
|
||||
if sessions.is_empty() {
|
||||
println!("No sessions found for schedule ID '{}'.", id);
|
||||
} else {
|
||||
println!("Sessions for schedule ID '{}':", id);
|
||||
// sessions is now Vec<(String, SessionMetadata)>
|
||||
for (session_name, metadata) in sessions {
|
||||
println!(
|
||||
" - Session ID: {}, Working Dir: {}, Description: \"{}\", Messages: {}, Schedule ID: {:?}",
|
||||
session_name, // Display the session_name as Session ID
|
||||
metadata.working_dir.display(),
|
||||
metadata.description,
|
||||
metadata.message_count,
|
||||
metadata.schedule_id.as_deref().unwrap_or("N/A")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
bail!("Failed to get sessions for schedule '{}': {:?}", id, e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_schedule_run_now(id: String) -> Result<()> {
|
||||
let scheduler_storage_path =
|
||||
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
|
||||
let scheduler = Scheduler::new(scheduler_storage_path)
|
||||
.await
|
||||
.context("Failed to initialize scheduler")?;
|
||||
|
||||
match scheduler.run_now(&id).await {
|
||||
Ok(session_id) => {
|
||||
println!(
|
||||
"Successfully triggered schedule '{}'. New session ID: {}",
|
||||
id, session_id
|
||||
);
|
||||
}
|
||||
Err(e) => match e {
|
||||
SchedulerError::JobNotFound(job_id) => {
|
||||
bail!("Error: Job with ID '{}' not found.", job_id);
|
||||
}
|
||||
_ => bail!("Failed to run schedule '{}' now: {:?}", id, e),
|
||||
},
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,18 +1,20 @@
|
||||
use anyhow::{Context, Result};
|
||||
use cliclack::{confirm, multiselect};
|
||||
use goose::session::info::{get_session_info, SessionInfo, SortOrder};
|
||||
use regex::Regex;
|
||||
use std::fs;
|
||||
|
||||
const TRUNCATED_DESC_LENGTH: usize = 60;
|
||||
|
||||
pub fn remove_sessions(sessions: Vec<SessionInfo>) -> Result<()> {
|
||||
println!("The following sessions will be removed:");
|
||||
for session in &sessions {
|
||||
println!("- {}", session.id);
|
||||
}
|
||||
|
||||
let should_delete =
|
||||
cliclack::confirm("Are you sure you want to delete all these sessions? (yes/no):")
|
||||
.initial_value(true)
|
||||
.interact()?;
|
||||
let should_delete = confirm("Are you sure you want to delete these sessions?")
|
||||
.initial_value(false)
|
||||
.interact()?;
|
||||
|
||||
if should_delete {
|
||||
for session in sessions {
|
||||
@@ -27,8 +29,50 @@ pub fn remove_sessions(sessions: Vec<SessionInfo>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn handle_session_remove(id: String, regex_string: String) -> Result<()> {
|
||||
let sessions = match get_session_info(SortOrder::Descending) {
|
||||
fn prompt_interactive_session_selection(sessions: &[SessionInfo]) -> Result<Vec<SessionInfo>> {
|
||||
if sessions.is_empty() {
|
||||
println!("No sessions to delete.");
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let mut selector = multiselect(
|
||||
"Select sessions to delete (use spacebar, Enter to confirm, Ctrl+C to cancel):",
|
||||
);
|
||||
|
||||
let display_map: std::collections::HashMap<String, SessionInfo> = sessions
|
||||
.iter()
|
||||
.map(|s| {
|
||||
let desc = if s.metadata.description.is_empty() {
|
||||
"(no description)"
|
||||
} else {
|
||||
&s.metadata.description
|
||||
};
|
||||
let truncated_desc = if desc.len() > TRUNCATED_DESC_LENGTH {
|
||||
format!("{}...", &desc[..TRUNCATED_DESC_LENGTH - 3])
|
||||
} else {
|
||||
desc.to_string()
|
||||
};
|
||||
let display_text = format!("{} - {} ({})", s.modified, truncated_desc, s.id);
|
||||
(display_text, s.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
for display_text in display_map.keys() {
|
||||
selector = selector.item(display_text.clone(), display_text.clone(), "");
|
||||
}
|
||||
|
||||
let selected_display_texts: Vec<String> = selector.interact()?;
|
||||
|
||||
let selected_sessions: Vec<SessionInfo> = selected_display_texts
|
||||
.into_iter()
|
||||
.filter_map(|text| display_map.get(&text).cloned())
|
||||
.collect();
|
||||
|
||||
Ok(selected_sessions)
|
||||
}
|
||||
|
||||
pub fn handle_session_remove(id: Option<String>, regex_string: Option<String>) -> Result<()> {
|
||||
let all_sessions = match get_session_info(SortOrder::Descending) {
|
||||
Ok(sessions) => sessions,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to retrieve sessions: {:?}", e);
|
||||
@@ -37,29 +81,35 @@ pub fn handle_session_remove(id: String, regex_string: String) -> Result<()> {
|
||||
};
|
||||
|
||||
let matched_sessions: Vec<SessionInfo>;
|
||||
if !id.is_empty() {
|
||||
if let Some(session) = sessions.iter().find(|s| s.id == id) {
|
||||
|
||||
if let Some(id_val) = id {
|
||||
if let Some(session) = all_sessions.iter().find(|s| s.id == id_val) {
|
||||
matched_sessions = vec![session.clone()];
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Session '{}' not found.", id));
|
||||
return Err(anyhow::anyhow!("Session '{}' not found.", id_val));
|
||||
}
|
||||
} else if !regex_string.is_empty() {
|
||||
let session_regex = Regex::new(®ex_string)
|
||||
.with_context(|| format!("Invalid regex pattern '{}'", regex_string))?;
|
||||
matched_sessions = sessions
|
||||
} else if let Some(regex_val) = regex_string {
|
||||
let session_regex = Regex::new(®ex_val)
|
||||
.with_context(|| format!("Invalid regex pattern '{}'", regex_val))?;
|
||||
|
||||
matched_sessions = all_sessions
|
||||
.into_iter()
|
||||
.filter(|session| session_regex.is_match(&session.id))
|
||||
.collect();
|
||||
|
||||
if matched_sessions.is_empty() {
|
||||
println!(
|
||||
"Regex string '{}' does not match any sessions",
|
||||
regex_string
|
||||
);
|
||||
println!("Regex string '{}' does not match any sessions", regex_val);
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Neither --regex nor --id flags provided."));
|
||||
if all_sessions.is_empty() {
|
||||
return Err(anyhow::anyhow!("No sessions found."));
|
||||
}
|
||||
matched_sessions = prompt_interactive_session_selection(&all_sessions)?;
|
||||
}
|
||||
|
||||
if matched_sessions.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
remove_sessions(matched_sessions)
|
||||
|
||||
@@ -3,9 +3,10 @@ use once_cell::sync::Lazy;
|
||||
pub mod cli;
|
||||
pub mod commands;
|
||||
pub mod logging;
|
||||
pub mod recipe;
|
||||
pub mod project_tracker;
|
||||
pub mod recipes;
|
||||
pub mod session;
|
||||
|
||||
pub mod signal;
|
||||
// Re-export commonly used types
|
||||
pub use session::Session;
|
||||
|
||||
|
||||
146
crates/goose-cli/src/project_tracker.rs
Normal file
146
crates/goose-cli/src/project_tracker.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use etcetera::{choose_app_strategy, AppStrategy};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Structure to track project information
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ProjectInfo {
|
||||
/// The absolute path to the project directory
|
||||
pub path: String,
|
||||
/// Last time the project was accessed
|
||||
pub last_accessed: DateTime<Utc>,
|
||||
/// Last instruction sent to goose (if available)
|
||||
pub last_instruction: Option<String>,
|
||||
/// Last session ID associated with this project
|
||||
pub last_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Structure to hold all tracked projects
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ProjectTracker {
|
||||
projects: HashMap<String, ProjectInfo>,
|
||||
}
|
||||
|
||||
/// Project information with path as a separate field for easier access
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProjectInfoDisplay {
|
||||
/// The absolute path to the project directory
|
||||
pub path: String,
|
||||
/// Last time the project was accessed
|
||||
pub last_accessed: DateTime<Utc>,
|
||||
/// Last instruction sent to goose (if available)
|
||||
pub last_instruction: Option<String>,
|
||||
/// Last session ID associated with this project
|
||||
pub last_session_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ProjectTracker {
|
||||
/// Get the path to the projects.json file
|
||||
fn get_projects_file() -> Result<PathBuf> {
|
||||
let projects_file = choose_app_strategy(crate::APP_STRATEGY.clone())
|
||||
.context("goose requires a home dir")?
|
||||
.in_data_dir("projects.json");
|
||||
|
||||
// Ensure data directory exists
|
||||
if let Some(parent) = projects_file.parent() {
|
||||
if !parent.exists() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(projects_file)
|
||||
}
|
||||
|
||||
/// Load the project tracker from the projects.json file
|
||||
pub fn load() -> Result<Self> {
|
||||
let projects_file = Self::get_projects_file()?;
|
||||
|
||||
if projects_file.exists() {
|
||||
let file_content = fs::read_to_string(&projects_file)?;
|
||||
let tracker: ProjectTracker = serde_json::from_str(&file_content)
|
||||
.context("Failed to parse projects.json file")?;
|
||||
Ok(tracker)
|
||||
} else {
|
||||
// If the file doesn't exist, create a new empty tracker
|
||||
Ok(ProjectTracker {
|
||||
projects: HashMap::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Save the project tracker to the projects.json file
|
||||
pub fn save(&self) -> Result<()> {
|
||||
let projects_file = Self::get_projects_file()?;
|
||||
let json = serde_json::to_string_pretty(self)?;
|
||||
fs::write(projects_file, json)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update project information for the current directory
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `project_dir` - The project directory to update
|
||||
/// * `instruction` - Optional instruction that was sent to goose
|
||||
/// * `session_id` - Optional session ID associated with this project
|
||||
pub fn update_project(
|
||||
&mut self,
|
||||
project_dir: &Path,
|
||||
instruction: Option<&str>,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let dir_str = project_dir.to_string_lossy().to_string();
|
||||
|
||||
// Create or update the project entry
|
||||
let project_info = self.projects.entry(dir_str.clone()).or_insert(ProjectInfo {
|
||||
path: dir_str,
|
||||
last_accessed: Utc::now(),
|
||||
last_instruction: None,
|
||||
last_session_id: None,
|
||||
});
|
||||
|
||||
// Update the last accessed time
|
||||
project_info.last_accessed = Utc::now();
|
||||
|
||||
// Update the last instruction if provided
|
||||
if let Some(instr) = instruction {
|
||||
project_info.last_instruction = Some(instr.to_string());
|
||||
}
|
||||
|
||||
// Update the session ID if provided
|
||||
if let Some(id) = session_id {
|
||||
project_info.last_session_id = Some(id.to_string());
|
||||
}
|
||||
|
||||
self.save()
|
||||
}
|
||||
|
||||
/// List all tracked projects
|
||||
///
|
||||
/// Returns a vector of ProjectInfoDisplay objects
|
||||
pub fn list_projects(&self) -> Vec<ProjectInfoDisplay> {
|
||||
self.projects
|
||||
.values()
|
||||
.map(|info| ProjectInfoDisplay {
|
||||
path: info.path.clone(),
|
||||
last_accessed: info.last_accessed,
|
||||
last_instruction: info.last_instruction.clone(),
|
||||
last_session_id: info.last_session_id.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the project tracker with the current directory and optional instruction
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `instruction` - Optional instruction that was sent to goose
|
||||
/// * `session_id` - Optional session ID associated with this project
|
||||
pub fn update_project_tracker(instruction: Option<&str>, session_id: Option<&str>) -> Result<()> {
|
||||
let current_dir = std::env::current_dir()?;
|
||||
let mut tracker = ProjectTracker::load()?;
|
||||
tracker.update_project(¤t_dir, instruction, session_id)
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
use anyhow::{Context, Result};
|
||||
use console::style;
|
||||
use std::path::Path;
|
||||
|
||||
use goose::recipe::Recipe;
|
||||
|
||||
/// Loads and validates a recipe from a YAML or JSON file
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the recipe file (YAML or JSON)
|
||||
/// * `log` - whether to log information about the recipe or not
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The parsed recipe struct if successful
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - The file doesn't exist
|
||||
/// - The file can't be read
|
||||
/// - The YAML/JSON is invalid
|
||||
/// - The required fields are missing
|
||||
pub fn load_recipe<P: AsRef<Path>>(path: P, log: bool) -> Result<Recipe> {
|
||||
let path = path.as_ref();
|
||||
|
||||
// Check if file exists
|
||||
if !path.exists() {
|
||||
return Err(anyhow::anyhow!("recipe file not found: {}", path.display()));
|
||||
}
|
||||
// Read file content
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read recipe file: {}", path.display()))?;
|
||||
|
||||
// Determine file format based on extension and parse accordingly
|
||||
let recipe: Recipe = if let Some(extension) = path.extension() {
|
||||
match extension.to_str().unwrap_or("").to_lowercase().as_str() {
|
||||
"json" => serde_json::from_str(&content)
|
||||
.with_context(|| format!("Failed to parse JSON recipe file: {}", path.display()))?,
|
||||
"yaml" => serde_yaml::from_str(&content)
|
||||
.with_context(|| format!("Failed to parse YAML recipe file: {}", path.display()))?,
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unsupported file format for recipe file: {}. Expected .yaml or .json",
|
||||
path.display()
|
||||
))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"File has no extension: {}. Expected .yaml or .json",
|
||||
path.display()
|
||||
));
|
||||
};
|
||||
|
||||
if log {
|
||||
// Display information about the loaded recipe
|
||||
println!(
|
||||
"{} {}",
|
||||
style("Loading recipe:").green().bold(),
|
||||
style(&recipe.title).green()
|
||||
);
|
||||
println!("{} {}", style("Description:").dim(), &recipe.description);
|
||||
|
||||
println!(); // Add a blank line for spacing
|
||||
}
|
||||
|
||||
Ok(recipe)
|
||||
}
|
||||
192
crates/goose-cli/src/recipes/github_recipe.rs
Normal file
192
crates/goose-cli/src/recipes/github_recipe.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
use anyhow::Result;
|
||||
use console::style;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::process::Stdio;
|
||||
use tar::Archive;
|
||||
|
||||
use crate::recipes::recipe::RECIPE_FILE_EXTENSIONS;
|
||||
|
||||
pub const GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY: &str = "GOOSE_RECIPE_GITHUB_REPO";
|
||||
pub fn retrieve_recipe_from_github(
|
||||
recipe_name: &str,
|
||||
recipe_repo_full_name: &str,
|
||||
) -> Result<(String, PathBuf)> {
|
||||
println!(
|
||||
"📦 Looking for recipe \"{}\" in github repo: {}",
|
||||
recipe_name, recipe_repo_full_name
|
||||
);
|
||||
ensure_gh_authenticated()?;
|
||||
let max_attempts = 2;
|
||||
let mut last_err = None;
|
||||
|
||||
for attempt in 1..=max_attempts {
|
||||
match clone_and_download_recipe(recipe_name, recipe_repo_full_name) {
|
||||
Ok(download_dir) => match read_recipe_file(&download_dir) {
|
||||
Ok(content) => return Ok((content, download_dir)),
|
||||
Err(err) => return Err(err),
|
||||
},
|
||||
Err(err) => {
|
||||
last_err = Some(err);
|
||||
}
|
||||
}
|
||||
if attempt < max_attempts {
|
||||
clean_cloned_dirs(recipe_repo_full_name)?;
|
||||
}
|
||||
}
|
||||
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Unknown error occurred")))
|
||||
}
|
||||
|
||||
fn clean_cloned_dirs(recipe_repo_full_name: &str) -> anyhow::Result<()> {
|
||||
let local_repo_path = get_local_repo_path(&env::temp_dir(), recipe_repo_full_name)?;
|
||||
if local_repo_path.exists() {
|
||||
fs::remove_dir_all(&local_repo_path)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn read_recipe_file(download_dir: &Path) -> Result<String> {
|
||||
for ext in RECIPE_FILE_EXTENSIONS {
|
||||
let candidate_file_path = download_dir.join(format!("recipe.{}", ext));
|
||||
if candidate_file_path.exists() {
|
||||
let content = fs::read_to_string(&candidate_file_path)?;
|
||||
println!(
|
||||
"⬇️ Retrieved recipe file: {}",
|
||||
candidate_file_path
|
||||
.strip_prefix(download_dir)
|
||||
.unwrap()
|
||||
.display()
|
||||
);
|
||||
return Ok(content);
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!(
|
||||
"No recipe file found in {} (looked for extensions: {:?})",
|
||||
download_dir.display(),
|
||||
RECIPE_FILE_EXTENSIONS
|
||||
))
|
||||
}
|
||||
|
||||
fn clone_and_download_recipe(recipe_name: &str, recipe_repo_full_name: &str) -> Result<PathBuf> {
|
||||
let local_repo_path = ensure_repo_cloned(recipe_repo_full_name)?;
|
||||
fetch_origin(&local_repo_path)?;
|
||||
get_folder_from_github(&local_repo_path, recipe_name)
|
||||
}
|
||||
|
||||
fn ensure_gh_authenticated() -> Result<()> {
|
||||
// Check authentication status
|
||||
let status = Command::new("gh")
|
||||
.args(["auth", "status"])
|
||||
.status()
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!("Failed to run `gh auth status`. Make sure you have `gh` installed.")
|
||||
})?;
|
||||
|
||||
if status.success() {
|
||||
return Ok(());
|
||||
}
|
||||
println!("GitHub CLI is not authenticated. Launching `gh auth login`...");
|
||||
// Run `gh auth login` interactively
|
||||
let login_status = Command::new("gh")
|
||||
.args(["auth", "login", "--git-protocol", "https"])
|
||||
.status()
|
||||
.map_err(|_| anyhow::anyhow!("Failed to run `gh auth login`"))?;
|
||||
|
||||
if !login_status.success() {
|
||||
Err(anyhow::anyhow!("Failed to authenticate using GitHub CLI."))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_local_repo_path(
|
||||
local_repo_parent_path: &Path,
|
||||
recipe_repo_full_name: &str,
|
||||
) -> Result<PathBuf> {
|
||||
let (_, repo_name) = recipe_repo_full_name
|
||||
.split_once('/')
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid repository name format"))?;
|
||||
let local_repo_path = local_repo_parent_path.to_path_buf().join(repo_name);
|
||||
Ok(local_repo_path)
|
||||
}
|
||||
|
||||
fn ensure_repo_cloned(recipe_repo_full_name: &str) -> Result<PathBuf> {
|
||||
let local_repo_parent_path = env::temp_dir();
|
||||
if !local_repo_parent_path.exists() {
|
||||
std::fs::create_dir_all(local_repo_parent_path.clone())?;
|
||||
}
|
||||
let local_repo_path = get_local_repo_path(&local_repo_parent_path, recipe_repo_full_name)?;
|
||||
|
||||
if local_repo_path.join(".git").exists() {
|
||||
Ok(local_repo_path)
|
||||
} else {
|
||||
let error_message: String = format!("Failed to clone repo: {}", recipe_repo_full_name);
|
||||
let status = Command::new("gh")
|
||||
.args(["repo", "clone", recipe_repo_full_name])
|
||||
.current_dir(local_repo_parent_path.clone())
|
||||
.status()
|
||||
.map_err(|_: std::io::Error| anyhow::anyhow!(error_message.clone()))?;
|
||||
|
||||
if status.success() {
|
||||
Ok(local_repo_path)
|
||||
} else {
|
||||
Err(anyhow::anyhow!(error_message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_origin(local_repo_path: &Path) -> Result<()> {
|
||||
let error_message: String = format!("Failed to fetch at {}", local_repo_path.to_str().unwrap());
|
||||
let status = Command::new("git")
|
||||
.args(["fetch", "origin"])
|
||||
.current_dir(local_repo_path)
|
||||
.status()
|
||||
.map_err(|_| anyhow::anyhow!(error_message.clone()))?;
|
||||
|
||||
if status.success() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(error_message))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_folder_from_github(local_repo_path: &Path, recipe_name: &str) -> Result<PathBuf> {
|
||||
let ref_and_path = format!("origin/main:{}", recipe_name);
|
||||
let output_dir = env::temp_dir().join(recipe_name);
|
||||
|
||||
if output_dir.exists() {
|
||||
fs::remove_dir_all(&output_dir)?;
|
||||
}
|
||||
fs::create_dir_all(&output_dir)?;
|
||||
|
||||
let archive_output = Command::new("git")
|
||||
.args(["archive", &ref_and_path])
|
||||
.current_dir(local_repo_path)
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let stdout = archive_output
|
||||
.stdout
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to capture stdout from git archive"))?;
|
||||
|
||||
let mut archive = Archive::new(stdout);
|
||||
archive.unpack(&output_dir)?;
|
||||
list_files(&output_dir)?;
|
||||
|
||||
Ok(output_dir)
|
||||
}
|
||||
|
||||
fn list_files(dir: &Path) -> Result<()> {
|
||||
println!("{}", style("Files downloaded from github:").bold());
|
||||
for entry in fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_file() {
|
||||
println!(" - {}", path.display());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
4
crates/goose-cli/src/recipes/mod.rs
Normal file
4
crates/goose-cli/src/recipes/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod github_recipe;
|
||||
pub mod print_recipe;
|
||||
pub mod recipe;
|
||||
pub mod search_recipe;
|
||||
83
crates/goose-cli/src/recipes/print_recipe.rs
Normal file
83
crates/goose-cli/src/recipes/print_recipe.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use console::style;
|
||||
use goose::recipe::Recipe;
|
||||
|
||||
use crate::recipes::recipe::BUILT_IN_RECIPE_DIR_PARAM;
|
||||
|
||||
pub fn print_recipe_explanation(recipe: &Recipe) {
|
||||
println!(
|
||||
"{} {}",
|
||||
style("🔍 Loading recipe:").bold().green(),
|
||||
style(&recipe.title).green()
|
||||
);
|
||||
println!("{}", style("📄 Description:").bold());
|
||||
println!(" {}", recipe.description);
|
||||
if let Some(params) = &recipe.parameters {
|
||||
if !params.is_empty() {
|
||||
println!("{}", style("⚙️ Recipe Parameters:").bold());
|
||||
for param in params {
|
||||
let default_display = match ¶m.default {
|
||||
Some(val) => format!(" (default: {})", val),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
println!(
|
||||
" - {} ({}, {}){}: {}",
|
||||
style(¶m.key).cyan(),
|
||||
param.input_type,
|
||||
param.requirement,
|
||||
default_display,
|
||||
param.description
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_parameters_with_values(params: HashMap<String, String>) {
|
||||
for (key, value) in params {
|
||||
let label = if key == BUILT_IN_RECIPE_DIR_PARAM {
|
||||
" (built-in)"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
println!(" {}{}: {}", key, label, value);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_required_parameters_for_template(
|
||||
params_for_template: HashMap<String, String>,
|
||||
missing_params: Vec<String>,
|
||||
) {
|
||||
if !params_for_template.is_empty() {
|
||||
println!(
|
||||
"{}",
|
||||
style("📥 Parameters used to load this recipe:").bold()
|
||||
);
|
||||
print_parameters_with_values(params_for_template)
|
||||
}
|
||||
if !missing_params.is_empty() {
|
||||
println!(
|
||||
"{}",
|
||||
style("🔴 Missing parameters in the command line if you want to run the recipe:")
|
||||
.bold()
|
||||
);
|
||||
for param in missing_params.iter() {
|
||||
println!(" - {}", param);
|
||||
}
|
||||
println!(
|
||||
"📩 {}:",
|
||||
style("Please provide the following parameters in the command line if you want to run the recipe:").bold()
|
||||
);
|
||||
println!(" {}", missing_parameters_command_line(missing_params));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn missing_parameters_command_line(missing_params: Vec<String>) -> String {
|
||||
missing_params
|
||||
.iter()
|
||||
.map(|key| format!("--params {}=your_value", key))
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
532
crates/goose-cli/src/recipes/recipe.rs
Normal file
532
crates/goose-cli/src/recipes/recipe.rs
Normal file
@@ -0,0 +1,532 @@
|
||||
use anyhow::Result;
|
||||
use console::style;
|
||||
|
||||
use crate::recipes::print_recipe::{
|
||||
missing_parameters_command_line, print_parameters_with_values, print_recipe_explanation,
|
||||
print_required_parameters_for_template,
|
||||
};
|
||||
use crate::recipes::search_recipe::retrieve_recipe_file;
|
||||
use goose::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement};
|
||||
use minijinja::{Environment, Error, Template, UndefinedBehavior};
|
||||
use serde_json::Value as JsonValue;
|
||||
use serde_yaml::Value as YamlValue;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub const BUILT_IN_RECIPE_DIR_PARAM: &str = "recipe_dir";
|
||||
pub const RECIPE_FILE_EXTENSIONS: &[&str] = &["yaml", "json"];
|
||||
/// Loads, validates a recipe from a YAML or JSON file, and renders it with the given parameters
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the recipe file (YAML or JSON)
|
||||
/// * `params` - parameters to render the recipe with
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The rendered recipe if successful
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - Recipe is not valid
|
||||
/// - The required fields are missing
|
||||
pub fn load_recipe_as_template(recipe_name: &str, params: Vec<(String, String)>) -> Result<Recipe> {
|
||||
let (recipe_file_content, recipe_parent_dir) = retrieve_recipe_file(recipe_name)?;
|
||||
|
||||
let recipe = validate_recipe_file_parameters(&recipe_file_content)?;
|
||||
|
||||
let (params_for_template, missing_params) =
|
||||
apply_values_to_parameters(¶ms, recipe.parameters, recipe_parent_dir, true)?;
|
||||
if !missing_params.is_empty() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Please provide the following parameters in the command line: {}",
|
||||
missing_parameters_command_line(missing_params)
|
||||
));
|
||||
}
|
||||
|
||||
let rendered_content = render_content_with_params(&recipe_file_content, ¶ms_for_template)?;
|
||||
|
||||
let recipe = parse_recipe_content(&rendered_content)?;
|
||||
|
||||
// Display information about the loaded recipe
|
||||
println!(
|
||||
"{} {}",
|
||||
style("Loading recipe:").green().bold(),
|
||||
style(&recipe.title).green()
|
||||
);
|
||||
println!("{} {}", style("Description:").bold(), &recipe.description);
|
||||
|
||||
if !params_for_template.is_empty() {
|
||||
println!("{}", style("Parameters used to load this recipe:").bold());
|
||||
print_parameters_with_values(params_for_template);
|
||||
}
|
||||
println!();
|
||||
Ok(recipe)
|
||||
}
|
||||
|
||||
/// Loads and validates a recipe from a YAML or JSON file
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the recipe file (YAML or JSON)
|
||||
/// * `params` - optional parameters to render the recipe with
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The parsed recipe struct if successful
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - The file doesn't exist
|
||||
/// - The file can't be read
|
||||
/// - The YAML/JSON is invalid
|
||||
/// - The parameter definition does not match the template variables in the recipe file
|
||||
pub fn load_recipe(recipe_name: &str) -> Result<Recipe> {
|
||||
let (recipe_file_content, _) = retrieve_recipe_file(recipe_name)?;
|
||||
|
||||
validate_recipe_file_parameters(&recipe_file_content)
|
||||
}
|
||||
|
||||
pub fn explain_recipe_with_parameters(
|
||||
recipe_name: &str,
|
||||
params: Vec<(String, String)>,
|
||||
) -> Result<()> {
|
||||
let (recipe_file_content, recipe_parent_dir) = retrieve_recipe_file(recipe_name)?;
|
||||
|
||||
let raw_recipe = validate_recipe_file_parameters(&recipe_file_content)?;
|
||||
print_recipe_explanation(&raw_recipe);
|
||||
let recipe_parameters = raw_recipe.parameters;
|
||||
let (params_for_template, missing_params) =
|
||||
apply_values_to_parameters(¶ms, recipe_parameters, recipe_parent_dir, false)?;
|
||||
print_required_parameters_for_template(params_for_template, missing_params);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_recipe_file_parameters(recipe_file_content: &str) -> Result<Recipe> {
|
||||
let recipe_from_recipe_file: Recipe = parse_recipe_content(recipe_file_content)?;
|
||||
validate_optional_parameters(&recipe_from_recipe_file)?;
|
||||
validate_parameters_in_template(&recipe_from_recipe_file.parameters, recipe_file_content)?;
|
||||
Ok(recipe_from_recipe_file)
|
||||
}
|
||||
|
||||
fn validate_parameters_in_template(
|
||||
recipe_parameters: &Option<Vec<RecipeParameter>>,
|
||||
recipe_file_content: &str,
|
||||
) -> Result<()> {
|
||||
let mut template_variables = extract_template_variables(recipe_file_content)?;
|
||||
template_variables.remove(BUILT_IN_RECIPE_DIR_PARAM);
|
||||
|
||||
let param_keys: HashSet<String> = recipe_parameters
|
||||
.as_ref()
|
||||
.unwrap_or(&vec![])
|
||||
.iter()
|
||||
.map(|p| p.key.clone())
|
||||
.collect();
|
||||
|
||||
let missing_keys = template_variables
|
||||
.difference(¶m_keys)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let extra_keys = param_keys
|
||||
.difference(&template_variables)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if missing_keys.is_empty() && extra_keys.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut message = String::new();
|
||||
|
||||
if !missing_keys.is_empty() {
|
||||
message.push_str(&format!(
|
||||
"Missing definitions for parameters in the recipe file: {}.",
|
||||
missing_keys
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
if !extra_keys.is_empty() {
|
||||
message.push_str(&format!(
|
||||
"\nUnnecessary parameter definitions: {}.",
|
||||
extra_keys
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
));
|
||||
}
|
||||
Err(anyhow::anyhow!("{}", message.trim_end()))
|
||||
}
|
||||
|
||||
fn validate_optional_parameters(recipe: &Recipe) -> Result<()> {
|
||||
let optional_params_without_default_values: Vec<String> = recipe
|
||||
.parameters
|
||||
.as_ref()
|
||||
.unwrap_or(&vec![])
|
||||
.iter()
|
||||
.filter(|p| {
|
||||
matches!(p.requirement, RecipeParameterRequirement::Optional) && p.default.is_none()
|
||||
})
|
||||
.map(|p| p.key.clone())
|
||||
.collect();
|
||||
|
||||
if optional_params_without_default_values.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Optional parameters missing default values in the recipe: {}. Please provide defaults.", optional_params_without_default_values.join(", ")))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_recipe_content(content: &str) -> Result<Recipe> {
|
||||
if serde_json::from_str::<JsonValue>(content).is_ok() {
|
||||
Ok(serde_json::from_str(content)?)
|
||||
} else if serde_yaml::from_str::<YamlValue>(content).is_ok() {
|
||||
Ok(serde_yaml::from_str(content)?)
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"Unsupported file format for recipe file. Expected .yaml or .json"
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_template_variables(template_str: &str) -> Result<HashSet<String>> {
|
||||
let mut env = Environment::new();
|
||||
env.set_undefined_behavior(UndefinedBehavior::Strict);
|
||||
|
||||
let template = env
|
||||
.template_from_str(template_str)
|
||||
.map_err(|e: Error| anyhow::anyhow!("Invalid template syntax: {}", e.to_string()))?;
|
||||
|
||||
Ok(template.undeclared_variables(true))
|
||||
}
|
||||
|
||||
fn apply_values_to_parameters(
|
||||
user_params: &[(String, String)],
|
||||
recipe_parameters: Option<Vec<RecipeParameter>>,
|
||||
recipe_parent_dir: PathBuf,
|
||||
enable_user_prompt: bool,
|
||||
) -> Result<(HashMap<String, String>, Vec<String>)> {
|
||||
let mut param_map: HashMap<String, String> = user_params.iter().cloned().collect();
|
||||
let recipe_parent_dir_str = recipe_parent_dir
|
||||
.to_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 in recipe_dir"))?;
|
||||
param_map.insert(
|
||||
BUILT_IN_RECIPE_DIR_PARAM.to_string(),
|
||||
recipe_parent_dir_str.to_string(),
|
||||
);
|
||||
let mut missing_params: Vec<String> = Vec::new();
|
||||
for param in recipe_parameters.unwrap_or_default() {
|
||||
if !param_map.contains_key(¶m.key) {
|
||||
match (¶m.default, ¶m.requirement) {
|
||||
(Some(default), _) => param_map.insert(param.key.clone(), default.clone()),
|
||||
(None, RecipeParameterRequirement::UserPrompt) if enable_user_prompt => {
|
||||
let input_value = cliclack::input(format!(
|
||||
"Please enter {} ({})",
|
||||
param.key, param.description
|
||||
))
|
||||
.interact()?;
|
||||
param_map.insert(param.key.clone(), input_value)
|
||||
}
|
||||
_ => {
|
||||
missing_params.push(param.key.clone());
|
||||
None
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok((param_map, missing_params))
|
||||
}
|
||||
|
||||
fn render_content_with_params(content: &str, params: &HashMap<String, String>) -> Result<String> {
|
||||
// Create a minijinja environment and context
|
||||
let mut env = minijinja::Environment::new();
|
||||
env.set_undefined_behavior(UndefinedBehavior::Strict);
|
||||
let template: Template<'_, '_> = env
|
||||
.template_from_str(content)
|
||||
.map_err(|e: Error| anyhow::anyhow!("Invalid template syntax: {}", e.to_string()))?;
|
||||
|
||||
// Render the template with the parameters
|
||||
template.render(params).map_err(|e: Error| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to render the recipe {} - please check if all required parameters are provided",
|
||||
e.to_string()
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use goose::recipe::{RecipeParameterInputType, RecipeParameterRequirement};
|
||||
use tempfile::TempDir;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn setup_recipe_file(instructions_and_parameters: &str) -> (TempDir, PathBuf) {
|
||||
let recipe_content = format!(
|
||||
r#"{{
|
||||
"version": "1.0.0",
|
||||
"title": "Test Recipe",
|
||||
"description": "A test recipe",
|
||||
{}
|
||||
}}"#,
|
||||
instructions_and_parameters
|
||||
);
|
||||
// Create a temporary file
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let recipe_path: std::path::PathBuf = temp_dir.path().join("test_recipe.json");
|
||||
std::fs::write(&recipe_path, recipe_content).unwrap();
|
||||
(temp_dir, recipe_path)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_render_content_with_params() {
|
||||
// Test basic parameter substitution
|
||||
let content = "Hello {{ name }}!";
|
||||
let mut params = HashMap::new();
|
||||
params.insert("name".to_string(), "World".to_string());
|
||||
let result = render_content_with_params(content, ¶ms).unwrap();
|
||||
assert_eq!(result, "Hello World!");
|
||||
|
||||
// Test empty parameter substitution
|
||||
let content = "Hello {{ empty }}!";
|
||||
let mut params = HashMap::new();
|
||||
params.insert("empty".to_string(), "".to_string());
|
||||
let result = render_content_with_params(content, ¶ms).unwrap();
|
||||
assert_eq!(result, "Hello !");
|
||||
|
||||
// Test multiple parameters
|
||||
let content = "{{ greeting }} {{ name }}!";
|
||||
let mut params = HashMap::new();
|
||||
params.insert("greeting".to_string(), "Hi".to_string());
|
||||
params.insert("name".to_string(), "Alice".to_string());
|
||||
let result = render_content_with_params(content, ¶ms).unwrap();
|
||||
assert_eq!(result, "Hi Alice!");
|
||||
|
||||
// Test missing parameter results in error
|
||||
let content = "Hello {{ missing }}!";
|
||||
let params = HashMap::new();
|
||||
let err = render_content_with_params(content, ¶ms).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("please check if all required parameters"));
|
||||
|
||||
// Test invalid template syntax results in error
|
||||
let content = "Hello {{ unclosed";
|
||||
let params = HashMap::new();
|
||||
let err = render_content_with_params(content, ¶ms).unwrap_err();
|
||||
assert!(err.to_string().contains("Invalid template syntax"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_recipe_as_template_success() {
|
||||
let instructions_and_parameters = r#"
|
||||
"instructions": "Test instructions with {{ my_name }}",
|
||||
"parameters": [
|
||||
{
|
||||
"key": "my_name",
|
||||
"input_type": "string",
|
||||
"requirement": "required",
|
||||
"description": "A test parameter"
|
||||
}
|
||||
]"#;
|
||||
|
||||
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
|
||||
|
||||
let params = vec![("my_name".to_string(), "value".to_string())];
|
||||
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap();
|
||||
|
||||
assert_eq!(recipe.title, "Test Recipe");
|
||||
assert_eq!(recipe.description, "A test recipe");
|
||||
assert_eq!(recipe.instructions.unwrap(), "Test instructions with value");
|
||||
// Verify parameters match recipe definition
|
||||
assert_eq!(recipe.parameters.as_ref().unwrap().len(), 1);
|
||||
let param = &recipe.parameters.as_ref().unwrap()[0];
|
||||
assert_eq!(param.key, "my_name");
|
||||
assert!(matches!(param.input_type, RecipeParameterInputType::String));
|
||||
assert!(matches!(
|
||||
param.requirement,
|
||||
RecipeParameterRequirement::Required
|
||||
));
|
||||
assert_eq!(param.description, "A test parameter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_recipe_as_template_success_variable_in_prompt() {
|
||||
let instructions_and_parameters = r#"
|
||||
"instructions": "Test instructions",
|
||||
"prompt": "My prompt {{ my_name }}",
|
||||
"parameters": [
|
||||
{
|
||||
"key": "my_name",
|
||||
"input_type": "string",
|
||||
"requirement": "required",
|
||||
"description": "A test parameter"
|
||||
}
|
||||
]"#;
|
||||
|
||||
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
|
||||
|
||||
let params = vec![("my_name".to_string(), "value".to_string())];
|
||||
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap();
|
||||
|
||||
assert_eq!(recipe.title, "Test Recipe");
|
||||
assert_eq!(recipe.description, "A test recipe");
|
||||
assert_eq!(recipe.instructions.unwrap(), "Test instructions");
|
||||
assert_eq!(recipe.prompt.unwrap(), "My prompt value");
|
||||
let param = &recipe.parameters.as_ref().unwrap()[0];
|
||||
assert_eq!(param.key, "my_name");
|
||||
assert!(matches!(param.input_type, RecipeParameterInputType::String));
|
||||
assert!(matches!(
|
||||
param.requirement,
|
||||
RecipeParameterRequirement::Required
|
||||
));
|
||||
assert_eq!(param.description, "A test parameter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_recipe_as_template_wrong_parameters_in_recipe_file() {
|
||||
let instructions_and_parameters = r#"
|
||||
"instructions": "Test instructions with {{ expected_param1 }} {{ expected_param2 }}",
|
||||
"parameters": [
|
||||
{
|
||||
"key": "wrong_param_key",
|
||||
"input_type": "string",
|
||||
"requirement": "required",
|
||||
"description": "A test parameter"
|
||||
}
|
||||
]"#;
|
||||
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
|
||||
|
||||
let load_recipe_result = load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new());
|
||||
assert!(load_recipe_result.is_err());
|
||||
let err = load_recipe_result.unwrap_err();
|
||||
println!("{}", err.to_string());
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("Unnecessary parameter definitions: wrong_param_key."));
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("Missing definitions for parameters in the recipe file:"));
|
||||
assert!(err.to_string().contains("expected_param1"));
|
||||
assert!(err.to_string().contains("expected_param2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_recipe_as_template_with_default_values_in_recipe_file() {
|
||||
let instructions_and_parameters = r#"
|
||||
"instructions": "Test instructions with {{ param_with_default }} {{ param_without_default }}",
|
||||
"parameters": [
|
||||
{
|
||||
"key": "param_with_default",
|
||||
"input_type": "string",
|
||||
"requirement": "optional",
|
||||
"default": "my_default_value",
|
||||
"description": "A test parameter"
|
||||
},
|
||||
{
|
||||
"key": "param_without_default",
|
||||
"input_type": "string",
|
||||
"requirement": "required",
|
||||
"description": "A test parameter"
|
||||
}
|
||||
]"#;
|
||||
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
|
||||
let params = vec![("param_without_default".to_string(), "value1".to_string())];
|
||||
|
||||
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap();
|
||||
|
||||
assert_eq!(recipe.title, "Test Recipe");
|
||||
assert_eq!(recipe.description, "A test recipe");
|
||||
assert_eq!(
|
||||
recipe.instructions.unwrap(),
|
||||
"Test instructions with my_default_value value1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_recipe_as_template_optional_parameters_with_empty_default_values_in_recipe_file() {
|
||||
let instructions_and_parameters = r#"
|
||||
"instructions": "Test instructions with {{ optional_param }}",
|
||||
"parameters": [
|
||||
{
|
||||
"key": "optional_param",
|
||||
"input_type": "string",
|
||||
"requirement": "optional",
|
||||
"description": "A test parameter",
|
||||
"default": "",
|
||||
}
|
||||
]"#;
|
||||
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
|
||||
|
||||
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new()).unwrap();
|
||||
assert_eq!(recipe.title, "Test Recipe");
|
||||
assert_eq!(recipe.description, "A test recipe");
|
||||
assert_eq!(recipe.instructions.unwrap(), "Test instructions with ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_recipe_as_template_optional_parameters_without_default_values_in_recipe_file() {
|
||||
let instructions_and_parameters = r#"
|
||||
"instructions": "Test instructions with {{ optional_param }}",
|
||||
"parameters": [
|
||||
{
|
||||
"key": "optional_param",
|
||||
"input_type": "string",
|
||||
"requirement": "optional",
|
||||
"description": "A test parameter"
|
||||
}
|
||||
]"#;
|
||||
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
|
||||
|
||||
let load_recipe_result = load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new());
|
||||
assert!(load_recipe_result.is_err());
|
||||
let err = load_recipe_result.unwrap_err();
|
||||
println!("{}", err.to_string());
|
||||
assert!(err.to_string().contains(
|
||||
"Optional parameters missing default values in the recipe: optional_param. Please provide defaults."
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_recipe_as_template_wrong_input_type_in_recipe_file() {
|
||||
let instructions_and_parameters = r#"
|
||||
"instructions": "Test instructions with {{ param }}",
|
||||
"parameters": [
|
||||
{
|
||||
"key": "param",
|
||||
"input_type": "some_invalid_type",
|
||||
"requirement": "required",
|
||||
"description": "A test parameter"
|
||||
}
|
||||
]"#;
|
||||
let params = vec![("param".to_string(), "value".to_string())];
|
||||
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
|
||||
|
||||
let load_recipe_result = load_recipe_as_template(recipe_path.to_str().unwrap(), params);
|
||||
assert!(load_recipe_result.is_err());
|
||||
let err = load_recipe_result.unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("unknown variant `some_invalid_type`"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_recipe_as_template_success_without_parameters() {
|
||||
let instructions_and_parameters = r#"
|
||||
"instructions": "Test instructions"
|
||||
"#;
|
||||
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
|
||||
|
||||
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new()).unwrap();
|
||||
assert_eq!(recipe.instructions.unwrap(), "Test instructions");
|
||||
assert!(recipe.parameters.is_none());
|
||||
}
|
||||
}
|
||||
101
crates/goose-cli/src/recipes/search_recipe.rs
Normal file
101
crates/goose-cli/src/recipes/search_recipe.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use goose::config::Config;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::{env, fs};
|
||||
|
||||
use crate::recipes::recipe::RECIPE_FILE_EXTENSIONS;
|
||||
|
||||
use super::github_recipe::{retrieve_recipe_from_github, GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY};
|
||||
|
||||
const GOOSE_RECIPE_PATH_ENV_VAR: &str = "GOOSE_RECIPE_PATH";
|
||||
|
||||
pub fn retrieve_recipe_file(recipe_name: &str) -> Result<(String, PathBuf)> {
|
||||
// If recipe_name ends with yaml or json, treat it as a direct file path
|
||||
if RECIPE_FILE_EXTENSIONS
|
||||
.iter()
|
||||
.any(|ext| recipe_name.ends_with(&format!(".{}", ext)))
|
||||
{
|
||||
let path = PathBuf::from(recipe_name);
|
||||
return read_recipe_file(path);
|
||||
}
|
||||
retrieve_recipe_from_local_path(recipe_name).or_else(|e| {
|
||||
if let Some(recipe_repo_full_name) = configured_github_recipe_repo() {
|
||||
retrieve_recipe_from_github(recipe_name, &recipe_repo_full_name)
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn read_recipe_in_dir(dir: &Path, recipe_name: &str) -> Result<(String, PathBuf)> {
|
||||
for ext in RECIPE_FILE_EXTENSIONS {
|
||||
let recipe_path = dir.join(format!("{}.{}", recipe_name, ext));
|
||||
if let Ok(result) = read_recipe_file(recipe_path) {
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
Err(anyhow!(format!(
|
||||
"No {}.yaml or {}.json recipe file found in directory: {}",
|
||||
recipe_name,
|
||||
recipe_name,
|
||||
dir.display()
|
||||
)))
|
||||
}
|
||||
|
||||
fn retrieve_recipe_from_local_path(recipe_name: &str) -> Result<(String, PathBuf)> {
|
||||
let mut search_dirs = vec![PathBuf::from(".")];
|
||||
if let Ok(recipe_path_env) = env::var(GOOSE_RECIPE_PATH_ENV_VAR) {
|
||||
let path_separator = if cfg!(windows) { ';' } else { ':' };
|
||||
let recipe_path_env_dirs: Vec<PathBuf> = recipe_path_env
|
||||
.split(path_separator)
|
||||
.map(PathBuf::from)
|
||||
.collect();
|
||||
search_dirs.extend(recipe_path_env_dirs);
|
||||
}
|
||||
for dir in &search_dirs {
|
||||
if let Ok(result) = read_recipe_in_dir(dir, recipe_name) {
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
let search_dirs_str = search_dirs
|
||||
.iter()
|
||||
.map(|p| p.to_string_lossy())
|
||||
.collect::<Vec<_>>()
|
||||
.join(":");
|
||||
Err(anyhow!(
|
||||
"ℹ️ Failed to retrieve {}.yaml or {}.json in {}",
|
||||
recipe_name,
|
||||
recipe_name,
|
||||
search_dirs_str
|
||||
))
|
||||
}
|
||||
|
||||
fn configured_github_recipe_repo() -> Option<String> {
|
||||
let config = Config::global();
|
||||
match config.get_param(GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY) {
|
||||
Ok(Some(recipe_repo_full_name)) => Some(recipe_repo_full_name),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn read_recipe_file<P: AsRef<Path>>(recipe_path: P) -> Result<(String, PathBuf)> {
|
||||
let path = recipe_path.as_ref();
|
||||
|
||||
let content = fs::read_to_string(path)
|
||||
.map_err(|e| anyhow!("Failed to read recipe file {}: {}", path.display(), e))?;
|
||||
|
||||
let canonical = path.canonicalize().map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to resolve absolute path for {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
let parent_dir = canonical
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow!("Resolved path has no parent: {}", canonical.display()))?
|
||||
.to_path_buf();
|
||||
|
||||
Ok((content, parent_dir))
|
||||
}
|
||||
@@ -21,6 +21,8 @@ pub struct SessionBuilderConfig {
|
||||
pub identifier: Option<Identifier>,
|
||||
/// Whether to resume an existing session
|
||||
pub resume: bool,
|
||||
/// Whether to run without a session file
|
||||
pub no_session: bool,
|
||||
/// List of stdio extension commands to add
|
||||
pub extensions: Vec<String>,
|
||||
/// List of remote extension commands to add
|
||||
@@ -33,6 +35,8 @@ pub struct SessionBuilderConfig {
|
||||
pub additional_system_prompt: Option<String>,
|
||||
/// Enable debug printing
|
||||
pub debug: bool,
|
||||
/// Maximum number of consecutive identical tool calls allowed
|
||||
pub max_tool_repetitions: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
@@ -51,10 +55,31 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
// Create the agent
|
||||
let agent: Agent = Agent::new();
|
||||
let new_provider = create(&provider_name, model_config).unwrap();
|
||||
let _ = agent.update_provider(new_provider).await;
|
||||
agent
|
||||
.update_provider(new_provider)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
output::render_error(&format!("Failed to initialize agent: {}", e));
|
||||
process::exit(1);
|
||||
});
|
||||
|
||||
// Configure tool monitoring if max_tool_repetitions is set
|
||||
if let Some(max_repetitions) = session_config.max_tool_repetitions {
|
||||
agent.configure_tool_monitor(Some(max_repetitions)).await;
|
||||
}
|
||||
|
||||
// Handle session file resolution and resuming
|
||||
let session_file = if session_config.resume {
|
||||
let session_file = if session_config.no_session {
|
||||
// Use a temporary path that won't be written to
|
||||
#[cfg(unix)]
|
||||
{
|
||||
std::path::PathBuf::from("/dev/null")
|
||||
}
|
||||
#[cfg(windows)]
|
||||
{
|
||||
std::path::PathBuf::from("NUL")
|
||||
}
|
||||
} else if session_config.resume {
|
||||
if let Some(identifier) = session_config.identifier {
|
||||
let session_file = session::get_path(identifier);
|
||||
if !session_file.exists() {
|
||||
@@ -87,7 +112,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
session::get_path(id)
|
||||
};
|
||||
|
||||
if session_config.resume {
|
||||
if session_config.resume && !session_config.no_session {
|
||||
// Read the session metadata
|
||||
let metadata = session::read_metadata(&session_file).unwrap_or_else(|e| {
|
||||
output::render_error(&format!("Failed to read session metadata: {}", e));
|
||||
@@ -103,7 +128,17 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
.interact().expect("Failed to get user input");
|
||||
|
||||
if change_workdir {
|
||||
std::env::set_current_dir(metadata.working_dir).unwrap();
|
||||
if !metadata.working_dir.exists() {
|
||||
output::render_error(&format!(
|
||||
"Cannot switch to original working directory - {} no longer exists",
|
||||
style(metadata.working_dir.display()).cyan()
|
||||
));
|
||||
} else if let Err(e) = std::env::set_current_dir(&metadata.working_dir) {
|
||||
output::render_error(&format!(
|
||||
"Failed to switch to original working directory: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,8 +352,13 @@ impl Helper for GooseCompleter {}
|
||||
impl Hinter for GooseCompleter {
|
||||
type Hint = String;
|
||||
|
||||
fn hint(&self, _line: &str, _pos: usize, _ctx: &rustyline::Context<'_>) -> Option<Self::Hint> {
|
||||
None
|
||||
fn hint(&self, line: &str, _pos: usize, _ctx: &rustyline::Context<'_>) -> Option<Self::Hint> {
|
||||
// Only show hint when line is empty
|
||||
if line.is_empty() {
|
||||
Some("Press Enter to send, Ctrl-J for new line".to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,7 +372,9 @@ impl Highlighter for GooseCompleter {
|
||||
}
|
||||
|
||||
fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> {
|
||||
Cow::Borrowed(hint)
|
||||
// Style the hint text with a dim color
|
||||
let styled = console::Style::new().dim().apply_to(hint).to_string();
|
||||
Cow::Owned(styled)
|
||||
}
|
||||
|
||||
fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
|
||||
|
||||
@@ -18,6 +18,7 @@ pub enum InputResult {
|
||||
Plan(PlanCommandOptions),
|
||||
EndPlan,
|
||||
Recipe(Option<String>),
|
||||
Summarize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -91,6 +92,7 @@ fn handle_slash_command(input: &str) -> Option<InputResult> {
|
||||
const CMD_PLAN: &str = "/plan";
|
||||
const CMD_ENDPLAN: &str = "/endplan";
|
||||
const CMD_RECIPE: &str = "/recipe";
|
||||
const CMD_SUMMARIZE: &str = "/summarize";
|
||||
|
||||
match input {
|
||||
"/exit" | "/quit" => Some(InputResult::Exit),
|
||||
@@ -133,6 +135,7 @@ fn handle_slash_command(input: &str) -> Option<InputResult> {
|
||||
s if s.starts_with(CMD_PLAN) => parse_plan_command(s[CMD_PLAN.len()..].trim().to_string()),
|
||||
s if s == CMD_ENDPLAN => Some(InputResult::EndPlan),
|
||||
s if s.starts_with(CMD_RECIPE) => parse_recipe_command(s),
|
||||
s if s == CMD_SUMMARIZE => Some(InputResult::Summarize),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -241,6 +244,7 @@ fn print_help() {
|
||||
/endplan - Exit plan mode and return to 'normal' goose mode.
|
||||
/recipe [filepath] - Generate a recipe from the current conversation and save it to the specified filepath (must end with .yaml).
|
||||
If no filepath is provided, it will be saved to ./recipe.yaml.
|
||||
/summarize - Summarize the current conversation to reduce context length while preserving key information.
|
||||
/? or /help - Display this help message
|
||||
|
||||
Navigation:
|
||||
@@ -474,4 +478,15 @@ mod tests {
|
||||
let result = handle_slash_command("/recipe /path/to/file.txt");
|
||||
assert!(matches!(result, Some(InputResult::Retry)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summarize_command() {
|
||||
// Test the summarize command
|
||||
let result = handle_slash_command("/summarize");
|
||||
assert!(matches!(result, Some(InputResult::Summarize)));
|
||||
|
||||
// Test with whitespace
|
||||
let result = handle_slash_command(" /summarize ");
|
||||
assert!(matches!(result, Some(InputResult::Summarize)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,6 +122,21 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to summarize context messages
|
||||
async fn summarize_context_messages(
|
||||
messages: &mut Vec<Message>,
|
||||
agent: &Agent,
|
||||
message_suffix: &str,
|
||||
) -> Result<()> {
|
||||
// Summarize messages to fit within context length
|
||||
let (summarized_messages, _) = agent.summarize_context(messages).await?;
|
||||
let msg = format!("Context maxed out\n{}\n{}", "-".repeat(50), message_suffix);
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
*messages = summarized_messages;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a stdio extension to the session
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -290,6 +305,22 @@ impl Session {
|
||||
// Persist messages with provider for automatic description generation
|
||||
session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?;
|
||||
|
||||
// Track the current directory and last instruction in projects.json
|
||||
let session_id = self
|
||||
.session_file
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Err(e) =
|
||||
crate::project_tracker::update_project_tracker(Some(&message), session_id.as_deref())
|
||||
{
|
||||
eprintln!(
|
||||
"Warning: Failed to update project tracker with instruction: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
self.process_agent_response(false).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -356,6 +387,20 @@ impl Session {
|
||||
|
||||
self.messages.push(Message::user().with_text(&content));
|
||||
|
||||
// Track the current directory and last instruction in projects.json
|
||||
let session_id = self
|
||||
.session_file
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Err(e) = crate::project_tracker::update_project_tracker(
|
||||
Some(&content),
|
||||
session_id.as_deref(),
|
||||
) {
|
||||
eprintln!("Warning: Failed to update project tracker with instruction: {}", e);
|
||||
}
|
||||
|
||||
// Get the provider from the agent for description generation
|
||||
let provider = self.agent.provider().await?;
|
||||
|
||||
@@ -503,6 +548,62 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
InputResult::Summarize => {
|
||||
save_history(&mut editor);
|
||||
|
||||
let prompt = "Are you sure you want to summarize this conversation? This will condense the message history.";
|
||||
let should_summarize =
|
||||
match cliclack::confirm(prompt).initial_value(true).interact() {
|
||||
Ok(choice) => choice,
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::Interrupted {
|
||||
false // If interrupted, set should_summarize to false
|
||||
} else {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if should_summarize {
|
||||
println!("{}", console::style("Summarizing conversation...").yellow());
|
||||
output::show_thinking();
|
||||
|
||||
// Get the provider for summarization
|
||||
let provider = self.agent.provider().await?;
|
||||
|
||||
// Call the summarize_context method which uses the summarize_messages function
|
||||
let (summarized_messages, _) =
|
||||
self.agent.summarize_context(&self.messages).await?;
|
||||
|
||||
// Update the session messages with the summarized ones
|
||||
self.messages = summarized_messages;
|
||||
|
||||
// Persist the summarized messages
|
||||
session::persist_messages(
|
||||
&self.session_file,
|
||||
&self.messages,
|
||||
Some(provider),
|
||||
)
|
||||
.await?;
|
||||
|
||||
output::hide_thinking();
|
||||
println!(
|
||||
"{}",
|
||||
console::style("Conversation has been summarized.").green()
|
||||
);
|
||||
println!(
|
||||
"{}",
|
||||
console::style(
|
||||
"Key information has been preserved while reducing context length."
|
||||
)
|
||||
.green()
|
||||
);
|
||||
} else {
|
||||
println!("{}", console::style("Summarization cancelled.").yellow());
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -532,10 +633,21 @@ impl Session {
|
||||
match planner_response_type {
|
||||
PlannerResponseType::Plan => {
|
||||
println!();
|
||||
let should_act =
|
||||
cliclack::confirm("Do you want to clear message history & act on this plan?")
|
||||
.initial_value(true)
|
||||
.interact()?;
|
||||
let should_act = match cliclack::confirm(
|
||||
"Do you want to clear message history & act on this plan?",
|
||||
)
|
||||
.initial_value(true)
|
||||
.interact()
|
||||
{
|
||||
Ok(choice) => choice,
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::Interrupted {
|
||||
false // If interrupted, set should_act to false
|
||||
} else {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
if should_act {
|
||||
output::render_act_on_plan();
|
||||
self.run_mode = RunMode::Normal;
|
||||
@@ -596,6 +708,7 @@ impl Session {
|
||||
id: session_id.clone(),
|
||||
working_dir: std::env::current_dir()
|
||||
.expect("failed to get current session working directory"),
|
||||
schedule_id: None,
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
@@ -614,51 +727,99 @@ impl Session {
|
||||
let prompt = "Goose would like to call the above tool, do you allow?".to_string();
|
||||
|
||||
// Get confirmation from user
|
||||
let permission = cliclack::select(prompt)
|
||||
let permission_result = cliclack::select(prompt)
|
||||
.item(Permission::AllowOnce, "Allow", "Allow the tool call once")
|
||||
.item(Permission::AlwaysAllow, "Always Allow", "Always allow the tool call")
|
||||
.item(Permission::DenyOnce, "Deny", "Deny the tool call")
|
||||
.interact()?;
|
||||
self.agent.handle_confirmation(confirmation.id.clone(), PermissionConfirmation {
|
||||
principal_type: PrincipalType::Tool,
|
||||
permission,
|
||||
},).await;
|
||||
.item(Permission::Cancel, "Cancel", "Cancel the AI response and tool call")
|
||||
.interact();
|
||||
|
||||
let permission = match permission_result {
|
||||
Ok(p) => p, // If Ok, use the selected permission
|
||||
Err(e) => {
|
||||
// Check if the error is an interruption (Ctrl+C/Cmd+C, Escape)
|
||||
if e.kind() == std::io::ErrorKind::Interrupted {
|
||||
Permission::Cancel // If interrupted, set permission to Cancel
|
||||
} else {
|
||||
return Err(e.into()); // Otherwise, convert and propagate the original error
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if permission == Permission::Cancel {
|
||||
output::render_text("Tool call cancelled. Returning to chat...", Some(Color::Yellow), true);
|
||||
|
||||
let mut response_message = Message::user();
|
||||
response_message.content.push(MessageContent::tool_response(
|
||||
confirmation.id.clone(),
|
||||
Err(ToolError::ExecutionError("Tool call cancelled by user".to_string()))
|
||||
));
|
||||
self.messages.push(response_message);
|
||||
session::persist_messages(&self.session_file, &self.messages, None).await?;
|
||||
|
||||
drop(stream);
|
||||
break;
|
||||
} else {
|
||||
self.agent.handle_confirmation(confirmation.id.clone(), PermissionConfirmation {
|
||||
principal_type: PrincipalType::Tool,
|
||||
permission,
|
||||
},).await;
|
||||
}
|
||||
} else if let Some(MessageContent::ContextLengthExceeded(_)) = message.content.first() {
|
||||
output::hide_thinking();
|
||||
|
||||
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
|
||||
let selected = cliclack::select(prompt)
|
||||
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
|
||||
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
|
||||
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
|
||||
.interact()?;
|
||||
if interactive {
|
||||
// In interactive mode, ask the user what to do
|
||||
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
|
||||
let selected_result = cliclack::select(prompt)
|
||||
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
|
||||
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
|
||||
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
|
||||
.item("cancel", "Cancel", "Cancel and return to chat")
|
||||
.interact();
|
||||
|
||||
match selected {
|
||||
"clear" => {
|
||||
self.messages.clear();
|
||||
let msg = format!("Session cleared.\n{}", "-".repeat(50));
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
break; // exit the loop to hand back control to the user
|
||||
}
|
||||
"truncate" => {
|
||||
// Truncate messages to fit within context length
|
||||
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
|
||||
let msg = format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50));
|
||||
output::render_text("", Some(Color::Yellow), true);
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
self.messages = truncated_messages;
|
||||
}
|
||||
"summarize" => {
|
||||
// Summarize messages to fit within context length
|
||||
let (summarized_messages, _) = self.agent.summarize_context(&self.messages).await?;
|
||||
let msg = format!("Context maxed out\n{}\nGoose summarized messages for you.", "-".repeat(50));
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
self.messages = summarized_messages;
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
let selected = match selected_result {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::Interrupted {
|
||||
"cancel" // If interrupted, set selected to cancel
|
||||
} else {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match selected {
|
||||
"clear" => {
|
||||
self.messages.clear();
|
||||
let msg = format!("Session cleared.\n{}", "-".repeat(50));
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
break; // exit the loop to hand back control to the user
|
||||
}
|
||||
"truncate" => {
|
||||
// Truncate messages to fit within context length
|
||||
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
|
||||
let msg = format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50));
|
||||
output::render_text("", Some(Color::Yellow), true);
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
self.messages = truncated_messages;
|
||||
}
|
||||
"summarize" => {
|
||||
// Use the helper function to summarize context
|
||||
Self::summarize_context_messages(&mut self.messages, &self.agent, "Goose summarized messages for you.").await?;
|
||||
}
|
||||
"cancel" => {
|
||||
break; // Return to main prompt
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// In headless mode (goose run), automatically use summarize
|
||||
Self::summarize_context_messages(&mut self.messages, &self.agent, "Goose automatically summarized messages to continue processing.").await?;
|
||||
}
|
||||
|
||||
// Restart the stream after handling ContextLengthExceeded
|
||||
stream = self
|
||||
.agent
|
||||
@@ -668,6 +829,7 @@ impl Session {
|
||||
id: session_id.clone(),
|
||||
working_dir: std::env::current_dir()
|
||||
.expect("failed to get current session working directory"),
|
||||
schedule_id: None,
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
@@ -852,6 +1014,31 @@ impl Session {
|
||||
self.messages.clone()
|
||||
}
|
||||
|
||||
/// Render all past messages from the session history
|
||||
pub fn render_message_history(&self) {
|
||||
if self.messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Print session restored message
|
||||
println!(
|
||||
"\n{} {} messages loaded into context.",
|
||||
console::style("Session restored:").green().bold(),
|
||||
console::style(self.messages.len()).green()
|
||||
);
|
||||
|
||||
// Render each message
|
||||
for message in &self.messages {
|
||||
output::render_message(message, self.debug);
|
||||
}
|
||||
|
||||
// Add a visual separator after restored messages
|
||||
println!(
|
||||
"\n{}\n",
|
||||
console::style("──────── New Messages ────────").dim()
|
||||
);
|
||||
}
|
||||
|
||||
/// Get the session metadata
|
||||
pub fn get_metadata(&self) -> Result<session::SessionMetadata> {
|
||||
if !self.session_file.exists() {
|
||||
@@ -976,29 +1163,30 @@ fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::create;
|
||||
|
||||
let (reasoner_provider, reasoner_model) = match (
|
||||
std::env::var("GOOSE_PLANNER_PROVIDER"),
|
||||
std::env::var("GOOSE_PLANNER_MODEL"),
|
||||
) {
|
||||
(Ok(provider), Ok(model)) => (provider, model),
|
||||
_ => {
|
||||
println!(
|
||||
"WARNING: GOOSE_PLANNER_PROVIDER or GOOSE_PLANNER_MODEL is not set. \
|
||||
Using default model from config..."
|
||||
);
|
||||
let config = Config::global();
|
||||
let provider = config
|
||||
.get_param("GOOSE_PROVIDER")
|
||||
.expect("No provider configured. Run 'goose configure' first");
|
||||
let model = config
|
||||
.get_param("GOOSE_MODEL")
|
||||
.expect("No model configured. Run 'goose configure' first");
|
||||
(provider, model)
|
||||
}
|
||||
let config = Config::global();
|
||||
|
||||
// Try planner-specific provider first, fallback to default provider
|
||||
let provider = if let Ok(provider) = config.get_param::<String>("GOOSE_PLANNER_PROVIDER") {
|
||||
provider
|
||||
} else {
|
||||
println!("WARNING: GOOSE_PLANNER_PROVIDER not found. Using default provider...");
|
||||
config
|
||||
.get_param::<String>("GOOSE_PROVIDER")
|
||||
.expect("No provider configured. Run 'goose configure' first")
|
||||
};
|
||||
|
||||
let model_config = ModelConfig::new(reasoner_model);
|
||||
let reasoner = create(&reasoner_provider, model_config)?;
|
||||
// Try planner-specific model first, fallback to default model
|
||||
let model = if let Ok(model) = config.get_param::<String>("GOOSE_PLANNER_MODEL") {
|
||||
model
|
||||
} else {
|
||||
println!("WARNING: GOOSE_PLANNER_MODEL not found. Using default model...");
|
||||
config
|
||||
.get_param::<String>("GOOSE_MODEL")
|
||||
.expect("No model configured. Run 'goose configure' first")
|
||||
};
|
||||
|
||||
let model_config = ModelConfig::new(model);
|
||||
let reasoner = create(&provider, model_config)?;
|
||||
|
||||
Ok(reasoner)
|
||||
}
|
||||
|
||||
@@ -209,7 +209,7 @@ fn render_tool_response(resp: &ToolResponse, theme: Theme, debug: bool) {
|
||||
let min_priority = config
|
||||
.get_param::<f32>("GOOSE_CLI_MIN_PRIORITY")
|
||||
.ok()
|
||||
.unwrap_or(0.0);
|
||||
.unwrap_or(0.5);
|
||||
|
||||
if content
|
||||
.priority()
|
||||
@@ -405,9 +405,15 @@ fn print_markdown(content: &str, theme: Theme) {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
const MAX_STRING_LENGTH: usize = 40;
|
||||
const INDENT: &str = " ";
|
||||
|
||||
fn get_tool_params_max_length() -> usize {
|
||||
Config::global()
|
||||
.get_param::<usize>("GOOSE_CLI_TOOL_PARAMS_TRUNCATION_MAX_LENGTH")
|
||||
.ok()
|
||||
.unwrap_or(40)
|
||||
}
|
||||
|
||||
fn print_params(value: &Value, depth: usize, debug: bool) {
|
||||
let indent = INDENT.repeat(depth);
|
||||
|
||||
@@ -427,7 +433,7 @@ fn print_params(value: &Value, depth: usize, debug: bool) {
|
||||
}
|
||||
}
|
||||
Value::String(s) => {
|
||||
if !debug && s.len() > MAX_STRING_LENGTH {
|
||||
if !debug && s.len() > get_tool_params_max_length() {
|
||||
println!("{}{}: {}", indent, style(key).dim(), style("...").dim());
|
||||
} else {
|
||||
println!("{}{}: {}", indent, style(key).dim(), style(s).green());
|
||||
@@ -452,7 +458,7 @@ fn print_params(value: &Value, depth: usize, debug: bool) {
|
||||
}
|
||||
}
|
||||
Value::String(s) => {
|
||||
if !debug && s.len() > MAX_STRING_LENGTH {
|
||||
if !debug && s.len() > get_tool_params_max_length() {
|
||||
println!(
|
||||
"{}{}",
|
||||
indent,
|
||||
@@ -527,6 +533,8 @@ fn shorten_path(path: &str, debug: bool) -> String {
|
||||
pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) {
|
||||
let start_session_msg = if resume {
|
||||
"resuming session |"
|
||||
} else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") {
|
||||
"running without session |"
|
||||
} else {
|
||||
"starting session |"
|
||||
};
|
||||
@@ -538,11 +546,15 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
|
||||
style("model:").dim(),
|
||||
style(model).cyan().dim(),
|
||||
);
|
||||
println!(
|
||||
" {} {}",
|
||||
style("logging to").dim(),
|
||||
style(session_file.display()).dim().cyan(),
|
||||
);
|
||||
|
||||
if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") {
|
||||
println!(
|
||||
" {} {}",
|
||||
style("logging to").dim(),
|
||||
style(session_file.display()).dim().cyan(),
|
||||
);
|
||||
}
|
||||
|
||||
println!(
|
||||
" {} {}",
|
||||
style("working directory:").dim(),
|
||||
|
||||
36
crates/goose-cli/src/signal.rs
Normal file
36
crates/goose-cli/src/signal.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use tokio::signal;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub fn shutdown_signal() -> Pin<Box<dyn Future<Output = ()> + Send>> {
|
||||
Box::pin(async move {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
pub fn shutdown_signal() -> Pin<Box<dyn Future<Output = ()> + Send>> {
|
||||
Box::pin(async move {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
})
|
||||
}
|
||||
@@ -16,10 +16,10 @@ To build the FFI library, you'll need Rust and Cargo installed. Then run:
|
||||
|
||||
```bash
|
||||
# Build the library in debug mode
|
||||
cargo build --package goose_ffi
|
||||
cargo build --package goose-ffi
|
||||
|
||||
# Build the library in release mode (recommended for production)
|
||||
cargo build --release --package goose_ffi
|
||||
cargo build --release --package goose-ffi
|
||||
```
|
||||
|
||||
This will generate a dynamic library (.so, .dll, or .dylib depending on your platform) in the `target` directory, and automatically generate the C header file in the `include` directory.
|
||||
@@ -54,7 +54,7 @@ To run the Python example:
|
||||
|
||||
```bash
|
||||
# First, build the FFI library
|
||||
cargo build --release --package goose_ffi
|
||||
cargo build --release --package goose-ffi
|
||||
|
||||
# Then set the environment variables & run the example
|
||||
DATABRICKS_HOST=... DATABRICKS_API_KEY=... python crates/goose-ffi/examples/goose_agent.py
|
||||
|
||||
61
crates/goose-llm/Cargo.toml
Normal file
61
crates/goose-llm/Cargo.toml
Normal file
@@ -0,0 +1,61 @@
|
||||
[package]
|
||||
name = "goose-llm"
|
||||
edition.workspace = true
|
||||
version.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
description.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["lib", "cdylib"]
|
||||
name = "goose_llm"
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
minijinja = "2.8.0"
|
||||
include_dir = "0.7.4"
|
||||
once_cell = "1.20.2"
|
||||
chrono = { version = "0.4.38", features = ["serde"] }
|
||||
reqwest = { version = "0.12.9", features = [
|
||||
"rustls-tls-native-roots",
|
||||
"json",
|
||||
"cookies",
|
||||
"gzip",
|
||||
"brotli",
|
||||
"deflate",
|
||||
"zstd",
|
||||
"charset",
|
||||
"http2",
|
||||
"stream"
|
||||
], default-features = false }
|
||||
async-trait = "0.1"
|
||||
url = "2.5"
|
||||
base64 = "0.21"
|
||||
regex = "1.11.1"
|
||||
tracing = "0.1"
|
||||
smallvec = { version = "1.13", features = ["serde"] }
|
||||
indoc = "1.0"
|
||||
# https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/fixtures/futures/Cargo.toml#L22
|
||||
uniffi = { version = "0.29", features = ["tokio", "cli", "scaffolding-ffi-buffer-fns"] }
|
||||
tokio = { version = "1.43", features = ["time", "sync"] }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5"
|
||||
tempfile = "3.15.0"
|
||||
dotenv = "0.15"
|
||||
lazy_static = "1.5"
|
||||
ctor = "0.2.7"
|
||||
tokio = { version = "1.43", features = ["full"] }
|
||||
|
||||
[[bin]]
|
||||
# https://mozilla.github.io/uniffi-rs/latest/tutorial/foreign_language_bindings.html
|
||||
name = "uniffi-bindgen"
|
||||
path = "uniffi-bindgen.rs"
|
||||
|
||||
[[example]]
|
||||
name = "simple"
|
||||
path = "examples/simple.rs"
|
||||
68
crates/goose-llm/README.md
Normal file
68
crates/goose-llm/README.md
Normal file
@@ -0,0 +1,68 @@
|
||||
## goose-llm
|
||||
|
||||
This crate is meant to be used for foreign function interface (FFI). It's meant to be
|
||||
stateless and contain logic related to providers and prompts:
|
||||
- chat completion with model providers
|
||||
- detecting read-only tools for smart approval
|
||||
- methods for summarization / truncation
|
||||
|
||||
|
||||
Run:
|
||||
```
|
||||
cargo run -p goose-llm --example simple
|
||||
```
|
||||
|
||||
|
||||
## Kotlin bindings
|
||||
|
||||
Structure:
|
||||
```
|
||||
.
|
||||
└── crates
|
||||
└── goose-llm/...
|
||||
└── target
|
||||
└── debug/libgoose_llm.dylib
|
||||
├── bindings
|
||||
│ └── kotlin
|
||||
│ ├── example
|
||||
│ │ └── Usage.kt ← your demo app
|
||||
│ └── uniffi
|
||||
│ └── goose_llm
|
||||
│ └── goose_llm.kt ← auto-generated bindings
|
||||
```
|
||||
|
||||
|
||||
#### Kotlin -> Rust: run example
|
||||
|
||||
The following `just` command creates kotlin bindings, then compiles and runs an example.
|
||||
|
||||
```bash
|
||||
just kotlin-example
|
||||
```
|
||||
|
||||
You will have to download jars in `bindings/kotlin/libs` directory (only the first time):
|
||||
```bash
|
||||
pushd bindings/kotlin/libs/
|
||||
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlin/kotlin-stdlib/1.9.0/kotlin-stdlib-1.9.0.jar
|
||||
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlinx/kotlinx-coroutines-core-jvm/1.7.3/kotlinx-coroutines-core-jvm-1.7.3.jar
|
||||
curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.jar
|
||||
popd
|
||||
```
|
||||
|
||||
To just create the Kotlin bindings:
|
||||
|
||||
```bash
|
||||
# run from project root directory
|
||||
cargo build -p goose-llm
|
||||
|
||||
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
|
||||
```
|
||||
|
||||
|
||||
#### Python -> Rust: generate bindings, run example
|
||||
|
||||
```bash
|
||||
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language python --out-dir bindings/python
|
||||
|
||||
DYLD_LIBRARY_PATH=./target/debug python bindings/python/usage.py
|
||||
```
|
||||
123
crates/goose-llm/examples/simple.rs
Normal file
123
crates/goose-llm/examples/simple.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use std::vec;
|
||||
|
||||
use anyhow::Result;
|
||||
use goose_llm::{
|
||||
completion,
|
||||
extractors::generate_tooltip,
|
||||
types::completion::{
|
||||
CompletionRequest, CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig,
|
||||
},
|
||||
Message, ModelConfig,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let provider = "databricks";
|
||||
let provider_config = json!({
|
||||
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
// let model_name = "goose-gpt-4-1"; // parallel tool calls
|
||||
let model_name = "claude-3-5-haiku";
|
||||
let model_config = ModelConfig::new(model_name.to_string());
|
||||
|
||||
let calculator_tool = ToolConfig::new(
|
||||
"calculator",
|
||||
"Perform basic arithmetic operations",
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["operation", "numbers"],
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The arithmetic operation to perform",
|
||||
},
|
||||
"numbers": {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"description": "List of numbers to operate on in order",
|
||||
}
|
||||
}
|
||||
}),
|
||||
ToolApprovalMode::Auto,
|
||||
);
|
||||
|
||||
let bash_tool = ToolConfig::new(
|
||||
"bash_shell",
|
||||
"Run a shell command",
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["command"],
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
}
|
||||
}
|
||||
}),
|
||||
ToolApprovalMode::Manual,
|
||||
);
|
||||
|
||||
let list_dir_tool = ToolConfig::new(
|
||||
"list_directory",
|
||||
"List files in a directory",
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["path"],
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The directory path to list files from"
|
||||
}
|
||||
}
|
||||
}),
|
||||
ToolApprovalMode::Auto,
|
||||
);
|
||||
|
||||
let extensions = vec![
|
||||
ExtensionConfig::new(
|
||||
"calculator_extension".to_string(),
|
||||
Some("This extension provides a calculator tool.".to_string()),
|
||||
vec![calculator_tool],
|
||||
),
|
||||
ExtensionConfig::new(
|
||||
"bash_extension".to_string(),
|
||||
Some("This extension provides a bash shell tool.".to_string()),
|
||||
vec![bash_tool, list_dir_tool],
|
||||
),
|
||||
];
|
||||
|
||||
let system_preamble = "You are a helpful assistant.";
|
||||
|
||||
for text in [
|
||||
"Add 10037 + 23123 using calculator and also run 'date -u' using bash",
|
||||
"List all files in the current directory",
|
||||
] {
|
||||
println!("\n---------------\n");
|
||||
println!("User Input: {text}");
|
||||
let messages = vec![
|
||||
Message::user().with_text("Hi there!"),
|
||||
Message::assistant().with_text("How can I help?"),
|
||||
Message::user().with_text(text),
|
||||
];
|
||||
let completion_response: CompletionResponse = completion(CompletionRequest::new(
|
||||
provider.to_string(),
|
||||
provider_config.clone(),
|
||||
model_config.clone(),
|
||||
system_preamble.to_string(),
|
||||
messages.clone(),
|
||||
extensions.clone(),
|
||||
))
|
||||
.await?;
|
||||
// Print the response
|
||||
println!("\nCompletion Response:");
|
||||
println!("{}", serde_json::to_string_pretty(&completion_response)?);
|
||||
|
||||
let tooltip = generate_tooltip(provider, provider_config.clone().into(), &messages).await?;
|
||||
println!("\nTooltip: {}", tooltip);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
144
crates/goose-llm/src/completion.rs
Normal file
144
crates/goose-llm/src/completion.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
use std::{collections::HashMap, time::Instant};
|
||||
|
||||
use anyhow::Result;
|
||||
use chrono::Utc;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
message::{Message, MessageContent},
|
||||
prompt_template,
|
||||
providers::create,
|
||||
types::{
|
||||
completion::{
|
||||
CompletionError, CompletionRequest, CompletionResponse, ExtensionConfig,
|
||||
RuntimeMetrics, ToolApprovalMode, ToolConfig,
|
||||
},
|
||||
core::ToolCall,
|
||||
},
|
||||
};
|
||||
|
||||
#[uniffi::export]
|
||||
pub fn print_messages(messages: Vec<Message>) {
|
||||
for msg in messages {
|
||||
println!("[{:?} @ {}] {:?}", msg.role, msg.created, msg.content);
|
||||
}
|
||||
}
|
||||
|
||||
/// Public API for the Goose LLM completion function
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, CompletionError> {
|
||||
let start_total = Instant::now();
|
||||
|
||||
let provider = create(
|
||||
&req.provider_name,
|
||||
req.provider_config.clone(),
|
||||
req.model_config.clone(),
|
||||
)
|
||||
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
|
||||
|
||||
let system_prompt = construct_system_prompt(&req.system_preamble, &req.extensions)?;
|
||||
let tools = collect_prefixed_tools(&req.extensions);
|
||||
|
||||
// Call the LLM provider
|
||||
let start_provider = Instant::now();
|
||||
let mut response = provider
|
||||
.complete(&system_prompt, &req.messages, &tools)
|
||||
.await?;
|
||||
let provider_elapsed_sec = start_provider.elapsed().as_secs_f32();
|
||||
let usage_tokens = response.usage.total_tokens;
|
||||
|
||||
let tool_configs = collect_prefixed_tool_configs(&req.extensions);
|
||||
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?;
|
||||
|
||||
Ok(CompletionResponse::new(
|
||||
response.message,
|
||||
response.model,
|
||||
response.usage,
|
||||
calculate_runtime_metrics(start_total, provider_elapsed_sec, usage_tokens),
|
||||
))
|
||||
}
|
||||
|
||||
/// Render the global `system.md` template with the provided context.
|
||||
fn construct_system_prompt(
|
||||
system_preamble: &str,
|
||||
extensions: &[ExtensionConfig],
|
||||
) -> Result<String, CompletionError> {
|
||||
let mut context: HashMap<&str, Value> = HashMap::new();
|
||||
context.insert("system_preamble", Value::String(system_preamble.to_owned()));
|
||||
context.insert("extensions", serde_json::to_value(extensions)?);
|
||||
context.insert(
|
||||
"current_date",
|
||||
Value::String(Utc::now().format("%Y-%m-%d").to_string()),
|
||||
);
|
||||
|
||||
Ok(prompt_template::render_global_file("system.md", &context)?)
|
||||
}
|
||||
|
||||
/// Determine if a tool call requires manual approval.
|
||||
fn determine_needs_approval(config: &ToolConfig, _call: &ToolCall) -> bool {
|
||||
match config.approval_mode {
|
||||
ToolApprovalMode::Auto => false,
|
||||
ToolApprovalMode::Manual => true,
|
||||
ToolApprovalMode::Smart => {
|
||||
// TODO: Implement smart approval logic later
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set `needs_approval` on every tool call in the message.
|
||||
/// Returns a `ToolNotFound` error if the corresponding `ToolConfig` is missing.
|
||||
pub fn update_needs_approval_for_tool_calls(
|
||||
message: &mut Message,
|
||||
tool_configs: &HashMap<String, ToolConfig>,
|
||||
) -> Result<(), CompletionError> {
|
||||
for content in &mut message.content.iter_mut() {
|
||||
if let MessageContent::ToolReq(req) = content {
|
||||
if let Ok(call) = &mut req.tool_call.0 {
|
||||
// Provide a clear error message when the tool config is missing
|
||||
let config = tool_configs.get(&call.name).ok_or_else(|| {
|
||||
CompletionError::ToolNotFound(format!(
|
||||
"could not find tool config for '{}'",
|
||||
call.name
|
||||
))
|
||||
})?;
|
||||
let needs_approval = determine_needs_approval(config, call);
|
||||
call.set_needs_approval(needs_approval);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Collect all `Tool` instances from the extensions.
|
||||
fn collect_prefixed_tools(extensions: &[ExtensionConfig]) -> Vec<crate::types::core::Tool> {
|
||||
extensions
|
||||
.iter()
|
||||
.flat_map(|ext| ext.get_prefixed_tools())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Collect all `ToolConfig` entries from the extensions into a map.
|
||||
fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap<String, ToolConfig> {
|
||||
extensions
|
||||
.iter()
|
||||
.flat_map(|ext| ext.get_prefixed_tool_configs())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute runtime metrics for the request.
|
||||
fn calculate_runtime_metrics(
|
||||
total_start: Instant,
|
||||
provider_elapsed_sec: f32,
|
||||
token_count: Option<i32>,
|
||||
) -> RuntimeMetrics {
|
||||
let total_ms = total_start.elapsed().as_secs_f32();
|
||||
let tokens_per_sec = token_count.and_then(|toks| {
|
||||
if provider_elapsed_sec > 0.0 {
|
||||
Some(toks as f64 / (provider_elapsed_sec as f64))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
RuntimeMetrics::new(total_ms, provider_elapsed_sec, tokens_per_sec)
|
||||
}
|
||||
5
crates/goose-llm/src/extractors/mod.rs
Normal file
5
crates/goose-llm/src/extractors/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod session_name;
|
||||
mod tooltip;
|
||||
|
||||
pub use session_name::generate_session_name;
|
||||
pub use tooltip::generate_tooltip;
|
||||
111
crates/goose-llm/src/extractors/session_name.rs
Normal file
111
crates/goose-llm/src/extractors/session_name.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use crate::generate_structured_outputs;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::types::core::Role;
|
||||
use crate::{message::Message, types::json_value_ffi::JsonValueFfi};
|
||||
use anyhow::Result;
|
||||
use indoc::indoc;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const SESSION_NAME_EXAMPLES: &[&str] = &[
|
||||
"Research Synthesis",
|
||||
"Sentiment Analysis",
|
||||
"Performance Report",
|
||||
"Feedback Collector",
|
||||
"Accessibility Check",
|
||||
"Design Reminder",
|
||||
"Project Reminder",
|
||||
"Launch Checklist",
|
||||
"Metrics Monitor",
|
||||
"Incident Response",
|
||||
"Deploy Cabinet App",
|
||||
"Design Reminder Alert",
|
||||
"Generate Monthly Expense Report",
|
||||
"Automate Incident Response Workflow",
|
||||
"Analyze Brand Sentiment Trends",
|
||||
"Monitor Device Health Issues",
|
||||
"Collect UI Feedback Summary",
|
||||
"Schedule Project Deadline Reminders",
|
||||
];
|
||||
|
||||
fn build_system_prompt() -> String {
|
||||
let examples = SESSION_NAME_EXAMPLES
|
||||
.iter()
|
||||
.map(|e| format!("- {}", e))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
indoc! {r#"
|
||||
You are an assistant that crafts a concise session title.
|
||||
Given the first couple user messages in the conversation so far,
|
||||
reply with only a short name (up to 4 words) that best describes
|
||||
this session’s goal.
|
||||
|
||||
Examples:
|
||||
"#}
|
||||
.to_string()
|
||||
+ &examples
|
||||
}
|
||||
|
||||
/// Generates a short (≤4 words) session name
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn generate_session_name(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
messages: &[Message],
|
||||
) -> Result<String, ProviderError> {
|
||||
// Collect up to the first 3 user messages (truncated to 300 chars each)
|
||||
let context: Vec<String> = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == Role::User)
|
||||
.take(3)
|
||||
.map(|m| {
|
||||
let text = m.content.concat_text_str();
|
||||
if text.len() > 300 {
|
||||
text.chars().take(300).collect()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if context.is_empty() {
|
||||
return Err(ProviderError::ExecutionError(
|
||||
"No user messages found to generate a session name.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let system_prompt = build_system_prompt();
|
||||
let user_msg_text = format!("Here are the user messages:\n{}", context.join("\n"));
|
||||
|
||||
// Use `extract` with a simple string schema
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
},
|
||||
"required": ["name"],
|
||||
"additionalProperties": false
|
||||
});
|
||||
|
||||
let resp = generate_structured_outputs(
|
||||
provider_name,
|
||||
provider_config,
|
||||
&system_prompt,
|
||||
&[Message::user().with_text(&user_msg_text)],
|
||||
schema,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let obj = resp
|
||||
.data
|
||||
.as_object()
|
||||
.ok_or_else(|| ProviderError::ResponseParseError("Expected object".into()))?;
|
||||
|
||||
let name = obj
|
||||
.get("name")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| ProviderError::ResponseParseError("Missing or non-string name".into()))?
|
||||
.to_string();
|
||||
|
||||
Ok(name)
|
||||
}
|
||||
169
crates/goose-llm/src/extractors/tooltip.rs
Normal file
169
crates/goose-llm/src/extractors/tooltip.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use crate::generate_structured_outputs;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::types::core::{Content, Role};
|
||||
use crate::types::json_value_ffi::JsonValueFfi;
|
||||
use anyhow::Result;
|
||||
use indoc::indoc;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const TOOLTIP_EXAMPLES: &[&str] = &[
|
||||
"analyzing KPIs",
|
||||
"detecting anomalies",
|
||||
"building artifacts in Buildkite",
|
||||
"categorizing issues",
|
||||
"checking dependencies",
|
||||
"collecting feedback",
|
||||
"deploying changes in AWS",
|
||||
"drafting report in Google Docs",
|
||||
"extracting action items",
|
||||
"generating insights",
|
||||
"logging issues",
|
||||
"monitoring tickets in Zendesk",
|
||||
"notifying design team",
|
||||
"running integration tests",
|
||||
"scanning threads in Figma",
|
||||
"sending reminders in Gmail",
|
||||
"sending surveys",
|
||||
"sharing with stakeholders",
|
||||
"summarizing findings",
|
||||
"transcribing meeting",
|
||||
"tracking resolution",
|
||||
"updating status in Linear",
|
||||
];
|
||||
|
||||
fn build_system_prompt() -> String {
|
||||
let examples = TOOLTIP_EXAMPLES
|
||||
.iter()
|
||||
.map(|e| format!("- {}", e))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
indoc! {r#"
|
||||
You are an assistant that summarizes the recent conversation into a tooltip.
|
||||
Given the last two messages, reply with only a short tooltip (up to 4 words)
|
||||
describing what is happening now.
|
||||
|
||||
Examples:
|
||||
"#}
|
||||
.to_string()
|
||||
+ &examples
|
||||
}
|
||||
|
||||
/// Generates a tooltip summarizing the last two messages in the session,
|
||||
/// including any tool calls or results.
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn generate_tooltip(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
messages: &[Message],
|
||||
) -> Result<String, ProviderError> {
|
||||
// Need at least two messages to generate a tooltip
|
||||
if messages.len() < 2 {
|
||||
return Err(ProviderError::ExecutionError(
|
||||
"Need at least two messages to generate a tooltip".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Helper to render a single message's content
|
||||
fn render_message(m: &Message) -> String {
|
||||
let mut parts = Vec::new();
|
||||
for content in m.content.iter() {
|
||||
match content {
|
||||
MessageContent::Text(text_block) => {
|
||||
let txt = text_block.text.trim();
|
||||
if !txt.is_empty() {
|
||||
parts.push(txt.to_string());
|
||||
}
|
||||
}
|
||||
MessageContent::ToolReq(req) => {
|
||||
if let Ok(tool_call) = &req.tool_call.0 {
|
||||
parts.push(format!(
|
||||
"called tool '{}' with args {}",
|
||||
tool_call.name, tool_call.arguments
|
||||
));
|
||||
} else if let Err(e) = &req.tool_call.0 {
|
||||
parts.push(format!("tool request error: {}", e));
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResp(resp) => match &resp.tool_result.0 {
|
||||
Ok(contents) => {
|
||||
let results: Vec<String> = contents
|
||||
.iter()
|
||||
.map(|c| match c {
|
||||
Content::Text(t) => t.text.clone(),
|
||||
Content::Image(_) => "[image]".to_string(),
|
||||
})
|
||||
.collect();
|
||||
parts.push(format!("tool responded with: {}", results.join(" ")));
|
||||
}
|
||||
Err(e) => {
|
||||
parts.push(format!("tool error: {}", e));
|
||||
}
|
||||
},
|
||||
_ => {} // ignore other variants
|
||||
}
|
||||
}
|
||||
|
||||
let role = match m.role {
|
||||
Role::User => "User",
|
||||
Role::Assistant => "Assistant",
|
||||
};
|
||||
|
||||
format!("{}: {}", role, parts.join("; "))
|
||||
}
|
||||
|
||||
// Take the last two messages (in correct chronological order)
|
||||
let rendered: Vec<String> = messages
|
||||
.iter()
|
||||
.rev()
|
||||
.take(2)
|
||||
.map(render_message)
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
let system_prompt = build_system_prompt();
|
||||
|
||||
let user_msg_text = format!(
|
||||
"Here are the last two messages:\n{}\n\nTooltip:",
|
||||
rendered.join("\n")
|
||||
);
|
||||
|
||||
// Schema wrapping our tooltip string
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tooltip": { "type": "string" }
|
||||
},
|
||||
"required": ["tooltip"],
|
||||
"additionalProperties": false
|
||||
});
|
||||
|
||||
// Get the structured outputs
|
||||
let resp = generate_structured_outputs(
|
||||
provider_name,
|
||||
provider_config,
|
||||
&system_prompt,
|
||||
&[Message::user().with_text(&user_msg_text)],
|
||||
schema,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Pull out the tooltip field
|
||||
let obj = resp
|
||||
.data
|
||||
.as_object()
|
||||
.ok_or_else(|| ProviderError::ResponseParseError("Expected JSON object".into()))?;
|
||||
|
||||
let tooltip = obj
|
||||
.get("tooltip")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
ProviderError::ResponseParseError("Missing or non-string `tooltip` field".into())
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
Ok(tooltip)
|
||||
}
|
||||
15
crates/goose-llm/src/lib.rs
Normal file
15
crates/goose-llm/src/lib.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
uniffi::setup_scaffolding!();
|
||||
|
||||
mod completion;
|
||||
pub mod extractors;
|
||||
pub mod message;
|
||||
mod model;
|
||||
mod prompt_template;
|
||||
pub mod providers;
|
||||
mod structured_outputs;
|
||||
pub mod types;
|
||||
|
||||
pub use completion::completion;
|
||||
pub use message::Message;
|
||||
pub use model::ModelConfig;
|
||||
pub use structured_outputs::generate_structured_outputs;
|
||||
84
crates/goose-llm/src/message/contents.rs
Normal file
84
crates/goose-llm/src/message/contents.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use std::{iter::FromIterator, ops::Deref};
|
||||
|
||||
use crate::message::MessageContent;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smallvec::SmallVec;
|
||||
|
||||
/// Holds the heterogeneous fragments that make up one chat message.
|
||||
///
|
||||
/// * Up to two items are stored inline on the stack.
|
||||
/// * Falls back to a heap allocation only when necessary.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(transparent)]
|
||||
pub struct Contents(SmallVec<[MessageContent; 2]>);
|
||||
|
||||
impl Contents {
|
||||
/*----------------------------------------------------------
|
||||
* 1-line ergonomic helpers
|
||||
*---------------------------------------------------------*/
|
||||
|
||||
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, MessageContent> {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
pub fn push(&mut self, item: impl Into<MessageContent>) {
|
||||
self.0.push(item.into());
|
||||
}
|
||||
|
||||
pub fn texts(&self) -> impl Iterator<Item = &str> {
|
||||
self.0.iter().filter_map(|c| c.as_text())
|
||||
}
|
||||
|
||||
pub fn concat_text_str(&self) -> String {
|
||||
self.texts().collect::<Vec<_>>().join("\n")
|
||||
}
|
||||
|
||||
/// Returns `true` if *any* item satisfies the predicate.
|
||||
pub fn any_is<P>(&self, pred: P) -> bool
|
||||
where
|
||||
P: FnMut(&MessageContent) -> bool,
|
||||
{
|
||||
self.iter().any(pred)
|
||||
}
|
||||
|
||||
/// Returns `true` if *every* item satisfies the predicate.
|
||||
pub fn all_are<P>(&self, pred: P) -> bool
|
||||
where
|
||||
P: FnMut(&MessageContent) -> bool,
|
||||
{
|
||||
self.iter().all(pred)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<MessageContent>> for Contents {
|
||||
fn from(v: Vec<MessageContent>) -> Self {
|
||||
Contents(SmallVec::from_vec(v))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<MessageContent> for Contents {
|
||||
fn from_iter<I: IntoIterator<Item = MessageContent>>(iter: I) -> Self {
|
||||
Contents(SmallVec::from_iter(iter))
|
||||
}
|
||||
}
|
||||
|
||||
/*--------------------------------------------------------------
|
||||
* Allow &message.content to behave like a slice of fragments.
|
||||
*-------------------------------------------------------------*/
|
||||
impl Deref for Contents {
|
||||
type Target = [MessageContent];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
// — Register the contents type with UniFFI, converting to/from Vec<MessageContent> —
|
||||
// We need to do this because UniFFI’s FFI layer supports only primitive buffers (here Vec<u8>),
|
||||
uniffi::custom_type!(Contents, Vec<MessageContent>, {
|
||||
lower: |contents: &Contents| {
|
||||
contents.0.to_vec()
|
||||
},
|
||||
try_lift: |contents: Vec<MessageContent>| {
|
||||
Ok(Contents::from(contents))
|
||||
},
|
||||
});
|
||||
240
crates/goose-llm/src/message/message_content.rs
Normal file
240
crates/goose-llm/src/message/message_content.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
|
||||
use crate::message::tool_result_serde;
|
||||
use crate::types::core::{Content, ImageContent, TextContent, ToolCall, ToolResult};
|
||||
|
||||
// — Newtype wrappers (local structs) so we satisfy Rust’s orphan rules —
|
||||
// We need these because we can’t implement UniFFI’s FfiConverter directly on a type alias.
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ToolRequestToolCall(#[serde(with = "tool_result_serde")] pub ToolResult<ToolCall>);
|
||||
|
||||
impl ToolRequestToolCall {
|
||||
pub fn as_result(&self) -> &ToolResult<ToolCall> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl std::ops::Deref for ToolRequestToolCall {
|
||||
type Target = ToolResult<ToolCall>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl From<Result<ToolCall, crate::types::core::ToolError>> for ToolRequestToolCall {
|
||||
fn from(res: Result<ToolCall, crate::types::core::ToolError>) -> Self {
|
||||
ToolRequestToolCall(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ToolResponseToolResult(
|
||||
#[serde(with = "tool_result_serde")] pub ToolResult<Vec<Content>>,
|
||||
);
|
||||
|
||||
impl ToolResponseToolResult {
|
||||
pub fn as_result(&self) -> &ToolResult<Vec<Content>> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl std::ops::Deref for ToolResponseToolResult {
|
||||
type Target = ToolResult<Vec<Content>>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl From<Result<Vec<Content>, crate::types::core::ToolError>> for ToolResponseToolResult {
|
||||
fn from(res: Result<Vec<Content>, crate::types::core::ToolError>) -> Self {
|
||||
ToolResponseToolResult(res)
|
||||
}
|
||||
}
|
||||
|
||||
// — Register the newtypes with UniFFI, converting via JSON strings —
|
||||
// UniFFI’s FFI layer supports only primitive buffers (here String), so we JSON-serialize
|
||||
// through our `tool_result_serde` to preserve the same success/error schema on both sides.
|
||||
|
||||
uniffi::custom_type!(ToolRequestToolCall, String, {
|
||||
lower: |obj| {
|
||||
serde_json::to_string(&obj.0).unwrap()
|
||||
},
|
||||
try_lift: |val| {
|
||||
Ok(serde_json::from_str(&val).unwrap() )
|
||||
},
|
||||
});
|
||||
|
||||
uniffi::custom_type!(ToolResponseToolResult, String, {
|
||||
lower: |obj| {
|
||||
serde_json::to_string(&obj.0).unwrap()
|
||||
},
|
||||
try_lift: |val| {
|
||||
Ok(serde_json::from_str(&val).unwrap() )
|
||||
},
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolRequest {
|
||||
pub id: String,
|
||||
pub tool_call: ToolRequestToolCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolResponse {
|
||||
pub id: String,
|
||||
pub tool_result: ToolResponseToolResult,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct ThinkingContent {
|
||||
pub thinking: String,
|
||||
pub signature: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct RedactedThinkingContent {
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
|
||||
/// Content passed inside a message, which can be both simple content and tool content
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum MessageContent {
|
||||
Text(TextContent),
|
||||
Image(ImageContent),
|
||||
ToolReq(ToolRequest),
|
||||
ToolResp(ToolResponse),
|
||||
Thinking(ThinkingContent),
|
||||
RedactedThinking(RedactedThinkingContent),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn text<S: Into<String>>(text: S) -> Self {
|
||||
MessageContent::Text(TextContent { text: text.into() })
|
||||
}
|
||||
|
||||
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
|
||||
MessageContent::Image(ImageContent {
|
||||
data: data.into(),
|
||||
mime_type: mime_type.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tool_request<S: Into<String>>(id: S, tool_call: ToolRequestToolCall) -> Self {
|
||||
MessageContent::ToolReq(ToolRequest {
|
||||
id: id.into(),
|
||||
tool_call,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tool_response<S: Into<String>>(id: S, tool_result: ToolResponseToolResult) -> Self {
|
||||
MessageContent::ToolResp(ToolResponse {
|
||||
id: id.into(),
|
||||
tool_result,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
|
||||
MessageContent::Thinking(ThinkingContent {
|
||||
thinking: thinking.into(),
|
||||
signature: signature.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
|
||||
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
|
||||
}
|
||||
|
||||
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
|
||||
if let MessageContent::ToolReq(ref tool_request) = self {
|
||||
Some(tool_request)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response(&self) -> Option<&ToolResponse> {
|
||||
if let MessageContent::ToolResp(ref tool_response) = self {
|
||||
Some(tool_response)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response_text(&self) -> Option<String> {
|
||||
if let Some(tool_response) = self.as_tool_response() {
|
||||
if let Ok(contents) = &tool_response.tool_result.0 {
|
||||
let texts: Vec<String> = contents
|
||||
.iter()
|
||||
.filter_map(|content| content.as_text().map(String::from))
|
||||
.collect();
|
||||
if !texts.is_empty() {
|
||||
return Some(texts.join("\n"));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn as_tool_request_id(&self) -> Option<&str> {
|
||||
if let Self::ToolReq(r) = self {
|
||||
Some(&r.id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response_id(&self) -> Option<&str> {
|
||||
if let Self::ToolResp(r) = self {
|
||||
Some(&r.id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the text content if this is a TextContent variant
|
||||
pub fn as_text(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessageContent::Text(text) => Some(&text.text),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the thinking content if this is a ThinkingContent variant
|
||||
pub fn as_thinking(&self) -> Option<&ThinkingContent> {
|
||||
match self {
|
||||
MessageContent::Thinking(thinking) => Some(thinking),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the redacted thinking content if this is a RedactedThinkingContent variant
|
||||
pub fn as_redacted_thinking(&self) -> Option<&RedactedThinkingContent> {
|
||||
match self {
|
||||
MessageContent::RedactedThinking(redacted) => Some(redacted),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_text(&self) -> bool {
|
||||
matches!(self, Self::Text(_))
|
||||
}
|
||||
pub fn is_image(&self) -> bool {
|
||||
matches!(self, Self::Image(_))
|
||||
}
|
||||
pub fn is_tool_request(&self) -> bool {
|
||||
matches!(self, Self::ToolReq(_))
|
||||
}
|
||||
pub fn is_tool_response(&self) -> bool {
|
||||
matches!(self, Self::ToolResp(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Content> for MessageContent {
|
||||
fn from(content: Content) -> Self {
|
||||
match content {
|
||||
Content::Text(text) => MessageContent::Text(text),
|
||||
Content::Image(image) => MessageContent::Image(image),
|
||||
}
|
||||
}
|
||||
}
|
||||
284
crates/goose-llm/src/message/mod.rs
Normal file
284
crates/goose-llm/src/message/mod.rs
Normal file
@@ -0,0 +1,284 @@
|
||||
//! Messages which represent the content sent back and forth to LLM provider
|
||||
//!
|
||||
//! We use these messages in the agent code, and interfaces which interact with
|
||||
//! the agent. That let's us reuse message histories across different interfaces.
|
||||
//!
|
||||
//! The content of the messages uses MCP types to avoid additional conversions
|
||||
//! when interacting with MCP servers.
|
||||
|
||||
mod contents;
|
||||
mod message_content;
|
||||
mod tool_result_serde;
|
||||
|
||||
pub use contents::Contents;
|
||||
pub use message_content::{
|
||||
MessageContent, RedactedThinkingContent, ThinkingContent, ToolRequest, ToolRequestToolCall,
|
||||
ToolResponse, ToolResponseToolResult,
|
||||
};
|
||||
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::types::core::Role;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
/// A message to or from an LLM
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub created: i64,
|
||||
pub content: Contents,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: Role) -> Self {
|
||||
Self {
|
||||
role,
|
||||
created: Utc::now().timestamp_millis(),
|
||||
content: Contents::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new user message with the current timestamp
|
||||
pub fn user() -> Self {
|
||||
Self::new(Role::User)
|
||||
}
|
||||
|
||||
/// Create a new assistant message with the current timestamp
|
||||
pub fn assistant() -> Self {
|
||||
Self::new(Role::Assistant)
|
||||
}
|
||||
|
||||
/// Add any item that implements Into<MessageContent> to the message
|
||||
pub fn with_content(mut self, item: impl Into<MessageContent>) -> Self {
|
||||
self.content.push(item);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add text content to the message
|
||||
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
|
||||
self.with_content(MessageContent::text(text))
|
||||
}
|
||||
|
||||
/// Add image content to the message
|
||||
pub fn with_image<S: Into<String>, T: Into<String>>(self, data: S, mime_type: T) -> Self {
|
||||
self.with_content(MessageContent::image(data, mime_type))
|
||||
}
|
||||
|
||||
/// Add a tool request to the message
|
||||
pub fn with_tool_request<S: Into<String>, T: Into<ToolRequestToolCall>>(
|
||||
self,
|
||||
id: S,
|
||||
tool_call: T,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::tool_request(id, tool_call.into()))
|
||||
}
|
||||
|
||||
/// Add a tool response to the message
|
||||
pub fn with_tool_response<S: Into<String>>(
|
||||
self,
|
||||
id: S,
|
||||
result: ToolResponseToolResult,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::tool_response(id, result))
|
||||
}
|
||||
|
||||
/// Add thinking content to the message
|
||||
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
|
||||
self,
|
||||
thinking: S1,
|
||||
signature: S2,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::thinking(thinking, signature))
|
||||
}
|
||||
|
||||
/// Add redacted thinking content to the message
|
||||
pub fn with_redacted_thinking<S: Into<String>>(self, data: S) -> Self {
|
||||
self.with_content(MessageContent::redacted_thinking(data))
|
||||
}
|
||||
|
||||
/// Check if the message is a tool call
|
||||
pub fn contains_tool_call(&self) -> bool {
|
||||
self.content.any_is(MessageContent::is_tool_request)
|
||||
}
|
||||
|
||||
/// Check if the message is a tool response
|
||||
pub fn contains_tool_response(&self) -> bool {
|
||||
self.content.any_is(MessageContent::is_tool_response)
|
||||
}
|
||||
|
||||
/// Check if the message contains only text content
|
||||
pub fn has_only_text_content(&self) -> bool {
|
||||
self.content.all_are(MessageContent::is_text)
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from ToolRequest messages
|
||||
pub fn tool_request_ids(&self) -> HashSet<&str> {
|
||||
self.content
|
||||
.iter()
|
||||
.filter_map(MessageContent::as_tool_request_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from ToolResponse messages
|
||||
pub fn tool_response_ids(&self) -> HashSet<&str> {
|
||||
self.content
|
||||
.iter()
|
||||
.filter_map(MessageContent::as_tool_response_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from the message
|
||||
pub fn tool_ids(&self) -> HashSet<&str> {
|
||||
self.tool_request_ids()
|
||||
.into_iter()
|
||||
.chain(self.tool_response_ids())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::*;
|
||||
use crate::types::core::{ToolCall, ToolError};
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization() {
|
||||
let message = Message::assistant()
|
||||
.with_text("Hello, I'll help you with that.")
|
||||
.with_tool_request(
|
||||
"tool123",
|
||||
Ok(ToolCall::new("test_tool", json!({"param": "value"})).into()),
|
||||
);
|
||||
|
||||
let json_str = serde_json::to_string_pretty(&message).unwrap();
|
||||
println!("Serialized message: {}", json_str);
|
||||
|
||||
// Parse back to Value to check structure
|
||||
let value: Value = serde_json::from_str(&json_str).unwrap();
|
||||
println!(
|
||||
"Read back serialized message: {}",
|
||||
serde_json::to_string_pretty(&value).unwrap()
|
||||
);
|
||||
|
||||
// Check top-level fields
|
||||
assert_eq!(value["role"], "assistant");
|
||||
assert!(value["created"].is_i64());
|
||||
assert!(value["content"].is_array());
|
||||
|
||||
// Check content items
|
||||
let content = &value["content"];
|
||||
|
||||
// First item should be text
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
|
||||
|
||||
// Second item should be toolRequest
|
||||
assert_eq!(content[1]["type"], "toolReq");
|
||||
assert_eq!(content[1]["id"], "tool123");
|
||||
|
||||
// Check tool_call serialization
|
||||
assert_eq!(content[1]["toolCall"]["status"], "success");
|
||||
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
|
||||
assert_eq!(
|
||||
content[1]["toolCall"]["value"]["arguments"]["param"],
|
||||
"value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_serialization() {
|
||||
let message = Message::assistant().with_tool_request(
|
||||
"tool123",
|
||||
Err(ToolError::ExecutionError(
|
||||
"Something went wrong".to_string(),
|
||||
)),
|
||||
);
|
||||
|
||||
let json_str = serde_json::to_string_pretty(&message).unwrap();
|
||||
println!("Serialized error: {}", json_str);
|
||||
|
||||
// Parse back to Value to check structure
|
||||
let value: Value = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
// Check tool_call serialization with error
|
||||
let tool_call = &value["content"][0]["toolCall"];
|
||||
assert_eq!(tool_call["status"], "error");
|
||||
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialization() {
|
||||
// Create a JSON string with our new format
|
||||
let json_str = r#"{
|
||||
"role": "assistant",
|
||||
"created": 1740171566,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "I'll help you with that."
|
||||
},
|
||||
{
|
||||
"type": "toolReq",
|
||||
"id": "tool123",
|
||||
"toolCall": {
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param": "value"},
|
||||
"needsApproval": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let message: Message = serde_json::from_str(json_str).unwrap();
|
||||
|
||||
assert_eq!(message.role, Role::Assistant);
|
||||
assert_eq!(message.created, 1740171566);
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
// Check first content item
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "I'll help you with that.");
|
||||
} else {
|
||||
panic!("Expected Text content");
|
||||
}
|
||||
|
||||
// Check second content item
|
||||
if let MessageContent::ToolReq(req) = &message.content[1] {
|
||||
assert_eq!(req.id, "tool123");
|
||||
if let Ok(tool_call) = req.tool_call.as_result() {
|
||||
assert_eq!(tool_call.name, "test_tool");
|
||||
assert_eq!(tool_call.arguments, json!({"param": "value"}));
|
||||
} else {
|
||||
panic!("Expected successful tool call");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_with_text() {
|
||||
let message = Message::user().with_text("Hello");
|
||||
assert_eq!(message.content.concat_text_str(), "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_with_tool_request() {
|
||||
let tool_call = Ok(ToolCall::new("test_tool", json!({})));
|
||||
|
||||
let message = Message::assistant().with_tool_request("req1", tool_call);
|
||||
assert!(message.contains_tool_call());
|
||||
assert!(!message.contains_tool_response());
|
||||
|
||||
let ids = message.tool_ids();
|
||||
assert_eq!(ids.len(), 1);
|
||||
assert!(ids.contains("req1"));
|
||||
}
|
||||
}
|
||||
64
crates/goose-llm/src/message/tool_result_serde.rs
Normal file
64
crates/goose-llm/src/message/tool_result_serde.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
use crate::types::core::{ToolError, ToolResult};
|
||||
|
||||
pub fn serialize<T, S>(value: &ToolResult<T>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
T: Serialize,
|
||||
S: Serializer,
|
||||
{
|
||||
match value {
|
||||
Ok(val) => {
|
||||
let mut state = serializer.serialize_struct("ToolResult", 2)?;
|
||||
state.serialize_field("status", "success")?;
|
||||
state.serialize_field("value", val)?;
|
||||
state.end()
|
||||
}
|
||||
Err(err) => {
|
||||
let mut state = serializer.serialize_struct("ToolResult", 2)?;
|
||||
state.serialize_field("status", "error")?;
|
||||
state.serialize_field("error", &err.to_string())?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For deserialization, let's use a simpler approach that works with the format we're serializing to
|
||||
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<ToolResult<T>, D::Error>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
// Define a helper enum to handle the two possible formats
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ResultFormat<T> {
|
||||
Success { status: String, value: T },
|
||||
Error { status: String, error: String },
|
||||
}
|
||||
|
||||
let format = ResultFormat::deserialize(deserializer)?;
|
||||
|
||||
match format {
|
||||
ResultFormat::Success { status, value } => {
|
||||
if status == "success" {
|
||||
Ok(Ok(value))
|
||||
} else {
|
||||
Err(serde::de::Error::custom(format!(
|
||||
"Expected status 'success', got '{}'",
|
||||
status
|
||||
)))
|
||||
}
|
||||
}
|
||||
ResultFormat::Error { status, error } => {
|
||||
if status == "error" {
|
||||
Ok(Err(ToolError::ExecutionError(error)))
|
||||
} else {
|
||||
Err(serde::de::Error::custom(format!(
|
||||
"Expected status 'error', got '{}'",
|
||||
status
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
119
crates/goose-llm/src/model.rs
Normal file
119
crates/goose-llm/src/model.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_CONTEXT_LIMIT: u32 = 128_000;
|
||||
|
||||
/// Configuration for model-specific settings and limits
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct ModelConfig {
|
||||
/// The name of the model to use
|
||||
pub model_name: String,
|
||||
/// Optional explicit context limit that overrides any defaults
|
||||
pub context_limit: Option<u32>,
|
||||
/// Optional temperature setting (0.0 - 1.0)
|
||||
pub temperature: Option<f32>,
|
||||
/// Optional maximum tokens to generate
|
||||
pub max_tokens: Option<i32>,
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
/// Create a new ModelConfig with the specified model name
|
||||
///
|
||||
/// The context limit is set with the following precedence:
|
||||
/// 1. Explicit context_limit if provided in config
|
||||
/// 2. Model-specific default based on model name
|
||||
/// 3. Global default (128_000) (in get_context_limit)
|
||||
pub fn new(model_name: String) -> Self {
|
||||
let context_limit = Self::get_model_specific_limit(&model_name);
|
||||
|
||||
Self {
|
||||
model_name,
|
||||
context_limit,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get model-specific context limit based on model name
|
||||
fn get_model_specific_limit(model_name: &str) -> Option<u32> {
|
||||
// Implement some sensible defaults
|
||||
match model_name {
|
||||
// OpenAI models, https://platform.openai.com/docs/models#models-overview
|
||||
name if name.contains("gpt-4o") => Some(128_000),
|
||||
name if name.contains("gpt-4-turbo") => Some(128_000),
|
||||
|
||||
// Anthropic models, https://docs.anthropic.com/en/docs/about-claude/models
|
||||
name if name.contains("claude-3") => Some(200_000),
|
||||
name if name.contains("claude-4") => Some(200_000),
|
||||
|
||||
// Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1
|
||||
name if name.contains("llama3.2") => Some(128_000),
|
||||
name if name.contains("llama3.3") => Some(128_000),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set an explicit context limit
|
||||
pub fn with_context_limit(mut self, limit: Option<u32>) -> Self {
|
||||
// Default is None and therefore DEFAULT_CONTEXT_LIMIT, only set
|
||||
// if input is Some to allow passing through with_context_limit in
|
||||
// configuration cases
|
||||
if limit.is_some() {
|
||||
self.context_limit = limit;
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the temperature
|
||||
pub fn with_temperature(mut self, temp: Option<f32>) -> Self {
|
||||
self.temperature = temp;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the max tokens
|
||||
pub fn with_max_tokens(mut self, tokens: Option<i32>) -> Self {
|
||||
self.max_tokens = tokens;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the context_limit for the current model
|
||||
/// If none are defined, use the DEFAULT_CONTEXT_LIMIT
|
||||
pub fn context_limit(&self) -> u32 {
|
||||
self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_config_context_limits() {
|
||||
// Test explicit limit
|
||||
let config =
|
||||
ModelConfig::new("claude-3-opus".to_string()).with_context_limit(Some(150_000));
|
||||
assert_eq!(config.context_limit(), 150_000);
|
||||
|
||||
// Test model-specific defaults
|
||||
let config = ModelConfig::new("claude-3-opus".to_string());
|
||||
assert_eq!(config.context_limit(), 200_000);
|
||||
|
||||
let config = ModelConfig::new("gpt-4-turbo".to_string());
|
||||
assert_eq!(config.context_limit(), 128_000);
|
||||
|
||||
// Test fallback to default
|
||||
let config = ModelConfig::new("unknown-model".to_string());
|
||||
assert_eq!(config.context_limit(), DEFAULT_CONTEXT_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config_settings() {
|
||||
let config = ModelConfig::new("test-model".to_string())
|
||||
.with_temperature(Some(0.7))
|
||||
.with_max_tokens(Some(1000))
|
||||
.with_context_limit(Some(50_000));
|
||||
|
||||
assert_eq!(config.temperature, Some(0.7));
|
||||
assert_eq!(config.max_tokens, Some(1000));
|
||||
assert_eq!(config.context_limit, Some(50_000));
|
||||
}
|
||||
}
|
||||
115
crates/goose-llm/src/prompt_template.rs
Normal file
115
crates/goose-llm/src/prompt_template.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use std::{
|
||||
path::PathBuf,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use include_dir::{include_dir, Dir};
|
||||
use minijinja::{Environment, Error as MiniJinjaError, Value as MJValue};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::Serialize;
|
||||
|
||||
/// This directory will be embedded into the final binary.
|
||||
/// Typically used to store "core" or "system" prompts.
|
||||
static CORE_PROMPTS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/prompts");
|
||||
|
||||
/// A global MiniJinja environment storing the "core" prompts.
|
||||
///
|
||||
/// - Loaded at startup from the `CORE_PROMPTS_DIR`.
|
||||
/// - Ideal for "system" templates that don't change often.
|
||||
/// - *Not* used for extension prompts (which are ephemeral).
|
||||
static GLOBAL_ENV: Lazy<Arc<RwLock<Environment<'static>>>> = Lazy::new(|| {
|
||||
let mut env = Environment::new();
|
||||
|
||||
// Pre-load all core templates from the embedded dir.
|
||||
for file in CORE_PROMPTS_DIR.files() {
|
||||
let name = file.path().to_string_lossy().to_string();
|
||||
let source = String::from_utf8_lossy(file.contents()).to_string();
|
||||
|
||||
// Since we're using 'static lifetime for the Environment, we need to ensure
|
||||
// the strings we add as templates live for the entire program duration.
|
||||
// We can achieve this by leaking the strings (acceptable for initialization).
|
||||
let static_name: &'static str = Box::leak(name.into_boxed_str());
|
||||
let static_source: &'static str = Box::leak(source.into_boxed_str());
|
||||
|
||||
if let Err(e) = env.add_template(static_name, static_source) {
|
||||
println!("Failed to add template {}: {}", static_name, e);
|
||||
}
|
||||
}
|
||||
|
||||
Arc::new(RwLock::new(env))
|
||||
});
|
||||
|
||||
/// Renders a prompt from the global environment by name.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `template_name` - The name of the template (usually the file path or a custom ID).
|
||||
/// * `context_data` - Data to be inserted into the template (must be `Serialize`).
|
||||
pub fn render_global_template<T: Serialize>(
|
||||
template_name: &str,
|
||||
context_data: &T,
|
||||
) -> Result<String, MiniJinjaError> {
|
||||
let env = GLOBAL_ENV.read().expect("GLOBAL_ENV lock poisoned");
|
||||
let tmpl = env.get_template(template_name)?;
|
||||
let ctx = MJValue::from_serialize(context_data);
|
||||
let rendered = tmpl.render(ctx)?;
|
||||
Ok(rendered.trim().to_string())
|
||||
}
|
||||
|
||||
/// Renders a file from `CORE_PROMPTS_DIR` within the global environment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `template_file` - The file path within the embedded directory (e.g. "system.md").
|
||||
/// * `context_data` - Data to be inserted into the template (must be `Serialize`).
|
||||
///
|
||||
/// This function **assumes** the file is already in `CORE_PROMPTS_DIR`. If it wasn't
|
||||
/// added to the global environment at startup (due to parse errors, etc.), this will error out.
|
||||
pub fn render_global_file<T: Serialize>(
|
||||
template_file: impl Into<PathBuf>,
|
||||
context_data: &T,
|
||||
) -> Result<String, MiniJinjaError> {
|
||||
let file_path = template_file.into();
|
||||
let template_name = file_path.to_string_lossy().to_string();
|
||||
|
||||
render_global_template(&template_name, context_data)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// For convenience in tests, define a small struct or use a HashMap to provide context.
|
||||
#[derive(Serialize)]
|
||||
struct TestContext {
|
||||
name: String,
|
||||
age: u32,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_file_render() {
|
||||
// "mock.md" should exist in the embedded CORE_PROMPTS_DIR
|
||||
// and have placeholders for `name` and `age`.
|
||||
let context = TestContext {
|
||||
name: "Alice".to_string(),
|
||||
age: 30,
|
||||
};
|
||||
|
||||
let result = render_global_file("mock.md", &context).unwrap();
|
||||
// Assume mock.md content is something like:
|
||||
// "This prompt is only used for testing.\n\nHello, {{ name }}! You are {{ age }} years old."
|
||||
assert_eq!(
|
||||
result,
|
||||
"This prompt is only used for testing.\n\nHello, Alice! You are 30 years old."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_file_not_found() {
|
||||
let context = TestContext {
|
||||
name: "Unused".to_string(),
|
||||
age: 99,
|
||||
};
|
||||
|
||||
let result = render_global_file("non_existent.md", &context);
|
||||
assert!(result.is_err(), "Should fail because file is missing");
|
||||
}
|
||||
}
|
||||
3
crates/goose-llm/src/prompts/mock.md
Normal file
3
crates/goose-llm/src/prompts/mock.md
Normal file
@@ -0,0 +1,3 @@
|
||||
This prompt is only used for testing.
|
||||
|
||||
Hello, {{ name }}! You are {{ age }} years old.
|
||||
34
crates/goose-llm/src/prompts/system.md
Normal file
34
crates/goose-llm/src/prompts/system.md
Normal file
@@ -0,0 +1,34 @@
|
||||
{{system_preamble}}
|
||||
|
||||
The current date is {{current_date}}.
|
||||
|
||||
Goose uses LLM providers with tool calling capability. You can be used with different language models (gpt-4o, claude-3.5-sonnet, o1, llama-3.2, deepseek-r1, etc).
|
||||
These models have varying knowledge cut-off dates depending on when they were trained, but typically it's between 5-10 months prior to the current date.
|
||||
|
||||
# Extensions
|
||||
|
||||
Extensions allow other applications to provide context to Goose. Extensions connect Goose to different data sources and tools.
|
||||
|
||||
{% if (extensions is defined) and extensions %}
|
||||
Because you dynamically load extensions, your conversation history may refer
|
||||
to interactions with extensions that are not currently active. The currently
|
||||
active extensions are below. Each of these extensions provides tools that are
|
||||
in your tool specification.
|
||||
|
||||
{% for extension in extensions %}
|
||||
## {{extension.name}}
|
||||
{% if extension.instructions %}### Instructions
|
||||
{{extension.instructions}}{% endif %}
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
No extensions are defined. You should let the user know that they should add extensions.{% endif %}
|
||||
|
||||
# Response Guidelines
|
||||
|
||||
- Use Markdown formatting for all responses.
|
||||
- Follow best practices for Markdown, including:
|
||||
- Using headers for organization.
|
||||
- Bullet points for lists.
|
||||
- Links formatted correctly, either as linked text (e.g., [this is linked text](https://example.com)) or automatic links using angle brackets (e.g., <http://example.com/>).
|
||||
- For code examples, use fenced code blocks by placing triple backticks (` ``` `) before and after the code. Include the language identifier after the opening backticks (e.g., ` ```python `) to enable syntax highlighting.
|
||||
- Ensure clarity, conciseness, and proper formatting to enhance readability and usability.
|
||||
131
crates/goose-llm/src/providers/base.rs
Normal file
131
crates/goose-llm/src/providers/base.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::errors::ProviderError;
|
||||
use crate::{message::Message, types::core::Tool};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: Option<i32>,
|
||||
pub output_tokens: Option<i32>,
|
||||
pub total_tokens: Option<i32>,
|
||||
}
|
||||
|
||||
impl Usage {
|
||||
pub fn new(
|
||||
input_tokens: Option<i32>,
|
||||
output_tokens: Option<i32>,
|
||||
total_tokens: Option<i32>,
|
||||
) -> Self {
|
||||
Self {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
total_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, uniffi::Record)]
|
||||
pub struct ProviderCompleteResponse {
|
||||
pub message: Message,
|
||||
pub model: String,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
impl ProviderCompleteResponse {
|
||||
pub fn new(message: Message, model: String, usage: Usage) -> Self {
|
||||
Self {
|
||||
message,
|
||||
model,
|
||||
usage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response from a structured‐extraction call
|
||||
#[derive(Debug, Clone, uniffi::Record)]
|
||||
pub struct ProviderExtractResponse {
|
||||
/// The extracted JSON object
|
||||
pub data: serde_json::Value,
|
||||
/// Which model produced it
|
||||
pub model: String,
|
||||
/// Token usage stats
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
impl ProviderExtractResponse {
|
||||
pub fn new(data: serde_json::Value, model: String, usage: Usage) -> Self {
|
||||
Self { data, model, usage }
|
||||
}
|
||||
}
|
||||
|
||||
/// Base trait for AI providers (OpenAI, Anthropic, etc)
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Generate the next message using the configured model and other parameters
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `system` - The system prompt that guides the model's behavior
|
||||
/// * `messages` - The conversation history as a sequence of messages
|
||||
/// * `tools` - Optional list of tools the model can use
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple containing the model's response message and provider usage statistics
|
||||
///
|
||||
/// # Errors
|
||||
/// ProviderError
|
||||
/// - It's important to raise ContextLengthExceeded correctly since agent handles it
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<ProviderCompleteResponse, ProviderError>;
|
||||
|
||||
/// Structured extraction: always JSON‐Schema
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `system` – system prompt guiding the extraction task
|
||||
/// * `messages` – conversation history
|
||||
/// * `schema` – a JSON‐Schema for the expected output.
|
||||
/// Will set strict=true for OpenAI & Databricks.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`.
|
||||
///
|
||||
/// # Errors
|
||||
/// * `ProviderError::ContextLengthExceeded` if the prompt is too large
|
||||
/// * other `ProviderError` variants for API/network failures
|
||||
async fn extract(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &serde_json::Value,
|
||||
) -> Result<ProviderExtractResponse, ProviderError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_usage_creation() {
|
||||
let usage = Usage::new(Some(10), Some(20), Some(30));
|
||||
assert_eq!(usage.input_tokens, Some(10));
|
||||
assert_eq!(usage.output_tokens, Some(20));
|
||||
assert_eq!(usage.total_tokens, Some(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_complete_response_creation() {
|
||||
let message = Message::user().with_text("Hello, world!");
|
||||
let usage = Usage::new(Some(10), Some(20), Some(30));
|
||||
let response =
|
||||
ProviderCompleteResponse::new(message.clone(), "test_model".to_string(), usage.clone());
|
||||
|
||||
assert_eq!(response.message, message);
|
||||
assert_eq!(response.model, "test_model");
|
||||
assert_eq!(response.usage, usage);
|
||||
}
|
||||
}
|
||||
298
crates/goose-llm/src/providers/databricks.rs
Normal file
298
crates/goose-llm/src/providers/databricks.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use url::Url;
|
||||
|
||||
use super::{
|
||||
errors::ProviderError,
|
||||
formats::databricks::{create_request, get_usage, response_to_message},
|
||||
utils::{get_env, get_model, ImageFormat},
|
||||
};
|
||||
use crate::{
|
||||
message::Message,
|
||||
model::ModelConfig,
|
||||
providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
|
||||
types::core::Tool,
|
||||
};
|
||||
|
||||
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet";
|
||||
// Databricks can passthrough to a wide range of models, we only provide the default
|
||||
pub const _DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
||||
"databricks-meta-llama-3-3-70b-instruct",
|
||||
"databricks-claude-3-7-sonnet",
|
||||
];
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
60
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatabricksProviderConfig {
|
||||
pub host: String,
|
||||
pub token: String,
|
||||
#[serde(default)]
|
||||
pub image_format: ImageFormat,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout: u64, // timeout in seconds
|
||||
}
|
||||
|
||||
impl DatabricksProviderConfig {
|
||||
pub fn new(host: String, token: String) -> Self {
|
||||
Self {
|
||||
host,
|
||||
token,
|
||||
image_format: ImageFormat::OpenAi,
|
||||
timeout: default_timeout(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env() -> Self {
|
||||
let host = get_env("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST");
|
||||
let token = get_env("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN");
|
||||
Self::new(host, token)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DatabricksProvider {
|
||||
config: DatabricksProviderConfig,
|
||||
model: ModelConfig,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl DatabricksProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Self {
|
||||
let config = DatabricksProviderConfig::from_env();
|
||||
DatabricksProvider::from_config(config, model)
|
||||
.expect("Failed to initialize DatabricksProvider")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DatabricksProvider {
|
||||
fn default() -> Self {
|
||||
let config = DatabricksProviderConfig::from_env();
|
||||
let model = ModelConfig::new(DATABRICKS_DEFAULT_MODEL.to_string());
|
||||
DatabricksProvider::from_config(config, model)
|
||||
.expect("Failed to initialize DatabricksProvider")
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabricksProvider {
|
||||
pub fn from_config(config: DatabricksProviderConfig, model: ModelConfig) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(config.timeout))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
model,
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let base_url = Url::parse(&self.config.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let path = format!("serving-endpoints/{}/invocations", self.model.model_name);
|
||||
let url = base_url.join(&path).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let auth_header = format!("Bearer {}", &self.config.token);
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", auth_header)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let payload: Option<Value> = response.json().await.ok();
|
||||
|
||||
match status {
|
||||
StatusCode::OK => payload.ok_or_else(|| {
|
||||
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
|
||||
}),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
Err(ProviderError::Authentication(format!(
|
||||
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||
Status: {}. Response: {:?}",
|
||||
status, payload
|
||||
)))
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
|
||||
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
|
||||
let payload_str = serde_json::to_string(&payload)
|
||||
.unwrap_or_default()
|
||||
.to_lowercase();
|
||||
let check_phrases = [
|
||||
"too long",
|
||||
"context length",
|
||||
"context_length_exceeded",
|
||||
"reduce the length",
|
||||
"token count",
|
||||
"exceeds",
|
||||
];
|
||||
if check_phrases.iter().any(|c| payload_str.contains(c)) {
|
||||
return Err(ProviderError::ContextLengthExceeded(payload_str));
|
||||
}
|
||||
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
// try to convert message to string, if that fails use external_model_message
|
||||
error_msg = payload
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.or_else(|| {
|
||||
payload
|
||||
.get("external_model_message")
|
||||
.and_then(|ext| ext.get("message"))
|
||||
.and_then(|m| m.as_str())
|
||||
})
|
||||
.unwrap_or("Unknown error")
|
||||
.to_string();
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed with status: {}. Message: {}",
|
||||
status, error_msg
|
||||
)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Err(ProviderError::ServerError(format!("{:?}", payload)))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed with status: {}",
|
||||
status
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for DatabricksProvider {
|
||||
#[tracing::instrument(
|
||||
skip(self, system, messages, tools),
|
||||
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||
)]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<ProviderCompleteResponse, ProviderError> {
|
||||
let mut payload = create_request(
|
||||
&self.model,
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
&self.config.image_format,
|
||||
)?;
|
||||
// Remove the model key which is part of the url with databricks
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload should have model key")
|
||||
.remove("model");
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(usage) => usage,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
tracing::debug!("Failed to get usage data: {}", e);
|
||||
Usage::default()
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let model = get_model(&response);
|
||||
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
|
||||
Ok(ProviderCompleteResponse::new(message, model, usage))
|
||||
}
|
||||
|
||||
async fn extract(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &Value,
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// 1. Build base payload (no tools)
|
||||
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
|
||||
|
||||
// 2. Inject strict JSON‐Schema wrapper
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload must be an object")
|
||||
.insert(
|
||||
"response_format".to_string(),
|
||||
json!({
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "extraction",
|
||||
"schema": schema,
|
||||
"strict": true
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// 3. Call OpenAI
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// 4. Extract the assistant’s `content` and parse it into JSON
|
||||
let msg = &response["choices"][0]["message"];
|
||||
let raw = msg.get("content").cloned().ok_or_else(|| {
|
||||
ProviderError::ResponseParseError("Missing content in extract response".into())
|
||||
})?;
|
||||
let data = match raw {
|
||||
Value::String(s) => serde_json::from_str(&s)
|
||||
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
|
||||
Value::Object(_) | Value::Array(_) => raw,
|
||||
other => {
|
||||
return Err(ProviderError::ResponseParseError(format!(
|
||||
"Unexpected content type: {:?}",
|
||||
other
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// 5. Gather usage & model info
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(u) => u,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
tracing::debug!("Failed to get usage in extract: {}", e);
|
||||
Usage::default()
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let model = get_model(&response);
|
||||
|
||||
Ok(ProviderExtractResponse::new(data, model, usage))
|
||||
}
|
||||
}
|
||||
144
crates/goose-llm/src/providers/errors.rs
Normal file
144
crates/goose-llm/src/providers/errors.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, uniffi::Error)]
|
||||
pub enum ProviderError {
|
||||
#[error("Authentication error: {0}")]
|
||||
Authentication(String),
|
||||
|
||||
#[error("Context length exceeded: {0}")]
|
||||
ContextLengthExceeded(String),
|
||||
|
||||
#[error("Rate limit exceeded: {0}")]
|
||||
RateLimitExceeded(String),
|
||||
|
||||
#[error("Server error: {0}")]
|
||||
ServerError(String),
|
||||
|
||||
#[error("Request failed: {0}")]
|
||||
RequestFailed(String),
|
||||
|
||||
#[error("Execution error: {0}")]
|
||||
ExecutionError(String),
|
||||
|
||||
#[error("Usage data error: {0}")]
|
||||
UsageError(String),
|
||||
|
||||
#[error("Invalid response: {0}")]
|
||||
ResponseParseError(String),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ProviderError {
|
||||
fn from(error: anyhow::Error) -> Self {
|
||||
ProviderError::ExecutionError(error.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for ProviderError {
|
||||
fn from(error: reqwest::Error) -> Self {
|
||||
ProviderError::ExecutionError(error.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
pub struct OpenAIError {
|
||||
#[serde(deserialize_with = "code_as_string")]
|
||||
pub code: Option<String>,
|
||||
pub message: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: Option<String>,
|
||||
}
|
||||
|
||||
fn code_as_string<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use std::fmt;
|
||||
|
||||
use serde::de::{self, Visitor};
|
||||
|
||||
struct CodeVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for CodeVisitor {
|
||||
type Value = Option<String>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string, a number, null, or none for the code field")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(Some(value.to_string()))
|
||||
}
|
||||
|
||||
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(Some(value.to_string()))
|
||||
}
|
||||
|
||||
fn visit_none<E>(self) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn visit_unit<E>(self) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_any(CodeVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_option(CodeVisitor)
|
||||
}
|
||||
|
||||
impl OpenAIError {
|
||||
pub fn is_context_length_exceeded(&self) -> bool {
|
||||
if let Some(code) = &self.code {
|
||||
code == "context_length_exceeded" || code == "string_above_max_length"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for OpenAIError {
|
||||
/// Format the error for display.
|
||||
/// E.g. {"message": "Invalid API key", "code": "invalid_api_key", "type": "client_error"}
|
||||
/// would be formatted as "Invalid API key (code: invalid_api_key, type: client_error)"
|
||||
/// and {"message": "Foo"} as just "Foo", etc.
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if let Some(message) = &self.message {
|
||||
write!(f, "{}", message)?;
|
||||
}
|
||||
let mut in_parenthesis = false;
|
||||
if let Some(code) = &self.code {
|
||||
write!(f, " (code: {}", code)?;
|
||||
in_parenthesis = true;
|
||||
}
|
||||
if let Some(typ) = &self.error_type {
|
||||
if in_parenthesis {
|
||||
write!(f, ", type: {}", typ)?;
|
||||
} else {
|
||||
write!(f, " (type: {}", typ)?;
|
||||
in_parenthesis = true;
|
||||
}
|
||||
}
|
||||
if in_parenthesis {
|
||||
write!(f, ")")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
29
crates/goose-llm/src/providers/factory.rs
Normal file
29
crates/goose-llm/src/providers/factory.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use super::{
|
||||
base::Provider,
|
||||
databricks::{DatabricksProvider, DatabricksProviderConfig},
|
||||
openai::{OpenAiProvider, OpenAiProviderConfig},
|
||||
};
|
||||
use crate::model::ModelConfig;
|
||||
|
||||
pub fn create(
|
||||
name: &str,
|
||||
provider_config: serde_json::Value,
|
||||
model: ModelConfig,
|
||||
) -> Result<Arc<dyn Provider>> {
|
||||
// We use Arc instead of Box to be able to clone for multiple async tasks
|
||||
match name {
|
||||
"openai" => {
|
||||
let config: OpenAiProviderConfig = serde_json::from_value(provider_config)?;
|
||||
Ok(Arc::new(OpenAiProvider::from_config(config, model)?))
|
||||
}
|
||||
"databricks" => {
|
||||
let config: DatabricksProviderConfig = serde_json::from_value(provider_config)?;
|
||||
Ok(Arc::new(DatabricksProvider::from_config(config, model)?))
|
||||
}
|
||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
}
|
||||
}
|
||||
1118
crates/goose-llm/src/providers/formats/databricks.rs
Normal file
1118
crates/goose-llm/src/providers/formats/databricks.rs
Normal file
File diff suppressed because it is too large
Load Diff
2
crates/goose-llm/src/providers/formats/mod.rs
Normal file
2
crates/goose-llm/src/providers/formats/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod databricks;
|
||||
pub mod openai;
|
||||
897
crates/goose-llm/src/providers/formats/openai.rs
Normal file
897
crates/goose-llm/src/providers/formats/openai.rs
Normal file
@@ -0,0 +1,897 @@
|
||||
use anyhow::{anyhow, Error};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::{
|
||||
message::{Message, MessageContent},
|
||||
model::ModelConfig,
|
||||
providers::{
|
||||
base::Usage,
|
||||
errors::ProviderError,
|
||||
utils::{
|
||||
convert_image, detect_image_path, is_valid_function_name, load_image_file,
|
||||
sanitize_function_name, ImageFormat,
|
||||
},
|
||||
},
|
||||
types::core::{Content, Role, Tool, ToolCall, ToolError},
|
||||
};
|
||||
|
||||
/// Convert internal Message format to OpenAI's API message specification
|
||||
/// some openai compatible endpoints use the anthropic image spec at the content level
|
||||
/// even though the message structure is otherwise following openai, the enum switches this
|
||||
pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<Value> {
|
||||
let mut messages_spec = Vec::new();
|
||||
for message in messages {
|
||||
let mut converted = json!({
|
||||
"role": message.role
|
||||
});
|
||||
|
||||
let mut output = Vec::new();
|
||||
|
||||
for content in message.content.iter() {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.text.is_empty() {
|
||||
// Check for image paths in the text
|
||||
if let Some(image_path) = detect_image_path(&text.text) {
|
||||
// Try to load and convert the image
|
||||
if let Ok(image) = load_image_file(image_path) {
|
||||
converted["content"] = json!([
|
||||
{"type": "text", "text": text.text},
|
||||
convert_image(&image, image_format)
|
||||
]);
|
||||
} else {
|
||||
// If image loading fails, just use the text
|
||||
converted["content"] = json!(text.text);
|
||||
}
|
||||
} else {
|
||||
converted["content"] = json!(text.text);
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::Thinking(_) => {
|
||||
// Thinking blocks are not directly used in OpenAI format
|
||||
continue;
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {
|
||||
// Redacted thinking blocks are not directly used in OpenAI format
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolReq(request) => match &request.tool_call.as_result() {
|
||||
Ok(tool_call) => {
|
||||
let sanitized_name = sanitize_function_name(&tool_call.name);
|
||||
let tool_calls = converted
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.entry("tool_calls")
|
||||
.or_insert(json!([]));
|
||||
|
||||
tool_calls.as_array_mut().unwrap().push(json!({
|
||||
"id": request.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sanitized_name,
|
||||
"arguments": tool_call.arguments.to_string(),
|
||||
}
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
output.push(json!({
|
||||
"role": "tool",
|
||||
"content": format!("Error: {}", e),
|
||||
"tool_call_id": request.id
|
||||
}));
|
||||
}
|
||||
},
|
||||
MessageContent::ToolResp(response) => {
|
||||
match &response.tool_result.0 {
|
||||
Ok(contents) => {
|
||||
// Process all content, replacing images with placeholder text
|
||||
let mut tool_content = Vec::new();
|
||||
let mut image_messages = Vec::new();
|
||||
|
||||
for content in contents {
|
||||
match content {
|
||||
Content::Image(image) => {
|
||||
// Add placeholder text in the tool response
|
||||
tool_content.push(Content::text("This tool result included an image that is uploaded in the next message."));
|
||||
|
||||
// Create a separate image message
|
||||
image_messages.push(json!({
|
||||
"role": "user",
|
||||
"content": [convert_image(image, image_format)]
|
||||
}));
|
||||
}
|
||||
_ => {
|
||||
tool_content.push(content.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
let tool_response_content: Value = json!(tool_content
|
||||
.iter()
|
||||
.map(|content| match content {
|
||||
Content::Text(text) => text.text.clone(),
|
||||
_ => String::new(),
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "));
|
||||
|
||||
// First add the tool response with all content
|
||||
output.push(json!({
|
||||
"role": "tool",
|
||||
"content": tool_response_content,
|
||||
"tool_call_id": response.id
|
||||
}));
|
||||
// Then add any image messages that need to follow
|
||||
output.extend(image_messages);
|
||||
}
|
||||
Err(e) => {
|
||||
// A tool result error is shown as output so the model can interpret the error message
|
||||
output.push(json!({
|
||||
"role": "tool",
|
||||
"content": format!("The tool call returned the following error:\n{}", e),
|
||||
"tool_call_id": response.id
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
// Handle direct image content
|
||||
converted["content"] = json!([convert_image(image, image_format)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if converted.get("content").is_some() || converted.get("tool_calls").is_some() {
|
||||
output.insert(0, converted);
|
||||
}
|
||||
messages_spec.extend(output);
|
||||
}
|
||||
|
||||
messages_spec
|
||||
}
|
||||
|
||||
/// Convert internal Tool format to OpenAI's API tool specification
|
||||
pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
|
||||
let mut tool_names = std::collections::HashSet::new();
|
||||
let mut result = Vec::new();
|
||||
|
||||
for tool in tools {
|
||||
if !tool_names.insert(&tool.name) {
|
||||
return Err(anyhow!("Duplicate tool name: {}", tool.name));
|
||||
}
|
||||
|
||||
let mut description = tool.description.clone();
|
||||
description.truncate(1024);
|
||||
|
||||
// OpenAI's tool description max str len is 1024
|
||||
result.push(json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": description,
|
||||
"parameters": tool.input_schema,
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert OpenAI's API response to internal Message format
|
||||
pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
let original = response["choices"][0]["message"].clone();
|
||||
let mut content = Vec::new();
|
||||
|
||||
if let Some(text) = original.get("content") {
|
||||
if let Some(text_str) = text.as_str() {
|
||||
content.push(MessageContent::text(text_str));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = original.get("tool_calls") {
|
||||
if let Some(tool_calls_array) = tool_calls.as_array() {
|
||||
for tool_call in tool_calls_array {
|
||||
let id = tool_call["id"].as_str().unwrap_or_default().to_string();
|
||||
let function_name = tool_call["function"]["name"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let mut arguments = tool_call["function"]["arguments"]
|
||||
.as_str()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
// If arguments is empty, we will have invalid json parsing error later.
|
||||
if arguments.is_empty() {
|
||||
arguments = "{}".to_string();
|
||||
}
|
||||
|
||||
if !is_valid_function_name(&function_name) {
|
||||
let error = ToolError::NotFound(format!(
|
||||
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
|
||||
function_name
|
||||
));
|
||||
content.push(MessageContent::tool_request(id, Err(error).into()));
|
||||
} else {
|
||||
match serde_json::from_str::<Value>(&arguments) {
|
||||
Ok(params) => {
|
||||
content.push(MessageContent::tool_request(
|
||||
id,
|
||||
Ok(ToolCall::new(&function_name, params)).into(),
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
let error = ToolError::InvalidParameters(format!(
|
||||
"Could not interpret tool use parameters for id {}: {}",
|
||||
id, e
|
||||
));
|
||||
content.push(MessageContent::tool_request(id, Err(error).into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Message {
|
||||
role: Role::Assistant,
|
||||
created: chrono::Utc::now().timestamp_millis(),
|
||||
content: content.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
|
||||
let usage = data
|
||||
.get("usage")
|
||||
.ok_or_else(|| ProviderError::UsageError("No usage data in response".to_string()))?;
|
||||
|
||||
let input_tokens = usage
|
||||
.get("prompt_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as i32);
|
||||
|
||||
let output_tokens = usage
|
||||
.get("completion_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as i32);
|
||||
|
||||
let total_tokens = usage
|
||||
.get("total_tokens")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as i32)
|
||||
.or_else(|| match (input_tokens, output_tokens) {
|
||||
(Some(input), Some(output)) => Some(input + output),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
|
||||
}
|
||||
|
||||
/// Validates and fixes tool schemas to ensure they have proper parameter structure.
|
||||
/// If parameters exist, ensures they have properties and required fields, or removes parameters entirely.
|
||||
pub fn validate_tool_schemas(tools: &mut [Value]) {
|
||||
for tool in tools.iter_mut() {
|
||||
if let Some(function) = tool.get_mut("function") {
|
||||
if let Some(parameters) = function.get_mut("parameters") {
|
||||
if parameters.is_object() {
|
||||
ensure_valid_json_schema(parameters);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures that the given JSON value follows the expected JSON Schema structure.
|
||||
fn ensure_valid_json_schema(schema: &mut Value) {
|
||||
if let Some(params_obj) = schema.as_object_mut() {
|
||||
// Check if this is meant to be an object type schema
|
||||
let is_object_type = params_obj
|
||||
.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.is_none_or(|t| t == "object"); // Default to true if no type is specified
|
||||
|
||||
// Only apply full schema validation to object types
|
||||
if is_object_type {
|
||||
// Ensure required fields exist with default values
|
||||
params_obj.entry("properties").or_insert_with(|| json!({}));
|
||||
params_obj.entry("required").or_insert_with(|| json!([]));
|
||||
params_obj.entry("type").or_insert_with(|| json!("object"));
|
||||
|
||||
// Recursively validate properties if it exists
|
||||
if let Some(properties) = params_obj.get_mut("properties") {
|
||||
if let Some(properties_obj) = properties.as_object_mut() {
|
||||
for (_key, prop) in properties_obj.iter_mut() {
|
||||
if prop.is_object()
|
||||
&& prop.get("type").and_then(|t| t.as_str()) == Some("object")
|
||||
{
|
||||
ensure_valid_json_schema(prop);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_request(
|
||||
model_config: &ModelConfig,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
image_format: &ImageFormat,
|
||||
) -> anyhow::Result<Value, Error> {
|
||||
if model_config.model_name.starts_with("o1-mini") {
|
||||
return Err(anyhow!(
|
||||
"o1-mini model is not currently supported since Goose uses tool calling and o1-mini does not support it. Please use o1 or o3 models instead."
|
||||
));
|
||||
}
|
||||
|
||||
let is_ox_model = model_config.model_name.starts_with("o");
|
||||
|
||||
// Only extract reasoning effort for O1/O3 models
|
||||
let (model_name, reasoning_effort) = if is_ox_model {
|
||||
let parts: Vec<&str> = model_config.model_name.split('-').collect();
|
||||
let last_part = parts.last().unwrap();
|
||||
|
||||
match *last_part {
|
||||
"low" | "medium" | "high" => {
|
||||
let base_name = parts[..parts.len() - 1].join("-");
|
||||
(base_name, Some(last_part.to_string()))
|
||||
}
|
||||
_ => (
|
||||
model_config.model_name.to_string(),
|
||||
Some("medium".to_string()),
|
||||
),
|
||||
}
|
||||
} else {
|
||||
// For non-O family models, use the model name as is and no reasoning effort
|
||||
(model_config.model_name.to_string(), None)
|
||||
};
|
||||
|
||||
let system_message = json!({
|
||||
"role": if is_ox_model { "developer" } else { "system" },
|
||||
"content": system
|
||||
});
|
||||
|
||||
let messages_spec = format_messages(messages, image_format);
|
||||
let mut tools_spec = if !tools.is_empty() {
|
||||
format_tools(tools)?
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Validate tool schemas
|
||||
validate_tool_schemas(&mut tools_spec);
|
||||
|
||||
let mut messages_array = vec![system_message];
|
||||
messages_array.extend(messages_spec);
|
||||
|
||||
let mut payload = json!({
|
||||
"model": model_name,
|
||||
"messages": messages_array
|
||||
});
|
||||
|
||||
if let Some(effort) = reasoning_effort {
|
||||
payload
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("reasoning_effort".to_string(), json!(effort));
|
||||
}
|
||||
|
||||
if !tools_spec.is_empty() {
|
||||
payload
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("tools".to_string(), json!(tools_spec));
|
||||
}
|
||||
// o1, o3 models currently don't support temperature
|
||||
if !is_ox_model {
|
||||
if let Some(temp) = model_config.temperature {
|
||||
payload
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert("temperature".to_string(), json!(temp));
|
||||
}
|
||||
}
|
||||
|
||||
// o1 models use max_completion_tokens instead of max_tokens
|
||||
if let Some(tokens) = model_config.max_tokens {
|
||||
let key = if is_ox_model {
|
||||
"max_completion_tokens"
|
||||
} else {
|
||||
"max_tokens"
|
||||
};
|
||||
payload
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.insert(key.to_string(), json!(tokens));
|
||||
}
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
use crate::types::core::Content;
|
||||
|
||||
#[test]
|
||||
fn test_validate_tool_schemas() {
|
||||
// Test case 1: Empty parameters object
|
||||
// Input JSON with an incomplete parameters object
|
||||
let mut actual = vec![json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_func",
|
||||
"description": "test description",
|
||||
"parameters": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
})];
|
||||
|
||||
// Run the function to validate and update schemas
|
||||
validate_tool_schemas(&mut actual);
|
||||
|
||||
// Expected JSON after validation
|
||||
let expected = vec![json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_func",
|
||||
"description": "test description",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
})];
|
||||
|
||||
// Compare entire JSON structures instead of individual fields
|
||||
assert_eq!(actual, expected);
|
||||
|
||||
// Test case 2: Missing type field
|
||||
let mut tools = vec![json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_func",
|
||||
"description": "test description",
|
||||
"parameters": {
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
})];
|
||||
|
||||
validate_tool_schemas(&mut tools);
|
||||
|
||||
let params = tools[0]["function"]["parameters"].as_object().unwrap();
|
||||
assert_eq!(params["type"], "object");
|
||||
|
||||
// Test case 3: Complete valid schema should remain unchanged
|
||||
let original_schema = json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_func",
|
||||
"description": "test description",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City and country"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut tools = vec![original_schema.clone()];
|
||||
validate_tool_schemas(&mut tools);
|
||||
assert_eq!(tools[0], original_schema);
|
||||
}
|
||||
|
||||
const OPENAI_TOOL_USE_RESPONSE: &str = r#"{
|
||||
"choices": [{
|
||||
"role": "assistant",
|
||||
"message": {
|
||||
"tool_calls": [{
|
||||
"id": "1",
|
||||
"function": {
|
||||
"name": "example_fn",
|
||||
"arguments": "{\"param\": \"value\"}"
|
||||
}
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}"#;
|
||||
|
||||
#[test]
|
||||
fn test_format_messages() -> anyhow::Result<()> {
|
||||
let message = Message::user().with_text("Hello");
|
||||
let spec = format_messages(&[message], &ImageFormat::OpenAi);
|
||||
|
||||
assert_eq!(spec.len(), 1);
|
||||
assert_eq!(spec[0]["role"], "user");
|
||||
assert_eq!(spec[0]["content"], "Hello");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tools() -> anyhow::Result<()> {
|
||||
let tool = Tool::new(
|
||||
"test_tool",
|
||||
"A test tool",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "Test parameter"
|
||||
}
|
||||
},
|
||||
"required": ["input"]
|
||||
}),
|
||||
);
|
||||
|
||||
let spec = format_tools(&[tool])?;
|
||||
|
||||
assert_eq!(spec.len(), 1);
|
||||
assert_eq!(spec[0]["type"], "function");
|
||||
assert_eq!(spec[0]["function"]["name"], "test_tool");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages_complex() -> anyhow::Result<()> {
|
||||
let mut messages = vec![
|
||||
Message::assistant().with_text("Hello!"),
|
||||
Message::user().with_text("How are you?"),
|
||||
Message::assistant().with_tool_request(
|
||||
"tool1",
|
||||
Ok(ToolCall::new("example", json!({"param1": "value1"}))),
|
||||
),
|
||||
];
|
||||
|
||||
// Get the ID from the tool request to use in the response
|
||||
let tool_id = if let MessageContent::ToolReq(request) = &messages[2].content[0] {
|
||||
request.id.clone()
|
||||
} else {
|
||||
panic!("should be tool request");
|
||||
};
|
||||
|
||||
messages.push(
|
||||
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
|
||||
);
|
||||
|
||||
let spec = format_messages(&messages, &ImageFormat::OpenAi);
|
||||
|
||||
assert_eq!(spec.len(), 4);
|
||||
assert_eq!(spec[0]["role"], "assistant");
|
||||
assert_eq!(spec[0]["content"], "Hello!");
|
||||
assert_eq!(spec[1]["role"], "user");
|
||||
assert_eq!(spec[1]["content"], "How are you?");
|
||||
assert_eq!(spec[2]["role"], "assistant");
|
||||
assert!(spec[2]["tool_calls"].is_array());
|
||||
assert_eq!(spec[3]["role"], "tool");
|
||||
assert_eq!(spec[3]["content"], "Result");
|
||||
assert_eq!(spec[3]["tool_call_id"], spec[2]["tool_calls"][0]["id"]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages_multiple_content() -> anyhow::Result<()> {
|
||||
let mut messages = vec![Message::assistant().with_tool_request(
|
||||
"tool1",
|
||||
Ok(ToolCall::new("example", json!({"param1": "value1"}))),
|
||||
)];
|
||||
|
||||
// Get the ID from the tool request to use in the response
|
||||
let tool_id = if let MessageContent::ToolReq(request) = &messages[0].content[0] {
|
||||
request.id.clone()
|
||||
} else {
|
||||
panic!("should be tool request");
|
||||
};
|
||||
|
||||
messages.push(
|
||||
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
|
||||
);
|
||||
|
||||
let spec = format_messages(&messages, &ImageFormat::OpenAi);
|
||||
|
||||
assert_eq!(spec.len(), 2);
|
||||
assert_eq!(spec[0]["role"], "assistant");
|
||||
assert!(spec[0]["tool_calls"].is_array());
|
||||
assert_eq!(spec[1]["role"], "tool");
|
||||
assert_eq!(spec[1]["content"], "Result");
|
||||
assert_eq!(spec[1]["tool_call_id"], spec[0]["tool_calls"][0]["id"]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tools_duplicate() -> anyhow::Result<()> {
|
||||
let tool1 = Tool::new(
|
||||
"test_tool",
|
||||
"Test tool",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "Test parameter"
|
||||
}
|
||||
},
|
||||
"required": ["input"]
|
||||
}),
|
||||
);
|
||||
|
||||
let tool2 = Tool::new(
|
||||
"test_tool",
|
||||
"Test tool",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "Test parameter"
|
||||
}
|
||||
},
|
||||
"required": ["input"]
|
||||
}),
|
||||
);
|
||||
|
||||
let result = format_tools(&[tool1, tool2]);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Duplicate tool name"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tools_empty() -> anyhow::Result<()> {
|
||||
let spec = format_tools(&[])?;
|
||||
assert!(spec.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, &png_data)?;
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create message with image path
|
||||
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
|
||||
let spec = format_messages(&[message], &ImageFormat::OpenAi);
|
||||
|
||||
assert_eq!(spec.len(), 1);
|
||||
assert_eq!(spec[0]["role"], "user");
|
||||
|
||||
// Content should be an array with text and image
|
||||
let content = spec[0]["content"].as_array().unwrap();
|
||||
assert_eq!(content.len(), 2);
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
|
||||
assert_eq!(content[1]["type"], "image_url");
|
||||
assert!(content[1]["image_url"]["url"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.starts_with("data:image/png;base64,"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_to_message_text() -> anyhow::Result<()> {
|
||||
let response = json!({
|
||||
"choices": [{
|
||||
"role": "assistant",
|
||||
"message": {
|
||||
"content": "Hello from John Cena!"
|
||||
}
|
||||
}],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
assert_eq!(message.content.len(), 1);
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "Hello from John Cena!");
|
||||
} else {
|
||||
panic!("Expected Text content");
|
||||
}
|
||||
assert!(matches!(message.role, Role::Assistant));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_to_message_valid_toolrequest() -> anyhow::Result<()> {
|
||||
let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
assert_eq!(message.content.len(), 1);
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
let tool_call = request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "example_fn");
|
||||
assert_eq!(tool_call.arguments, json!({"param": "value"}));
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_to_message_invalid_func_name() -> anyhow::Result<()> {
|
||||
let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] =
|
||||
json!("invalid fn");
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
match &request.tool_call.as_result() {
|
||||
Err(ToolError::NotFound(msg)) => {
|
||||
assert!(msg.starts_with("The provided function name"));
|
||||
}
|
||||
_ => panic!("Expected ToolNotFound error"),
|
||||
}
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_to_message_json_decode_error() -> anyhow::Result<()> {
|
||||
let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
|
||||
json!("invalid json {");
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
match &request.tool_call.as_result() {
|
||||
Err(ToolError::InvalidParameters(msg)) => {
|
||||
assert!(msg.starts_with("Could not interpret tool use parameters"));
|
||||
}
|
||||
_ => panic!("Expected InvalidParameters error"),
|
||||
}
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_to_message_empty_argument() -> anyhow::Result<()> {
|
||||
let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
|
||||
serde_json::Value::String("".to_string());
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
let tool_call = request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "example_fn");
|
||||
assert_eq!(tool_call.arguments, json!({}));
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_request_gpt_4o() -> anyhow::Result<()> {
|
||||
// Test default medium reasoning effort for O3 model
|
||||
let model_config = ModelConfig {
|
||||
model_name: "gpt-4o".to_string(),
|
||||
context_limit: Some(4096),
|
||||
temperature: None,
|
||||
max_tokens: Some(1024),
|
||||
};
|
||||
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
|
||||
let obj = request.as_object().unwrap();
|
||||
let expected = json!({
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "system"
|
||||
}
|
||||
],
|
||||
"max_tokens": 1024
|
||||
});
|
||||
|
||||
for (key, value) in expected.as_object().unwrap() {
|
||||
assert_eq!(obj.get(key).unwrap(), value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_request_o1_default() -> anyhow::Result<()> {
|
||||
// Test default medium reasoning effort for O1 model
|
||||
let model_config = ModelConfig {
|
||||
model_name: "o1".to_string(),
|
||||
context_limit: Some(4096),
|
||||
temperature: None,
|
||||
max_tokens: Some(1024),
|
||||
};
|
||||
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
|
||||
let obj = request.as_object().unwrap();
|
||||
let expected = json!({
|
||||
"model": "o1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": "system"
|
||||
}
|
||||
],
|
||||
"reasoning_effort": "medium",
|
||||
"max_completion_tokens": 1024
|
||||
});
|
||||
|
||||
for (key, value) in expected.as_object().unwrap() {
|
||||
assert_eq!(obj.get(key).unwrap(), value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_request_o3_custom_reasoning_effort() -> anyhow::Result<()> {
|
||||
// Test custom reasoning effort for O3 model
|
||||
let model_config = ModelConfig {
|
||||
model_name: "o3-mini-high".to_string(),
|
||||
context_limit: Some(4096),
|
||||
temperature: None,
|
||||
max_tokens: Some(1024),
|
||||
};
|
||||
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
|
||||
let obj = request.as_object().unwrap();
|
||||
let expected = json!({
|
||||
"model": "o3-mini",
|
||||
"messages": [
|
||||
{
|
||||
"role": "developer",
|
||||
"content": "system"
|
||||
}
|
||||
],
|
||||
"reasoning_effort": "high",
|
||||
"max_completion_tokens": 1024
|
||||
});
|
||||
|
||||
for (key, value) in expected.as_object().unwrap() {
|
||||
assert_eq!(obj.get(key).unwrap(), value);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
10
crates/goose-llm/src/providers/mod.rs
Normal file
10
crates/goose-llm/src/providers/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
pub mod base;
|
||||
pub mod databricks;
|
||||
pub mod errors;
|
||||
mod factory;
|
||||
pub mod formats;
|
||||
pub mod openai;
|
||||
pub mod utils;
|
||||
|
||||
pub use base::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage};
|
||||
pub use factory::create;
|
||||
231
crates/goose-llm/src/providers/openai.rs
Normal file
231
crates/goose-llm/src/providers/openai.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::{
|
||||
errors::ProviderError,
|
||||
formats::openai::{create_request, get_usage, response_to_message},
|
||||
utils::{emit_debug_trace, get_env, get_model, handle_response_openai_compat, ImageFormat},
|
||||
};
|
||||
use crate::{
|
||||
message::Message,
|
||||
model::ModelConfig,
|
||||
providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
|
||||
types::core::Tool,
|
||||
};
|
||||
|
||||
pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
|
||||
pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4.1", "o1", "o3", "o4-mini"];
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
60
|
||||
}
|
||||
|
||||
fn default_base_path() -> String {
|
||||
"v1/chat/completions".to_string()
|
||||
}
|
||||
|
||||
fn default_host() -> String {
|
||||
"https://api.openai.com".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAiProviderConfig {
|
||||
pub api_key: String,
|
||||
#[serde(default = "default_host")]
|
||||
pub host: String,
|
||||
#[serde(default)]
|
||||
pub organization: Option<String>,
|
||||
#[serde(default = "default_base_path")]
|
||||
pub base_path: String,
|
||||
#[serde(default)]
|
||||
pub project: Option<String>,
|
||||
#[serde(default)]
|
||||
pub custom_headers: Option<HashMap<String, String>>,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout: u64, // timeout in seconds
|
||||
}
|
||||
|
||||
impl OpenAiProviderConfig {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
host: default_host(),
|
||||
organization: None,
|
||||
base_path: default_base_path(),
|
||||
project: None,
|
||||
custom_headers: None,
|
||||
timeout: 600,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = get_env("OPENAI_API_KEY").expect("Missing OPENAI_API_KEY");
|
||||
Self::new(api_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenAiProvider {
|
||||
config: OpenAiProviderConfig,
|
||||
model: ModelConfig,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Self {
|
||||
let config = OpenAiProviderConfig::from_env();
|
||||
OpenAiProvider::from_config(config, model).expect("Failed to initialize OpenAiProvider")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OpenAiProvider {
|
||||
fn default() -> Self {
|
||||
let config = OpenAiProviderConfig::from_env();
|
||||
let model = ModelConfig::new(OPEN_AI_DEFAULT_MODEL.to_string());
|
||||
OpenAiProvider::from_config(config, model).expect("Failed to initialize OpenAiProvider")
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
pub fn from_config(config: OpenAiProviderConfig, model: ModelConfig) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(config.timeout))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
model,
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let base_url = url::Url::parse(&self.config.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join(&self.config.base_path).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let mut request = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.config.api_key));
|
||||
|
||||
// Add organization header if present
|
||||
if let Some(org) = &self.config.organization {
|
||||
request = request.header("OpenAI-Organization", org);
|
||||
}
|
||||
|
||||
// Add project header if present
|
||||
if let Some(project) = &self.config.project {
|
||||
request = request.header("OpenAI-Project", project);
|
||||
}
|
||||
|
||||
if let Some(custom_headers) = &self.config.custom_headers {
|
||||
for (key, value) in custom_headers {
|
||||
request = request.header(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
let response = request.json(&payload).send().await?;
|
||||
|
||||
handle_response_openai_compat(response).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OpenAiProvider {
|
||||
#[tracing::instrument(
|
||||
skip(self, system, messages, tools),
|
||||
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||
)]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<ProviderCompleteResponse, ProviderError> {
|
||||
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
|
||||
|
||||
// Make request
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(usage) => usage,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
tracing::debug!("Failed to get usage data: {}", e);
|
||||
Usage::default()
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let model = get_model(&response);
|
||||
emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
Ok(ProviderCompleteResponse::new(message, model, usage))
|
||||
}
|
||||
|
||||
async fn extract(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &Value,
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// 1. Build base payload (no tools)
|
||||
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
|
||||
|
||||
// 2. Inject strict JSON‐Schema wrapper
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload must be an object")
|
||||
.insert(
|
||||
"response_format".to_string(),
|
||||
json!({
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "extraction",
|
||||
"schema": schema,
|
||||
"strict": true
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// 3. Call OpenAI
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// 4. Extract the assistant’s `content` and parse it into JSON
|
||||
let msg = &response["choices"][0]["message"];
|
||||
let raw = msg.get("content").cloned().ok_or_else(|| {
|
||||
ProviderError::ResponseParseError("Missing content in extract response".into())
|
||||
})?;
|
||||
let data = match raw {
|
||||
Value::String(s) => serde_json::from_str(&s)
|
||||
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
|
||||
Value::Object(_) | Value::Array(_) => raw,
|
||||
other => {
|
||||
return Err(ProviderError::ResponseParseError(format!(
|
||||
"Unexpected content type: {:?}",
|
||||
other
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// 5. Gather usage & model info
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(u) => u,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
tracing::debug!("Failed to get usage in extract: {}", e);
|
||||
Usage::default()
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let model = get_model(&response);
|
||||
|
||||
Ok(ProviderExtractResponse::new(data, model, usage))
|
||||
}
|
||||
}
|
||||
359
crates/goose-llm/src/providers/utils.rs
Normal file
359
crates/goose-llm/src/providers/utils.rs
Normal file
@@ -0,0 +1,359 @@
|
||||
use std::{env, io::Read, path::Path};
|
||||
|
||||
use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use regex::Regex;
|
||||
use reqwest::{Response, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{from_value, json, Value};
|
||||
|
||||
use super::base::Usage;
|
||||
use crate::{
|
||||
model::ModelConfig,
|
||||
providers::errors::{OpenAIError, ProviderError},
|
||||
types::core::ImageContent,
|
||||
};
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct OpenAIErrorResponse {
|
||||
error: OpenAIError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, Default)]
|
||||
pub enum ImageFormat {
|
||||
#[default]
|
||||
OpenAi,
|
||||
Anthropic,
|
||||
}
|
||||
|
||||
/// Timeout in seconds.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct Timeout(u32);
|
||||
impl Default for Timeout {
|
||||
fn default() -> Self {
|
||||
Timeout(60)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert an image content into an image json based on format
|
||||
pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value {
|
||||
match image_format {
|
||||
ImageFormat::OpenAi => json!({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": format!("data:{};base64,{}", image.mime_type, image.data)
|
||||
}
|
||||
}),
|
||||
ImageFormat::Anthropic => json!({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": image.mime_type,
|
||||
"data": image.data,
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle response from OpenAI compatible endpoints
|
||||
/// Error codes: https://platform.openai.com/docs/guides/error-codes
|
||||
/// Context window exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543
|
||||
pub async fn handle_response_openai_compat(response: Response) -> Result<Value, ProviderError> {
|
||||
let status = response.status();
|
||||
// Try to parse the response body as JSON (if applicable)
|
||||
let payload = match response.json::<Value>().await {
|
||||
Ok(json) => json,
|
||||
Err(e) => return Err(ProviderError::RequestFailed(e.to_string())),
|
||||
};
|
||||
|
||||
match status {
|
||||
StatusCode::OK => Ok(payload),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
Err(ProviderError::Authentication(format!(
|
||||
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||
Status: {}. Response: {:?}",
|
||||
status, payload
|
||||
)))
|
||||
}
|
||||
StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND => {
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
)
|
||||
);
|
||||
if let Ok(err_resp) = from_value::<OpenAIErrorResponse>(payload) {
|
||||
let err = err_resp.error;
|
||||
if err.is_context_length_exceeded() {
|
||||
return Err(ProviderError::ContextLengthExceeded(
|
||||
err.message.unwrap_or("Unknown error".to_string()),
|
||||
));
|
||||
}
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
"{} (status {})",
|
||||
err,
|
||||
status.as_u16()
|
||||
)));
|
||||
}
|
||||
Err(ProviderError::RequestFailed(format!(
|
||||
"Unknown error (status {})",
|
||||
status
|
||||
)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Err(ProviderError::ServerError(format!("{:?}", payload)))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed with status: {}",
|
||||
status
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a secret from environment variables. The secret is expected to be in JSON format.
|
||||
pub fn get_env(key: &str) -> Result<String> {
|
||||
// check environment variables (convert to uppercase)
|
||||
let env_key = key.to_uppercase();
|
||||
if let Ok(val) = env::var(&env_key) {
|
||||
let value: Value = serde_json::from_str(&val).unwrap_or(Value::String(val));
|
||||
Ok(serde_json::from_value(value)?)
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"Environment variable {} not found",
|
||||
env_key
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sanitize_function_name(name: &str) -> String {
|
||||
let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap();
|
||||
re.replace_all(name, "_").to_string()
|
||||
}
|
||||
|
||||
pub fn is_valid_function_name(name: &str) -> bool {
|
||||
let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
|
||||
re.is_match(name)
|
||||
}
|
||||
|
||||
/// Extract the model name from a JSON object. Common with most providers to have this top level attribute.
|
||||
pub fn get_model(data: &Value) -> String {
|
||||
if let Some(model) = data.get("model") {
|
||||
if let Some(model_str) = model.as_str() {
|
||||
model_str.to_string()
|
||||
} else {
|
||||
"Unknown".to_string()
|
||||
}
|
||||
} else {
|
||||
"Unknown".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a file is actually an image by examining its magic bytes
|
||||
fn is_image_file(path: &Path) -> bool {
|
||||
if let Ok(mut file) = std::fs::File::open(path) {
|
||||
let mut buffer = [0u8; 8]; // Large enough for most image magic numbers
|
||||
if file.read(&mut buffer).is_ok() {
|
||||
// Check magic numbers for common image formats
|
||||
return match &buffer[0..4] {
|
||||
// PNG: 89 50 4E 47
|
||||
[0x89, 0x50, 0x4E, 0x47] => true,
|
||||
// JPEG: FF D8 FF
|
||||
[0xFF, 0xD8, 0xFF, _] => true,
|
||||
// GIF: 47 49 46 38
|
||||
[0x47, 0x49, 0x46, 0x38] => true,
|
||||
_ => false,
|
||||
};
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Detect if a string contains a path to an image file
|
||||
pub fn detect_image_path(text: &str) -> Option<&str> {
|
||||
// Basic image file extension check
|
||||
let extensions = [".png", ".jpg", ".jpeg"];
|
||||
|
||||
// Find any word that ends with an image extension
|
||||
for word in text.split_whitespace() {
|
||||
if extensions
|
||||
.iter()
|
||||
.any(|ext| word.to_lowercase().ends_with(ext))
|
||||
{
|
||||
let path = Path::new(word);
|
||||
// Check if it's an absolute path and file exists
|
||||
if path.is_absolute() && path.is_file() {
|
||||
// Verify it's actually an image file
|
||||
if is_image_file(path) {
|
||||
return Some(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Convert a local image file to base64 encoded ImageContent
|
||||
pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
|
||||
let path = Path::new(path);
|
||||
|
||||
// Verify it's an image before proceeding
|
||||
if !is_image_file(path) {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"File is not a valid image".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Read the file
|
||||
let bytes = std::fs::read(path)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?;
|
||||
|
||||
// Detect mime type from extension
|
||||
let mime_type = match path.extension().and_then(|e| e.to_str()) {
|
||||
Some(ext) => match ext.to_lowercase().as_str() {
|
||||
"png" => "image/png",
|
||||
"jpg" | "jpeg" => "image/jpeg",
|
||||
_ => {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"Unsupported image format".to_string(),
|
||||
));
|
||||
}
|
||||
},
|
||||
None => {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"Unknown image format".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Convert to base64
|
||||
let data = base64::prelude::BASE64_STANDARD.encode(&bytes);
|
||||
|
||||
Ok(ImageContent {
|
||||
mime_type: mime_type.to_string(),
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn emit_debug_trace(
|
||||
model_config: &ModelConfig,
|
||||
payload: &Value,
|
||||
response: &Value,
|
||||
usage: &Usage,
|
||||
) {
|
||||
tracing::debug!(
|
||||
model_config = %serde_json::to_string_pretty(model_config).unwrap_or_default(),
|
||||
input = %serde_json::to_string_pretty(payload).unwrap_or_default(),
|
||||
output = %serde_json::to_string_pretty(response).unwrap_or_default(),
|
||||
input_tokens = ?usage.input_tokens.unwrap_or_default(),
|
||||
output_tokens = ?usage.output_tokens.unwrap_or_default(),
|
||||
total_tokens = ?usage.total_tokens.unwrap_or_default(),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_image_path() {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, &png_data).unwrap();
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create a fake PNG (wrong magic numbers)
|
||||
let fake_png_path = temp_dir.path().join("fake.png");
|
||||
std::fs::write(&fake_png_path, b"not a real png").unwrap();
|
||||
|
||||
// Test with valid PNG file using absolute path
|
||||
let text = format!("Here is an image {}", png_path_str);
|
||||
assert_eq!(detect_image_path(&text), Some(png_path_str));
|
||||
|
||||
// Test with non-image file that has .png extension
|
||||
let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
|
||||
assert_eq!(detect_image_path(&text), None);
|
||||
|
||||
// Test with non-existent file
|
||||
let text = "Here is a fake.png that doesn't exist";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
|
||||
// Test with non-image file
|
||||
let text = "Here is a file.txt";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
|
||||
// Test with relative path (should not match)
|
||||
let text = "Here is a relative/path/image.png";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_image_file() {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, &png_data).unwrap();
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create a fake PNG (wrong magic numbers)
|
||||
let fake_png_path = temp_dir.path().join("fake.png");
|
||||
std::fs::write(&fake_png_path, b"not a real png").unwrap();
|
||||
let fake_png_path_str = fake_png_path.to_str().unwrap();
|
||||
|
||||
// Test loading valid PNG file
|
||||
let result = load_image_file(png_path_str);
|
||||
assert!(result.is_ok());
|
||||
let image = result.unwrap();
|
||||
assert_eq!(image.mime_type, "image/png");
|
||||
|
||||
// Test loading fake PNG file
|
||||
let result = load_image_file(fake_png_path_str);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("not a valid image"));
|
||||
|
||||
// Test non-existent file
|
||||
let result = load_image_file("nonexistent.png");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_function_name() {
|
||||
assert_eq!(sanitize_function_name("hello-world"), "hello-world");
|
||||
assert_eq!(sanitize_function_name("hello world"), "hello_world");
|
||||
assert_eq!(sanitize_function_name("hello@world"), "hello_world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_function_name() {
|
||||
assert!(is_valid_function_name("hello-world"));
|
||||
assert!(is_valid_function_name("hello_world"));
|
||||
assert!(!is_valid_function_name("hello world"));
|
||||
assert!(!is_valid_function_name("hello@world"));
|
||||
}
|
||||
}
|
||||
29
crates/goose-llm/src/structured_outputs.rs
Normal file
29
crates/goose-llm/src/structured_outputs.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::{
|
||||
providers::{create, errors::ProviderError, ProviderExtractResponse},
|
||||
types::json_value_ffi::JsonValueFfi,
|
||||
Message, ModelConfig,
|
||||
};
|
||||
|
||||
/// Generates a structured output based on the provided schema,
|
||||
/// system prompt and user messages.
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn generate_structured_outputs(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
system_prompt: &str,
|
||||
messages: &[Message],
|
||||
schema: JsonValueFfi,
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// Use OpenAI models specifically for this task
|
||||
let model_name = if provider_name == "databricks" {
|
||||
"goose-gpt-4-1"
|
||||
} else {
|
||||
"gpt-4.1"
|
||||
};
|
||||
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
|
||||
let provider = create(provider_name, provider_config, model_cfg)?;
|
||||
|
||||
let resp = provider.extract(system_prompt, messages, &schema).await?;
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
227
crates/goose-llm/src/types/completion.rs
Normal file
227
crates/goose-llm/src/types/completion.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
// This file defines types for completion interfaces, including the request and response structures.
|
||||
// Many of these are adapted based on the Goose Service API:
|
||||
// https://docs.google.com/document/d/1r5vjSK3nBQU1cIRf0WKysDigqMlzzrzl_bxEE4msOiw/edit?tab=t.0
|
||||
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::types::json_value_ffi::JsonValueFfi;
|
||||
use crate::{message::Message, providers::Usage};
|
||||
use crate::{model::ModelConfig, providers::errors::ProviderError};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompletionRequest {
|
||||
pub provider_name: String,
|
||||
pub provider_config: serde_json::Value,
|
||||
pub model_config: ModelConfig,
|
||||
pub system_preamble: String,
|
||||
pub messages: Vec<Message>,
|
||||
pub extensions: Vec<ExtensionConfig>,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
pub fn new(
|
||||
provider_name: String,
|
||||
provider_config: serde_json::Value,
|
||||
model_config: ModelConfig,
|
||||
system_preamble: String,
|
||||
messages: Vec<Message>,
|
||||
extensions: Vec<ExtensionConfig>,
|
||||
) -> Self {
|
||||
Self {
|
||||
provider_name,
|
||||
provider_config,
|
||||
model_config,
|
||||
system_preamble,
|
||||
messages,
|
||||
extensions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[uniffi::export]
|
||||
pub fn create_completion_request(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
model_config: ModelConfig,
|
||||
system_preamble: &str,
|
||||
messages: Vec<Message>,
|
||||
extensions: Vec<ExtensionConfig>,
|
||||
) -> CompletionRequest {
|
||||
CompletionRequest::new(
|
||||
provider_name.to_string(),
|
||||
provider_config,
|
||||
model_config,
|
||||
system_preamble.to_string(),
|
||||
messages,
|
||||
extensions,
|
||||
)
|
||||
}
|
||||
|
||||
uniffi::custom_type!(CompletionRequest, String, {
|
||||
lower: |tc: &CompletionRequest| {
|
||||
serde_json::to_string(&tc).unwrap()
|
||||
},
|
||||
try_lift: |s: String| {
|
||||
Ok(serde_json::from_str(&s).unwrap())
|
||||
},
|
||||
});
|
||||
|
||||
// https://mozilla.github.io/uniffi-rs/latest/proc_macro/errors.html
|
||||
#[derive(Debug, Error, uniffi::Error)]
|
||||
#[uniffi(flat_error)]
|
||||
pub enum CompletionError {
|
||||
#[error("failed to create provider: {0}")]
|
||||
UnknownProvider(String),
|
||||
|
||||
#[error("provider error: {0}")]
|
||||
Provider(#[from] ProviderError),
|
||||
|
||||
#[error("template rendering error: {0}")]
|
||||
Template(#[from] minijinja::Error),
|
||||
|
||||
#[error("json serialization error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("tool not found error: {0}")]
|
||||
ToolNotFound(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct CompletionResponse {
|
||||
pub message: Message,
|
||||
pub model: String,
|
||||
pub usage: Usage,
|
||||
pub runtime_metrics: RuntimeMetrics,
|
||||
}
|
||||
|
||||
impl CompletionResponse {
|
||||
pub fn new(
|
||||
message: Message,
|
||||
model: String,
|
||||
usage: Usage,
|
||||
runtime_metrics: RuntimeMetrics,
|
||||
) -> Self {
|
||||
Self {
|
||||
message,
|
||||
model,
|
||||
usage,
|
||||
runtime_metrics,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct RuntimeMetrics {
|
||||
pub total_time_sec: f32,
|
||||
pub total_time_sec_provider: f32,
|
||||
pub tokens_per_second: Option<f64>,
|
||||
}
|
||||
|
||||
impl RuntimeMetrics {
|
||||
pub fn new(
|
||||
total_time_sec: f32,
|
||||
total_time_sec_provider: f32,
|
||||
tokens_per_second: Option<f64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
total_time_sec,
|
||||
total_time_sec_provider,
|
||||
tokens_per_second,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
|
||||
pub enum ToolApprovalMode {
|
||||
Auto,
|
||||
Manual,
|
||||
Smart,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct ToolConfig {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: JsonValueFfi,
|
||||
pub approval_mode: ToolApprovalMode,
|
||||
}
|
||||
|
||||
impl ToolConfig {
|
||||
pub fn new(
|
||||
name: &str,
|
||||
description: &str,
|
||||
input_schema: JsonValueFfi,
|
||||
approval_mode: ToolApprovalMode,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
description: description.to_string(),
|
||||
input_schema,
|
||||
approval_mode,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the tool config to a core tool
|
||||
pub fn to_core_tool(&self, name: Option<&str>) -> super::core::Tool {
|
||||
let tool_name = name.unwrap_or(&self.name);
|
||||
super::core::Tool::new(
|
||||
tool_name,
|
||||
self.description.clone(),
|
||||
self.input_schema.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[uniffi::export]
|
||||
pub fn create_tool_config(
|
||||
name: &str,
|
||||
description: &str,
|
||||
input_schema: JsonValueFfi,
|
||||
approval_mode: ToolApprovalMode,
|
||||
) -> ToolConfig {
|
||||
ToolConfig::new(name, description, input_schema, approval_mode)
|
||||
}
|
||||
|
||||
// — Register the newtypes with UniFFI, converting via JSON strings —
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct ExtensionConfig {
|
||||
name: String,
|
||||
instructions: Option<String>,
|
||||
tools: Vec<ToolConfig>,
|
||||
}
|
||||
|
||||
impl ExtensionConfig {
|
||||
pub fn new(name: String, instructions: Option<String>, tools: Vec<ToolConfig>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
instructions,
|
||||
tools,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the tools to core tools with the extension name as a prefix
|
||||
pub fn get_prefixed_tools(&self) -> Vec<super::core::Tool> {
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let name = format!("{}__{}", self.name, tool.name);
|
||||
tool.to_core_tool(Some(&name))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get a map of prefixed tool names to their approval modes
|
||||
pub fn get_prefixed_tool_configs(&self) -> HashMap<String, ToolConfig> {
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let name = format!("{}__{}", self.name, tool.name);
|
||||
(name, tool.clone())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
131
crates/goose-llm/src/types/core.rs
Normal file
131
crates/goose-llm/src/types/core.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
// This file defines core types that require serialization to
|
||||
// construct payloads for LLM model providers and work with MCPs.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum Content {
|
||||
Text(TextContent),
|
||||
Image(ImageContent),
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn text<S: Into<String>>(text: S) -> Self {
|
||||
Content::Text(TextContent { text: text.into() })
|
||||
}
|
||||
|
||||
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
|
||||
Content::Image(ImageContent {
|
||||
data: data.into(),
|
||||
mime_type: mime_type.into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the text content if this is a TextContent variant
|
||||
pub fn as_text(&self) -> Option<&str> {
|
||||
match self {
|
||||
Content::Text(text) => Some(&text.text),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the image content if this is an ImageContent variant
|
||||
pub fn as_image(&self) -> Option<(&str, &str)> {
|
||||
match self {
|
||||
Content::Image(image) => Some((&image.data, &image.mime_type)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TextContent {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ImageContent {
|
||||
pub data: String,
|
||||
pub mime_type: String,
|
||||
}
|
||||
|
||||
/// A tool that can be used by a model.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tool {
|
||||
/// The name of the tool
|
||||
pub name: String,
|
||||
/// A description of what the tool does
|
||||
pub description: String,
|
||||
/// A JSON Schema object defining the expected parameters for the tool
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
/// Create a new tool with the given name and description
|
||||
pub fn new<N, D>(name: N, description: D, input_schema: serde_json::Value) -> Self
|
||||
where
|
||||
N: Into<String>,
|
||||
D: Into<String>,
|
||||
{
|
||||
Tool {
|
||||
name: name.into(),
|
||||
description: description.into(),
|
||||
input_schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool call request that an extension can execute
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolCall {
|
||||
/// The name of the tool to execute
|
||||
pub name: String,
|
||||
/// The parameters for the execution
|
||||
pub arguments: serde_json::Value,
|
||||
/// Whether the tool call needs approval before execution. Default is false.
|
||||
pub needs_approval: bool,
|
||||
}
|
||||
|
||||
impl ToolCall {
|
||||
/// Create a new ToolUse with the given name and parameters
|
||||
pub fn new<S: Into<String>>(name: S, arguments: serde_json::Value) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
arguments,
|
||||
needs_approval: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set needs_approval field
|
||||
pub fn set_needs_approval(&mut self, flag: bool) {
|
||||
self.needs_approval = flag;
|
||||
}
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq, uniffi::Error)]
|
||||
pub enum ToolError {
|
||||
#[error("Invalid parameters: {0}")]
|
||||
InvalidParameters(String),
|
||||
#[error("Execution failed: {0}")]
|
||||
ExecutionError(String),
|
||||
#[error("Schema error: {0}")]
|
||||
SchemaError(String),
|
||||
#[error("Tool not found: {0}")]
|
||||
NotFound(String),
|
||||
}
|
||||
|
||||
pub type ToolResult<T> = std::result::Result<T, ToolError>;
|
||||
18
crates/goose-llm/src/types/json_value_ffi.rs
Normal file
18
crates/goose-llm/src/types/json_value_ffi.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use serde_json::Value;
|
||||
|
||||
// `serde_json::Value` gets converted to a `String` to pass across the FFI.
|
||||
// https://github.com/mozilla/uniffi-rs/blob/main/docs/manual/src/types/custom_types.md?plain=1
|
||||
// https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/examples/custom-types/src/lib.rs#L63-L69
|
||||
|
||||
uniffi::custom_type!(Value, String, {
|
||||
// Remote is required since 'Value' is from a different crate
|
||||
remote,
|
||||
lower: |obj| {
|
||||
serde_json::to_string(&obj).unwrap()
|
||||
},
|
||||
try_lift: |val| {
|
||||
Ok(serde_json::from_str(&val).unwrap() )
|
||||
},
|
||||
});
|
||||
|
||||
pub type JsonValueFfi = Value;
|
||||
3
crates/goose-llm/src/types/mod.rs
Normal file
3
crates/goose-llm/src/types/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod completion;
|
||||
pub mod core;
|
||||
pub mod json_value_ffi;
|
||||
79
crates/goose-llm/tests/extract_session_name.rs
Normal file
79
crates/goose-llm/tests/extract_session_name.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose_llm::extractors::generate_session_name;
|
||||
use goose_llm::message::Message;
|
||||
use goose_llm::providers::errors::ProviderError;
|
||||
|
||||
fn should_run_test() -> Result<(), String> {
|
||||
dotenv().ok();
|
||||
if std::env::var("DATABRICKS_HOST").is_err() {
|
||||
return Err("Missing DATABRICKS_HOST".to_string());
|
||||
}
|
||||
if std::env::var("DATABRICKS_TOKEN").is_err() {
|
||||
return Err("Missing DATABRICKS_TOKEN".to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn _generate_session_name(messages: &[Message]) -> Result<String, ProviderError> {
|
||||
let provider_name = "databricks";
|
||||
let provider_config = serde_json::json!({
|
||||
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
|
||||
generate_session_name(provider_name, provider_config, messages).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_session_name_success() {
|
||||
if should_run_test().is_err() {
|
||||
println!("Skipping...");
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a few messages with at least two user messages
|
||||
let messages = vec![
|
||||
Message::user().with_text("Hello, how are you?"),
|
||||
Message::assistant().with_text("I’m fine, thanks!"),
|
||||
Message::user().with_text("What’s the weather in New York tomorrow?"),
|
||||
];
|
||||
|
||||
let name = _generate_session_name(&messages)
|
||||
.await
|
||||
.expect("Failed to generate session name");
|
||||
|
||||
println!("Generated session name: {:?}", name);
|
||||
|
||||
// Should be non-empty and at most 4 words
|
||||
let name = name.trim();
|
||||
assert!(!name.is_empty(), "Name must not be empty");
|
||||
let word_count = name.split_whitespace().count();
|
||||
assert!(
|
||||
word_count <= 4,
|
||||
"Name must be 4 words or less, got {}: {}",
|
||||
word_count,
|
||||
name
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_session_name_no_user() {
|
||||
if should_run_test().is_err() {
|
||||
println!("Skipping 'test_generate_session_name_no_user'. Databricks creds not set");
|
||||
return;
|
||||
}
|
||||
|
||||
// No user messages → expect ExecutionError
|
||||
let messages = vec![
|
||||
Message::assistant().with_text("System starting…"),
|
||||
Message::assistant().with_text("All systems go."),
|
||||
];
|
||||
|
||||
let err = _generate_session_name(&messages).await;
|
||||
assert!(
|
||||
matches!(err, Err(ProviderError::ExecutionError(_))),
|
||||
"Expected ExecutionError when there are no user messages, got: {:?}",
|
||||
err
|
||||
);
|
||||
}
|
||||
88
crates/goose-llm/tests/extract_tooltip.rs
Normal file
88
crates/goose-llm/tests/extract_tooltip.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose_llm::extractors::generate_tooltip;
|
||||
use goose_llm::message::{Message, MessageContent, ToolRequest};
|
||||
use goose_llm::providers::errors::ProviderError;
|
||||
use goose_llm::types::core::{Content, ToolCall};
|
||||
use serde_json::json;
|
||||
|
||||
fn should_run_test() -> Result<(), String> {
|
||||
dotenv().ok();
|
||||
if std::env::var("DATABRICKS_HOST").is_err() {
|
||||
return Err("Missing DATABRICKS_HOST".to_string());
|
||||
}
|
||||
if std::env::var("DATABRICKS_TOKEN").is_err() {
|
||||
return Err("Missing DATABRICKS_TOKEN".to_string());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn _generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> {
|
||||
let provider_name = "databricks";
|
||||
let provider_config = serde_json::json!({
|
||||
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
|
||||
generate_tooltip(provider_name, provider_config, messages).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_tooltip_simple() {
|
||||
if should_run_test().is_err() {
|
||||
println!("Skipping...");
|
||||
return;
|
||||
}
|
||||
|
||||
// Two plain-text messages
|
||||
let messages = vec![
|
||||
Message::user().with_text("Hello, how are you?"),
|
||||
Message::assistant().with_text("I'm fine, thanks! How can I help?"),
|
||||
];
|
||||
|
||||
let tooltip = _generate_tooltip(&messages)
|
||||
.await
|
||||
.expect("Failed to generate tooltip");
|
||||
println!("Generated tooltip: {:?}", tooltip);
|
||||
|
||||
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
|
||||
assert!(
|
||||
tooltip.len() < 100,
|
||||
"Tooltip should be reasonably short (<100 chars)"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_tooltip_with_tools() {
|
||||
if should_run_test().is_err() {
|
||||
println!("Skipping...");
|
||||
return;
|
||||
}
|
||||
|
||||
// 1) Assistant message with a tool request
|
||||
let mut tool_req_msg = Message::assistant();
|
||||
let req = ToolRequest {
|
||||
id: "1".to_string(),
|
||||
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))).into(),
|
||||
};
|
||||
tool_req_msg.content.push(MessageContent::ToolReq(req));
|
||||
|
||||
// 2) User message with the tool response
|
||||
let tool_resp_msg = Message::user().with_tool_response(
|
||||
"1",
|
||||
Ok(vec![Content::text("The current time is 12:00 UTC")]).into(),
|
||||
);
|
||||
|
||||
let messages = vec![tool_req_msg, tool_resp_msg];
|
||||
|
||||
let tooltip = _generate_tooltip(&messages)
|
||||
.await
|
||||
.expect("Failed to generate tooltip");
|
||||
println!("Generated tooltip (tools): {:?}", tooltip);
|
||||
|
||||
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
|
||||
assert!(
|
||||
tooltip.len() < 100,
|
||||
"Tooltip should be reasonably short (<100 chars)"
|
||||
);
|
||||
}
|
||||
380
crates/goose-llm/tests/providers_complete.rs
Normal file
380
crates/goose-llm/tests/providers_complete.rs
Normal file
@@ -0,0 +1,380 @@
|
||||
use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose_llm::message::{Message, MessageContent};
|
||||
use goose_llm::providers::base::Provider;
|
||||
use goose_llm::providers::errors::ProviderError;
|
||||
use goose_llm::providers::{databricks, openai};
|
||||
use goose_llm::types::core::{Content, Tool};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum TestStatus {
|
||||
Passed,
|
||||
Skipped,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TestStatus::Passed => write!(f, "✅"),
|
||||
TestStatus::Skipped => write!(f, "⏭️"),
|
||||
TestStatus::Failed => write!(f, "❌"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TestReport {
|
||||
results: Mutex<HashMap<String, TestStatus>>,
|
||||
}
|
||||
|
||||
impl TestReport {
|
||||
fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
results: Mutex::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
fn record_status(&self, provider: &str, status: TestStatus) {
|
||||
let mut results = self.results.lock().unwrap();
|
||||
results.insert(provider.to_string(), status);
|
||||
}
|
||||
|
||||
fn record_pass(&self, provider: &str) {
|
||||
self.record_status(provider, TestStatus::Passed);
|
||||
}
|
||||
|
||||
fn record_skip(&self, provider: &str) {
|
||||
self.record_status(provider, TestStatus::Skipped);
|
||||
}
|
||||
|
||||
fn record_fail(&self, provider: &str) {
|
||||
self.record_status(provider, TestStatus::Failed);
|
||||
}
|
||||
|
||||
fn print_summary(&self) {
|
||||
println!("\n============== Providers ==============");
|
||||
let results = self.results.lock().unwrap();
|
||||
let mut providers: Vec<_> = results.iter().collect();
|
||||
providers.sort_by(|a, b| a.0.cmp(b.0));
|
||||
|
||||
for (provider, status) in providers {
|
||||
println!("{} {}", status, provider);
|
||||
}
|
||||
println!("=======================================\n");
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref TEST_REPORT: Arc<TestReport> = TestReport::new();
|
||||
static ref ENV_LOCK: Mutex<()> = Mutex::new(());
|
||||
}
|
||||
|
||||
/// Generic test harness for any Provider implementation
|
||||
struct ProviderTester {
|
||||
provider: Arc<dyn Provider>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl ProviderTester {
|
||||
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
|
||||
Self {
|
||||
provider: Arc::new(provider),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
async fn test_basic_response(&self) -> Result<()> {
|
||||
let message = Message::user().with_text("Just say hello!");
|
||||
|
||||
let response = self
|
||||
.provider
|
||||
.complete("You are a helpful assistant.", &[message], &[])
|
||||
.await?;
|
||||
|
||||
// For a basic response, we expect a single text response
|
||||
assert_eq!(
|
||||
response.message.content.len(),
|
||||
1,
|
||||
"Expected single content item in response"
|
||||
);
|
||||
|
||||
// Verify we got a text response
|
||||
assert!(
|
||||
matches!(response.message.content[0], MessageContent::Text(_)),
|
||||
"Expected text response"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_tool_usage(&self) -> Result<()> {
|
||||
let weather_tool = Tool::new(
|
||||
"get_weather",
|
||||
"Get the weather for a location",
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let message = Message::user().with_text("What's the weather like in San Francisco?");
|
||||
|
||||
let response1 = self
|
||||
.provider
|
||||
.complete(
|
||||
"You are a helpful weather assistant.",
|
||||
&[message.clone()],
|
||||
&[weather_tool.clone()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
println!("=== {}::reponse1 ===", self.name);
|
||||
dbg!(&response1);
|
||||
println!("===================");
|
||||
|
||||
// Verify we got a tool request
|
||||
assert!(
|
||||
response1
|
||||
.message
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| matches!(content, MessageContent::ToolReq(_))),
|
||||
"Expected tool request in response"
|
||||
);
|
||||
|
||||
let id = &response1
|
||||
.message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|message| message.as_tool_request())
|
||||
.last()
|
||||
.expect("got tool request")
|
||||
.id;
|
||||
|
||||
let weather = Message::user().with_tool_response(
|
||||
id,
|
||||
Ok(vec![Content::text(
|
||||
"
|
||||
50°F°C
|
||||
Precipitation: 0%
|
||||
Humidity: 84%
|
||||
Wind: 2 mph
|
||||
Weather
|
||||
Saturday 9:00 PM
|
||||
Clear",
|
||||
)])
|
||||
.into(),
|
||||
);
|
||||
|
||||
// Verify we construct a valid payload including the request/response pair for the next inference
|
||||
let response2 = self
|
||||
.provider
|
||||
.complete(
|
||||
"You are a helpful weather assistant.",
|
||||
&[message, response1.message, weather],
|
||||
&[weather_tool],
|
||||
)
|
||||
.await?;
|
||||
|
||||
println!("=== {}::reponse2 ===", self.name);
|
||||
dbg!(&response2);
|
||||
println!("===================");
|
||||
|
||||
assert!(
|
||||
response2
|
||||
.message
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| matches!(content, MessageContent::Text(_))),
|
||||
"Expected text for final response"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_context_length_exceeded_error(&self) -> Result<()> {
|
||||
// Google Gemini has a really long context window
|
||||
let large_message_content = if self.name.to_lowercase() == "google" {
|
||||
"hello ".repeat(1_300_000)
|
||||
} else {
|
||||
"hello ".repeat(300_000)
|
||||
};
|
||||
|
||||
let messages = vec![
|
||||
Message::user().with_text("hi there. what is 2 + 2?"),
|
||||
Message::assistant().with_text("hey! I think it's 4."),
|
||||
Message::user().with_text(&large_message_content),
|
||||
Message::assistant().with_text("heyy!!"),
|
||||
// Messages before this mark should be truncated
|
||||
Message::user().with_text("what's the meaning of life?"),
|
||||
Message::assistant().with_text("the meaning of life is 42"),
|
||||
Message::user().with_text(
|
||||
"did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'",
|
||||
),
|
||||
];
|
||||
|
||||
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
|
||||
let result = self
|
||||
.provider
|
||||
.complete("You are a helpful assistant.", &messages, &[])
|
||||
.await;
|
||||
|
||||
// Print some debug info
|
||||
println!("=== {}::context_length_exceeded_error ===", self.name);
|
||||
dbg!(&result);
|
||||
println!("===================");
|
||||
|
||||
// Ollama truncates by default even when the context window is exceeded
|
||||
if self.name.to_lowercase() == "ollama" {
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected to succeed because of default truncation"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Expected error when context window is exceeded"
|
||||
);
|
||||
assert!(
|
||||
matches!(result.unwrap_err(), ProviderError::ContextLengthExceeded(_)),
|
||||
"Expected error to be ContextLengthExceeded"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run all provider tests
|
||||
async fn run_test_suite(&self) -> Result<()> {
|
||||
self.test_basic_response().await?;
|
||||
self.test_tool_usage().await?;
|
||||
self.test_context_length_exceeded_error().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn load_env() {
|
||||
if let Ok(path) = dotenv() {
|
||||
println!("Loaded environment from {:?}", path);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to run a provider test with proper error handling and reporting
|
||||
async fn test_provider<F, T>(
|
||||
name: &str,
|
||||
required_vars: &[&str],
|
||||
env_modifications: Option<HashMap<&str, Option<String>>>,
|
||||
provider_fn: F,
|
||||
) -> Result<()>
|
||||
where
|
||||
F: FnOnce() -> T,
|
||||
T: Provider + Send + Sync + 'static,
|
||||
{
|
||||
// We start off as failed, so that if the process panics it is seen as a failure
|
||||
TEST_REPORT.record_fail(name);
|
||||
|
||||
// Take exclusive access to environment modifications
|
||||
let lock = ENV_LOCK.lock().unwrap();
|
||||
|
||||
load_env();
|
||||
|
||||
// Save current environment state for required vars and modified vars
|
||||
let mut original_env = HashMap::new();
|
||||
for &var in required_vars {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
original_env.insert(var, val);
|
||||
}
|
||||
}
|
||||
if let Some(mods) = &env_modifications {
|
||||
for &var in mods.keys() {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
original_env.insert(var, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply any environment modifications
|
||||
if let Some(mods) = &env_modifications {
|
||||
for (&var, value) in mods.iter() {
|
||||
match value {
|
||||
Some(val) => std::env::set_var(var, val),
|
||||
None => std::env::remove_var(var),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the provider
|
||||
let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
|
||||
if missing_vars {
|
||||
println!("Skipping {} tests - credentials not configured", name);
|
||||
TEST_REPORT.record_skip(name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = provider_fn();
|
||||
|
||||
// Restore original environment
|
||||
for (&var, value) in original_env.iter() {
|
||||
std::env::set_var(var, value);
|
||||
}
|
||||
if let Some(mods) = env_modifications {
|
||||
for &var in mods.keys() {
|
||||
if !original_env.contains_key(var) {
|
||||
std::env::remove_var(var);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::mem::drop(lock);
|
||||
|
||||
let tester = ProviderTester::new(provider, name.to_string());
|
||||
match tester.run_test_suite().await {
|
||||
Ok(_) => {
|
||||
TEST_REPORT.record_pass(name);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
println!("{} test failed: {}", name, e);
|
||||
TEST_REPORT.record_fail(name);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_complete() -> Result<()> {
|
||||
test_provider(
|
||||
"OpenAI",
|
||||
&["OPENAI_API_KEY"],
|
||||
None,
|
||||
openai::OpenAiProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn databricks_complete() -> Result<()> {
|
||||
test_provider(
|
||||
"Databricks",
|
||||
&["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
|
||||
None,
|
||||
databricks::DatabricksProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Print the final test report
|
||||
#[ctor::dtor]
|
||||
fn print_test_report() {
|
||||
TEST_REPORT.print_summary();
|
||||
}
|
||||
195
crates/goose-llm/tests/providers_extract.rs
Normal file
195
crates/goose-llm/tests/providers_extract.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
// tests/providers_extract.rs
|
||||
|
||||
use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose_llm::message::Message;
|
||||
use goose_llm::providers::base::Provider;
|
||||
use goose_llm::providers::{databricks::DatabricksProvider, openai::OpenAiProvider};
|
||||
use goose_llm::ModelConfig;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
enum ProviderType {
|
||||
OpenAi,
|
||||
Databricks,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
fn required_env(&self) -> &'static [&'static str] {
|
||||
match self {
|
||||
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
||||
ProviderType::Databricks => &["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider(&self, cfg: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||
Ok(match self {
|
||||
ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(cfg)),
|
||||
ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(cfg)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn check_required_env_vars(required: &[&str]) -> bool {
|
||||
let missing: Vec<_> = required
|
||||
.iter()
|
||||
.filter(|&&v| std::env::var(v).is_err())
|
||||
.cloned()
|
||||
.collect();
|
||||
if !missing.is_empty() {
|
||||
println!("Skipping test; missing env vars: {:?}", missing);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shared inputs for "paper" task ---
|
||||
const PAPER_SYSTEM: &str =
|
||||
"You are an expert at structured data extraction. Extract the metadata of a research paper into JSON.";
|
||||
const PAPER_TEXT: &str =
|
||||
"Application of Quantum Algorithms in Interstellar Navigation: A New Frontier \
|
||||
by Dr. Stella Voyager, Dr. Nova Star, Dr. Lyra Hunter. Abstract: This paper \
|
||||
investigates the utilization of quantum algorithms to improve interstellar \
|
||||
navigation systems. Keywords: Quantum algorithms, interstellar navigation, \
|
||||
space-time anomalies, quantum superposition, quantum entanglement, space travel.";
|
||||
|
||||
fn paper_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": { "type": "string" },
|
||||
"authors": { "type": "array", "items": { "type": "string" } },
|
||||
"abstract": { "type": "string" },
|
||||
"keywords": { "type": "array", "items": { "type": "string" } }
|
||||
},
|
||||
"required": ["title","authors","abstract","keywords"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
// --- Shared inputs for "UI" task ---
|
||||
const UI_SYSTEM: &str = "You are a UI generator AI. Convert the user input into a JSON-driven UI.";
|
||||
const UI_TEXT: &str = "Make a User Profile Form";
|
||||
|
||||
fn ui_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["div","button","header","section","field","form"]
|
||||
},
|
||||
"label": { "type": "string" },
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": { "$ref": "#" }
|
||||
},
|
||||
"attributes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"value": { "type": "string" }
|
||||
},
|
||||
"required": ["name","value"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["type","label","children","attributes"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
/// Generic runner for any extract task
|
||||
async fn run_extract_test<F>(
|
||||
provider_type: ProviderType,
|
||||
model: &str,
|
||||
system: &'static str,
|
||||
user_text: &'static str,
|
||||
schema: Value,
|
||||
validate: F,
|
||||
) -> Result<()>
|
||||
where
|
||||
F: Fn(&Value) -> bool,
|
||||
{
|
||||
dotenv().ok();
|
||||
if !check_required_env_vars(provider_type.required_env()) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let cfg = ModelConfig::new(model.to_string()).with_temperature(Some(0.0));
|
||||
let provider = provider_type.create_provider(cfg)?;
|
||||
|
||||
let msg = Message::user().with_text(user_text);
|
||||
let resp = provider.extract(system, &[msg], &schema).await?;
|
||||
|
||||
println!("[{:?}] extract => {}", provider_type, resp.data);
|
||||
|
||||
assert!(
|
||||
validate(&resp.data),
|
||||
"{:?} failed validation on {}",
|
||||
provider_type,
|
||||
resp.data
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper for the "paper" task
|
||||
async fn run_extract_paper_test(provider: ProviderType, model: &str) -> Result<()> {
|
||||
run_extract_test(
|
||||
provider,
|
||||
model,
|
||||
PAPER_SYSTEM,
|
||||
PAPER_TEXT,
|
||||
paper_schema(),
|
||||
|v| {
|
||||
v.as_object()
|
||||
.map(|o| {
|
||||
["title", "authors", "abstract", "keywords"]
|
||||
.iter()
|
||||
.all(|k| o.contains_key(*k))
|
||||
})
|
||||
.unwrap_or(false)
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Helper for the "UI" task
|
||||
async fn run_extract_ui_test(provider: ProviderType, model: &str) -> Result<()> {
|
||||
run_extract_test(provider, model, UI_SYSTEM, UI_TEXT, ui_schema(), |v| {
|
||||
v.as_object()
|
||||
.and_then(|o| o.get("type").and_then(Value::as_str))
|
||||
== Some("form")
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_extract_paper() -> Result<()> {
|
||||
run_extract_paper_test(ProviderType::OpenAi, "gpt-4o").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_extract_ui() -> Result<()> {
|
||||
run_extract_ui_test(ProviderType::OpenAi, "gpt-4o").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn databricks_extract_paper() -> Result<()> {
|
||||
run_extract_paper_test(ProviderType::Databricks, "goose-gpt-4-1").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn databricks_extract_ui() -> Result<()> {
|
||||
run_extract_ui_test(ProviderType::Databricks, "goose-gpt-4-1").await
|
||||
}
|
||||
}
|
||||
3
crates/goose-llm/uniffi-bindgen.rs
Normal file
3
crates/goose-llm/uniffi-bindgen.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
uniffi::uniffi_bindgen_main()
|
||||
}
|
||||
@@ -35,6 +35,7 @@ chrono = { version = "0.4.38", features = ["serde"] }
|
||||
etcetera = "0.8.0"
|
||||
tempfile = "3.8"
|
||||
include_dir = "0.7.4"
|
||||
google-apis-common = "7.0.0"
|
||||
google-drive3 = "6.0.0"
|
||||
google-sheets4 = "6.0.0"
|
||||
google-docs1 = "6.0.0"
|
||||
@@ -47,9 +48,21 @@ lopdf = "0.35.0"
|
||||
docx-rs = "0.4.7"
|
||||
image = "0.24.9"
|
||||
umya-spreadsheet = "2.2.3"
|
||||
keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service", "vendored"] }
|
||||
keyring = { version = "3.6.1", features = [
|
||||
"apple-native",
|
||||
"windows-native",
|
||||
"sync-secret-service",
|
||||
"vendored",
|
||||
] }
|
||||
oauth2 = { version = "5.0.0", features = ["reqwest"] }
|
||||
utoipa = { version = "4.1", optional = true }
|
||||
hyper = "1"
|
||||
serde_with = "3"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
serial_test = "3.0.0"
|
||||
sysinfo = "0.32.1"
|
||||
|
||||
[features]
|
||||
utoipa = ["dep:utoipa"]
|
||||
|
||||
@@ -8,6 +8,9 @@ use std::{
|
||||
};
|
||||
use tokio::process::Command;
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
@@ -743,6 +746,23 @@ impl ComputerControllerRouter {
|
||||
ToolError::ExecutionError(format!("Failed to write script: {}", e))
|
||||
})?;
|
||||
|
||||
// Set execute permissions on Unix systems
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let mut perms = fs::metadata(&script_path)
|
||||
.map_err(|e| {
|
||||
ToolError::ExecutionError(format!("Failed to get file metadata: {}", e))
|
||||
})?
|
||||
.permissions();
|
||||
perms.set_mode(0o755); // rwxr-xr-x
|
||||
fs::set_permissions(&script_path, perms).map_err(|e| {
|
||||
ToolError::ExecutionError(format!(
|
||||
"Failed to set execute permissions: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
script_path.display().to_string()
|
||||
}
|
||||
"ruby" => {
|
||||
|
||||
@@ -65,10 +65,7 @@ impl LinuxAutomation {
|
||||
DisplayServer::X11 => self.check_x11_dependencies()?,
|
||||
DisplayServer::Wayland => self.check_wayland_dependencies()?,
|
||||
DisplayServer::Unknown => {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"Unable to detect display server",
|
||||
));
|
||||
return Err(std::io::Error::other("Unable to detect display server"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,10 +103,7 @@ impl LinuxAutomation {
|
||||
match self.display_server {
|
||||
DisplayServer::X11 => self.execute_x11_command(cmd),
|
||||
DisplayServer::Wayland => self.execute_wayland_command(cmd),
|
||||
DisplayServer::Unknown => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"Unknown display server",
|
||||
)),
|
||||
DisplayServer::Unknown => Err(std::io::Error::other("Unknown display server")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,8 +230,7 @@ impl SystemAutomation for LinuxAutomation {
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).into_owned())
|
||||
} else {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
Err(std::io::Error::other(
|
||||
String::from_utf8_lossy(&output.stderr).into_owned(),
|
||||
))
|
||||
}
|
||||
|
||||
478
crates/goose-mcp/src/google_drive/google_labels.rs
Normal file
478
crates/goose-mcp/src/google_drive/google_labels.rs
Normal file
@@ -0,0 +1,478 @@
|
||||
#![allow(clippy::ptr_arg, dead_code, clippy::enum_variant_names)]
|
||||
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
|
||||
use google_apis_common as common;
|
||||
use tokio::time::sleep;
|
||||
|
||||
/// A scope is needed when requesting an
|
||||
/// [authorization token](https://developers.google.com/workspace/drive/labels/guides/authorize).
|
||||
#[derive(PartialEq, Eq, Ord, PartialOrd, Hash, Debug, Clone, Copy)]
|
||||
pub enum Scope {
|
||||
/// View, use, and manage Drive labels.
|
||||
DriveLabels,
|
||||
|
||||
/// View and use Drive labels.
|
||||
DriveLabelsReadonly,
|
||||
|
||||
/// View, edit, create, and delete all Drive labels in your organization,
|
||||
/// and view your organization's label-related administration policies.
|
||||
DriveLabelsAdmin,
|
||||
|
||||
/// View all Drive labels and label-related administration policies in your
|
||||
/// organization.
|
||||
DriveLabelsAdminReadonly,
|
||||
}
|
||||
|
||||
impl AsRef<str> for Scope {
|
||||
fn as_ref(&self) -> &str {
|
||||
match *self {
|
||||
Scope::DriveLabels => "https://www.googleapis.com/auth/drive.labels",
|
||||
Scope::DriveLabelsReadonly => "https://www.googleapis.com/auth/drive.labels.readonly",
|
||||
Scope::DriveLabelsAdmin => "https://www.googleapis.com/auth/drive.admin.labels",
|
||||
Scope::DriveLabelsAdminReadonly => {
|
||||
"https://www.googleapis.com/auth/drive.admin.labels.readonly"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::derivable_impls)]
|
||||
impl Default for Scope {
|
||||
fn default() -> Scope {
|
||||
Scope::DriveLabelsReadonly
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DriveLabelsHub<C> {
|
||||
pub client: common::Client<C>,
|
||||
pub auth: Box<dyn common::GetToken>,
|
||||
_user_agent: String,
|
||||
_base_url: String,
|
||||
}
|
||||
|
||||
impl<C> common::Hub for DriveLabelsHub<C> {}
|
||||
|
||||
impl<'a, C> DriveLabelsHub<C> {
|
||||
pub fn new<A: 'static + common::GetToken>(
|
||||
client: common::Client<C>,
|
||||
auth: A,
|
||||
) -> DriveLabelsHub<C> {
|
||||
DriveLabelsHub {
|
||||
client,
|
||||
auth: Box::new(auth),
|
||||
_user_agent: "google-api-rust-client/6.0.0".to_string(),
|
||||
_base_url: "https://drivelabels.googleapis.com/".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn labels(&'a self) -> LabelMethods<'a, C> {
|
||||
LabelMethods { hub: self }
|
||||
}
|
||||
|
||||
/// Set the user-agent header field to use in all requests to the server.
|
||||
/// It defaults to `google-api-rust-client/6.0.0`.
|
||||
///
|
||||
/// Returns the previously set user-agent.
|
||||
pub fn user_agent(&mut self, agent_name: String) -> String {
|
||||
std::mem::replace(&mut self._user_agent, agent_name)
|
||||
}
|
||||
|
||||
/// Set the base url to use in all requests to the server.
|
||||
/// It defaults to `https://www.googleapis.com/drive/v3/`.
|
||||
///
|
||||
/// Returns the previously set base url.
|
||||
pub fn base_url(&mut self, new_base_url: String) -> String {
|
||||
std::mem::replace(&mut self._base_url, new_base_url)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Label {
|
||||
#[serde(rename = "name")]
|
||||
pub name: Option<String>,
|
||||
#[serde(rename = "id")]
|
||||
pub id: Option<String>,
|
||||
#[serde(rename = "revisionId")]
|
||||
pub revision_id: Option<String>,
|
||||
#[serde(rename = "labelType")]
|
||||
pub label_type: Option<String>,
|
||||
#[serde(rename = "creator")]
|
||||
pub creator: Option<User>,
|
||||
#[serde(rename = "createTime")]
|
||||
pub create_time: Option<String>,
|
||||
#[serde(rename = "revisionCreator")]
|
||||
pub revision_creator: Option<User>,
|
||||
#[serde(rename = "revisionCreateTime")]
|
||||
pub revision_create_time: Option<String>,
|
||||
#[serde(rename = "publisher")]
|
||||
pub publisher: Option<User>,
|
||||
#[serde(rename = "publishTime")]
|
||||
pub publish_time: Option<String>,
|
||||
#[serde(rename = "disabler")]
|
||||
pub disabler: Option<User>,
|
||||
#[serde(rename = "disableTime")]
|
||||
pub disable_time: Option<String>,
|
||||
#[serde(rename = "customer")]
|
||||
pub customer: Option<String>,
|
||||
pub properties: Option<LabelProperty>,
|
||||
pub fields: Option<Vec<Field>>,
|
||||
// We ignore the remaining fields.
|
||||
}
|
||||
|
||||
impl common::Part for Label {}
|
||||
|
||||
impl common::ResponseResult for Label {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct LabelProperty {
|
||||
pub title: Option<String>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
impl common::Part for LabelProperty {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Field {
|
||||
id: Option<String>,
|
||||
#[serde(rename = "queryKey")]
|
||||
query_key: Option<String>,
|
||||
properties: Option<FieldProperty>,
|
||||
#[serde(rename = "selectionOptions")]
|
||||
selection_options: Option<SelectionOption>,
|
||||
}
|
||||
|
||||
impl common::Part for Field {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct FieldProperty {
|
||||
#[serde(rename = "displayName")]
|
||||
pub display_name: Option<String>,
|
||||
pub required: Option<bool>,
|
||||
}
|
||||
|
||||
impl common::Part for FieldProperty {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SelectionOption {
|
||||
#[serde(rename = "listOptions")]
|
||||
pub list_options: Option<String>,
|
||||
pub choices: Option<Vec<Choice>>,
|
||||
}
|
||||
|
||||
impl common::Part for SelectionOption {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Choice {
|
||||
id: Option<String>,
|
||||
properties: Option<ChoiceProperties>,
|
||||
// We ignore the remaining fields.
|
||||
}
|
||||
|
||||
impl common::Part for Choice {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ChoiceProperties {
|
||||
#[serde(rename = "displayName")]
|
||||
display_name: Option<String>,
|
||||
description: Option<String>,
|
||||
}
|
||||
|
||||
impl common::Part for ChoiceProperties {}
|
||||
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct LabelList {
|
||||
pub labels: Option<Vec<Label>>,
|
||||
#[serde(rename = "nextPageToken")]
|
||||
pub next_page_token: Option<String>,
|
||||
}
|
||||
|
||||
impl common::ResponseResult for LabelList {}
|
||||
|
||||
/// Information about a Drive user.
|
||||
///
|
||||
/// This type is not used in any activity, and only used as *part* of another schema.
|
||||
///
|
||||
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct User {
|
||||
/// Output only. A plain text displayable name for this user.
|
||||
#[serde(rename = "displayName")]
|
||||
pub display_name: Option<String>,
|
||||
/// Output only. The email address of the user. This may not be present in certain contexts if the user has not made their email address visible to the requester.
|
||||
#[serde(rename = "emailAddress")]
|
||||
pub email_address: Option<String>,
|
||||
/// Output only. Identifies what kind of resource this is. Value: the fixed string `"drive#user"`.
|
||||
pub kind: Option<String>,
|
||||
/// Output only. Whether this user is the requesting user.
|
||||
pub me: Option<bool>,
|
||||
/// Output only. The user's ID as visible in Permission resources.
|
||||
#[serde(rename = "permissionId")]
|
||||
pub permission_id: Option<String>,
|
||||
/// Output only. A link to the user's profile photo, if available.
|
||||
#[serde(rename = "photoLink")]
|
||||
pub photo_link: Option<String>,
|
||||
}
|
||||
|
||||
impl common::Part for User {}
|
||||
|
||||
pub struct LabelMethods<'a, C>
|
||||
where
|
||||
C: 'a,
|
||||
{
|
||||
hub: &'a DriveLabelsHub<C>,
|
||||
}
|
||||
|
||||
impl<C> common::MethodsBuilder for LabelMethods<'_, C> {}
|
||||
|
||||
impl<'a, C> LabelMethods<'a, C> {
|
||||
/// Create a builder to help you perform the following tasks:
|
||||
///
|
||||
/// List labels
|
||||
pub fn list(&self) -> LabelListCall<'a, C> {
|
||||
LabelListCall {
|
||||
hub: self.hub,
|
||||
_delegate: Default::default(),
|
||||
_additional_params: Default::default(),
|
||||
_scopes: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Lists the workspace's labels.
|
||||
pub struct LabelListCall<'a, C>
|
||||
where
|
||||
C: 'a,
|
||||
{
|
||||
hub: &'a DriveLabelsHub<C>,
|
||||
_delegate: Option<&'a mut dyn common::Delegate>,
|
||||
_additional_params: HashMap<String, String>,
|
||||
_scopes: BTreeSet<String>,
|
||||
}
|
||||
|
||||
impl<C> common::CallBuilder for LabelListCall<'_, C> {}
|
||||
|
||||
impl<'a, C> LabelListCall<'a, C>
|
||||
where
|
||||
C: common::Connector,
|
||||
{
|
||||
/// Perform the operation you have built so far.
|
||||
pub async fn doit(mut self) -> common::Result<(common::Response, LabelList)> {
|
||||
use common::url::Params;
|
||||
use hyper::header::{AUTHORIZATION, CONTENT_LENGTH, USER_AGENT};
|
||||
|
||||
let mut dd = common::DefaultDelegate;
|
||||
let dlg: &mut dyn common::Delegate = self._delegate.unwrap_or(&mut dd);
|
||||
dlg.begin(common::MethodInfo {
|
||||
id: "drivelabels.labels.list",
|
||||
http_method: hyper::Method::GET,
|
||||
});
|
||||
|
||||
for &field in ["alt"].iter() {
|
||||
if self._additional_params.contains_key(field) {
|
||||
dlg.finished(false);
|
||||
return Err(common::Error::FieldClash(field));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: We don't handle any of the query params.
|
||||
let mut params = Params::with_capacity(2 + self._additional_params.len());
|
||||
|
||||
params.extend(self._additional_params.iter());
|
||||
|
||||
params.push("alt", "json");
|
||||
let url = self.hub._base_url.clone() + "v2/labels";
|
||||
|
||||
if self._scopes.is_empty() {
|
||||
self._scopes
|
||||
.insert(Scope::DriveLabelsReadonly.as_ref().to_string());
|
||||
}
|
||||
|
||||
let url = params.parse_with_url(&url);
|
||||
|
||||
loop {
|
||||
let token = match self
|
||||
.hub
|
||||
.auth
|
||||
.get_token(&self._scopes.iter().map(String::as_str).collect::<Vec<_>>()[..])
|
||||
.await
|
||||
{
|
||||
Ok(token) => token,
|
||||
Err(e) => match dlg.token(e) {
|
||||
Ok(token) => token,
|
||||
Err(e) => {
|
||||
dlg.finished(false);
|
||||
return Err(common::Error::MissingToken(e));
|
||||
}
|
||||
},
|
||||
};
|
||||
let req_result = {
|
||||
let client = &self.hub.client;
|
||||
dlg.pre_request();
|
||||
let mut req_builder = hyper::Request::builder()
|
||||
.method(hyper::Method::GET)
|
||||
.uri(url.as_str())
|
||||
.header(USER_AGENT, self.hub._user_agent.clone());
|
||||
|
||||
if let Some(token) = token.as_ref() {
|
||||
req_builder = req_builder.header(AUTHORIZATION, format!("Bearer {}", token));
|
||||
}
|
||||
|
||||
let request = req_builder
|
||||
.header(CONTENT_LENGTH, 0_u64)
|
||||
.body(common::to_body::<String>(None));
|
||||
client.request(request.unwrap()).await
|
||||
};
|
||||
|
||||
match req_result {
|
||||
Err(err) => {
|
||||
if let common::Retry::After(d) = dlg.http_error(&err) {
|
||||
sleep(d).await;
|
||||
continue;
|
||||
}
|
||||
dlg.finished(false);
|
||||
return Err(common::Error::HttpError(err));
|
||||
}
|
||||
Ok(res) => {
|
||||
let (parts, body) = res.into_parts();
|
||||
let body = common::Body::new(body);
|
||||
if !parts.status.is_success() {
|
||||
let bytes = common::to_bytes(body).await.unwrap_or_default();
|
||||
let error = serde_json::from_str(&common::to_string(&bytes));
|
||||
let response = common::to_response(parts, bytes.into());
|
||||
|
||||
if let common::Retry::After(d) =
|
||||
dlg.http_failure(&response, error.as_ref().ok())
|
||||
{
|
||||
sleep(d).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
dlg.finished(false);
|
||||
|
||||
return Err(match error {
|
||||
Ok(value) => common::Error::BadRequest(value),
|
||||
_ => common::Error::Failure(response),
|
||||
});
|
||||
}
|
||||
let response = {
|
||||
let bytes = common::to_bytes(body).await.unwrap_or_default();
|
||||
let encoded = common::to_string(&bytes);
|
||||
match serde_json::from_str(&encoded) {
|
||||
Ok(decoded) => (common::to_response(parts, bytes.into()), decoded),
|
||||
Err(error) => {
|
||||
dlg.response_json_decode_error(&encoded, &error);
|
||||
return Err(common::Error::JsonDecodeError(
|
||||
encoded.to_string(),
|
||||
error,
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
dlg.finished(true);
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The delegate implementation is consulted whenever there is an intermediate result, or if something goes wrong
|
||||
/// while executing the actual API request.
|
||||
///
|
||||
/// ````text
|
||||
/// It should be used to handle progress information, and to implement a certain level of resilience.
|
||||
/// ````
|
||||
///
|
||||
/// Sets the *delegate* property to the given value.
|
||||
pub fn delegate(mut self, new_value: &'a mut dyn common::Delegate) -> LabelListCall<'a, C> {
|
||||
self._delegate = Some(new_value);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set any additional parameter of the query string used in the request.
|
||||
/// It should be used to set parameters which are not yet available through their own
|
||||
/// setters.
|
||||
///
|
||||
/// Please note that this method must not be used to set any of the known parameters
|
||||
/// which have their own setter method. If done anyway, the request will fail.
|
||||
///
|
||||
/// # Additional Parameters
|
||||
///
|
||||
/// * *$.xgafv* (query-string) - V1 error format.
|
||||
/// * *access_token* (query-string) - OAuth access token.
|
||||
/// * *alt* (query-string) - Data format for response.
|
||||
/// * *callback* (query-string) - JSONP
|
||||
/// * *fields* (query-string) - Selector specifying which fields to include in a partial response.
|
||||
/// * *key* (query-string) - API key. Your API key identifies your project and provides you with API access, quota, and reports. Required unless you provide an OAuth 2.0 token.
|
||||
/// * *oauth_token* (query-string) - OAuth 2.0 token for the current user.
|
||||
/// * *prettyPrint* (query-boolean) - Returns response with indentations and line breaks.
|
||||
/// * *quotaUser* (query-string) - Available to use for quota purposes for server-side applications. Can be any arbitrary string assigned to a user, but should not exceed 40 characters.
|
||||
/// * *uploadType* (query-string) - Legacy upload protocol for media (e.g. "media", "multipart").
|
||||
/// * *upload_protocol* (query-string) - Upload protocol for media (e.g. "raw", "multipart").
|
||||
pub fn param<T>(mut self, name: T, value: T) -> LabelListCall<'a, C>
|
||||
where
|
||||
T: AsRef<str>,
|
||||
{
|
||||
self._additional_params
|
||||
.insert(name.as_ref().to_string(), value.as_ref().to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Identifies the authorization scope for the method you are building.
|
||||
///
|
||||
/// Use this method to actively specify which scope should be used, instead of the default [`Scope`] variant
|
||||
/// [`Scope::DriveLabelsReadonly`].
|
||||
///
|
||||
/// The `scope` will be added to a set of scopes. This is important as one can maintain access
|
||||
/// tokens for more than one scope.
|
||||
///
|
||||
/// Usually there is more than one suitable scope to authorize an operation, some of which may
|
||||
/// encompass more rights than others. For example, for listing resources, a *read-only* scope will be
|
||||
/// sufficient, a read-write scope will do as well.
|
||||
pub fn add_scope<St>(mut self, scope: St) -> LabelListCall<'a, C>
|
||||
where
|
||||
St: AsRef<str>,
|
||||
{
|
||||
self._scopes.insert(String::from(scope.as_ref()));
|
||||
self
|
||||
}
|
||||
/// Identifies the authorization scope(s) for the method you are building.
|
||||
///
|
||||
/// See [`Self::add_scope()`] for details.
|
||||
pub fn add_scopes<I, St>(mut self, scopes: I) -> LabelListCall<'a, C>
|
||||
where
|
||||
I: IntoIterator<Item = St>,
|
||||
St: AsRef<str>,
|
||||
{
|
||||
self._scopes
|
||||
.extend(scopes.into_iter().map(|s| String::from(s.as_ref())));
|
||||
self
|
||||
}
|
||||
|
||||
/// Removes all scopes, and no default scope will be used either.
|
||||
/// In this case, you have to specify your API-key using the `key` parameter (see [`Self::param()`]
|
||||
/// for details).
|
||||
pub fn clear_scopes(mut self) -> LabelListCall<'a, C> {
|
||||
self._scopes.clear();
|
||||
self
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -46,6 +46,7 @@ struct TokenData {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
expires_at: Option<u64>,
|
||||
project_id: String,
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
/// PkceOAuth2Client implements the GetToken trait required by DriveHub
|
||||
@@ -56,6 +57,7 @@ pub struct PkceOAuth2Client {
|
||||
credentials_manager: Arc<CredentialsManager>,
|
||||
http_client: reqwest::Client,
|
||||
project_id: String,
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
impl PkceOAuth2Client {
|
||||
@@ -69,6 +71,7 @@ impl PkceOAuth2Client {
|
||||
|
||||
// Extract the project_id from the config
|
||||
let project_id = config.installed.project_id.clone();
|
||||
let scopes = vec![];
|
||||
|
||||
// Create OAuth URLs
|
||||
let auth_url =
|
||||
@@ -97,6 +100,7 @@ impl PkceOAuth2Client {
|
||||
credentials_manager,
|
||||
http_client,
|
||||
project_id,
|
||||
scopes,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -185,6 +189,7 @@ impl PkceOAuth2Client {
|
||||
refresh_token: refresh_token_str.clone(),
|
||||
expires_at,
|
||||
project_id: self.project_id.clone(),
|
||||
scopes: scopes.iter().map(|s| s.to_string()).collect(),
|
||||
};
|
||||
|
||||
// Store updated token data
|
||||
@@ -239,6 +244,7 @@ impl PkceOAuth2Client {
|
||||
refresh_token: new_refresh_token.clone(),
|
||||
expires_at,
|
||||
project_id: self.project_id.clone(),
|
||||
scopes: self.scopes.clone(),
|
||||
};
|
||||
|
||||
// Store updated token data
|
||||
@@ -315,37 +321,66 @@ impl GetToken for PkceOAuth2Client {
|
||||
if let Ok(token_data) = self.credentials_manager.read_credentials::<TokenData>() {
|
||||
// Verify the project_id matches
|
||||
if token_data.project_id == self.project_id {
|
||||
// Check if the token is expired or expiring within a 5-min buffer
|
||||
if !self.is_token_expired(token_data.expires_at, 300) {
|
||||
return Ok(Some(token_data.access_token));
|
||||
}
|
||||
// Convert stored scopes to &str slices for comparison
|
||||
let stored_scope_refs: Vec<&str> =
|
||||
token_data.scopes.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
// Token is expired or will expire soon, try to refresh it
|
||||
debug!("Token is expired or will expire soon, refreshing...");
|
||||
// Check if we need additional scopes
|
||||
let needs_additional_scopes = scopes.iter().any(|&scope| {
|
||||
!stored_scope_refs
|
||||
.iter()
|
||||
.any(|&stored| stored.contains(scope))
|
||||
});
|
||||
|
||||
// Try to refresh the token
|
||||
if let Ok(access_token) = self.refresh_token(&token_data.refresh_token).await {
|
||||
debug!("Successfully refreshed access token");
|
||||
return Ok(Some(access_token));
|
||||
if !needs_additional_scopes {
|
||||
// Check if the token is expired or expiring within a 5-min buffer
|
||||
if !self.is_token_expired(token_data.expires_at, 300) {
|
||||
return Ok(Some(token_data.access_token));
|
||||
}
|
||||
|
||||
// Token is expired or will expire soon, try to refresh it
|
||||
debug!("Token is expired or will expire soon, refreshing...");
|
||||
|
||||
// Try to refresh the token
|
||||
if let Ok(access_token) =
|
||||
self.refresh_token(&token_data.refresh_token).await
|
||||
{
|
||||
debug!("Successfully refreshed access token");
|
||||
return Ok(Some(access_token));
|
||||
}
|
||||
} else {
|
||||
// Only allocate new strings when we need to combine scopes
|
||||
let mut combined_scopes: Vec<&str> =
|
||||
Vec::with_capacity(scopes.len() + stored_scope_refs.len());
|
||||
combined_scopes.extend(scopes);
|
||||
combined_scopes.extend(stored_scope_refs.iter().filter(|&&stored| {
|
||||
!scopes.iter().any(|&scope| stored.contains(scope))
|
||||
}));
|
||||
|
||||
return self
|
||||
.perform_oauth_flow(&combined_scopes)
|
||||
.await
|
||||
.map(Some)
|
||||
.map_err(|e| {
|
||||
error!("OAuth flow failed: {}", e);
|
||||
e
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here, either:
|
||||
// 1. The project ID didn't match
|
||||
// 2. Token refresh failed
|
||||
// 3. There are no valid tokens yet
|
||||
// 4. We didn't have to change the scopes of an existing token
|
||||
// Fallback: perform interactive OAuth flow
|
||||
match self.perform_oauth_flow(scopes).await {
|
||||
Ok(token) => {
|
||||
debug!("Successfully obtained new access token through OAuth flow");
|
||||
Ok(Some(token))
|
||||
}
|
||||
Err(e) => {
|
||||
self.perform_oauth_flow(scopes)
|
||||
.await
|
||||
.map(Some)
|
||||
.map_err(|e| {
|
||||
error!("OAuth flow failed: {}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
e
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
IMPORTANT: currently GOOSE_ALLOWLIST is used in main.ts in ui/desktop, and not in goose-server. The following is for reference in case it is used on the server side for launch time enforcement.
|
||||
|
||||
# Goose Extension Allowlist
|
||||
|
||||
The allowlist feature provides a security mechanism for controlling which MCP commands can be used by goose.
|
||||
@@ -24,9 +26,11 @@ If this environment variable is not set, no allowlist restrictions will be appli
|
||||
In certain development or testing scenarios, you may need to bypass the allowlist restrictions. You can do this by setting the `GOOSE_ALLOWLIST_BYPASS` environment variable to `true`:
|
||||
|
||||
```bash
|
||||
export GOOSE_ALLOWLIST_BYPASS=true
|
||||
# For the GUI, you can have it show a warning instead of blocking (but it will always show a warning):
|
||||
export GOOSE_ALLOWLIST_WARNING=true
|
||||
```
|
||||
|
||||
|
||||
When this environment variable is set to `true` (case insensitive), the allowlist check will be bypassed and all commands will be allowed, even if the `GOOSE_ALLOWLIST` environment variable is set.
|
||||
|
||||
## Allowlist File Format
|
||||
|
||||
@@ -12,9 +12,10 @@ goose = { path = "../goose" }
|
||||
mcp-core = { path = "../mcp-core" }
|
||||
goose-mcp = { path = "../goose-mcp" }
|
||||
mcp-server = { path = "../mcp-server" }
|
||||
axum = { version = "0.7.2", features = ["ws", "macros"] }
|
||||
axum = { version = "0.8.1", features = ["ws", "macros"] }
|
||||
tokio = { version = "1.43", features = ["full"] }
|
||||
chrono = "0.4"
|
||||
tokio-cron-scheduler = "0.14.0"
|
||||
tower-http = { version = "0.5", features = ["cors"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
@@ -26,6 +27,7 @@ tokio-stream = "0.1"
|
||||
anyhow = "1.0"
|
||||
bytes = "1.5"
|
||||
http = "1.0"
|
||||
base64 = "0.21"
|
||||
config = { version = "0.14.1", features = ["toml"] }
|
||||
thiserror = "1.0"
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
@@ -33,7 +35,7 @@ once_cell = "1.20.2"
|
||||
etcetera = "0.8.0"
|
||||
serde_yaml = "0.9.34"
|
||||
axum-extra = "0.10.0"
|
||||
utoipa = { version = "4.1", features = ["axum_extras"] }
|
||||
utoipa = { version = "4.1", features = ["axum_extras", "chrono"] }
|
||||
dirs = "6.0.0"
|
||||
reqwest = { version = "0.12.9", features = ["json", "rustls-tls", "blocking"], default-features = false }
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@ use std::sync::Arc;
|
||||
use crate::configuration;
|
||||
use crate::state;
|
||||
use anyhow::Result;
|
||||
use etcetera::{choose_app_strategy, AppStrategy};
|
||||
use goose::agents::Agent;
|
||||
use goose::config::APP_STRATEGY;
|
||||
use goose::scheduler::Scheduler as GooseScheduler;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing::info;
|
||||
|
||||
@@ -11,27 +14,30 @@ pub async fn run() -> Result<()> {
|
||||
// Initialize logging
|
||||
crate::logging::setup_logging(Some("goosed"))?;
|
||||
|
||||
// Load configuration
|
||||
let settings = configuration::Settings::new()?;
|
||||
|
||||
// load secret key from GOOSE_SERVER__SECRET_KEY environment variable
|
||||
let secret_key =
|
||||
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
|
||||
|
||||
let new_agent = Agent::new();
|
||||
let agent_ref = Arc::new(new_agent);
|
||||
|
||||
// Create app state with agent
|
||||
let state = state::AppState::new(Arc::new(new_agent), secret_key.clone()).await;
|
||||
let app_state = state::AppState::new(agent_ref.clone(), secret_key.clone()).await;
|
||||
|
||||
let schedule_file_path = choose_app_strategy(APP_STRATEGY.clone())?
|
||||
.data_dir()
|
||||
.join("schedules.json");
|
||||
|
||||
let scheduler_instance = GooseScheduler::new(schedule_file_path).await?;
|
||||
app_state.set_scheduler(scheduler_instance).await;
|
||||
|
||||
// Create router with CORS support
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
let app = crate::routes::configure(state).layer(cors);
|
||||
let app = crate::routes::configure(app_state).layer(cors);
|
||||
|
||||
// Run server
|
||||
let listener = tokio::net::TcpListener::bind(settings.socket_addr()).await?;
|
||||
info!("listening on {}", listener.local_addr()?);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
@@ -8,6 +8,7 @@ use tracing_subscriber::{
|
||||
Registry,
|
||||
};
|
||||
|
||||
use goose::config::APP_STRATEGY;
|
||||
use goose::tracing::langfuse_layer;
|
||||
|
||||
/// Returns the directory where log files should be stored.
|
||||
@@ -17,8 +18,8 @@ fn get_log_directory() -> Result<PathBuf> {
|
||||
// - macOS/Linux: ~/.local/state/goose/logs/server
|
||||
// - Windows: ~\AppData\Roaming\Block\goose\data\logs\server
|
||||
// - Windows has no convention for state_dir, use data_dir instead
|
||||
let home_dir = choose_app_strategy(crate::APP_STRATEGY.clone())
|
||||
.context("HOME environment variable not set")?;
|
||||
let home_dir =
|
||||
choose_app_strategy(APP_STRATEGY.clone()).context("HOME environment variable not set")?;
|
||||
|
||||
let base_log_dir = home_dir
|
||||
.in_state_dir("logs/server")
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
use etcetera::AppStrategyArgs;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
|
||||
top_level_domain: "Block".to_string(),
|
||||
author: "Block".to_string(),
|
||||
app_name: "goose".to_string(),
|
||||
});
|
||||
|
||||
mod commands;
|
||||
mod configuration;
|
||||
mod error;
|
||||
|
||||
@@ -3,8 +3,18 @@ use goose::agents::extension::ToolInfo;
|
||||
use goose::agents::ExtensionConfig;
|
||||
use goose::config::permission::PermissionLevel;
|
||||
use goose::config::ExtensionEntry;
|
||||
use goose::message::{
|
||||
ContextLengthExceeded, FrontendToolRequest, Message, MessageContent, RedactedThinkingContent,
|
||||
SummarizationRequested, ThinkingContent, ToolConfirmationRequest, ToolRequest, ToolResponse,
|
||||
};
|
||||
use goose::permission::permission_confirmation::PrincipalType;
|
||||
use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata};
|
||||
use goose::session::info::SessionInfo;
|
||||
use goose::session::SessionMetadata;
|
||||
use mcp_core::content::{Annotations, Content, EmbeddedResource, ImageContent, TextContent};
|
||||
use mcp_core::handler::ToolResultSchema;
|
||||
use mcp_core::resource::ResourceContents;
|
||||
use mcp_core::role::Role;
|
||||
use mcp_core::tool::{Tool, ToolAnnotations};
|
||||
use utoipa::OpenApi;
|
||||
|
||||
@@ -25,6 +35,17 @@ use utoipa::OpenApi;
|
||||
super::routes::config_management::upsert_permissions,
|
||||
super::routes::agent::get_tools,
|
||||
super::routes::reply::confirm_permission,
|
||||
super::routes::context::manage_context,
|
||||
super::routes::session::list_sessions,
|
||||
super::routes::session::get_session_history,
|
||||
super::routes::schedule::create_schedule,
|
||||
super::routes::schedule::list_schedules,
|
||||
super::routes::schedule::delete_schedule,
|
||||
super::routes::schedule::update_schedule,
|
||||
super::routes::schedule::run_now_handler,
|
||||
super::routes::schedule::pause_schedule,
|
||||
super::routes::schedule::unpause_schedule,
|
||||
super::routes::schedule::sessions_handler
|
||||
),
|
||||
components(schemas(
|
||||
super::routes::config_management::UpsertConfigQuery,
|
||||
@@ -37,6 +58,28 @@ use utoipa::OpenApi;
|
||||
super::routes::config_management::ToolPermission,
|
||||
super::routes::config_management::UpsertPermissionsQuery,
|
||||
super::routes::reply::PermissionConfirmationRequest,
|
||||
super::routes::context::ContextManageRequest,
|
||||
super::routes::context::ContextManageResponse,
|
||||
super::routes::session::SessionListResponse,
|
||||
super::routes::session::SessionHistoryResponse,
|
||||
Message,
|
||||
MessageContent,
|
||||
Content,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
Annotations,
|
||||
TextContent,
|
||||
ToolResponse,
|
||||
ToolRequest,
|
||||
ToolResultSchema,
|
||||
ToolConfirmationRequest,
|
||||
ThinkingContent,
|
||||
RedactedThinkingContent,
|
||||
FrontendToolRequest,
|
||||
ResourceContents,
|
||||
ContextLengthExceeded,
|
||||
SummarizationRequested,
|
||||
Role,
|
||||
ProviderMetadata,
|
||||
ExtensionEntry,
|
||||
ExtensionConfig,
|
||||
@@ -48,6 +91,15 @@ use utoipa::OpenApi;
|
||||
PermissionLevel,
|
||||
PrincipalType,
|
||||
ModelInfo,
|
||||
SessionInfo,
|
||||
SessionMetadata,
|
||||
super::routes::schedule::CreateScheduleRequest,
|
||||
super::routes::schedule::UpdateScheduleRequest,
|
||||
goose::scheduler::ScheduledJob,
|
||||
super::routes::schedule::RunNowResponse,
|
||||
super::routes::schedule::ListSchedulesResponse,
|
||||
super::routes::schedule::SessionsQuery,
|
||||
super::routes::schedule::SessionDisplayInfo,
|
||||
))
|
||||
)]
|
||||
pub struct ApiDoc;
|
||||
|
||||
@@ -6,15 +6,16 @@ use axum::{
|
||||
routing::{delete, get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs};
|
||||
use etcetera::{choose_app_strategy, AppStrategy};
|
||||
use goose::config::Config;
|
||||
use goose::config::APP_STRATEGY;
|
||||
use goose::config::{extensions::name_to_key, PermissionManager};
|
||||
use goose::config::{ExtensionConfigManager, ExtensionEntry};
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::ProviderMetadata;
|
||||
use goose::providers::providers as get_providers;
|
||||
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
|
||||
use http::{HeaderMap, StatusCode};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_yaml;
|
||||
@@ -51,14 +52,12 @@ pub struct ConfigResponse {
|
||||
pub config: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
// Define a new structure to encapsulate the provider details along with configuration status
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ProviderDetails {
|
||||
/// Unique identifier and name of the provider
|
||||
pub name: String,
|
||||
/// Metadata about the provider
|
||||
|
||||
pub metadata: ProviderMetadata,
|
||||
/// Indicates whether the provider is fully configured
|
||||
|
||||
pub is_configured: bool,
|
||||
}
|
||||
|
||||
@@ -69,7 +68,6 @@ pub struct ProvidersResponse {
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ToolPermission {
|
||||
/// Unique identifier and name of the tool, format <extension_name>__<tool_name>
|
||||
pub tool_name: String,
|
||||
pub permission: PermissionLevel,
|
||||
}
|
||||
@@ -93,7 +91,6 @@ pub async fn upsert_config(
|
||||
headers: HeaderMap,
|
||||
Json(query): Json<UpsertConfigQuery>,
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
// Use the helper function to verify the secret key
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let config = Config::global();
|
||||
@@ -120,12 +117,10 @@ pub async fn remove_config(
|
||||
headers: HeaderMap,
|
||||
Json(query): Json<ConfigKeyQuery>,
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
// Use the helper function to verify the secret key
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let config = Config::global();
|
||||
|
||||
// Check if the secret flag is true and call the appropriate method
|
||||
let result = if query.is_secret {
|
||||
config.delete_secret(&query.key)
|
||||
} else {
|
||||
@@ -141,7 +136,7 @@ pub async fn remove_config(
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/config/read",
|
||||
request_body = ConfigKeyQuery, // Switch back to request_body
|
||||
request_body = ConfigKeyQuery,
|
||||
responses(
|
||||
(status = 200, description = "Configuration value retrieved successfully", body = Value),
|
||||
(status = 404, description = "Configuration key not found")
|
||||
@@ -154,16 +149,20 @@ pub async fn read_config(
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
if query.key == "model-limits" {
|
||||
let limits = ModelConfig::get_all_model_limits();
|
||||
return Ok(Json(
|
||||
serde_json::to_value(limits).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
|
||||
));
|
||||
}
|
||||
|
||||
let config = Config::global();
|
||||
|
||||
match config.get(&query.key, query.is_secret) {
|
||||
// Always get the actual value
|
||||
Ok(value) => {
|
||||
if query.is_secret {
|
||||
// If it's marked as secret, return a boolean indicating presence
|
||||
Ok(Json(Value::Bool(true)))
|
||||
} else {
|
||||
// Return the actual value if not secret
|
||||
Ok(Json(value))
|
||||
}
|
||||
}
|
||||
@@ -188,7 +187,6 @@ pub async fn get_extensions(
|
||||
match ExtensionConfigManager::get_all() {
|
||||
Ok(extensions) => Ok(Json(ExtensionResponse { extensions })),
|
||||
Err(err) => {
|
||||
// Return UNPROCESSABLE_ENTITY only for DeserializeError, INTERNAL_SERVER_ERROR for everything else
|
||||
if err
|
||||
.downcast_ref::<goose::config::base::ConfigError>()
|
||||
.is_some_and(|e| matches!(e, goose::config::base::ConfigError::DeserializeError(_)))
|
||||
@@ -219,7 +217,6 @@ pub async fn add_extension(
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Get existing extensions to check if this is an update
|
||||
let extensions =
|
||||
ExtensionConfigManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let key = name_to_key(&extension_query.name);
|
||||
@@ -275,12 +272,10 @@ pub async fn read_all_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<ConfigResponse>, StatusCode> {
|
||||
// Use the helper function to verify the secret key
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let config = Config::global();
|
||||
|
||||
// Load values from config file
|
||||
let values = config
|
||||
.load_values()
|
||||
.map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
|
||||
@@ -288,7 +283,6 @@ pub async fn read_all_config(
|
||||
Ok(Json(ConfigResponse { config: values }))
|
||||
}
|
||||
|
||||
// Modified providers function using the new response type
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/config/providers",
|
||||
@@ -302,14 +296,11 @@ pub async fn providers(
|
||||
) -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Fetch the list of providers, which are likely stored in the AppState or can be retrieved via a function call
|
||||
let providers_metadata = get_providers();
|
||||
|
||||
// Construct the response by checking configuration status for each provider
|
||||
let providers_response: Vec<ProviderDetails> = providers_metadata
|
||||
.into_iter()
|
||||
.map(|metadata| {
|
||||
// Check if the provider is configured (this will depend on how you track configuration status)
|
||||
let is_configured = check_provider_configured(&metadata);
|
||||
|
||||
ProviderDetails {
|
||||
@@ -339,21 +330,16 @@ pub async fn init_config(
|
||||
|
||||
let config = Config::global();
|
||||
|
||||
// 200 if config already exists
|
||||
if config.exists() {
|
||||
return Ok(Json("Config already exists".to_string()));
|
||||
}
|
||||
|
||||
// Find the workspace root (where the top-level Cargo.toml with [workspace] is)
|
||||
let workspace_root = match std::env::current_exe() {
|
||||
Ok(mut exe_path) => {
|
||||
// Start from the executable's directory and traverse up
|
||||
while let Some(parent) = exe_path.parent() {
|
||||
let cargo_toml = parent.join("Cargo.toml");
|
||||
if cargo_toml.exists() {
|
||||
// Read the Cargo.toml file
|
||||
if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
|
||||
// Check if it contains [workspace]
|
||||
if content.contains("[workspace]") {
|
||||
exe_path = parent.to_path_buf();
|
||||
break;
|
||||
@@ -367,7 +353,6 @@ pub async fn init_config(
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
};
|
||||
|
||||
// Check if init-config.yaml exists at workspace root
|
||||
let init_config_path = workspace_root.join("init-config.yaml");
|
||||
if !init_config_path.exists() {
|
||||
return Ok(Json(
|
||||
@@ -375,7 +360,6 @@ pub async fn init_config(
|
||||
));
|
||||
}
|
||||
|
||||
// Read init-config.yaml and validate
|
||||
let init_content = match std::fs::read_to_string(&init_config_path) {
|
||||
Ok(content) => content,
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
@@ -385,7 +369,6 @@ pub async fn init_config(
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
};
|
||||
|
||||
// Save init-config.yaml to ~/.config/goose/config.yaml
|
||||
match config.save_values(init_values) {
|
||||
Ok(_) => Ok(Json("Config initialized successfully".to_string())),
|
||||
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
@@ -409,7 +392,7 @@ pub async fn upsert_permissions(
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let mut permission_manager = PermissionManager::default();
|
||||
// Iterate over each tool permission and update permissions
|
||||
|
||||
for tool_permission in &query.tool_permissions {
|
||||
permission_manager.update_user_permission(
|
||||
&tool_permission.tool_name,
|
||||
@@ -420,12 +403,6 @@ pub async fn upsert_permissions(
|
||||
Ok(Json("Permissions updated successfully".to_string()))
|
||||
}
|
||||
|
||||
pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
|
||||
top_level_domain: "Block".to_string(),
|
||||
author: "Block".to_string(),
|
||||
app_name: "goose".to_string(),
|
||||
});
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/config/backup",
|
||||
@@ -451,11 +428,9 @@ pub async fn backup_config(
|
||||
.file_name()
|
||||
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
// Append ".bak" to the file name
|
||||
let mut backup_name = file_name.to_os_string();
|
||||
backup_name.push(".bak");
|
||||
|
||||
// Construct the new path with the same parent directory
|
||||
let backup = config_path.with_file_name(backup_name);
|
||||
match std::fs::rename(&config_path, &backup) {
|
||||
Ok(_) => Ok(Json(format!("Moved {:?} to {:?}", config_path, backup))),
|
||||
@@ -474,10 +449,55 @@ pub fn routes(state: Arc<AppState>) -> Router {
|
||||
.route("/config/read", post(read_config))
|
||||
.route("/config/extensions", get(get_extensions))
|
||||
.route("/config/extensions", post(add_extension))
|
||||
.route("/config/extensions/:name", delete(remove_extension))
|
||||
.route("/config/extensions/{name}", delete(remove_extension))
|
||||
.route("/config/providers", get(providers))
|
||||
.route("/config/init", post(init_config))
|
||||
.route("/config/backup", post(backup_config))
|
||||
.route("/config/permissions", post(upsert_permissions))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_model_limits() {
|
||||
let test_state = AppState::new(
|
||||
Arc::new(goose::agents::Agent::default()),
|
||||
"test".to_string(),
|
||||
)
|
||||
.await;
|
||||
let sched_storage_path = choose_app_strategy(APP_STRATEGY.clone())
|
||||
.unwrap()
|
||||
.data_dir()
|
||||
.join("schedules.json");
|
||||
let sched = goose::scheduler::Scheduler::new(sched_storage_path)
|
||||
.await
|
||||
.unwrap();
|
||||
test_state.set_scheduler(sched).await;
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Secret-Key", "test".parse().unwrap());
|
||||
|
||||
let result = read_config(
|
||||
State(test_state),
|
||||
headers,
|
||||
Json(ConfigKeyQuery {
|
||||
key: "model-limits".to_string(),
|
||||
is_secret: false,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let response = result.unwrap();
|
||||
|
||||
let limits: Vec<goose::model::ModelLimitConfig> =
|
||||
serde_json::from_value(response.0).unwrap();
|
||||
assert!(!limits.is_empty());
|
||||
|
||||
let gpt4_limit = limits.iter().find(|l| l.pattern == "gpt-4o");
|
||||
assert!(gpt4_limit.is_some());
|
||||
assert_eq!(gpt4_limit.unwrap().context_limit, 128_000);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
use super::utils::verify_secret_key;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
routing::{delete, get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use goose::config::Config;
|
||||
use http::{HeaderMap, StatusCode};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ConfigResponse {
|
||||
error: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct ConfigRequest {
|
||||
key: String,
|
||||
value: String,
|
||||
is_secret: bool,
|
||||
}
|
||||
|
||||
async fn store_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<ConfigRequest>,
|
||||
) -> Result<Json<ConfigResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let config = Config::global();
|
||||
let result = if request.is_secret {
|
||||
config.set_secret(&request.key, Value::String(request.value))
|
||||
} else {
|
||||
config.set_param(&request.key, Value::String(request.value))
|
||||
};
|
||||
match result {
|
||||
Ok(_) => Ok(Json(ConfigResponse { error: false })),
|
||||
Err(_) => Ok(Json(ConfigResponse { error: true })),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ProviderConfigRequest {
|
||||
pub providers: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ConfigStatus {
|
||||
pub is_set: bool,
|
||||
pub location: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ProviderResponse {
|
||||
pub supported: bool,
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub models: Option<Vec<String>>,
|
||||
pub config_status: HashMap<String, ConfigStatus>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ProviderConfig {
|
||||
name: String,
|
||||
description: String,
|
||||
models: Vec<String>,
|
||||
required_keys: Vec<String>,
|
||||
}
|
||||
|
||||
static PROVIDER_ENV_REQUIREMENTS: Lazy<HashMap<String, ProviderConfig>> = Lazy::new(|| {
|
||||
let contents = include_str!("providers_and_keys.json");
|
||||
serde_json::from_str(contents).expect("Failed to parse providers_and_keys.json")
|
||||
});
|
||||
|
||||
fn check_key_status(config: &Config, key: &str) -> (bool, Option<String>) {
|
||||
if let Ok(_value) = std::env::var(key) {
|
||||
(true, Some("env".to_string()))
|
||||
} else if config.get_param::<String>(key).is_ok() {
|
||||
(true, Some("yaml".to_string()))
|
||||
} else if config.get_secret::<String>(key).is_ok() {
|
||||
(true, Some("keyring".to_string()))
|
||||
} else {
|
||||
(false, None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_provider_configs(
|
||||
Json(request): Json<ProviderConfigRequest>,
|
||||
) -> Result<Json<HashMap<String, ProviderResponse>>, StatusCode> {
|
||||
let mut response = HashMap::new();
|
||||
let config = Config::global();
|
||||
|
||||
for provider_name in request.providers {
|
||||
if let Some(provider_config) = PROVIDER_ENV_REQUIREMENTS.get(&provider_name) {
|
||||
let mut config_status = HashMap::new();
|
||||
|
||||
for key in &provider_config.required_keys {
|
||||
let (key_set, key_location) = check_key_status(config, key);
|
||||
config_status.insert(
|
||||
key.to_string(),
|
||||
ConfigStatus {
|
||||
is_set: key_set,
|
||||
location: key_location,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
response.insert(
|
||||
provider_name,
|
||||
ProviderResponse {
|
||||
supported: true,
|
||||
name: Some(provider_config.name.clone()),
|
||||
description: Some(provider_config.description.clone()),
|
||||
models: Some(provider_config.models.clone()),
|
||||
config_status,
|
||||
},
|
||||
);
|
||||
} else {
|
||||
response.insert(
|
||||
provider_name,
|
||||
ProviderResponse {
|
||||
supported: false,
|
||||
name: None,
|
||||
description: None,
|
||||
models: None,
|
||||
config_status: HashMap::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetConfigQuery {
|
||||
key: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GetConfigResponse {
|
||||
value: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn get_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Query(query): Query<GetConfigQuery>,
|
||||
) -> Result<Json<GetConfigResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Fetch the configuration value. Right now we don't allow get a secret.
|
||||
let config = Config::global();
|
||||
let value = if let Ok(config_value) = config.get_param::<String>(&query.key) {
|
||||
Some(config_value)
|
||||
} else {
|
||||
std::env::var(&query.key).ok()
|
||||
};
|
||||
|
||||
// Return the value
|
||||
Ok(Json(GetConfigResponse { value }))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct DeleteConfigRequest {
|
||||
key: String,
|
||||
is_secret: bool,
|
||||
}
|
||||
|
||||
async fn delete_config(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<DeleteConfigRequest>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Attempt to delete the key
|
||||
let config = Config::global();
|
||||
let result = if request.is_secret {
|
||||
config.delete_secret(&request.key)
|
||||
} else {
|
||||
config.delete(&request.key)
|
||||
};
|
||||
match result {
|
||||
Ok(_) => Ok(StatusCode::NO_CONTENT),
|
||||
Err(_) => Err(StatusCode::NOT_FOUND),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/configs/providers", post(check_provider_configs))
|
||||
.route("/configs/get", get(get_config))
|
||||
.route("/configs/store", post(store_config))
|
||||
.route("/configs/delete", delete(delete_config))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unsupported_provider() {
|
||||
// Setup
|
||||
let request = ProviderConfigRequest {
|
||||
providers: vec!["unsupported_provider".to_string()],
|
||||
};
|
||||
|
||||
// Execute
|
||||
let result = check_provider_configs(Json(request)).await;
|
||||
|
||||
// Assert
|
||||
assert!(result.is_ok());
|
||||
let Json(response) = result.unwrap();
|
||||
|
||||
let provider_status = response
|
||||
.get("unsupported_provider")
|
||||
.expect("Provider should exist");
|
||||
assert!(!provider_status.supported);
|
||||
assert!(provider_status.config_status.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -9,23 +9,43 @@ use axum::{
|
||||
use goose::message::Message;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
// Direct message serialization for context mgmt request
|
||||
#[derive(Debug, Deserialize)]
|
||||
/// Request payload for context management operations
|
||||
#[derive(Debug, Deserialize, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ContextManageRequest {
|
||||
messages: Vec<Message>,
|
||||
manage_action: String,
|
||||
/// Collection of messages to be managed
|
||||
pub messages: Vec<Message>,
|
||||
/// Operation to perform: "truncation" or "summarize"
|
||||
pub manage_action: String,
|
||||
}
|
||||
|
||||
// Direct message serialization for context mgmt request
|
||||
#[derive(Debug, Serialize)]
|
||||
/// Response from context management operations
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ContextManageResponse {
|
||||
messages: Vec<Message>,
|
||||
token_counts: Vec<usize>,
|
||||
/// Processed messages after the operation
|
||||
pub messages: Vec<Message>,
|
||||
/// Token counts for each processed message
|
||||
pub token_counts: Vec<usize>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/context/manage",
|
||||
request_body = ContextManageRequest,
|
||||
responses(
|
||||
(status = 200, description = "Context managed successfully", body = ContextManageResponse),
|
||||
(status = 401, description = "Unauthorized - Invalid or missing API key"),
|
||||
(status = 412, description = "Precondition failed - Agent not available"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(
|
||||
("api_key" = [])
|
||||
),
|
||||
tag = "Context Management"
|
||||
)]
|
||||
async fn manage_context(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
@@ -40,7 +60,8 @@ async fn manage_context(
|
||||
|
||||
let mut processed_messages: Vec<Message> = vec![];
|
||||
let mut token_counts: Vec<usize> = vec![];
|
||||
if request.manage_action == "trunction" {
|
||||
|
||||
if request.manage_action == "truncation" {
|
||||
(processed_messages, token_counts) = agent
|
||||
.truncate_context(&request.messages)
|
||||
.await
|
||||
|
||||
@@ -268,12 +268,16 @@ async fn remove_extension(
|
||||
.get_agent()
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
agent.remove_extension(&name).await;
|
||||
|
||||
Ok(Json(ExtensionResponse {
|
||||
error: false,
|
||||
message: None,
|
||||
}))
|
||||
match agent.remove_extension(&name).await {
|
||||
Ok(_) => Ok(Json(ExtensionResponse {
|
||||
error: false,
|
||||
message: None,
|
||||
})),
|
||||
Err(e) => Ok(Json(ExtensionResponse {
|
||||
error: true,
|
||||
message: Some(format!("Failed to remove extension: {:?}", e)),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers the extension management routes with the Axum router.
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
// Export route modules
|
||||
pub mod agent;
|
||||
pub mod config_management;
|
||||
pub mod configs;
|
||||
pub mod context;
|
||||
pub mod extension;
|
||||
pub mod health;
|
||||
pub mod recipe;
|
||||
pub mod reply;
|
||||
pub mod schedule;
|
||||
pub mod session;
|
||||
pub mod utils;
|
||||
use std::sync::Arc;
|
||||
@@ -21,8 +21,8 @@ pub fn configure(state: Arc<crate::state::AppState>) -> Router {
|
||||
.merge(agent::routes(state.clone()))
|
||||
.merge(context::routes(state.clone()))
|
||||
.merge(extension::routes(state.clone()))
|
||||
.merge(configs::routes(state.clone()))
|
||||
.merge(config_management::routes(state.clone()))
|
||||
.merge(recipe::routes(state.clone()))
|
||||
.merge(session::routes(state.clone()))
|
||||
.merge(schedule::routes(state.clone()))
|
||||
}
|
||||
|
||||
@@ -49,9 +49,9 @@
|
||||
},
|
||||
"azure_openai": {
|
||||
"name": "Azure OpenAI",
|
||||
"description": "Connect to Azure OpenAI Service",
|
||||
"description": "Connect to Azure OpenAI Service. If no API key is provided, Azure credential chain will be used.",
|
||||
"models": ["gpt-4o", "gpt-4o-mini"],
|
||||
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
|
||||
"required_keys": ["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
|
||||
},
|
||||
"aws_bedrock": {
|
||||
"name": "AWS Bedrock",
|
||||
|
||||
@@ -35,7 +35,6 @@ use tokio::time::timeout;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
// Direct message serialization for the chat request
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatRequest {
|
||||
messages: Vec<Message>,
|
||||
@@ -43,7 +42,6 @@ struct ChatRequest {
|
||||
session_working_dir: String,
|
||||
}
|
||||
|
||||
// Custom SSE response type for streaming messages
|
||||
pub struct SseResponse {
|
||||
rx: ReceiverStream<String>,
|
||||
}
|
||||
@@ -78,7 +76,6 @@ impl IntoResponse for SseResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// Message event types for SSE streaming
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum MessageEvent {
|
||||
@@ -87,7 +84,6 @@ enum MessageEvent {
|
||||
Finish { reason: String },
|
||||
}
|
||||
|
||||
// Stream a message as an SSE event
|
||||
async fn stream_event(
|
||||
event: MessageEvent,
|
||||
tx: &mpsc::Sender<String>,
|
||||
@@ -108,19 +104,16 @@ async fn handler(
|
||||
) -> Result<SseResponse, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Create channel for streaming
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let stream = ReceiverStream::new(rx);
|
||||
|
||||
let messages = request.messages;
|
||||
let session_working_dir = request.session_working_dir;
|
||||
|
||||
// Generate a new session ID if not provided in the request
|
||||
let session_id = request
|
||||
.session_id
|
||||
.unwrap_or_else(session::generate_session_id);
|
||||
|
||||
// Spawn task to handle streaming
|
||||
tokio::spawn(async move {
|
||||
let agent = state.get_agent().await;
|
||||
let agent = match agent {
|
||||
@@ -166,7 +159,6 @@ async fn handler(
|
||||
}
|
||||
};
|
||||
|
||||
// Get the provider first, before starting the reply stream
|
||||
let provider = agent.provider().await;
|
||||
|
||||
let mut stream = match agent
|
||||
@@ -175,6 +167,7 @@ async fn handler(
|
||||
Some(SessionConfig {
|
||||
id: session::Identifier::Name(session_id.clone()),
|
||||
working_dir: PathBuf::from(session_working_dir),
|
||||
schedule_id: None,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
@@ -200,7 +193,6 @@ async fn handler(
|
||||
}
|
||||
};
|
||||
|
||||
// Collect all messages for storage
|
||||
let mut all_messages = messages.clone();
|
||||
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
|
||||
|
||||
@@ -221,7 +213,7 @@ async fn handler(
|
||||
break;
|
||||
}
|
||||
|
||||
// Store messages and generate description in background
|
||||
|
||||
let session_path = session_path.clone();
|
||||
let messages = all_messages.clone();
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
@@ -255,7 +247,6 @@ async fn handler(
|
||||
}
|
||||
}
|
||||
|
||||
// Send finish event
|
||||
let _ = stream_event(
|
||||
MessageEvent::Finish {
|
||||
reason: "stop".to_string(),
|
||||
@@ -280,7 +271,6 @@ struct AskResponse {
|
||||
response: String,
|
||||
}
|
||||
|
||||
// Simple ask an AI for a response, non streaming
|
||||
async fn ask_handler(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
@@ -290,7 +280,6 @@ async fn ask_handler(
|
||||
|
||||
let session_working_dir = request.session_working_dir;
|
||||
|
||||
// Generate a new session ID if not provided in the request
|
||||
let session_id = request
|
||||
.session_id
|
||||
.unwrap_or_else(session::generate_session_id);
|
||||
@@ -300,13 +289,10 @@ async fn ask_handler(
|
||||
.await
|
||||
.map_err(|_| StatusCode::PRECONDITION_FAILED)?;
|
||||
|
||||
// Get the provider first, before starting the reply stream
|
||||
let provider = agent.provider().await;
|
||||
|
||||
// Create a single message for the prompt
|
||||
let messages = vec![Message::user().with_text(request.prompt)];
|
||||
|
||||
// Get response from agent
|
||||
let mut response_text = String::new();
|
||||
let mut stream = match agent
|
||||
.reply(
|
||||
@@ -314,6 +300,7 @@ async fn ask_handler(
|
||||
Some(SessionConfig {
|
||||
id: session::Identifier::Name(session_id.clone()),
|
||||
working_dir: PathBuf::from(session_working_dir),
|
||||
schedule_id: None,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
@@ -325,7 +312,6 @@ async fn ask_handler(
|
||||
}
|
||||
};
|
||||
|
||||
// Collect all messages for storage
|
||||
let mut all_messages = messages.clone();
|
||||
let mut response_message = Message::assistant();
|
||||
|
||||
@@ -349,15 +335,12 @@ async fn ask_handler(
|
||||
}
|
||||
}
|
||||
|
||||
// Add the complete response message to the conversation history
|
||||
if !response_message.content.is_empty() {
|
||||
all_messages.push(response_message);
|
||||
}
|
||||
|
||||
// Get the session path - file will be created when needed
|
||||
let session_path = session::get_path(session::Identifier::Name(session_id.clone()));
|
||||
|
||||
// Store messages and generate description in background
|
||||
let session_path = session_path.clone();
|
||||
let messages = all_messages.clone();
|
||||
let provider = Arc::clone(provider.as_ref().unwrap());
|
||||
@@ -438,13 +421,11 @@ async fn submit_tool_result(
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Log the raw request for debugging
|
||||
tracing::info!(
|
||||
"Received tool result request: {}",
|
||||
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||
);
|
||||
|
||||
// Try to parse into our struct
|
||||
let payload: ToolResultRequest = match serde_json::from_value(raw.0.clone()) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
@@ -465,7 +446,6 @@ async fn submit_tool_result(
|
||||
Ok(Json(json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
// Configure routes for this module
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/reply", post(handler))
|
||||
@@ -488,7 +468,6 @@ mod tests {
|
||||
};
|
||||
use mcp_core::tool::Tool;
|
||||
|
||||
// Mock Provider implementation for testing
|
||||
#[derive(Clone)]
|
||||
struct MockProvider {
|
||||
model_config: ModelConfig,
|
||||
@@ -523,10 +502,8 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
|
||||
// This test requires tokio runtime
|
||||
#[tokio::test]
|
||||
async fn test_ask_endpoint() {
|
||||
// Create a mock app state with mock provider
|
||||
let mock_model_config = ModelConfig::new("test-model".to_string());
|
||||
let mock_provider = Arc::new(MockProvider {
|
||||
model_config: mock_model_config,
|
||||
@@ -534,11 +511,15 @@ mod tests {
|
||||
let agent = Agent::new();
|
||||
let _ = agent.update_provider(mock_provider).await;
|
||||
let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await;
|
||||
let scheduler_path = goose::scheduler::get_default_scheduler_storage_path()
|
||||
.expect("Failed to get default scheduler storage path");
|
||||
let scheduler = goose::scheduler::Scheduler::new(scheduler_path)
|
||||
.await
|
||||
.unwrap();
|
||||
state.set_scheduler(scheduler).await;
|
||||
|
||||
// Build router
|
||||
let app = routes(state);
|
||||
|
||||
// Create request
|
||||
let request = Request::builder()
|
||||
.uri("/ask")
|
||||
.method("POST")
|
||||
@@ -554,10 +535,8 @@ mod tests {
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Send request
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
|
||||
// Assert response status
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
||||
|
||||
403
crates/goose-server/src/routes/schedule.rs
Normal file
403
crates/goose-server/src/routes/schedule.rs
Normal file
@@ -0,0 +1,403 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
routing::{delete, get, post, put},
|
||||
Json, Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use chrono::NaiveDateTime;
|
||||
|
||||
use crate::routes::utils::verify_secret_key;
|
||||
use crate::state::AppState;
|
||||
use goose::scheduler::ScheduledJob;
|
||||
|
||||
#[derive(Deserialize, Serialize, utoipa::ToSchema)]
|
||||
pub struct CreateScheduleRequest {
|
||||
id: String,
|
||||
recipe_source: String,
|
||||
cron: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, utoipa::ToSchema)]
|
||||
pub struct UpdateScheduleRequest {
|
||||
cron: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, utoipa::ToSchema)]
|
||||
pub struct ListSchedulesResponse {
|
||||
jobs: Vec<ScheduledJob>,
|
||||
}
|
||||
|
||||
// Response for the run_now endpoint
|
||||
#[derive(Serialize, utoipa::ToSchema)]
|
||||
pub struct RunNowResponse {
|
||||
session_id: String,
|
||||
}
|
||||
|
||||
// Query parameters for the sessions endpoint
|
||||
#[derive(Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
|
||||
pub struct SessionsQuery {
|
||||
#[serde(default = "default_limit")]
|
||||
limit: u32,
|
||||
}
|
||||
|
||||
fn default_limit() -> u32 {
|
||||
50 // Default limit for sessions listed
|
||||
}
|
||||
|
||||
// Struct for the frontend session list
|
||||
#[derive(Serialize, utoipa::ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionDisplayInfo {
|
||||
id: String, // Derived from session_name (filename)
|
||||
name: String, // From metadata.description
|
||||
created_at: String, // Derived from session_name, in ISO 8601 format
|
||||
working_dir: String, // from metadata.working_dir (as String)
|
||||
schedule_id: Option<String>,
|
||||
message_count: usize,
|
||||
total_tokens: Option<i32>,
|
||||
input_tokens: Option<i32>,
|
||||
output_tokens: Option<i32>,
|
||||
accumulated_total_tokens: Option<i32>,
|
||||
accumulated_input_tokens: Option<i32>,
|
||||
accumulated_output_tokens: Option<i32>,
|
||||
}
|
||||
|
||||
fn parse_session_name_to_iso(session_name: &str) -> String {
|
||||
NaiveDateTime::parse_from_str(session_name, "%Y%m%d_%H%M%S")
|
||||
.map(|dt| dt.and_utc().to_rfc3339())
|
||||
.unwrap_or_else(|_| String::new()) // Fallback to empty string if parsing fails
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/schedule/create",
|
||||
request_body = CreateScheduleRequest,
|
||||
responses(
|
||||
(status = 200, description = "Scheduled job created successfully", body = ScheduledJob),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
async fn create_schedule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<CreateScheduleRequest>,
|
||||
) -> Result<Json<ScheduledJob>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let job = ScheduledJob {
|
||||
id: req.id,
|
||||
source: req.recipe_source,
|
||||
cron: req.cron,
|
||||
last_run: None,
|
||||
currently_running: false,
|
||||
paused: false,
|
||||
};
|
||||
scheduler
|
||||
.add_scheduled_job(job.clone())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
eprintln!("Error creating schedule: {:?}", e); // Log error
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
Ok(Json(job))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/schedule/list",
|
||||
responses(
|
||||
(status = 200, description = "A list of scheduled jobs", body = ListSchedulesResponse),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
async fn list_schedules(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<ListSchedulesResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let jobs = scheduler.list_scheduled_jobs().await;
|
||||
Ok(Json(ListSchedulesResponse { jobs }))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/schedule/delete/{id}",
|
||||
params(
|
||||
("id" = String, Path, description = "ID of the schedule to delete")
|
||||
),
|
||||
responses(
|
||||
(status = 204, description = "Scheduled job deleted successfully"),
|
||||
(status = 404, description = "Scheduled job not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
async fn delete_schedule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
scheduler.remove_scheduled_job(&id).await.map_err(|e| {
|
||||
eprintln!("Error deleting schedule '{}': {:?}", id, e);
|
||||
match e {
|
||||
goose::scheduler::SchedulerError::JobNotFound(_) => StatusCode::NOT_FOUND,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
})?;
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/schedule/{id}/run_now",
|
||||
params(
|
||||
("id" = String, Path, description = "ID of the schedule to run")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Scheduled job triggered successfully, returns new session ID", body = RunNowResponse),
|
||||
(status = 404, description = "Scheduled job not found"),
|
||||
(status = 500, description = "Internal server error when trying to run the job")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
async fn run_now_handler(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<RunNowResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
match scheduler.run_now(&id).await {
|
||||
Ok(session_id) => Ok(Json(RunNowResponse { session_id })),
|
||||
Err(e) => {
|
||||
eprintln!("Error running schedule '{}' now: {:?}", id, e);
|
||||
match e {
|
||||
goose::scheduler::SchedulerError::JobNotFound(_) => Err(StatusCode::NOT_FOUND),
|
||||
_ => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/schedule/{id}/sessions",
|
||||
params(
|
||||
("id" = String, Path, description = "ID of the schedule"),
|
||||
SessionsQuery // This will automatically pick up 'limit' as a query parameter
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "A list of session display info", body = Vec<SessionDisplayInfo>),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
async fn sessions_handler(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap, // Added this line
|
||||
Path(schedule_id_param): Path<String>, // Renamed to avoid confusion with session_id
|
||||
Query(query_params): Query<SessionsQuery>,
|
||||
) -> Result<Json<Vec<SessionDisplayInfo>>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?; // Added this line
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
match scheduler
|
||||
.sessions(&schedule_id_param, query_params.limit as usize)
|
||||
.await
|
||||
{
|
||||
Ok(session_tuples) => {
|
||||
// Expecting Vec<(String, goose::session::storage::SessionMetadata)>
|
||||
let display_infos: Vec<SessionDisplayInfo> = session_tuples
|
||||
.into_iter()
|
||||
.map(|(session_name, metadata)| SessionDisplayInfo {
|
||||
id: session_name.clone(),
|
||||
name: metadata.description, // Use description as name
|
||||
created_at: parse_session_name_to_iso(&session_name),
|
||||
working_dir: metadata.working_dir.to_string_lossy().into_owned(),
|
||||
schedule_id: metadata.schedule_id, // This is the ID of the schedule itself
|
||||
message_count: metadata.message_count,
|
||||
total_tokens: metadata.total_tokens,
|
||||
input_tokens: metadata.input_tokens,
|
||||
output_tokens: metadata.output_tokens,
|
||||
accumulated_total_tokens: metadata.accumulated_total_tokens,
|
||||
accumulated_input_tokens: metadata.accumulated_input_tokens,
|
||||
accumulated_output_tokens: metadata.accumulated_output_tokens,
|
||||
})
|
||||
.collect();
|
||||
Ok(Json(display_infos))
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"Error fetching sessions for schedule '{}': {:?}",
|
||||
schedule_id_param, e
|
||||
);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/schedule/{id}/pause",
|
||||
params(
|
||||
("id" = String, Path, description = "ID of the schedule to pause")
|
||||
),
|
||||
responses(
|
||||
(status = 204, description = "Scheduled job paused successfully"),
|
||||
(status = 404, description = "Scheduled job not found"),
|
||||
(status = 400, description = "Cannot pause a currently running job"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
async fn pause_schedule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
scheduler.pause_schedule(&id).await.map_err(|e| {
|
||||
eprintln!("Error pausing schedule '{}': {:?}", id, e);
|
||||
match e {
|
||||
goose::scheduler::SchedulerError::JobNotFound(_) => StatusCode::NOT_FOUND,
|
||||
goose::scheduler::SchedulerError::AnyhowError(_) => StatusCode::BAD_REQUEST,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
})?;
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/schedule/{id}/unpause",
|
||||
params(
|
||||
("id" = String, Path, description = "ID of the schedule to unpause")
|
||||
),
|
||||
responses(
|
||||
(status = 204, description = "Scheduled job unpaused successfully"),
|
||||
(status = 404, description = "Scheduled job not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
async fn unpause_schedule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
scheduler.unpause_schedule(&id).await.map_err(|e| {
|
||||
eprintln!("Error unpausing schedule '{}': {:?}", id, e);
|
||||
match e {
|
||||
goose::scheduler::SchedulerError::JobNotFound(_) => StatusCode::NOT_FOUND,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
})?;
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/schedule/{id}",
|
||||
params(
|
||||
("id" = String, Path, description = "ID of the schedule to update")
|
||||
),
|
||||
request_body = UpdateScheduleRequest,
|
||||
responses(
|
||||
(status = 200, description = "Scheduled job updated successfully", body = ScheduledJob),
|
||||
(status = 404, description = "Scheduled job not found"),
|
||||
(status = 400, description = "Cannot update a currently running job or invalid request"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
async fn update_schedule(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateScheduleRequest>,
|
||||
) -> Result<Json<ScheduledJob>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
scheduler
|
||||
.update_schedule(&id, req.cron)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
eprintln!("Error updating schedule '{}': {:?}", id, e);
|
||||
match e {
|
||||
goose::scheduler::SchedulerError::JobNotFound(_) => StatusCode::NOT_FOUND,
|
||||
goose::scheduler::SchedulerError::AnyhowError(_) => StatusCode::BAD_REQUEST,
|
||||
goose::scheduler::SchedulerError::CronParseError(_) => StatusCode::BAD_REQUEST,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
})?;
|
||||
|
||||
// Return the updated schedule
|
||||
let jobs = scheduler.list_scheduled_jobs().await;
|
||||
let updated_job = jobs
|
||||
.into_iter()
|
||||
.find(|job| job.id == id)
|
||||
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(updated_job))
|
||||
}
|
||||
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/schedule/create", post(create_schedule))
|
||||
.route("/schedule/list", get(list_schedules))
|
||||
.route("/schedule/delete/{id}", delete(delete_schedule)) // Corrected
|
||||
.route("/schedule/{id}", put(update_schedule))
|
||||
.route("/schedule/{id}/run_now", post(run_now_handler)) // Corrected
|
||||
.route("/schedule/{id}/pause", post(pause_schedule))
|
||||
.route("/schedule/{id}/unpause", post(unpause_schedule))
|
||||
.route("/schedule/{id}/sessions", get(sessions_handler)) // Corrected
|
||||
.with_state(state)
|
||||
}
|
||||
@@ -11,20 +11,41 @@ use axum::{
|
||||
use goose::message::Message;
|
||||
use goose::session;
|
||||
use goose::session::info::{get_session_info, SessionInfo, SortOrder};
|
||||
use goose::session::SessionMetadata;
|
||||
use serde::Serialize;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SessionListResponse {
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionListResponse {
|
||||
/// List of available session information objects
|
||||
sessions: Vec<SessionInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SessionHistoryResponse {
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionHistoryResponse {
|
||||
/// Unique identifier for the session
|
||||
session_id: String,
|
||||
metadata: session::SessionMetadata,
|
||||
/// Session metadata containing creation time and other details
|
||||
metadata: SessionMetadata,
|
||||
/// List of messages in the session conversation
|
||||
messages: Vec<Message>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/sessions",
|
||||
responses(
|
||||
(status = 200, description = "List of available sessions retrieved successfully", body = SessionListResponse),
|
||||
(status = 401, description = "Unauthorized - Invalid or missing API key"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(
|
||||
("api_key" = [])
|
||||
),
|
||||
tag = "Session Management"
|
||||
)]
|
||||
// List all available sessions
|
||||
async fn list_sessions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
@@ -38,6 +59,23 @@ async fn list_sessions(
|
||||
Ok(Json(SessionListResponse { sessions }))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/sessions/{session_id}",
|
||||
params(
|
||||
("session_id" = String, Path, description = "Unique identifier for the session")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Session history retrieved successfully", body = SessionHistoryResponse),
|
||||
(status = 401, description = "Unauthorized - Invalid or missing API key"),
|
||||
(status = 404, description = "Session not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
security(
|
||||
("api_key" = [])
|
||||
),
|
||||
tag = "Session Management"
|
||||
)]
|
||||
// Get a specific session's history
|
||||
async fn get_session_history(
|
||||
State(state): State<Arc<AppState>>,
|
||||
@@ -70,6 +108,6 @@ async fn get_session_history(
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/sessions", get(list_sessions))
|
||||
.route("/sessions/:session_id", get(get_session_history))
|
||||
.route("/sessions/{session_id}", get(get_session_history))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
use goose::agents::Agent;
|
||||
use goose::scheduler::Scheduler;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Shared reference to an Agent that can be cloned cheaply
|
||||
/// without cloning the underlying Agent object
|
||||
pub type AgentRef = Arc<Agent>;
|
||||
|
||||
/// Thread-safe container for an optional Agent reference
|
||||
/// Outer Arc: Allows multiple route handlers to access the same Mutex
|
||||
/// - Mutex provides exclusive access for updates
|
||||
/// - Option allows for the case where no agent exists yet
|
||||
///
|
||||
/// Shared application state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
// agent: SharedAgentStore,
|
||||
agent: Option<AgentRef>,
|
||||
pub secret_key: String,
|
||||
pub scheduler: Arc<Mutex<Option<Arc<Scheduler>>>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -23,6 +17,7 @@ impl AppState {
|
||||
Arc::new(Self {
|
||||
agent: Some(agent.clone()),
|
||||
secret_key,
|
||||
scheduler: Arc::new(Mutex::new(None)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -31,4 +26,17 @@ impl AppState {
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow::anyhow!("Agent needs to be created first."))
|
||||
}
|
||||
|
||||
pub async fn set_scheduler(&self, sched: Arc<Scheduler>) {
|
||||
let mut guard = self.scheduler.lock().await;
|
||||
*guard = Some(sched);
|
||||
}
|
||||
|
||||
pub async fn scheduler(&self) -> Result<Arc<Scheduler>, anyhow::Error> {
|
||||
self.scheduler
|
||||
.lock()
|
||||
.await
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow::anyhow!("Scheduler not initialized"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,10 +37,10 @@
|
||||
"/config/extension": {
|
||||
"post": {
|
||||
"tags": [
|
||||
"super::routes::config_management"
|
||||
"config"
|
||||
],
|
||||
"summary": "Add an extension configuration",
|
||||
"operationId": "add_extension",
|
||||
"operationId": "add_extension_config",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
@@ -208,6 +208,180 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/schedule/create": {
|
||||
"post": {
|
||||
"tags": ["schedule"],
|
||||
"summary": "Create a new scheduled job",
|
||||
"operationId": "create_schedule",
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CreateScheduleRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Scheduled job created successfully",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ScheduledJob"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal server error"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/schedule/list": {
|
||||
"get": {
|
||||
"tags": ["schedule"],
|
||||
"summary": "List all scheduled jobs",
|
||||
"operationId": "list_schedules",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A list of scheduled jobs",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"jobs": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ScheduledJob"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal server error"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/schedule/delete/{id}": {
|
||||
"delete": {
|
||||
"tags": ["schedule"],
|
||||
"summary": "Delete a scheduled job by ID",
|
||||
"operationId": "delete_schedule",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"description": "ID of the schedule to delete",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "Scheduled job deleted successfully"
|
||||
},
|
||||
"404": {
|
||||
"description": "Scheduled job not found"
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal server error"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/schedule/{id}/run_now": {
|
||||
"post": {
|
||||
"tags": ["schedule"],
|
||||
"summary": "Run a scheduled job immediately",
|
||||
"operationId": "run_schedule_now",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"description": "ID of the schedule to run",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Scheduled job triggered successfully, returns new session ID",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/RunNowResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Scheduled job not found"
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal server error when trying to run the job"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/schedule/{id}/sessions": {
|
||||
"get": {
|
||||
"tags": ["schedule"],
|
||||
"summary": "List sessions created by a specific schedule",
|
||||
"operationId": "list_schedule_sessions",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"description": "ID of the schedule",
|
||||
"schema": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"description": "Maximum number of sessions to return",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"default": 50
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "A list of session metadata",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/SessionMetadata"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "Internal server error"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
@@ -273,7 +447,127 @@
|
||||
"description": "The value to set for the configuration"
|
||||
}
|
||||
}
|
||||
},
|
||||
"CreateScheduleRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"recipe_source",
|
||||
"cron"
|
||||
],
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Unique ID for the new schedule."
|
||||
},
|
||||
"recipe_source": {
|
||||
"type": "string",
|
||||
"description": "Path to the recipe file to be executed by this schedule."
|
||||
},
|
||||
"cron": {
|
||||
"type": "string",
|
||||
"description": "Cron string defining when the job should run."
|
||||
}
|
||||
}
|
||||
},
|
||||
"ScheduledJob": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"source",
|
||||
"cron"
|
||||
],
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the scheduled job."
|
||||
},
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "Path to the recipe file for this job."
|
||||
},
|
||||
"cron": {
|
||||
"type": "string",
|
||||
"description": "Cron string defining the schedule."
|
||||
},
|
||||
"last_run": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "Timestamp of the last time the job was run.",
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"SessionMetadata": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"working_dir",
|
||||
"description",
|
||||
"message_count"
|
||||
],
|
||||
"properties": {
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "Working directory for the session."
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "A short description of the session."
|
||||
},
|
||||
"schedule_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the schedule that triggered this session, if any.",
|
||||
"nullable": true
|
||||
},
|
||||
"message_count": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"description": "Number of messages in the session."
|
||||
},
|
||||
"total_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"nullable": true
|
||||
},
|
||||
"input_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"nullable": true
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"nullable": true
|
||||
},
|
||||
"accumulated_total_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"nullable": true
|
||||
},
|
||||
"accumulated_input_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"nullable": true
|
||||
},
|
||||
"accumulated_output_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"RunNowResponse": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"session_id"
|
||||
],
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the newly created session."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ nanoid = "0.4"
|
||||
sha2 = "0.10"
|
||||
base64 = "0.21"
|
||||
url = "2.5"
|
||||
axum = "0.7"
|
||||
axum = "0.8.1"
|
||||
webbrowser = "0.8"
|
||||
dotenv = "0.15"
|
||||
lazy_static = "1.5"
|
||||
@@ -60,7 +60,8 @@ serde_yaml = "0.9.34"
|
||||
once_cell = "1.20.2"
|
||||
etcetera = "0.8.0"
|
||||
rand = "0.8.5"
|
||||
utoipa = "4.1"
|
||||
utoipa = { version = "4.1", features = ["chrono"] }
|
||||
tokio-cron-scheduler = "0.14.0"
|
||||
|
||||
# For Bedrock provider
|
||||
aws-config = { version = "1.5.16", features = ["behavior-version-latest"] }
|
||||
@@ -73,6 +74,11 @@ jsonwebtoken = "9.3.1"
|
||||
# Added blake3 hashing library as a dependency
|
||||
blake3 = "1.5"
|
||||
fs2 = "0.4.3"
|
||||
futures-util = "0.3.31"
|
||||
|
||||
# Vector database for tool selection
|
||||
lancedb = "0.13"
|
||||
arrow = "52.2"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
@@ -12,18 +12,25 @@ use crate::permission::PermissionConfirmation;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::recipe::{Author, Recipe};
|
||||
use crate::tool_monitor::{ToolCall, ToolMonitor};
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tracing::{debug, error, instrument};
|
||||
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult, ToolInfo};
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionError, ExtensionResult, ToolInfo};
|
||||
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
|
||||
use crate::agents::platform_tools::{
|
||||
PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME,
|
||||
PLATFORM_READ_RESOURCE_TOOL_NAME, PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
|
||||
};
|
||||
use crate::agents::prompt_manager::PromptManager;
|
||||
use crate::agents::router_tool_selector::{
|
||||
create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector,
|
||||
};
|
||||
use crate::agents::router_tools::ROUTER_VECTOR_SEARCH_TOOL_NAME;
|
||||
use crate::agents::tool_router_index_manager::ToolRouterIndexManager;
|
||||
use crate::agents::tool_vectordb::generate_table_id;
|
||||
use crate::agents::types::SessionConfig;
|
||||
use crate::agents::types::{FrontendTool, ToolResultReceiver};
|
||||
use mcp_core::{
|
||||
@@ -31,6 +38,7 @@ use mcp_core::{
|
||||
};
|
||||
|
||||
use super::platform_tools;
|
||||
use super::router_tools;
|
||||
use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
|
||||
|
||||
/// The main goose Agent
|
||||
@@ -44,6 +52,8 @@ pub struct Agent {
|
||||
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
||||
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
pub(super) tool_result_rx: ToolResultReceiver,
|
||||
pub(super) tool_monitor: Mutex<Option<ToolMonitor>>,
|
||||
pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
@@ -62,6 +72,24 @@ impl Agent {
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
||||
tool_monitor: Mutex::new(None),
|
||||
router_tool_selector: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn configure_tool_monitor(&self, max_repetitions: Option<u32>) {
|
||||
let mut tool_monitor = self.tool_monitor.lock().await;
|
||||
*tool_monitor = Some(ToolMonitor::new(max_repetitions));
|
||||
}
|
||||
|
||||
pub async fn get_tool_stats(&self) -> Option<HashMap<String, u32>> {
|
||||
let tool_monitor = self.tool_monitor.lock().await;
|
||||
tool_monitor.as_ref().map(|monitor| monitor.get_stats())
|
||||
}
|
||||
|
||||
pub async fn reset_tool_monitor(&self) {
|
||||
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
|
||||
monitor.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,6 +144,20 @@ impl Agent {
|
||||
tool_call: mcp_core::tool::ToolCall,
|
||||
request_id: String,
|
||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
||||
// Check if this tool call should be allowed based on repetition monitoring
|
||||
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
|
||||
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
|
||||
|
||||
if !monitor.check_tool_call(tool_call_info) {
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(
|
||||
"Tool call rejected: exceeded maximum allowed repetitions".to_string(),
|
||||
)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if tool_call.name == PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME {
|
||||
let extension_name = tool_call
|
||||
.arguments
|
||||
@@ -151,6 +193,15 @@ impl Agent {
|
||||
Err(ToolError::ExecutionError(
|
||||
"Frontend tool execution required".to_string(),
|
||||
))
|
||||
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if let Some(selector) = selector {
|
||||
selector.select_tools(tool_call.arguments.clone()).await
|
||||
} else {
|
||||
Err(ToolError::ExecutionError(
|
||||
"Encountered vector search error.".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
extension_manager
|
||||
.dispatch_tool_call(tool_call.clone())
|
||||
@@ -162,7 +213,10 @@ impl Agent {
|
||||
"output" = serde_json::to_string(&result).unwrap(),
|
||||
);
|
||||
|
||||
(request_id, result)
|
||||
// Process the response to handle large text content
|
||||
let processed_result = super::large_response_handler::process_tool_response(result);
|
||||
|
||||
(request_id, processed_result)
|
||||
}
|
||||
|
||||
pub(super) async fn manage_extensions(
|
||||
@@ -220,6 +274,33 @@ impl Agent {
|
||||
})
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()));
|
||||
|
||||
// Update vector index if operation was successful and vector routing is enabled
|
||||
if result.is_ok() {
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if ToolRouterIndexManager::vector_tool_router_enabled(&selector) {
|
||||
if let Some(selector) = selector {
|
||||
let vector_action = if action == "disable" { "remove" } else { "add" };
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
|
||||
&selector,
|
||||
&extension_manager,
|
||||
&extension_name,
|
||||
vector_action,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(format!(
|
||||
"Failed to update vector index: {}",
|
||||
e
|
||||
))),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(request_id, result)
|
||||
}
|
||||
|
||||
@@ -253,10 +334,32 @@ impl Agent {
|
||||
}
|
||||
_ => {
|
||||
let mut extension_manager = self.extension_manager.lock().await;
|
||||
extension_manager.add_extension(extension).await?;
|
||||
extension_manager.add_extension(extension.clone()).await?;
|
||||
}
|
||||
};
|
||||
|
||||
// If vector tool selection is enabled, index the tools
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if ToolRouterIndexManager::vector_tool_router_enabled(&selector) {
|
||||
if let Some(selector) = selector {
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
|
||||
&selector,
|
||||
&extension_manager,
|
||||
&extension.name(),
|
||||
"add",
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Err(ExtensionError::SetupError(format!(
|
||||
"Failed to index tools for extension {}: {}",
|
||||
extension.name(),
|
||||
e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -282,12 +385,61 @@ impl Agent {
|
||||
prefixed_tools
|
||||
}
|
||||
|
||||
pub async fn remove_extension(&self, name: &str) {
|
||||
pub async fn list_tools_for_router(
|
||||
&self,
|
||||
strategy: Option<RouterToolSelectionStrategy>,
|
||||
) -> Vec<Tool> {
|
||||
let mut prefixed_tools = vec![];
|
||||
match strategy {
|
||||
Some(RouterToolSelectionStrategy::Vector) => {
|
||||
prefixed_tools.push(router_tools::vector_search_tool());
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
// Get recent tool calls from router tool selector if available
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if let Some(selector) = selector {
|
||||
if let Ok(recent_calls) = selector.get_recent_tool_calls(20).await {
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
// Add recent tool calls to the list, avoiding duplicates
|
||||
for tool_name in recent_calls {
|
||||
// Find the tool in the extension manager's tools
|
||||
if let Ok(extension_tools) = extension_manager.get_prefixed_tools(None).await {
|
||||
if let Some(tool) = extension_tools.iter().find(|t| t.name == tool_name) {
|
||||
// Only add if not already in prefixed_tools
|
||||
if !prefixed_tools.iter().any(|t| t.name == tool.name) {
|
||||
prefixed_tools.push(tool.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefixed_tools
|
||||
}
|
||||
|
||||
pub async fn remove_extension(&self, name: &str) -> Result<()> {
|
||||
let mut extension_manager = self.extension_manager.lock().await;
|
||||
extension_manager
|
||||
.remove_extension(name)
|
||||
.await
|
||||
.expect("Failed to remove extension");
|
||||
extension_manager.remove_extension(name).await?;
|
||||
|
||||
// If vector tool selection is enabled, remove tools from the index
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if ToolRouterIndexManager::vector_tool_router_enabled(&selector) {
|
||||
if let Some(selector) = selector {
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
ToolRouterIndexManager::update_extension_tools(
|
||||
&selector,
|
||||
&extension_manager,
|
||||
name,
|
||||
"remove",
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_extensions(&self) -> Vec<String> {
|
||||
@@ -360,6 +512,26 @@ impl Agent {
|
||||
filtered_response) =
|
||||
self.categorize_tool_requests(&response).await;
|
||||
|
||||
// Record tool calls in the router selector
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if let Some(selector) = selector {
|
||||
// Record frontend tool calls
|
||||
for request in &frontend_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||
tracing::error!("Failed to record frontend tool call: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Record remaining tool calls
|
||||
for request in &remaining_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
if let Err(e) = selector.record_tool_call(&tool_call.name).await {
|
||||
tracing::error!("Failed to record tool call: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Yield the assistant's response with frontend tool requests filtered out
|
||||
yield filtered_response.clone();
|
||||
@@ -438,7 +610,7 @@ impl Agent {
|
||||
&permission_check_result.needs_approval,
|
||||
tool_futures_arc.clone(),
|
||||
&mut permission_manager,
|
||||
message_tool_response.clone(),
|
||||
message_tool_response.clone()
|
||||
);
|
||||
|
||||
// We have a stream of tool_approval_requests to handle
|
||||
@@ -511,7 +683,35 @@ impl Agent {
|
||||
|
||||
/// Update the provider used by this agent
|
||||
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
|
||||
*self.provider.lock().await = Some(provider);
|
||||
*self.provider.lock().await = Some(provider.clone());
|
||||
self.update_router_tool_selector(provider).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_router_tool_selector(&self, provider: Arc<dyn Provider>) -> Result<()> {
|
||||
let config = Config::global();
|
||||
let router_tool_selection_strategy = config
|
||||
.get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
|
||||
.unwrap_or_else(|_| "default".to_string());
|
||||
|
||||
let strategy = match router_tool_selection_strategy.to_lowercase().as_str() {
|
||||
"vector" => Some(RouterToolSelectionStrategy::Vector),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(strategy) = strategy {
|
||||
let table_name = generate_table_id();
|
||||
let selector = create_tool_selector(Some(strategy), provider, table_name)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
|
||||
|
||||
let selector = Arc::new(selector);
|
||||
*self.router_tool_selector.lock().await = Some(selector.clone());
|
||||
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -593,6 +793,7 @@ impl Agent {
|
||||
self.frontend_instructions.lock().await.clone(),
|
||||
extension_manager.suggest_disable_extensions_prompt().await,
|
||||
Some(model_name),
|
||||
None,
|
||||
);
|
||||
|
||||
let recipe_prompt = prompt_manager.get_recipe_prompt().await;
|
||||
|
||||
247
crates/goose/src/agents/large_response_handler.rs
Normal file
247
crates/goose/src/agents/large_response_handler.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use chrono::Utc;
|
||||
use mcp_core::{Content, ToolError};
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
|
||||
// Constant for the size threshold (20K characters)
|
||||
const LARGE_TEXT_THRESHOLD: usize = 20_000;
|
||||
|
||||
/// Process tool response and handle large text content
|
||||
pub fn process_tool_response(
|
||||
response: Result<Vec<Content>, ToolError>,
|
||||
) -> Result<Vec<Content>, ToolError> {
|
||||
match response {
|
||||
Ok(contents) => {
|
||||
let mut processed_contents = Vec::new();
|
||||
|
||||
for content in contents {
|
||||
match content {
|
||||
Content::Text(text_content) => {
|
||||
// Check if text exceeds threshold
|
||||
if text_content.text.len() > LARGE_TEXT_THRESHOLD {
|
||||
// Write to temp file
|
||||
match write_large_text_to_file(&text_content.text) {
|
||||
Ok(file_path) => {
|
||||
// Create a new text content with reference to the file
|
||||
let message = format!(
|
||||
"The response returned from the tool call was larger ({} characters) and is stored in the file which you can use other tools to examine or search in: {}",
|
||||
text_content.text.len(),
|
||||
file_path
|
||||
);
|
||||
processed_contents.push(Content::text(message));
|
||||
}
|
||||
Err(e) => {
|
||||
// If file writing fails, include original content with warning
|
||||
let warning = format!(
|
||||
"Warning: Failed to write large response to file: {}. Showing full content instead.\n\n{}",
|
||||
e,
|
||||
text_content.text
|
||||
);
|
||||
processed_contents.push(Content::text(warning));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Keep original content for smaller texts
|
||||
processed_contents.push(Content::Text(text_content));
|
||||
}
|
||||
}
|
||||
// Pass through other content types unchanged
|
||||
_ => processed_contents.push(content),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(processed_contents)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Write large text content to a temporary file
|
||||
fn write_large_text_to_file(content: &str) -> Result<String, std::io::Error> {
|
||||
// Create temp directory if it doesn't exist
|
||||
let temp_dir = std::env::temp_dir().join("goose_mcp_responses");
|
||||
std::fs::create_dir_all(&temp_dir)?;
|
||||
|
||||
// Generate a unique filename with timestamp
|
||||
let timestamp = Utc::now().format("%Y%m%d_%H%M%S");
|
||||
let filename = format!("mcp_response_{}.txt", timestamp);
|
||||
let file_path = temp_dir.join(&filename);
|
||||
|
||||
// Write content to file
|
||||
let mut file = File::create(&file_path)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
|
||||
Ok(file_path.to_string_lossy().to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_core::{Content, ImageContent, TextContent, ToolError};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
#[test]
|
||||
fn test_small_text_response_passes_through() {
|
||||
// Create a small text response
|
||||
let small_text = "This is a small text response";
|
||||
let content = Content::Text(TextContent {
|
||||
text: small_text.to_string(),
|
||||
annotations: None,
|
||||
});
|
||||
|
||||
let response = Ok(vec![content]);
|
||||
|
||||
// Process the response
|
||||
let processed = process_tool_response(response).unwrap();
|
||||
|
||||
// Verify the response is unchanged
|
||||
assert_eq!(processed.len(), 1);
|
||||
if let Content::Text(text_content) = &processed[0] {
|
||||
assert_eq!(text_content.text, small_text);
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_text_response_redirected_to_file() {
|
||||
// Create a text larger than the threshold
|
||||
let large_text = "a".repeat(LARGE_TEXT_THRESHOLD + 1000);
|
||||
let content = Content::Text(TextContent {
|
||||
text: large_text.clone(),
|
||||
annotations: None,
|
||||
});
|
||||
|
||||
let response = Ok(vec![content]);
|
||||
|
||||
// Process the response
|
||||
let processed = process_tool_response(response).unwrap();
|
||||
|
||||
// Verify the response contains a message about the file
|
||||
assert_eq!(processed.len(), 1);
|
||||
if let Content::Text(text_content) = &processed[0] {
|
||||
assert!(text_content
|
||||
.text
|
||||
.contains("The response returned from the tool call was larger"));
|
||||
assert!(text_content.text.contains("characters"));
|
||||
|
||||
// Extract the file path from the message
|
||||
if let Some(file_path) = text_content.text.split("stored in the file: ").nth(1) {
|
||||
// Verify the file exists and contains the original text
|
||||
let path = Path::new(file_path.trim());
|
||||
if path.exists() {
|
||||
// Only check content if file exists (may not exist in CI environments)
|
||||
if let Ok(file_content) = fs::read_to_string(path) {
|
||||
assert_eq!(file_content, large_text);
|
||||
}
|
||||
|
||||
// Clean up the file
|
||||
let _ = fs::remove_file(path); // Ignore errors on cleanup
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_content_passes_through() {
|
||||
// Create an image content
|
||||
let image_content = Content::Image(ImageContent {
|
||||
data: "base64data".to_string(),
|
||||
mime_type: "image/png".to_string(),
|
||||
annotations: None,
|
||||
});
|
||||
|
||||
let response = Ok(vec![image_content]);
|
||||
|
||||
// Process the response
|
||||
let processed = process_tool_response(response).unwrap();
|
||||
|
||||
// Verify the response is unchanged
|
||||
assert_eq!(processed.len(), 1);
|
||||
match &processed[0] {
|
||||
Content::Image(img) => {
|
||||
assert_eq!(img.data, "base64data");
|
||||
assert_eq!(img.mime_type, "image/png");
|
||||
}
|
||||
_ => panic!("Expected image content"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_content_handled_correctly() {
|
||||
// Create a response with mixed content types
|
||||
let small_text = Content::text("Small text");
|
||||
let large_text = Content::Text(TextContent {
|
||||
text: "a".repeat(LARGE_TEXT_THRESHOLD + 1000),
|
||||
annotations: None,
|
||||
});
|
||||
let image = Content::Image(ImageContent {
|
||||
data: "image_data".to_string(),
|
||||
mime_type: "image/jpeg".to_string(),
|
||||
annotations: None,
|
||||
});
|
||||
|
||||
let response = Ok(vec![small_text, large_text, image]);
|
||||
|
||||
// Process the response
|
||||
let processed = process_tool_response(response).unwrap();
|
||||
|
||||
// Verify each item is handled correctly
|
||||
assert_eq!(processed.len(), 3);
|
||||
|
||||
// First item should be unchanged small text
|
||||
if let Content::Text(text_content) = &processed[0] {
|
||||
assert_eq!(text_content.text, "Small text");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
}
|
||||
|
||||
// Second item should be a message about the file
|
||||
if let Content::Text(text_content) = &processed[1] {
|
||||
assert!(text_content
|
||||
.text
|
||||
.contains("The response returned from the tool call was larger"));
|
||||
|
||||
// Extract the file path and clean up
|
||||
if let Some(file_path) = text_content.text.split("stored in the file: ").nth(1) {
|
||||
let path = Path::new(file_path.trim());
|
||||
if path.exists() {
|
||||
let _ = fs::remove_file(path); // Ignore errors on cleanup
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
}
|
||||
|
||||
// Third item should be unchanged image
|
||||
match &processed[2] {
|
||||
Content::Image(img) => {
|
||||
assert_eq!(img.data, "image_data");
|
||||
assert_eq!(img.mime_type, "image/jpeg");
|
||||
}
|
||||
_ => panic!("Expected image content"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response_passes_through() {
|
||||
// Create an error response
|
||||
let error = ToolError::ExecutionError("Test error".to_string());
|
||||
let response: Result<Vec<Content>, ToolError> = Err(error);
|
||||
|
||||
// Process the response
|
||||
let processed = process_tool_response(response);
|
||||
|
||||
// Verify the error is passed through unchanged
|
||||
assert!(processed.is_err());
|
||||
match processed {
|
||||
Err(ToolError::ExecutionError(msg)) => {
|
||||
assert_eq!(msg, "Test error");
|
||||
}
|
||||
_ => panic!("Expected execution error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,15 @@ mod agent;
|
||||
mod context;
|
||||
pub mod extension;
|
||||
pub mod extension_manager;
|
||||
mod large_response_handler;
|
||||
pub mod platform_tools;
|
||||
pub mod prompt_manager;
|
||||
mod reply_parts;
|
||||
mod router_tool_selector;
|
||||
mod router_tools;
|
||||
mod tool_execution;
|
||||
mod tool_router_index_manager;
|
||||
pub(crate) mod tool_vectordb;
|
||||
mod types;
|
||||
|
||||
pub use agent::Agent;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user