Merge branch 'block:main' into goose-api

This commit is contained in:
2025-05-29 10:18:25 +02:00
committed by GitHub
435 changed files with 64808 additions and 6579 deletions

View File

@@ -21,6 +21,11 @@ on:
required: false
default: true
type: boolean
ref:
description: 'Git ref to checkout (branch, tag, or SHA). Defaults to main branch if not specified.'
required: false
type: string
default: ''
secrets:
CERTIFICATE_OSX_APPLICATION:
description: 'Certificate for macOS application signing'
@@ -45,6 +50,30 @@ jobs:
runs-on: macos-latest
name: Bundle Desktop App on macOS
steps:
# Debug information about the workflow and inputs
- name: Debug workflow info
env:
WORKFLOW_NAME: ${{ github.workflow }}
WORKFLOW_REF: ${{ github.ref }}
EVENT_NAME: ${{ github.event_name }}
REPOSITORY: ${{ github.repository }}
INPUT_REF: ${{ inputs.ref }}
INPUT_VERSION: ${{ inputs.version }}
INPUT_SIGNING: ${{ inputs.signing }}
INPUT_QUICK_TEST: ${{ inputs.quick_test }}
run: |
echo "=== Workflow Information ==="
echo "Workflow: ${WORKFLOW_NAME}"
echo "Ref: ${WORKFLOW_REF}"
echo "Event: ${EVENT_NAME}"
echo "Repo: ${REPOSITORY}"
echo ""
echo "=== Input Parameters ==="
echo "Build ref: ${INPUT_REF:-<default branch>}"
echo "Version: ${INPUT_VERSION:-not set}"
echo "Signing: ${INPUT_SIGNING:-false}"
echo "Quick test: ${INPUT_QUICK_TEST:-true}"
# Check initial disk space
- name: Check initial disk space
run: df -h
@@ -52,43 +81,63 @@ jobs:
# Validate Signing Secrets if signing is enabled
- name: Validate Signing Secrets
if: ${{ inputs.signing }}
env:
HAS_CERT: ${{ secrets.CERTIFICATE_OSX_APPLICATION != '' }}
HAS_CERT_PASS: ${{ secrets.CERTIFICATE_PASSWORD != '' }}
HAS_APPLE_ID: ${{ secrets.APPLE_ID != '' }}
HAS_APPLE_PASS: ${{ secrets.APPLE_ID_PASSWORD != '' }}
HAS_TEAM_ID: ${{ secrets.APPLE_TEAM_ID != '' }}
run: |
if [[ -z "${{ secrets.CERTIFICATE_OSX_APPLICATION }}" ]]; then
echo "Error: CERTIFICATE_OSX_APPLICATION secret is required for signing."
exit 1
fi
if [[ -z "${{ secrets.CERTIFICATE_PASSWORD }}" ]]; then
echo "Error: CERTIFICATE_PASSWORD secret is required for signing."
exit 1
fi
if [[ -z "${{ secrets.APPLE_ID }}" ]]; then
echo "Error: APPLE_ID secret is required for signing."
exit 1
fi
if [[ -z "${{ secrets.APPLE_ID_PASSWORD }}" ]]; then
echo "Error: APPLE_ID_PASSWORD secret is required for signing."
exit 1
fi
if [[ -z "${{ secrets.APPLE_TEAM_ID }}" ]]; then
echo "Error: APPLE_TEAM_ID secret is required for signing."
missing=()
[[ "${HAS_CERT}" != "true" ]] && missing+=("CERTIFICATE_OSX_APPLICATION")
[[ "${HAS_CERT_PASS}" != "true" ]] && missing+=("CERTIFICATE_PASSWORD")
[[ "${HAS_APPLE_ID}" != "true" ]] && missing+=("APPLE_ID")
[[ "${HAS_APPLE_PASS}" != "true" ]] && missing+=("APPLE_ID_PASSWORD")
[[ "${HAS_TEAM_ID}" != "true" ]] && missing+=("APPLE_TEAM_ID")
if (( ${#missing[@]} > 0 )); then
echo "Error: Missing required signing secrets:"
printf '%s\n' "${missing[@]}"
exit 1
fi
echo "All required signing secrets are present."
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
# Only pass ref if it's explicitly set, otherwise let checkout action use its default behavior
ref: ${{ inputs.ref != '' && inputs.ref || '' }}
fetch-depth: 0
- name: Debug git status
run: |
echo "=== Git Status ==="
git status
echo ""
echo "=== Current Commit ==="
git rev-parse HEAD
git rev-parse --abbrev-ref HEAD
echo ""
echo "=== Recent Commits ==="
git log --oneline -n 5
echo ""
echo "=== Remote Branches ==="
git branch -r
# Update versions before build
- name: Update versions
if: ${{ inputs.version != '' }}
env:
VERSION: ${{ inputs.version }}
run: |
# Update version in Cargo.toml
sed -i.bak 's/^version = ".*"/version = "'${{ inputs.version }}'"/' Cargo.toml
sed -i.bak "s/^version = \".*\"/version = \"${VERSION}\"/" Cargo.toml
rm -f Cargo.toml.bak
# Update version in package.json
cd ui/desktop
npm version ${{ inputs.version }} --no-git-tag-version --allow-same-version
npm version "${VERSION}" --no-git-tag-version --allow-same-version
# Pre-build cleanup to ensure enough disk space
- name: Pre-build cleanup
@@ -105,6 +154,11 @@ jobs:
# Check disk space after cleanup
df -h
- name: Install protobuf
run: |
brew install protobuf
echo "PROTOC=$(which protoc)" >> $GITHUB_ENV
- name: Setup Rust
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
with:
@@ -194,6 +248,10 @@ jobs:
- name: Make Signed App
if: ${{ inputs.signing }}
env:
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_ID_PASSWORD: ${{ secrets.APPLE_ID_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
run: |
attempt=0
max_attempts=2
@@ -208,10 +266,6 @@ jobs:
exit 1
fi
working-directory: ui/desktop
env:
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_ID_PASSWORD: ${{ secrets.APPLE_ID_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
- name: Final cleanup before artifact upload
run: |

View File

@@ -1,12 +1,8 @@
on:
push:
paths-ignore:
- "documentation/**"
branches:
- main
pull_request:
paths-ignore:
- "documentation/**"
branches:
- main
workflow_dispatch:
@@ -61,7 +57,7 @@ jobs:
- name: Install Dependencies
run: |
sudo apt update -y
sudo apt install -y libdbus-1-dev gnome-keyring libxcb1-dev
sudo apt install -y libdbus-1-dev gnome-keyring libxcb1-dev protobuf-compiler
- name: Setup Rust
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable

View File

@@ -28,9 +28,26 @@ jobs:
runs-on: ubuntu-latest
outputs:
continue: ${{ steps.command.outputs.continue || github.event_name == 'workflow_dispatch' }}
# Cannot use github.event.pull_request.number since the trigger is 'issue_comment'
pr_number: ${{ steps.command.outputs.issue_number || github.event.inputs.pr_number }}
pr_sha: ${{ steps.get_pr_info.outputs.sha }}
steps:
- name: Debug workflow trigger
env:
WORKFLOW_NAME: ${{ github.workflow }}
WORKFLOW_REF: ${{ github.ref }}
EVENT_NAME: ${{ github.event_name }}
EVENT_ACTION: ${{ github.event.action }}
ACTOR: ${{ github.actor }}
REPOSITORY: ${{ github.repository }}
run: |
echo "=== Workflow Trigger Info ==="
echo "Workflow: ${WORKFLOW_NAME}"
echo "Ref: ${WORKFLOW_REF}"
echo "Event: ${EVENT_NAME}"
echo "Action: ${EVENT_ACTION}"
echo "Actor: ${ACTOR}"
echo "Repository: ${REPOSITORY}"
- if: ${{ github.event_name == 'issue_comment' }}
uses: github/command@319d5236cc34ed2cb72a47c058a363db0b628ebe # pin@v1.3.0
id: command
@@ -40,13 +57,56 @@ jobs:
reaction: "eyes"
allowed_contexts: pull_request
# Get the PR's SHA
- name: Get PR info
id: get_pr_info
if: ${{ steps.command.outputs.continue == 'true' || github.event_name == 'workflow_dispatch' }}
uses: actions/github-script@v7
with:
script: |
let prNumber;
if (context.eventName === 'workflow_dispatch') {
prNumber = context.payload.inputs.pr_number;
} else {
prNumber = context.payload.issue.number;
}
if (!prNumber) {
throw new Error('No PR number found');
}
console.log('Using PR number:', prNumber);
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: parseInt(prNumber, 10)
});
console.log('PR Details:', {
number: pr.number,
head: {
ref: pr.head.ref,
sha: pr.head.sha,
label: pr.head.label
},
base: {
ref: pr.base.ref,
sha: pr.base.sha,
label: pr.base.label
}
});
core.setOutput('sha', pr.head.sha);
bundle-desktop:
# Only run this if ".bundle" command is detected.
needs: [trigger-on-command]
if: ${{ needs.trigger-on-command.outputs.continue == 'true' }}
uses: ./.github/workflows/bundle-desktop.yml
with:
signing: true
ref: ${{ needs.trigger-on-command.outputs.pr_sha }}
secrets:
CERTIFICATE_OSX_APPLICATION: ${{ secrets.CERTIFICATE_OSX_APPLICATION }}
CERTIFICATE_PASSWORD: ${{ secrets.CERTIFICATE_PASSWORD }}

3
.gitignore vendored
View File

@@ -1,3 +1,6 @@
__pycache__
*.pyc
*.jar
run_cli.sh
tokenizer_files/
.DS_Store

View File

@@ -4,4 +4,10 @@
if git diff --cached --name-only | grep -q "^ui/desktop/"; then
. "$(dirname -- "$0")/_/husky.sh"
cd ui/desktop && npx lint-staged
fi
fi
# Only auto-format ui-v2 TS code if relevant files are modified
if git diff --cached --name-only | grep -q "^ui-v2/"; then
. "$(dirname -- "$0")/_/husky.sh"
cd ui-v2 && npx lint-staged
fi

2836
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ resolver = "2"
[workspace.package]
edition = "2021"
version = "1.0.20"
version = "1.0.24"
authors = ["Block <ai-oss-tools@block.xyz>"]
license = "Apache-2.0"
repository = "https://github.com/block/goose"

View File

@@ -79,6 +79,11 @@ run-ui:
@echo "Running UI..."
cd ui/desktop && npm install && npm run start-gui
run-ui-only:
@echo "Running UI..."
cd ui/desktop && npm install && npm run start-gui
# Run UI with alpha changes
run-ui-alpha:
@just release-binary
@@ -319,3 +324,24 @@ win-total-dbg *allparam:
win-total-rls *allparam:
just win-bld-rls{{allparam}}
just win-run-rls
### Build and run the Kotlin example with
### auto-generated bindings for goose-llm
kotlin-example:
# Build Rust dylib and generate Kotlin bindings
cargo build -p goose-llm
cargo run --features=uniffi/cli --bin uniffi-bindgen generate \
--library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
# Compile and run the Kotlin example
cd bindings/kotlin/ && kotlinc \
example/Usage.kt \
uniffi/goose_llm/goose_llm.kt \
-classpath "libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
-include-runtime \
-d example.jar
cd bindings/kotlin/ && java \
-Djna.library.path=$HOME/Development/goose/target/debug \
-classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
UsageKt

View File

@@ -0,0 +1,200 @@
import kotlinx.coroutines.runBlocking
import uniffi.goose_llm.*
fun main() = runBlocking {
val now = System.currentTimeMillis() / 1000
val msgs = listOf(
// 1) User sends a plain-text prompt
Message(
role = Role.USER,
created = now,
content = listOf(
MessageContent.Text(
TextContent("What is 7 x 6?")
)
)
),
// 2) Assistant makes a tool request (ToolReq) to calculate 7×6
Message(
role = Role.ASSISTANT,
created = now + 2,
content = listOf(
MessageContent.ToolReq(
ToolRequest(
id = "calc1",
toolCall = """
{
"status": "success",
"value": {
"name": "calculator_extension__toolname",
"arguments": {
"operation": "multiply",
"numbers": [7, 6]
},
"needsApproval": false
}
}
""".trimIndent()
)
)
)
),
// 3) User (on behalf of the tool) responds with the tool result (ToolResp)
Message(
role = Role.USER,
created = now + 3,
content = listOf(
MessageContent.ToolResp(
ToolResponse(
id = "calc1",
toolResult = """
{
"status": "success",
"value": [
{"type": "text", "text": "42"}
]
}
""".trimIndent()
)
)
)
),
)
printMessages(msgs)
println("---\n")
// Setup provider
val providerName = "databricks"
val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set")
val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set")
val providerConfig = """{"host": "$host", "token": "$token"}"""
println("Provider Name: $providerName")
println("Provider Config: $providerConfig")
val sessionName = generateSessionName(providerName, providerConfig, msgs)
println("\nSession Name: $sessionName")
val tooltip = generateTooltip(providerName, providerConfig, msgs)
println("\nTooltip: $tooltip")
// Completion
val modelName = "goose-gpt-4-1"
val modelConfig = ModelConfig(
modelName,
100000u, // UInt
0.1f, // Float
200 // Int
)
val calculatorTool = createToolConfig(
name = "calculator",
description = "Perform basic arithmetic operations",
inputSchema = """
{
"type": "object",
"required": ["operation", "numbers"],
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"numbers": {
"type": "array",
"items": { "type": "number" },
"description": "List of numbers to operate on in order"
}
}
}
""".trimIndent(),
approvalMode = ToolApprovalMode.AUTO
)
val calculator_extension = ExtensionConfig(
name = "calculator_extension",
instructions = "This extension provides a calculator tool.",
tools = listOf(calculatorTool)
)
val extensions = listOf(calculator_extension)
val systemPreamble = "You are a helpful assistant."
val req = createCompletionRequest(
providerName,
providerConfig,
modelConfig,
systemPreamble,
msgs,
extensions
)
val response = completion(req)
println("\nCompletion Response:\n${response.message}")
println()
// ---- UI Extraction (custom schema) ----
runUiExtraction(providerName, providerConfig)
}
suspend fun runUiExtraction(providerName: String, providerConfig: String) {
val systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI."
val messages = listOf(
Message(
role = Role.USER,
created = System.currentTimeMillis() / 1000,
content = listOf(
MessageContent.Text(
TextContent("Make a User Profile Form")
)
)
)
)
val schema = """{
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["div","button","header","section","field","form"]
},
"label": { "type": "string" },
"children": {
"type": "array",
"items": { "${'$'}ref": "#" }
},
"attributes": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"value": { "type": "string" }
},
"required": ["name","value"],
"additionalProperties": false
}
}
},
"required": ["type","label","children","attributes"],
"additionalProperties": false
}""".trimIndent();
try {
val response = generateStructuredOutputs(
providerName = providerName,
providerConfig = providerConfig,
systemPrompt = systemPrompt,
messages = messages,
schema = schema
)
println("\nUI Extraction Output:\n${response}")
} catch (e: ProviderException) {
println("\nUI Extraction failed:\n${e.message}")
}
}

File diff suppressed because it is too large Load Diff

3001
bindings/python/goose_llm.py Normal file

File diff suppressed because it is too large Load Diff

133
bindings/python/usage.py Normal file
View File

@@ -0,0 +1,133 @@
import asyncio
import os
import time
from goose_llm import (
Message, MessageContent, TextContent, ToolRequest, ToolResponse,
Role, ModelConfig, ToolApprovalMode,
create_tool_config, ExtensionConfig,
generate_session_name, generate_tooltip,
create_completion_request, completion
)
async def main():
now = int(time.time())
# 1) User sends a plain-text prompt
messages = [
Message(
role=Role.USER,
created=now,
content=[MessageContent.TEXT(TextContent(text="What is 7 x 6?"))]
),
# 2) Assistant makes a tool request
Message(
role=Role.ASSISTANT,
created=now + 2,
content=[MessageContent.TOOL_REQ(ToolRequest(
id="calc1",
tool_call="""
{
"status": "success",
"value": {
"name": "calculator_extension__toolname",
"arguments": {
"operation": "multiply",
"numbers": [7, 6]
},
"needsApproval": false
}
}
"""
))]
),
# 3) User sends tool result
Message(
role=Role.USER,
created=now + 3,
content=[MessageContent.TOOL_RESP(ToolResponse(
id="calc1",
tool_result="""
{
"status": "success",
"value": [
{"type": "text", "text": "42"}
]
}
"""
))]
)
]
provider_name = "databricks"
provider_config = f'''{{
"host": "{os.environ.get("DATABRICKS_HOST")}",
"token": "{os.environ.get("DATABRICKS_TOKEN")}"
}}'''
print(f"Provider Name: {provider_name}")
print(f"Provider Config: {provider_config}")
session_name = await generate_session_name(provider_name, provider_config, messages)
print(f"\nSession Name: {session_name}")
tooltip = await generate_tooltip(provider_name, provider_config, messages)
print(f"\nTooltip: {tooltip}")
model_config = ModelConfig(
model_name="goose-gpt-4-1",
max_tokens=500,
temperature=0.1,
context_limit=4096,
)
calculator_tool = create_tool_config(
name="calculator",
description="Perform basic arithmetic operations",
input_schema="""
{
"type": "object",
"required": ["operation", "numbers"],
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"numbers": {
"type": "array",
"items": { "type": "number" },
"description": "List of numbers to operate on in order"
}
}
}
""",
approval_mode=ToolApprovalMode.AUTO
)
calculator_extension = ExtensionConfig(
name="calculator_extension",
instructions="This extension provides a calculator tool.",
tools=[calculator_tool]
)
system_preamble = "You are a helpful assistant."
extensions = [calculator_extension]
req = create_completion_request(
provider_name,
provider_config,
model_config,
system_preamble,
messages,
extensions
)
resp = await completion(req)
print(f"\nCompletion Response:\n{resp.message}")
print(f"Msg content: {resp.message.content[0][0]}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -25,6 +25,7 @@ include_dir = "0.7.4"
once_cell = "1.19"
regex = "1.11.1"
toml = "0.8.20"
dotenvy = "0.15.7"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -0,0 +1,273 @@
# Goose Benchmarking Framework
The `goose-bench` crate provides a framework for benchmarking and evaluating LLM models with the Goose framework. This tool helps quantify model performance across various tasks and generate structured reports.
## Features
- Run benchmark suites across multiple LLM models
- Execute evaluations in parallel when supported
- Generate structured JSON and CSV reports
- Process evaluation results with custom scripts
- Calculate aggregate metrics across evaluations
- Support for tool-shim evaluation
- Generate leaderboards and comparative metrics
## Prerequisites
- **Python Environment**: The `generate-leaderboard` command executes Python scripts and requires a valid Python environment with necessary dependencies (pandas, etc.)
- **OpenAI API Key**: For evaluations using LLM-as-judge (like `blog_summary` and `restaurant_research`), you must have an `OPENAI_API_KEY` environment variable set, as the judge uses the OpenAI GPT-4o model
## Benchmark Workflow
Running benchmarks is a two-step process:
### Step 1: Run Benchmarks
First, run the benchmark evaluations with your configuration:
```bash
goose bench run --config /path/to/your-config.json
```
This will execute all evaluations for all models specified in your configuration and create a benchmark directory with results.
### Step 2: Generate Leaderboard
After the benchmarks complete, generate the leaderboard and aggregated metrics:
```bash
goose bench generate-leaderboard --benchmark-dir /path/to/benchmark-output-directory
```
The benchmark directory path will be shown in the output of the previous command, typically in the format `benchmark-YYYY-MM-DD-HH:MM:SS`.
**Note**: This command requires a valid Python environment as it executes Python scripts for data aggregation and leaderboard generation.
## Configuration
Benchmark configuration is provided through a JSON file. Here's a sample configuration file (leaderboard-config.json) that you can use as a template:
```json
{
"models": [
{
"provider": "databricks",
"name": "gpt-4-1-mini",
"parallel_safe": true,
"tool_shim": {
"use_tool_shim": false,
"tool_shim_model": null
}
},
{
"provider": "databricks",
"name": "claude-3-5-sonnet",
"parallel_safe": true,
"tool_shim": null
},
{
"provider": "databricks",
"name": "gpt-4o",
"parallel_safe": true,
"tool_shim": null
}
],
"evals": [
{
"selector": "core:developer",
"post_process_cmd": null,
"parallel_safe": true
},
{
"selector": "core:developer_search_replace",
"post_process_cmd": null,
"parallel_safe": true
},
{
"selector": "vibes:blog_summary",
"post_process_cmd": "/Users/ahau/Development/goose-1.0/goose/scripts/bench-postprocess-scripts/llm-judges/run_vibes_judge.sh",
"parallel_safe": true
},
{
"selector": "vibes:restaurant_research",
"post_process_cmd": "/Users/ahau/Development/goose-1.0/goose/scripts/bench-postprocess-scripts/llm-judges/run_vibes_judge.sh",
"parallel_safe": true
}
],
"include_dirs": [],
"repeat": 3,
"run_id": null,
"output_dir": "/path/to/output/directory",
"eval_result_filename": "eval-results.json",
"run_summary_filename": "run-results-summary.json",
"env_file": "/path/to/.goosebench.env"
}
```
## Configuration Options
### Models
- `provider`: The LLM provider (e.g., "databricks", "openai")
- `name`: The model name
- `parallel_safe`: Whether the model can be run in parallel
- `tool_shim`: Configuration for tool-shim support
- `use_tool_shim`: Whether to use tool-shim
- `tool_shim_model`: Optional custom model for tool-shim
### Evaluations
- `selector`: The evaluation selector in format `suite:evaluation`
- `post_process_cmd`: Optional path to a post-processing script
- `parallel_safe`: Whether the evaluation can be run in parallel
### Global Configuration
- `include_dirs`: Additional directories to include in the benchmark environment
- `repeat`: Number of times to repeat evaluations (for statistical significance)
- `run_id`: Optional identifier for the run (defaults to timestamp)
- `output_dir`: Directory to store benchmark results (must be absolute path)
- `eval_result_filename`: Filename for individual evaluation results
- `run_summary_filename`: Filename for run summary
- `env_file`: Optional path to environment variables file
## Environment Variables
You can provide environment variables through the `env_file` configuration option. This is useful for provider API keys and other sensitive information. Example `.goosebench.env` file:
```bash
OPENAI_API_KEY=your_openai_api_key_here
DATABRICKS_TOKEN=your_databricks_token_here
# Add other environment variables as needed
```
**Important**: For evaluations that use LLM-as-judge (like `blog_summary` and `restaurant_research`), you must set `OPENAI_API_KEY` as the judging system uses OpenAI's GPT-4o model.
## Post-Processing
You can specify post-processing commands for evaluations, which will be executed after each evaluation completes. The command receives the path to the evaluation results file as its first argument.
For example, the `run_vibes_judge.sh` script processes outputs from the `blog_summary` and `restaurant_research` evaluations, using LLM-based judging to assign scores.
## Output Structure
Results are organized in a directory structure that follows this pattern:
```
{benchmark_dir}/
├── config.cfg # Configuration used for the benchmark
├── {provider}-{model}/
│ ├── eval-results/
│ │ └── aggregate_metrics.csv # Aggregated metrics for this model
│ └── run-{run_id}/
│ ├── {suite}/
│ │ └── {evaluation}/
│ │ ├── eval-results.json # Individual evaluation results
│ │ ├── {eval_name}.jsonl # Session logs
│ │ └── work_dir.json # Info about evaluation working dir
│ └── run-results-summary.json # Summary of all evaluations in this run
├── leaderboard.csv # Final leaderboard comparing all models
└── all_metrics.csv # Union of all metrics across all models
```
### Output Files Explained
#### Per-Model Files
- **`eval-results/aggregate_metrics.csv`**: Contains aggregated metrics for each evaluation, averaged across all runs. Includes metrics like `score_mean`, `total_tokens_mean`, `prompt_execution_time_seconds_mean`, etc.
#### Global Output Files
- **`leaderboard.csv`**: Final leaderboard ranking all models by their average performance across evaluations. Contains columns like:
- `provider`, `model_name`: Model identification
- `avg_score_mean`: Average score across all evaluations
- `avg_prompt_execution_time_seconds_mean`: Average execution time
- `avg_total_tool_calls_mean`: Average number of tool calls
- `avg_total_tokens_mean`: Average token usage
- **`all_metrics.csv`**: Comprehensive dataset containing detailed metrics for every model-evaluation combination. This is a union of all individual model metrics, useful for detailed analysis and custom reporting.
Each model gets its own directory, containing run results and aggregated CSV files for analysis. The `generate-leaderboard` command processes all individual evaluation results and creates the comparative metrics files.
## Error Handling and Troubleshooting
**Important**: The current version of goose-bench does not have robust error handling for common issues that can occur during evaluation runs, such as:
- Rate limiting from inference providers
- Network timeouts or connection errors
- Provider API errors that cause early session termination
- Resource exhaustion or memory issues
### Checking for Failed Evaluations
After running benchmarks, you should inspect the generated metrics files to identify any evaluations that may have failed or terminated early:
1. **Check the `aggregate_metrics.csv` files** in each model's `eval-results/` directory for:
- Missing evaluations (fewer rows than expected)
- Unusually low scores or metrics
- Zero or near-zero execution times
- Missing or NaN values
2. **Look for `server_error_mean` column** in the aggregate metrics - values greater than 0 indicate server errors occurred during evaluation
3. **Review session logs** (`.jsonl` files) in individual evaluation directories for error messages like:
- "Server error"
- "Rate limit exceeded"
- "TEMPORARILY_UNAVAILABLE"
- Unexpected session terminations
### Re-running Failed Evaluations
If you identify failed evaluations, you may need to:
1. **Adjust rate limiting**: Add delays between requests or reduce parallel execution
2. **Update environment variables**: Ensure API keys and tokens are valid
3. **Re-run specific model/evaluation combinations**: Create a new config with only the failed combinations
4. **Check provider status**: Verify the inference provider is operational
Example of creating a config to re-run failed evaluations:
```json
{
"models": [
{
"provider": "databricks",
"name": "claude-3-5-sonnet",
"parallel_safe": false
}
],
"evals": [
{
"selector": "vibes:blog_summary",
"post_process_cmd": "/path/to/scripts/bench-postprocess-scripts/llm-judges/run_vibes_judge.sh",
"parallel_safe": false
}
],
"repeat": 1,
"output_dir": "/path/to/retry-benchmark"
}
```
We recommend monitoring evaluation progress and checking for errors regularly, especially when running large benchmark suites across multiple models.
## Available Commands
### List Evaluations
```bash
goose bench selectors --config /path/to/config.json
```
### Generate Initial Config
```bash
goose bench init-config --name my-benchmark-config.json
```
### Run Benchmarks
```bash
goose bench run --config /path/to/config.json
```
### Generate Leaderboard
```bash
goose bench generate-leaderboard --benchmark-dir /path/to/benchmark-output
```

View File

@@ -30,6 +30,7 @@ pub struct BenchRunConfig {
pub include_dirs: Vec<PathBuf>,
pub repeat: Option<usize>,
pub run_id: Option<String>,
pub output_dir: Option<PathBuf>,
pub eval_result_filename: String,
pub run_summary_filename: String,
pub env_file: Option<PathBuf>,
@@ -63,6 +64,7 @@ impl Default for BenchRunConfig {
include_dirs: vec![],
repeat: Some(2),
run_id: None,
output_dir: None,
eval_result_filename: "eval-results.json".to_string(),
run_summary_filename: "run-results-summary.json".to_string(),
env_file: None,

View File

@@ -1,9 +1,9 @@
use anyhow::Context;
use chrono::Local;
use include_dir::{include_dir, Dir};
use serde::{Deserialize, Serialize};
use std::fs;
use std::io;
use std::io::ErrorKind;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
@@ -53,15 +53,35 @@ impl BenchmarkWorkDir {
}
}
pub fn init_experiment() {
pub fn init_experiment(output_dir: PathBuf) -> anyhow::Result<()> {
if !output_dir.is_absolute() {
anyhow::bail!(
"Internal Error: init_experiment received a non-absolute path: {}",
output_dir.display()
);
}
// create experiment folder
let current_time = Local::now().format("%H:%M:%S").to_string();
let current_date = Local::now().format("%Y-%m-%d").to_string();
let exp_name = format!("{}-{}", &current_date, current_time);
let base_path = PathBuf::from(format!("./benchmark-{}", exp_name));
fs::create_dir_all(&base_path).unwrap();
std::env::set_current_dir(&base_path).unwrap();
let exp_folder_name = format!("benchmark-{}-{}", &current_date, &current_time);
let base_path = output_dir.join(exp_folder_name);
fs::create_dir_all(&base_path).with_context(|| {
format!(
"Failed to create benchmark directory: {}",
base_path.display()
)
})?;
std::env::set_current_dir(&base_path).with_context(|| {
format!(
"Failed to change working directory to: {}",
base_path.display()
)
})?;
Ok(())
}
pub fn canonical_dirs(include_dirs: Vec<PathBuf>) -> Vec<PathBuf> {
include_dirs
.iter()
@@ -186,7 +206,7 @@ impl BenchmarkWorkDir {
Ok(())
} else {
let error_message = String::from_utf8_lossy(&output.stderr).to_string();
Err(io::Error::new(ErrorKind::Other, error_message))
Err(io::Error::other(error_message))
}
}

View File

@@ -3,17 +3,29 @@ use crate::bench_work_dir::BenchmarkWorkDir;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt;
pub type Model = (String, String);
pub type Extension = String;
#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum EvalMetricValue {
Integer(i64),
Float(f64),
String(String),
Boolean(bool),
}
impl fmt::Display for EvalMetricValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EvalMetricValue::Integer(i) => write!(f, "{}", i),
EvalMetricValue::Float(fl) => write!(f, "{:.2}", fl),
EvalMetricValue::String(s) => write!(f, "{}", s),
EvalMetricValue::Boolean(b) => write!(f, "{}", b),
}
}
}
#[derive(Debug, Serialize)]
pub struct EvalMetric {
pub name: String,

View File

@@ -98,17 +98,6 @@ impl BenchmarkResults {
}
}
impl fmt::Display for EvalMetricValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EvalMetricValue::Integer(i) => write!(f, "{}", i),
EvalMetricValue::Float(fl) => write!(f, "{:.2}", fl),
EvalMetricValue::String(s) => write!(f, "{}", s),
EvalMetricValue::Boolean(b) => write!(f, "{}", b),
}
}
}
impl fmt::Display for BenchmarkResults {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Benchmark Results")?;

View File

@@ -3,6 +3,7 @@ use crate::bench_work_dir::BenchmarkWorkDir;
use crate::eval_suites::EvaluationSuite;
use crate::runners::model_runner::ModelRunner;
use crate::utilities::{await_process_exits, parallel_bench_cmd};
use anyhow::Context;
use std::path::PathBuf;
#[derive(Clone)]
@@ -11,9 +12,27 @@ pub struct BenchRunner {
}
impl BenchRunner {
pub fn new(config: PathBuf) -> anyhow::Result<BenchRunner> {
let config = BenchRunConfig::from(config)?;
BenchmarkWorkDir::init_experiment();
pub fn new(config_path: PathBuf) -> anyhow::Result<BenchRunner> {
let config = BenchRunConfig::from(config_path.clone())?;
let resolved_output_dir = match &config.output_dir {
Some(path) => {
if !path.is_absolute() {
anyhow::bail!(
"Config Error in '{}': 'output_dir' must be an absolute path, but found relative path: {}",
config_path.display(),
path.display()
);
}
path.clone()
}
None => std::env::current_dir().context(
"Failed to get current working directory to use as default output directory",
)?,
};
BenchmarkWorkDir::init_experiment(resolved_output_dir)?;
config.save("config.cfg".to_string());
Ok(BenchRunner { config })
}

View File

@@ -4,12 +4,14 @@ use crate::bench_work_dir::BenchmarkWorkDir;
use crate::eval_suites::{EvaluationSuite, ExtensionRequirements};
use crate::reporting::EvaluationResult;
use crate::utilities::await_process_exits;
use anyhow::{bail, Context, Result};
use std::env;
use std::fs;
use std::future::Future;
use std::path::PathBuf;
use std::process::Command;
use std::time::{SystemTime, UNIX_EPOCH};
use tracing;
#[derive(Clone)]
pub struct EvalRunner {
@@ -17,13 +19,17 @@ pub struct EvalRunner {
}
impl EvalRunner {
pub fn from(config: String) -> anyhow::Result<EvalRunner> {
let config = BenchRunConfig::from_string(config)?;
pub fn from(config: String) -> Result<EvalRunner> {
let config = BenchRunConfig::from_string(config)
.context("Failed to parse evaluation configuration")?;
Ok(EvalRunner { config })
}
fn create_work_dir(&self, config: &BenchRunConfig) -> anyhow::Result<BenchmarkWorkDir> {
let goose_model = config.models.first().unwrap();
fn create_work_dir(&self, config: &BenchRunConfig) -> Result<BenchmarkWorkDir> {
let goose_model = config
.models
.first()
.context("No model specified in configuration")?;
let model_name = goose_model.name.clone();
let provider_name = goose_model.provider.clone();
@@ -48,13 +54,21 @@ impl EvalRunner {
let work_dir = BenchmarkWorkDir::new(work_dir_name, include_dir);
Ok(work_dir)
}
pub async fn run<F, Fut>(&mut self, agent_generator: F) -> anyhow::Result<()>
pub async fn run<F, Fut>(&mut self, agent_generator: F) -> Result<()>
where
F: Fn(ExtensionRequirements, String) -> Fut,
Fut: Future<Output = BenchAgent> + Send,
{
let mut work_dir = self.create_work_dir(&self.config)?;
let bench_eval = self.config.evals.first().unwrap();
let mut work_dir = self
.create_work_dir(&self.config)
.context("Failed to create evaluation work directory")?;
let bench_eval = self
.config
.evals
.first()
.context("No evaluations specified in configuration")?;
let run_id = &self
.config
@@ -65,41 +79,89 @@ impl EvalRunner {
// create entire dir subtree for eval and cd into dir for running eval
work_dir.set_eval(&bench_eval.selector, run_id);
tracing::info!("Set evaluation directory for {}", bench_eval.selector);
if let Some(eval) = EvaluationSuite::from(&bench_eval.selector) {
let now_stamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_nanos();
let now_stamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.context("Failed to get current timestamp")?
.as_nanos();
let session_id = format!("{}-{}", bench_eval.selector.clone(), now_stamp);
let mut agent = agent_generator(eval.required_extensions(), session_id).await;
tracing::info!("Agent created for {}", eval.name());
let mut result = EvaluationResult::new(eval.name().to_string());
if let Ok(metrics) = eval.run(&mut agent, &mut work_dir).await {
for (name, metric) in metrics {
result.add_metric(name, metric);
match eval.run(&mut agent, &mut work_dir).await {
Ok(metrics) => {
tracing::info!("Evaluation run successful with {} metrics", metrics.len());
for (name, metric) in metrics {
result.add_metric(name, metric);
}
}
// Add any errors that occurred
for error in agent.get_errors().await {
result.add_error(error);
Err(e) => {
tracing::error!("Evaluation run failed: {}", e);
}
}
let eval_results = serde_json::to_string_pretty(&result)?;
// Add any errors that occurred
let errors = agent.get_errors().await;
tracing::info!("Agent reported {} errors", errors.len());
for error in errors {
result.add_error(error);
}
// Write results to file
let eval_results = serde_json::to_string_pretty(&result)
.context("Failed to serialize evaluation results to JSON")?;
let eval_results_file = env::current_dir()
.context("Failed to get current directory")?
.join(&self.config.eval_result_filename);
fs::write(&eval_results_file, &eval_results).with_context(|| {
format!(
"Failed to write evaluation results to {}",
eval_results_file.display()
)
})?;
tracing::info!(
"Wrote evaluation results to {}",
eval_results_file.display()
);
let eval_results_file = env::current_dir()?.join(&self.config.eval_result_filename);
fs::write(&eval_results_file, &eval_results)?;
self.config.save("config.cfg".to_string());
work_dir.save();
// handle running post-process cmd if configured
if let Some(cmd) = &bench_eval.post_process_cmd {
let handle = Command::new(cmd).arg(&eval_results_file).spawn()?;
tracing::info!("Running post-process command: {:?}", cmd);
let handle = Command::new(cmd)
.arg(&eval_results_file)
.spawn()
.with_context(|| {
format!("Failed to execute post-process command: {:?}", cmd)
})?;
await_process_exits(&mut [handle], Vec::new());
}
// copy session file into eval-dir
let here = env::current_dir()?.canonicalize()?;
BenchmarkWorkDir::deep_copy(agent.session_file().as_path(), here.as_path(), false)?;
let here = env::current_dir()
.context("Failed to get current directory")?
.canonicalize()
.context("Failed to canonicalize current directory path")?;
BenchmarkWorkDir::deep_copy(agent.session_file().as_path(), here.as_path(), false)
.context("Failed to copy session file to evaluation directory")?;
tracing::info!("Evaluation completed successfully");
} else {
tracing::error!("No evaluation found for selector: {}", bench_eval.selector);
bail!("No evaluation found for selector: {}", bench_eval.selector);
}
Ok(())

View File

@@ -0,0 +1,81 @@
use anyhow::{bail, ensure, Context, Result};
use std::path::PathBuf;
use tracing;
pub struct MetricAggregator;
impl MetricAggregator {
/// Generate leaderboard and aggregated metrics CSV files from benchmark directory
pub fn generate_csv_from_benchmark_dir(benchmark_dir: &PathBuf) -> Result<()> {
use std::process::Command;
// Step 1: Run prepare_aggregate_metrics.py to create aggregate_metrics.csv files
let prepare_script_path = std::env::current_dir()
.context("Failed to get current working directory")?
.join("scripts")
.join("bench-postprocess-scripts")
.join("prepare_aggregate_metrics.py");
ensure!(
prepare_script_path.exists(),
"Prepare script not found: {}",
prepare_script_path.display()
);
tracing::info!(
"Preparing aggregate metrics from benchmark directory: {}",
benchmark_dir.display()
);
let output = Command::new(&prepare_script_path)
.arg("--benchmark-dir")
.arg(benchmark_dir)
.output()
.context("Failed to execute prepare_aggregate_metrics.py script")?;
if !output.status.success() {
let error_message = String::from_utf8_lossy(&output.stderr);
bail!("Failed to prepare aggregate metrics: {}", error_message);
}
let success_message = String::from_utf8_lossy(&output.stdout);
tracing::info!("{}", success_message);
// Step 2: Run generate_leaderboard.py to create the final leaderboard
let leaderboard_script_path = std::env::current_dir()
.context("Failed to get current working directory")?
.join("scripts")
.join("bench-postprocess-scripts")
.join("generate_leaderboard.py");
ensure!(
leaderboard_script_path.exists(),
"Leaderboard script not found: {}",
leaderboard_script_path.display()
);
tracing::info!(
"Generating leaderboard from benchmark directory: {}",
benchmark_dir.display()
);
let output = Command::new(&leaderboard_script_path)
.arg("--benchmark-dir")
.arg(benchmark_dir)
.arg("--leaderboard-output")
.arg("leaderboard.csv")
.arg("--union-output")
.arg("all_metrics.csv")
.output()
.context("Failed to execute generate_leaderboard.py script")?;
if !output.status.success() {
let error_message = String::from_utf8_lossy(&output.stderr);
bail!("Failed to generate leaderboard: {}", error_message);
}
let success_message = String::from_utf8_lossy(&output.stdout);
tracing::info!("{}", success_message);
Ok(())
}
}

View File

@@ -1,3 +1,4 @@
pub mod bench_runner;
pub mod eval_runner;
pub mod metric_aggregator;
pub mod model_runner;

View File

@@ -3,12 +3,14 @@ use crate::eval_suites::EvaluationSuite;
use crate::reporting::{BenchmarkResults, SuiteResult};
use crate::runners::eval_runner::EvalRunner;
use crate::utilities::{await_process_exits, parallel_bench_cmd};
use anyhow::{Context, Result};
use dotenvy::from_path_iter;
use std::collections::HashMap;
use std::fs::read_to_string;
use std::io::{self, BufRead};
use std::path::PathBuf;
use std::process::Child;
use std::thread;
use tracing;
#[derive(Clone)]
pub struct ModelRunner {
@@ -16,23 +18,27 @@ pub struct ModelRunner {
}
impl ModelRunner {
pub fn from(config: String) -> anyhow::Result<ModelRunner> {
let config = BenchRunConfig::from_string(config)?;
pub fn from(config: String) -> Result<ModelRunner> {
let config =
BenchRunConfig::from_string(config).context("Failed to parse configuration")?;
Ok(ModelRunner { config })
}
pub fn run(&self) -> anyhow::Result<()> {
let model = self.config.models.first().unwrap();
pub fn run(&self) -> Result<()> {
let model = self
.config
.models
.first()
.context("No model specified in config")?;
let suites = self.collect_evals_for_run();
let mut handles = vec![];
for i in 0..self.config.repeat.unwrap_or(1) {
let mut self_copy = self.clone();
let self_copy = self.clone();
let model_clone = model.clone();
let suites_clone = suites.clone();
// create thread to handle launching parallel processes to run model's evals in parallel
let handle = thread::spawn(move || {
let handle = thread::spawn(move || -> Result<()> {
self_copy.run_benchmark(&model_clone, suites_clone, i.to_string())
});
handles.push(handle);
@@ -41,55 +47,32 @@ impl ModelRunner {
let mut all_runs_results: Vec<BenchmarkResults> = Vec::new();
for i in 0..self.config.repeat.unwrap_or(1) {
let run_results =
self.collect_run_results(model.clone(), suites.clone(), i.to_string())?;
all_runs_results.push(run_results);
match self.collect_run_results(model.clone(), suites.clone(), i.to_string()) {
Ok(run_results) => all_runs_results.push(run_results),
Err(e) => {
tracing::error!("Failed to collect results for run {}: {}", i, e)
}
}
}
// write summary file
Ok(())
}
fn load_env_file(&self, path: &PathBuf) -> anyhow::Result<Vec<(String, String)>> {
let file = std::fs::File::open(path)?;
let reader = io::BufReader::new(file);
let mut env_vars = Vec::new();
for line in reader.lines() {
let line = line?;
// Skip empty lines and comments
if line.trim().is_empty() || line.trim_start().starts_with('#') {
continue;
}
// Split on first '=' only
if let Some((key, value)) = line.split_once('=') {
let key = key.trim().to_string();
// Remove quotes if present
let value = value
.trim()
.trim_matches('"')
.trim_matches('\'')
.to_string();
env_vars.push((key, value));
}
}
Ok(env_vars)
}
fn run_benchmark(
&mut self,
&self,
model: &BenchModel,
suites: HashMap<String, Vec<BenchEval>>,
run_id: String,
) -> anyhow::Result<()> {
) -> Result<()> {
let mut results_handles = HashMap::<String, Vec<Child>>::new();
// Load environment variables from file if specified
let mut envs = self.toolshim_envs();
if let Some(env_file) = &self.config.env_file {
let env_vars = self.load_env_file(env_file)?;
let env_vars = ModelRunner::load_env_file(env_file).context(format!(
"Failed to load environment file: {}",
env_file.display()
))?;
envs.extend(env_vars);
}
envs.push(("GOOSE_MODEL".to_string(), model.clone().name));
@@ -116,9 +99,13 @@ impl ModelRunner {
// Run parallel-safe evaluations in parallel
if !parallel_evals.is_empty() {
for eval_selector in &parallel_evals {
self.config.run_id = Some(run_id.clone());
self.config.evals = vec![(*eval_selector).clone()];
let cfg = self.config.to_string()?;
let mut config_copy = self.config.clone();
config_copy.run_id = Some(run_id.clone());
config_copy.evals = vec![(*eval_selector).clone()];
let cfg = config_copy
.to_string()
.context("Failed to serialize configuration")?;
let handle = parallel_bench_cmd("exec-eval".to_string(), cfg, envs.clone());
results_handles.get_mut(suite).unwrap().push(handle);
}
@@ -126,9 +113,13 @@ impl ModelRunner {
// Run non-parallel-safe evaluations sequentially
for eval_selector in &sequential_evals {
self.config.run_id = Some(run_id.clone());
self.config.evals = vec![(*eval_selector).clone()];
let cfg = self.config.to_string()?;
let mut config_copy = self.config.clone();
config_copy.run_id = Some(run_id.clone());
config_copy.evals = vec![(*eval_selector).clone()];
let cfg = config_copy
.to_string()
.context("Failed to serialize configuration")?;
let handle = parallel_bench_cmd("exec-eval".to_string(), cfg, envs.clone());
// Wait for this process to complete before starting the next one
@@ -150,7 +141,7 @@ impl ModelRunner {
model: BenchModel,
suites: HashMap<String, Vec<BenchEval>>,
run_id: String,
) -> anyhow::Result<BenchmarkResults> {
) -> Result<BenchmarkResults> {
let mut results = BenchmarkResults::new(model.provider.clone());
let mut summary_path: Option<PathBuf> = None;
@@ -161,7 +152,17 @@ impl ModelRunner {
let mut eval_path =
EvalRunner::path_for_eval(&model, eval_selector, run_id.clone());
eval_path.push(self.config.eval_result_filename.clone());
let eval_result = serde_json::from_str(&read_to_string(&eval_path)?)?;
let content = read_to_string(&eval_path).with_context(|| {
format!(
"Failed to read evaluation results from {}",
eval_path.display()
)
})?;
let eval_result = serde_json::from_str(&content)
.context("Failed to parse evaluation results JSON")?;
suite_result.add_evaluation(eval_result);
// use current eval to determine where the summary should be written
@@ -180,12 +181,21 @@ impl ModelRunner {
results.add_suite(suite_result);
}
let mut run_summary = PathBuf::new();
run_summary.push(summary_path.clone().unwrap());
run_summary.push(&self.config.run_summary_filename);
if let Some(path) = summary_path {
let mut run_summary = PathBuf::new();
run_summary.push(path);
run_summary.push(&self.config.run_summary_filename);
let output_str = serde_json::to_string_pretty(&results)?;
std::fs::write(run_summary, &output_str)?;
let output_str = serde_json::to_string_pretty(&results)
.context("Failed to serialize benchmark results to JSON")?;
std::fs::write(&run_summary, &output_str).with_context(|| {
format!(
"Failed to write results summary to {}",
run_summary.display()
)
})?;
}
Ok(results)
}
@@ -210,20 +220,29 @@ impl ModelRunner {
fn toolshim_envs(&self) -> Vec<(String, String)> {
// read tool-shim preference from config, set respective env vars accordingly
let model = self.config.models.first().unwrap();
let mut shim_envs: Vec<(String, String)> = Vec::new();
if let Some(shim_opt) = &model.tool_shim {
if shim_opt.use_tool_shim {
shim_envs.push(("GOOSE_TOOLSHIM".to_string(), "true".to_string()));
if let Some(shim_model) = &shim_opt.tool_shim_model {
shim_envs.push((
"GOOSE_TOOLSHIM_OLLAMA_MODEL".to_string(),
shim_model.clone(),
));
if let Some(model) = self.config.models.first() {
if let Some(shim_opt) = &model.tool_shim {
if shim_opt.use_tool_shim {
shim_envs.push(("GOOSE_TOOLSHIM".to_string(), "true".to_string()));
if let Some(shim_model) = &shim_opt.tool_shim_model {
shim_envs.push((
"GOOSE_TOOLSHIM_OLLAMA_MODEL".to_string(),
shim_model.clone(),
));
}
}
}
}
shim_envs
}
fn load_env_file(path: &PathBuf) -> Result<Vec<(String, String)>> {
let iter =
from_path_iter(path).context("Failed to read environment variables from file")?;
let env_vars = iter
.map(|item| item.context("Failed to parse environment variable"))
.collect::<Result<_, _>>()?;
Ok(env_vars)
}
}

View File

@@ -1,15 +1,14 @@
use anyhow::Result;
use std::env;
use std::process::{Child, Command};
use std::thread::JoinHandle;
use tracing;
pub fn await_process_exits(
child_processes: &mut [Child],
handles: Vec<JoinHandle<anyhow::Result<()>>>,
) {
pub fn await_process_exits(child_processes: &mut [Child], handles: Vec<JoinHandle<Result<()>>>) {
for child in child_processes.iter_mut() {
match child.wait() {
Ok(status) => println!("Child exited with status: {}", status),
Err(e) => println!("Error waiting for child: {}", e),
Ok(status) => tracing::info!("Child exited with status: {}", status),
Err(e) => tracing::error!("Error waiting for child: {}", e),
}
}
@@ -18,7 +17,7 @@ pub fn await_process_exits(
Ok(_res) => (),
Err(e) => {
// Handle thread panic
println!("Thread panicked: {:?}", e);
tracing::error!("Thread panicked: {:?}", e);
}
}
}

View File

@@ -52,6 +52,9 @@ shlex = "1.3.0"
async-trait = "0.1.86"
base64 = "0.22.1"
regex = "1.11.1"
minijinja = "2.8.0"
nix = { version = "0.30.1", features = ["process", "signal"] }
tar = "0.4"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -7,15 +7,22 @@ use crate::commands::bench::agent_generator;
use crate::commands::configure::handle_configure;
use crate::commands::info::handle_info;
use crate::commands::mcp::run_server;
use crate::commands::project::{handle_project_default, handle_projects_interactive};
use crate::commands::recipe::{handle_deeplink, handle_validate};
// Import the new handlers from commands::schedule
use crate::commands::schedule::{
handle_schedule_add, handle_schedule_list, handle_schedule_remove, handle_schedule_run_now,
handle_schedule_sessions,
};
use crate::commands::session::{handle_session_list, handle_session_remove};
use crate::logging::setup_logging;
use crate::recipe::load_recipe;
use crate::recipes::recipe::{explain_recipe_with_parameters, load_recipe_as_template};
use crate::session;
use crate::session::{build_session, SessionBuilderConfig};
use goose_bench::bench_config::BenchRunConfig;
use goose_bench::runners::bench_runner::BenchRunner;
use goose_bench::runners::eval_runner::EvalRunner;
use goose_bench::runners::metric_aggregator::MetricAggregator;
use goose_bench::runners::model_runner::ModelRunner;
use std::io::Read;
use std::path::PathBuf;
@@ -59,6 +66,13 @@ fn extract_identifier(identifier: Identifier) -> session::Identifier {
}
}
fn parse_key_val(s: &str) -> Result<(String, String), String> {
match s.split_once('=') {
Some((key, value)) => Ok((key.to_string(), value.to_string())),
None => Err(format!("invalid KEY=VALUE: {}", s)),
}
}
#[derive(Subcommand)]
enum SessionCommand {
#[command(about = "List all available sessions")]
@@ -81,17 +95,52 @@ enum SessionCommand {
)]
ascending: bool,
},
#[command(about = "Remove sessions")]
#[command(about = "Remove sessions. Runs interactively if no ID or regex is provided.")]
Remove {
#[arg(short, long, help = "session id to be removed", default_value = "")]
#[arg(short, long, help = "Session ID to be removed (optional)")]
id: Option<String>,
#[arg(short, long, help = "Regex for removing matched sessions (optional)")]
regex: Option<String>,
},
}
#[derive(Subcommand, Debug)]
enum SchedulerCommand {
#[command(about = "Add a new scheduled job")]
Add {
#[arg(long, help = "Unique ID for the job")]
id: String,
#[arg(long, help = "Cron string for the schedule (e.g., '0 0 * * * *')")]
cron: String,
#[arg(
short,
long,
help = "regex for removing matched session",
default_value = ""
help = "Recipe source (path to file, or base64 encoded recipe string)"
)]
regex: String,
recipe_source: String,
},
#[command(about = "List all scheduled jobs")]
List {},
#[command(about = "Remove a scheduled job by ID")]
Remove {
#[arg(long, help = "ID of the job to remove")] // Changed from positional to named --id
id: String,
},
/// List sessions created by a specific schedule
#[command(about = "List sessions created by a specific schedule")]
Sessions {
/// ID of the schedule
#[arg(long, help = "ID of the schedule")] // Explicitly make it --id
id: String,
/// Maximum number of sessions to return
#[arg(long, help = "Maximum number of sessions to return")]
limit: Option<u32>,
},
/// Run a scheduled job immediately
#[command(about = "Run a scheduled job immediately")]
RunNow {
/// ID of the schedule to run
#[arg(long, help = "ID of the schedule to run")] // Explicitly make it --id
id: String,
},
}
@@ -134,24 +183,39 @@ pub enum BenchCommand {
#[arg(short, long, help = "A serialized config file for the eval only.")]
config: String,
},
#[command(
name = "generate-leaderboard",
about = "Generate a leaderboard CSV from benchmark results"
)]
GenerateLeaderboard {
#[arg(
short,
long,
help = "Path to the benchmark directory containing model evaluation results"
)]
benchmark_dir: PathBuf,
},
}
#[derive(Subcommand)]
enum RecipeCommand {
/// Validate a recipe file
#[command(about = "Validate a recipe file")]
#[command(about = "Validate a recipe")]
Validate {
/// Path to the recipe file to validate
#[arg(help = "Path to the recipe file to validate")]
file: String,
/// Recipe name to get recipe file to validate
#[arg(help = "recipe name to get recipe file or full path to the recipe file to validate")]
recipe_name: String,
},
/// Generate a deeplink for a recipe file
#[command(about = "Generate a deeplink for a recipe file")]
#[command(about = "Generate a deeplink for a recipe")]
Deeplink {
/// Path to the recipe file
#[arg(help = "Path to the recipe file")]
file: String,
/// Recipe name to get recipe file to generate deeplink
#[arg(
help = "recipe name to get recipe file or full path to the recipe file to generate deeplink"
)]
recipe_name: String,
},
}
@@ -194,6 +258,14 @@ enum Command {
)]
resume: bool,
/// Show message history when resuming
#[arg(
long,
help = "Show previous messages when resuming a session",
requires = "resume"
)]
history: bool,
/// Enable debug output mode
#[arg(
long,
@@ -202,6 +274,15 @@ enum Command {
)]
debug: bool,
/// Maximum number of consecutive identical tool calls allowed
#[arg(
long = "max-tool-repetitions",
value_name = "NUMBER",
help = "Maximum number of consecutive identical tool calls allowed",
long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops."
)]
max_tool_repetitions: Option<u32>,
/// Add stdio extensions with environment variables and commands
#[arg(
long = "with-extension",
@@ -233,6 +314,14 @@ enum Command {
builtins: Vec<String>,
},
/// Open the last project directory
#[command(about = "Open the last project directory", visible_alias = "p")]
Project {},
/// List recent project directories
#[command(about = "List recent project directories", visible_alias = "ps")]
Projects,
/// Execute commands from an instruction file
#[command(about = "Execute commands from an instruction file or stdin")]
Run {
@@ -259,18 +348,28 @@ enum Command {
)]
input_text: Option<String>,
/// Path to recipe.yaml file
/// Recipe name or full path to the recipe file
#[arg(
short = None,
long = "recipe",
value_name = "FILE",
help = "Path to recipe.yaml file",
long_help = "Path to a recipe.yaml file that defines a custom agent configuration",
value_name = "RECIPE_NAME or FULL_PATH_TO_RECIPE_FILE",
help = "Recipe name to get recipe file or the full path of the recipe file (use --explain to see recipe details)",
long_help = "Recipe name to get recipe file or the full path of the recipe file that defines a custom agent configuration. Use --explain to see the recipe's title, description, and parameters.",
conflicts_with = "instructions",
conflicts_with = "input_text"
)]
recipe: Option<String>,
#[arg(
long,
value_name = "KEY=VALUE",
help = "Dynamic parameters (e.g., --params username=alice --params channel_name=goose-channel)",
long_help = "Key-value parameters to pass to the recipe file. Can be specified multiple times.",
action = clap::ArgAction::Append,
value_parser = parse_key_val,
)]
params: Vec<(String, String)>,
/// Continue in interactive mode after processing input
#[arg(
short = 's',
@@ -279,6 +378,31 @@ enum Command {
)]
interactive: bool,
/// Run without storing a session file
#[arg(
long = "no-session",
help = "Run without storing a session file",
long_help = "Execute commands without creating or using a session file. Useful for automated runs.",
conflicts_with_all = ["resume", "name", "path"]
)]
no_session: bool,
/// Show the recipe title, description, and parameters
#[arg(
long = "explain",
help = "Show the recipe title, description, and parameters"
)]
explain: bool,
/// Maximum number of consecutive identical tool calls allowed
#[arg(
long = "max-tool-repetitions",
value_name = "NUMBER",
help = "Maximum number of consecutive identical tool calls allowed",
long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops."
)]
max_tool_repetitions: Option<u32>,
/// Identifier for this run session
#[command(flatten)]
identifier: Option<Identifier>,
@@ -339,6 +463,13 @@ enum Command {
command: RecipeCommand,
},
/// Manage scheduled jobs
#[command(about = "Manage scheduled jobs", visible_alias = "sched")]
Schedule {
#[command(subcommand)]
command: SchedulerCommand,
},
/// Update the Goose CLI version
#[command(about = "Update the goose CLI version")]
Update {
@@ -378,6 +509,11 @@ struct InputConfig {
pub async fn cli() -> Result<()> {
let cli = Cli::parse();
// Track the current directory in projects.json
if let Err(e) = crate::project_tracker::update_project_tracker(None, None) {
eprintln!("Warning: Failed to update project tracker: {}", e);
}
match cli.command {
Some(Command::Configure {}) => {
let _ = handle_configure().await;
@@ -394,7 +530,9 @@ pub async fn cli() -> Result<()> {
command,
identifier,
resume,
history,
debug,
max_tool_repetitions,
extensions,
remote_extensions,
builtins,
@@ -414,26 +552,44 @@ pub async fn cli() -> Result<()> {
}
None => {
// Run session command by default
let mut session = build_session(SessionBuilderConfig {
let mut session: crate::Session = build_session(SessionBuilderConfig {
identifier: identifier.map(extract_identifier),
resume,
no_session: false,
extensions,
remote_extensions,
builtins,
extensions_override: None,
additional_system_prompt: None,
debug,
max_tool_repetitions,
})
.await;
setup_logging(
session.session_file().file_stem().and_then(|s| s.to_str()),
None,
)?;
// Render previous messages if resuming a session and history flag is set
if resume && history {
session.render_message_history();
}
let _ = session.interactive(None).await;
Ok(())
}
};
}
Some(Command::Project {}) => {
// Default behavior: offer to resume the last project
handle_project_default()?;
return Ok(());
}
Some(Command::Projects) => {
// Interactive project selection
handle_projects_interactive()?;
return Ok(());
}
Some(Command::Run {
instructions,
input_text,
@@ -441,13 +597,17 @@ pub async fn cli() -> Result<()> {
interactive,
identifier,
resume,
no_session,
debug,
max_tool_repetitions,
extensions,
remote_extensions,
builtins,
params,
explain,
}) => {
let input_config = match (instructions, input_text, recipe) {
(Some(file), _, _) if file == "-" => {
let input_config = match (instructions, input_text, recipe, explain) {
(Some(file), _, _, _) if file == "-" => {
let mut input = String::new();
std::io::stdin()
.read_to_string(&mut input)
@@ -459,7 +619,7 @@ pub async fn cli() -> Result<()> {
additional_system_prompt: None,
}
}
(Some(file), _, _) => {
(Some(file), _, _, _) => {
let contents = std::fs::read_to_string(&file).unwrap_or_else(|err| {
eprintln!(
"Instruction file not found — did you mean to use goose run --text?\n{}",
@@ -473,23 +633,28 @@ pub async fn cli() -> Result<()> {
additional_system_prompt: None,
}
}
(_, Some(text), _) => InputConfig {
(_, Some(text), _, _) => InputConfig {
contents: Some(text),
extensions_override: None,
additional_system_prompt: None,
},
(_, _, Some(file)) => {
let recipe = load_recipe(&file, true).unwrap_or_else(|err| {
eprintln!("{}: {}", console::style("Error").red().bold(), err);
std::process::exit(1);
});
(_, _, Some(recipe_name), explain) => {
if explain {
explain_recipe_with_parameters(&recipe_name, params)?;
return Ok(());
}
let recipe =
load_recipe_as_template(&recipe_name, params).unwrap_or_else(|err| {
eprintln!("{}: {}", console::style("Error").red().bold(), err);
std::process::exit(1);
});
InputConfig {
contents: recipe.prompt,
extensions_override: recipe.extensions,
additional_system_prompt: Some(recipe.instructions),
additional_system_prompt: recipe.instructions,
}
}
(None, None, None) => {
(None, None, None, _) => {
eprintln!("Error: Must provide either --instructions (-i), --text (-t), or --recipe. Use -i - for stdin.");
std::process::exit(1);
}
@@ -498,12 +663,14 @@ pub async fn cli() -> Result<()> {
let mut session = build_session(SessionBuilderConfig {
identifier: identifier.map(extract_identifier),
resume,
no_session,
extensions,
remote_extensions,
builtins,
extensions_override: input_config.extensions_override,
additional_system_prompt: input_config.additional_system_prompt,
debug,
max_tool_repetitions,
})
.await;
@@ -523,6 +690,32 @@ pub async fn cli() -> Result<()> {
return Ok(());
}
Some(Command::Schedule { command }) => {
match command {
SchedulerCommand::Add {
id,
cron,
recipe_source,
} => {
handle_schedule_add(id, cron, recipe_source).await?;
}
SchedulerCommand::List {} => {
handle_schedule_list().await?;
}
SchedulerCommand::Remove { id } => {
handle_schedule_remove(id).await?;
}
SchedulerCommand::Sessions { id, limit } => {
// New arm
handle_schedule_sessions(id, limit).await?;
}
SchedulerCommand::RunNow { id } => {
// New arm
handle_schedule_run_now(id).await?;
}
}
return Ok(());
}
Some(Command::Update {
canary,
reconfigure,
@@ -533,22 +726,31 @@ pub async fn cli() -> Result<()> {
Some(Command::Bench { cmd }) => {
match cmd {
BenchCommand::Selectors { config } => BenchRunner::list_selectors(config)?,
BenchCommand::InitConfig { name } => BenchRunConfig::default().save(name),
BenchCommand::InitConfig { name } => {
let mut config = BenchRunConfig::default();
let cwd =
std::env::current_dir().expect("Failed to get current working directory");
config.output_dir = Some(cwd);
config.save(name);
}
BenchCommand::Run { config } => BenchRunner::new(config)?.run()?,
BenchCommand::EvalModel { config } => ModelRunner::from(config)?.run()?,
BenchCommand::ExecEval { config } => {
EvalRunner::from(config)?.run(agent_generator).await?
}
BenchCommand::GenerateLeaderboard { benchmark_dir } => {
MetricAggregator::generate_csv_from_benchmark_dir(&benchmark_dir)?
}
}
return Ok(());
}
Some(Command::Recipe { command }) => {
match command {
RecipeCommand::Validate { file } => {
handle_validate(file)?;
RecipeCommand::Validate { recipe_name } => {
handle_validate(&recipe_name)?;
}
RecipeCommand::Deeplink { file } => {
handle_deeplink(file)?;
RecipeCommand::Deeplink { recipe_name } => {
handle_deeplink(&recipe_name)?;
}
}
return Ok(());
@@ -559,7 +761,19 @@ pub async fn cli() -> Result<()> {
Ok(())
} else {
// Run session command by default
let mut session = build_session(SessionBuilderConfig::default()).await;
let mut session = build_session(SessionBuilderConfig {
identifier: None,
resume: false,
no_session: false,
extensions: Vec::new(),
remote_extensions: Vec::new(),
builtins: Vec::new(),
extensions_override: None,
additional_system_prompt: None,
debug: false,
max_tool_repetitions: None,
})
.await;
setup_logging(
session.session_file().file_stem().and_then(|s| s.to_str()),
None,

View File

@@ -34,12 +34,14 @@ pub async fn agent_generator(
let base_session = build_session(SessionBuilderConfig {
identifier,
resume: false,
no_session: false,
extensions: requirements.external,
remote_extensions: requirements.remote,
builtins: requirements.builtin,
extensions_override: None,
additional_system_prompt: None,
debug: false,
max_tool_repetitions: None,
})
.await;

View File

@@ -21,6 +21,8 @@ use serde_json::{json, Value};
use std::collections::HashMap;
use std::error::Error;
use crate::recipes::github_recipe::GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY;
// useful for light themes where there is no dicernible colour contrast between
// cursor-selected and cursor-unselected items.
const MULTISELECT_VISIBILITY_HINT: &str = "<";
@@ -193,7 +195,7 @@ pub async fn handle_configure() -> Result<(), Box<dyn Error>> {
.item(
"settings",
"Goose Settings",
"Set the Goose Mode, Tool Output, Tool Permissions, Experiment and more",
"Set the Goose Mode, Tool Output, Tool Permissions, Experiment, Goose recipe github repo and more",
)
.interact()?;
@@ -325,11 +327,40 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
}
}
// Select model, defaulting to the provider's recommended model UNLESS there is an env override
let default_model = std::env::var("GOOSE_MODEL").unwrap_or(provider_meta.default_model.clone());
let model: String = cliclack::input("Enter a model from that provider:")
.default_input(&default_model)
.interact()?;
// Attempt to fetch supported models for this provider
let spin = spinner();
spin.start("Attempting to fetch supported models...");
let models_res = {
let temp_model_config = goose::model::ModelConfig::new(provider_meta.default_model.clone());
let temp_provider = create(provider_name, temp_model_config)?;
temp_provider.fetch_supported_models_async().await
};
spin.stop(style("Model fetch complete").green());
// Select a model: on fetch error show styled error and abort; if Some(models), show list; if None, free-text input
let model: String = match models_res {
Err(e) => {
// Provider hook error
cliclack::outro(style(e.to_string()).on_red().white())?;
return Ok(false);
}
Ok(Some(models)) => cliclack::select("Select a model:")
.items(
&models
.iter()
.map(|m| (m, m.as_str(), ""))
.collect::<Vec<_>>(),
)
.interact()?
.to_string(),
Ok(None) => {
let default_model =
std::env::var("GOOSE_MODEL").unwrap_or(provider_meta.default_model.clone());
cliclack::input("Enter a model from that provider:")
.default_input(&default_model)
.interact()?
}
};
// Test the configuration
let spin = spinner();
@@ -793,6 +824,11 @@ pub fn remove_extension_dialog() -> Result<(), Box<dyn Error>> {
pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
let setting_type = cliclack::select("What setting would you like to configure?")
.item("goose_mode", "Goose Mode", "Configure Goose mode")
.item(
"goose_router_strategy",
"Router Tool Selection Strategy",
"Configure the strategy for selecting tools to use",
)
.item(
"tool_permission",
"Tool Permission",
@@ -808,12 +844,20 @@ pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
"Toggle Experiment",
"Enable or disable an experiment feature",
)
.item(
"recipe",
"Goose recipe github repo",
"Goose will pull recipes from this repo if not found locally.",
)
.interact()?;
match setting_type {
"goose_mode" => {
configure_goose_mode_dialog()?;
}
"goose_router_strategy" => {
configure_goose_router_strategy_dialog()?;
}
"tool_permission" => {
configure_tool_permissions_dialog().await.and(Ok(()))?;
}
@@ -823,6 +867,9 @@ pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
"experiment" => {
toggle_experiments_dialog()?;
}
"recipe" => {
configure_recipe_dialog()?;
}
_ => unreachable!(),
};
@@ -882,6 +929,49 @@ pub fn configure_goose_mode_dialog() -> Result<(), Box<dyn Error>> {
Ok(())
}
pub fn configure_goose_router_strategy_dialog() -> Result<(), Box<dyn Error>> {
let config = Config::global();
// Check if GOOSE_ROUTER_STRATEGY is set as an environment variable
if std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY").is_ok() {
let _ = cliclack::log::info("Notice: GOOSE_ROUTER_TOOL_SELECTION_STRATEGY environment variable is set. Configuration will override this.");
}
let strategy = cliclack::select("Which router strategy would you like to use?")
.item(
"vector",
"Vector Strategy",
"Use vector-based similarity to select tools",
)
.item(
"default",
"Default Strategy",
"Use the default tool selection strategy",
)
.interact()?;
match strategy {
"vector" => {
config.set_param(
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
Value::String("vector".to_string()),
)?;
cliclack::outro(
"Set to Vector Strategy - using vector-based similarity for tool selection",
)?;
}
"default" => {
config.set_param(
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
Value::String("default".to_string()),
)?;
cliclack::outro("Set to Default Strategy - using default tool selection")?;
}
_ => unreachable!(),
};
Ok(())
}
pub fn configure_tool_output_dialog() -> Result<(), Box<dyn Error>> {
let config = Config::global();
// Check if GOOSE_CLI_MIN_PRIORITY is set as an environment variable
@@ -1104,3 +1194,26 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box<dyn Error>> {
Ok(())
}
fn configure_recipe_dialog() -> Result<(), Box<dyn Error>> {
let key_name = GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY;
let config = Config::global();
let default_recipe_repo = std::env::var(key_name)
.ok()
.or_else(|| config.get_param(key_name).unwrap_or(None));
let mut recipe_repo_input = cliclack::input(
"Enter your Goose Recipe Github repo (owner/repo): eg: my_org/goose-recipes",
)
.required(false);
if let Some(recipe_repo) = default_recipe_repo {
recipe_repo_input = recipe_repo_input.default_input(&recipe_repo);
}
let input_value: String = recipe_repo_input.interact()?;
// if input is blank, it clears the recipe github repo settings in the config file
if input_value.clone().trim().is_empty() {
config.delete(key_name)?;
} else {
config.set_param(key_name, Value::String(input_value))?;
}
Ok(())
}

View File

@@ -7,6 +7,16 @@ use mcp_server::router::RouterService;
use mcp_server::{BoundedService, ByteTransport, Server};
use tokio::io::{stdin, stdout};
use std::sync::Arc;
use tokio::sync::Notify;
#[cfg(unix)]
use nix::sys::signal::{kill, Signal};
#[cfg(unix)]
use nix::unistd::getpgrp;
#[cfg(unix)]
use nix::unistd::Pid;
pub async fn run_server(name: &str) -> Result<()> {
// Initialize logging
crate::logging::setup_logging(Some(&format!("mcp-{name}")), None)?;
@@ -26,10 +36,38 @@ pub async fn run_server(name: &str) -> Result<()> {
_ => None,
};
// Create shutdown notification channel
let shutdown = Arc::new(Notify::new());
let shutdown_clone = shutdown.clone();
// Spawn shutdown signal handler
tokio::spawn(async move {
crate::signal::shutdown_signal().await;
shutdown_clone.notify_one();
});
// Create and run the server
let server = Server::new(router.unwrap_or_else(|| panic!("Unknown server requested {}", name)));
let transport = ByteTransport::new(stdin(), stdout());
tracing::info!("Server initialized and ready to handle requests");
Ok(server.run(transport).await?)
tokio::select! {
result = server.run(transport) => {
Ok(result?)
}
_ = shutdown.notified() => {
// On Unix systems, kill the entire process group
#[cfg(unix)]
{
fn terminate_process_group() {
let pgid = getpgrp();
kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGTERM)
.expect("Failed to send SIGTERM to process group");
}
terminate_process_group();
}
Ok(())
}
}
}

View File

@@ -2,6 +2,8 @@ pub mod bench;
pub mod configure;
pub mod info;
pub mod mcp;
pub mod project;
pub mod recipe;
pub mod schedule;
pub mod session;
pub mod update;

View File

@@ -0,0 +1,307 @@
use anyhow::Result;
use chrono::DateTime;
use cliclack::{self, intro, outro};
use std::path::Path;
use crate::project_tracker::ProjectTracker;
/// Format a DateTime for display
fn format_date(date: DateTime<chrono::Utc>) -> String {
// Format: "2025-05-08 18:15:30"
date.format("%Y-%m-%d %H:%M:%S").to_string()
}
/// Handle the default project command
///
/// Offers options to resume the most recently accessed project
pub fn handle_project_default() -> Result<()> {
let tracker = ProjectTracker::load()?;
let mut projects = tracker.list_projects();
if projects.is_empty() {
// If no projects exist, just start a new one in the current directory
println!("No previous projects found. Starting a new session in the current directory.");
let mut command = std::process::Command::new("goose");
command.arg("session");
let status = command.status()?;
if !status.success() {
println!("Failed to run Goose. Exit code: {:?}", status.code());
}
return Ok(());
}
// Sort projects by last_accessed (newest first)
projects.sort_by(|a, b| b.last_accessed.cmp(&a.last_accessed));
// Get the most recent project
let project = &projects[0];
let project_dir = &project.path;
// Check if the directory exists
if !Path::new(project_dir).exists() {
println!(
"Most recent project directory '{}' no longer exists.",
project_dir
);
return Ok(());
}
// Format the path for display
let path = Path::new(project_dir);
let components: Vec<_> = path.components().collect();
let len = components.len();
let short_path = if len <= 2 {
project_dir.clone()
} else {
let mut path_str = String::new();
path_str.push_str("...");
for component in components.iter().skip(len - 2) {
path_str.push('/');
path_str.push_str(component.as_os_str().to_string_lossy().as_ref());
}
path_str
};
// Ask the user what they want to do
let _ = intro("Goose Project Manager");
let current_dir = std::env::current_dir()?;
let current_dir_display = current_dir.display();
let choice = cliclack::select("Choose an option:")
.item(
"resume",
format!("Resume project with session: {}", short_path),
"Continue with the previous session",
)
.item(
"fresh",
format!("Resume project with fresh session: {}", short_path),
"Change to the project directory but start a new session",
)
.item(
"new",
format!(
"Start new project in current directory: {}",
current_dir_display
),
"Stay in the current directory and start a new session",
)
.interact()?;
match choice {
"resume" => {
let _ = outro(format!("Changing to directory: {}", project_dir));
// Get the session ID if available
let session_id = project.last_session_id.clone();
// Change to the project directory
std::env::set_current_dir(project_dir)?;
// Build the command to run Goose
let mut command = std::process::Command::new("goose");
command.arg("session");
if let Some(id) = session_id {
command.arg("--name").arg(&id).arg("--resume");
println!("Resuming session: {}", id);
}
// Execute the command
let status = command.status()?;
if !status.success() {
println!("Failed to run Goose. Exit code: {:?}", status.code());
}
}
"fresh" => {
let _ = outro(format!(
"Changing to directory: {} with a fresh session",
project_dir
));
// Change to the project directory
std::env::set_current_dir(project_dir)?;
// Build the command to run Goose with a fresh session
let mut command = std::process::Command::new("goose");
command.arg("session");
// Execute the command
let status = command.status()?;
if !status.success() {
println!("Failed to run Goose. Exit code: {:?}", status.code());
}
}
"new" => {
let _ = outro("Starting a new session in the current directory");
// Build the command to run Goose
let mut command = std::process::Command::new("goose");
command.arg("session");
// Execute the command
let status = command.status()?;
if !status.success() {
println!("Failed to run Goose. Exit code: {:?}", status.code());
}
}
_ => {
let _ = outro("Operation canceled");
}
}
Ok(())
}
/// Handle the interactive projects command
///
/// Shows a list of projects and lets the user select one to resume
pub fn handle_projects_interactive() -> Result<()> {
let tracker = ProjectTracker::load()?;
let mut projects = tracker.list_projects();
if projects.is_empty() {
println!("No projects found.");
return Ok(());
}
// Sort projects by last_accessed (newest first)
projects.sort_by(|a, b| b.last_accessed.cmp(&a.last_accessed));
// Format project paths for display
let project_choices: Vec<(String, String)> = projects
.iter()
.enumerate()
.map(|(i, project)| {
let path = Path::new(&project.path);
let components: Vec<_> = path.components().collect();
let len = components.len();
let short_path = if len <= 2 {
project.path.clone()
} else {
let mut path_str = String::new();
path_str.push_str("...");
for component in components.iter().skip(len - 2) {
path_str.push('/');
path_str.push_str(component.as_os_str().to_string_lossy().as_ref());
}
path_str
};
// Include last instruction if available (truncated)
let instruction_preview =
project
.last_instruction
.as_ref()
.map_or(String::new(), |instr| {
let truncated = if instr.len() > 40 {
format!("{}...", &instr[0..37])
} else {
instr.clone()
};
format!(" [{}]", truncated)
});
let formatted_date = format_date(project.last_accessed);
(
format!("{}", i + 1), // Value to return
format!("{} ({}){}", short_path, formatted_date, instruction_preview), // Display text with instruction
)
})
.collect();
// Let the user select a project
let _ = intro("Goose Project Manager");
let mut select = cliclack::select("Select a project:");
// Add each project as an option
for (value, display) in &project_choices {
select = select.item(value, display, "");
}
// Add a cancel option
let cancel_value = String::from("cancel");
select = select.item(&cancel_value, "Cancel", "Don't resume any project");
let selected = select.interact()?;
if selected == "cancel" {
let _ = outro("Project selection canceled.");
return Ok(());
}
// Parse the selected index
let index = selected.parse::<usize>().unwrap_or(0);
if index == 0 || index > projects.len() {
let _ = outro("Invalid selection.");
return Ok(());
}
// Get the selected project
let project = &projects[index - 1];
let project_dir = &project.path;
// Check if the directory exists
if !Path::new(project_dir).exists() {
let _ = outro(format!(
"Project directory '{}' no longer exists.",
project_dir
));
return Ok(());
}
// Ask if the user wants to resume the session or start a new one
let session_id = project.last_session_id.clone();
let has_previous_session = session_id.is_some();
// Change to the project directory first
std::env::set_current_dir(project_dir)?;
let _ = outro(format!("Changed to directory: {}", project_dir));
// Only ask about resuming if there's a previous session
let resume_session = if has_previous_session {
let session_choice = cliclack::select("What would you like to do?")
.item(
"resume",
"Resume previous session",
"Continue with the previous session",
)
.item(
"new",
"Start new session",
"Start a fresh session in this project directory",
)
.interact()?;
session_choice == "resume"
} else {
false
};
// Build the command to run Goose
let mut command = std::process::Command::new("goose");
command.arg("session");
if resume_session {
if let Some(id) = session_id {
command.arg("--name").arg(&id).arg("--resume");
println!("Resuming session: {}", id);
}
} else {
println!("Starting new session");
}
// Execute the command
let status = command.status()?;
if !status.success() {
println!("Failed to run Goose. Exit code: {:?}", status.code());
}
Ok(())
}

View File

@@ -1,9 +1,8 @@
use anyhow::Result;
use base64::Engine;
use console::style;
use std::path::Path;
use crate::recipe::load_recipe;
use crate::recipes::recipe::load_recipe;
/// Validates a recipe file
///
@@ -14,9 +13,9 @@ use crate::recipe::load_recipe;
/// # Returns
///
/// Result indicating success or failure
pub fn handle_validate<P: AsRef<Path>>(file_path: P) -> Result<()> {
pub fn handle_validate(recipe_name: &str) -> Result<()> {
// Load and validate the recipe file
match load_recipe(&file_path, false) {
match load_recipe(recipe_name) {
Ok(_) => {
println!("{} recipe file is valid", style("").green().bold());
Ok(())
@@ -37,9 +36,9 @@ pub fn handle_validate<P: AsRef<Path>>(file_path: P) -> Result<()> {
/// # Returns
///
/// Result indicating success or failure
pub fn handle_deeplink<P: AsRef<Path>>(file_path: P) -> Result<()> {
pub fn handle_deeplink(recipe_name: &str) -> Result<()> {
// Load the recipe file first to validate it
match load_recipe(&file_path, false) {
match load_recipe(recipe_name) {
Ok(recipe) => {
if let Ok(recipe_json) = serde_json::to_string(&recipe) {
let deeplink = base64::engine::general_purpose::STANDARD.encode(recipe_json);

View File

@@ -0,0 +1,186 @@
use anyhow::{bail, Context, Result};
use base64::engine::{general_purpose::STANDARD as BASE64_STANDARD, Engine};
use goose::scheduler::{
get_default_scheduled_recipes_dir, get_default_scheduler_storage_path, ScheduledJob, Scheduler,
SchedulerError,
};
use std::path::Path;
// Base64 decoding function - might be needed if recipe_source_arg can be base64
// For now, handle_schedule_add will assume it's a path.
async fn _decode_base64_recipe(source: &str) -> Result<String> {
let bytes = BASE64_STANDARD
.decode(source.as_bytes())
.with_context(|| "Recipe source is not a valid path and not valid Base64.")?;
String::from_utf8(bytes).with_context(|| "Decoded Base64 recipe source is not valid UTF-8.")
}
pub async fn handle_schedule_add(
id: String,
cron: String,
recipe_source_arg: String, // This is expected to be a file path by the Scheduler
) -> Result<()> {
println!(
"[CLI Debug] Scheduling job ID: {}, Cron: {}, Recipe Source Path: {}",
id, cron, recipe_source_arg
);
// The Scheduler's add_scheduled_job will handle copying the recipe from recipe_source_arg
// to its internal storage and validating the path.
let job = ScheduledJob {
id: id.clone(),
source: recipe_source_arg.clone(), // Pass the original user-provided path
cron,
last_run: None,
currently_running: false,
paused: false,
};
let scheduler_storage_path =
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
let scheduler = Scheduler::new(scheduler_storage_path)
.await
.context("Failed to initialize scheduler")?;
match scheduler.add_scheduled_job(job).await {
Ok(_) => {
// The scheduler has copied the recipe to its internal directory.
// We can reconstruct the likely path for display if needed, or adjust success message.
let scheduled_recipes_dir = get_default_scheduled_recipes_dir()
.unwrap_or_else(|_| Path::new("./.goose_scheduled_recipes").to_path_buf()); // Fallback for display
let extension = Path::new(&recipe_source_arg)
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("yaml");
let final_recipe_path = scheduled_recipes_dir.join(format!("{}.{}", id, extension));
println!(
"Scheduled job '{}' added. Recipe expected at {:?}",
id, final_recipe_path
);
Ok(())
}
Err(e) => {
// No local file to clean up by the CLI in this revised flow.
match e {
SchedulerError::JobIdExists(job_id) => {
bail!("Error: Job with ID '{}' already exists.", job_id);
}
SchedulerError::RecipeLoadError(msg) => {
bail!(
"Error with recipe source: {}. Path: {}",
msg,
recipe_source_arg
);
}
_ => Err(anyhow::Error::new(e))
.context(format!("Failed to add job '{}' to scheduler", id)),
}
}
}
}
pub async fn handle_schedule_list() -> Result<()> {
let scheduler_storage_path =
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
let scheduler = Scheduler::new(scheduler_storage_path)
.await
.context("Failed to initialize scheduler")?;
let jobs = scheduler.list_scheduled_jobs().await;
if jobs.is_empty() {
println!("No scheduled jobs found.");
} else {
println!("Scheduled Jobs:");
for job in jobs {
println!(
"- ID: {}\n Cron: {}\n Recipe Source (in store): {}\n Last Run: {}",
job.id,
job.cron,
job.source, // This source is now the path within scheduled_recipes_dir
job.last_run
.map_or_else(|| "Never".to_string(), |dt| dt.to_rfc3339())
);
}
}
Ok(())
}
pub async fn handle_schedule_remove(id: String) -> Result<()> {
let scheduler_storage_path =
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
let scheduler = Scheduler::new(scheduler_storage_path)
.await
.context("Failed to initialize scheduler")?;
match scheduler.remove_scheduled_job(&id).await {
Ok(_) => {
println!("Scheduled job '{}' and its associated recipe removed.", id);
Ok(())
}
Err(e) => match e {
SchedulerError::JobNotFound(job_id) => {
bail!("Error: Job with ID '{}' not found.", job_id);
}
_ => Err(anyhow::Error::new(e))
.context(format!("Failed to remove job '{}' from scheduler", id)),
},
}
}
pub async fn handle_schedule_sessions(id: String, limit: Option<u32>) -> Result<()> {
let scheduler_storage_path =
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
let scheduler = Scheduler::new(scheduler_storage_path)
.await
.context("Failed to initialize scheduler")?;
match scheduler.sessions(&id, limit.unwrap_or(50) as usize).await {
Ok(sessions) => {
if sessions.is_empty() {
println!("No sessions found for schedule ID '{}'.", id);
} else {
println!("Sessions for schedule ID '{}':", id);
// sessions is now Vec<(String, SessionMetadata)>
for (session_name, metadata) in sessions {
println!(
" - Session ID: {}, Working Dir: {}, Description: \"{}\", Messages: {}, Schedule ID: {:?}",
session_name, // Display the session_name as Session ID
metadata.working_dir.display(),
metadata.description,
metadata.message_count,
metadata.schedule_id.as_deref().unwrap_or("N/A")
);
}
}
}
Err(e) => {
bail!("Failed to get sessions for schedule '{}': {:?}", id, e);
}
}
Ok(())
}
pub async fn handle_schedule_run_now(id: String) -> Result<()> {
let scheduler_storage_path =
get_default_scheduler_storage_path().context("Failed to get scheduler storage path")?;
let scheduler = Scheduler::new(scheduler_storage_path)
.await
.context("Failed to initialize scheduler")?;
match scheduler.run_now(&id).await {
Ok(session_id) => {
println!(
"Successfully triggered schedule '{}'. New session ID: {}",
id, session_id
);
}
Err(e) => match e {
SchedulerError::JobNotFound(job_id) => {
bail!("Error: Job with ID '{}' not found.", job_id);
}
_ => bail!("Failed to run schedule '{}' now: {:?}", id, e),
},
}
Ok(())
}

View File

@@ -1,18 +1,20 @@
use anyhow::{Context, Result};
use cliclack::{confirm, multiselect};
use goose::session::info::{get_session_info, SessionInfo, SortOrder};
use regex::Regex;
use std::fs;
const TRUNCATED_DESC_LENGTH: usize = 60;
pub fn remove_sessions(sessions: Vec<SessionInfo>) -> Result<()> {
println!("The following sessions will be removed:");
for session in &sessions {
println!("- {}", session.id);
}
let should_delete =
cliclack::confirm("Are you sure you want to delete all these sessions? (yes/no):")
.initial_value(true)
.interact()?;
let should_delete = confirm("Are you sure you want to delete these sessions?")
.initial_value(false)
.interact()?;
if should_delete {
for session in sessions {
@@ -27,8 +29,50 @@ pub fn remove_sessions(sessions: Vec<SessionInfo>) -> Result<()> {
Ok(())
}
pub fn handle_session_remove(id: String, regex_string: String) -> Result<()> {
let sessions = match get_session_info(SortOrder::Descending) {
fn prompt_interactive_session_selection(sessions: &[SessionInfo]) -> Result<Vec<SessionInfo>> {
if sessions.is_empty() {
println!("No sessions to delete.");
return Ok(vec![]);
}
let mut selector = multiselect(
"Select sessions to delete (use spacebar, Enter to confirm, Ctrl+C to cancel):",
);
let display_map: std::collections::HashMap<String, SessionInfo> = sessions
.iter()
.map(|s| {
let desc = if s.metadata.description.is_empty() {
"(no description)"
} else {
&s.metadata.description
};
let truncated_desc = if desc.len() > TRUNCATED_DESC_LENGTH {
format!("{}...", &desc[..TRUNCATED_DESC_LENGTH - 3])
} else {
desc.to_string()
};
let display_text = format!("{} - {} ({})", s.modified, truncated_desc, s.id);
(display_text, s.clone())
})
.collect();
for display_text in display_map.keys() {
selector = selector.item(display_text.clone(), display_text.clone(), "");
}
let selected_display_texts: Vec<String> = selector.interact()?;
let selected_sessions: Vec<SessionInfo> = selected_display_texts
.into_iter()
.filter_map(|text| display_map.get(&text).cloned())
.collect();
Ok(selected_sessions)
}
pub fn handle_session_remove(id: Option<String>, regex_string: Option<String>) -> Result<()> {
let all_sessions = match get_session_info(SortOrder::Descending) {
Ok(sessions) => sessions,
Err(e) => {
tracing::error!("Failed to retrieve sessions: {:?}", e);
@@ -37,29 +81,35 @@ pub fn handle_session_remove(id: String, regex_string: String) -> Result<()> {
};
let matched_sessions: Vec<SessionInfo>;
if !id.is_empty() {
if let Some(session) = sessions.iter().find(|s| s.id == id) {
if let Some(id_val) = id {
if let Some(session) = all_sessions.iter().find(|s| s.id == id_val) {
matched_sessions = vec![session.clone()];
} else {
return Err(anyhow::anyhow!("Session '{}' not found.", id));
return Err(anyhow::anyhow!("Session '{}' not found.", id_val));
}
} else if !regex_string.is_empty() {
let session_regex = Regex::new(&regex_string)
.with_context(|| format!("Invalid regex pattern '{}'", regex_string))?;
matched_sessions = sessions
} else if let Some(regex_val) = regex_string {
let session_regex = Regex::new(&regex_val)
.with_context(|| format!("Invalid regex pattern '{}'", regex_val))?;
matched_sessions = all_sessions
.into_iter()
.filter(|session| session_regex.is_match(&session.id))
.collect();
if matched_sessions.is_empty() {
println!(
"Regex string '{}' does not match any sessions",
regex_string
);
println!("Regex string '{}' does not match any sessions", regex_val);
return Ok(());
}
} else {
return Err(anyhow::anyhow!("Neither --regex nor --id flags provided."));
if all_sessions.is_empty() {
return Err(anyhow::anyhow!("No sessions found."));
}
matched_sessions = prompt_interactive_session_selection(&all_sessions)?;
}
if matched_sessions.is_empty() {
return Ok(());
}
remove_sessions(matched_sessions)

View File

@@ -3,9 +3,10 @@ use once_cell::sync::Lazy;
pub mod cli;
pub mod commands;
pub mod logging;
pub mod recipe;
pub mod project_tracker;
pub mod recipes;
pub mod session;
pub mod signal;
// Re-export commonly used types
pub use session::Session;

View File

@@ -0,0 +1,146 @@
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use etcetera::{choose_app_strategy, AppStrategy};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
/// Structure to track project information
#[derive(Debug, Serialize, Deserialize)]
pub struct ProjectInfo {
/// The absolute path to the project directory
pub path: String,
/// Last time the project was accessed
pub last_accessed: DateTime<Utc>,
/// Last instruction sent to goose (if available)
pub last_instruction: Option<String>,
/// Last session ID associated with this project
pub last_session_id: Option<String>,
}
/// Structure to hold all tracked projects
#[derive(Debug, Serialize, Deserialize)]
pub struct ProjectTracker {
projects: HashMap<String, ProjectInfo>,
}
/// Project information with path as a separate field for easier access
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProjectInfoDisplay {
/// The absolute path to the project directory
pub path: String,
/// Last time the project was accessed
pub last_accessed: DateTime<Utc>,
/// Last instruction sent to goose (if available)
pub last_instruction: Option<String>,
/// Last session ID associated with this project
pub last_session_id: Option<String>,
}
impl ProjectTracker {
/// Get the path to the projects.json file
fn get_projects_file() -> Result<PathBuf> {
let projects_file = choose_app_strategy(crate::APP_STRATEGY.clone())
.context("goose requires a home dir")?
.in_data_dir("projects.json");
// Ensure data directory exists
if let Some(parent) = projects_file.parent() {
if !parent.exists() {
fs::create_dir_all(parent)?;
}
}
Ok(projects_file)
}
/// Load the project tracker from the projects.json file
pub fn load() -> Result<Self> {
let projects_file = Self::get_projects_file()?;
if projects_file.exists() {
let file_content = fs::read_to_string(&projects_file)?;
let tracker: ProjectTracker = serde_json::from_str(&file_content)
.context("Failed to parse projects.json file")?;
Ok(tracker)
} else {
// If the file doesn't exist, create a new empty tracker
Ok(ProjectTracker {
projects: HashMap::new(),
})
}
}
/// Save the project tracker to the projects.json file
pub fn save(&self) -> Result<()> {
let projects_file = Self::get_projects_file()?;
let json = serde_json::to_string_pretty(self)?;
fs::write(projects_file, json)?;
Ok(())
}
/// Update project information for the current directory
///
/// # Arguments
/// * `project_dir` - The project directory to update
/// * `instruction` - Optional instruction that was sent to goose
/// * `session_id` - Optional session ID associated with this project
pub fn update_project(
&mut self,
project_dir: &Path,
instruction: Option<&str>,
session_id: Option<&str>,
) -> Result<()> {
let dir_str = project_dir.to_string_lossy().to_string();
// Create or update the project entry
let project_info = self.projects.entry(dir_str.clone()).or_insert(ProjectInfo {
path: dir_str,
last_accessed: Utc::now(),
last_instruction: None,
last_session_id: None,
});
// Update the last accessed time
project_info.last_accessed = Utc::now();
// Update the last instruction if provided
if let Some(instr) = instruction {
project_info.last_instruction = Some(instr.to_string());
}
// Update the session ID if provided
if let Some(id) = session_id {
project_info.last_session_id = Some(id.to_string());
}
self.save()
}
/// List all tracked projects
///
/// Returns a vector of ProjectInfoDisplay objects
pub fn list_projects(&self) -> Vec<ProjectInfoDisplay> {
self.projects
.values()
.map(|info| ProjectInfoDisplay {
path: info.path.clone(),
last_accessed: info.last_accessed,
last_instruction: info.last_instruction.clone(),
last_session_id: info.last_session_id.clone(),
})
.collect()
}
}
/// Update the project tracker with the current directory and optional instruction
///
/// # Arguments
/// * `instruction` - Optional instruction that was sent to goose
/// * `session_id` - Optional session ID associated with this project
pub fn update_project_tracker(instruction: Option<&str>, session_id: Option<&str>) -> Result<()> {
let current_dir = std::env::current_dir()?;
let mut tracker = ProjectTracker::load()?;
tracker.update_project(&current_dir, instruction, session_id)
}

View File

@@ -1,70 +0,0 @@
use anyhow::{Context, Result};
use console::style;
use std::path::Path;
use goose::recipe::Recipe;
/// Loads and validates a recipe from a YAML or JSON file
///
/// # Arguments
///
/// * `path` - Path to the recipe file (YAML or JSON)
/// * `log` - whether to log information about the recipe or not
///
/// # Returns
///
/// The parsed recipe struct if successful
///
/// # Errors
///
/// Returns an error if:
/// - The file doesn't exist
/// - The file can't be read
/// - The YAML/JSON is invalid
/// - The required fields are missing
pub fn load_recipe<P: AsRef<Path>>(path: P, log: bool) -> Result<Recipe> {
let path = path.as_ref();
// Check if file exists
if !path.exists() {
return Err(anyhow::anyhow!("recipe file not found: {}", path.display()));
}
// Read file content
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read recipe file: {}", path.display()))?;
// Determine file format based on extension and parse accordingly
let recipe: Recipe = if let Some(extension) = path.extension() {
match extension.to_str().unwrap_or("").to_lowercase().as_str() {
"json" => serde_json::from_str(&content)
.with_context(|| format!("Failed to parse JSON recipe file: {}", path.display()))?,
"yaml" => serde_yaml::from_str(&content)
.with_context(|| format!("Failed to parse YAML recipe file: {}", path.display()))?,
_ => {
return Err(anyhow::anyhow!(
"Unsupported file format for recipe file: {}. Expected .yaml or .json",
path.display()
))
}
}
} else {
return Err(anyhow::anyhow!(
"File has no extension: {}. Expected .yaml or .json",
path.display()
));
};
if log {
// Display information about the loaded recipe
println!(
"{} {}",
style("Loading recipe:").green().bold(),
style(&recipe.title).green()
);
println!("{} {}", style("Description:").dim(), &recipe.description);
println!(); // Add a blank line for spacing
}
Ok(recipe)
}

View File

@@ -0,0 +1,192 @@
use anyhow::Result;
use console::style;
use std::env;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
use std::process::Stdio;
use tar::Archive;
use crate::recipes::recipe::RECIPE_FILE_EXTENSIONS;
pub const GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY: &str = "GOOSE_RECIPE_GITHUB_REPO";
pub fn retrieve_recipe_from_github(
recipe_name: &str,
recipe_repo_full_name: &str,
) -> Result<(String, PathBuf)> {
println!(
"📦 Looking for recipe \"{}\" in github repo: {}",
recipe_name, recipe_repo_full_name
);
ensure_gh_authenticated()?;
let max_attempts = 2;
let mut last_err = None;
for attempt in 1..=max_attempts {
match clone_and_download_recipe(recipe_name, recipe_repo_full_name) {
Ok(download_dir) => match read_recipe_file(&download_dir) {
Ok(content) => return Ok((content, download_dir)),
Err(err) => return Err(err),
},
Err(err) => {
last_err = Some(err);
}
}
if attempt < max_attempts {
clean_cloned_dirs(recipe_repo_full_name)?;
}
}
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Unknown error occurred")))
}
fn clean_cloned_dirs(recipe_repo_full_name: &str) -> anyhow::Result<()> {
let local_repo_path = get_local_repo_path(&env::temp_dir(), recipe_repo_full_name)?;
if local_repo_path.exists() {
fs::remove_dir_all(&local_repo_path)?;
}
Ok(())
}
fn read_recipe_file(download_dir: &Path) -> Result<String> {
for ext in RECIPE_FILE_EXTENSIONS {
let candidate_file_path = download_dir.join(format!("recipe.{}", ext));
if candidate_file_path.exists() {
let content = fs::read_to_string(&candidate_file_path)?;
println!(
"⬇️ Retrieved recipe file: {}",
candidate_file_path
.strip_prefix(download_dir)
.unwrap()
.display()
);
return Ok(content);
}
}
Err(anyhow::anyhow!(
"No recipe file found in {} (looked for extensions: {:?})",
download_dir.display(),
RECIPE_FILE_EXTENSIONS
))
}
fn clone_and_download_recipe(recipe_name: &str, recipe_repo_full_name: &str) -> Result<PathBuf> {
let local_repo_path = ensure_repo_cloned(recipe_repo_full_name)?;
fetch_origin(&local_repo_path)?;
get_folder_from_github(&local_repo_path, recipe_name)
}
fn ensure_gh_authenticated() -> Result<()> {
// Check authentication status
let status = Command::new("gh")
.args(["auth", "status"])
.status()
.map_err(|_| {
anyhow::anyhow!("Failed to run `gh auth status`. Make sure you have `gh` installed.")
})?;
if status.success() {
return Ok(());
}
println!("GitHub CLI is not authenticated. Launching `gh auth login`...");
// Run `gh auth login` interactively
let login_status = Command::new("gh")
.args(["auth", "login", "--git-protocol", "https"])
.status()
.map_err(|_| anyhow::anyhow!("Failed to run `gh auth login`"))?;
if !login_status.success() {
Err(anyhow::anyhow!("Failed to authenticate using GitHub CLI."))
} else {
Ok(())
}
}
fn get_local_repo_path(
local_repo_parent_path: &Path,
recipe_repo_full_name: &str,
) -> Result<PathBuf> {
let (_, repo_name) = recipe_repo_full_name
.split_once('/')
.ok_or_else(|| anyhow::anyhow!("Invalid repository name format"))?;
let local_repo_path = local_repo_parent_path.to_path_buf().join(repo_name);
Ok(local_repo_path)
}
fn ensure_repo_cloned(recipe_repo_full_name: &str) -> Result<PathBuf> {
let local_repo_parent_path = env::temp_dir();
if !local_repo_parent_path.exists() {
std::fs::create_dir_all(local_repo_parent_path.clone())?;
}
let local_repo_path = get_local_repo_path(&local_repo_parent_path, recipe_repo_full_name)?;
if local_repo_path.join(".git").exists() {
Ok(local_repo_path)
} else {
let error_message: String = format!("Failed to clone repo: {}", recipe_repo_full_name);
let status = Command::new("gh")
.args(["repo", "clone", recipe_repo_full_name])
.current_dir(local_repo_parent_path.clone())
.status()
.map_err(|_: std::io::Error| anyhow::anyhow!(error_message.clone()))?;
if status.success() {
Ok(local_repo_path)
} else {
Err(anyhow::anyhow!(error_message))
}
}
}
fn fetch_origin(local_repo_path: &Path) -> Result<()> {
let error_message: String = format!("Failed to fetch at {}", local_repo_path.to_str().unwrap());
let status = Command::new("git")
.args(["fetch", "origin"])
.current_dir(local_repo_path)
.status()
.map_err(|_| anyhow::anyhow!(error_message.clone()))?;
if status.success() {
Ok(())
} else {
Err(anyhow::anyhow!(error_message))
}
}
fn get_folder_from_github(local_repo_path: &Path, recipe_name: &str) -> Result<PathBuf> {
let ref_and_path = format!("origin/main:{}", recipe_name);
let output_dir = env::temp_dir().join(recipe_name);
if output_dir.exists() {
fs::remove_dir_all(&output_dir)?;
}
fs::create_dir_all(&output_dir)?;
let archive_output = Command::new("git")
.args(["archive", &ref_and_path])
.current_dir(local_repo_path)
.stdout(Stdio::piped())
.spawn()?;
let stdout = archive_output
.stdout
.ok_or_else(|| anyhow::anyhow!("Failed to capture stdout from git archive"))?;
let mut archive = Archive::new(stdout);
archive.unpack(&output_dir)?;
list_files(&output_dir)?;
Ok(output_dir)
}
fn list_files(dir: &Path) -> Result<()> {
println!("{}", style("Files downloaded from github:").bold());
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() {
println!(" - {}", path.display());
}
}
Ok(())
}

View File

@@ -0,0 +1,4 @@
pub mod github_recipe;
pub mod print_recipe;
pub mod recipe;
pub mod search_recipe;

View File

@@ -0,0 +1,83 @@
use std::collections::HashMap;
use console::style;
use goose::recipe::Recipe;
use crate::recipes::recipe::BUILT_IN_RECIPE_DIR_PARAM;
pub fn print_recipe_explanation(recipe: &Recipe) {
println!(
"{} {}",
style("🔍 Loading recipe:").bold().green(),
style(&recipe.title).green()
);
println!("{}", style("📄 Description:").bold());
println!(" {}", recipe.description);
if let Some(params) = &recipe.parameters {
if !params.is_empty() {
println!("{}", style("⚙️ Recipe Parameters:").bold());
for param in params {
let default_display = match &param.default {
Some(val) => format!(" (default: {})", val),
None => String::new(),
};
println!(
" - {} ({}, {}){}: {}",
style(&param.key).cyan(),
param.input_type,
param.requirement,
default_display,
param.description
);
}
}
}
}
pub fn print_parameters_with_values(params: HashMap<String, String>) {
for (key, value) in params {
let label = if key == BUILT_IN_RECIPE_DIR_PARAM {
" (built-in)"
} else {
""
};
println!(" {}{}: {}", key, label, value);
}
}
pub fn print_required_parameters_for_template(
params_for_template: HashMap<String, String>,
missing_params: Vec<String>,
) {
if !params_for_template.is_empty() {
println!(
"{}",
style("📥 Parameters used to load this recipe:").bold()
);
print_parameters_with_values(params_for_template)
}
if !missing_params.is_empty() {
println!(
"{}",
style("🔴 Missing parameters in the command line if you want to run the recipe:")
.bold()
);
for param in missing_params.iter() {
println!(" - {}", param);
}
println!(
"📩 {}:",
style("Please provide the following parameters in the command line if you want to run the recipe:").bold()
);
println!(" {}", missing_parameters_command_line(missing_params));
}
}
pub fn missing_parameters_command_line(missing_params: Vec<String>) -> String {
missing_params
.iter()
.map(|key| format!("--params {}=your_value", key))
.collect::<Vec<_>>()
.join(" ")
}

View File

@@ -0,0 +1,532 @@
use anyhow::Result;
use console::style;
use crate::recipes::print_recipe::{
missing_parameters_command_line, print_parameters_with_values, print_recipe_explanation,
print_required_parameters_for_template,
};
use crate::recipes::search_recipe::retrieve_recipe_file;
use goose::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement};
use minijinja::{Environment, Error, Template, UndefinedBehavior};
use serde_json::Value as JsonValue;
use serde_yaml::Value as YamlValue;
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
pub const BUILT_IN_RECIPE_DIR_PARAM: &str = "recipe_dir";
pub const RECIPE_FILE_EXTENSIONS: &[&str] = &["yaml", "json"];
/// Loads, validates a recipe from a YAML or JSON file, and renders it with the given parameters
///
/// # Arguments
///
/// * `path` - Path to the recipe file (YAML or JSON)
/// * `params` - parameters to render the recipe with
///
/// # Returns
///
/// The rendered recipe if successful
///
/// # Errors
///
/// Returns an error if:
/// - Recipe is not valid
/// - The required fields are missing
pub fn load_recipe_as_template(recipe_name: &str, params: Vec<(String, String)>) -> Result<Recipe> {
let (recipe_file_content, recipe_parent_dir) = retrieve_recipe_file(recipe_name)?;
let recipe = validate_recipe_file_parameters(&recipe_file_content)?;
let (params_for_template, missing_params) =
apply_values_to_parameters(&params, recipe.parameters, recipe_parent_dir, true)?;
if !missing_params.is_empty() {
return Err(anyhow::anyhow!(
"Please provide the following parameters in the command line: {}",
missing_parameters_command_line(missing_params)
));
}
let rendered_content = render_content_with_params(&recipe_file_content, &params_for_template)?;
let recipe = parse_recipe_content(&rendered_content)?;
// Display information about the loaded recipe
println!(
"{} {}",
style("Loading recipe:").green().bold(),
style(&recipe.title).green()
);
println!("{} {}", style("Description:").bold(), &recipe.description);
if !params_for_template.is_empty() {
println!("{}", style("Parameters used to load this recipe:").bold());
print_parameters_with_values(params_for_template);
}
println!();
Ok(recipe)
}
/// Loads and validates a recipe from a YAML or JSON file
///
/// # Arguments
///
/// * `path` - Path to the recipe file (YAML or JSON)
/// * `params` - optional parameters to render the recipe with
///
/// # Returns
///
/// The parsed recipe struct if successful
///
/// # Errors
///
/// Returns an error if:
/// - The file doesn't exist
/// - The file can't be read
/// - The YAML/JSON is invalid
/// - The parameter definition does not match the template variables in the recipe file
pub fn load_recipe(recipe_name: &str) -> Result<Recipe> {
let (recipe_file_content, _) = retrieve_recipe_file(recipe_name)?;
validate_recipe_file_parameters(&recipe_file_content)
}
pub fn explain_recipe_with_parameters(
recipe_name: &str,
params: Vec<(String, String)>,
) -> Result<()> {
let (recipe_file_content, recipe_parent_dir) = retrieve_recipe_file(recipe_name)?;
let raw_recipe = validate_recipe_file_parameters(&recipe_file_content)?;
print_recipe_explanation(&raw_recipe);
let recipe_parameters = raw_recipe.parameters;
let (params_for_template, missing_params) =
apply_values_to_parameters(&params, recipe_parameters, recipe_parent_dir, false)?;
print_required_parameters_for_template(params_for_template, missing_params);
Ok(())
}
fn validate_recipe_file_parameters(recipe_file_content: &str) -> Result<Recipe> {
let recipe_from_recipe_file: Recipe = parse_recipe_content(recipe_file_content)?;
validate_optional_parameters(&recipe_from_recipe_file)?;
validate_parameters_in_template(&recipe_from_recipe_file.parameters, recipe_file_content)?;
Ok(recipe_from_recipe_file)
}
fn validate_parameters_in_template(
recipe_parameters: &Option<Vec<RecipeParameter>>,
recipe_file_content: &str,
) -> Result<()> {
let mut template_variables = extract_template_variables(recipe_file_content)?;
template_variables.remove(BUILT_IN_RECIPE_DIR_PARAM);
let param_keys: HashSet<String> = recipe_parameters
.as_ref()
.unwrap_or(&vec![])
.iter()
.map(|p| p.key.clone())
.collect();
let missing_keys = template_variables
.difference(&param_keys)
.collect::<Vec<_>>();
let extra_keys = param_keys
.difference(&template_variables)
.collect::<Vec<_>>();
if missing_keys.is_empty() && extra_keys.is_empty() {
return Ok(());
}
let mut message = String::new();
if !missing_keys.is_empty() {
message.push_str(&format!(
"Missing definitions for parameters in the recipe file: {}.",
missing_keys
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join(", ")
));
}
if !extra_keys.is_empty() {
message.push_str(&format!(
"\nUnnecessary parameter definitions: {}.",
extra_keys
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join(", ")
));
}
Err(anyhow::anyhow!("{}", message.trim_end()))
}
fn validate_optional_parameters(recipe: &Recipe) -> Result<()> {
let optional_params_without_default_values: Vec<String> = recipe
.parameters
.as_ref()
.unwrap_or(&vec![])
.iter()
.filter(|p| {
matches!(p.requirement, RecipeParameterRequirement::Optional) && p.default.is_none()
})
.map(|p| p.key.clone())
.collect();
if optional_params_without_default_values.is_empty() {
Ok(())
} else {
Err(anyhow::anyhow!("Optional parameters missing default values in the recipe: {}. Please provide defaults.", optional_params_without_default_values.join(", ")))
}
}
fn parse_recipe_content(content: &str) -> Result<Recipe> {
if serde_json::from_str::<JsonValue>(content).is_ok() {
Ok(serde_json::from_str(content)?)
} else if serde_yaml::from_str::<YamlValue>(content).is_ok() {
Ok(serde_yaml::from_str(content)?)
} else {
Err(anyhow::anyhow!(
"Unsupported file format for recipe file. Expected .yaml or .json"
))
}
}
fn extract_template_variables(template_str: &str) -> Result<HashSet<String>> {
let mut env = Environment::new();
env.set_undefined_behavior(UndefinedBehavior::Strict);
let template = env
.template_from_str(template_str)
.map_err(|e: Error| anyhow::anyhow!("Invalid template syntax: {}", e.to_string()))?;
Ok(template.undeclared_variables(true))
}
fn apply_values_to_parameters(
user_params: &[(String, String)],
recipe_parameters: Option<Vec<RecipeParameter>>,
recipe_parent_dir: PathBuf,
enable_user_prompt: bool,
) -> Result<(HashMap<String, String>, Vec<String>)> {
let mut param_map: HashMap<String, String> = user_params.iter().cloned().collect();
let recipe_parent_dir_str = recipe_parent_dir
.to_str()
.ok_or_else(|| anyhow::anyhow!("Invalid UTF-8 in recipe_dir"))?;
param_map.insert(
BUILT_IN_RECIPE_DIR_PARAM.to_string(),
recipe_parent_dir_str.to_string(),
);
let mut missing_params: Vec<String> = Vec::new();
for param in recipe_parameters.unwrap_or_default() {
if !param_map.contains_key(&param.key) {
match (&param.default, &param.requirement) {
(Some(default), _) => param_map.insert(param.key.clone(), default.clone()),
(None, RecipeParameterRequirement::UserPrompt) if enable_user_prompt => {
let input_value = cliclack::input(format!(
"Please enter {} ({})",
param.key, param.description
))
.interact()?;
param_map.insert(param.key.clone(), input_value)
}
_ => {
missing_params.push(param.key.clone());
None
}
};
}
}
Ok((param_map, missing_params))
}
fn render_content_with_params(content: &str, params: &HashMap<String, String>) -> Result<String> {
// Create a minijinja environment and context
let mut env = minijinja::Environment::new();
env.set_undefined_behavior(UndefinedBehavior::Strict);
let template: Template<'_, '_> = env
.template_from_str(content)
.map_err(|e: Error| anyhow::anyhow!("Invalid template syntax: {}", e.to_string()))?;
// Render the template with the parameters
template.render(params).map_err(|e: Error| {
anyhow::anyhow!(
"Failed to render the recipe {} - please check if all required parameters are provided",
e.to_string()
)
})
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use goose::recipe::{RecipeParameterInputType, RecipeParameterRequirement};
use tempfile::TempDir;
use super::*;
fn setup_recipe_file(instructions_and_parameters: &str) -> (TempDir, PathBuf) {
let recipe_content = format!(
r#"{{
"version": "1.0.0",
"title": "Test Recipe",
"description": "A test recipe",
{}
}}"#,
instructions_and_parameters
);
// Create a temporary file
let temp_dir = tempfile::tempdir().unwrap();
let recipe_path: std::path::PathBuf = temp_dir.path().join("test_recipe.json");
std::fs::write(&recipe_path, recipe_content).unwrap();
(temp_dir, recipe_path)
}
#[test]
fn test_render_content_with_params() {
// Test basic parameter substitution
let content = "Hello {{ name }}!";
let mut params = HashMap::new();
params.insert("name".to_string(), "World".to_string());
let result = render_content_with_params(content, &params).unwrap();
assert_eq!(result, "Hello World!");
// Test empty parameter substitution
let content = "Hello {{ empty }}!";
let mut params = HashMap::new();
params.insert("empty".to_string(), "".to_string());
let result = render_content_with_params(content, &params).unwrap();
assert_eq!(result, "Hello !");
// Test multiple parameters
let content = "{{ greeting }} {{ name }}!";
let mut params = HashMap::new();
params.insert("greeting".to_string(), "Hi".to_string());
params.insert("name".to_string(), "Alice".to_string());
let result = render_content_with_params(content, &params).unwrap();
assert_eq!(result, "Hi Alice!");
// Test missing parameter results in error
let content = "Hello {{ missing }}!";
let params = HashMap::new();
let err = render_content_with_params(content, &params).unwrap_err();
assert!(err
.to_string()
.contains("please check if all required parameters"));
// Test invalid template syntax results in error
let content = "Hello {{ unclosed";
let params = HashMap::new();
let err = render_content_with_params(content, &params).unwrap_err();
assert!(err.to_string().contains("Invalid template syntax"));
}
#[test]
fn test_load_recipe_as_template_success() {
let instructions_and_parameters = r#"
"instructions": "Test instructions with {{ my_name }}",
"parameters": [
{
"key": "my_name",
"input_type": "string",
"requirement": "required",
"description": "A test parameter"
}
]"#;
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
let params = vec![("my_name".to_string(), "value".to_string())];
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap();
assert_eq!(recipe.title, "Test Recipe");
assert_eq!(recipe.description, "A test recipe");
assert_eq!(recipe.instructions.unwrap(), "Test instructions with value");
// Verify parameters match recipe definition
assert_eq!(recipe.parameters.as_ref().unwrap().len(), 1);
let param = &recipe.parameters.as_ref().unwrap()[0];
assert_eq!(param.key, "my_name");
assert!(matches!(param.input_type, RecipeParameterInputType::String));
assert!(matches!(
param.requirement,
RecipeParameterRequirement::Required
));
assert_eq!(param.description, "A test parameter");
}
#[test]
fn test_load_recipe_as_template_success_variable_in_prompt() {
let instructions_and_parameters = r#"
"instructions": "Test instructions",
"prompt": "My prompt {{ my_name }}",
"parameters": [
{
"key": "my_name",
"input_type": "string",
"requirement": "required",
"description": "A test parameter"
}
]"#;
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
let params = vec![("my_name".to_string(), "value".to_string())];
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap();
assert_eq!(recipe.title, "Test Recipe");
assert_eq!(recipe.description, "A test recipe");
assert_eq!(recipe.instructions.unwrap(), "Test instructions");
assert_eq!(recipe.prompt.unwrap(), "My prompt value");
let param = &recipe.parameters.as_ref().unwrap()[0];
assert_eq!(param.key, "my_name");
assert!(matches!(param.input_type, RecipeParameterInputType::String));
assert!(matches!(
param.requirement,
RecipeParameterRequirement::Required
));
assert_eq!(param.description, "A test parameter");
}
#[test]
fn test_load_recipe_as_template_wrong_parameters_in_recipe_file() {
let instructions_and_parameters = r#"
"instructions": "Test instructions with {{ expected_param1 }} {{ expected_param2 }}",
"parameters": [
{
"key": "wrong_param_key",
"input_type": "string",
"requirement": "required",
"description": "A test parameter"
}
]"#;
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
let load_recipe_result = load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new());
assert!(load_recipe_result.is_err());
let err = load_recipe_result.unwrap_err();
println!("{}", err.to_string());
assert!(err
.to_string()
.contains("Unnecessary parameter definitions: wrong_param_key."));
assert!(err
.to_string()
.contains("Missing definitions for parameters in the recipe file:"));
assert!(err.to_string().contains("expected_param1"));
assert!(err.to_string().contains("expected_param2"));
}
#[test]
fn test_load_recipe_as_template_with_default_values_in_recipe_file() {
let instructions_and_parameters = r#"
"instructions": "Test instructions with {{ param_with_default }} {{ param_without_default }}",
"parameters": [
{
"key": "param_with_default",
"input_type": "string",
"requirement": "optional",
"default": "my_default_value",
"description": "A test parameter"
},
{
"key": "param_without_default",
"input_type": "string",
"requirement": "required",
"description": "A test parameter"
}
]"#;
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
let params = vec![("param_without_default".to_string(), "value1".to_string())];
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), params).unwrap();
assert_eq!(recipe.title, "Test Recipe");
assert_eq!(recipe.description, "A test recipe");
assert_eq!(
recipe.instructions.unwrap(),
"Test instructions with my_default_value value1"
);
}
#[test]
fn test_load_recipe_as_template_optional_parameters_with_empty_default_values_in_recipe_file() {
let instructions_and_parameters = r#"
"instructions": "Test instructions with {{ optional_param }}",
"parameters": [
{
"key": "optional_param",
"input_type": "string",
"requirement": "optional",
"description": "A test parameter",
"default": "",
}
]"#;
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new()).unwrap();
assert_eq!(recipe.title, "Test Recipe");
assert_eq!(recipe.description, "A test recipe");
assert_eq!(recipe.instructions.unwrap(), "Test instructions with ");
}
#[test]
fn test_load_recipe_as_template_optional_parameters_without_default_values_in_recipe_file() {
let instructions_and_parameters = r#"
"instructions": "Test instructions with {{ optional_param }}",
"parameters": [
{
"key": "optional_param",
"input_type": "string",
"requirement": "optional",
"description": "A test parameter"
}
]"#;
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
let load_recipe_result = load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new());
assert!(load_recipe_result.is_err());
let err = load_recipe_result.unwrap_err();
println!("{}", err.to_string());
assert!(err.to_string().contains(
"Optional parameters missing default values in the recipe: optional_param. Please provide defaults."
));
}
#[test]
fn test_load_recipe_as_template_wrong_input_type_in_recipe_file() {
let instructions_and_parameters = r#"
"instructions": "Test instructions with {{ param }}",
"parameters": [
{
"key": "param",
"input_type": "some_invalid_type",
"requirement": "required",
"description": "A test parameter"
}
]"#;
let params = vec![("param".to_string(), "value".to_string())];
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
let load_recipe_result = load_recipe_as_template(recipe_path.to_str().unwrap(), params);
assert!(load_recipe_result.is_err());
let err = load_recipe_result.unwrap_err();
assert!(err
.to_string()
.contains("unknown variant `some_invalid_type`"));
}
#[test]
fn test_load_recipe_as_template_success_without_parameters() {
let instructions_and_parameters = r#"
"instructions": "Test instructions"
"#;
let (_temp_dir, recipe_path) = setup_recipe_file(instructions_and_parameters);
let recipe = load_recipe_as_template(recipe_path.to_str().unwrap(), Vec::new()).unwrap();
assert_eq!(recipe.instructions.unwrap(), "Test instructions");
assert!(recipe.parameters.is_none());
}
}

View File

@@ -0,0 +1,101 @@
use anyhow::{anyhow, Result};
use goose::config::Config;
use std::path::{Path, PathBuf};
use std::{env, fs};
use crate::recipes::recipe::RECIPE_FILE_EXTENSIONS;
use super::github_recipe::{retrieve_recipe_from_github, GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY};
const GOOSE_RECIPE_PATH_ENV_VAR: &str = "GOOSE_RECIPE_PATH";
pub fn retrieve_recipe_file(recipe_name: &str) -> Result<(String, PathBuf)> {
// If recipe_name ends with yaml or json, treat it as a direct file path
if RECIPE_FILE_EXTENSIONS
.iter()
.any(|ext| recipe_name.ends_with(&format!(".{}", ext)))
{
let path = PathBuf::from(recipe_name);
return read_recipe_file(path);
}
retrieve_recipe_from_local_path(recipe_name).or_else(|e| {
if let Some(recipe_repo_full_name) = configured_github_recipe_repo() {
retrieve_recipe_from_github(recipe_name, &recipe_repo_full_name)
} else {
Err(e)
}
})
}
fn read_recipe_in_dir(dir: &Path, recipe_name: &str) -> Result<(String, PathBuf)> {
for ext in RECIPE_FILE_EXTENSIONS {
let recipe_path = dir.join(format!("{}.{}", recipe_name, ext));
if let Ok(result) = read_recipe_file(recipe_path) {
return Ok(result);
}
}
Err(anyhow!(format!(
"No {}.yaml or {}.json recipe file found in directory: {}",
recipe_name,
recipe_name,
dir.display()
)))
}
fn retrieve_recipe_from_local_path(recipe_name: &str) -> Result<(String, PathBuf)> {
let mut search_dirs = vec![PathBuf::from(".")];
if let Ok(recipe_path_env) = env::var(GOOSE_RECIPE_PATH_ENV_VAR) {
let path_separator = if cfg!(windows) { ';' } else { ':' };
let recipe_path_env_dirs: Vec<PathBuf> = recipe_path_env
.split(path_separator)
.map(PathBuf::from)
.collect();
search_dirs.extend(recipe_path_env_dirs);
}
for dir in &search_dirs {
if let Ok(result) = read_recipe_in_dir(dir, recipe_name) {
return Ok(result);
}
}
let search_dirs_str = search_dirs
.iter()
.map(|p| p.to_string_lossy())
.collect::<Vec<_>>()
.join(":");
Err(anyhow!(
" Failed to retrieve {}.yaml or {}.json in {}",
recipe_name,
recipe_name,
search_dirs_str
))
}
fn configured_github_recipe_repo() -> Option<String> {
let config = Config::global();
match config.get_param(GOOSE_RECIPE_GITHUB_REPO_CONFIG_KEY) {
Ok(Some(recipe_repo_full_name)) => Some(recipe_repo_full_name),
_ => None,
}
}
fn read_recipe_file<P: AsRef<Path>>(recipe_path: P) -> Result<(String, PathBuf)> {
let path = recipe_path.as_ref();
let content = fs::read_to_string(path)
.map_err(|e| anyhow!("Failed to read recipe file {}: {}", path.display(), e))?;
let canonical = path.canonicalize().map_err(|e| {
anyhow!(
"Failed to resolve absolute path for {}: {}",
path.display(),
e
)
})?;
let parent_dir = canonical
.parent()
.ok_or_else(|| anyhow!("Resolved path has no parent: {}", canonical.display()))?
.to_path_buf();
Ok((content, parent_dir))
}

View File

@@ -21,6 +21,8 @@ pub struct SessionBuilderConfig {
pub identifier: Option<Identifier>,
/// Whether to resume an existing session
pub resume: bool,
/// Whether to run without a session file
pub no_session: bool,
/// List of stdio extension commands to add
pub extensions: Vec<String>,
/// List of remote extension commands to add
@@ -33,6 +35,8 @@ pub struct SessionBuilderConfig {
pub additional_system_prompt: Option<String>,
/// Enable debug printing
pub debug: bool,
/// Maximum number of consecutive identical tool calls allowed
pub max_tool_repetitions: Option<u32>,
}
pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
@@ -51,10 +55,31 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
// Create the agent
let agent: Agent = Agent::new();
let new_provider = create(&provider_name, model_config).unwrap();
let _ = agent.update_provider(new_provider).await;
agent
.update_provider(new_provider)
.await
.unwrap_or_else(|e| {
output::render_error(&format!("Failed to initialize agent: {}", e));
process::exit(1);
});
// Configure tool monitoring if max_tool_repetitions is set
if let Some(max_repetitions) = session_config.max_tool_repetitions {
agent.configure_tool_monitor(Some(max_repetitions)).await;
}
// Handle session file resolution and resuming
let session_file = if session_config.resume {
let session_file = if session_config.no_session {
// Use a temporary path that won't be written to
#[cfg(unix)]
{
std::path::PathBuf::from("/dev/null")
}
#[cfg(windows)]
{
std::path::PathBuf::from("NUL")
}
} else if session_config.resume {
if let Some(identifier) = session_config.identifier {
let session_file = session::get_path(identifier);
if !session_file.exists() {
@@ -87,7 +112,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
session::get_path(id)
};
if session_config.resume {
if session_config.resume && !session_config.no_session {
// Read the session metadata
let metadata = session::read_metadata(&session_file).unwrap_or_else(|e| {
output::render_error(&format!("Failed to read session metadata: {}", e));
@@ -103,7 +128,17 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
.interact().expect("Failed to get user input");
if change_workdir {
std::env::set_current_dir(metadata.working_dir).unwrap();
if !metadata.working_dir.exists() {
output::render_error(&format!(
"Cannot switch to original working directory - {} no longer exists",
style(metadata.working_dir.display()).cyan()
));
} else if let Err(e) = std::env::set_current_dir(&metadata.working_dir) {
output::render_error(&format!(
"Failed to switch to original working directory: {}",
e
));
}
}
}
}

View File

@@ -352,8 +352,13 @@ impl Helper for GooseCompleter {}
impl Hinter for GooseCompleter {
type Hint = String;
fn hint(&self, _line: &str, _pos: usize, _ctx: &rustyline::Context<'_>) -> Option<Self::Hint> {
None
fn hint(&self, line: &str, _pos: usize, _ctx: &rustyline::Context<'_>) -> Option<Self::Hint> {
// Only show hint when line is empty
if line.is_empty() {
Some("Press Enter to send, Ctrl-J for new line".to_string())
} else {
None
}
}
}
@@ -367,7 +372,9 @@ impl Highlighter for GooseCompleter {
}
fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> {
Cow::Borrowed(hint)
// Style the hint text with a dim color
let styled = console::Style::new().dim().apply_to(hint).to_string();
Cow::Owned(styled)
}
fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {

View File

@@ -18,6 +18,7 @@ pub enum InputResult {
Plan(PlanCommandOptions),
EndPlan,
Recipe(Option<String>),
Summarize,
}
#[derive(Debug)]
@@ -91,6 +92,7 @@ fn handle_slash_command(input: &str) -> Option<InputResult> {
const CMD_PLAN: &str = "/plan";
const CMD_ENDPLAN: &str = "/endplan";
const CMD_RECIPE: &str = "/recipe";
const CMD_SUMMARIZE: &str = "/summarize";
match input {
"/exit" | "/quit" => Some(InputResult::Exit),
@@ -133,6 +135,7 @@ fn handle_slash_command(input: &str) -> Option<InputResult> {
s if s.starts_with(CMD_PLAN) => parse_plan_command(s[CMD_PLAN.len()..].trim().to_string()),
s if s == CMD_ENDPLAN => Some(InputResult::EndPlan),
s if s.starts_with(CMD_RECIPE) => parse_recipe_command(s),
s if s == CMD_SUMMARIZE => Some(InputResult::Summarize),
_ => None,
}
}
@@ -241,6 +244,7 @@ fn print_help() {
/endplan - Exit plan mode and return to 'normal' goose mode.
/recipe [filepath] - Generate a recipe from the current conversation and save it to the specified filepath (must end with .yaml).
If no filepath is provided, it will be saved to ./recipe.yaml.
/summarize - Summarize the current conversation to reduce context length while preserving key information.
/? or /help - Display this help message
Navigation:
@@ -474,4 +478,15 @@ mod tests {
let result = handle_slash_command("/recipe /path/to/file.txt");
assert!(matches!(result, Some(InputResult::Retry)));
}
#[test]
fn test_summarize_command() {
// Test the summarize command
let result = handle_slash_command("/summarize");
assert!(matches!(result, Some(InputResult::Summarize)));
// Test with whitespace
let result = handle_slash_command(" /summarize ");
assert!(matches!(result, Some(InputResult::Summarize)));
}
}

View File

@@ -122,6 +122,21 @@ impl Session {
}
}
/// Helper function to summarize context messages
async fn summarize_context_messages(
messages: &mut Vec<Message>,
agent: &Agent,
message_suffix: &str,
) -> Result<()> {
// Summarize messages to fit within context length
let (summarized_messages, _) = agent.summarize_context(messages).await?;
let msg = format!("Context maxed out\n{}\n{}", "-".repeat(50), message_suffix);
output::render_text(&msg, Some(Color::Yellow), true);
*messages = summarized_messages;
Ok(())
}
/// Add a stdio extension to the session
///
/// # Arguments
@@ -290,6 +305,22 @@ impl Session {
// Persist messages with provider for automatic description generation
session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?;
// Track the current directory and last instruction in projects.json
let session_id = self
.session_file
.file_stem()
.and_then(|s| s.to_str())
.map(|s| s.to_string());
if let Err(e) =
crate::project_tracker::update_project_tracker(Some(&message), session_id.as_deref())
{
eprintln!(
"Warning: Failed to update project tracker with instruction: {}",
e
);
}
self.process_agent_response(false).await?;
Ok(())
}
@@ -356,6 +387,20 @@ impl Session {
self.messages.push(Message::user().with_text(&content));
// Track the current directory and last instruction in projects.json
let session_id = self
.session_file
.file_stem()
.and_then(|s| s.to_str())
.map(|s| s.to_string());
if let Err(e) = crate::project_tracker::update_project_tracker(
Some(&content),
session_id.as_deref(),
) {
eprintln!("Warning: Failed to update project tracker with instruction: {}", e);
}
// Get the provider from the agent for description generation
let provider = self.agent.provider().await?;
@@ -503,6 +548,62 @@ impl Session {
}
}
continue;
}
InputResult::Summarize => {
save_history(&mut editor);
let prompt = "Are you sure you want to summarize this conversation? This will condense the message history.";
let should_summarize =
match cliclack::confirm(prompt).initial_value(true).interact() {
Ok(choice) => choice,
Err(e) => {
if e.kind() == std::io::ErrorKind::Interrupted {
false // If interrupted, set should_summarize to false
} else {
return Err(e.into());
}
}
};
if should_summarize {
println!("{}", console::style("Summarizing conversation...").yellow());
output::show_thinking();
// Get the provider for summarization
let provider = self.agent.provider().await?;
// Call the summarize_context method which uses the summarize_messages function
let (summarized_messages, _) =
self.agent.summarize_context(&self.messages).await?;
// Update the session messages with the summarized ones
self.messages = summarized_messages;
// Persist the summarized messages
session::persist_messages(
&self.session_file,
&self.messages,
Some(provider),
)
.await?;
output::hide_thinking();
println!(
"{}",
console::style("Conversation has been summarized.").green()
);
println!(
"{}",
console::style(
"Key information has been preserved while reducing context length."
)
.green()
);
} else {
println!("{}", console::style("Summarization cancelled.").yellow());
}
continue;
}
}
@@ -532,10 +633,21 @@ impl Session {
match planner_response_type {
PlannerResponseType::Plan => {
println!();
let should_act =
cliclack::confirm("Do you want to clear message history & act on this plan?")
.initial_value(true)
.interact()?;
let should_act = match cliclack::confirm(
"Do you want to clear message history & act on this plan?",
)
.initial_value(true)
.interact()
{
Ok(choice) => choice,
Err(e) => {
if e.kind() == std::io::ErrorKind::Interrupted {
false // If interrupted, set should_act to false
} else {
return Err(e.into());
}
}
};
if should_act {
output::render_act_on_plan();
self.run_mode = RunMode::Normal;
@@ -596,6 +708,7 @@ impl Session {
id: session_id.clone(),
working_dir: std::env::current_dir()
.expect("failed to get current session working directory"),
schedule_id: None,
}),
)
.await?;
@@ -614,51 +727,99 @@ impl Session {
let prompt = "Goose would like to call the above tool, do you allow?".to_string();
// Get confirmation from user
let permission = cliclack::select(prompt)
let permission_result = cliclack::select(prompt)
.item(Permission::AllowOnce, "Allow", "Allow the tool call once")
.item(Permission::AlwaysAllow, "Always Allow", "Always allow the tool call")
.item(Permission::DenyOnce, "Deny", "Deny the tool call")
.interact()?;
self.agent.handle_confirmation(confirmation.id.clone(), PermissionConfirmation {
principal_type: PrincipalType::Tool,
permission,
},).await;
.item(Permission::Cancel, "Cancel", "Cancel the AI response and tool call")
.interact();
let permission = match permission_result {
Ok(p) => p, // If Ok, use the selected permission
Err(e) => {
// Check if the error is an interruption (Ctrl+C/Cmd+C, Escape)
if e.kind() == std::io::ErrorKind::Interrupted {
Permission::Cancel // If interrupted, set permission to Cancel
} else {
return Err(e.into()); // Otherwise, convert and propagate the original error
}
}
};
if permission == Permission::Cancel {
output::render_text("Tool call cancelled. Returning to chat...", Some(Color::Yellow), true);
let mut response_message = Message::user();
response_message.content.push(MessageContent::tool_response(
confirmation.id.clone(),
Err(ToolError::ExecutionError("Tool call cancelled by user".to_string()))
));
self.messages.push(response_message);
session::persist_messages(&self.session_file, &self.messages, None).await?;
drop(stream);
break;
} else {
self.agent.handle_confirmation(confirmation.id.clone(), PermissionConfirmation {
principal_type: PrincipalType::Tool,
permission,
},).await;
}
} else if let Some(MessageContent::ContextLengthExceeded(_)) = message.content.first() {
output::hide_thinking();
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
let selected = cliclack::select(prompt)
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
.interact()?;
if interactive {
// In interactive mode, ask the user what to do
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
let selected_result = cliclack::select(prompt)
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
.item("cancel", "Cancel", "Cancel and return to chat")
.interact();
match selected {
"clear" => {
self.messages.clear();
let msg = format!("Session cleared.\n{}", "-".repeat(50));
output::render_text(&msg, Some(Color::Yellow), true);
break; // exit the loop to hand back control to the user
}
"truncate" => {
// Truncate messages to fit within context length
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
let msg = format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50));
output::render_text("", Some(Color::Yellow), true);
output::render_text(&msg, Some(Color::Yellow), true);
self.messages = truncated_messages;
}
"summarize" => {
// Summarize messages to fit within context length
let (summarized_messages, _) = self.agent.summarize_context(&self.messages).await?;
let msg = format!("Context maxed out\n{}\nGoose summarized messages for you.", "-".repeat(50));
output::render_text(&msg, Some(Color::Yellow), true);
self.messages = summarized_messages;
}
_ => {
unreachable!()
let selected = match selected_result {
Ok(s) => s,
Err(e) => {
if e.kind() == std::io::ErrorKind::Interrupted {
"cancel" // If interrupted, set selected to cancel
} else {
return Err(e.into());
}
}
};
match selected {
"clear" => {
self.messages.clear();
let msg = format!("Session cleared.\n{}", "-".repeat(50));
output::render_text(&msg, Some(Color::Yellow), true);
break; // exit the loop to hand back control to the user
}
"truncate" => {
// Truncate messages to fit within context length
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
let msg = format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50));
output::render_text("", Some(Color::Yellow), true);
output::render_text(&msg, Some(Color::Yellow), true);
self.messages = truncated_messages;
}
"summarize" => {
// Use the helper function to summarize context
Self::summarize_context_messages(&mut self.messages, &self.agent, "Goose summarized messages for you.").await?;
}
"cancel" => {
break; // Return to main prompt
}
_ => {
unreachable!()
}
}
} else {
// In headless mode (goose run), automatically use summarize
Self::summarize_context_messages(&mut self.messages, &self.agent, "Goose automatically summarized messages to continue processing.").await?;
}
// Restart the stream after handling ContextLengthExceeded
stream = self
.agent
@@ -668,6 +829,7 @@ impl Session {
id: session_id.clone(),
working_dir: std::env::current_dir()
.expect("failed to get current session working directory"),
schedule_id: None,
}),
)
.await?;
@@ -852,6 +1014,31 @@ impl Session {
self.messages.clone()
}
/// Render all past messages from the session history
pub fn render_message_history(&self) {
if self.messages.is_empty() {
return;
}
// Print session restored message
println!(
"\n{} {} messages loaded into context.",
console::style("Session restored:").green().bold(),
console::style(self.messages.len()).green()
);
// Render each message
for message in &self.messages {
output::render_message(message, self.debug);
}
// Add a visual separator after restored messages
println!(
"\n{}\n",
console::style("──────── New Messages ────────").dim()
);
}
/// Get the session metadata
pub fn get_metadata(&self) -> Result<session::SessionMetadata> {
if !self.session_file.exists() {
@@ -976,29 +1163,30 @@ fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
use goose::model::ModelConfig;
use goose::providers::create;
let (reasoner_provider, reasoner_model) = match (
std::env::var("GOOSE_PLANNER_PROVIDER"),
std::env::var("GOOSE_PLANNER_MODEL"),
) {
(Ok(provider), Ok(model)) => (provider, model),
_ => {
println!(
"WARNING: GOOSE_PLANNER_PROVIDER or GOOSE_PLANNER_MODEL is not set. \
Using default model from config..."
);
let config = Config::global();
let provider = config
.get_param("GOOSE_PROVIDER")
.expect("No provider configured. Run 'goose configure' first");
let model = config
.get_param("GOOSE_MODEL")
.expect("No model configured. Run 'goose configure' first");
(provider, model)
}
let config = Config::global();
// Try planner-specific provider first, fallback to default provider
let provider = if let Ok(provider) = config.get_param::<String>("GOOSE_PLANNER_PROVIDER") {
provider
} else {
println!("WARNING: GOOSE_PLANNER_PROVIDER not found. Using default provider...");
config
.get_param::<String>("GOOSE_PROVIDER")
.expect("No provider configured. Run 'goose configure' first")
};
let model_config = ModelConfig::new(reasoner_model);
let reasoner = create(&reasoner_provider, model_config)?;
// Try planner-specific model first, fallback to default model
let model = if let Ok(model) = config.get_param::<String>("GOOSE_PLANNER_MODEL") {
model
} else {
println!("WARNING: GOOSE_PLANNER_MODEL not found. Using default model...");
config
.get_param::<String>("GOOSE_MODEL")
.expect("No model configured. Run 'goose configure' first")
};
let model_config = ModelConfig::new(model);
let reasoner = create(&provider, model_config)?;
Ok(reasoner)
}

View File

@@ -209,7 +209,7 @@ fn render_tool_response(resp: &ToolResponse, theme: Theme, debug: bool) {
let min_priority = config
.get_param::<f32>("GOOSE_CLI_MIN_PRIORITY")
.ok()
.unwrap_or(0.0);
.unwrap_or(0.5);
if content
.priority()
@@ -405,9 +405,15 @@ fn print_markdown(content: &str, theme: Theme) {
.unwrap();
}
const MAX_STRING_LENGTH: usize = 40;
const INDENT: &str = " ";
fn get_tool_params_max_length() -> usize {
Config::global()
.get_param::<usize>("GOOSE_CLI_TOOL_PARAMS_TRUNCATION_MAX_LENGTH")
.ok()
.unwrap_or(40)
}
fn print_params(value: &Value, depth: usize, debug: bool) {
let indent = INDENT.repeat(depth);
@@ -427,7 +433,7 @@ fn print_params(value: &Value, depth: usize, debug: bool) {
}
}
Value::String(s) => {
if !debug && s.len() > MAX_STRING_LENGTH {
if !debug && s.len() > get_tool_params_max_length() {
println!("{}{}: {}", indent, style(key).dim(), style("...").dim());
} else {
println!("{}{}: {}", indent, style(key).dim(), style(s).green());
@@ -452,7 +458,7 @@ fn print_params(value: &Value, depth: usize, debug: bool) {
}
}
Value::String(s) => {
if !debug && s.len() > MAX_STRING_LENGTH {
if !debug && s.len() > get_tool_params_max_length() {
println!(
"{}{}",
indent,
@@ -527,6 +533,8 @@ fn shorten_path(path: &str, debug: bool) -> String {
pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) {
let start_session_msg = if resume {
"resuming session |"
} else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") {
"running without session |"
} else {
"starting session |"
};
@@ -538,11 +546,15 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
style("model:").dim(),
style(model).cyan().dim(),
);
println!(
" {} {}",
style("logging to").dim(),
style(session_file.display()).dim().cyan(),
);
if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") {
println!(
" {} {}",
style("logging to").dim(),
style(session_file.display()).dim().cyan(),
);
}
println!(
" {} {}",
style("working directory:").dim(),

View File

@@ -0,0 +1,36 @@
use std::future::Future;
use std::pin::Pin;
use tokio::signal;
#[cfg(unix)]
pub fn shutdown_signal() -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async move {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
})
}
#[cfg(not(unix))]
pub fn shutdown_signal() -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async move {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
})
}

View File

@@ -16,10 +16,10 @@ To build the FFI library, you'll need Rust and Cargo installed. Then run:
```bash
# Build the library in debug mode
cargo build --package goose_ffi
cargo build --package goose-ffi
# Build the library in release mode (recommended for production)
cargo build --release --package goose_ffi
cargo build --release --package goose-ffi
```
This will generate a dynamic library (.so, .dll, or .dylib depending on your platform) in the `target` directory, and automatically generate the C header file in the `include` directory.
@@ -54,7 +54,7 @@ To run the Python example:
```bash
# First, build the FFI library
cargo build --release --package goose_ffi
cargo build --release --package goose-ffi
# Then set the environment variables & run the example
DATABRICKS_HOST=... DATABRICKS_API_KEY=... python crates/goose-ffi/examples/goose_agent.py

View File

@@ -0,0 +1,61 @@
[package]
name = "goose-llm"
edition.workspace = true
version.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
description.workspace = true
[lib]
crate-type = ["lib", "cdylib"]
name = "goose_llm"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
anyhow = "1.0"
thiserror = "1.0"
minijinja = "2.8.0"
include_dir = "0.7.4"
once_cell = "1.20.2"
chrono = { version = "0.4.38", features = ["serde"] }
reqwest = { version = "0.12.9", features = [
"rustls-tls-native-roots",
"json",
"cookies",
"gzip",
"brotli",
"deflate",
"zstd",
"charset",
"http2",
"stream"
], default-features = false }
async-trait = "0.1"
url = "2.5"
base64 = "0.21"
regex = "1.11.1"
tracing = "0.1"
smallvec = { version = "1.13", features = ["serde"] }
indoc = "1.0"
# https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/fixtures/futures/Cargo.toml#L22
uniffi = { version = "0.29", features = ["tokio", "cli", "scaffolding-ffi-buffer-fns"] }
tokio = { version = "1.43", features = ["time", "sync"] }
[dev-dependencies]
criterion = "0.5"
tempfile = "3.15.0"
dotenv = "0.15"
lazy_static = "1.5"
ctor = "0.2.7"
tokio = { version = "1.43", features = ["full"] }
[[bin]]
# https://mozilla.github.io/uniffi-rs/latest/tutorial/foreign_language_bindings.html
name = "uniffi-bindgen"
path = "uniffi-bindgen.rs"
[[example]]
name = "simple"
path = "examples/simple.rs"

View File

@@ -0,0 +1,68 @@
## goose-llm
This crate is meant to be used for foreign function interface (FFI). It's meant to be
stateless and contain logic related to providers and prompts:
- chat completion with model providers
- detecting read-only tools for smart approval
- methods for summarization / truncation
Run:
```
cargo run -p goose-llm --example simple
```
## Kotlin bindings
Structure:
```
.
└── crates
└── goose-llm/...
└── target
└── debug/libgoose_llm.dylib
├── bindings
│ └── kotlin
│ ├── example
│ │ └── Usage.kt ← your demo app
│ └── uniffi
│ └── goose_llm
│ └── goose_llm.kt ← auto-generated bindings
```
#### Kotlin -> Rust: run example
The following `just` command creates kotlin bindings, then compiles and runs an example.
```bash
just kotlin-example
```
You will have to download jars in `bindings/kotlin/libs` directory (only the first time):
```bash
pushd bindings/kotlin/libs/
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlin/kotlin-stdlib/1.9.0/kotlin-stdlib-1.9.0.jar
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlinx/kotlinx-coroutines-core-jvm/1.7.3/kotlinx-coroutines-core-jvm-1.7.3.jar
curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.jar
popd
```
To just create the Kotlin bindings:
```bash
# run from project root directory
cargo build -p goose-llm
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
```
#### Python -> Rust: generate bindings, run example
```bash
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language python --out-dir bindings/python
DYLD_LIBRARY_PATH=./target/debug python bindings/python/usage.py
```

View File

@@ -0,0 +1,123 @@
use std::vec;
use anyhow::Result;
use goose_llm::{
completion,
extractors::generate_tooltip,
types::completion::{
CompletionRequest, CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig,
},
Message, ModelConfig,
};
use serde_json::json;
#[tokio::main]
async fn main() -> Result<()> {
let provider = "databricks";
let provider_config = json!({
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
// let model_name = "goose-gpt-4-1"; // parallel tool calls
let model_name = "claude-3-5-haiku";
let model_config = ModelConfig::new(model_name.to_string());
let calculator_tool = ToolConfig::new(
"calculator",
"Perform basic arithmetic operations",
json!({
"type": "object",
"required": ["operation", "numbers"],
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform",
},
"numbers": {
"type": "array",
"items": {"type": "number"},
"description": "List of numbers to operate on in order",
}
}
}),
ToolApprovalMode::Auto,
);
let bash_tool = ToolConfig::new(
"bash_shell",
"Run a shell command",
json!({
"type": "object",
"required": ["command"],
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
}
}
}),
ToolApprovalMode::Manual,
);
let list_dir_tool = ToolConfig::new(
"list_directory",
"List files in a directory",
json!({
"type": "object",
"required": ["path"],
"properties": {
"path": {
"type": "string",
"description": "The directory path to list files from"
}
}
}),
ToolApprovalMode::Auto,
);
let extensions = vec![
ExtensionConfig::new(
"calculator_extension".to_string(),
Some("This extension provides a calculator tool.".to_string()),
vec![calculator_tool],
),
ExtensionConfig::new(
"bash_extension".to_string(),
Some("This extension provides a bash shell tool.".to_string()),
vec![bash_tool, list_dir_tool],
),
];
let system_preamble = "You are a helpful assistant.";
for text in [
"Add 10037 + 23123 using calculator and also run 'date -u' using bash",
"List all files in the current directory",
] {
println!("\n---------------\n");
println!("User Input: {text}");
let messages = vec![
Message::user().with_text("Hi there!"),
Message::assistant().with_text("How can I help?"),
Message::user().with_text(text),
];
let completion_response: CompletionResponse = completion(CompletionRequest::new(
provider.to_string(),
provider_config.clone(),
model_config.clone(),
system_preamble.to_string(),
messages.clone(),
extensions.clone(),
))
.await?;
// Print the response
println!("\nCompletion Response:");
println!("{}", serde_json::to_string_pretty(&completion_response)?);
let tooltip = generate_tooltip(provider, provider_config.clone().into(), &messages).await?;
println!("\nTooltip: {}", tooltip);
}
Ok(())
}

View File

@@ -0,0 +1,144 @@
use std::{collections::HashMap, time::Instant};
use anyhow::Result;
use chrono::Utc;
use serde_json::Value;
use crate::{
message::{Message, MessageContent},
prompt_template,
providers::create,
types::{
completion::{
CompletionError, CompletionRequest, CompletionResponse, ExtensionConfig,
RuntimeMetrics, ToolApprovalMode, ToolConfig,
},
core::ToolCall,
},
};
#[uniffi::export]
pub fn print_messages(messages: Vec<Message>) {
for msg in messages {
println!("[{:?} @ {}] {:?}", msg.role, msg.created, msg.content);
}
}
/// Public API for the Goose LLM completion function
#[uniffi::export(async_runtime = "tokio")]
pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, CompletionError> {
let start_total = Instant::now();
let provider = create(
&req.provider_name,
req.provider_config.clone(),
req.model_config.clone(),
)
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
let system_prompt = construct_system_prompt(&req.system_preamble, &req.extensions)?;
let tools = collect_prefixed_tools(&req.extensions);
// Call the LLM provider
let start_provider = Instant::now();
let mut response = provider
.complete(&system_prompt, &req.messages, &tools)
.await?;
let provider_elapsed_sec = start_provider.elapsed().as_secs_f32();
let usage_tokens = response.usage.total_tokens;
let tool_configs = collect_prefixed_tool_configs(&req.extensions);
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?;
Ok(CompletionResponse::new(
response.message,
response.model,
response.usage,
calculate_runtime_metrics(start_total, provider_elapsed_sec, usage_tokens),
))
}
/// Render the global `system.md` template with the provided context.
fn construct_system_prompt(
system_preamble: &str,
extensions: &[ExtensionConfig],
) -> Result<String, CompletionError> {
let mut context: HashMap<&str, Value> = HashMap::new();
context.insert("system_preamble", Value::String(system_preamble.to_owned()));
context.insert("extensions", serde_json::to_value(extensions)?);
context.insert(
"current_date",
Value::String(Utc::now().format("%Y-%m-%d").to_string()),
);
Ok(prompt_template::render_global_file("system.md", &context)?)
}
/// Determine if a tool call requires manual approval.
fn determine_needs_approval(config: &ToolConfig, _call: &ToolCall) -> bool {
match config.approval_mode {
ToolApprovalMode::Auto => false,
ToolApprovalMode::Manual => true,
ToolApprovalMode::Smart => {
// TODO: Implement smart approval logic later
true
}
}
}
/// Set `needs_approval` on every tool call in the message.
/// Returns a `ToolNotFound` error if the corresponding `ToolConfig` is missing.
pub fn update_needs_approval_for_tool_calls(
message: &mut Message,
tool_configs: &HashMap<String, ToolConfig>,
) -> Result<(), CompletionError> {
for content in &mut message.content.iter_mut() {
if let MessageContent::ToolReq(req) = content {
if let Ok(call) = &mut req.tool_call.0 {
// Provide a clear error message when the tool config is missing
let config = tool_configs.get(&call.name).ok_or_else(|| {
CompletionError::ToolNotFound(format!(
"could not find tool config for '{}'",
call.name
))
})?;
let needs_approval = determine_needs_approval(config, call);
call.set_needs_approval(needs_approval);
}
}
}
Ok(())
}
/// Collect all `Tool` instances from the extensions.
fn collect_prefixed_tools(extensions: &[ExtensionConfig]) -> Vec<crate::types::core::Tool> {
extensions
.iter()
.flat_map(|ext| ext.get_prefixed_tools())
.collect()
}
/// Collect all `ToolConfig` entries from the extensions into a map.
fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap<String, ToolConfig> {
extensions
.iter()
.flat_map(|ext| ext.get_prefixed_tool_configs())
.collect()
}
/// Compute runtime metrics for the request.
fn calculate_runtime_metrics(
total_start: Instant,
provider_elapsed_sec: f32,
token_count: Option<i32>,
) -> RuntimeMetrics {
let total_ms = total_start.elapsed().as_secs_f32();
let tokens_per_sec = token_count.and_then(|toks| {
if provider_elapsed_sec > 0.0 {
Some(toks as f64 / (provider_elapsed_sec as f64))
} else {
None
}
});
RuntimeMetrics::new(total_ms, provider_elapsed_sec, tokens_per_sec)
}

View File

@@ -0,0 +1,5 @@
mod session_name;
mod tooltip;
pub use session_name::generate_session_name;
pub use tooltip::generate_tooltip;

View File

@@ -0,0 +1,111 @@
use crate::generate_structured_outputs;
use crate::providers::errors::ProviderError;
use crate::types::core::Role;
use crate::{message::Message, types::json_value_ffi::JsonValueFfi};
use anyhow::Result;
use indoc::indoc;
use serde_json::{json, Value};
const SESSION_NAME_EXAMPLES: &[&str] = &[
"Research Synthesis",
"Sentiment Analysis",
"Performance Report",
"Feedback Collector",
"Accessibility Check",
"Design Reminder",
"Project Reminder",
"Launch Checklist",
"Metrics Monitor",
"Incident Response",
"Deploy Cabinet App",
"Design Reminder Alert",
"Generate Monthly Expense Report",
"Automate Incident Response Workflow",
"Analyze Brand Sentiment Trends",
"Monitor Device Health Issues",
"Collect UI Feedback Summary",
"Schedule Project Deadline Reminders",
];
fn build_system_prompt() -> String {
let examples = SESSION_NAME_EXAMPLES
.iter()
.map(|e| format!("- {}", e))
.collect::<Vec<_>>()
.join("\n");
indoc! {r#"
You are an assistant that crafts a concise session title.
Given the first couple user messages in the conversation so far,
reply with only a short name (up to 4 words) that best describes
this sessions goal.
Examples:
"#}
.to_string()
+ &examples
}
/// Generates a short (≤4 words) session name
#[uniffi::export(async_runtime = "tokio")]
pub async fn generate_session_name(
provider_name: &str,
provider_config: JsonValueFfi,
messages: &[Message],
) -> Result<String, ProviderError> {
// Collect up to the first 3 user messages (truncated to 300 chars each)
let context: Vec<String> = messages
.iter()
.filter(|m| m.role == Role::User)
.take(3)
.map(|m| {
let text = m.content.concat_text_str();
if text.len() > 300 {
text.chars().take(300).collect()
} else {
text
}
})
.collect();
if context.is_empty() {
return Err(ProviderError::ExecutionError(
"No user messages found to generate a session name.".to_string(),
));
}
let system_prompt = build_system_prompt();
let user_msg_text = format!("Here are the user messages:\n{}", context.join("\n"));
// Use `extract` with a simple string schema
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"],
"additionalProperties": false
});
let resp = generate_structured_outputs(
provider_name,
provider_config,
&system_prompt,
&[Message::user().with_text(&user_msg_text)],
schema,
)
.await?;
let obj = resp
.data
.as_object()
.ok_or_else(|| ProviderError::ResponseParseError("Expected object".into()))?;
let name = obj
.get("name")
.and_then(Value::as_str)
.ok_or_else(|| ProviderError::ResponseParseError("Missing or non-string name".into()))?
.to_string();
Ok(name)
}

View File

@@ -0,0 +1,169 @@
use crate::generate_structured_outputs;
use crate::message::{Message, MessageContent};
use crate::providers::errors::ProviderError;
use crate::types::core::{Content, Role};
use crate::types::json_value_ffi::JsonValueFfi;
use anyhow::Result;
use indoc::indoc;
use serde_json::{json, Value};
const TOOLTIP_EXAMPLES: &[&str] = &[
"analyzing KPIs",
"detecting anomalies",
"building artifacts in Buildkite",
"categorizing issues",
"checking dependencies",
"collecting feedback",
"deploying changes in AWS",
"drafting report in Google Docs",
"extracting action items",
"generating insights",
"logging issues",
"monitoring tickets in Zendesk",
"notifying design team",
"running integration tests",
"scanning threads in Figma",
"sending reminders in Gmail",
"sending surveys",
"sharing with stakeholders",
"summarizing findings",
"transcribing meeting",
"tracking resolution",
"updating status in Linear",
];
fn build_system_prompt() -> String {
let examples = TOOLTIP_EXAMPLES
.iter()
.map(|e| format!("- {}", e))
.collect::<Vec<_>>()
.join("\n");
indoc! {r#"
You are an assistant that summarizes the recent conversation into a tooltip.
Given the last two messages, reply with only a short tooltip (up to 4 words)
describing what is happening now.
Examples:
"#}
.to_string()
+ &examples
}
/// Generates a tooltip summarizing the last two messages in the session,
/// including any tool calls or results.
#[uniffi::export(async_runtime = "tokio")]
pub async fn generate_tooltip(
provider_name: &str,
provider_config: JsonValueFfi,
messages: &[Message],
) -> Result<String, ProviderError> {
// Need at least two messages to generate a tooltip
if messages.len() < 2 {
return Err(ProviderError::ExecutionError(
"Need at least two messages to generate a tooltip".to_string(),
));
}
// Helper to render a single message's content
fn render_message(m: &Message) -> String {
let mut parts = Vec::new();
for content in m.content.iter() {
match content {
MessageContent::Text(text_block) => {
let txt = text_block.text.trim();
if !txt.is_empty() {
parts.push(txt.to_string());
}
}
MessageContent::ToolReq(req) => {
if let Ok(tool_call) = &req.tool_call.0 {
parts.push(format!(
"called tool '{}' with args {}",
tool_call.name, tool_call.arguments
));
} else if let Err(e) = &req.tool_call.0 {
parts.push(format!("tool request error: {}", e));
}
}
MessageContent::ToolResp(resp) => match &resp.tool_result.0 {
Ok(contents) => {
let results: Vec<String> = contents
.iter()
.map(|c| match c {
Content::Text(t) => t.text.clone(),
Content::Image(_) => "[image]".to_string(),
})
.collect();
parts.push(format!("tool responded with: {}", results.join(" ")));
}
Err(e) => {
parts.push(format!("tool error: {}", e));
}
},
_ => {} // ignore other variants
}
}
let role = match m.role {
Role::User => "User",
Role::Assistant => "Assistant",
};
format!("{}: {}", role, parts.join("; "))
}
// Take the last two messages (in correct chronological order)
let rendered: Vec<String> = messages
.iter()
.rev()
.take(2)
.map(render_message)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let system_prompt = build_system_prompt();
let user_msg_text = format!(
"Here are the last two messages:\n{}\n\nTooltip:",
rendered.join("\n")
);
// Schema wrapping our tooltip string
let schema = json!({
"type": "object",
"properties": {
"tooltip": { "type": "string" }
},
"required": ["tooltip"],
"additionalProperties": false
});
// Get the structured outputs
let resp = generate_structured_outputs(
provider_name,
provider_config,
&system_prompt,
&[Message::user().with_text(&user_msg_text)],
schema,
)
.await?;
// Pull out the tooltip field
let obj = resp
.data
.as_object()
.ok_or_else(|| ProviderError::ResponseParseError("Expected JSON object".into()))?;
let tooltip = obj
.get("tooltip")
.and_then(Value::as_str)
.ok_or_else(|| {
ProviderError::ResponseParseError("Missing or non-string `tooltip` field".into())
})?
.to_string();
Ok(tooltip)
}

View File

@@ -0,0 +1,15 @@
uniffi::setup_scaffolding!();
mod completion;
pub mod extractors;
pub mod message;
mod model;
mod prompt_template;
pub mod providers;
mod structured_outputs;
pub mod types;
pub use completion::completion;
pub use message::Message;
pub use model::ModelConfig;
pub use structured_outputs::generate_structured_outputs;

View File

@@ -0,0 +1,84 @@
use std::{iter::FromIterator, ops::Deref};
use crate::message::MessageContent;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
/// Holds the heterogeneous fragments that make up one chat message.
///
/// * Up to two items are stored inline on the stack.
/// * Falls back to a heap allocation only when necessary.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(transparent)]
pub struct Contents(SmallVec<[MessageContent; 2]>);
impl Contents {
/*----------------------------------------------------------
* 1-line ergonomic helpers
*---------------------------------------------------------*/
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, MessageContent> {
self.0.iter_mut()
}
pub fn push(&mut self, item: impl Into<MessageContent>) {
self.0.push(item.into());
}
pub fn texts(&self) -> impl Iterator<Item = &str> {
self.0.iter().filter_map(|c| c.as_text())
}
pub fn concat_text_str(&self) -> String {
self.texts().collect::<Vec<_>>().join("\n")
}
/// Returns `true` if *any* item satisfies the predicate.
pub fn any_is<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().any(pred)
}
/// Returns `true` if *every* item satisfies the predicate.
pub fn all_are<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().all(pred)
}
}
impl From<Vec<MessageContent>> for Contents {
fn from(v: Vec<MessageContent>) -> Self {
Contents(SmallVec::from_vec(v))
}
}
impl FromIterator<MessageContent> for Contents {
fn from_iter<I: IntoIterator<Item = MessageContent>>(iter: I) -> Self {
Contents(SmallVec::from_iter(iter))
}
}
/*--------------------------------------------------------------
* Allow &message.content to behave like a slice of fragments.
*-------------------------------------------------------------*/
impl Deref for Contents {
type Target = [MessageContent];
fn deref(&self) -> &Self::Target {
&self.0
}
}
// — Register the contents type with UniFFI, converting to/from Vec<MessageContent> —
// We need to do this because UniFFIs FFI layer supports only primitive buffers (here Vec<u8>),
uniffi::custom_type!(Contents, Vec<MessageContent>, {
lower: |contents: &Contents| {
contents.0.to_vec()
},
try_lift: |contents: Vec<MessageContent>| {
Ok(Contents::from(contents))
},
});

View File

@@ -0,0 +1,240 @@
use serde::{Deserialize, Serialize};
use serde_json;
use crate::message::tool_result_serde;
use crate::types::core::{Content, ImageContent, TextContent, ToolCall, ToolResult};
// — Newtype wrappers (local structs) so we satisfy Rusts orphan rules —
// We need these because we cant implement UniFFIs FfiConverter directly on a type alias.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolRequestToolCall(#[serde(with = "tool_result_serde")] pub ToolResult<ToolCall>);
impl ToolRequestToolCall {
pub fn as_result(&self) -> &ToolResult<ToolCall> {
&self.0
}
}
impl std::ops::Deref for ToolRequestToolCall {
type Target = ToolResult<ToolCall>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Result<ToolCall, crate::types::core::ToolError>> for ToolRequestToolCall {
fn from(res: Result<ToolCall, crate::types::core::ToolError>) -> Self {
ToolRequestToolCall(res)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolResponseToolResult(
#[serde(with = "tool_result_serde")] pub ToolResult<Vec<Content>>,
);
impl ToolResponseToolResult {
pub fn as_result(&self) -> &ToolResult<Vec<Content>> {
&self.0
}
}
impl std::ops::Deref for ToolResponseToolResult {
type Target = ToolResult<Vec<Content>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Result<Vec<Content>, crate::types::core::ToolError>> for ToolResponseToolResult {
fn from(res: Result<Vec<Content>, crate::types::core::ToolError>) -> Self {
ToolResponseToolResult(res)
}
}
// — Register the newtypes with UniFFI, converting via JSON strings —
// UniFFIs FFI layer supports only primitive buffers (here String), so we JSON-serialize
// through our `tool_result_serde` to preserve the same success/error schema on both sides.
uniffi::custom_type!(ToolRequestToolCall, String, {
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
uniffi::custom_type!(ToolResponseToolResult, String, {
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct ToolRequest {
pub id: String,
pub tool_call: ToolRequestToolCall,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {
pub id: String,
pub tool_result: ToolResponseToolResult,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
pub struct ThinkingContent {
pub thinking: String,
pub signature: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
pub struct RedactedThinkingContent {
pub data: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
/// Content passed inside a message, which can be both simple content and tool content
#[serde(tag = "type", rename_all = "camelCase")]
pub enum MessageContent {
Text(TextContent),
Image(ImageContent),
ToolReq(ToolRequest),
ToolResp(ToolResponse),
Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent),
}
impl MessageContent {
pub fn text<S: Into<String>>(text: S) -> Self {
MessageContent::Text(TextContent { text: text.into() })
}
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
MessageContent::Image(ImageContent {
data: data.into(),
mime_type: mime_type.into(),
})
}
pub fn tool_request<S: Into<String>>(id: S, tool_call: ToolRequestToolCall) -> Self {
MessageContent::ToolReq(ToolRequest {
id: id.into(),
tool_call,
})
}
pub fn tool_response<S: Into<String>>(id: S, tool_result: ToolResponseToolResult) -> Self {
MessageContent::ToolResp(ToolResponse {
id: id.into(),
tool_result,
})
}
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
MessageContent::Thinking(ThinkingContent {
thinking: thinking.into(),
signature: signature.into(),
})
}
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
}
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
if let MessageContent::ToolReq(ref tool_request) = self {
Some(tool_request)
} else {
None
}
}
pub fn as_tool_response(&self) -> Option<&ToolResponse> {
if let MessageContent::ToolResp(ref tool_response) = self {
Some(tool_response)
} else {
None
}
}
pub fn as_tool_response_text(&self) -> Option<String> {
if let Some(tool_response) = self.as_tool_response() {
if let Ok(contents) = &tool_response.tool_result.0 {
let texts: Vec<String> = contents
.iter()
.filter_map(|content| content.as_text().map(String::from))
.collect();
if !texts.is_empty() {
return Some(texts.join("\n"));
}
}
}
None
}
pub fn as_tool_request_id(&self) -> Option<&str> {
if let Self::ToolReq(r) = self {
Some(&r.id)
} else {
None
}
}
pub fn as_tool_response_id(&self) -> Option<&str> {
if let Self::ToolResp(r) = self {
Some(&r.id)
} else {
None
}
}
/// Get the text content if this is a TextContent variant
pub fn as_text(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(&text.text),
_ => None,
}
}
/// Get the thinking content if this is a ThinkingContent variant
pub fn as_thinking(&self) -> Option<&ThinkingContent> {
match self {
MessageContent::Thinking(thinking) => Some(thinking),
_ => None,
}
}
/// Get the redacted thinking content if this is a RedactedThinkingContent variant
pub fn as_redacted_thinking(&self) -> Option<&RedactedThinkingContent> {
match self {
MessageContent::RedactedThinking(redacted) => Some(redacted),
_ => None,
}
}
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub fn is_image(&self) -> bool {
matches!(self, Self::Image(_))
}
pub fn is_tool_request(&self) -> bool {
matches!(self, Self::ToolReq(_))
}
pub fn is_tool_response(&self) -> bool {
matches!(self, Self::ToolResp(_))
}
}
impl From<Content> for MessageContent {
fn from(content: Content) -> Self {
match content {
Content::Text(text) => MessageContent::Text(text),
Content::Image(image) => MessageContent::Image(image),
}
}
}

View File

@@ -0,0 +1,284 @@
//! Messages which represent the content sent back and forth to LLM provider
//!
//! We use these messages in the agent code, and interfaces which interact with
//! the agent. That let's us reuse message histories across different interfaces.
//!
//! The content of the messages uses MCP types to avoid additional conversions
//! when interacting with MCP servers.
mod contents;
mod message_content;
mod tool_result_serde;
pub use contents::Contents;
pub use message_content::{
MessageContent, RedactedThinkingContent, ThinkingContent, ToolRequest, ToolRequestToolCall,
ToolResponse, ToolResponseToolResult,
};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use crate::types::core::Role;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
pub role: Role,
pub created: i64,
pub content: Contents,
}
impl Message {
pub fn new(role: Role) -> Self {
Self {
role,
created: Utc::now().timestamp_millis(),
content: Contents::default(),
}
}
/// Create a new user message with the current timestamp
pub fn user() -> Self {
Self::new(Role::User)
}
/// Create a new assistant message with the current timestamp
pub fn assistant() -> Self {
Self::new(Role::Assistant)
}
/// Add any item that implements Into<MessageContent> to the message
pub fn with_content(mut self, item: impl Into<MessageContent>) -> Self {
self.content.push(item);
self
}
/// Add text content to the message
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
self.with_content(MessageContent::text(text))
}
/// Add image content to the message
pub fn with_image<S: Into<String>, T: Into<String>>(self, data: S, mime_type: T) -> Self {
self.with_content(MessageContent::image(data, mime_type))
}
/// Add a tool request to the message
pub fn with_tool_request<S: Into<String>, T: Into<ToolRequestToolCall>>(
self,
id: S,
tool_call: T,
) -> Self {
self.with_content(MessageContent::tool_request(id, tool_call.into()))
}
/// Add a tool response to the message
pub fn with_tool_response<S: Into<String>>(
self,
id: S,
result: ToolResponseToolResult,
) -> Self {
self.with_content(MessageContent::tool_response(id, result))
}
/// Add thinking content to the message
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
self,
thinking: S1,
signature: S2,
) -> Self {
self.with_content(MessageContent::thinking(thinking, signature))
}
/// Add redacted thinking content to the message
pub fn with_redacted_thinking<S: Into<String>>(self, data: S) -> Self {
self.with_content(MessageContent::redacted_thinking(data))
}
/// Check if the message is a tool call
pub fn contains_tool_call(&self) -> bool {
self.content.any_is(MessageContent::is_tool_request)
}
/// Check if the message is a tool response
pub fn contains_tool_response(&self) -> bool {
self.content.any_is(MessageContent::is_tool_response)
}
/// Check if the message contains only text content
pub fn has_only_text_content(&self) -> bool {
self.content.all_are(MessageContent::is_text)
}
/// Retrieves all tool `id` from ToolRequest messages
pub fn tool_request_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_request_id)
.collect()
}
/// Retrieves all tool `id` from ToolResponse messages
pub fn tool_response_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_response_id)
.collect()
}
/// Retrieves all tool `id` from the message
pub fn tool_ids(&self) -> HashSet<&str> {
self.tool_request_ids()
.into_iter()
.chain(self.tool_response_ids())
.collect()
}
}
#[cfg(test)]
mod tests {
use serde_json::{json, Value};
use super::*;
use crate::types::core::{ToolCall, ToolError};
#[test]
fn test_message_serialization() {
let message = Message::assistant()
.with_text("Hello, I'll help you with that.")
.with_tool_request(
"tool123",
Ok(ToolCall::new("test_tool", json!({"param": "value"})).into()),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized message: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
println!(
"Read back serialized message: {}",
serde_json::to_string_pretty(&value).unwrap()
);
// Check top-level fields
assert_eq!(value["role"], "assistant");
assert!(value["created"].is_i64());
assert!(value["content"].is_array());
// Check content items
let content = &value["content"];
// First item should be text
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
// Second item should be toolRequest
assert_eq!(content[1]["type"], "toolReq");
assert_eq!(content[1]["id"], "tool123");
// Check tool_call serialization
assert_eq!(content[1]["toolCall"]["status"], "success");
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
assert_eq!(
content[1]["toolCall"]["value"]["arguments"]["param"],
"value"
);
}
#[test]
fn test_error_serialization() {
let message = Message::assistant().with_tool_request(
"tool123",
Err(ToolError::ExecutionError(
"Something went wrong".to_string(),
)),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized error: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check tool_call serialization with error
let tool_call = &value["content"][0]["toolCall"];
assert_eq!(tool_call["status"], "error");
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
}
#[test]
fn test_deserialization() {
// Create a JSON string with our new format
let json_str = r#"{
"role": "assistant",
"created": 1740171566,
"content": [
{
"type": "text",
"text": "I'll help you with that."
},
{
"type": "toolReq",
"id": "tool123",
"toolCall": {
"status": "success",
"value": {
"name": "test_tool",
"arguments": {"param": "value"},
"needsApproval": false
}
}
}
]
}"#;
let message: Message = serde_json::from_str(json_str).unwrap();
assert_eq!(message.role, Role::Assistant);
assert_eq!(message.created, 1740171566);
assert_eq!(message.content.len(), 2);
// Check first content item
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "I'll help you with that.");
} else {
panic!("Expected Text content");
}
// Check second content item
if let MessageContent::ToolReq(req) = &message.content[1] {
assert_eq!(req.id, "tool123");
if let Ok(tool_call) = req.tool_call.as_result() {
assert_eq!(tool_call.name, "test_tool");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
} else {
panic!("Expected successful tool call");
}
} else {
panic!("Expected ToolRequest content");
}
}
#[test]
fn test_message_with_text() {
let message = Message::user().with_text("Hello");
assert_eq!(message.content.concat_text_str(), "Hello");
}
#[test]
fn test_message_with_tool_request() {
let tool_call = Ok(ToolCall::new("test_tool", json!({})));
let message = Message::assistant().with_tool_request("req1", tool_call);
assert!(message.contains_tool_call());
assert!(!message.contains_tool_response());
let ids = message.tool_ids();
assert_eq!(ids.len(), 1);
assert!(ids.contains("req1"));
}
}

View File

@@ -0,0 +1,64 @@
use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
use crate::types::core::{ToolError, ToolResult};
pub fn serialize<T, S>(value: &ToolResult<T>, serializer: S) -> Result<S::Ok, S::Error>
where
T: Serialize,
S: Serializer,
{
match value {
Ok(val) => {
let mut state = serializer.serialize_struct("ToolResult", 2)?;
state.serialize_field("status", "success")?;
state.serialize_field("value", val)?;
state.end()
}
Err(err) => {
let mut state = serializer.serialize_struct("ToolResult", 2)?;
state.serialize_field("status", "error")?;
state.serialize_field("error", &err.to_string())?;
state.end()
}
}
}
// For deserialization, let's use a simpler approach that works with the format we're serializing to
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<ToolResult<T>, D::Error>
where
T: Deserialize<'de>,
D: Deserializer<'de>,
{
// Define a helper enum to handle the two possible formats
#[derive(Deserialize)]
#[serde(untagged)]
enum ResultFormat<T> {
Success { status: String, value: T },
Error { status: String, error: String },
}
let format = ResultFormat::deserialize(deserializer)?;
match format {
ResultFormat::Success { status, value } => {
if status == "success" {
Ok(Ok(value))
} else {
Err(serde::de::Error::custom(format!(
"Expected status 'success', got '{}'",
status
)))
}
}
ResultFormat::Error { status, error } => {
if status == "error" {
Ok(Err(ToolError::ExecutionError(error)))
} else {
Err(serde::de::Error::custom(format!(
"Expected status 'error', got '{}'",
status
)))
}
}
}
}

View File

@@ -0,0 +1,119 @@
use serde::{Deserialize, Serialize};
const DEFAULT_CONTEXT_LIMIT: u32 = 128_000;
/// Configuration for model-specific settings and limits
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct ModelConfig {
/// The name of the model to use
pub model_name: String,
/// Optional explicit context limit that overrides any defaults
pub context_limit: Option<u32>,
/// Optional temperature setting (0.0 - 1.0)
pub temperature: Option<f32>,
/// Optional maximum tokens to generate
pub max_tokens: Option<i32>,
}
impl ModelConfig {
/// Create a new ModelConfig with the specified model name
///
/// The context limit is set with the following precedence:
/// 1. Explicit context_limit if provided in config
/// 2. Model-specific default based on model name
/// 3. Global default (128_000) (in get_context_limit)
pub fn new(model_name: String) -> Self {
let context_limit = Self::get_model_specific_limit(&model_name);
Self {
model_name,
context_limit,
temperature: None,
max_tokens: None,
}
}
/// Get model-specific context limit based on model name
fn get_model_specific_limit(model_name: &str) -> Option<u32> {
// Implement some sensible defaults
match model_name {
// OpenAI models, https://platform.openai.com/docs/models#models-overview
name if name.contains("gpt-4o") => Some(128_000),
name if name.contains("gpt-4-turbo") => Some(128_000),
// Anthropic models, https://docs.anthropic.com/en/docs/about-claude/models
name if name.contains("claude-3") => Some(200_000),
name if name.contains("claude-4") => Some(200_000),
// Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1
name if name.contains("llama3.2") => Some(128_000),
name if name.contains("llama3.3") => Some(128_000),
_ => None,
}
}
/// Set an explicit context limit
pub fn with_context_limit(mut self, limit: Option<u32>) -> Self {
// Default is None and therefore DEFAULT_CONTEXT_LIMIT, only set
// if input is Some to allow passing through with_context_limit in
// configuration cases
if limit.is_some() {
self.context_limit = limit;
}
self
}
/// Set the temperature
pub fn with_temperature(mut self, temp: Option<f32>) -> Self {
self.temperature = temp;
self
}
/// Set the max tokens
pub fn with_max_tokens(mut self, tokens: Option<i32>) -> Self {
self.max_tokens = tokens;
self
}
/// Get the context_limit for the current model
/// If none are defined, use the DEFAULT_CONTEXT_LIMIT
pub fn context_limit(&self) -> u32 {
self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_config_context_limits() {
// Test explicit limit
let config =
ModelConfig::new("claude-3-opus".to_string()).with_context_limit(Some(150_000));
assert_eq!(config.context_limit(), 150_000);
// Test model-specific defaults
let config = ModelConfig::new("claude-3-opus".to_string());
assert_eq!(config.context_limit(), 200_000);
let config = ModelConfig::new("gpt-4-turbo".to_string());
assert_eq!(config.context_limit(), 128_000);
// Test fallback to default
let config = ModelConfig::new("unknown-model".to_string());
assert_eq!(config.context_limit(), DEFAULT_CONTEXT_LIMIT);
}
#[test]
fn test_model_config_settings() {
let config = ModelConfig::new("test-model".to_string())
.with_temperature(Some(0.7))
.with_max_tokens(Some(1000))
.with_context_limit(Some(50_000));
assert_eq!(config.temperature, Some(0.7));
assert_eq!(config.max_tokens, Some(1000));
assert_eq!(config.context_limit, Some(50_000));
}
}

View File

@@ -0,0 +1,115 @@
use std::{
path::PathBuf,
sync::{Arc, RwLock},
};
use include_dir::{include_dir, Dir};
use minijinja::{Environment, Error as MiniJinjaError, Value as MJValue};
use once_cell::sync::Lazy;
use serde::Serialize;
/// This directory will be embedded into the final binary.
/// Typically used to store "core" or "system" prompts.
static CORE_PROMPTS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/prompts");
/// A global MiniJinja environment storing the "core" prompts.
///
/// - Loaded at startup from the `CORE_PROMPTS_DIR`.
/// - Ideal for "system" templates that don't change often.
/// - *Not* used for extension prompts (which are ephemeral).
static GLOBAL_ENV: Lazy<Arc<RwLock<Environment<'static>>>> = Lazy::new(|| {
let mut env = Environment::new();
// Pre-load all core templates from the embedded dir.
for file in CORE_PROMPTS_DIR.files() {
let name = file.path().to_string_lossy().to_string();
let source = String::from_utf8_lossy(file.contents()).to_string();
// Since we're using 'static lifetime for the Environment, we need to ensure
// the strings we add as templates live for the entire program duration.
// We can achieve this by leaking the strings (acceptable for initialization).
let static_name: &'static str = Box::leak(name.into_boxed_str());
let static_source: &'static str = Box::leak(source.into_boxed_str());
if let Err(e) = env.add_template(static_name, static_source) {
println!("Failed to add template {}: {}", static_name, e);
}
}
Arc::new(RwLock::new(env))
});
/// Renders a prompt from the global environment by name.
///
/// # Arguments
/// * `template_name` - The name of the template (usually the file path or a custom ID).
/// * `context_data` - Data to be inserted into the template (must be `Serialize`).
pub fn render_global_template<T: Serialize>(
template_name: &str,
context_data: &T,
) -> Result<String, MiniJinjaError> {
let env = GLOBAL_ENV.read().expect("GLOBAL_ENV lock poisoned");
let tmpl = env.get_template(template_name)?;
let ctx = MJValue::from_serialize(context_data);
let rendered = tmpl.render(ctx)?;
Ok(rendered.trim().to_string())
}
/// Renders a file from `CORE_PROMPTS_DIR` within the global environment.
///
/// # Arguments
/// * `template_file` - The file path within the embedded directory (e.g. "system.md").
/// * `context_data` - Data to be inserted into the template (must be `Serialize`).
///
/// This function **assumes** the file is already in `CORE_PROMPTS_DIR`. If it wasn't
/// added to the global environment at startup (due to parse errors, etc.), this will error out.
pub fn render_global_file<T: Serialize>(
template_file: impl Into<PathBuf>,
context_data: &T,
) -> Result<String, MiniJinjaError> {
let file_path = template_file.into();
let template_name = file_path.to_string_lossy().to_string();
render_global_template(&template_name, context_data)
}
#[cfg(test)]
mod tests {
use super::*;
/// For convenience in tests, define a small struct or use a HashMap to provide context.
#[derive(Serialize)]
struct TestContext {
name: String,
age: u32,
}
#[test]
fn test_global_file_render() {
// "mock.md" should exist in the embedded CORE_PROMPTS_DIR
// and have placeholders for `name` and `age`.
let context = TestContext {
name: "Alice".to_string(),
age: 30,
};
let result = render_global_file("mock.md", &context).unwrap();
// Assume mock.md content is something like:
// "This prompt is only used for testing.\n\nHello, {{ name }}! You are {{ age }} years old."
assert_eq!(
result,
"This prompt is only used for testing.\n\nHello, Alice! You are 30 years old."
);
}
#[test]
fn test_global_file_not_found() {
let context = TestContext {
name: "Unused".to_string(),
age: 99,
};
let result = render_global_file("non_existent.md", &context);
assert!(result.is_err(), "Should fail because file is missing");
}
}

View File

@@ -0,0 +1,3 @@
This prompt is only used for testing.
Hello, {{ name }}! You are {{ age }} years old.

View File

@@ -0,0 +1,34 @@
{{system_preamble}}
The current date is {{current_date}}.
Goose uses LLM providers with tool calling capability. You can be used with different language models (gpt-4o, claude-3.5-sonnet, o1, llama-3.2, deepseek-r1, etc).
These models have varying knowledge cut-off dates depending on when they were trained, but typically it's between 5-10 months prior to the current date.
# Extensions
Extensions allow other applications to provide context to Goose. Extensions connect Goose to different data sources and tools.
{% if (extensions is defined) and extensions %}
Because you dynamically load extensions, your conversation history may refer
to interactions with extensions that are not currently active. The currently
active extensions are below. Each of these extensions provides tools that are
in your tool specification.
{% for extension in extensions %}
## {{extension.name}}
{% if extension.instructions %}### Instructions
{{extension.instructions}}{% endif %}
{% endfor %}
{% else %}
No extensions are defined. You should let the user know that they should add extensions.{% endif %}
# Response Guidelines
- Use Markdown formatting for all responses.
- Follow best practices for Markdown, including:
- Using headers for organization.
- Bullet points for lists.
- Links formatted correctly, either as linked text (e.g., [this is linked text](https://example.com)) or automatic links using angle brackets (e.g., <http://example.com/>).
- For code examples, use fenced code blocks by placing triple backticks (` ``` `) before and after the code. Include the language identifier after the opening backticks (e.g., ` ```python `) to enable syntax highlighting.
- Ensure clarity, conciseness, and proper formatting to enhance readability and usability.

View File

@@ -0,0 +1,131 @@
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use super::errors::ProviderError;
use crate::{message::Message, types::core::Tool};
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, uniffi::Record)]
pub struct Usage {
pub input_tokens: Option<i32>,
pub output_tokens: Option<i32>,
pub total_tokens: Option<i32>,
}
impl Usage {
pub fn new(
input_tokens: Option<i32>,
output_tokens: Option<i32>,
total_tokens: Option<i32>,
) -> Self {
Self {
input_tokens,
output_tokens,
total_tokens,
}
}
}
#[derive(Debug, Clone, uniffi::Record)]
pub struct ProviderCompleteResponse {
pub message: Message,
pub model: String,
pub usage: Usage,
}
impl ProviderCompleteResponse {
pub fn new(message: Message, model: String, usage: Usage) -> Self {
Self {
message,
model,
usage,
}
}
}
/// Response from a structuredextraction call
#[derive(Debug, Clone, uniffi::Record)]
pub struct ProviderExtractResponse {
/// The extracted JSON object
pub data: serde_json::Value,
/// Which model produced it
pub model: String,
/// Token usage stats
pub usage: Usage,
}
impl ProviderExtractResponse {
pub fn new(data: serde_json::Value, model: String, usage: Usage) -> Self {
Self { data, model, usage }
}
}
/// Base trait for AI providers (OpenAI, Anthropic, etc)
#[async_trait]
pub trait Provider: Send + Sync {
/// Generate the next message using the configured model and other parameters
///
/// # Arguments
/// * `system` - The system prompt that guides the model's behavior
/// * `messages` - The conversation history as a sequence of messages
/// * `tools` - Optional list of tools the model can use
///
/// # Returns
/// A tuple containing the model's response message and provider usage statistics
///
/// # Errors
/// ProviderError
/// - It's important to raise ContextLengthExceeded correctly since agent handles it
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<ProviderCompleteResponse, ProviderError>;
/// Structured extraction: always JSONSchema
///
/// # Arguments
/// * `system` system prompt guiding the extraction task
/// * `messages` conversation history
/// * `schema` a JSONSchema for the expected output.
/// Will set strict=true for OpenAI & Databricks.
///
/// # Returns
/// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`.
///
/// # Errors
/// * `ProviderError::ContextLengthExceeded` if the prompt is too large
/// * other `ProviderError` variants for API/network failures
async fn extract(
&self,
system: &str,
messages: &[Message],
schema: &serde_json::Value,
) -> Result<ProviderExtractResponse, ProviderError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_usage_creation() {
let usage = Usage::new(Some(10), Some(20), Some(30));
assert_eq!(usage.input_tokens, Some(10));
assert_eq!(usage.output_tokens, Some(20));
assert_eq!(usage.total_tokens, Some(30));
}
#[test]
fn test_provider_complete_response_creation() {
let message = Message::user().with_text("Hello, world!");
let usage = Usage::new(Some(10), Some(20), Some(30));
let response =
ProviderCompleteResponse::new(message.clone(), "test_model".to_string(), usage.clone());
assert_eq!(response.message, message);
assert_eq!(response.model, "test_model");
assert_eq!(response.usage, usage);
}
}

View File

@@ -0,0 +1,298 @@
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use url::Url;
use super::{
errors::ProviderError,
formats::databricks::{create_request, get_usage, response_to_message},
utils::{get_env, get_model, ImageFormat},
};
use crate::{
message::Message,
model::ModelConfig,
providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
types::core::Tool,
};
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet";
// Databricks can passthrough to a wide range of models, we only provide the default
pub const _DATABRICKS_KNOWN_MODELS: &[&str] = &[
"databricks-meta-llama-3-3-70b-instruct",
"databricks-claude-3-7-sonnet",
];
fn default_timeout() -> u64 {
60
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabricksProviderConfig {
pub host: String,
pub token: String,
#[serde(default)]
pub image_format: ImageFormat,
#[serde(default = "default_timeout")]
pub timeout: u64, // timeout in seconds
}
impl DatabricksProviderConfig {
pub fn new(host: String, token: String) -> Self {
Self {
host,
token,
image_format: ImageFormat::OpenAi,
timeout: default_timeout(),
}
}
pub fn from_env() -> Self {
let host = get_env("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST");
let token = get_env("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN");
Self::new(host, token)
}
}
#[derive(Debug)]
pub struct DatabricksProvider {
config: DatabricksProviderConfig,
model: ModelConfig,
client: Client,
}
impl DatabricksProvider {
pub fn from_env(model: ModelConfig) -> Self {
let config = DatabricksProviderConfig::from_env();
DatabricksProvider::from_config(config, model)
.expect("Failed to initialize DatabricksProvider")
}
}
impl Default for DatabricksProvider {
fn default() -> Self {
let config = DatabricksProviderConfig::from_env();
let model = ModelConfig::new(DATABRICKS_DEFAULT_MODEL.to_string());
DatabricksProvider::from_config(config, model)
.expect("Failed to initialize DatabricksProvider")
}
}
impl DatabricksProvider {
pub fn from_config(config: DatabricksProviderConfig, model: ModelConfig) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout))
.build()?;
Ok(Self {
config,
model,
client,
})
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url = Url::parse(&self.config.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let path = format!("serving-endpoints/{}/invocations", self.model.model_name);
let url = base_url.join(&path).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let auth_header = format!("Bearer {}", &self.config.token);
let response = self
.client
.post(url)
.header("Authorization", auth_header)
.json(&payload)
.send()
.await?;
let status = response.status();
let payload: Option<Value> = response.json().await.ok();
match status {
StatusCode::OK => payload.ok_or_else(|| {
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
}),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
Err(ProviderError::Authentication(format!(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}",
status, payload
)))
}
StatusCode::BAD_REQUEST => {
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
let payload_str = serde_json::to_string(&payload)
.unwrap_or_default()
.to_lowercase();
let check_phrases = [
"too long",
"context length",
"context_length_exceeded",
"reduce the length",
"token count",
"exceeds",
];
if check_phrases.iter().any(|c| payload_str.contains(c)) {
return Err(ProviderError::ContextLengthExceeded(payload_str));
}
let mut error_msg = "Unknown error".to_string();
if let Some(payload) = &payload {
// try to convert message to string, if that fails use external_model_message
error_msg = payload
.get("message")
.and_then(|m| m.as_str())
.or_else(|| {
payload
.get("external_model_message")
.and_then(|ext| ext.get("message"))
.and_then(|m| m.as_str())
})
.unwrap_or("Unknown error")
.to_string();
}
tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
Err(ProviderError::RequestFailed(format!(
"Request failed with status: {}. Message: {}",
status, error_msg
)))
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
_ => {
tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
Err(ProviderError::RequestFailed(format!(
"Request failed with status: {}",
status
)))
}
}
}
}
#[async_trait]
impl Provider for DatabricksProvider {
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<ProviderCompleteResponse, ProviderError> {
let mut payload = create_request(
&self.model,
system,
messages,
tools,
&self.config.image_format,
)?;
// Remove the model key which is part of the url with databricks
payload
.as_object_mut()
.expect("payload should have model key")
.remove("model");
let response = self.post(payload.clone()).await?;
// Parse response
let message = response_to_message(response.clone())?;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage data: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
Ok(ProviderCompleteResponse::new(message, model, usage))
}
async fn extract(
&self,
system: &str,
messages: &[Message],
schema: &Value,
) -> Result<ProviderExtractResponse, ProviderError> {
// 1. Build base payload (no tools)
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
// 2. Inject strict JSONSchema wrapper
payload
.as_object_mut()
.expect("payload must be an object")
.insert(
"response_format".to_string(),
json!({
"type": "json_schema",
"json_schema": {
"name": "extraction",
"schema": schema,
"strict": true
}
}),
);
// 3. Call OpenAI
let response = self.post(payload.clone()).await?;
// 4. Extract the assistants `content` and parse it into JSON
let msg = &response["choices"][0]["message"];
let raw = msg.get("content").cloned().ok_or_else(|| {
ProviderError::ResponseParseError("Missing content in extract response".into())
})?;
let data = match raw {
Value::String(s) => serde_json::from_str(&s)
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
Value::Object(_) | Value::Array(_) => raw,
other => {
return Err(ProviderError::ResponseParseError(format!(
"Unexpected content type: {:?}",
other
)))
}
};
// 5. Gather usage & model info
let usage = match get_usage(&response) {
Ok(u) => u,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage in extract: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
Ok(ProviderExtractResponse::new(data, model, usage))
}
}

View File

@@ -0,0 +1,144 @@
use thiserror::Error;
#[derive(Error, Debug, uniffi::Error)]
pub enum ProviderError {
#[error("Authentication error: {0}")]
Authentication(String),
#[error("Context length exceeded: {0}")]
ContextLengthExceeded(String),
#[error("Rate limit exceeded: {0}")]
RateLimitExceeded(String),
#[error("Server error: {0}")]
ServerError(String),
#[error("Request failed: {0}")]
RequestFailed(String),
#[error("Execution error: {0}")]
ExecutionError(String),
#[error("Usage data error: {0}")]
UsageError(String),
#[error("Invalid response: {0}")]
ResponseParseError(String),
}
impl From<anyhow::Error> for ProviderError {
fn from(error: anyhow::Error) -> Self {
ProviderError::ExecutionError(error.to_string())
}
}
impl From<reqwest::Error> for ProviderError {
fn from(error: reqwest::Error) -> Self {
ProviderError::ExecutionError(error.to_string())
}
}
#[derive(serde::Deserialize, Debug)]
pub struct OpenAIError {
#[serde(deserialize_with = "code_as_string")]
pub code: Option<String>,
pub message: Option<String>,
#[serde(rename = "type")]
pub error_type: Option<String>,
}
fn code_as_string<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use std::fmt;
use serde::de::{self, Visitor};
struct CodeVisitor;
impl<'de> Visitor<'de> for CodeVisitor {
type Value = Option<String>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string, a number, null, or none for the code field")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Some(value.to_string()))
}
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Some(value.to_string()))
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(CodeVisitor)
}
}
deserializer.deserialize_option(CodeVisitor)
}
impl OpenAIError {
pub fn is_context_length_exceeded(&self) -> bool {
if let Some(code) = &self.code {
code == "context_length_exceeded" || code == "string_above_max_length"
} else {
false
}
}
}
impl std::fmt::Display for OpenAIError {
/// Format the error for display.
/// E.g. {"message": "Invalid API key", "code": "invalid_api_key", "type": "client_error"}
/// would be formatted as "Invalid API key (code: invalid_api_key, type: client_error)"
/// and {"message": "Foo"} as just "Foo", etc.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(message) = &self.message {
write!(f, "{}", message)?;
}
let mut in_parenthesis = false;
if let Some(code) = &self.code {
write!(f, " (code: {}", code)?;
in_parenthesis = true;
}
if let Some(typ) = &self.error_type {
if in_parenthesis {
write!(f, ", type: {}", typ)?;
} else {
write!(f, " (type: {}", typ)?;
in_parenthesis = true;
}
}
if in_parenthesis {
write!(f, ")")?;
}
Ok(())
}
}

View File

@@ -0,0 +1,29 @@
use std::sync::Arc;
use anyhow::Result;
use super::{
base::Provider,
databricks::{DatabricksProvider, DatabricksProviderConfig},
openai::{OpenAiProvider, OpenAiProviderConfig},
};
use crate::model::ModelConfig;
pub fn create(
name: &str,
provider_config: serde_json::Value,
model: ModelConfig,
) -> Result<Arc<dyn Provider>> {
// We use Arc instead of Box to be able to clone for multiple async tasks
match name {
"openai" => {
let config: OpenAiProviderConfig = serde_json::from_value(provider_config)?;
Ok(Arc::new(OpenAiProvider::from_config(config, model)?))
}
"databricks" => {
let config: DatabricksProviderConfig = serde_json::from_value(provider_config)?;
Ok(Arc::new(DatabricksProvider::from_config(config, model)?))
}
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
pub mod databricks;
pub mod openai;

View File

@@ -0,0 +1,897 @@
use anyhow::{anyhow, Error};
use serde_json::{json, Value};
use crate::{
message::{Message, MessageContent},
model::ModelConfig,
providers::{
base::Usage,
errors::ProviderError,
utils::{
convert_image, detect_image_path, is_valid_function_name, load_image_file,
sanitize_function_name, ImageFormat,
},
},
types::core::{Content, Role, Tool, ToolCall, ToolError},
};
/// Convert internal Message format to OpenAI's API message specification
/// some openai compatible endpoints use the anthropic image spec at the content level
/// even though the message structure is otherwise following openai, the enum switches this
pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<Value> {
let mut messages_spec = Vec::new();
for message in messages {
let mut converted = json!({
"role": message.role
});
let mut output = Vec::new();
for content in message.content.iter() {
match content {
MessageContent::Text(text) => {
if !text.text.is_empty() {
// Check for image paths in the text
if let Some(image_path) = detect_image_path(&text.text) {
// Try to load and convert the image
if let Ok(image) = load_image_file(image_path) {
converted["content"] = json!([
{"type": "text", "text": text.text},
convert_image(&image, image_format)
]);
} else {
// If image loading fails, just use the text
converted["content"] = json!(text.text);
}
} else {
converted["content"] = json!(text.text);
}
}
}
MessageContent::Thinking(_) => {
// Thinking blocks are not directly used in OpenAI format
continue;
}
MessageContent::RedactedThinking(_) => {
// Redacted thinking blocks are not directly used in OpenAI format
continue;
}
MessageContent::ToolReq(request) => match &request.tool_call.as_result() {
Ok(tool_call) => {
let sanitized_name = sanitize_function_name(&tool_call.name);
let tool_calls = converted
.as_object_mut()
.unwrap()
.entry("tool_calls")
.or_insert(json!([]));
tool_calls.as_array_mut().unwrap().push(json!({
"id": request.id,
"type": "function",
"function": {
"name": sanitized_name,
"arguments": tool_call.arguments.to_string(),
}
}));
}
Err(e) => {
output.push(json!({
"role": "tool",
"content": format!("Error: {}", e),
"tool_call_id": request.id
}));
}
},
MessageContent::ToolResp(response) => {
match &response.tool_result.0 {
Ok(contents) => {
// Process all content, replacing images with placeholder text
let mut tool_content = Vec::new();
let mut image_messages = Vec::new();
for content in contents {
match content {
Content::Image(image) => {
// Add placeholder text in the tool response
tool_content.push(Content::text("This tool result included an image that is uploaded in the next message."));
// Create a separate image message
image_messages.push(json!({
"role": "user",
"content": [convert_image(image, image_format)]
}));
}
_ => {
tool_content.push(content.clone());
}
}
}
let tool_response_content: Value = json!(tool_content
.iter()
.map(|content| match content {
Content::Text(text) => text.text.clone(),
_ => String::new(),
})
.collect::<Vec<String>>()
.join(" "));
// First add the tool response with all content
output.push(json!({
"role": "tool",
"content": tool_response_content,
"tool_call_id": response.id
}));
// Then add any image messages that need to follow
output.extend(image_messages);
}
Err(e) => {
// A tool result error is shown as output so the model can interpret the error message
output.push(json!({
"role": "tool",
"content": format!("The tool call returned the following error:\n{}", e),
"tool_call_id": response.id
}));
}
}
}
MessageContent::Image(image) => {
// Handle direct image content
converted["content"] = json!([convert_image(image, image_format)]);
}
}
}
if converted.get("content").is_some() || converted.get("tool_calls").is_some() {
output.insert(0, converted);
}
messages_spec.extend(output);
}
messages_spec
}
/// Convert internal Tool format to OpenAI's API tool specification
pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
let mut tool_names = std::collections::HashSet::new();
let mut result = Vec::new();
for tool in tools {
if !tool_names.insert(&tool.name) {
return Err(anyhow!("Duplicate tool name: {}", tool.name));
}
let mut description = tool.description.clone();
description.truncate(1024);
// OpenAI's tool description max str len is 1024
result.push(json!({
"type": "function",
"function": {
"name": tool.name,
"description": description,
"parameters": tool.input_schema,
}
}));
}
Ok(result)
}
/// Convert OpenAI's API response to internal Message format
pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
let original = response["choices"][0]["message"].clone();
let mut content = Vec::new();
if let Some(text) = original.get("content") {
if let Some(text_str) = text.as_str() {
content.push(MessageContent::text(text_str));
}
}
if let Some(tool_calls) = original.get("tool_calls") {
if let Some(tool_calls_array) = tool_calls.as_array() {
for tool_call in tool_calls_array {
let id = tool_call["id"].as_str().unwrap_or_default().to_string();
let function_name = tool_call["function"]["name"]
.as_str()
.unwrap_or_default()
.to_string();
let mut arguments = tool_call["function"]["arguments"]
.as_str()
.unwrap_or_default()
.to_string();
// If arguments is empty, we will have invalid json parsing error later.
if arguments.is_empty() {
arguments = "{}".to_string();
}
if !is_valid_function_name(&function_name) {
let error = ToolError::NotFound(format!(
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
function_name
));
content.push(MessageContent::tool_request(id, Err(error).into()));
} else {
match serde_json::from_str::<Value>(&arguments) {
Ok(params) => {
content.push(MessageContent::tool_request(
id,
Ok(ToolCall::new(&function_name, params)).into(),
));
}
Err(e) => {
let error = ToolError::InvalidParameters(format!(
"Could not interpret tool use parameters for id {}: {}",
id, e
));
content.push(MessageContent::tool_request(id, Err(error).into()));
}
}
}
}
}
}
Ok(Message {
role: Role::Assistant,
created: chrono::Utc::now().timestamp_millis(),
content: content.into(),
})
}
pub fn get_usage(data: &Value) -> Result<Usage, ProviderError> {
let usage = data
.get("usage")
.ok_or_else(|| ProviderError::UsageError("No usage data in response".to_string()))?;
let input_tokens = usage
.get("prompt_tokens")
.and_then(|v| v.as_i64())
.map(|v| v as i32);
let output_tokens = usage
.get("completion_tokens")
.and_then(|v| v.as_i64())
.map(|v| v as i32);
let total_tokens = usage
.get("total_tokens")
.and_then(|v| v.as_i64())
.map(|v| v as i32)
.or_else(|| match (input_tokens, output_tokens) {
(Some(input), Some(output)) => Some(input + output),
_ => None,
});
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
}
/// Validates and fixes tool schemas to ensure they have proper parameter structure.
/// If parameters exist, ensures they have properties and required fields, or removes parameters entirely.
pub fn validate_tool_schemas(tools: &mut [Value]) {
for tool in tools.iter_mut() {
if let Some(function) = tool.get_mut("function") {
if let Some(parameters) = function.get_mut("parameters") {
if parameters.is_object() {
ensure_valid_json_schema(parameters);
}
}
}
}
}
/// Ensures that the given JSON value follows the expected JSON Schema structure.
fn ensure_valid_json_schema(schema: &mut Value) {
if let Some(params_obj) = schema.as_object_mut() {
// Check if this is meant to be an object type schema
let is_object_type = params_obj
.get("type")
.and_then(|t| t.as_str())
.is_none_or(|t| t == "object"); // Default to true if no type is specified
// Only apply full schema validation to object types
if is_object_type {
// Ensure required fields exist with default values
params_obj.entry("properties").or_insert_with(|| json!({}));
params_obj.entry("required").or_insert_with(|| json!([]));
params_obj.entry("type").or_insert_with(|| json!("object"));
// Recursively validate properties if it exists
if let Some(properties) = params_obj.get_mut("properties") {
if let Some(properties_obj) = properties.as_object_mut() {
for (_key, prop) in properties_obj.iter_mut() {
if prop.is_object()
&& prop.get("type").and_then(|t| t.as_str()) == Some("object")
{
ensure_valid_json_schema(prop);
}
}
}
}
}
}
}
pub fn create_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
image_format: &ImageFormat,
) -> anyhow::Result<Value, Error> {
if model_config.model_name.starts_with("o1-mini") {
return Err(anyhow!(
"o1-mini model is not currently supported since Goose uses tool calling and o1-mini does not support it. Please use o1 or o3 models instead."
));
}
let is_ox_model = model_config.model_name.starts_with("o");
// Only extract reasoning effort for O1/O3 models
let (model_name, reasoning_effort) = if is_ox_model {
let parts: Vec<&str> = model_config.model_name.split('-').collect();
let last_part = parts.last().unwrap();
match *last_part {
"low" | "medium" | "high" => {
let base_name = parts[..parts.len() - 1].join("-");
(base_name, Some(last_part.to_string()))
}
_ => (
model_config.model_name.to_string(),
Some("medium".to_string()),
),
}
} else {
// For non-O family models, use the model name as is and no reasoning effort
(model_config.model_name.to_string(), None)
};
let system_message = json!({
"role": if is_ox_model { "developer" } else { "system" },
"content": system
});
let messages_spec = format_messages(messages, image_format);
let mut tools_spec = if !tools.is_empty() {
format_tools(tools)?
} else {
vec![]
};
// Validate tool schemas
validate_tool_schemas(&mut tools_spec);
let mut messages_array = vec![system_message];
messages_array.extend(messages_spec);
let mut payload = json!({
"model": model_name,
"messages": messages_array
});
if let Some(effort) = reasoning_effort {
payload
.as_object_mut()
.unwrap()
.insert("reasoning_effort".to_string(), json!(effort));
}
if !tools_spec.is_empty() {
payload
.as_object_mut()
.unwrap()
.insert("tools".to_string(), json!(tools_spec));
}
// o1, o3 models currently don't support temperature
if !is_ox_model {
if let Some(temp) = model_config.temperature {
payload
.as_object_mut()
.unwrap()
.insert("temperature".to_string(), json!(temp));
}
}
// o1 models use max_completion_tokens instead of max_tokens
if let Some(tokens) = model_config.max_tokens {
let key = if is_ox_model {
"max_completion_tokens"
} else {
"max_tokens"
};
payload
.as_object_mut()
.unwrap()
.insert(key.to_string(), json!(tokens));
}
Ok(payload)
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use crate::types::core::Content;
#[test]
fn test_validate_tool_schemas() {
// Test case 1: Empty parameters object
// Input JSON with an incomplete parameters object
let mut actual = vec![json!({
"type": "function",
"function": {
"name": "test_func",
"description": "test description",
"parameters": {
"type": "object"
}
}
})];
// Run the function to validate and update schemas
validate_tool_schemas(&mut actual);
// Expected JSON after validation
let expected = vec![json!({
"type": "function",
"function": {
"name": "test_func",
"description": "test description",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
})];
// Compare entire JSON structures instead of individual fields
assert_eq!(actual, expected);
// Test case 2: Missing type field
let mut tools = vec![json!({
"type": "function",
"function": {
"name": "test_func",
"description": "test description",
"parameters": {
"properties": {}
}
}
})];
validate_tool_schemas(&mut tools);
let params = tools[0]["function"]["parameters"].as_object().unwrap();
assert_eq!(params["type"], "object");
// Test case 3: Complete valid schema should remain unchanged
let original_schema = json!({
"type": "function",
"function": {
"name": "test_func",
"description": "test description",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City and country"
}
},
"required": ["location"]
}
}
});
let mut tools = vec![original_schema.clone()];
validate_tool_schemas(&mut tools);
assert_eq!(tools[0], original_schema);
}
const OPENAI_TOOL_USE_RESPONSE: &str = r#"{
"choices": [{
"role": "assistant",
"message": {
"tool_calls": [{
"id": "1",
"function": {
"name": "example_fn",
"arguments": "{\"param\": \"value\"}"
}
}]
}
}],
"usage": {
"input_tokens": 10,
"output_tokens": 25,
"total_tokens": 35
}
}"#;
#[test]
fn test_format_messages() -> anyhow::Result<()> {
let message = Message::user().with_text("Hello");
let spec = format_messages(&[message], &ImageFormat::OpenAi);
assert_eq!(spec.len(), 1);
assert_eq!(spec[0]["role"], "user");
assert_eq!(spec[0]["content"], "Hello");
Ok(())
}
#[test]
fn test_format_tools() -> anyhow::Result<()> {
let tool = Tool::new(
"test_tool",
"A test tool",
json!({
"type": "object",
"properties": {
"input": {
"type": "string",
"description": "Test parameter"
}
},
"required": ["input"]
}),
);
let spec = format_tools(&[tool])?;
assert_eq!(spec.len(), 1);
assert_eq!(spec[0]["type"], "function");
assert_eq!(spec[0]["function"]["name"], "test_tool");
Ok(())
}
#[test]
fn test_format_messages_complex() -> anyhow::Result<()> {
let mut messages = vec![
Message::assistant().with_text("Hello!"),
Message::user().with_text("How are you?"),
Message::assistant().with_tool_request(
"tool1",
Ok(ToolCall::new("example", json!({"param1": "value1"}))),
),
];
// Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolReq(request) = &messages[2].content[0] {
request.id.clone()
} else {
panic!("should be tool request");
};
messages.push(
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi);
assert_eq!(spec.len(), 4);
assert_eq!(spec[0]["role"], "assistant");
assert_eq!(spec[0]["content"], "Hello!");
assert_eq!(spec[1]["role"], "user");
assert_eq!(spec[1]["content"], "How are you?");
assert_eq!(spec[2]["role"], "assistant");
assert!(spec[2]["tool_calls"].is_array());
assert_eq!(spec[3]["role"], "tool");
assert_eq!(spec[3]["content"], "Result");
assert_eq!(spec[3]["tool_call_id"], spec[2]["tool_calls"][0]["id"]);
Ok(())
}
#[test]
fn test_format_messages_multiple_content() -> anyhow::Result<()> {
let mut messages = vec![Message::assistant().with_tool_request(
"tool1",
Ok(ToolCall::new("example", json!({"param1": "value1"}))),
)];
// Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolReq(request) = &messages[0].content[0] {
request.id.clone()
} else {
panic!("should be tool request");
};
messages.push(
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi);
assert_eq!(spec.len(), 2);
assert_eq!(spec[0]["role"], "assistant");
assert!(spec[0]["tool_calls"].is_array());
assert_eq!(spec[1]["role"], "tool");
assert_eq!(spec[1]["content"], "Result");
assert_eq!(spec[1]["tool_call_id"], spec[0]["tool_calls"][0]["id"]);
Ok(())
}
#[test]
fn test_format_tools_duplicate() -> anyhow::Result<()> {
let tool1 = Tool::new(
"test_tool",
"Test tool",
json!({
"type": "object",
"properties": {
"input": {
"type": "string",
"description": "Test parameter"
}
},
"required": ["input"]
}),
);
let tool2 = Tool::new(
"test_tool",
"Test tool",
json!({
"type": "object",
"properties": {
"input": {
"type": "string",
"description": "Test parameter"
}
},
"required": ["input"]
}),
);
let result = format_tools(&[tool1, tool2]);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Duplicate tool name"));
Ok(())
}
#[test]
fn test_format_tools_empty() -> anyhow::Result<()> {
let spec = format_tools(&[])?;
assert!(spec.is_empty());
Ok(())
}
#[test]
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir()?;
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, &png_data)?;
let png_path_str = png_path.to_str().unwrap();
// Create message with image path
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
let spec = format_messages(&[message], &ImageFormat::OpenAi);
assert_eq!(spec.len(), 1);
assert_eq!(spec[0]["role"], "user");
// Content should be an array with text and image
let content = spec[0]["content"].as_array().unwrap();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "text");
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
assert_eq!(content[1]["type"], "image_url");
assert!(content[1]["image_url"]["url"]
.as_str()
.unwrap()
.starts_with("data:image/png;base64,"));
Ok(())
}
#[test]
fn test_response_to_message_text() -> anyhow::Result<()> {
let response = json!({
"choices": [{
"role": "assistant",
"message": {
"content": "Hello from John Cena!"
}
}],
"usage": {
"input_tokens": 10,
"output_tokens": 25,
"total_tokens": 35
}
});
let message = response_to_message(response)?;
assert_eq!(message.content.len(), 1);
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "Hello from John Cena!");
} else {
panic!("Expected Text content");
}
assert!(matches!(message.role, Role::Assistant));
Ok(())
}
#[test]
fn test_response_to_message_valid_toolrequest() -> anyhow::Result<()> {
let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
let message = response_to_message(response)?;
assert_eq!(message.content.len(), 1);
if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
} else {
panic!("Expected ToolRequest content");
}
Ok(())
}
#[test]
fn test_response_to_message_invalid_func_name() -> anyhow::Result<()> {
let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] =
json!("invalid fn");
let message = response_to_message(response)?;
if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call.as_result() {
Err(ToolError::NotFound(msg)) => {
assert!(msg.starts_with("The provided function name"));
}
_ => panic!("Expected ToolNotFound error"),
}
} else {
panic!("Expected ToolRequest content");
}
Ok(())
}
#[test]
fn test_response_to_message_json_decode_error() -> anyhow::Result<()> {
let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
json!("invalid json {");
let message = response_to_message(response)?;
if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call.as_result() {
Err(ToolError::InvalidParameters(msg)) => {
assert!(msg.starts_with("Could not interpret tool use parameters"));
}
_ => panic!("Expected InvalidParameters error"),
}
} else {
panic!("Expected ToolRequest content");
}
Ok(())
}
#[test]
fn test_response_to_message_empty_argument() -> anyhow::Result<()> {
let mut response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
serde_json::Value::String("".to_string());
let message = response_to_message(response)?;
if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({}));
} else {
panic!("Expected ToolRequest content");
}
Ok(())
}
#[test]
fn test_create_request_gpt_4o() -> anyhow::Result<()> {
// Test default medium reasoning effort for O3 model
let model_config = ModelConfig {
model_name: "gpt-4o".to_string(),
context_limit: Some(4096),
temperature: None,
max_tokens: Some(1024),
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
let expected = json!({
"model": "gpt-4o",
"messages": [
{
"role": "system",
"content": "system"
}
],
"max_tokens": 1024
});
for (key, value) in expected.as_object().unwrap() {
assert_eq!(obj.get(key).unwrap(), value);
}
Ok(())
}
#[test]
fn test_create_request_o1_default() -> anyhow::Result<()> {
// Test default medium reasoning effort for O1 model
let model_config = ModelConfig {
model_name: "o1".to_string(),
context_limit: Some(4096),
temperature: None,
max_tokens: Some(1024),
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
let expected = json!({
"model": "o1",
"messages": [
{
"role": "developer",
"content": "system"
}
],
"reasoning_effort": "medium",
"max_completion_tokens": 1024
});
for (key, value) in expected.as_object().unwrap() {
assert_eq!(obj.get(key).unwrap(), value);
}
Ok(())
}
#[test]
fn test_create_request_o3_custom_reasoning_effort() -> anyhow::Result<()> {
// Test custom reasoning effort for O3 model
let model_config = ModelConfig {
model_name: "o3-mini-high".to_string(),
context_limit: Some(4096),
temperature: None,
max_tokens: Some(1024),
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
let expected = json!({
"model": "o3-mini",
"messages": [
{
"role": "developer",
"content": "system"
}
],
"reasoning_effort": "high",
"max_completion_tokens": 1024
});
for (key, value) in expected.as_object().unwrap() {
assert_eq!(obj.get(key).unwrap(), value);
}
Ok(())
}
}

View File

@@ -0,0 +1,10 @@
pub mod base;
pub mod databricks;
pub mod errors;
mod factory;
pub mod formats;
pub mod openai;
pub mod utils;
pub use base::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage};
pub use factory::create;

View File

@@ -0,0 +1,231 @@
use std::{collections::HashMap, time::Duration};
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use super::{
errors::ProviderError,
formats::openai::{create_request, get_usage, response_to_message},
utils::{emit_debug_trace, get_env, get_model, handle_response_openai_compat, ImageFormat},
};
use crate::{
message::Message,
model::ModelConfig,
providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
types::core::Tool,
};
pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4.1", "o1", "o3", "o4-mini"];
fn default_timeout() -> u64 {
60
}
fn default_base_path() -> String {
"v1/chat/completions".to_string()
}
fn default_host() -> String {
"https://api.openai.com".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiProviderConfig {
pub api_key: String,
#[serde(default = "default_host")]
pub host: String,
#[serde(default)]
pub organization: Option<String>,
#[serde(default = "default_base_path")]
pub base_path: String,
#[serde(default)]
pub project: Option<String>,
#[serde(default)]
pub custom_headers: Option<HashMap<String, String>>,
#[serde(default = "default_timeout")]
pub timeout: u64, // timeout in seconds
}
impl OpenAiProviderConfig {
pub fn new(api_key: String) -> Self {
Self {
api_key,
host: default_host(),
organization: None,
base_path: default_base_path(),
project: None,
custom_headers: None,
timeout: 600,
}
}
pub fn from_env() -> Self {
let api_key = get_env("OPENAI_API_KEY").expect("Missing OPENAI_API_KEY");
Self::new(api_key)
}
}
#[derive(Debug)]
pub struct OpenAiProvider {
config: OpenAiProviderConfig,
model: ModelConfig,
client: Client,
}
impl OpenAiProvider {
pub fn from_env(model: ModelConfig) -> Self {
let config = OpenAiProviderConfig::from_env();
OpenAiProvider::from_config(config, model).expect("Failed to initialize OpenAiProvider")
}
}
impl Default for OpenAiProvider {
fn default() -> Self {
let config = OpenAiProviderConfig::from_env();
let model = ModelConfig::new(OPEN_AI_DEFAULT_MODEL.to_string());
OpenAiProvider::from_config(config, model).expect("Failed to initialize OpenAiProvider")
}
}
impl OpenAiProvider {
pub fn from_config(config: OpenAiProviderConfig, model: ModelConfig) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout))
.build()?;
Ok(Self {
config,
model,
client,
})
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url = url::Url::parse(&self.config.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let url = base_url.join(&self.config.base_path).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let mut request = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.config.api_key));
// Add organization header if present
if let Some(org) = &self.config.organization {
request = request.header("OpenAI-Organization", org);
}
// Add project header if present
if let Some(project) = &self.config.project {
request = request.header("OpenAI-Project", project);
}
if let Some(custom_headers) = &self.config.custom_headers {
for (key, value) in custom_headers {
request = request.header(key, value);
}
}
let response = request.json(&payload).send().await?;
handle_response_openai_compat(response).await
}
}
#[async_trait]
impl Provider for OpenAiProvider {
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<ProviderCompleteResponse, ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
// Make request
let response = self.post(payload.clone()).await?;
// Parse response
let message = response_to_message(response.clone())?;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage data: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
emit_debug_trace(&self.model, &payload, &response, &usage);
Ok(ProviderCompleteResponse::new(message, model, usage))
}
async fn extract(
&self,
system: &str,
messages: &[Message],
schema: &Value,
) -> Result<ProviderExtractResponse, ProviderError> {
// 1. Build base payload (no tools)
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
// 2. Inject strict JSONSchema wrapper
payload
.as_object_mut()
.expect("payload must be an object")
.insert(
"response_format".to_string(),
json!({
"type": "json_schema",
"json_schema": {
"name": "extraction",
"schema": schema,
"strict": true
}
}),
);
// 3. Call OpenAI
let response = self.post(payload.clone()).await?;
// 4. Extract the assistants `content` and parse it into JSON
let msg = &response["choices"][0]["message"];
let raw = msg.get("content").cloned().ok_or_else(|| {
ProviderError::ResponseParseError("Missing content in extract response".into())
})?;
let data = match raw {
Value::String(s) => serde_json::from_str(&s)
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
Value::Object(_) | Value::Array(_) => raw,
other => {
return Err(ProviderError::ResponseParseError(format!(
"Unexpected content type: {:?}",
other
)))
}
};
// 5. Gather usage & model info
let usage = match get_usage(&response) {
Ok(u) => u,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage in extract: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
Ok(ProviderExtractResponse::new(data, model, usage))
}
}

View File

@@ -0,0 +1,359 @@
use std::{env, io::Read, path::Path};
use anyhow::Result;
use base64::Engine;
use regex::Regex;
use reqwest::{Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{from_value, json, Value};
use super::base::Usage;
use crate::{
model::ModelConfig,
providers::errors::{OpenAIError, ProviderError},
types::core::ImageContent,
};
#[derive(serde::Deserialize)]
struct OpenAIErrorResponse {
error: OpenAIError,
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize, Default)]
pub enum ImageFormat {
#[default]
OpenAi,
Anthropic,
}
/// Timeout in seconds.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Timeout(u32);
impl Default for Timeout {
fn default() -> Self {
Timeout(60)
}
}
/// Convert an image content into an image json based on format
pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value {
match image_format {
ImageFormat::OpenAi => json!({
"type": "image_url",
"image_url": {
"url": format!("data:{};base64,{}", image.mime_type, image.data)
}
}),
ImageFormat::Anthropic => json!({
"type": "image",
"source": {
"type": "base64",
"media_type": image.mime_type,
"data": image.data,
}
}),
}
}
/// Handle response from OpenAI compatible endpoints
/// Error codes: https://platform.openai.com/docs/guides/error-codes
/// Context window exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543
pub async fn handle_response_openai_compat(response: Response) -> Result<Value, ProviderError> {
let status = response.status();
// Try to parse the response body as JSON (if applicable)
let payload = match response.json::<Value>().await {
Ok(json) => json,
Err(e) => return Err(ProviderError::RequestFailed(e.to_string())),
};
match status {
StatusCode::OK => Ok(payload),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
Err(ProviderError::Authentication(format!(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}",
status, payload
)))
}
StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND => {
tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
if let Ok(err_resp) = from_value::<OpenAIErrorResponse>(payload) {
let err = err_resp.error;
if err.is_context_length_exceeded() {
return Err(ProviderError::ContextLengthExceeded(
err.message.unwrap_or("Unknown error".to_string()),
));
}
return Err(ProviderError::RequestFailed(format!(
"{} (status {})",
err,
status.as_u16()
)));
}
Err(ProviderError::RequestFailed(format!(
"Unknown error (status {})",
status
)))
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
_ => {
tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
Err(ProviderError::RequestFailed(format!(
"Request failed with status: {}",
status
)))
}
}
}
/// Get a secret from environment variables. The secret is expected to be in JSON format.
pub fn get_env(key: &str) -> Result<String> {
// check environment variables (convert to uppercase)
let env_key = key.to_uppercase();
if let Ok(val) = env::var(&env_key) {
let value: Value = serde_json::from_str(&val).unwrap_or(Value::String(val));
Ok(serde_json::from_value(value)?)
} else {
Err(anyhow::anyhow!(
"Environment variable {} not found",
env_key
))
}
}
pub fn sanitize_function_name(name: &str) -> String {
let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap();
re.replace_all(name, "_").to_string()
}
pub fn is_valid_function_name(name: &str) -> bool {
let re = Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
re.is_match(name)
}
/// Extract the model name from a JSON object. Common with most providers to have this top level attribute.
pub fn get_model(data: &Value) -> String {
if let Some(model) = data.get("model") {
if let Some(model_str) = model.as_str() {
model_str.to_string()
} else {
"Unknown".to_string()
}
} else {
"Unknown".to_string()
}
}
/// Check if a file is actually an image by examining its magic bytes
fn is_image_file(path: &Path) -> bool {
if let Ok(mut file) = std::fs::File::open(path) {
let mut buffer = [0u8; 8]; // Large enough for most image magic numbers
if file.read(&mut buffer).is_ok() {
// Check magic numbers for common image formats
return match &buffer[0..4] {
// PNG: 89 50 4E 47
[0x89, 0x50, 0x4E, 0x47] => true,
// JPEG: FF D8 FF
[0xFF, 0xD8, 0xFF, _] => true,
// GIF: 47 49 46 38
[0x47, 0x49, 0x46, 0x38] => true,
_ => false,
};
}
}
false
}
/// Detect if a string contains a path to an image file
pub fn detect_image_path(text: &str) -> Option<&str> {
// Basic image file extension check
let extensions = [".png", ".jpg", ".jpeg"];
// Find any word that ends with an image extension
for word in text.split_whitespace() {
if extensions
.iter()
.any(|ext| word.to_lowercase().ends_with(ext))
{
let path = Path::new(word);
// Check if it's an absolute path and file exists
if path.is_absolute() && path.is_file() {
// Verify it's actually an image file
if is_image_file(path) {
return Some(word);
}
}
}
}
None
}
/// Convert a local image file to base64 encoded ImageContent
pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
let path = Path::new(path);
// Verify it's an image before proceeding
if !is_image_file(path) {
return Err(ProviderError::RequestFailed(
"File is not a valid image".to_string(),
));
}
// Read the file
let bytes = std::fs::read(path)
.map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?;
// Detect mime type from extension
let mime_type = match path.extension().and_then(|e| e.to_str()) {
Some(ext) => match ext.to_lowercase().as_str() {
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
_ => {
return Err(ProviderError::RequestFailed(
"Unsupported image format".to_string(),
));
}
},
None => {
return Err(ProviderError::RequestFailed(
"Unknown image format".to_string(),
));
}
};
// Convert to base64
let data = base64::prelude::BASE64_STANDARD.encode(&bytes);
Ok(ImageContent {
mime_type: mime_type.to_string(),
data,
})
}
pub fn emit_debug_trace(
model_config: &ModelConfig,
payload: &Value,
response: &Value,
usage: &Usage,
) {
tracing::debug!(
model_config = %serde_json::to_string_pretty(model_config).unwrap_or_default(),
input = %serde_json::to_string_pretty(payload).unwrap_or_default(),
output = %serde_json::to_string_pretty(response).unwrap_or_default(),
input_tokens = ?usage.input_tokens.unwrap_or_default(),
output_tokens = ?usage.output_tokens.unwrap_or_default(),
total_tokens = ?usage.total_tokens.unwrap_or_default(),
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_image_path() {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir().unwrap();
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, &png_data).unwrap();
let png_path_str = png_path.to_str().unwrap();
// Create a fake PNG (wrong magic numbers)
let fake_png_path = temp_dir.path().join("fake.png");
std::fs::write(&fake_png_path, b"not a real png").unwrap();
// Test with valid PNG file using absolute path
let text = format!("Here is an image {}", png_path_str);
assert_eq!(detect_image_path(&text), Some(png_path_str));
// Test with non-image file that has .png extension
let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
assert_eq!(detect_image_path(&text), None);
// Test with non-existent file
let text = "Here is a fake.png that doesn't exist";
assert_eq!(detect_image_path(text), None);
// Test with non-image file
let text = "Here is a file.txt";
assert_eq!(detect_image_path(text), None);
// Test with relative path (should not match)
let text = "Here is a relative/path/image.png";
assert_eq!(detect_image_path(text), None);
}
#[test]
fn test_load_image_file() {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir().unwrap();
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, &png_data).unwrap();
let png_path_str = png_path.to_str().unwrap();
// Create a fake PNG (wrong magic numbers)
let fake_png_path = temp_dir.path().join("fake.png");
std::fs::write(&fake_png_path, b"not a real png").unwrap();
let fake_png_path_str = fake_png_path.to_str().unwrap();
// Test loading valid PNG file
let result = load_image_file(png_path_str);
assert!(result.is_ok());
let image = result.unwrap();
assert_eq!(image.mime_type, "image/png");
// Test loading fake PNG file
let result = load_image_file(fake_png_path_str);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not a valid image"));
// Test non-existent file
let result = load_image_file("nonexistent.png");
assert!(result.is_err());
}
#[test]
fn test_sanitize_function_name() {
assert_eq!(sanitize_function_name("hello-world"), "hello-world");
assert_eq!(sanitize_function_name("hello world"), "hello_world");
assert_eq!(sanitize_function_name("hello@world"), "hello_world");
}
#[test]
fn test_is_valid_function_name() {
assert!(is_valid_function_name("hello-world"));
assert!(is_valid_function_name("hello_world"));
assert!(!is_valid_function_name("hello world"));
assert!(!is_valid_function_name("hello@world"));
}
}

View File

@@ -0,0 +1,29 @@
use crate::{
providers::{create, errors::ProviderError, ProviderExtractResponse},
types::json_value_ffi::JsonValueFfi,
Message, ModelConfig,
};
/// Generates a structured output based on the provided schema,
/// system prompt and user messages.
#[uniffi::export(async_runtime = "tokio")]
pub async fn generate_structured_outputs(
provider_name: &str,
provider_config: JsonValueFfi,
system_prompt: &str,
messages: &[Message],
schema: JsonValueFfi,
) -> Result<ProviderExtractResponse, ProviderError> {
// Use OpenAI models specifically for this task
let model_name = if provider_name == "databricks" {
"goose-gpt-4-1"
} else {
"gpt-4.1"
};
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
let provider = create(provider_name, provider_config, model_cfg)?;
let resp = provider.extract(system_prompt, messages, &schema).await?;
Ok(resp)
}

View File

@@ -0,0 +1,227 @@
// This file defines types for completion interfaces, including the request and response structures.
// Many of these are adapted based on the Goose Service API:
// https://docs.google.com/document/d/1r5vjSK3nBQU1cIRf0WKysDigqMlzzrzl_bxEE4msOiw/edit?tab=t.0
use std::collections::HashMap;
use thiserror::Error;
use serde::{Deserialize, Serialize};
use crate::types::json_value_ffi::JsonValueFfi;
use crate::{message::Message, providers::Usage};
use crate::{model::ModelConfig, providers::errors::ProviderError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub provider_name: String,
pub provider_config: serde_json::Value,
pub model_config: ModelConfig,
pub system_preamble: String,
pub messages: Vec<Message>,
pub extensions: Vec<ExtensionConfig>,
}
impl CompletionRequest {
pub fn new(
provider_name: String,
provider_config: serde_json::Value,
model_config: ModelConfig,
system_preamble: String,
messages: Vec<Message>,
extensions: Vec<ExtensionConfig>,
) -> Self {
Self {
provider_name,
provider_config,
model_config,
system_preamble,
messages,
extensions,
}
}
}
#[uniffi::export]
pub fn create_completion_request(
provider_name: &str,
provider_config: JsonValueFfi,
model_config: ModelConfig,
system_preamble: &str,
messages: Vec<Message>,
extensions: Vec<ExtensionConfig>,
) -> CompletionRequest {
CompletionRequest::new(
provider_name.to_string(),
provider_config,
model_config,
system_preamble.to_string(),
messages,
extensions,
)
}
uniffi::custom_type!(CompletionRequest, String, {
lower: |tc: &CompletionRequest| {
serde_json::to_string(&tc).unwrap()
},
try_lift: |s: String| {
Ok(serde_json::from_str(&s).unwrap())
},
});
// https://mozilla.github.io/uniffi-rs/latest/proc_macro/errors.html
#[derive(Debug, Error, uniffi::Error)]
#[uniffi(flat_error)]
pub enum CompletionError {
#[error("failed to create provider: {0}")]
UnknownProvider(String),
#[error("provider error: {0}")]
Provider(#[from] ProviderError),
#[error("template rendering error: {0}")]
Template(#[from] minijinja::Error),
#[error("json serialization error: {0}")]
Json(#[from] serde_json::Error),
#[error("tool not found error: {0}")]
ToolNotFound(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct CompletionResponse {
pub message: Message,
pub model: String,
pub usage: Usage,
pub runtime_metrics: RuntimeMetrics,
}
impl CompletionResponse {
pub fn new(
message: Message,
model: String,
usage: Usage,
runtime_metrics: RuntimeMetrics,
) -> Self {
Self {
message,
model,
usage,
runtime_metrics,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct RuntimeMetrics {
pub total_time_sec: f32,
pub total_time_sec_provider: f32,
pub tokens_per_second: Option<f64>,
}
impl RuntimeMetrics {
pub fn new(
total_time_sec: f32,
total_time_sec_provider: f32,
tokens_per_second: Option<f64>,
) -> Self {
Self {
total_time_sec,
total_time_sec_provider,
tokens_per_second,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
pub enum ToolApprovalMode {
Auto,
Manual,
Smart,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
pub struct ToolConfig {
pub name: String,
pub description: String,
pub input_schema: JsonValueFfi,
pub approval_mode: ToolApprovalMode,
}
impl ToolConfig {
pub fn new(
name: &str,
description: &str,
input_schema: JsonValueFfi,
approval_mode: ToolApprovalMode,
) -> Self {
Self {
name: name.to_string(),
description: description.to_string(),
input_schema,
approval_mode,
}
}
/// Convert the tool config to a core tool
pub fn to_core_tool(&self, name: Option<&str>) -> super::core::Tool {
let tool_name = name.unwrap_or(&self.name);
super::core::Tool::new(
tool_name,
self.description.clone(),
self.input_schema.clone(),
)
}
}
#[uniffi::export]
pub fn create_tool_config(
name: &str,
description: &str,
input_schema: JsonValueFfi,
approval_mode: ToolApprovalMode,
) -> ToolConfig {
ToolConfig::new(name, description, input_schema, approval_mode)
}
// — Register the newtypes with UniFFI, converting via JSON strings —
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct ExtensionConfig {
name: String,
instructions: Option<String>,
tools: Vec<ToolConfig>,
}
impl ExtensionConfig {
pub fn new(name: String, instructions: Option<String>, tools: Vec<ToolConfig>) -> Self {
Self {
name,
instructions,
tools,
}
}
/// Convert the tools to core tools with the extension name as a prefix
pub fn get_prefixed_tools(&self) -> Vec<super::core::Tool> {
self.tools
.iter()
.map(|tool| {
let name = format!("{}__{}", self.name, tool.name);
tool.to_core_tool(Some(&name))
})
.collect()
}
/// Get a map of prefixed tool names to their approval modes
pub fn get_prefixed_tool_configs(&self) -> HashMap<String, ToolConfig> {
self.tools
.iter()
.map(|tool| {
let name = format!("{}__{}", self.name, tool.name);
(name, tool.clone())
})
.collect()
}
}

View File

@@ -0,0 +1,131 @@
// This file defines core types that require serialization to
// construct payloads for LLM model providers and work with MCPs.
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum Content {
Text(TextContent),
Image(ImageContent),
}
impl Content {
pub fn text<S: Into<String>>(text: S) -> Self {
Content::Text(TextContent { text: text.into() })
}
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
Content::Image(ImageContent {
data: data.into(),
mime_type: mime_type.into(),
})
}
/// Get the text content if this is a TextContent variant
pub fn as_text(&self) -> Option<&str> {
match self {
Content::Text(text) => Some(&text.text),
_ => None,
}
}
/// Get the image content if this is an ImageContent variant
pub fn as_image(&self) -> Option<(&str, &str)> {
match self {
Content::Image(image) => Some((&image.data, &image.mime_type)),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct TextContent {
pub text: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct ImageContent {
pub data: String,
pub mime_type: String,
}
/// A tool that can be used by a model.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
/// The name of the tool
pub name: String,
/// A description of what the tool does
pub description: String,
/// A JSON Schema object defining the expected parameters for the tool
pub input_schema: serde_json::Value,
}
impl Tool {
/// Create a new tool with the given name and description
pub fn new<N, D>(name: N, description: D, input_schema: serde_json::Value) -> Self
where
N: Into<String>,
D: Into<String>,
{
Tool {
name: name.into(),
description: description.into(),
input_schema,
}
}
}
/// A tool call request that an extension can execute
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolCall {
/// The name of the tool to execute
pub name: String,
/// The parameters for the execution
pub arguments: serde_json::Value,
/// Whether the tool call needs approval before execution. Default is false.
pub needs_approval: bool,
}
impl ToolCall {
/// Create a new ToolUse with the given name and parameters
pub fn new<S: Into<String>>(name: S, arguments: serde_json::Value) -> Self {
Self {
name: name.into(),
arguments,
needs_approval: false,
}
}
/// Set needs_approval field
pub fn set_needs_approval(&mut self, flag: bool) {
self.needs_approval = flag;
}
}
#[non_exhaustive]
#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq, uniffi::Error)]
pub enum ToolError {
#[error("Invalid parameters: {0}")]
InvalidParameters(String),
#[error("Execution failed: {0}")]
ExecutionError(String),
#[error("Schema error: {0}")]
SchemaError(String),
#[error("Tool not found: {0}")]
NotFound(String),
}
pub type ToolResult<T> = std::result::Result<T, ToolError>;

View File

@@ -0,0 +1,18 @@
use serde_json::Value;
// `serde_json::Value` gets converted to a `String` to pass across the FFI.
// https://github.com/mozilla/uniffi-rs/blob/main/docs/manual/src/types/custom_types.md?plain=1
// https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/examples/custom-types/src/lib.rs#L63-L69
uniffi::custom_type!(Value, String, {
// Remote is required since 'Value' is from a different crate
remote,
lower: |obj| {
serde_json::to_string(&obj).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
pub type JsonValueFfi = Value;

View File

@@ -0,0 +1,3 @@
pub mod completion;
pub mod core;
pub mod json_value_ffi;

View File

@@ -0,0 +1,79 @@
use anyhow::Result;
use dotenv::dotenv;
use goose_llm::extractors::generate_session_name;
use goose_llm::message::Message;
use goose_llm::providers::errors::ProviderError;
fn should_run_test() -> Result<(), String> {
dotenv().ok();
if std::env::var("DATABRICKS_HOST").is_err() {
return Err("Missing DATABRICKS_HOST".to_string());
}
if std::env::var("DATABRICKS_TOKEN").is_err() {
return Err("Missing DATABRICKS_TOKEN".to_string());
}
Ok(())
}
async fn _generate_session_name(messages: &[Message]) -> Result<String, ProviderError> {
let provider_name = "databricks";
let provider_config = serde_json::json!({
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
generate_session_name(provider_name, provider_config, messages).await
}
#[tokio::test]
async fn test_generate_session_name_success() {
if should_run_test().is_err() {
println!("Skipping...");
return;
}
// Build a few messages with at least two user messages
let messages = vec![
Message::user().with_text("Hello, how are you?"),
Message::assistant().with_text("Im fine, thanks!"),
Message::user().with_text("Whats the weather in New York tomorrow?"),
];
let name = _generate_session_name(&messages)
.await
.expect("Failed to generate session name");
println!("Generated session name: {:?}", name);
// Should be non-empty and at most 4 words
let name = name.trim();
assert!(!name.is_empty(), "Name must not be empty");
let word_count = name.split_whitespace().count();
assert!(
word_count <= 4,
"Name must be 4 words or less, got {}: {}",
word_count,
name
)
}
#[tokio::test]
async fn test_generate_session_name_no_user() {
if should_run_test().is_err() {
println!("Skipping 'test_generate_session_name_no_user'. Databricks creds not set");
return;
}
// No user messages → expect ExecutionError
let messages = vec![
Message::assistant().with_text("System starting…"),
Message::assistant().with_text("All systems go."),
];
let err = _generate_session_name(&messages).await;
assert!(
matches!(err, Err(ProviderError::ExecutionError(_))),
"Expected ExecutionError when there are no user messages, got: {:?}",
err
);
}

View File

@@ -0,0 +1,88 @@
use anyhow::Result;
use dotenv::dotenv;
use goose_llm::extractors::generate_tooltip;
use goose_llm::message::{Message, MessageContent, ToolRequest};
use goose_llm::providers::errors::ProviderError;
use goose_llm::types::core::{Content, ToolCall};
use serde_json::json;
fn should_run_test() -> Result<(), String> {
dotenv().ok();
if std::env::var("DATABRICKS_HOST").is_err() {
return Err("Missing DATABRICKS_HOST".to_string());
}
if std::env::var("DATABRICKS_TOKEN").is_err() {
return Err("Missing DATABRICKS_TOKEN".to_string());
}
Ok(())
}
async fn _generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> {
let provider_name = "databricks";
let provider_config = serde_json::json!({
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
generate_tooltip(provider_name, provider_config, messages).await
}
#[tokio::test]
async fn test_generate_tooltip_simple() {
if should_run_test().is_err() {
println!("Skipping...");
return;
}
// Two plain-text messages
let messages = vec![
Message::user().with_text("Hello, how are you?"),
Message::assistant().with_text("I'm fine, thanks! How can I help?"),
];
let tooltip = _generate_tooltip(&messages)
.await
.expect("Failed to generate tooltip");
println!("Generated tooltip: {:?}", tooltip);
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
assert!(
tooltip.len() < 100,
"Tooltip should be reasonably short (<100 chars)"
);
}
#[tokio::test]
async fn test_generate_tooltip_with_tools() {
if should_run_test().is_err() {
println!("Skipping...");
return;
}
// 1) Assistant message with a tool request
let mut tool_req_msg = Message::assistant();
let req = ToolRequest {
id: "1".to_string(),
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))).into(),
};
tool_req_msg.content.push(MessageContent::ToolReq(req));
// 2) User message with the tool response
let tool_resp_msg = Message::user().with_tool_response(
"1",
Ok(vec![Content::text("The current time is 12:00 UTC")]).into(),
);
let messages = vec![tool_req_msg, tool_resp_msg];
let tooltip = _generate_tooltip(&messages)
.await
.expect("Failed to generate tooltip");
println!("Generated tooltip (tools): {:?}", tooltip);
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
assert!(
tooltip.len() < 100,
"Tooltip should be reasonably short (<100 chars)"
);
}

View File

@@ -0,0 +1,380 @@
use anyhow::Result;
use dotenv::dotenv;
use goose_llm::message::{Message, MessageContent};
use goose_llm::providers::base::Provider;
use goose_llm::providers::errors::ProviderError;
use goose_llm::providers::{databricks, openai};
use goose_llm::types::core::{Content, Tool};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
#[derive(Debug, Clone, Copy)]
enum TestStatus {
Passed,
Skipped,
Failed,
}
impl std::fmt::Display for TestStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestStatus::Passed => write!(f, ""),
TestStatus::Skipped => write!(f, "⏭️"),
TestStatus::Failed => write!(f, ""),
}
}
}
struct TestReport {
results: Mutex<HashMap<String, TestStatus>>,
}
impl TestReport {
fn new() -> Arc<Self> {
Arc::new(Self {
results: Mutex::new(HashMap::new()),
})
}
fn record_status(&self, provider: &str, status: TestStatus) {
let mut results = self.results.lock().unwrap();
results.insert(provider.to_string(), status);
}
fn record_pass(&self, provider: &str) {
self.record_status(provider, TestStatus::Passed);
}
fn record_skip(&self, provider: &str) {
self.record_status(provider, TestStatus::Skipped);
}
fn record_fail(&self, provider: &str) {
self.record_status(provider, TestStatus::Failed);
}
fn print_summary(&self) {
println!("\n============== Providers ==============");
let results = self.results.lock().unwrap();
let mut providers: Vec<_> = results.iter().collect();
providers.sort_by(|a, b| a.0.cmp(b.0));
for (provider, status) in providers {
println!("{} {}", status, provider);
}
println!("=======================================\n");
}
}
lazy_static::lazy_static! {
static ref TEST_REPORT: Arc<TestReport> = TestReport::new();
static ref ENV_LOCK: Mutex<()> = Mutex::new(());
}
/// Generic test harness for any Provider implementation
struct ProviderTester {
provider: Arc<dyn Provider>,
name: String,
}
impl ProviderTester {
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
Self {
provider: Arc::new(provider),
name,
}
}
async fn test_basic_response(&self) -> Result<()> {
let message = Message::user().with_text("Just say hello!");
let response = self
.provider
.complete("You are a helpful assistant.", &[message], &[])
.await?;
// For a basic response, we expect a single text response
assert_eq!(
response.message.content.len(),
1,
"Expected single content item in response"
);
// Verify we got a text response
assert!(
matches!(response.message.content[0], MessageContent::Text(_)),
"Expected text response"
);
Ok(())
}
async fn test_tool_usage(&self) -> Result<()> {
let weather_tool = Tool::new(
"get_weather",
"Get the weather for a location",
serde_json::json!({
"type": "object",
"required": ["location"],
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
}
}),
);
let message = Message::user().with_text("What's the weather like in San Francisco?");
let response1 = self
.provider
.complete(
"You are a helpful weather assistant.",
&[message.clone()],
&[weather_tool.clone()],
)
.await?;
println!("=== {}::reponse1 ===", self.name);
dbg!(&response1);
println!("===================");
// Verify we got a tool request
assert!(
response1
.message
.content
.iter()
.any(|content| matches!(content, MessageContent::ToolReq(_))),
"Expected tool request in response"
);
let id = &response1
.message
.content
.iter()
.filter_map(|message| message.as_tool_request())
.last()
.expect("got tool request")
.id;
let weather = Message::user().with_tool_response(
id,
Ok(vec![Content::text(
"
50°F°C
Precipitation: 0%
Humidity: 84%
Wind: 2 mph
Weather
Saturday 9:00 PM
Clear",
)])
.into(),
);
// Verify we construct a valid payload including the request/response pair for the next inference
let response2 = self
.provider
.complete(
"You are a helpful weather assistant.",
&[message, response1.message, weather],
&[weather_tool],
)
.await?;
println!("=== {}::reponse2 ===", self.name);
dbg!(&response2);
println!("===================");
assert!(
response2
.message
.content
.iter()
.any(|content| matches!(content, MessageContent::Text(_))),
"Expected text for final response"
);
Ok(())
}
async fn test_context_length_exceeded_error(&self) -> Result<()> {
// Google Gemini has a really long context window
let large_message_content = if self.name.to_lowercase() == "google" {
"hello ".repeat(1_300_000)
} else {
"hello ".repeat(300_000)
};
let messages = vec![
Message::user().with_text("hi there. what is 2 + 2?"),
Message::assistant().with_text("hey! I think it's 4."),
Message::user().with_text(&large_message_content),
Message::assistant().with_text("heyy!!"),
// Messages before this mark should be truncated
Message::user().with_text("what's the meaning of life?"),
Message::assistant().with_text("the meaning of life is 42"),
Message::user().with_text(
"did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'",
),
];
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
let result = self
.provider
.complete("You are a helpful assistant.", &messages, &[])
.await;
// Print some debug info
println!("=== {}::context_length_exceeded_error ===", self.name);
dbg!(&result);
println!("===================");
// Ollama truncates by default even when the context window is exceeded
if self.name.to_lowercase() == "ollama" {
assert!(
result.is_ok(),
"Expected to succeed because of default truncation"
);
return Ok(());
}
assert!(
result.is_err(),
"Expected error when context window is exceeded"
);
assert!(
matches!(result.unwrap_err(), ProviderError::ContextLengthExceeded(_)),
"Expected error to be ContextLengthExceeded"
);
Ok(())
}
/// Run all provider tests
async fn run_test_suite(&self) -> Result<()> {
self.test_basic_response().await?;
self.test_tool_usage().await?;
self.test_context_length_exceeded_error().await?;
Ok(())
}
}
fn load_env() {
if let Ok(path) = dotenv() {
println!("Loaded environment from {:?}", path);
}
}
/// Helper function to run a provider test with proper error handling and reporting
async fn test_provider<F, T>(
name: &str,
required_vars: &[&str],
env_modifications: Option<HashMap<&str, Option<String>>>,
provider_fn: F,
) -> Result<()>
where
F: FnOnce() -> T,
T: Provider + Send + Sync + 'static,
{
// We start off as failed, so that if the process panics it is seen as a failure
TEST_REPORT.record_fail(name);
// Take exclusive access to environment modifications
let lock = ENV_LOCK.lock().unwrap();
load_env();
// Save current environment state for required vars and modified vars
let mut original_env = HashMap::new();
for &var in required_vars {
if let Ok(val) = std::env::var(var) {
original_env.insert(var, val);
}
}
if let Some(mods) = &env_modifications {
for &var in mods.keys() {
if let Ok(val) = std::env::var(var) {
original_env.insert(var, val);
}
}
}
// Apply any environment modifications
if let Some(mods) = &env_modifications {
for (&var, value) in mods.iter() {
match value {
Some(val) => std::env::set_var(var, val),
None => std::env::remove_var(var),
}
}
}
// Setup the provider
let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
if missing_vars {
println!("Skipping {} tests - credentials not configured", name);
TEST_REPORT.record_skip(name);
return Ok(());
}
let provider = provider_fn();
// Restore original environment
for (&var, value) in original_env.iter() {
std::env::set_var(var, value);
}
if let Some(mods) = env_modifications {
for &var in mods.keys() {
if !original_env.contains_key(var) {
std::env::remove_var(var);
}
}
}
std::mem::drop(lock);
let tester = ProviderTester::new(provider, name.to_string());
match tester.run_test_suite().await {
Ok(_) => {
TEST_REPORT.record_pass(name);
Ok(())
}
Err(e) => {
println!("{} test failed: {}", name, e);
TEST_REPORT.record_fail(name);
Err(e)
}
}
}
#[tokio::test]
async fn openai_complete() -> Result<()> {
test_provider(
"OpenAI",
&["OPENAI_API_KEY"],
None,
openai::OpenAiProvider::default,
)
.await
}
#[tokio::test]
async fn databricks_complete() -> Result<()> {
test_provider(
"Databricks",
&["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
None,
databricks::DatabricksProvider::default,
)
.await
}
// Print the final test report
#[ctor::dtor]
fn print_test_report() {
TEST_REPORT.print_summary();
}

View File

@@ -0,0 +1,195 @@
// tests/providers_extract.rs
use anyhow::Result;
use dotenv::dotenv;
use goose_llm::message::Message;
use goose_llm::providers::base::Provider;
use goose_llm::providers::{databricks::DatabricksProvider, openai::OpenAiProvider};
use goose_llm::ModelConfig;
use serde_json::{json, Value};
use std::sync::Arc;
#[derive(Debug, PartialEq, Copy, Clone)]
enum ProviderType {
OpenAi,
Databricks,
}
impl ProviderType {
fn required_env(&self) -> &'static [&'static str] {
match self {
ProviderType::OpenAi => &["OPENAI_API_KEY"],
ProviderType::Databricks => &["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
}
}
fn create_provider(&self, cfg: ModelConfig) -> Result<Arc<dyn Provider>> {
Ok(match self {
ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(cfg)),
ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(cfg)),
})
}
}
fn check_required_env_vars(required: &[&str]) -> bool {
let missing: Vec<_> = required
.iter()
.filter(|&&v| std::env::var(v).is_err())
.cloned()
.collect();
if !missing.is_empty() {
println!("Skipping test; missing env vars: {:?}", missing);
false
} else {
true
}
}
// --- Shared inputs for "paper" task ---
const PAPER_SYSTEM: &str =
"You are an expert at structured data extraction. Extract the metadata of a research paper into JSON.";
const PAPER_TEXT: &str =
"Application of Quantum Algorithms in Interstellar Navigation: A New Frontier \
by Dr. Stella Voyager, Dr. Nova Star, Dr. Lyra Hunter. Abstract: This paper \
investigates the utilization of quantum algorithms to improve interstellar \
navigation systems. Keywords: Quantum algorithms, interstellar navigation, \
space-time anomalies, quantum superposition, quantum entanglement, space travel.";
fn paper_schema() -> Value {
json!({
"type": "object",
"properties": {
"title": { "type": "string" },
"authors": { "type": "array", "items": { "type": "string" } },
"abstract": { "type": "string" },
"keywords": { "type": "array", "items": { "type": "string" } }
},
"required": ["title","authors","abstract","keywords"],
"additionalProperties": false
})
}
// --- Shared inputs for "UI" task ---
const UI_SYSTEM: &str = "You are a UI generator AI. Convert the user input into a JSON-driven UI.";
const UI_TEXT: &str = "Make a User Profile Form";
fn ui_schema() -> Value {
json!({
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["div","button","header","section","field","form"]
},
"label": { "type": "string" },
"children": {
"type": "array",
"items": { "$ref": "#" }
},
"attributes": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"value": { "type": "string" }
},
"required": ["name","value"],
"additionalProperties": false
}
}
},
"required": ["type","label","children","attributes"],
"additionalProperties": false
})
}
/// Generic runner for any extract task
async fn run_extract_test<F>(
provider_type: ProviderType,
model: &str,
system: &'static str,
user_text: &'static str,
schema: Value,
validate: F,
) -> Result<()>
where
F: Fn(&Value) -> bool,
{
dotenv().ok();
if !check_required_env_vars(provider_type.required_env()) {
return Ok(());
}
let cfg = ModelConfig::new(model.to_string()).with_temperature(Some(0.0));
let provider = provider_type.create_provider(cfg)?;
let msg = Message::user().with_text(user_text);
let resp = provider.extract(system, &[msg], &schema).await?;
println!("[{:?}] extract => {}", provider_type, resp.data);
assert!(
validate(&resp.data),
"{:?} failed validation on {}",
provider_type,
resp.data
);
Ok(())
}
/// Helper for the "paper" task
async fn run_extract_paper_test(provider: ProviderType, model: &str) -> Result<()> {
run_extract_test(
provider,
model,
PAPER_SYSTEM,
PAPER_TEXT,
paper_schema(),
|v| {
v.as_object()
.map(|o| {
["title", "authors", "abstract", "keywords"]
.iter()
.all(|k| o.contains_key(*k))
})
.unwrap_or(false)
},
)
.await
}
/// Helper for the "UI" task
async fn run_extract_ui_test(provider: ProviderType, model: &str) -> Result<()> {
run_extract_test(provider, model, UI_SYSTEM, UI_TEXT, ui_schema(), |v| {
v.as_object()
.and_then(|o| o.get("type").and_then(Value::as_str))
== Some("form")
})
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn openai_extract_paper() -> Result<()> {
run_extract_paper_test(ProviderType::OpenAi, "gpt-4o").await
}
#[tokio::test]
async fn openai_extract_ui() -> Result<()> {
run_extract_ui_test(ProviderType::OpenAi, "gpt-4o").await
}
#[tokio::test]
async fn databricks_extract_paper() -> Result<()> {
run_extract_paper_test(ProviderType::Databricks, "goose-gpt-4-1").await
}
#[tokio::test]
async fn databricks_extract_ui() -> Result<()> {
run_extract_ui_test(ProviderType::Databricks, "goose-gpt-4-1").await
}
}

View File

@@ -0,0 +1,3 @@
fn main() {
uniffi::uniffi_bindgen_main()
}

View File

@@ -35,6 +35,7 @@ chrono = { version = "0.4.38", features = ["serde"] }
etcetera = "0.8.0"
tempfile = "3.8"
include_dir = "0.7.4"
google-apis-common = "7.0.0"
google-drive3 = "6.0.0"
google-sheets4 = "6.0.0"
google-docs1 = "6.0.0"
@@ -47,9 +48,21 @@ lopdf = "0.35.0"
docx-rs = "0.4.7"
image = "0.24.9"
umya-spreadsheet = "2.2.3"
keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service", "vendored"] }
keyring = { version = "3.6.1", features = [
"apple-native",
"windows-native",
"sync-secret-service",
"vendored",
] }
oauth2 = { version = "5.0.0", features = ["reqwest"] }
utoipa = { version = "4.1", optional = true }
hyper = "1"
serde_with = "3"
[dev-dependencies]
serial_test = "3.0.0"
sysinfo = "0.32.1"
[features]
utoipa = ["dep:utoipa"]

View File

@@ -8,6 +8,9 @@ use std::{
};
use tokio::process::Command;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
@@ -743,6 +746,23 @@ impl ComputerControllerRouter {
ToolError::ExecutionError(format!("Failed to write script: {}", e))
})?;
// Set execute permissions on Unix systems
#[cfg(unix)]
{
let mut perms = fs::metadata(&script_path)
.map_err(|e| {
ToolError::ExecutionError(format!("Failed to get file metadata: {}", e))
})?
.permissions();
perms.set_mode(0o755); // rwxr-xr-x
fs::set_permissions(&script_path, perms).map_err(|e| {
ToolError::ExecutionError(format!(
"Failed to set execute permissions: {}",
e
))
})?;
}
script_path.display().to_string()
}
"ruby" => {

View File

@@ -65,10 +65,7 @@ impl LinuxAutomation {
DisplayServer::X11 => self.check_x11_dependencies()?,
DisplayServer::Wayland => self.check_wayland_dependencies()?,
DisplayServer::Unknown => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Unable to detect display server",
));
return Err(std::io::Error::other("Unable to detect display server"));
}
}
@@ -106,10 +103,7 @@ impl LinuxAutomation {
match self.display_server {
DisplayServer::X11 => self.execute_x11_command(cmd),
DisplayServer::Wayland => self.execute_wayland_command(cmd),
DisplayServer::Unknown => Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Unknown display server",
)),
DisplayServer::Unknown => Err(std::io::Error::other("Unknown display server")),
}
}
@@ -236,8 +230,7 @@ impl SystemAutomation for LinuxAutomation {
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).into_owned())
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
Err(std::io::Error::other(
String::from_utf8_lossy(&output.stderr).into_owned(),
))
}

View File

@@ -0,0 +1,478 @@
#![allow(clippy::ptr_arg, dead_code, clippy::enum_variant_names)]
use std::collections::{BTreeSet, HashMap};
use google_apis_common as common;
use tokio::time::sleep;
/// A scope is needed when requesting an
/// [authorization token](https://developers.google.com/workspace/drive/labels/guides/authorize).
#[derive(PartialEq, Eq, Ord, PartialOrd, Hash, Debug, Clone, Copy)]
pub enum Scope {
/// View, use, and manage Drive labels.
DriveLabels,
/// View and use Drive labels.
DriveLabelsReadonly,
/// View, edit, create, and delete all Drive labels in your organization,
/// and view your organization's label-related administration policies.
DriveLabelsAdmin,
/// View all Drive labels and label-related administration policies in your
/// organization.
DriveLabelsAdminReadonly,
}
impl AsRef<str> for Scope {
fn as_ref(&self) -> &str {
match *self {
Scope::DriveLabels => "https://www.googleapis.com/auth/drive.labels",
Scope::DriveLabelsReadonly => "https://www.googleapis.com/auth/drive.labels.readonly",
Scope::DriveLabelsAdmin => "https://www.googleapis.com/auth/drive.admin.labels",
Scope::DriveLabelsAdminReadonly => {
"https://www.googleapis.com/auth/drive.admin.labels.readonly"
}
}
}
}
#[allow(clippy::derivable_impls)]
impl Default for Scope {
fn default() -> Scope {
Scope::DriveLabelsReadonly
}
}
#[derive(Clone)]
pub struct DriveLabelsHub<C> {
pub client: common::Client<C>,
pub auth: Box<dyn common::GetToken>,
_user_agent: String,
_base_url: String,
}
impl<C> common::Hub for DriveLabelsHub<C> {}
impl<'a, C> DriveLabelsHub<C> {
pub fn new<A: 'static + common::GetToken>(
client: common::Client<C>,
auth: A,
) -> DriveLabelsHub<C> {
DriveLabelsHub {
client,
auth: Box::new(auth),
_user_agent: "google-api-rust-client/6.0.0".to_string(),
_base_url: "https://drivelabels.googleapis.com/".to_string(),
}
}
pub fn labels(&'a self) -> LabelMethods<'a, C> {
LabelMethods { hub: self }
}
/// Set the user-agent header field to use in all requests to the server.
/// It defaults to `google-api-rust-client/6.0.0`.
///
/// Returns the previously set user-agent.
pub fn user_agent(&mut self, agent_name: String) -> String {
std::mem::replace(&mut self._user_agent, agent_name)
}
/// Set the base url to use in all requests to the server.
/// It defaults to `https://www.googleapis.com/drive/v3/`.
///
/// Returns the previously set base url.
pub fn base_url(&mut self, new_base_url: String) -> String {
std::mem::replace(&mut self._base_url, new_base_url)
}
}
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Label {
#[serde(rename = "name")]
pub name: Option<String>,
#[serde(rename = "id")]
pub id: Option<String>,
#[serde(rename = "revisionId")]
pub revision_id: Option<String>,
#[serde(rename = "labelType")]
pub label_type: Option<String>,
#[serde(rename = "creator")]
pub creator: Option<User>,
#[serde(rename = "createTime")]
pub create_time: Option<String>,
#[serde(rename = "revisionCreator")]
pub revision_creator: Option<User>,
#[serde(rename = "revisionCreateTime")]
pub revision_create_time: Option<String>,
#[serde(rename = "publisher")]
pub publisher: Option<User>,
#[serde(rename = "publishTime")]
pub publish_time: Option<String>,
#[serde(rename = "disabler")]
pub disabler: Option<User>,
#[serde(rename = "disableTime")]
pub disable_time: Option<String>,
#[serde(rename = "customer")]
pub customer: Option<String>,
pub properties: Option<LabelProperty>,
pub fields: Option<Vec<Field>>,
// We ignore the remaining fields.
}
impl common::Part for Label {}
impl common::ResponseResult for Label {}
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct LabelProperty {
pub title: Option<String>,
pub description: Option<String>,
}
impl common::Part for LabelProperty {}
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Field {
id: Option<String>,
#[serde(rename = "queryKey")]
query_key: Option<String>,
properties: Option<FieldProperty>,
#[serde(rename = "selectionOptions")]
selection_options: Option<SelectionOption>,
}
impl common::Part for Field {}
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct FieldProperty {
#[serde(rename = "displayName")]
pub display_name: Option<String>,
pub required: Option<bool>,
}
impl common::Part for FieldProperty {}
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct SelectionOption {
#[serde(rename = "listOptions")]
pub list_options: Option<String>,
pub choices: Option<Vec<Choice>>,
}
impl common::Part for SelectionOption {}
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Choice {
id: Option<String>,
properties: Option<ChoiceProperties>,
// We ignore the remaining fields.
}
impl common::Part for Choice {}
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ChoiceProperties {
#[serde(rename = "displayName")]
display_name: Option<String>,
description: Option<String>,
}
impl common::Part for ChoiceProperties {}
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct LabelList {
pub labels: Option<Vec<Label>>,
#[serde(rename = "nextPageToken")]
pub next_page_token: Option<String>,
}
impl common::ResponseResult for LabelList {}
/// Information about a Drive user.
///
/// This type is not used in any activity, and only used as *part* of another schema.
///
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde_with::serde_as]
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct User {
/// Output only. A plain text displayable name for this user.
#[serde(rename = "displayName")]
pub display_name: Option<String>,
/// Output only. The email address of the user. This may not be present in certain contexts if the user has not made their email address visible to the requester.
#[serde(rename = "emailAddress")]
pub email_address: Option<String>,
/// Output only. Identifies what kind of resource this is. Value: the fixed string `"drive#user"`.
pub kind: Option<String>,
/// Output only. Whether this user is the requesting user.
pub me: Option<bool>,
/// Output only. The user's ID as visible in Permission resources.
#[serde(rename = "permissionId")]
pub permission_id: Option<String>,
/// Output only. A link to the user's profile photo, if available.
#[serde(rename = "photoLink")]
pub photo_link: Option<String>,
}
impl common::Part for User {}
pub struct LabelMethods<'a, C>
where
C: 'a,
{
hub: &'a DriveLabelsHub<C>,
}
impl<C> common::MethodsBuilder for LabelMethods<'_, C> {}
impl<'a, C> LabelMethods<'a, C> {
/// Create a builder to help you perform the following tasks:
///
/// List labels
pub fn list(&self) -> LabelListCall<'a, C> {
LabelListCall {
hub: self.hub,
_delegate: Default::default(),
_additional_params: Default::default(),
_scopes: Default::default(),
}
}
}
/// Lists the workspace's labels.
pub struct LabelListCall<'a, C>
where
C: 'a,
{
hub: &'a DriveLabelsHub<C>,
_delegate: Option<&'a mut dyn common::Delegate>,
_additional_params: HashMap<String, String>,
_scopes: BTreeSet<String>,
}
impl<C> common::CallBuilder for LabelListCall<'_, C> {}
impl<'a, C> LabelListCall<'a, C>
where
C: common::Connector,
{
/// Perform the operation you have built so far.
pub async fn doit(mut self) -> common::Result<(common::Response, LabelList)> {
use common::url::Params;
use hyper::header::{AUTHORIZATION, CONTENT_LENGTH, USER_AGENT};
let mut dd = common::DefaultDelegate;
let dlg: &mut dyn common::Delegate = self._delegate.unwrap_or(&mut dd);
dlg.begin(common::MethodInfo {
id: "drivelabels.labels.list",
http_method: hyper::Method::GET,
});
for &field in ["alt"].iter() {
if self._additional_params.contains_key(field) {
dlg.finished(false);
return Err(common::Error::FieldClash(field));
}
}
// TODO: We don't handle any of the query params.
let mut params = Params::with_capacity(2 + self._additional_params.len());
params.extend(self._additional_params.iter());
params.push("alt", "json");
let url = self.hub._base_url.clone() + "v2/labels";
if self._scopes.is_empty() {
self._scopes
.insert(Scope::DriveLabelsReadonly.as_ref().to_string());
}
let url = params.parse_with_url(&url);
loop {
let token = match self
.hub
.auth
.get_token(&self._scopes.iter().map(String::as_str).collect::<Vec<_>>()[..])
.await
{
Ok(token) => token,
Err(e) => match dlg.token(e) {
Ok(token) => token,
Err(e) => {
dlg.finished(false);
return Err(common::Error::MissingToken(e));
}
},
};
let req_result = {
let client = &self.hub.client;
dlg.pre_request();
let mut req_builder = hyper::Request::builder()
.method(hyper::Method::GET)
.uri(url.as_str())
.header(USER_AGENT, self.hub._user_agent.clone());
if let Some(token) = token.as_ref() {
req_builder = req_builder.header(AUTHORIZATION, format!("Bearer {}", token));
}
let request = req_builder
.header(CONTENT_LENGTH, 0_u64)
.body(common::to_body::<String>(None));
client.request(request.unwrap()).await
};
match req_result {
Err(err) => {
if let common::Retry::After(d) = dlg.http_error(&err) {
sleep(d).await;
continue;
}
dlg.finished(false);
return Err(common::Error::HttpError(err));
}
Ok(res) => {
let (parts, body) = res.into_parts();
let body = common::Body::new(body);
if !parts.status.is_success() {
let bytes = common::to_bytes(body).await.unwrap_or_default();
let error = serde_json::from_str(&common::to_string(&bytes));
let response = common::to_response(parts, bytes.into());
if let common::Retry::After(d) =
dlg.http_failure(&response, error.as_ref().ok())
{
sleep(d).await;
continue;
}
dlg.finished(false);
return Err(match error {
Ok(value) => common::Error::BadRequest(value),
_ => common::Error::Failure(response),
});
}
let response = {
let bytes = common::to_bytes(body).await.unwrap_or_default();
let encoded = common::to_string(&bytes);
match serde_json::from_str(&encoded) {
Ok(decoded) => (common::to_response(parts, bytes.into()), decoded),
Err(error) => {
dlg.response_json_decode_error(&encoded, &error);
return Err(common::Error::JsonDecodeError(
encoded.to_string(),
error,
));
}
}
};
dlg.finished(true);
return Ok(response);
}
}
}
}
/// The delegate implementation is consulted whenever there is an intermediate result, or if something goes wrong
/// while executing the actual API request.
///
/// ````text
/// It should be used to handle progress information, and to implement a certain level of resilience.
/// ````
///
/// Sets the *delegate* property to the given value.
pub fn delegate(mut self, new_value: &'a mut dyn common::Delegate) -> LabelListCall<'a, C> {
self._delegate = Some(new_value);
self
}
/// Set any additional parameter of the query string used in the request.
/// It should be used to set parameters which are not yet available through their own
/// setters.
///
/// Please note that this method must not be used to set any of the known parameters
/// which have their own setter method. If done anyway, the request will fail.
///
/// # Additional Parameters
///
/// * *$.xgafv* (query-string) - V1 error format.
/// * *access_token* (query-string) - OAuth access token.
/// * *alt* (query-string) - Data format for response.
/// * *callback* (query-string) - JSONP
/// * *fields* (query-string) - Selector specifying which fields to include in a partial response.
/// * *key* (query-string) - API key. Your API key identifies your project and provides you with API access, quota, and reports. Required unless you provide an OAuth 2.0 token.
/// * *oauth_token* (query-string) - OAuth 2.0 token for the current user.
/// * *prettyPrint* (query-boolean) - Returns response with indentations and line breaks.
/// * *quotaUser* (query-string) - Available to use for quota purposes for server-side applications. Can be any arbitrary string assigned to a user, but should not exceed 40 characters.
/// * *uploadType* (query-string) - Legacy upload protocol for media (e.g. "media", "multipart").
/// * *upload_protocol* (query-string) - Upload protocol for media (e.g. "raw", "multipart").
pub fn param<T>(mut self, name: T, value: T) -> LabelListCall<'a, C>
where
T: AsRef<str>,
{
self._additional_params
.insert(name.as_ref().to_string(), value.as_ref().to_string());
self
}
/// Identifies the authorization scope for the method you are building.
///
/// Use this method to actively specify which scope should be used, instead of the default [`Scope`] variant
/// [`Scope::DriveLabelsReadonly`].
///
/// The `scope` will be added to a set of scopes. This is important as one can maintain access
/// tokens for more than one scope.
///
/// Usually there is more than one suitable scope to authorize an operation, some of which may
/// encompass more rights than others. For example, for listing resources, a *read-only* scope will be
/// sufficient, a read-write scope will do as well.
pub fn add_scope<St>(mut self, scope: St) -> LabelListCall<'a, C>
where
St: AsRef<str>,
{
self._scopes.insert(String::from(scope.as_ref()));
self
}
/// Identifies the authorization scope(s) for the method you are building.
///
/// See [`Self::add_scope()`] for details.
pub fn add_scopes<I, St>(mut self, scopes: I) -> LabelListCall<'a, C>
where
I: IntoIterator<Item = St>,
St: AsRef<str>,
{
self._scopes
.extend(scopes.into_iter().map(|s| String::from(s.as_ref())));
self
}
/// Removes all scopes, and no default scope will be used either.
/// In this case, you have to specify your API-key using the `key` parameter (see [`Self::param()`]
/// for details).
pub fn clear_scopes(mut self) -> LabelListCall<'a, C> {
self._scopes.clear();
self
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -46,6 +46,7 @@ struct TokenData {
#[serde(skip_serializing_if = "Option::is_none")]
expires_at: Option<u64>,
project_id: String,
scopes: Vec<String>,
}
/// PkceOAuth2Client implements the GetToken trait required by DriveHub
@@ -56,6 +57,7 @@ pub struct PkceOAuth2Client {
credentials_manager: Arc<CredentialsManager>,
http_client: reqwest::Client,
project_id: String,
scopes: Vec<String>,
}
impl PkceOAuth2Client {
@@ -69,6 +71,7 @@ impl PkceOAuth2Client {
// Extract the project_id from the config
let project_id = config.installed.project_id.clone();
let scopes = vec![];
// Create OAuth URLs
let auth_url =
@@ -97,6 +100,7 @@ impl PkceOAuth2Client {
credentials_manager,
http_client,
project_id,
scopes,
})
}
@@ -185,6 +189,7 @@ impl PkceOAuth2Client {
refresh_token: refresh_token_str.clone(),
expires_at,
project_id: self.project_id.clone(),
scopes: scopes.iter().map(|s| s.to_string()).collect(),
};
// Store updated token data
@@ -239,6 +244,7 @@ impl PkceOAuth2Client {
refresh_token: new_refresh_token.clone(),
expires_at,
project_id: self.project_id.clone(),
scopes: self.scopes.clone(),
};
// Store updated token data
@@ -315,37 +321,66 @@ impl GetToken for PkceOAuth2Client {
if let Ok(token_data) = self.credentials_manager.read_credentials::<TokenData>() {
// Verify the project_id matches
if token_data.project_id == self.project_id {
// Check if the token is expired or expiring within a 5-min buffer
if !self.is_token_expired(token_data.expires_at, 300) {
return Ok(Some(token_data.access_token));
}
// Convert stored scopes to &str slices for comparison
let stored_scope_refs: Vec<&str> =
token_data.scopes.iter().map(|s| s.as_str()).collect();
// Token is expired or will expire soon, try to refresh it
debug!("Token is expired or will expire soon, refreshing...");
// Check if we need additional scopes
let needs_additional_scopes = scopes.iter().any(|&scope| {
!stored_scope_refs
.iter()
.any(|&stored| stored.contains(scope))
});
// Try to refresh the token
if let Ok(access_token) = self.refresh_token(&token_data.refresh_token).await {
debug!("Successfully refreshed access token");
return Ok(Some(access_token));
if !needs_additional_scopes {
// Check if the token is expired or expiring within a 5-min buffer
if !self.is_token_expired(token_data.expires_at, 300) {
return Ok(Some(token_data.access_token));
}
// Token is expired or will expire soon, try to refresh it
debug!("Token is expired or will expire soon, refreshing...");
// Try to refresh the token
if let Ok(access_token) =
self.refresh_token(&token_data.refresh_token).await
{
debug!("Successfully refreshed access token");
return Ok(Some(access_token));
}
} else {
// Only allocate new strings when we need to combine scopes
let mut combined_scopes: Vec<&str> =
Vec::with_capacity(scopes.len() + stored_scope_refs.len());
combined_scopes.extend(scopes);
combined_scopes.extend(stored_scope_refs.iter().filter(|&&stored| {
!scopes.iter().any(|&scope| stored.contains(scope))
}));
return self
.perform_oauth_flow(&combined_scopes)
.await
.map(Some)
.map_err(|e| {
error!("OAuth flow failed: {}", e);
e
});
}
}
}
// If we get here, either:
// 1. The project ID didn't match
// 2. Token refresh failed
// 3. There are no valid tokens yet
// 4. We didn't have to change the scopes of an existing token
// Fallback: perform interactive OAuth flow
match self.perform_oauth_flow(scopes).await {
Ok(token) => {
debug!("Successfully obtained new access token through OAuth flow");
Ok(Some(token))
}
Err(e) => {
self.perform_oauth_flow(scopes)
.await
.map(Some)
.map_err(|e| {
error!("OAuth flow failed: {}", e);
Err(e)
}
}
e
})
})
}
}

View File

@@ -1,3 +1,5 @@
IMPORTANT: currently GOOSE_ALLOWLIST is used in main.ts in ui/desktop, and not in goose-server. The following is for reference in case it is used on the server side for launch time enforcement.
# Goose Extension Allowlist
The allowlist feature provides a security mechanism for controlling which MCP commands can be used by goose.
@@ -24,9 +26,11 @@ If this environment variable is not set, no allowlist restrictions will be appli
In certain development or testing scenarios, you may need to bypass the allowlist restrictions. You can do this by setting the `GOOSE_ALLOWLIST_BYPASS` environment variable to `true`:
```bash
export GOOSE_ALLOWLIST_BYPASS=true
# For the GUI, you can have it show a warning instead of blocking (but it will always show a warning):
export GOOSE_ALLOWLIST_WARNING=true
```
When this environment variable is set to `true` (case insensitive), the allowlist check will be bypassed and all commands will be allowed, even if the `GOOSE_ALLOWLIST` environment variable is set.
## Allowlist File Format

View File

@@ -12,9 +12,10 @@ goose = { path = "../goose" }
mcp-core = { path = "../mcp-core" }
goose-mcp = { path = "../goose-mcp" }
mcp-server = { path = "../mcp-server" }
axum = { version = "0.7.2", features = ["ws", "macros"] }
axum = { version = "0.8.1", features = ["ws", "macros"] }
tokio = { version = "1.43", features = ["full"] }
chrono = "0.4"
tokio-cron-scheduler = "0.14.0"
tower-http = { version = "0.5", features = ["cors"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
@@ -26,6 +27,7 @@ tokio-stream = "0.1"
anyhow = "1.0"
bytes = "1.5"
http = "1.0"
base64 = "0.21"
config = { version = "0.14.1", features = ["toml"] }
thiserror = "1.0"
clap = { version = "4.4", features = ["derive"] }
@@ -33,7 +35,7 @@ once_cell = "1.20.2"
etcetera = "0.8.0"
serde_yaml = "0.9.34"
axum-extra = "0.10.0"
utoipa = { version = "4.1", features = ["axum_extras"] }
utoipa = { version = "4.1", features = ["axum_extras", "chrono"] }
dirs = "6.0.0"
reqwest = { version = "0.12.9", features = ["json", "rustls-tls", "blocking"], default-features = false }

View File

@@ -3,7 +3,10 @@ use std::sync::Arc;
use crate::configuration;
use crate::state;
use anyhow::Result;
use etcetera::{choose_app_strategy, AppStrategy};
use goose::agents::Agent;
use goose::config::APP_STRATEGY;
use goose::scheduler::Scheduler as GooseScheduler;
use tower_http::cors::{Any, CorsLayer};
use tracing::info;
@@ -11,27 +14,30 @@ pub async fn run() -> Result<()> {
// Initialize logging
crate::logging::setup_logging(Some("goosed"))?;
// Load configuration
let settings = configuration::Settings::new()?;
// load secret key from GOOSE_SERVER__SECRET_KEY environment variable
let secret_key =
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
let new_agent = Agent::new();
let agent_ref = Arc::new(new_agent);
// Create app state with agent
let state = state::AppState::new(Arc::new(new_agent), secret_key.clone()).await;
let app_state = state::AppState::new(agent_ref.clone(), secret_key.clone()).await;
let schedule_file_path = choose_app_strategy(APP_STRATEGY.clone())?
.data_dir()
.join("schedules.json");
let scheduler_instance = GooseScheduler::new(schedule_file_path).await?;
app_state.set_scheduler(scheduler_instance).await;
// Create router with CORS support
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = crate::routes::configure(state).layer(cors);
let app = crate::routes::configure(app_state).layer(cors);
// Run server
let listener = tokio::net::TcpListener::bind(settings.socket_addr()).await?;
info!("listening on {}", listener.local_addr()?);
axum::serve(listener, app).await?;

View File

@@ -8,6 +8,7 @@ use tracing_subscriber::{
Registry,
};
use goose::config::APP_STRATEGY;
use goose::tracing::langfuse_layer;
/// Returns the directory where log files should be stored.
@@ -17,8 +18,8 @@ fn get_log_directory() -> Result<PathBuf> {
// - macOS/Linux: ~/.local/state/goose/logs/server
// - Windows: ~\AppData\Roaming\Block\goose\data\logs\server
// - Windows has no convention for state_dir, use data_dir instead
let home_dir = choose_app_strategy(crate::APP_STRATEGY.clone())
.context("HOME environment variable not set")?;
let home_dir =
choose_app_strategy(APP_STRATEGY.clone()).context("HOME environment variable not set")?;
let base_log_dir = home_dir
.in_state_dir("logs/server")

View File

@@ -1,12 +1,3 @@
use etcetera::AppStrategyArgs;
use once_cell::sync::Lazy;
pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
top_level_domain: "Block".to_string(),
author: "Block".to_string(),
app_name: "goose".to_string(),
});
mod commands;
mod configuration;
mod error;

View File

@@ -3,8 +3,18 @@ use goose::agents::extension::ToolInfo;
use goose::agents::ExtensionConfig;
use goose::config::permission::PermissionLevel;
use goose::config::ExtensionEntry;
use goose::message::{
ContextLengthExceeded, FrontendToolRequest, Message, MessageContent, RedactedThinkingContent,
SummarizationRequested, ThinkingContent, ToolConfirmationRequest, ToolRequest, ToolResponse,
};
use goose::permission::permission_confirmation::PrincipalType;
use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata};
use goose::session::info::SessionInfo;
use goose::session::SessionMetadata;
use mcp_core::content::{Annotations, Content, EmbeddedResource, ImageContent, TextContent};
use mcp_core::handler::ToolResultSchema;
use mcp_core::resource::ResourceContents;
use mcp_core::role::Role;
use mcp_core::tool::{Tool, ToolAnnotations};
use utoipa::OpenApi;
@@ -25,6 +35,17 @@ use utoipa::OpenApi;
super::routes::config_management::upsert_permissions,
super::routes::agent::get_tools,
super::routes::reply::confirm_permission,
super::routes::context::manage_context,
super::routes::session::list_sessions,
super::routes::session::get_session_history,
super::routes::schedule::create_schedule,
super::routes::schedule::list_schedules,
super::routes::schedule::delete_schedule,
super::routes::schedule::update_schedule,
super::routes::schedule::run_now_handler,
super::routes::schedule::pause_schedule,
super::routes::schedule::unpause_schedule,
super::routes::schedule::sessions_handler
),
components(schemas(
super::routes::config_management::UpsertConfigQuery,
@@ -37,6 +58,28 @@ use utoipa::OpenApi;
super::routes::config_management::ToolPermission,
super::routes::config_management::UpsertPermissionsQuery,
super::routes::reply::PermissionConfirmationRequest,
super::routes::context::ContextManageRequest,
super::routes::context::ContextManageResponse,
super::routes::session::SessionListResponse,
super::routes::session::SessionHistoryResponse,
Message,
MessageContent,
Content,
EmbeddedResource,
ImageContent,
Annotations,
TextContent,
ToolResponse,
ToolRequest,
ToolResultSchema,
ToolConfirmationRequest,
ThinkingContent,
RedactedThinkingContent,
FrontendToolRequest,
ResourceContents,
ContextLengthExceeded,
SummarizationRequested,
Role,
ProviderMetadata,
ExtensionEntry,
ExtensionConfig,
@@ -48,6 +91,15 @@ use utoipa::OpenApi;
PermissionLevel,
PrincipalType,
ModelInfo,
SessionInfo,
SessionMetadata,
super::routes::schedule::CreateScheduleRequest,
super::routes::schedule::UpdateScheduleRequest,
goose::scheduler::ScheduledJob,
super::routes::schedule::RunNowResponse,
super::routes::schedule::ListSchedulesResponse,
super::routes::schedule::SessionsQuery,
super::routes::schedule::SessionDisplayInfo,
))
)]
pub struct ApiDoc;

View File

@@ -6,15 +6,16 @@ use axum::{
routing::{delete, get, post},
Json, Router,
};
use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs};
use etcetera::{choose_app_strategy, AppStrategy};
use goose::config::Config;
use goose::config::APP_STRATEGY;
use goose::config::{extensions::name_to_key, PermissionManager};
use goose::config::{ExtensionConfigManager, ExtensionEntry};
use goose::model::ModelConfig;
use goose::providers::base::ProviderMetadata;
use goose::providers::providers as get_providers;
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
use http::{HeaderMap, StatusCode};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_yaml;
@@ -51,14 +52,12 @@ pub struct ConfigResponse {
pub config: HashMap<String, Value>,
}
// Define a new structure to encapsulate the provider details along with configuration status
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ProviderDetails {
/// Unique identifier and name of the provider
pub name: String,
/// Metadata about the provider
pub metadata: ProviderMetadata,
/// Indicates whether the provider is fully configured
pub is_configured: bool,
}
@@ -69,7 +68,6 @@ pub struct ProvidersResponse {
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ToolPermission {
/// Unique identifier and name of the tool, format <extension_name>__<tool_name>
pub tool_name: String,
pub permission: PermissionLevel,
}
@@ -93,7 +91,6 @@ pub async fn upsert_config(
headers: HeaderMap,
Json(query): Json<UpsertConfigQuery>,
) -> Result<Json<Value>, StatusCode> {
// Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let config = Config::global();
@@ -120,12 +117,10 @@ pub async fn remove_config(
headers: HeaderMap,
Json(query): Json<ConfigKeyQuery>,
) -> Result<Json<String>, StatusCode> {
// Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let config = Config::global();
// Check if the secret flag is true and call the appropriate method
let result = if query.is_secret {
config.delete_secret(&query.key)
} else {
@@ -141,7 +136,7 @@ pub async fn remove_config(
#[utoipa::path(
post,
path = "/config/read",
request_body = ConfigKeyQuery, // Switch back to request_body
request_body = ConfigKeyQuery,
responses(
(status = 200, description = "Configuration value retrieved successfully", body = Value),
(status = 404, description = "Configuration key not found")
@@ -154,16 +149,20 @@ pub async fn read_config(
) -> Result<Json<Value>, StatusCode> {
verify_secret_key(&headers, &state)?;
if query.key == "model-limits" {
let limits = ModelConfig::get_all_model_limits();
return Ok(Json(
serde_json::to_value(limits).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
));
}
let config = Config::global();
match config.get(&query.key, query.is_secret) {
// Always get the actual value
Ok(value) => {
if query.is_secret {
// If it's marked as secret, return a boolean indicating presence
Ok(Json(Value::Bool(true)))
} else {
// Return the actual value if not secret
Ok(Json(value))
}
}
@@ -188,7 +187,6 @@ pub async fn get_extensions(
match ExtensionConfigManager::get_all() {
Ok(extensions) => Ok(Json(ExtensionResponse { extensions })),
Err(err) => {
// Return UNPROCESSABLE_ENTITY only for DeserializeError, INTERNAL_SERVER_ERROR for everything else
if err
.downcast_ref::<goose::config::base::ConfigError>()
.is_some_and(|e| matches!(e, goose::config::base::ConfigError::DeserializeError(_)))
@@ -219,7 +217,6 @@ pub async fn add_extension(
) -> Result<Json<String>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Get existing extensions to check if this is an update
let extensions =
ExtensionConfigManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let key = name_to_key(&extension_query.name);
@@ -275,12 +272,10 @@ pub async fn read_all_config(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<ConfigResponse>, StatusCode> {
// Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let config = Config::global();
// Load values from config file
let values = config
.load_values()
.map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
@@ -288,7 +283,6 @@ pub async fn read_all_config(
Ok(Json(ConfigResponse { config: values }))
}
// Modified providers function using the new response type
#[utoipa::path(
get,
path = "/config/providers",
@@ -302,14 +296,11 @@ pub async fn providers(
) -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Fetch the list of providers, which are likely stored in the AppState or can be retrieved via a function call
let providers_metadata = get_providers();
// Construct the response by checking configuration status for each provider
let providers_response: Vec<ProviderDetails> = providers_metadata
.into_iter()
.map(|metadata| {
// Check if the provider is configured (this will depend on how you track configuration status)
let is_configured = check_provider_configured(&metadata);
ProviderDetails {
@@ -339,21 +330,16 @@ pub async fn init_config(
let config = Config::global();
// 200 if config already exists
if config.exists() {
return Ok(Json("Config already exists".to_string()));
}
// Find the workspace root (where the top-level Cargo.toml with [workspace] is)
let workspace_root = match std::env::current_exe() {
Ok(mut exe_path) => {
// Start from the executable's directory and traverse up
while let Some(parent) = exe_path.parent() {
let cargo_toml = parent.join("Cargo.toml");
if cargo_toml.exists() {
// Read the Cargo.toml file
if let Ok(content) = std::fs::read_to_string(&cargo_toml) {
// Check if it contains [workspace]
if content.contains("[workspace]") {
exe_path = parent.to_path_buf();
break;
@@ -367,7 +353,6 @@ pub async fn init_config(
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
// Check if init-config.yaml exists at workspace root
let init_config_path = workspace_root.join("init-config.yaml");
if !init_config_path.exists() {
return Ok(Json(
@@ -375,7 +360,6 @@ pub async fn init_config(
));
}
// Read init-config.yaml and validate
let init_content = match std::fs::read_to_string(&init_config_path) {
Ok(content) => content,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
@@ -385,7 +369,6 @@ pub async fn init_config(
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
// Save init-config.yaml to ~/.config/goose/config.yaml
match config.save_values(init_values) {
Ok(_) => Ok(Json("Config initialized successfully".to_string())),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
@@ -409,7 +392,7 @@ pub async fn upsert_permissions(
verify_secret_key(&headers, &state)?;
let mut permission_manager = PermissionManager::default();
// Iterate over each tool permission and update permissions
for tool_permission in &query.tool_permissions {
permission_manager.update_user_permission(
&tool_permission.tool_name,
@@ -420,12 +403,6 @@ pub async fn upsert_permissions(
Ok(Json("Permissions updated successfully".to_string()))
}
pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
top_level_domain: "Block".to_string(),
author: "Block".to_string(),
app_name: "goose".to_string(),
});
#[utoipa::path(
post,
path = "/config/backup",
@@ -451,11 +428,9 @@ pub async fn backup_config(
.file_name()
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
// Append ".bak" to the file name
let mut backup_name = file_name.to_os_string();
backup_name.push(".bak");
// Construct the new path with the same parent directory
let backup = config_path.with_file_name(backup_name);
match std::fs::rename(&config_path, &backup) {
Ok(_) => Ok(Json(format!("Moved {:?} to {:?}", config_path, backup))),
@@ -474,10 +449,55 @@ pub fn routes(state: Arc<AppState>) -> Router {
.route("/config/read", post(read_config))
.route("/config/extensions", get(get_extensions))
.route("/config/extensions", post(add_extension))
.route("/config/extensions/:name", delete(remove_extension))
.route("/config/extensions/{name}", delete(remove_extension))
.route("/config/providers", get(providers))
.route("/config/init", post(init_config))
.route("/config/backup", post(backup_config))
.route("/config/permissions", post(upsert_permissions))
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_read_model_limits() {
let test_state = AppState::new(
Arc::new(goose::agents::Agent::default()),
"test".to_string(),
)
.await;
let sched_storage_path = choose_app_strategy(APP_STRATEGY.clone())
.unwrap()
.data_dir()
.join("schedules.json");
let sched = goose::scheduler::Scheduler::new(sched_storage_path)
.await
.unwrap();
test_state.set_scheduler(sched).await;
let mut headers = HeaderMap::new();
headers.insert("X-Secret-Key", "test".parse().unwrap());
let result = read_config(
State(test_state),
headers,
Json(ConfigKeyQuery {
key: "model-limits".to_string(),
is_secret: false,
}),
)
.await;
assert!(result.is_ok());
let response = result.unwrap();
let limits: Vec<goose::model::ModelLimitConfig> =
serde_json::from_value(response.0).unwrap();
assert!(!limits.is_empty());
let gpt4_limit = limits.iter().find(|l| l.pattern == "gpt-4o");
assert!(gpt4_limit.is_some());
assert_eq!(gpt4_limit.unwrap().context_limit, 128_000);
}
}

View File

@@ -1,229 +0,0 @@
use super::utils::verify_secret_key;
use crate::state::AppState;
use axum::{
extract::{Query, State},
routing::{delete, get, post},
Json, Router,
};
use goose::config::Config;
use http::{HeaderMap, StatusCode};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, sync::Arc};
#[derive(Serialize)]
struct ConfigResponse {
error: bool,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct ConfigRequest {
key: String,
value: String,
is_secret: bool,
}
async fn store_config(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<ConfigRequest>,
) -> Result<Json<ConfigResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let config = Config::global();
let result = if request.is_secret {
config.set_secret(&request.key, Value::String(request.value))
} else {
config.set_param(&request.key, Value::String(request.value))
};
match result {
Ok(_) => Ok(Json(ConfigResponse { error: false })),
Err(_) => Ok(Json(ConfigResponse { error: true })),
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderConfigRequest {
pub providers: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ConfigStatus {
pub is_set: bool,
pub location: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderResponse {
pub supported: bool,
pub name: Option<String>,
pub description: Option<String>,
pub models: Option<Vec<String>>,
pub config_status: HashMap<String, ConfigStatus>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ProviderConfig {
name: String,
description: String,
models: Vec<String>,
required_keys: Vec<String>,
}
static PROVIDER_ENV_REQUIREMENTS: Lazy<HashMap<String, ProviderConfig>> = Lazy::new(|| {
let contents = include_str!("providers_and_keys.json");
serde_json::from_str(contents).expect("Failed to parse providers_and_keys.json")
});
fn check_key_status(config: &Config, key: &str) -> (bool, Option<String>) {
if let Ok(_value) = std::env::var(key) {
(true, Some("env".to_string()))
} else if config.get_param::<String>(key).is_ok() {
(true, Some("yaml".to_string()))
} else if config.get_secret::<String>(key).is_ok() {
(true, Some("keyring".to_string()))
} else {
(false, None)
}
}
async fn check_provider_configs(
Json(request): Json<ProviderConfigRequest>,
) -> Result<Json<HashMap<String, ProviderResponse>>, StatusCode> {
let mut response = HashMap::new();
let config = Config::global();
for provider_name in request.providers {
if let Some(provider_config) = PROVIDER_ENV_REQUIREMENTS.get(&provider_name) {
let mut config_status = HashMap::new();
for key in &provider_config.required_keys {
let (key_set, key_location) = check_key_status(config, key);
config_status.insert(
key.to_string(),
ConfigStatus {
is_set: key_set,
location: key_location,
},
);
}
response.insert(
provider_name,
ProviderResponse {
supported: true,
name: Some(provider_config.name.clone()),
description: Some(provider_config.description.clone()),
models: Some(provider_config.models.clone()),
config_status,
},
);
} else {
response.insert(
provider_name,
ProviderResponse {
supported: false,
name: None,
description: None,
models: None,
config_status: HashMap::new(),
},
);
}
}
Ok(Json(response))
}
#[derive(Deserialize)]
pub struct GetConfigQuery {
key: String,
}
#[derive(Serialize)]
pub struct GetConfigResponse {
value: Option<String>,
}
pub async fn get_config(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(query): Query<GetConfigQuery>,
) -> Result<Json<GetConfigResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Fetch the configuration value. Right now we don't allow get a secret.
let config = Config::global();
let value = if let Ok(config_value) = config.get_param::<String>(&query.key) {
Some(config_value)
} else {
std::env::var(&query.key).ok()
};
// Return the value
Ok(Json(GetConfigResponse { value }))
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct DeleteConfigRequest {
key: String,
is_secret: bool,
}
async fn delete_config(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<DeleteConfigRequest>,
) -> Result<StatusCode, StatusCode> {
verify_secret_key(&headers, &state)?;
// Attempt to delete the key
let config = Config::global();
let result = if request.is_secret {
config.delete_secret(&request.key)
} else {
config.delete(&request.key)
};
match result {
Ok(_) => Ok(StatusCode::NO_CONTENT),
Err(_) => Err(StatusCode::NOT_FOUND),
}
}
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/configs/providers", post(check_provider_configs))
.route("/configs/get", get(get_config))
.route("/configs/store", post(store_config))
.route("/configs/delete", delete(delete_config))
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_unsupported_provider() {
// Setup
let request = ProviderConfigRequest {
providers: vec!["unsupported_provider".to_string()],
};
// Execute
let result = check_provider_configs(Json(request)).await;
// Assert
assert!(result.is_ok());
let Json(response) = result.unwrap();
let provider_status = response
.get("unsupported_provider")
.expect("Provider should exist");
assert!(!provider_status.supported);
assert!(provider_status.config_status.is_empty());
}
}

View File

@@ -9,23 +9,43 @@ use axum::{
use goose::message::Message;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use utoipa::ToSchema;
// Direct message serialization for context mgmt request
#[derive(Debug, Deserialize)]
/// Request payload for context management operations
#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ContextManageRequest {
messages: Vec<Message>,
manage_action: String,
/// Collection of messages to be managed
pub messages: Vec<Message>,
/// Operation to perform: "truncation" or "summarize"
pub manage_action: String,
}
// Direct message serialization for context mgmt request
#[derive(Debug, Serialize)]
/// Response from context management operations
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ContextManageResponse {
messages: Vec<Message>,
token_counts: Vec<usize>,
/// Processed messages after the operation
pub messages: Vec<Message>,
/// Token counts for each processed message
pub token_counts: Vec<usize>,
}
#[utoipa::path(
post,
path = "/context/manage",
request_body = ContextManageRequest,
responses(
(status = 200, description = "Context managed successfully", body = ContextManageResponse),
(status = 401, description = "Unauthorized - Invalid or missing API key"),
(status = 412, description = "Precondition failed - Agent not available"),
(status = 500, description = "Internal server error")
),
security(
("api_key" = [])
),
tag = "Context Management"
)]
async fn manage_context(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
@@ -40,7 +60,8 @@ async fn manage_context(
let mut processed_messages: Vec<Message> = vec![];
let mut token_counts: Vec<usize> = vec![];
if request.manage_action == "trunction" {
if request.manage_action == "truncation" {
(processed_messages, token_counts) = agent
.truncate_context(&request.messages)
.await

Some files were not shown because too many files have changed in this diff Show More