This commit is contained in:
2025-06-05 10:29:02 +02:00
364 changed files with 14643 additions and 6840 deletions

View File

@@ -20,6 +20,10 @@ on:
type: string
required: false
default: '["x86_64","aarch64"]'
ref:
type: string
required: false
default: 'refs/heads/main'
name: "Reusable workflow to build CLI"
@@ -41,6 +45,9 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
with:
ref: ${{ inputs.ref }}
fetch-depth: 0
- name: Update version in Cargo.toml
if: ${{ inputs.version != '' }}
@@ -48,14 +55,8 @@ jobs:
sed -i.bak 's/^version = ".*"/version = "'${{ inputs.version }}'"/' Cargo.toml
rm -f Cargo.toml.bak
- name: Setup Rust
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
with:
toolchain: stable
target: ${{ matrix.architecture }}-${{ matrix.target-suffix }}
- name: Install cross
run: cargo install cross --git https://github.com/cross-rs/cross
run: source ./bin/activate-hermit && cargo install cross --git https://github.com/cross-rs/cross
- name: Build CLI
env:
@@ -64,6 +65,7 @@ jobs:
RUST_BACKTRACE: 1
CROSS_VERBOSE: 1
run: |
source ./bin/activate-hermit
export TARGET="${{ matrix.architecture }}-${{ matrix.target-suffix }}"
rustup target add "${TARGET}"
echo "Building for target: ${TARGET}"
@@ -72,7 +74,7 @@ jobs:
echo "Cross version:"
cross --version
# 'cross' is used to cross-compile for different architectures (see Cross.toml)
echo "Building with explicit PROTOC path..."
cross build --release --target ${TARGET} -p goose-cli -vv
# tar the goose binary as goose-<TARGET>.tar.bz2

View File

@@ -21,6 +21,10 @@ on:
required: false
default: true
type: boolean
ref:
type: string
required: false
default: 'refs/heads/main'
secrets:
CERTIFICATE_OSX_APPLICATION:
description: 'Certificate for macOS application signing'
@@ -77,6 +81,9 @@ jobs:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
ref: ${{ inputs.ref }}
fetch-depth: 0
# Update versions before build
- name: Update versions
@@ -87,18 +94,14 @@ jobs:
rm -f Cargo.toml.bak
# Update version in package.json
source ./bin/activate-hermit
cd ui/desktop
npm version ${{ inputs.version }} --no-git-tag-version --allow-same-version
- name: Setup Rust
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
with:
toolchain: stable
targets: x86_64-apple-darwin
# Pre-build cleanup to ensure enough disk space
- name: Pre-build cleanup
run: |
source ./bin/activate-hermit
echo "Performing pre-build cleanup..."
# Clean npm cache
npm cache clean --force || true
@@ -137,7 +140,7 @@ jobs:
# Build specifically for Intel architecture
- name: Build goosed for Intel
run: cargo build --release -p goose-server --target x86_64-apple-darwin
run: source ./bin/activate-hermit && cargo build --release -p goose-server --target x86_64-apple-darwin
# Post-build cleanup to free space
- name: Post-build cleanup
@@ -164,13 +167,8 @@ jobs:
CERTIFICATE_OSX_APPLICATION: ${{ secrets.CERTIFICATE_OSX_APPLICATION }}
CERTIFICATE_PASSWORD: ${{ secrets.CERTIFICATE_PASSWORD }}
- name: Set up Node.js
uses: actions/setup-node@7c12f8017d5436eb855f1ed4399f037a36fbd9e8 # pin@v2
with:
node-version: 'lts/*'
- name: Install dependencies
run: npm ci
run: source ../../bin/activate-hermit && npm ci
working-directory: ui/desktop
# Configure Electron builder for Intel architecture
@@ -187,6 +185,7 @@ jobs:
- name: Make Unsigned App
if: ${{ !inputs.signing }}
run: |
source ../../bin/activate-hermit
attempt=0
max_attempts=2
until [ $attempt -ge $max_attempts ]; do
@@ -204,6 +203,7 @@ jobs:
- name: Make Signed App
if: ${{ inputs.signing }}
run: |
source ../../bin/activate-hermit
attempt=0
max_attempts=2
until [ $attempt -ge $max_attempts ]; do

View File

@@ -17,32 +17,31 @@ on:
required: false
WINDOWS_CERTIFICATE_PASSWORD:
required: false
ref:
type: string
required: false
default: 'refs/heads/main'
jobs:
build-desktop-windows:
name: Build Desktop (Windows)
runs-on: windows-latest
runs-on: ubuntu-latest # Use Ubuntu for cross-compilation
steps:
# 1) Check out source
- name: Checkout repository
uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744
with:
ref: ${{ inputs.ref }}
fetch-depth: 0
# 2) Set up Rust
- name: Set up Rust
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b
# If you need a specific version, you could do:
# or uses: actions/setup-rust@v1
# with:
# rust-version: 1.73.0
# 3) Set up Node.js
# 2) Set up Node.js
- name: Set up Node.js
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 # pin@v3
with:
node-version: 16
node-version: 18
# 4) Cache dependencies (optional, can add more paths if needed)
# 3) Cache dependencies
- name: Cache node_modules
uses: actions/cache@2f8e54208210a422b2efd51efaa6bd6d7ca8920f # pin@v3
with:
@@ -53,103 +52,92 @@ jobs:
restore-keys: |
${{ runner.os }}-build-desktop-windows-
# 5) Install top-level dependencies if a package.json is in root
- name: Install top-level deps
# 4) Build Rust for Windows using Docker (cross-compilation)
- name: Build Windows executable using Docker
run: |
if (Test-Path package.json) {
npm install
}
echo "Building Windows executable using Docker cross-compilation..."
docker volume create goose-windows-cache || true
docker run --rm \
-v "$(pwd)":/usr/src/myapp \
-v goose-windows-cache:/usr/local/cargo/registry \
-w /usr/src/myapp \
rust:latest \
sh -c "rustup target add x86_64-pc-windows-gnu && \
apt-get update && \
apt-get install -y mingw-w64 protobuf-compiler cmake && \
export CC_x86_64_pc_windows_gnu=x86_64-w64-mingw32-gcc && \
export CXX_x86_64_pc_windows_gnu=x86_64-w64-mingw32-g++ && \
export AR_x86_64_pc_windows_gnu=x86_64-w64-mingw32-ar && \
export CARGO_TARGET_X86_64_PC_WINDOWS_GNU_LINKER=x86_64-w64-mingw32-gcc && \
export PKG_CONFIG_ALLOW_CROSS=1 && \
export PROTOC=/usr/bin/protoc && \
export PATH=/usr/bin:\$PATH && \
protoc --version && \
cargo build --release --target x86_64-pc-windows-gnu && \
GCC_DIR=\$(ls -d /usr/lib/gcc/x86_64-w64-mingw32/*/ | head -n 1) && \
cp \$GCC_DIR/libstdc++-6.dll /usr/src/myapp/target/x86_64-pc-windows-gnu/release/ && \
cp \$GCC_DIR/libgcc_s_seh-1.dll /usr/src/myapp/target/x86_64-pc-windows-gnu/release/ && \
cp /usr/x86_64-w64-mingw32/lib/libwinpthread-1.dll /usr/src/myapp/target/x86_64-pc-windows-gnu/release/"
# 6) Build rust for x86_64-pc-windows-gnu
- name: Install MinGW dependencies
run: |
choco install mingw --version=8.1.0
# Debug - check installation paths
Write-Host "Checking MinGW installation..."
Get-ChildItem -Path "C:\ProgramData\chocolatey\lib\mingw" -Recurse -Filter "*.dll" | ForEach-Object {
Write-Host $_.FullName
}
Get-ChildItem -Path "C:\tools" -Recurse -Filter "*.dll" | ForEach-Object {
Write-Host $_.FullName
}
- name: Cargo build for Windows
run: |
cargo build --release --target x86_64-pc-windows-gnu
# 7) Check that the compiled goosed.exe exists and copy exe/dll to ui/desktop/src/bin
# 5) Prepare Windows binary and DLLs
- name: Prepare Windows binary and DLLs
run: |
if (!(Test-Path .\target\x86_64-pc-windows-gnu\release\goosed.exe)) {
Write-Error "Windows binary not found."; exit 1;
}
Write-Host "Copying Windows binary and DLLs to ui/desktop/src/bin..."
if (!(Test-Path ui\desktop\src\bin)) {
New-Item -ItemType Directory -Path ui\desktop\src\bin | Out-Null
}
Copy-Item .\target\x86_64-pc-windows-gnu\release\goosed.exe ui\desktop\src\bin\
if [ ! -f "./target/x86_64-pc-windows-gnu/release/goosed.exe" ]; then
echo "Windows binary not found."
exit 1
fi
# Copy MinGW DLLs - try both possible locations
$mingwPaths = @(
"C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw64\bin",
"C:\tools\mingw64\bin"
)
echo "Cleaning destination directory..."
rm -rf ./ui/desktop/src/bin
mkdir -p ./ui/desktop/src/bin
foreach ($path in $mingwPaths) {
if (Test-Path "$path\libstdc++-6.dll") {
Write-Host "Found MinGW DLLs in $path"
Copy-Item "$path\libstdc++-6.dll" ui\desktop\src\bin\
Copy-Item "$path\libgcc_s_seh-1.dll" ui\desktop\src\bin\
Copy-Item "$path\libwinpthread-1.dll" ui\desktop\src\bin\
break
}
}
echo "Copying Windows binary and DLLs..."
cp -f ./target/x86_64-pc-windows-gnu/release/goosed.exe ./ui/desktop/src/bin/
cp -f ./target/x86_64-pc-windows-gnu/release/*.dll ./ui/desktop/src/bin/
# Copy any other DLLs from the release directory
ls .\target\x86_64-pc-windows-gnu\release\*.dll | ForEach-Object {
Copy-Item $_ ui\desktop\src\bin\
}
# Copy Windows platform files (tools, scripts, etc.)
if [ -d "./ui/desktop/src/platform/windows/bin" ]; then
echo "Copying Windows platform files..."
for file in ./ui/desktop/src/platform/windows/bin/*.{exe,dll,cmd}; do
if [ -f "$file" ] && [ "$(basename "$file")" != "goosed.exe" ]; then
cp -f "$file" ./ui/desktop/src/bin/
fi
done
# 8) Install & build UI desktop
if [ -d "./ui/desktop/src/platform/windows/bin/goose-npm" ]; then
echo "Setting up npm environment..."
rsync -a --delete ./ui/desktop/src/platform/windows/bin/goose-npm/ ./ui/desktop/src/bin/goose-npm/
fi
echo "Windows-specific files copied successfully"
fi
# 6) Install & build UI desktop
- name: Build desktop UI with npm
run: |
cd ui\desktop
cd ui/desktop
npm install
npm run bundle:windows
# 9) Copy exe/dll to final out/Goose-win32-x64/resources/bin
# 7) Copy exe/dll to final out/Goose-win32-x64/resources/bin
- name: Copy exe/dll to out folder
run: |
cd ui\desktop
if (!(Test-Path .\out\Goose-win32-x64\resources\bin)) {
New-Item -ItemType Directory -Path .\out\Goose-win32-x64\resources\bin | Out-Null
}
Copy-Item .\src\bin\goosed.exe .\out\Goose-win32-x64\resources\bin\
ls .\src\bin\*.dll | ForEach-Object {
Copy-Item $_ .\out\Goose-win32-x64\resources\bin\
}
cd ui/desktop
mkdir -p ./out/Goose-win32-x64/resources/bin
rsync -av src/bin/ out/Goose-win32-x64/resources/bin/
# 10) Code signing (if enabled)
# 8) Code signing (if enabled)
- name: Sign Windows executable
# Skip this step by default - enable when we have a certificate
if: inputs.signing && inputs.signing == true
env:
WINDOWS_CERTIFICATE: ${{ secrets.WINDOWS_CERTIFICATE }}
WINDOWS_CERTIFICATE_PASSWORD: ${{ secrets.WINDOWS_CERTIFICATE_PASSWORD }}
run: |
# Create a temporary certificate file
$certBytes = [Convert]::FromBase64String($env:WINDOWS_CERTIFICATE)
$certPath = Join-Path -Path $env:RUNNER_TEMP -ChildPath "certificate.pfx"
[IO.File]::WriteAllBytes($certPath, $certBytes)
# Note: This would need to be adapted for Linux-based signing
# or moved to a Windows runner for the signing step only
echo "Code signing would be implemented here"
echo "Currently skipped as we're running on Ubuntu"
# Sign the main executable
$signtool = "C:\Program Files (x86)\Windows Kits\10\bin\10.0.17763.0\x64\signtool.exe"
& $signtool sign /f $certPath /p $env:WINDOWS_CERTIFICATE_PASSWORD /tr http://timestamp.digicert.com /td sha256 /fd sha256 "ui\desktop\out\Goose-win32-x64\Goose.exe"
# Clean up the certificate
Remove-Item -Path $certPath
# 11) Upload the final Windows build
# 9) Upload the final Windows build
- name: Upload Windows build artifacts
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # pin@v4
with:

View File

@@ -135,6 +135,7 @@ jobs:
sed -i.bak "s/^version = \".*\"/version = \"${VERSION}\"/" Cargo.toml
rm -f Cargo.toml.bak
source ./bin/activate-hermit
# Update version in package.json
cd ui/desktop
npm version "${VERSION}" --no-git-tag-version --allow-same-version
@@ -142,6 +143,7 @@ jobs:
# Pre-build cleanup to ensure enough disk space
- name: Pre-build cleanup
run: |
source ./bin/activate-hermit
echo "Performing pre-build cleanup..."
# Clean npm cache
npm cache clean --force || true
@@ -154,16 +156,6 @@ jobs:
# Check disk space after cleanup
df -h
- name: Install protobuf
run: |
brew install protobuf
echo "PROTOC=$(which protoc)" >> $GITHUB_ENV
- name: Setup Rust
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
with:
toolchain: stable
- name: Cache Cargo registry
uses: actions/cache@2f8e54208210a422b2efd51efaa6bd6d7ca8920f # pin@v3
with:
@@ -190,7 +182,7 @@ jobs:
# Build the project
- name: Build goosed
run: cargo build --release -p goose-server
run: source ./bin/activate-hermit && cargo build --release -p goose-server
# Post-build cleanup to free space
- name: Post-build cleanup
@@ -216,13 +208,8 @@ jobs:
CERTIFICATE_OSX_APPLICATION: ${{ secrets.CERTIFICATE_OSX_APPLICATION }}
CERTIFICATE_PASSWORD: ${{ secrets.CERTIFICATE_PASSWORD }}
- name: Set up Node.js
uses: actions/setup-node@7c12f8017d5436eb855f1ed4399f037a36fbd9e8 # pin@v2
with:
node-version: 'lts/*'
- name: Install dependencies
run: npm ci
run: source ../../bin/activate-hermit && npm ci
working-directory: ui/desktop
# Check disk space before bundling
@@ -232,6 +219,7 @@ jobs:
- name: Make Unsigned App
if: ${{ !inputs.signing }}
run: |
source ../../bin/activate-hermit
attempt=0
max_attempts=2
until [ $attempt -ge $max_attempts ]; do
@@ -253,6 +241,7 @@ jobs:
APPLE_ID_PASSWORD: ${{ secrets.APPLE_ID_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
run: |
attempt=0
max_attempts=2
until [ $attempt -ge $max_attempts ]; do

View File

@@ -5,6 +5,9 @@ on:
pull_request:
branches:
- main
merge_group:
branches:
- main
workflow_dispatch:
name: CI
@@ -17,13 +20,8 @@ jobs:
- name: Checkout Code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
- name: Setup Rust
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
with:
toolchain: stable
- name: Run cargo fmt
run: cargo fmt --check
run: source ./bin/activate-hermit && cargo fmt --check
rust-build-and-test:
name: Build and Test Rust Project
@@ -57,12 +55,7 @@ jobs:
- name: Install Dependencies
run: |
sudo apt update -y
sudo apt install -y libdbus-1-dev gnome-keyring libxcb1-dev protobuf-compiler
- name: Setup Rust
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
with:
toolchain: stable
sudo apt install -y libdbus-1-dev gnome-keyring libxcb1-dev
- name: Cache Cargo Registry
uses: actions/cache@2f8e54208210a422b2efd51efaa6bd6d7ca8920f # pin@v3
@@ -91,7 +84,7 @@ jobs:
- name: Build and Test
run: |
gnome-keyring-daemon --components=secrets --daemonize --unlock <<< 'foobar'
cargo test
source ../bin/activate-hermit && cargo test
working-directory: crates
# Add disk space cleanup before linting
@@ -120,7 +113,7 @@ jobs:
run: df -h
- name: Lint
run: cargo clippy -- -D warnings
run: source ./bin/activate-hermit && cargo clippy -- -D warnings
desktop-lint:
name: Lint Electron Desktop App
@@ -129,22 +122,17 @@ jobs:
- name: Checkout Code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
- name: Set up Node.js
uses: actions/setup-node@7c12f8017d5436eb855f1ed4399f037a36fbd9e8 # pin@v2
with:
node-version: "lts/*"
- name: Install Dependencies
run: npm ci
run: source ../../bin/activate-hermit && npm ci
working-directory: ui/desktop
- name: Run Lint
run: npm run lint:check
run: source ../../bin/activate-hermit && npm run lint:check
working-directory: ui/desktop
# Faster Desktop App build for PRs only
bundle-desktop-unsigned:
uses: ./.github/workflows/bundle-desktop.yml
if: github.event_name == 'pull_request'
if: github.event_name == 'pull_request' || github.event_name == 'merge_group'
with:
signing: false

View File

@@ -27,6 +27,7 @@ jobs:
outputs:
continue: ${{ steps.command.outputs.continue || github.event_name == 'workflow_dispatch' }}
pr_number: ${{ steps.command.outputs.issue_number || github.event.inputs.pr_number }}
head_sha: ${{ steps.set_head_sha.outputs.head_sha || github.sha }}
steps:
- if: ${{ github.event_name == 'issue_comment' }}
uses: github/command@v1.3.0
@@ -37,10 +38,26 @@ jobs:
reaction: "eyes"
allowed_contexts: pull_request
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
- name: Get PR head SHA with gh
id: set_head_sha
run: |
echo "Get PR head SHA with gh"
HEAD_SHA=$(gh pr view "$ISSUE_NUMBER" --json headRefOid -q .headRefOid)
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
echo "head_sha=$HEAD_SHA"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
ISSUE_NUMBER: ${{ steps.command.outputs.issue_number }}
build-cli:
needs: [trigger-on-command]
if: ${{ needs.trigger-on-command.outputs.continue == 'true' }}
uses: ./.github/workflows/build-cli.yml
with:
ref: ${{ needs.trigger-on-command.outputs.head_sha }}
pr-comment-cli:
name: PR Comment with CLI builds

View File

@@ -30,6 +30,7 @@ jobs:
continue: ${{ steps.command.outputs.continue || github.event_name == 'workflow_dispatch' }}
# Cannot use github.event.pull_request.number since the trigger is 'issue_comment'
pr_number: ${{ steps.command.outputs.issue_number || github.event.inputs.pr_number }}
head_sha: ${{ steps.set_head_sha.outputs.head_sha || github.sha }}
steps:
- if: ${{ github.event_name == 'issue_comment' }}
uses: github/command@319d5236cc34ed2cb72a47c058a363db0b628ebe # pin@v1.3.0
@@ -40,6 +41,20 @@ jobs:
reaction: "eyes"
allowed_contexts: pull_request
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
- name: Get PR head SHA with gh
id: set_head_sha
run: |
echo "Get PR head SHA with gh"
HEAD_SHA=$(gh pr view "$ISSUE_NUMBER" --json headRefOid -q .headRefOid)
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
echo "head_sha=$HEAD_SHA"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
ISSUE_NUMBER: ${{ steps.command.outputs.issue_number }}
bundle-desktop-intel:
# Only run this if ".bundle-intel" command is detected.
needs: [trigger-on-command]
@@ -47,6 +62,7 @@ jobs:
uses: ./.github/workflows/bundle-desktop-intel.yml
with:
signing: true
ref: ${{ needs.trigger-on-command.outputs.head_sha }}
secrets:
CERTIFICATE_OSX_APPLICATION: ${{ secrets.CERTIFICATE_OSX_APPLICATION }}
CERTIFICATE_PASSWORD: ${{ secrets.CERTIFICATE_PASSWORD }}

View File

@@ -30,6 +30,7 @@ jobs:
continue: ${{ steps.command.outputs.continue || github.event_name == 'workflow_dispatch' }}
# Cannot use github.event.pull_request.number since the trigger is 'issue_comment'
pr_number: ${{ steps.command.outputs.issue_number || github.event.inputs.pr_number }}
head_sha: ${{ steps.set_head_sha.outputs.head_sha || github.sha }}
steps:
- if: ${{ github.event_name == 'issue_comment' }}
uses: github/command@319d5236cc34ed2cb72a47c058a363db0b628ebe # pin@v1.3.0
@@ -40,6 +41,20 @@ jobs:
reaction: "eyes"
allowed_contexts: pull_request
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
- name: Get PR head SHA with gh
id: set_head_sha
run: |
echo "Get PR head SHA with gh"
HEAD_SHA=$(gh pr view "$ISSUE_NUMBER" --json headRefOid -q .headRefOid)
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
echo "head_sha=$HEAD_SHA"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
ISSUE_NUMBER: ${{ steps.command.outputs.issue_number }}
bundle-desktop-windows:
# Only run this if ".bundle-windows" command is detected.
needs: [trigger-on-command]
@@ -47,6 +62,7 @@ jobs:
uses: ./.github/workflows/bundle-desktop-windows.yml
with:
signing: false # false for now as we don't have a cert yet
ref: ${{ needs.trigger-on-command.outputs.head_sha }}
secrets:
WINDOWS_CERTIFICATE: ${{ secrets.WINDOWS_CERTIFICATE }}
WINDOWS_CERTIFICATE_PASSWORD: ${{ secrets.WINDOWS_CERTIFICATE_PASSWORD }}

7
.gitignore vendored
View File

@@ -23,9 +23,12 @@ target/
./ui/desktop/node_modules
./ui/desktop/out
# Generated Goose DLLs (built at build time, not checked in)
ui/desktop/src/bin/goose_ffi.dll
ui/desktop/src/bin/goose_llm.dll
# Hermit
/.hermit/
/bin/
.hermit/
debug_*.txt

View File

@@ -2,12 +2,20 @@
# Only auto-format desktop TS code if relevant files are modified
if git diff --cached --name-only | grep -q "^ui/desktop/"; then
if [ -d "ui/desktop" ]; then
. "$(dirname -- "$0")/_/husky.sh"
cd ui/desktop && npx lint-staged
else
echo "Warning: ui/desktop directory does not exist, skipping lint-staged"
fi
fi
# Only auto-format ui-v2 TS code if relevant files are modified
if git diff --cached --name-only | grep -q "^ui-v2/"; then
if [ -d "ui-v2" ]; then
. "$(dirname -- "$0")/_/husky.sh"
cd ui-v2 && npx lint-staged
else
echo "Warning: ui-v2 directory does not exist, skipping lint-staged"
fi
fi

72
Cargo.lock generated
View File

@@ -2083,8 +2083,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crunchy"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
source = "git+https://github.com/nmathewson/crunchy?branch=cross-compilation-fix#260ec5f08969480c342bb3fe47f88870ed5c6cce"
[[package]]
name = "crypto-common"
@@ -3452,12 +3451,13 @@ dependencies = [
"tokenizers",
"tokio",
"tokio-cron-scheduler",
"tokio-stream",
"tracing",
"tracing-subscriber",
"url",
"utoipa",
"uuid",
"webbrowser",
"webbrowser 0.8.15",
"winapi",
"wiremock",
]
@@ -3514,8 +3514,10 @@ version = "1.0.24"
dependencies = [
"anyhow",
"async-trait",
"axum",
"base64 0.22.1",
"bat",
"bytes",
"chrono",
"clap 4.5.31",
"cliclack",
@@ -3525,6 +3527,8 @@ dependencies = [
"goose",
"goose-bench",
"goose-mcp",
"http 1.2.0",
"indicatif",
"mcp-client",
"mcp-core",
"mcp-server",
@@ -3544,9 +3548,12 @@ dependencies = [
"tempfile",
"test-case",
"tokio",
"tokio-stream",
"tower-http",
"tracing",
"tracing-appender",
"tracing-subscriber",
"webbrowser 1.0.4",
"winapi",
]
@@ -3639,7 +3646,7 @@ dependencies = [
"url",
"urlencoding",
"utoipa",
"webbrowser",
"webbrowser 0.8.15",
"xcap",
]
@@ -3941,6 +3948,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c"
[[package]]
name = "httparse"
version = "1.10.1"
@@ -5902,6 +5915,31 @@ dependencies = [
"malloc_buf",
]
[[package]]
name = "objc2"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88c6597e14493ab2e44ce58f2fdecf095a51f12ca57bec060a11c57332520551"
dependencies = [
"objc2-encode",
]
[[package]]
name = "objc2-encode"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
[[package]]
name = "objc2-foundation"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c"
dependencies = [
"bitflags 2.9.0",
"objc2",
]
[[package]]
name = "object"
version = "0.36.7"
@@ -8727,12 +8765,21 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
dependencies = [
"bitflags 2.9.0",
"bytes",
"futures-util",
"http 1.2.0",
"http-body 1.0.1",
"http-body-util",
"http-range-header",
"httpdate",
"mime",
"mime_guess",
"percent-encoding",
"pin-project-lite",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
@@ -9420,6 +9467,23 @@ dependencies = [
"web-sys",
]
[[package]]
name = "webbrowser"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5df295f8451142f1856b1bd86a606dfe9587d439bc036e319c827700dbd555e"
dependencies = [
"core-foundation 0.10.0",
"home",
"jni",
"log",
"ndk-context",
"objc2",
"objc2-foundation",
"url",
"web-sys",
]
[[package]]
name = "webpki-roots"
version = "0.26.8"

View File

@@ -9,3 +9,7 @@ authors = ["Block <ai-oss-tools@block.xyz>"]
license = "Apache-2.0"
repository = "https://github.com/block/goose"
description = "An AI agent"
# Patch for Windows cross-compilation issue with crunchy
[patch.crates-io]
crunchy = { git = "https://github.com/nmathewson/crunchy", branch = "cross-compilation-fix" }

View File

@@ -6,29 +6,59 @@ pre-build = [
"dpkg --add-architecture arm64",
"""\
apt-get update --fix-missing && apt-get install -y \
curl \
unzip \
pkg-config \
libssl-dev:arm64 \
libdbus-1-dev:arm64 \
libxcb1-dev:arm64
""",
"""\
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v31.1/protoc-31.1-linux-x86_64.zip && \
unzip -o protoc-31.1-linux-x86_64.zip -d /usr/local && \
chmod +x /usr/local/bin/protoc && \
ln -sf /usr/local/bin/protoc /usr/bin/protoc && \
which protoc && \
protoc --version
"""
]
[target.x86_64-unknown-linux-gnu]
xargo = false
pre-build = [
# Install necessary dependencies for x86_64
# We don't need architecture-specific flags because x86_64 dependencies are installable on Ubuntu system
"""\
apt-get update && apt-get install -y \
curl \
unzip \
pkg-config \
libssl-dev \
libdbus-1-dev \
libxcb1-dev \
libxcb1-dev
""",
"""\
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v31.1/protoc-31.1-linux-x86_64.zip && \
unzip -o protoc-31.1-linux-x86_64.zip -d /usr/local && \
chmod +x /usr/local/bin/protoc && \
ln -sf /usr/local/bin/protoc /usr/bin/protoc && \
which protoc && \
protoc --version
"""
]
[target.x86_64-pc-windows-gnu]
image = "dockcross/windows-static-x64:latest"
# Enable verbose output for Windows builds
build-std = true
env = { "RUST_LOG" = "debug", "RUST_BACKTRACE" = "1", "CROSS_VERBOSE" = "1" }
image = "ghcr.io/cross-rs/x86_64-pc-windows-gnu:latest"
env = { "RUST_LOG" = "debug", "RUST_BACKTRACE" = "1", "CROSS_VERBOSE" = "1", "PKG_CONFIG_ALLOW_CROSS" = "1" }
pre-build = [
"""\
apt-get update && apt-get install -y \
curl \
unzip
""",
"""\
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v31.1/protoc-31.1-linux-x86_64.zip && \
unzip protoc-31.1-linux-x86_64.zip -d /usr/local && \
chmod +x /usr/local/bin/protoc && \
export PROTOC=/usr/local/bin/protoc && \
protoc --version
"""
]

View File

@@ -25,7 +25,15 @@ release-windows:
rust:latest \
sh -c "rustup target add x86_64-pc-windows-gnu && \
apt-get update && \
apt-get install -y mingw-w64 && \
apt-get install -y mingw-w64 protobuf-compiler cmake && \
export CC_x86_64_pc_windows_gnu=x86_64-w64-mingw32-gcc && \
export CXX_x86_64_pc_windows_gnu=x86_64-w64-mingw32-g++ && \
export AR_x86_64_pc_windows_gnu=x86_64-w64-mingw32-ar && \
export CARGO_TARGET_X86_64_PC_WINDOWS_GNU_LINKER=x86_64-w64-mingw32-gcc && \
export PKG_CONFIG_ALLOW_CROSS=1 && \
export PROTOC=/usr/bin/protoc && \
export PATH=/usr/bin:\$PATH && \
protoc --version && \
cargo build --release --target x86_64-pc-windows-gnu && \
GCC_DIR=\$(ls -d /usr/lib/gcc/x86_64-w64-mingw32/*/ | head -n 1) && \
cp \$GCC_DIR/libstdc++-6.dll /usr/src/myapp/target/x86_64-pc-windows-gnu/release/ && \

View File

@@ -23,6 +23,31 @@ Whether you're prototyping an idea, refining existing code, or managing intricat
Designed for maximum flexibility, goose works with any LLM, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation.
## Multiple Model Configuration
goose supports using different models for different purposes to optimize performance and cost, which can work across model providers as well as models.
### Lead/Worker Model Pattern
Use a powerful model for initial planning and complex reasoning, then switch to a faster/cheaper model for execution, this happens automatically by goose:
```bash
# Required: Enable lead model mode
export GOOSE_LEAD_MODEL=modelY
# Optional: configure a provider for the lead model if not the default provider
export GOOSE_LEAD_PROVIDER=providerX # Defaults to main provider
```
### Planning Model Configuration
Use a specialized model for the `/plan` command in CLI mode, this is explicitly invoked when you want to plan (vs execute)
```bash
# Optional: Use different model for planning
export GOOSE_PLANNER_PROVIDER=openai
export GOOSE_PLANNER_MODEL=gpt-4
```
Both patterns help you balance model capabilities with cost and speed for optimal results, and switch between models and vendors as required.
# Quick Links
- [Quickstart](https://block.github.io/goose/docs/quickstart)

1
bin/.node-22.9.0.pkg Symbolic link
View File

@@ -0,0 +1 @@
hermit

1
bin/.protoc-31.1.pkg Symbolic link
View File

@@ -0,0 +1 @@
hermit

1
bin/.rustup-1.25.2.pkg Symbolic link
View File

@@ -0,0 +1 @@
hermit

7
bin/README.hermit.md Normal file
View File

@@ -0,0 +1,7 @@
# Hermit environment
This is a [Hermit](https://github.com/cashapp/hermit) bin directory.
The symlinks in this directory are managed by Hermit and will automatically
download and install Hermit itself as well as packages. These packages are
local to this environment.

21
bin/activate-hermit Executable file
View File

@@ -0,0 +1,21 @@
#!/bin/bash
# This file must be used with "source bin/activate-hermit" from bash or zsh.
# You cannot run it directly
#
# THIS FILE IS GENERATED; DO NOT MODIFY
if [ "${BASH_SOURCE-}" = "$0" ]; then
echo "You must source this script: \$ source $0" >&2
exit 33
fi
BIN_DIR="$(dirname "${BASH_SOURCE[0]:-${(%):-%x}}")"
if "${BIN_DIR}/hermit" noop > /dev/null; then
eval "$("${BIN_DIR}/hermit" activate "${BIN_DIR}/..")"
if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ]; then
hash -r 2>/dev/null
fi
echo "Hermit environment $("${HERMIT_ENV}"/bin/hermit env HERMIT_ENV) activated"
fi

24
bin/activate-hermit.fish Executable file
View File

@@ -0,0 +1,24 @@
#!/usr/bin/env fish
# This file must be sourced with "source bin/activate-hermit.fish" from Fish shell.
# You cannot run it directly.
#
# THIS FILE IS GENERATED; DO NOT MODIFY
if status is-interactive
set BIN_DIR (dirname (status --current-filename))
if "$BIN_DIR/hermit" noop > /dev/null
# Source the activation script generated by Hermit
"$BIN_DIR/hermit" activate "$BIN_DIR/.." | source
# Clear the command cache if applicable
functions -c > /dev/null 2>&1
# Display activation message
echo "Hermit environment $($HERMIT_ENV/bin/hermit env HERMIT_ENV) activated"
end
else
echo "You must source this script: source $argv[0]" >&2
exit 33
end

1
bin/cargo Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/cargo-clippy Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/cargo-fmt Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/cargo-miri Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/clippy-driver Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/corepack Symbolic link
View File

@@ -0,0 +1 @@
.node-22.9.0.pkg

43
bin/hermit Executable file
View File

@@ -0,0 +1,43 @@
#!/bin/bash
#
# THIS FILE IS GENERATED; DO NOT MODIFY
set -eo pipefail
export HERMIT_USER_HOME=~
if [ -z "${HERMIT_STATE_DIR}" ]; then
case "$(uname -s)" in
Darwin)
export HERMIT_STATE_DIR="${HERMIT_USER_HOME}/Library/Caches/hermit"
;;
Linux)
export HERMIT_STATE_DIR="${XDG_CACHE_HOME:-${HERMIT_USER_HOME}/.cache}/hermit"
;;
esac
fi
export HERMIT_DIST_URL="${HERMIT_DIST_URL:-https://github.com/cashapp/hermit/releases/download/stable}"
HERMIT_CHANNEL="$(basename "${HERMIT_DIST_URL}")"
export HERMIT_CHANNEL
export HERMIT_EXE=${HERMIT_EXE:-${HERMIT_STATE_DIR}/pkg/hermit@${HERMIT_CHANNEL}/hermit}
if [ ! -x "${HERMIT_EXE}" ]; then
echo "Bootstrapping ${HERMIT_EXE} from ${HERMIT_DIST_URL}" 1>&2
INSTALL_SCRIPT="$(mktemp)"
# This value must match that of the install script
INSTALL_SCRIPT_SHA256="09ed936378857886fd4a7a4878c0f0c7e3d839883f39ca8b4f2f242e3126e1c6"
if [ "${INSTALL_SCRIPT_SHA256}" = "BYPASS" ]; then
curl -fsSL "${HERMIT_DIST_URL}/install.sh" -o "${INSTALL_SCRIPT}"
else
# Install script is versioned by its sha256sum value
curl -fsSL "${HERMIT_DIST_URL}/install-${INSTALL_SCRIPT_SHA256}.sh" -o "${INSTALL_SCRIPT}"
# Verify install script's sha256sum
openssl dgst -sha256 "${INSTALL_SCRIPT}" | \
awk -v EXPECTED="$INSTALL_SCRIPT_SHA256" \
'$2!=EXPECTED {print "Install script sha256 " $2 " does not match " EXPECTED; exit 1}'
fi
/bin/bash "${INSTALL_SCRIPT}" 1>&2
fi
exec "${HERMIT_EXE}" --level=fatal exec "$0" -- "$@"

4
bin/hermit.hcl Normal file
View File

@@ -0,0 +1,4 @@
manage-git = false
github-token-auth {
}

1
bin/node Symbolic link
View File

@@ -0,0 +1 @@
.node-22.9.0.pkg

1
bin/npm Symbolic link
View File

@@ -0,0 +1 @@
.node-22.9.0.pkg

1
bin/npx Symbolic link
View File

@@ -0,0 +1 @@
.node-22.9.0.pkg

1
bin/protoc Symbolic link
View File

@@ -0,0 +1 @@
.protoc-31.1.pkg

1
bin/rls Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/rust-analyzer Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/rust-gdb Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/rust-gdbgui Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/rust-lldb Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/rustc Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/rustdoc Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/rustfmt Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

1
bin/rustup Symbolic link
View File

@@ -0,0 +1 @@
.rustup-1.25.2.pkg

View File

@@ -55,6 +55,15 @@ regex = "1.11.1"
minijinja = "2.8.0"
nix = { version = "0.30.1", features = ["process", "signal"] }
tar = "0.4"
# Web server dependencies
axum = { version = "0.8.1", features = ["ws", "macros"] }
tower-http = { version = "0.5", features = ["cors", "fs"] }
tokio-stream = "0.1"
bytes = "1.5"
http = "1.0"
webbrowser = "1.0"
indicatif = "0.17.11"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -0,0 +1,78 @@
# Goose Web Interface
The `goose web` command provides a (preview) web-based chat interface for interacting with Goose.
Do not expose this publicly - this is in a preview state as an option.
## Usage
```bash
# Start the web server on default port (3000)
goose web
# Start on a specific port
goose web --port 8080
# Start and automatically open in browser
goose web --open
# Bind to a specific host
goose web --host 0.0.0.0 --port 8080
```
## Features
- **Real-time chat interface**: Communicate with Goose through a clean web UI
- **WebSocket support**: Real-time message streaming
- **Session management**: Each browser tab maintains its own session
- **Responsive design**: Works on desktop and mobile devices
## Architecture
The web interface is built with:
- **Backend**: Rust with Axum web framework
- **Frontend**: Vanilla JavaScript with WebSocket communication
- **Styling**: CSS with dark/light mode support
## Development Notes
### Current Implementation
The web interface provides:
1. A simple chat UI similar to the desktop Electron app
2. WebSocket-based real-time communication
3. Basic session management (messages are stored in memory)
### Future Enhancements
- [ ] Persistent session storage
- [ ] Tool call visualization
- [ ] File upload support
- [ ] Multiple session tabs
- [ ] Authentication/authorization
- [ ] Streaming responses with proper formatting
- [ ] Code syntax highlighting
- [ ] Export chat history
### Integration with Goose Agent
The web server creates an instance of the Goose Agent and processes messages through the same pipeline as the CLI. However, some features like:
- Extension management
- Tool confirmations
- File system interactions
...may require additional UI components to be fully functional.
## Security Considerations
Currently, the web interface:
- Binds to localhost by default for security
- Does not include authentication (planned for future)
- Should not be exposed to the internet without proper security measures
## Troubleshooting
If you encounter issues:
1. **Port already in use**: Try a different port with `--port`
2. **Cannot connect**: Ensure no firewall is blocking the port
3. **Agent not configured**: Run `goose configure` first to set up a provider

View File

@@ -34,7 +34,7 @@ struct Cli {
command: Option<Command>,
}
#[derive(Args)]
#[derive(Args, Debug)]
#[group(required = false, multiple = false)]
struct Identifier {
#[arg(
@@ -102,6 +102,19 @@ enum SessionCommand {
#[arg(short, long, help = "Regex for removing matched sessions (optional)")]
regex: Option<String>,
},
#[command(about = "Export a session to Markdown format")]
Export {
#[command(flatten)]
identifier: Option<Identifier>,
#[arg(
short,
long,
help = "Output file path (default: stdout)",
long_help = "Path to save the exported Markdown. If not provided, output will be sent to stdout"
)]
output: Option<PathBuf>,
},
}
#[derive(Subcommand, Debug)]
@@ -491,6 +504,31 @@ enum Command {
#[command(subcommand)]
cmd: BenchCommand,
},
/// Start a web server with a chat interface
#[command(about = "Start a web server with a chat interface", hide = true)]
Web {
/// Port to run the web server on
#[arg(
short,
long,
default_value = "3000",
help = "Port to run the web server on"
)]
port: u16,
/// Host to bind the web server to
#[arg(
long,
default_value = "127.0.0.1",
help = "Host to bind the web server to"
)]
host: String,
/// Open browser automatically
#[arg(long, help = "Open browser automatically when server starts")]
open: bool,
},
}
#[derive(clap::ValueEnum, Clone, Debug)]
@@ -550,6 +588,23 @@ pub async fn cli() -> Result<()> {
handle_session_remove(id, regex)?;
return Ok(());
}
Some(SessionCommand::Export { identifier, output }) => {
let session_identifier = if let Some(id) = identifier {
extract_identifier(id)
} else {
// If no identifier is provided, prompt for interactive selection
match crate::commands::session::prompt_interactive_session_selection() {
Ok(id) => id,
Err(e) => {
eprintln!("Error: {}", e);
return Ok(());
}
}
};
crate::commands::session::handle_session_export(session_identifier, output)?;
Ok(())
}
None => {
// Run session command by default
let mut session: crate::Session = build_session(SessionBuilderConfig {
@@ -755,6 +810,10 @@ pub async fn cli() -> Result<()> {
}
return Ok(());
}
Some(Command::Web { port, host, open }) => {
crate::commands::web::handle_web(port, host, open).await?;
return Ok(());
}
None => {
return if !Config::global().exists() {
let _ = handle_configure().await;

View File

@@ -7,3 +7,4 @@ pub mod recipe;
pub mod schedule;
pub mod session;
pub mod update;
pub mod web;

View File

@@ -34,6 +34,8 @@ pub async fn handle_schedule_add(
last_run: None,
currently_running: false,
paused: false,
current_session_id: None,
process_start_time: None,
};
let scheduler_storage_path =

View File

@@ -1,8 +1,11 @@
use crate::session::message_to_markdown;
use anyhow::{Context, Result};
use cliclack::{confirm, multiselect};
use cliclack::{confirm, multiselect, select};
use goose::session::info::{get_session_info, SessionInfo, SortOrder};
use goose::session::{self, Identifier};
use regex::Regex;
use std::fs;
use std::path::{Path, PathBuf};
const TRUNCATED_DESC_LENGTH: usize = 60;
@@ -29,7 +32,7 @@ pub fn remove_sessions(sessions: Vec<SessionInfo>) -> Result<()> {
Ok(())
}
fn prompt_interactive_session_selection(sessions: &[SessionInfo]) -> Result<Vec<SessionInfo>> {
fn prompt_interactive_session_removal(sessions: &[SessionInfo]) -> Result<Vec<SessionInfo>> {
if sessions.is_empty() {
println!("No sessions to delete.");
return Ok(vec![]);
@@ -105,7 +108,7 @@ pub fn handle_session_remove(id: Option<String>, regex_string: Option<String>) -
if all_sessions.is_empty() {
return Err(anyhow::anyhow!("No sessions found."));
}
matched_sessions = prompt_interactive_session_selection(&all_sessions)?;
matched_sessions = prompt_interactive_session_removal(&all_sessions)?;
}
if matched_sessions.is_empty() {
@@ -165,3 +168,184 @@ pub fn handle_session_list(verbose: bool, format: String, ascending: bool) -> Re
}
Ok(())
}
/// Export a session to Markdown without creating a full Session object
///
/// This function directly reads messages from the session file and converts them to Markdown
/// without creating an Agent or prompting about working directories.
pub fn handle_session_export(identifier: Identifier, output_path: Option<PathBuf>) -> Result<()> {
// Get the session file path
let session_file_path = goose::session::get_path(identifier.clone());
if !session_file_path.exists() {
return Err(anyhow::anyhow!(
"Session file not found (expected path: {})",
session_file_path.display()
));
}
// Read messages directly without using Session
let messages = match goose::session::read_messages(&session_file_path) {
Ok(msgs) => msgs,
Err(e) => {
return Err(anyhow::anyhow!("Failed to read session messages: {}", e));
}
};
// Generate the markdown content using the export functionality
let markdown = export_session_to_markdown(messages, &session_file_path, None);
// Output the markdown
if let Some(output) = output_path {
fs::write(&output, markdown)
.with_context(|| format!("Failed to write to output file: {}", output.display()))?;
println!("Session exported to {}", output.display());
} else {
println!("{}", markdown);
}
Ok(())
}
/// Convert a list of messages to markdown format for session export
///
/// This function handles the formatting of a complete session including headers,
/// message organization, and proper tool request/response pairing.
fn export_session_to_markdown(
messages: Vec<goose::message::Message>,
session_file: &Path,
session_name_override: Option<&str>,
) -> String {
let mut markdown_output = String::new();
let session_name = session_name_override.unwrap_or_else(|| {
session_file
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("Unnamed Session")
});
markdown_output.push_str(&format!("# Session Export: {}\n\n", session_name));
if messages.is_empty() {
markdown_output.push_str("*(This session has no messages)*\n");
return markdown_output;
}
markdown_output.push_str(&format!("*Total messages: {}*\n\n---\n\n", messages.len()));
// Track if the last message had tool requests to properly handle tool responses
let mut skip_next_if_tool_response = false;
for message in &messages {
// Check if this is a User message containing only ToolResponses
let is_only_tool_response = message.role == mcp_core::role::Role::User
&& message
.content
.iter()
.all(|content| matches!(content, goose::message::MessageContent::ToolResponse(_)));
// If the previous message had tool requests and this one is just tool responses,
// don't create a new User section - we'll attach the responses to the tool calls
if skip_next_if_tool_response && is_only_tool_response {
// Export the tool responses without a User heading
markdown_output.push_str(&message_to_markdown(message, false));
markdown_output.push_str("\n\n---\n\n");
skip_next_if_tool_response = false;
continue;
}
// Reset the skip flag - we'll update it below if needed
skip_next_if_tool_response = false;
// Output the role prefix except for tool response-only messages
if !is_only_tool_response {
let role_prefix = match message.role {
mcp_core::role::Role::User => "### User:\n",
mcp_core::role::Role::Assistant => "### Assistant:\n",
};
markdown_output.push_str(role_prefix);
}
// Add the message content
markdown_output.push_str(&message_to_markdown(message, false));
markdown_output.push_str("\n\n---\n\n");
// Check if this message has any tool requests, to handle the next message differently
if message
.content
.iter()
.any(|content| matches!(content, goose::message::MessageContent::ToolRequest(_)))
{
skip_next_if_tool_response = true;
}
}
markdown_output
}
/// Prompt the user to interactively select a session
///
/// Shows a list of available sessions and lets the user select one
pub fn prompt_interactive_session_selection() -> Result<session::Identifier> {
// Get sessions sorted by modification date (newest first)
let sessions = match get_session_info(SortOrder::Descending) {
Ok(sessions) => sessions,
Err(e) => {
tracing::error!("Failed to list sessions: {:?}", e);
return Err(anyhow::anyhow!("Failed to list sessions"));
}
};
if sessions.is_empty() {
return Err(anyhow::anyhow!("No sessions found"));
}
// Build the selection prompt
let mut selector = select("Select a session to export:");
// Map to display text
let display_map: std::collections::HashMap<String, SessionInfo> = sessions
.iter()
.map(|s| {
let desc = if s.metadata.description.is_empty() {
"(no description)"
} else {
&s.metadata.description
};
// Truncate description if too long
let truncated_desc = if desc.len() > 40 {
format!("{}...", &desc[..37])
} else {
desc.to_string()
};
let display_text = format!("{} - {} ({})", s.modified, truncated_desc, s.id);
(display_text, s.clone())
})
.collect();
// Add each session as an option
for display_text in display_map.keys() {
selector = selector.item(display_text.clone(), display_text.clone(), "");
}
// Add a cancel option
let cancel_value = String::from("cancel");
selector = selector.item(cancel_value, "Cancel", "Cancel export");
// Get user selection
let selected_display_text: String = selector.interact()?;
if selected_display_text == "cancel" {
return Err(anyhow::anyhow!("Export canceled"));
}
// Retrieve the selected session
if let Some(session) = display_map.get(&selected_display_text) {
Ok(goose::session::Identifier::Name(session.id.clone()))
} else {
Err(anyhow::anyhow!("Invalid selection"))
}
}

View File

@@ -0,0 +1,640 @@
use anyhow::Result;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::{Html, IntoResponse, Response},
routing::get,
Json, Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use goose::agents::{Agent, AgentEvent};
use goose::message::Message as GooseMessage;
use goose::session;
use serde::{Deserialize, Serialize};
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::{Mutex, RwLock};
use tower_http::cors::{Any, CorsLayer};
use tracing::error;
type SessionStore = Arc<RwLock<std::collections::HashMap<String, Arc<Mutex<Vec<GooseMessage>>>>>>;
type CancellationStore = Arc<RwLock<std::collections::HashMap<String, tokio::task::AbortHandle>>>;
#[derive(Clone)]
struct AppState {
agent: Arc<Agent>,
sessions: SessionStore,
cancellations: CancellationStore,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "type")]
enum WebSocketMessage {
#[serde(rename = "message")]
Message {
content: String,
session_id: String,
timestamp: i64,
},
#[serde(rename = "cancel")]
Cancel { session_id: String },
#[serde(rename = "response")]
Response {
content: String,
role: String,
timestamp: i64,
},
#[serde(rename = "tool_request")]
ToolRequest {
id: String,
tool_name: String,
arguments: serde_json::Value,
},
#[serde(rename = "tool_response")]
ToolResponse {
id: String,
result: serde_json::Value,
is_error: bool,
},
#[serde(rename = "tool_confirmation")]
ToolConfirmation {
id: String,
tool_name: String,
arguments: serde_json::Value,
needs_confirmation: bool,
},
#[serde(rename = "error")]
Error { message: String },
#[serde(rename = "thinking")]
Thinking { message: String },
#[serde(rename = "context_exceeded")]
ContextExceeded { message: String },
#[serde(rename = "cancelled")]
Cancelled { message: String },
#[serde(rename = "complete")]
Complete { message: String },
}
pub async fn handle_web(port: u16, host: String, open: bool) -> Result<()> {
// Setup logging
crate::logging::setup_logging(Some("goose-web"), None)?;
// Load config and create agent just like the CLI does
let config = goose::config::Config::global();
let provider_name: String = match config.get_param("GOOSE_PROVIDER") {
Ok(p) => p,
Err(_) => {
eprintln!("No provider configured. Run 'goose configure' first");
std::process::exit(1);
}
};
let model: String = match config.get_param("GOOSE_MODEL") {
Ok(m) => m,
Err(_) => {
eprintln!("No model configured. Run 'goose configure' first");
std::process::exit(1);
}
};
let model_config = goose::model::ModelConfig::new(model.clone());
// Create the agent
let agent = Agent::new();
let provider = goose::providers::create(&provider_name, model_config)?;
agent.update_provider(provider).await?;
// Load and enable extensions from config
let extensions = goose::config::ExtensionConfigManager::get_all()?;
for ext_config in extensions {
if ext_config.enabled {
if let Err(e) = agent.add_extension(ext_config.config.clone()).await {
eprintln!(
"Warning: Failed to load extension {}: {}",
ext_config.config.name(),
e
);
}
}
}
let state = AppState {
agent: Arc::new(agent),
sessions: Arc::new(RwLock::new(std::collections::HashMap::new())),
cancellations: Arc::new(RwLock::new(std::collections::HashMap::new())),
};
// Build router
let app = Router::new()
.route("/", get(serve_index))
.route("/session/{session_name}", get(serve_session))
.route("/ws", get(websocket_handler))
.route("/api/health", get(health_check))
.route("/api/sessions", get(list_sessions))
.route("/api/sessions/{session_id}", get(get_session))
.route("/static/{*path}", get(serve_static))
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
.with_state(state);
let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
println!("\n🪿 Starting Goose web server");
println!(" Provider: {} | Model: {}", provider_name, model);
println!(
" Working directory: {}",
std::env::current_dir()?.display()
);
println!(" Server: http://{}", addr);
println!(" Press Ctrl+C to stop\n");
if open {
// Open browser
let url = format!("http://{}", addr);
if let Err(e) = webbrowser::open(&url) {
eprintln!("Failed to open browser: {}", e);
}
}
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn serve_index() -> Html<&'static str> {
Html(include_str!("../../static/index.html"))
}
async fn serve_session(
axum::extract::Path(session_name): axum::extract::Path<String>,
) -> Html<String> {
let html = include_str!("../../static/index.html");
// Inject the session name into the HTML so JavaScript can use it
let html_with_session = html.replace(
"<script src=\"/static/script.js\"></script>",
&format!(
"<script>window.GOOSE_SESSION_NAME = '{}';</script>\n <script src=\"/static/script.js\"></script>",
session_name
)
);
Html(html_with_session)
}
async fn serve_static(axum::extract::Path(path): axum::extract::Path<String>) -> Response {
match path.as_str() {
"style.css" => (
[("content-type", "text/css")],
include_str!("../../static/style.css"),
)
.into_response(),
"script.js" => (
[("content-type", "application/javascript")],
include_str!("../../static/script.js"),
)
.into_response(),
_ => (axum::http::StatusCode::NOT_FOUND, "Not found").into_response(),
}
}
async fn health_check() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "ok",
"service": "goose-web"
}))
}
async fn list_sessions() -> Json<serde_json::Value> {
match session::list_sessions() {
Ok(sessions) => {
let session_info: Vec<serde_json::Value> = sessions
.into_iter()
.map(|(name, path)| {
let metadata = session::read_metadata(&path).unwrap_or_default();
serde_json::json!({
"name": name,
"path": path,
"description": metadata.description,
"message_count": metadata.message_count,
"working_dir": metadata.working_dir
})
})
.collect();
Json(serde_json::json!({
"sessions": session_info
}))
}
Err(e) => Json(serde_json::json!({
"error": e.to_string()
})),
}
}
async fn get_session(
axum::extract::Path(session_id): axum::extract::Path<String>,
) -> Json<serde_json::Value> {
let session_file = session::get_path(session::Identifier::Name(session_id));
match session::read_messages(&session_file) {
Ok(messages) => {
let metadata = session::read_metadata(&session_file).unwrap_or_default();
Json(serde_json::json!({
"metadata": metadata,
"messages": messages
}))
}
Err(e) => Json(serde_json::json!({
"error": e.to_string()
})),
}
}
async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: AppState) {
let (sender, mut receiver) = socket.split();
let sender = Arc::new(Mutex::new(sender));
while let Some(msg) = receiver.next().await {
if let Ok(msg) = msg {
match msg {
Message::Text(text) => {
match serde_json::from_str::<WebSocketMessage>(&text.to_string()) {
Ok(WebSocketMessage::Message {
content,
session_id,
..
}) => {
// Get session file path from session_id
let session_file =
session::get_path(session::Identifier::Name(session_id.clone()));
// Get or create session in memory (for fast access during processing)
let session_messages = {
let sessions = state.sessions.read().await;
if let Some(session) = sessions.get(&session_id) {
session.clone()
} else {
drop(sessions);
let mut sessions = state.sessions.write().await;
// Load existing messages from JSONL file if it exists
let existing_messages = session::read_messages(&session_file)
.unwrap_or_else(|_| Vec::new());
let new_session = Arc::new(Mutex::new(existing_messages));
sessions.insert(session_id.clone(), new_session.clone());
new_session
}
};
// Clone sender for async processing
let sender_clone = sender.clone();
let agent = state.agent.clone();
// Process message in a separate task to allow streaming
let task_handle = tokio::spawn(async move {
let result = process_message_streaming(
&agent,
session_messages,
session_file,
content,
sender_clone,
)
.await;
if let Err(e) = result {
error!("Error processing message: {}", e);
}
});
// Store the abort handle
{
let mut cancellations = state.cancellations.write().await;
cancellations
.insert(session_id.clone(), task_handle.abort_handle());
}
// Wait for task completion and handle abort
let sender_for_abort = sender.clone();
let session_id_for_cleanup = session_id.clone();
let cancellations_for_cleanup = state.cancellations.clone();
tokio::spawn(async move {
match task_handle.await {
Ok(_) => {
// Task completed normally
}
Err(e) if e.is_cancelled() => {
// Task was aborted
let mut sender = sender_for_abort.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(
&WebSocketMessage::Cancelled {
message: "Operation cancelled by user"
.to_string(),
},
)
.unwrap()
.into(),
))
.await;
}
Err(e) => {
error!("Task error: {}", e);
}
}
// Clean up cancellation token
{
let mut cancellations = cancellations_for_cleanup.write().await;
cancellations.remove(&session_id_for_cleanup);
}
});
}
Ok(WebSocketMessage::Cancel { session_id }) => {
// Cancel the active operation for this session
let abort_handle = {
let mut cancellations = state.cancellations.write().await;
cancellations.remove(&session_id)
};
if let Some(handle) = abort_handle {
handle.abort();
// Send cancellation confirmation
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Cancelled {
message: "Operation cancelled".to_string(),
})
.unwrap()
.into(),
))
.await;
}
}
Ok(_) => {
// Ignore other message types
}
Err(e) => {
error!("Failed to parse WebSocket message: {}", e);
}
}
}
Message::Close(_) => break,
_ => {}
}
} else {
break;
}
}
}
async fn process_message_streaming(
agent: &Agent,
session_messages: Arc<Mutex<Vec<GooseMessage>>>,
session_file: std::path::PathBuf,
content: String,
sender: Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
) -> Result<()> {
use futures::StreamExt;
use goose::agents::SessionConfig;
use goose::message::MessageContent;
use goose::session;
// Create a user message
let user_message = GooseMessage::user().with_text(content.clone());
// Get existing messages from session and add the new user message
let mut messages = {
let mut session_msgs = session_messages.lock().await;
session_msgs.push(user_message.clone());
session_msgs.clone()
};
// Persist messages to JSONL file with provider for automatic description generation
let provider = agent.provider().await;
if provider.is_err() {
let error_msg = "I'm not properly configured yet. Please configure a provider through the CLI first using `goose configure`.".to_string();
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Response {
content: error_msg,
role: "assistant".to_string(),
timestamp: chrono::Utc::now().timestamp_millis(),
})
.unwrap()
.into(),
))
.await;
return Ok(());
}
let provider = provider.unwrap();
session::persist_messages(&session_file, &messages, Some(provider.clone())).await?;
// Create a session config
let session_config = SessionConfig {
id: session::Identifier::Path(session_file.clone()),
working_dir: std::env::current_dir()?,
schedule_id: None,
};
// Get response from agent
match agent.reply(&messages, Some(session_config)).await {
Ok(mut stream) => {
while let Some(result) = stream.next().await {
match result {
Ok(AgentEvent::Message(message)) => {
// Add message to our session
{
let mut session_msgs = session_messages.lock().await;
session_msgs.push(message.clone());
}
// Persist messages to JSONL file (no provider needed for assistant messages)
let current_messages = {
let session_msgs = session_messages.lock().await;
session_msgs.clone()
};
session::persist_messages(&session_file, &current_messages, None).await?;
// Handle different message content types
for content in &message.content {
match content {
MessageContent::Text(text) => {
// Send the text response
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Response {
content: text.text.clone(),
role: "assistant".to_string(),
timestamp: chrono::Utc::now().timestamp_millis(),
})
.unwrap()
.into(),
))
.await;
}
MessageContent::ToolRequest(req) => {
// Send tool request notification
let mut sender = sender.lock().await;
if let Ok(tool_call) = &req.tool_call {
let _ = sender
.send(Message::Text(
serde_json::to_string(
&WebSocketMessage::ToolRequest {
id: req.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
},
)
.unwrap()
.into(),
))
.await;
}
}
MessageContent::ToolResponse(_resp) => {
// Tool responses are already included in the complete message stream
// and will be persisted to session history. No need to send separate
// WebSocket messages as this would cause duplicates.
}
MessageContent::ToolConfirmationRequest(confirmation) => {
// Send tool confirmation request
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(
&WebSocketMessage::ToolConfirmation {
id: confirmation.id.clone(),
tool_name: confirmation.tool_name.clone(),
arguments: confirmation.arguments.clone(),
needs_confirmation: true,
},
)
.unwrap()
.into(),
))
.await;
// For now, auto-approve in web mode
// TODO: Implement proper confirmation UI
agent.handle_confirmation(
confirmation.id.clone(),
goose::permission::PermissionConfirmation {
principal_type: goose::permission::permission_confirmation::PrincipalType::Tool,
permission: goose::permission::Permission::AllowOnce,
}
).await;
}
MessageContent::Thinking(thinking) => {
// Send thinking indicator
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Thinking {
message: thinking.thinking.clone(),
})
.unwrap()
.into(),
))
.await;
}
MessageContent::ContextLengthExceeded(msg) => {
// Send context exceeded notification
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(
&WebSocketMessage::ContextExceeded {
message: msg.msg.clone(),
},
)
.unwrap()
.into(),
))
.await;
// For now, auto-summarize in web mode
// TODO: Implement proper UI for context handling
let (summarized_messages, _) =
agent.summarize_context(&messages).await?;
messages = summarized_messages;
}
_ => {
// Handle other message types as needed
}
}
}
}
Ok(AgentEvent::McpNotification(_notification)) => {
// Handle MCP notifications if needed
// For now, we'll just log them
tracing::info!("Received MCP notification in web interface");
}
Err(e) => {
error!("Error in message stream: {}", e);
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Error {
message: format!("Error: {}", e),
})
.unwrap()
.into(),
))
.await;
break;
}
}
}
}
Err(e) => {
error!("Error calling agent: {}", e);
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Error {
message: format!("Error: {}", e),
})
.unwrap()
.into(),
))
.await;
}
}
// Send completion message
let mut sender = sender.lock().await;
let _ = sender
.send(Message::Text(
serde_json::to_string(&WebSocketMessage::Complete {
message: "Response complete".to_string(),
})
.unwrap()
.into(),
))
.await;
Ok(())
}
// Add webbrowser dependency for opening browser
use webbrowser;

View File

@@ -7,6 +7,7 @@ use goose::session;
use goose::session::Identifier;
use mcp_client::transport::Error as McpClientError;
use std::process;
use std::sync::Arc;
use super::output;
use super::Session;
@@ -55,6 +56,22 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
// Create the agent
let agent: Agent = Agent::new();
let new_provider = create(&provider_name, model_config).unwrap();
// Keep a reference to the provider for display_session_info
let provider_for_display = Arc::clone(&new_provider);
// Log model information at startup
if let Some(lead_worker) = new_provider.as_lead_worker() {
let (lead_model, worker_model) = lead_worker.get_model_info();
tracing::info!(
"🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled",
lead_model,
worker_model
);
} else {
tracing::info!("🤖 Using model: {}", model);
}
agent
.update_provider(new_provider)
.await
@@ -217,6 +234,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
session.agent.override_system_prompt(override_prompt).await;
}
output::display_session_info(session_config.resume, &provider_name, &model, &session_file);
output::display_session_info(
session_config.resume,
&provider_name,
&model,
&session_file,
Some(&provider_for_display),
);
session
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,15 @@
mod builder;
mod completion;
mod export;
mod input;
mod output;
mod prompt;
mod thinking;
pub use self::export::message_to_markdown;
pub use builder::{build_session, SessionBuilderConfig};
use console::Color;
use goose::agents::AgentEvent;
use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::Permission;
use goose::permission::PermissionConfirmation;
@@ -15,8 +18,7 @@ pub use goose::session::Identifier;
use anyhow::{Context, Result};
use completion::GooseCompleter;
use etcetera::choose_app_strategy;
use etcetera::AppStrategy;
use etcetera::{choose_app_strategy, AppStrategy};
use goose::agents::extension::{Envs, ExtensionConfig};
use goose::agents::{Agent, SessionConfig};
use goose::config::Config;
@@ -25,6 +27,8 @@ use goose::session;
use input::InputResult;
use mcp_core::handler::ToolError;
use mcp_core::prompt::PromptMessage;
use mcp_core::protocol::JsonRpcMessage;
use mcp_core::protocol::JsonRpcNotification;
use rand::{distributions::Alphanumeric, Rng};
use serde_json::Value;
@@ -351,9 +355,10 @@ impl Session {
// Create and use a global history file in ~/.config/goose directory
// This allows command history to persist across different chat sessions
// instead of being tied to each individual session's messages
let history_file = choose_app_strategy(crate::APP_STRATEGY.clone())
.expect("goose requires a home dir")
.in_config_dir("history.txt");
let strategy =
choose_app_strategy(crate::APP_STRATEGY.clone()).expect("goose requires a home dir");
let config_dir = strategy.config_dir();
let history_file = config_dir.join("history.txt");
// Ensure config directory exists
if let Some(parent) = history_file.parent() {
@@ -379,6 +384,9 @@ impl Session {
output::display_greeting();
loop {
// Display context usage before each prompt
self.display_context_usage().await?;
match input::get_input(&mut editor)? {
input::InputResult::Message(content) => {
match self.run_mode {
@@ -713,12 +721,14 @@ impl Session {
)
.await?;
let mut progress_bars = output::McpSpinners::new();
use futures::StreamExt;
loop {
tokio::select! {
result = stream.next() => {
match result {
Some(Ok(message)) => {
Some(Ok(AgentEvent::Message(message))) => {
// If it's a confirmation request, get approval but otherwise do not render/persist
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
output::hide_thinking();
@@ -768,23 +778,27 @@ impl Session {
} else if let Some(MessageContent::ContextLengthExceeded(_)) = message.content.first() {
output::hide_thinking();
// Check for user-configured default context strategy
let config = Config::global();
let context_strategy = config.get_param::<String>("GOOSE_CONTEXT_STRATEGY")
.unwrap_or_else(|_| if interactive { "prompt".to_string() } else { "summarize".to_string() });
let selected = match context_strategy.as_str() {
"clear" => "clear",
"truncate" => "truncate",
"summarize" => "summarize",
_ => {
if interactive {
// In interactive mode, ask the user what to do
// In interactive mode with no default, ask the user what to do
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
let selected_result = cliclack::select(prompt)
cliclack::select(prompt)
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
.item("cancel", "Cancel", "Cancel and return to chat")
.interact();
let selected = match selected_result {
Ok(s) => s,
Err(e) => {
if e.kind() == std::io::ErrorKind::Interrupted {
"cancel" // If interrupted, set selected to cancel
.interact()?
} else {
return Err(e.into());
// In headless mode, default to summarize
"summarize"
}
}
};
@@ -792,33 +806,41 @@ impl Session {
match selected {
"clear" => {
self.messages.clear();
let msg = format!("Session cleared.\n{}", "-".repeat(50));
let msg = if context_strategy == "clear" {
format!("Context maxed out - automatically cleared session.\n{}", "-".repeat(50))
} else {
format!("Session cleared.\n{}", "-".repeat(50))
};
output::render_text(&msg, Some(Color::Yellow), true);
break; // exit the loop to hand back control to the user
}
"truncate" => {
// Truncate messages to fit within context length
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
let msg = format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50));
let msg = if context_strategy == "truncate" {
format!("Context maxed out - automatically truncated messages.\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50))
} else {
format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50))
};
output::render_text("", Some(Color::Yellow), true);
output::render_text(&msg, Some(Color::Yellow), true);
self.messages = truncated_messages;
}
"summarize" => {
// Use the helper function to summarize context
Self::summarize_context_messages(&mut self.messages, &self.agent, "Goose summarized messages for you.").await?;
}
"cancel" => {
break; // Return to main prompt
let message_suffix = if context_strategy == "summarize" {
"Goose automatically summarized messages for you."
} else if interactive {
"Goose summarized messages for you."
} else {
"Goose automatically summarized messages to continue processing."
};
Self::summarize_context_messages(&mut self.messages, &self.agent, message_suffix).await?;
}
_ => {
unreachable!()
}
}
} else {
// In headless mode (goose run), automatically use summarize
Self::summarize_context_messages(&mut self.messages, &self.agent, "Goose automatically summarized messages to continue processing.").await?;
}
// Restart the stream after handling ContextLengthExceeded
stream = self
@@ -842,10 +864,55 @@ impl Session {
session::persist_messages(&self.session_file, &self.messages, None).await?;
if interactive {output::hide_thinking()};
let _ = progress_bars.hide();
output::render_message(&message, self.debug);
if interactive {output::show_thinking()};
}
}
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
if let JsonRpcMessage::Notification(JsonRpcNotification{
method,
params: Some(Value::Object(o)),
..
}) = message {
match method.as_str() {
"notifications/message" => {
let data = o.get("data").unwrap_or(&Value::Null);
let message = match data {
Value::String(s) => s.clone(),
Value::Object(o) => {
if let Some(Value::String(output)) = o.get("output") {
output.to_owned()
} else {
data.to_string()
}
},
v => {
v.to_string()
},
};
progress_bars.log(&message);
},
"notifications/progress" => {
let progress = o.get("progress").and_then(|v| v.as_f64());
let token = o.get("progressToken").map(|v| v.to_string());
let message = o.get("message").and_then(|v| v.as_str());
let total = o
.get("total")
.and_then(|v| v.as_f64());
if let (Some(progress), Some(token)) = (progress, token) {
progress_bars.update(
token.as_str(),
progress,
total,
message,
);
}
},
_ => (),
}
}
}
Some(Err(e)) => {
eprintln!("Error: {}", e);
drop(stream);
@@ -872,6 +939,7 @@ impl Session {
}
}
}
Ok(())
}
@@ -1054,6 +1122,26 @@ impl Session {
Ok(metadata.total_tokens)
}
/// Display enhanced context usage with session totals
pub async fn display_context_usage(&self) -> Result<()> {
let provider = self.agent.provider().await?;
let model_config = provider.get_model_config();
let context_limit = model_config.context_limit.unwrap_or(32000);
match self.get_metadata() {
Ok(metadata) => {
let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;
output::display_context_usage(total_tokens, context_limit);
}
Err(_) => {
output::display_context_usage(0, context_limit);
}
}
Ok(())
}
/// Handle prompt command execution
async fn handle_prompt_command(&mut self, opts: input::PromptCommandOptions) -> Result<()> {
// name is required

View File

@@ -2,12 +2,16 @@ use bat::WrappingMode;
use console::{style, Color};
use goose::config::Config;
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use mcp_core::prompt::PromptArgument;
use mcp_core::tool::ToolCall;
use serde_json::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::io::Error;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
// Re-export theme for use in main
#[derive(Clone, Copy)]
@@ -144,6 +148,10 @@ pub fn render_message(message: &Message, debug: bool) {
}
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
render_text_no_newlines(format!("\n{}\n\n", text).as_str(), color, dim);
}
pub fn render_text_no_newlines(text: &str, color: Option<Color>, dim: bool) {
let mut styled_text = style(text);
if dim {
styled_text = styled_text.dim();
@@ -153,7 +161,7 @@ pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
} else {
styled_text = styled_text.green();
}
println!("\n{}\n", styled_text);
print!("{}", styled_text);
}
pub fn render_enter_plan_mode() {
@@ -359,7 +367,6 @@ fn render_shell_request(call: &ToolCall, debug: bool) {
}
_ => print_params(&call.arguments, 0, debug),
}
println!();
}
fn render_default_request(call: &ToolCall, debug: bool) {
@@ -530,7 +537,13 @@ fn shorten_path(path: &str, debug: bool) -> String {
}
// Session display functions
pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) {
pub fn display_session_info(
resume: bool,
provider: &str,
model: &str,
session_file: &Path,
provider_instance: Option<&Arc<dyn goose::providers::base::Provider>>,
) {
let start_session_msg = if resume {
"resuming session |"
} else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") {
@@ -538,6 +551,22 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
} else {
"starting session |"
};
// Check if we have lead/worker mode
if let Some(provider_inst) = provider_instance {
if let Some(lead_worker) = provider_inst.as_lead_worker() {
let (lead_model, worker_model) = lead_worker.get_model_info();
println!(
"{} {} {} {} {} {} {}",
style(start_session_msg).dim(),
style("provider:").dim(),
style(provider).cyan().dim(),
style("lead model:").dim(),
style(&lead_model).cyan().dim(),
style("worker model:").dim(),
style(&worker_model).cyan().dim(),
);
} else {
println!(
"{} {} {} {} {}",
style(start_session_msg).dim(),
@@ -546,6 +575,18 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
style("model:").dim(),
style(model).cyan().dim(),
);
}
} else {
// Fallback to original behavior if no provider instance
println!(
"{} {} {} {} {}",
style(start_session_msg).dim(),
style("provider:").dim(),
style(provider).cyan().dim(),
style("model:").dim(),
style(model).cyan().dim(),
);
}
if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") {
println!(
@@ -568,6 +609,102 @@ pub fn display_greeting() {
println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n");
}
/// Display context window usage with both current and session totals
pub fn display_context_usage(total_tokens: usize, context_limit: usize) {
use console::style;
// Calculate percentage used
let percentage = (total_tokens as f64 / context_limit as f64 * 100.0).round() as usize;
// Create dot visualization
let dot_count = 10;
let filled_dots = ((percentage as f64 / 100.0) * dot_count as f64).round() as usize;
let empty_dots = dot_count - filled_dots;
let filled = "".repeat(filled_dots);
let empty = "".repeat(empty_dots);
// Combine dots and apply color
let dots = format!("{}{}", filled, empty);
let colored_dots = if percentage < 50 {
style(dots).green()
} else if percentage < 85 {
style(dots).yellow()
} else {
style(dots).red()
};
// Print the status line
println!(
"Context: {} {}% ({}/{} tokens)",
colored_dots, percentage, total_tokens, context_limit
);
}
pub struct McpSpinners {
bars: HashMap<String, ProgressBar>,
log_spinner: Option<ProgressBar>,
multi_bar: MultiProgress,
}
impl McpSpinners {
pub fn new() -> Self {
McpSpinners {
bars: HashMap::new(),
log_spinner: None,
multi_bar: MultiProgress::new(),
}
}
pub fn log(&mut self, message: &str) {
let spinner = self.log_spinner.get_or_insert_with(|| {
let bar = self.multi_bar.add(
ProgressBar::new_spinner()
.with_style(
ProgressStyle::with_template("{spinner:.green} {msg}")
.unwrap()
.tick_chars("⠋⠙⠚⠛⠓⠒⠊⠉"),
)
.with_message(message.to_string()),
);
bar.enable_steady_tick(Duration::from_millis(100));
bar
});
spinner.set_message(message.to_string());
}
pub fn update(&mut self, token: &str, value: f64, total: Option<f64>, message: Option<&str>) {
let bar = self.bars.entry(token.to_string()).or_insert_with(|| {
if let Some(total) = total {
self.multi_bar.add(
ProgressBar::new((total * 100.0) as u64).with_style(
ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}")
.unwrap(),
),
)
} else {
self.multi_bar.add(ProgressBar::new_spinner())
}
});
bar.set_position((value * 100.0) as u64);
if let Some(msg) = message {
bar.set_message(msg.to_string());
}
}
pub fn hide(&mut self) -> Result<(), Error> {
self.bars.iter_mut().for_each(|(_, bar)| {
bar.disable_steady_tick();
});
if let Some(spinner) = self.log_spinner.as_mut() {
spinner.disable_steady_tick();
}
self.multi_bar.clear()
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -0,0 +1,46 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Goose Chat</title>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<div class="container">
<header>
<h1 id="session-title">Goose Chat</h1>
<div class="status" id="connection-status">Connecting...</div>
</header>
<div class="chat-container">
<div class="messages" id="messages">
<div class="welcome-message">
<h2>Welcome to Goose!</h2>
<p>I'm your AI assistant. How can I help you today?</p>
<div class="suggestion-pills">
<div class="suggestion-pill" onclick="sendSuggestion('What can you do?')">What can you do?</div>
<div class="suggestion-pill" onclick="sendSuggestion('Demo writing and reading files')">Demo writing and reading files</div>
<div class="suggestion-pill" onclick="sendSuggestion('Make a snake game in a new folder')">Make a snake game in a new folder</div>
<div class="suggestion-pill" onclick="sendSuggestion('List files in my current directory')">List files in my current directory</div>
<div class="suggestion-pill" onclick="sendSuggestion('Take a screenshot and summarize')">Take a screenshot and summarize</div>
</div>
</div>
</div>
<div class="input-container">
<textarea
id="message-input"
placeholder="Type your message here..."
rows="3"
autofocus
></textarea>
<button id="send-button" type="button">Send</button>
</div>
</div>
</div>
<script src="/static/script.js"></script>
</body>
</html>

View File

@@ -0,0 +1,523 @@
// WebSocket connection and chat functionality
let socket = null;
let sessionId = getSessionId();
let isConnected = false;
// DOM elements
const messagesContainer = document.getElementById('messages');
const messageInput = document.getElementById('message-input');
const sendButton = document.getElementById('send-button');
const connectionStatus = document.getElementById('connection-status');
// Track if we're currently processing
let isProcessing = false;
// Get session ID - either from URL parameter, injected session name, or generate new one
function getSessionId() {
// Check if session name was injected by server (for /session/:name routes)
if (window.GOOSE_SESSION_NAME) {
return window.GOOSE_SESSION_NAME;
}
// Check URL parameters
const urlParams = new URLSearchParams(window.location.search);
const sessionParam = urlParams.get('session') || urlParams.get('name');
if (sessionParam) {
return sessionParam;
}
// Generate new session ID using CLI format
return generateSessionId();
}
// Generate a session ID using timestamp format (yyyymmdd_hhmmss) like CLI
function generateSessionId() {
const now = new Date();
const year = now.getFullYear();
const month = String(now.getMonth() + 1).padStart(2, '0');
const day = String(now.getDate()).padStart(2, '0');
const hour = String(now.getHours()).padStart(2, '0');
const minute = String(now.getMinutes()).padStart(2, '0');
const second = String(now.getSeconds()).padStart(2, '0');
return `${year}${month}${day}_${hour}${minute}${second}`;
}
// Format timestamp
function formatTimestamp(date) {
return date.toLocaleTimeString('en-US', {
hour: '2-digit',
minute: '2-digit'
});
}
// Create message element
function createMessageElement(content, role, timestamp) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${role}`;
// Create content div
const contentDiv = document.createElement('div');
contentDiv.className = 'message-content';
contentDiv.innerHTML = formatMessageContent(content);
messageDiv.appendChild(contentDiv);
// Add timestamp
const timestampDiv = document.createElement('div');
timestampDiv.className = 'timestamp';
timestampDiv.textContent = formatTimestamp(new Date(timestamp || Date.now()));
messageDiv.appendChild(timestampDiv);
return messageDiv;
}
// Format message content (handle markdown-like formatting)
function formatMessageContent(content) {
// Escape HTML
let formatted = content
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;');
// Handle code blocks
formatted = formatted.replace(/```(\w+)?\n([\s\S]*?)```/g, (match, lang, code) => {
return `<pre><code class="language-${lang || 'plaintext'}">${code.trim()}</code></pre>`;
});
// Handle inline code
formatted = formatted.replace(/`([^`]+)`/g, '<code>$1</code>');
// Handle line breaks
formatted = formatted.replace(/\n/g, '<br>');
return formatted;
}
// Add message to chat
function addMessage(content, role, timestamp) {
// Remove welcome message if it exists
const welcomeMessage = messagesContainer.querySelector('.welcome-message');
if (welcomeMessage) {
welcomeMessage.remove();
}
const messageElement = createMessageElement(content, role, timestamp);
messagesContainer.appendChild(messageElement);
// Scroll to bottom
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
// Add thinking indicator
function addThinkingIndicator() {
removeThinkingIndicator(); // Remove any existing one first
const thinkingDiv = document.createElement('div');
thinkingDiv.id = 'thinking-indicator';
thinkingDiv.className = 'message thinking-message';
thinkingDiv.innerHTML = `
<div class="thinking-dots">
<span></span>
<span></span>
<span></span>
</div>
<span class="thinking-text">Goose is thinking...</span>
`;
messagesContainer.appendChild(thinkingDiv);
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
// Remove thinking indicator
function removeThinkingIndicator() {
const thinking = document.getElementById('thinking-indicator');
if (thinking) {
thinking.remove();
}
}
// Connect to WebSocket
function connectWebSocket() {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//${window.location.host}/ws`;
socket = new WebSocket(wsUrl);
socket.onopen = () => {
console.log('WebSocket connected');
isConnected = true;
connectionStatus.textContent = 'Connected';
connectionStatus.className = 'status connected';
sendButton.disabled = false;
// Check if this session exists and load history if it does
loadSessionIfExists();
};
socket.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
handleServerMessage(data);
} catch (e) {
console.error('Failed to parse message:', e);
}
};
socket.onclose = () => {
console.log('WebSocket disconnected');
isConnected = false;
connectionStatus.textContent = 'Disconnected';
connectionStatus.className = 'status disconnected';
sendButton.disabled = true;
// Attempt to reconnect after 3 seconds
setTimeout(connectWebSocket, 3000);
};
socket.onerror = (error) => {
console.error('WebSocket error:', error);
};
}
// Handle messages from server
function handleServerMessage(data) {
switch (data.type) {
case 'response':
// For streaming responses, we need to handle partial messages
handleStreamingResponse(data);
break;
case 'tool_request':
handleToolRequest(data);
break;
case 'tool_response':
handleToolResponse(data);
break;
case 'tool_confirmation':
handleToolConfirmation(data);
break;
case 'thinking':
handleThinking(data);
break;
case 'context_exceeded':
handleContextExceeded(data);
break;
case 'cancelled':
handleCancelled(data);
break;
case 'complete':
handleComplete(data);
break;
case 'error':
removeThinkingIndicator();
resetSendButton();
addMessage(`Error: ${data.message}`, 'assistant', Date.now());
break;
default:
console.log('Unknown message type:', data.type);
}
}
// Track current streaming message
let currentStreamingMessage = null;
// Handle streaming responses
function handleStreamingResponse(data) {
removeThinkingIndicator();
// If this is the first chunk of a new message, or we don't have a current streaming message
if (!currentStreamingMessage) {
// Create a new message element
const messageElement = createMessageElement(data.content, data.role || 'assistant', data.timestamp);
messageElement.setAttribute('data-streaming', 'true');
messagesContainer.appendChild(messageElement);
currentStreamingMessage = {
element: messageElement,
content: data.content,
role: data.role || 'assistant',
timestamp: data.timestamp
};
} else {
// Append to existing streaming message
currentStreamingMessage.content += data.content;
// Update the message content using the proper content div
const contentDiv = currentStreamingMessage.element.querySelector('.message-content');
if (contentDiv) {
contentDiv.innerHTML = formatMessageContent(currentStreamingMessage.content);
}
}
// Scroll to bottom
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
// Handle tool requests
function handleToolRequest(data) {
removeThinkingIndicator(); // Remove thinking when tool starts
// Reset streaming message so tool doesn't interfere with message flow
currentStreamingMessage = null;
const toolDiv = document.createElement('div');
toolDiv.className = 'message assistant tool-message';
const headerDiv = document.createElement('div');
headerDiv.className = 'tool-header';
headerDiv.innerHTML = `🔧 <strong>${data.tool_name}</strong>`;
const contentDiv = document.createElement('div');
contentDiv.className = 'tool-content';
// Format the arguments
if (data.tool_name === 'developer__shell' && data.arguments.command) {
contentDiv.innerHTML = `<pre><code>${escapeHtml(data.arguments.command)}</code></pre>`;
} else if (data.tool_name === 'developer__text_editor') {
const action = data.arguments.command || 'unknown';
const path = data.arguments.path || 'unknown';
contentDiv.innerHTML = `<div class="tool-param"><strong>action:</strong> ${action}</div>`;
contentDiv.innerHTML += `<div class="tool-param"><strong>path:</strong> ${escapeHtml(path)}</div>`;
if (data.arguments.file_text) {
contentDiv.innerHTML += `<div class="tool-param"><strong>content:</strong> <pre><code>${escapeHtml(data.arguments.file_text.substring(0, 200))}${data.arguments.file_text.length > 200 ? '...' : ''}</code></pre></div>`;
}
} else {
contentDiv.innerHTML = `<pre><code>${JSON.stringify(data.arguments, null, 2)}</code></pre>`;
}
toolDiv.appendChild(headerDiv);
toolDiv.appendChild(contentDiv);
// Add a "running" indicator
const runningDiv = document.createElement('div');
runningDiv.className = 'tool-running';
runningDiv.innerHTML = '⏳ Running...';
toolDiv.appendChild(runningDiv);
messagesContainer.appendChild(toolDiv);
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
// Handle tool responses
function handleToolResponse(data) {
// Remove the "running" indicator from the last tool message
const toolMessages = messagesContainer.querySelectorAll('.tool-message');
if (toolMessages.length > 0) {
const lastToolMessage = toolMessages[toolMessages.length - 1];
const runningIndicator = lastToolMessage.querySelector('.tool-running');
if (runningIndicator) {
runningIndicator.remove();
}
}
if (data.is_error) {
const errorDiv = document.createElement('div');
errorDiv.className = 'message tool-error';
errorDiv.innerHTML = `<strong>Tool Error:</strong> ${escapeHtml(data.result.error || 'Unknown error')}`;
messagesContainer.appendChild(errorDiv);
} else {
// Handle successful tool response
if (Array.isArray(data.result)) {
data.result.forEach(content => {
if (content.type === 'text' && content.text) {
const responseDiv = document.createElement('div');
responseDiv.className = 'message tool-result';
responseDiv.innerHTML = `<pre>${escapeHtml(content.text)}</pre>`;
messagesContainer.appendChild(responseDiv);
}
});
}
}
messagesContainer.scrollTop = messagesContainer.scrollHeight;
// Reset streaming message so next assistant response creates a new message
currentStreamingMessage = null;
// Show thinking indicator because assistant will likely follow up with explanation
// Only show if we're still processing (cancel button is active)
if (isProcessing) {
addThinkingIndicator();
}
}
// Handle tool confirmations
function handleToolConfirmation(data) {
const confirmDiv = document.createElement('div');
confirmDiv.className = 'message tool-confirmation';
confirmDiv.innerHTML = `
<div class="tool-confirm-header">⚠️ Tool Confirmation Required</div>
<div class="tool-confirm-content">
<strong>${data.tool_name}</strong> wants to execute with:
<pre><code>${JSON.stringify(data.arguments, null, 2)}</code></pre>
</div>
<div class="tool-confirm-note">Auto-approved in web mode (UI coming soon)</div>
`;
messagesContainer.appendChild(confirmDiv);
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
// Handle thinking messages
function handleThinking(data) {
// For now, just log thinking messages
console.log('Thinking:', data.message);
}
// Handle context exceeded
function handleContextExceeded(data) {
const contextDiv = document.createElement('div');
contextDiv.className = 'message context-warning';
contextDiv.innerHTML = `
<div class="context-header">⚠️ Context Length Exceeded</div>
<div class="context-content">${escapeHtml(data.message)}</div>
<div class="context-note">Auto-summarizing conversation...</div>
`;
messagesContainer.appendChild(contextDiv);
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
// Handle cancelled operation
function handleCancelled(data) {
removeThinkingIndicator();
resetSendButton();
const cancelDiv = document.createElement('div');
cancelDiv.className = 'message system-message cancelled';
cancelDiv.innerHTML = `<em>${escapeHtml(data.message)}</em>`;
messagesContainer.appendChild(cancelDiv);
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
// Handle completion of response
function handleComplete(data) {
removeThinkingIndicator();
resetSendButton();
// Finalize any streaming message
if (currentStreamingMessage) {
currentStreamingMessage = null;
}
}
// Reset send button to normal state
function resetSendButton() {
isProcessing = false;
sendButton.textContent = 'Send';
sendButton.classList.remove('cancel-mode');
}
// Escape HTML to prevent XSS
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
// Send message or cancel
function sendMessage() {
if (isProcessing) {
// Cancel the current operation
socket.send(JSON.stringify({
type: 'cancel',
session_id: sessionId
}));
return;
}
const message = messageInput.value.trim();
if (!message || !isConnected) return;
// Add user message to chat
addMessage(message, 'user', Date.now());
// Clear input
messageInput.value = '';
messageInput.style.height = 'auto';
// Add thinking indicator
addThinkingIndicator();
// Update button to show cancel
isProcessing = true;
sendButton.textContent = 'Cancel';
sendButton.classList.add('cancel-mode');
// Send message through WebSocket
socket.send(JSON.stringify({
type: 'message',
content: message,
session_id: sessionId,
timestamp: Date.now()
}));
}
// Handle suggestion pill clicks
function sendSuggestion(text) {
if (!isConnected || isProcessing) return;
messageInput.value = text;
sendMessage();
}
// Load session history if the session exists (like --resume in CLI)
async function loadSessionIfExists() {
try {
const response = await fetch(`/api/sessions/${sessionId}`);
if (response.ok) {
const sessionData = await response.json();
if (sessionData.messages && sessionData.messages.length > 0) {
// Remove welcome message since we're resuming
const welcomeMessage = messagesContainer.querySelector('.welcome-message');
if (welcomeMessage) {
welcomeMessage.remove();
}
// Display session resumed message
const resumeDiv = document.createElement('div');
resumeDiv.className = 'message system-message';
resumeDiv.innerHTML = `<em>Session resumed: ${sessionData.messages.length} messages loaded</em>`;
messagesContainer.appendChild(resumeDiv);
// Update page title with session description if available
if (sessionData.metadata && sessionData.metadata.description) {
document.title = `Goose Chat - ${sessionData.metadata.description}`;
}
messagesContainer.scrollTop = messagesContainer.scrollHeight;
}
}
} catch (error) {
console.log('No existing session found or error loading:', error);
// This is fine - just means it's a new session
}
}
// Event listeners
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendMessage();
}
});
// Auto-resize textarea
messageInput.addEventListener('input', () => {
messageInput.style.height = 'auto';
messageInput.style.height = messageInput.scrollHeight + 'px';
});
// Initialize WebSocket connection
connectWebSocket();
// Focus on input
messageInput.focus();
// Update session title
function updateSessionTitle() {
const titleElement = document.getElementById('session-title');
// Just show "Goose Chat" - no need to show session ID
titleElement.textContent = 'Goose Chat';
}
// Update title on load
updateSessionTitle();

View File

@@ -0,0 +1,480 @@
:root {
/* Dark theme colors (matching the dark.png) */
--bg-primary: #000000;
--bg-secondary: #0a0a0a;
--bg-tertiary: #1a1a1a;
--text-primary: #ffffff;
--text-secondary: #a0a0a0;
--text-muted: #666666;
--border-color: #333333;
--border-subtle: #1a1a1a;
--accent-color: #ffffff;
--accent-hover: #f0f0f0;
--user-bg: #1a1a1a;
--assistant-bg: #0a0a0a;
--input-bg: #0a0a0a;
--input-border: #333333;
--button-bg: #ffffff;
--button-text: #000000;
--button-hover: #e0e0e0;
--pill-bg: transparent;
--pill-border: #333333;
--pill-hover: #1a1a1a;
--tool-bg: #0f0f0f;
--code-bg: #0f0f0f;
}
/* Light theme */
@media (prefers-color-scheme: light) {
:root {
--bg-primary: #ffffff;
--bg-secondary: #fafafa;
--bg-tertiary: #f5f5f5;
--text-primary: #000000;
--text-secondary: #666666;
--text-muted: #999999;
--border-color: #e1e5e9;
--border-subtle: #f0f0f0;
--accent-color: #000000;
--accent-hover: #333333;
--user-bg: #f0f0f0;
--assistant-bg: #fafafa;
--input-bg: #ffffff;
--input-border: #e1e5e9;
--button-bg: #000000;
--button-text: #ffffff;
--button-hover: #333333;
--pill-bg: #f5f5f5;
--pill-border: #e1e5e9;
--pill-hover: #e8eaed;
--tool-bg: #f8f9fa;
--code-bg: #f5f5f5;
}
}
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
background-color: var(--bg-primary);
color: var(--text-primary);
line-height: 1.5;
height: 100vh;
overflow: hidden;
font-size: 14px;
}
.container {
display: flex;
flex-direction: column;
height: 100vh;
max-width: 100%;
margin: 0 auto;
}
header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 1rem 1.5rem;
background-color: var(--bg-primary);
border-bottom: 1px solid var(--border-subtle);
}
header h1 {
font-size: 1.25rem;
font-weight: 600;
display: flex;
align-items: center;
gap: 0.75rem;
}
header h1::before {
content: "🪿";
font-size: 1.5rem;
}
.status {
font-size: 0.75rem;
color: var(--text-secondary);
padding: 0.25rem 0.75rem;
border-radius: 1rem;
background-color: var(--bg-secondary);
border: 1px solid var(--border-color);
}
.status.connected {
color: #10b981;
border-color: #10b981;
background-color: rgba(16, 185, 129, 0.1);
}
.status.disconnected {
color: #ef4444;
border-color: #ef4444;
background-color: rgba(239, 68, 68, 0.1);
}
.chat-container {
flex: 1;
display: flex;
flex-direction: column;
overflow: hidden;
}
.messages {
flex: 1;
overflow-y: auto;
padding: 2rem;
display: flex;
flex-direction: column;
gap: 1.5rem;
}
.welcome-message {
text-align: center;
padding: 4rem 2rem;
color: var(--text-secondary);
}
.welcome-message h2 {
font-size: 1.5rem;
margin-bottom: 1rem;
color: var(--text-primary);
font-weight: 600;
}
.welcome-message p {
font-size: 1rem;
margin-bottom: 2rem;
}
/* Suggestion pills like in the design */
.suggestion-pills {
display: flex;
flex-wrap: wrap;
gap: 0.75rem;
justify-content: center;
margin-top: 2rem;
}
.suggestion-pill {
padding: 0.75rem 1.25rem;
background-color: var(--pill-bg);
border: 1px solid var(--pill-border);
border-radius: 2rem;
color: var(--text-primary);
font-size: 0.875rem;
cursor: pointer;
transition: all 0.2s ease;
text-decoration: none;
display: inline-block;
}
.suggestion-pill:hover {
background-color: var(--pill-hover);
border-color: var(--border-color);
}
.message {
max-width: 80%;
padding: 1rem 1.25rem;
border-radius: 1rem;
word-wrap: break-word;
position: relative;
}
.message.user {
align-self: flex-end;
background-color: var(--user-bg);
margin-left: auto;
border: 1px solid var(--border-subtle);
}
.message.assistant {
align-self: flex-start;
background-color: var(--assistant-bg);
border: 1px solid var(--border-subtle);
}
.message-content {
flex: 1;
margin-bottom: 0.5rem;
}
.message .timestamp {
font-size: 0.6875rem;
color: var(--text-muted);
margin-top: 0.5rem;
opacity: 0.7;
}
.message pre {
background-color: var(--code-bg);
padding: 0.75rem;
border-radius: 0.5rem;
overflow-x: auto;
margin: 0.75rem 0;
border: 1px solid var(--border-color);
font-size: 0.8125rem;
}
.message code {
background-color: var(--code-bg);
padding: 0.125rem 0.375rem;
border-radius: 0.25rem;
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
font-size: 0.8125rem;
border: 1px solid var(--border-color);
}
.input-container {
display: flex;
gap: 0.75rem;
padding: 1.5rem;
background-color: var(--bg-primary);
border-top: 1px solid var(--border-subtle);
}
#message-input {
flex: 1;
padding: 0.875rem 1rem;
border: 1px solid var(--input-border);
border-radius: 0.75rem;
background-color: var(--input-bg);
color: var(--text-primary);
font-family: inherit;
font-size: 0.875rem;
resize: none;
min-height: 2.75rem;
max-height: 8rem;
outline: none;
transition: border-color 0.2s ease;
}
#message-input:focus {
border-color: var(--accent-color);
}
#message-input::placeholder {
color: var(--text-muted);
}
#send-button {
padding: 0.875rem 1.5rem;
background-color: var(--button-bg);
color: var(--button-text);
border: none;
border-radius: 0.75rem;
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: all 0.2s ease;
min-width: 4rem;
}
#send-button:hover {
background-color: var(--button-hover);
transform: translateY(-1px);
}
#send-button:disabled {
opacity: 0.5;
cursor: not-allowed;
transform: none;
}
#send-button.cancel-mode {
background-color: #ef4444;
color: #ffffff;
}
#send-button.cancel-mode:hover {
background-color: #dc2626;
}
/* Scrollbar styling */
.messages::-webkit-scrollbar {
width: 6px;
}
.messages::-webkit-scrollbar-track {
background: transparent;
}
.messages::-webkit-scrollbar-thumb {
background: var(--border-color);
border-radius: 3px;
}
.messages::-webkit-scrollbar-thumb:hover {
background: var(--text-secondary);
}
/* Tool call styling */
.tool-message, .tool-result, .tool-error, .tool-confirmation, .context-warning {
background-color: var(--tool-bg);
border: 1px solid var(--border-color);
border-radius: 0.75rem;
padding: 1rem;
margin: 0.75rem 0;
max-width: 90%;
}
.tool-header, .tool-confirm-header, .context-header {
font-weight: 600;
color: var(--accent-color);
margin-bottom: 0.75rem;
font-size: 0.875rem;
}
.tool-content {
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
font-size: 0.8125rem;
color: var(--text-secondary);
}
.tool-param {
margin: 0.5rem 0;
}
.tool-param strong {
color: var(--text-primary);
}
.tool-running {
font-size: 0.8125rem;
color: var(--accent-color);
margin-top: 0.75rem;
font-style: italic;
}
.tool-error {
border-color: #ef4444;
background-color: rgba(239, 68, 68, 0.05);
}
.tool-error strong {
color: #ef4444;
}
.tool-result {
background-color: var(--tool-bg);
border-left: 3px solid var(--accent-color);
margin-left: 1.5rem;
border-radius: 0.5rem;
}
.tool-confirmation {
border-color: #f59e0b;
background-color: rgba(245, 158, 11, 0.05);
}
.tool-confirm-note, .context-note {
font-size: 0.75rem;
color: var(--text-muted);
margin-top: 0.75rem;
font-style: italic;
}
.context-warning {
border-color: #f59e0b;
background-color: rgba(245, 158, 11, 0.05);
}
.context-header {
color: #f59e0b;
}
.system-message {
text-align: center;
color: var(--text-secondary);
font-style: italic;
margin: 1rem 0;
font-size: 0.875rem;
}
.cancelled {
color: #ef4444;
}
/* Thinking indicator */
.thinking-message {
display: flex;
align-items: center;
gap: 0.75rem;
color: var(--text-secondary);
font-style: italic;
padding: 1rem 1.25rem;
background-color: var(--bg-secondary);
border-radius: 1rem;
border: 1px solid var(--border-subtle);
max-width: 80%;
font-size: 0.875rem;
}
.thinking-dots {
display: flex;
gap: 0.25rem;
}
.thinking-dots span {
width: 6px;
height: 6px;
border-radius: 50%;
background-color: var(--text-secondary);
animation: thinking-bounce 1.4s infinite ease-in-out both;
}
.thinking-dots span:nth-child(1) {
animation-delay: -0.32s;
}
.thinking-dots span:nth-child(2) {
animation-delay: -0.16s;
}
@keyframes thinking-bounce {
0%, 80%, 100% {
transform: scale(0.6);
opacity: 0.5;
}
40% {
transform: scale(1);
opacity: 1;
}
}
/* Keep the old loading indicator for backwards compatibility */
.loading-message {
display: none;
}
/* Responsive design */
@media (max-width: 768px) {
.messages {
padding: 1rem;
gap: 1rem;
}
.message {
max-width: 90%;
padding: 0.875rem 1rem;
}
.input-container {
padding: 1rem;
}
header {
padding: 0.75rem 1rem;
}
.welcome-message {
padding: 2rem 1rem;
}
}

View File

@@ -3,7 +3,7 @@ use std::ptr;
use std::sync::Arc;
use futures::StreamExt;
use goose::agents::Agent;
use goose::agents::{Agent, AgentEvent};
use goose::message::Message;
use goose::model::ModelConfig;
use goose::providers::databricks::DatabricksProvider;
@@ -256,13 +256,16 @@ pub unsafe extern "C" fn goose_agent_send_message(
while let Some(message_result) = stream.next().await {
match message_result {
Ok(message) => {
Ok(AgentEvent::Message(message)) => {
// Get text or serialize to JSON
// Note: Message doesn't have as_text method, we'll serialize to JSON
if let Ok(json) = serde_json::to_string(&message) {
full_response.push_str(&json);
}
}
Ok(AgentEvent::McpNotification(_)) => {
// TODO: Handle MCP notifications.
}
Err(e) => {
full_response.push_str(&format!("\nError in message stream: {}", e));
}

View File

@@ -138,6 +138,11 @@ impl DatabricksProvider {
"reduce the length",
"token count",
"exceeds",
"exceed context limit",
"input length",
"max_tokens",
"decrease input length",
"context limit",
];
if check_phrases.iter().any(|c| payload_str.contains(c)) {
return Err(ProviderError::ContextLengthExceeded(payload_str));

View File

@@ -6,7 +6,7 @@ use serde_json::{json, Value};
use std::{
collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex,
};
use tokio::process::Command;
use tokio::{process::Command, sync::mpsc};
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
@@ -14,7 +14,7 @@ use std::os::unix::fs::PermissionsExt;
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource,
tool::{Tool, ToolAnnotations},
Content,
@@ -1155,6 +1155,7 @@ impl Router for ComputerControllerRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -13,13 +13,17 @@ use std::{
path::{Path, PathBuf},
pin::Pin,
};
use tokio::process::Command;
use tokio::{
io::{AsyncBufReadExt, BufReader},
process::Command,
sync::mpsc,
};
use url::Url;
use include_dir::{include_dir, Dir};
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities},
resource::Resource,
tool::Tool,
Content,
@@ -404,10 +408,20 @@ impl DeveloperRouter {
if local_ignore_path.is_file() {
let _ = builder.add(local_ignore_path);
has_ignore_file = true;
} else {
// If no .gooseignore exists, check for .gitignore as fallback
let gitignore_path = cwd.join(".gitignore");
if gitignore_path.is_file() {
tracing::debug!(
"No .gooseignore found, using .gitignore as fallback for ignore patterns"
);
let _ = builder.add(gitignore_path);
has_ignore_file = true;
}
}
// Only use default patterns if no .gooseignore files were found
// If the file is empty, we will not ignore any file
// AND no .gitignore was used as fallback
if !has_ignore_file {
// Add some sensible defaults
let _ = builder.add_line(None, "**/.env");
@@ -456,7 +470,11 @@ impl DeveloperRouter {
}
// Shell command execution with platform-specific handling
async fn bash(&self, params: Value) -> Result<Vec<Content>, ToolError> {
async fn bash(
&self,
params: Value,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> Result<Vec<Content>, ToolError> {
let command =
params
.get("command")
@@ -488,27 +506,102 @@ impl DeveloperRouter {
// Get platform-specific shell configuration
let shell_config = get_shell_config();
let cmd_with_redirect = format_command_for_platform(command);
let cmd_str = format_command_for_platform(command);
// Execute the command using platform-specific shell
let child = Command::new(&shell_config.executable)
let mut child = Command::new(&shell_config.executable)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.stdin(Stdio::null())
.kill_on_drop(true)
.arg(&shell_config.arg)
.arg(cmd_with_redirect)
.arg(cmd_str)
.spawn()
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let mut stdout_reader = BufReader::new(stdout);
let mut stderr_reader = BufReader::new(stderr);
let output_task = tokio::spawn(async move {
let mut combined_output = String::new();
let mut stdout_buf = Vec::new();
let mut stderr_buf = Vec::new();
let mut stdout_done = false;
let mut stderr_done = false;
loop {
tokio::select! {
n = stdout_reader.read_until(b'\n', &mut stdout_buf), if !stdout_done => {
if n? == 0 {
stdout_done = true;
} else {
let line = String::from_utf8_lossy(&stdout_buf);
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": {
"type": "shell",
"stream": "stdout",
"output": line.to_string(),
}
})),
})).ok();
combined_output.push_str(&line);
stdout_buf.clear();
}
}
n = stderr_reader.read_until(b'\n', &mut stderr_buf), if !stderr_done => {
if n? == 0 {
stderr_done = true;
} else {
let line = String::from_utf8_lossy(&stderr_buf);
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": {
"type": "shell",
"stream": "stderr",
"output": line.to_string(),
}
})),
})).ok();
combined_output.push_str(&line);
stderr_buf.clear();
}
}
else => break,
}
if stdout_done && stderr_done {
break;
}
}
Ok::<_, std::io::Error>(combined_output)
});
// Wait for the command to complete and get output
let output = child
.wait_with_output()
child
.wait()
.await
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
let stdout_str = String::from_utf8_lossy(&output.stdout);
let output_str = stdout_str;
let output_str = match output_task.await {
Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
Err(e) => return Err(ToolError::ExecutionError(e.to_string())),
};
// Check the character count of the output
const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB
@@ -1048,12 +1141,13 @@ impl Router for DeveloperRouter {
&self,
tool_name: &str,
arguments: Value,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();
Box::pin(async move {
match tool_name.as_str() {
"shell" => this.bash(arguments).await,
"shell" => this.bash(arguments, notifier).await,
"text_editor" => this.text_editor(arguments).await,
"list_windows" => this.list_windows(arguments).await,
"screen_capture" => this.screen_capture(arguments).await,
@@ -1195,6 +1289,10 @@ mod tests {
.await
}
fn dummy_sender() -> mpsc::Sender<JsonRpcMessage> {
mpsc::channel(1).0
}
#[tokio::test]
#[serial]
async fn test_shell_missing_parameters() {
@@ -1202,7 +1300,7 @@ mod tests {
std::env::set_current_dir(&temp_dir).unwrap();
let router = get_router().await;
let result = router.call_tool("shell", json!({})).await;
let result = router.call_tool("shell", json!({}), dummy_sender()).await;
assert!(result.is_err());
let err = result.err().unwrap();
@@ -1263,6 +1361,7 @@ mod tests {
"command": "view",
"path": large_file_str
}),
dummy_sender(),
)
.await;
@@ -1288,6 +1387,7 @@ mod tests {
"command": "view",
"path": many_chars_str
}),
dummy_sender(),
)
.await;
@@ -1319,6 +1419,7 @@ mod tests {
"path": file_path_str,
"file_text": "Hello, world!"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1331,6 +1432,7 @@ mod tests {
"command": "view",
"path": file_path_str
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1369,6 +1471,7 @@ mod tests {
"path": file_path_str,
"file_text": "Hello, world!"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1383,6 +1486,7 @@ mod tests {
"old_str": "world",
"new_str": "Rust"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1407,6 +1511,7 @@ mod tests {
"command": "view",
"path": file_path_str
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1444,6 +1549,7 @@ mod tests {
"path": file_path_str,
"file_text": "First line"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1458,6 +1564,7 @@ mod tests {
"old_str": "First line",
"new_str": "Second line"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1470,6 +1577,7 @@ mod tests {
"command": "undo_edit",
"path": file_path_str
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1485,6 +1593,7 @@ mod tests {
"command": "view",
"path": file_path_str
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1583,6 +1692,7 @@ mod tests {
"path": temp_dir.path().join("secret.txt").to_str().unwrap(),
"file_text": "test content"
}),
dummy_sender(),
)
.await;
@@ -1601,6 +1711,7 @@ mod tests {
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
"file_text": "test content"
}),
dummy_sender(),
)
.await;
@@ -1642,6 +1753,7 @@ mod tests {
json!({
"command": format!("cat {}", secret_file_path.to_str().unwrap())
}),
dummy_sender(),
)
.await;
@@ -1658,6 +1770,202 @@ mod tests {
json!({
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
}),
dummy_sender(),
)
.await;
assert!(result.is_ok(), "Should be able to cat non-ignored file");
temp_dir.close().unwrap();
}
#[tokio::test]
#[serial]
async fn test_gitignore_fallback_when_no_gooseignore() {
let temp_dir = tempfile::tempdir().unwrap();
std::env::set_current_dir(&temp_dir).unwrap();
// Create a .gitignore file but no .gooseignore
std::fs::write(temp_dir.path().join(".gitignore"), "*.log\n*.tmp\n.env").unwrap();
let router = DeveloperRouter::new();
// Test that gitignore patterns are respected
assert!(
router.is_ignored(Path::new("test.log")),
"*.log pattern from .gitignore should be ignored"
);
assert!(
router.is_ignored(Path::new("build.tmp")),
"*.tmp pattern from .gitignore should be ignored"
);
assert!(
router.is_ignored(Path::new(".env")),
".env pattern from .gitignore should be ignored"
);
assert!(
!router.is_ignored(Path::new("test.txt")),
"test.txt should not be ignored"
);
temp_dir.close().unwrap();
}
#[tokio::test]
#[serial]
async fn test_gooseignore_takes_precedence_over_gitignore() {
let temp_dir = tempfile::tempdir().unwrap();
std::env::set_current_dir(&temp_dir).unwrap();
// Create both .gooseignore and .gitignore files with different patterns
std::fs::write(temp_dir.path().join(".gooseignore"), "*.secret").unwrap();
std::fs::write(temp_dir.path().join(".gitignore"), "*.log\ntarget/").unwrap();
let router = DeveloperRouter::new();
// .gooseignore patterns should be used
assert!(
router.is_ignored(Path::new("test.secret")),
"*.secret pattern from .gooseignore should be ignored"
);
// .gitignore patterns should NOT be used when .gooseignore exists
assert!(
!router.is_ignored(Path::new("test.log")),
"*.log pattern from .gitignore should NOT be ignored when .gooseignore exists"
);
assert!(
!router.is_ignored(Path::new("build.tmp")),
"*.tmp pattern from .gitignore should NOT be ignored when .gooseignore exists"
);
temp_dir.close().unwrap();
}
#[tokio::test]
#[serial]
async fn test_default_patterns_when_no_ignore_files() {
let temp_dir = tempfile::tempdir().unwrap();
std::env::set_current_dir(&temp_dir).unwrap();
// Don't create any ignore files
let router = DeveloperRouter::new();
// Default patterns should be used
assert!(
router.is_ignored(Path::new(".env")),
".env should be ignored by default patterns"
);
assert!(
router.is_ignored(Path::new(".env.local")),
".env.local should be ignored by default patterns"
);
assert!(
router.is_ignored(Path::new("secrets.txt")),
"secrets.txt should be ignored by default patterns"
);
assert!(
!router.is_ignored(Path::new("normal.txt")),
"normal.txt should not be ignored"
);
temp_dir.close().unwrap();
}
#[tokio::test]
#[serial]
async fn test_text_editor_respects_gitignore_fallback() {
let temp_dir = tempfile::tempdir().unwrap();
std::env::set_current_dir(&temp_dir).unwrap();
// Create a .gitignore file but no .gooseignore
std::fs::write(temp_dir.path().join(".gitignore"), "*.log").unwrap();
let router = DeveloperRouter::new();
// Try to write to a file ignored by .gitignore
let result = router
.call_tool(
"text_editor",
json!({
"command": "write",
"path": temp_dir.path().join("test.log").to_str().unwrap(),
"file_text": "test content"
}),
dummy_sender(),
)
.await;
assert!(
result.is_err(),
"Should not be able to write to file ignored by .gitignore fallback"
);
assert!(matches!(result.unwrap_err(), ToolError::ExecutionError(_)));
// Try to write to a non-ignored file
let result = router
.call_tool(
"text_editor",
json!({
"command": "write",
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
"file_text": "test content"
}),
dummy_sender(),
)
.await;
assert!(
result.is_ok(),
"Should be able to write to non-ignored file"
);
temp_dir.close().unwrap();
}
#[tokio::test]
#[serial]
async fn test_bash_respects_gitignore_fallback() {
let temp_dir = tempfile::tempdir().unwrap();
std::env::set_current_dir(&temp_dir).unwrap();
// Create a .gitignore file but no .gooseignore
std::fs::write(temp_dir.path().join(".gitignore"), "*.log").unwrap();
let router = DeveloperRouter::new();
// Create a file that would be ignored by .gitignore
let log_file_path = temp_dir.path().join("test.log");
std::fs::write(&log_file_path, "log content").unwrap();
// Try to cat the ignored file
let result = router
.call_tool(
"shell",
json!({
"command": format!("cat {}", log_file_path.to_str().unwrap())
}),
dummy_sender(),
)
.await;
assert!(
result.is_err(),
"Should not be able to cat file ignored by .gitignore fallback"
);
assert!(matches!(result.unwrap_err(), ToolError::ExecutionError(_)));
// Try to cat a non-ignored file
let allowed_file_path = temp_dir.path().join("allowed.txt");
std::fs::write(&allowed_file_path, "allowed content").unwrap();
let result = router
.call_tool(
"shell",
json!({
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
}),
dummy_sender(),
)
.await;

View File

@@ -4,7 +4,6 @@ use std::env;
pub struct ShellConfig {
pub executable: String,
pub arg: String,
pub redirect_syntax: String,
}
impl Default for ShellConfig {
@@ -14,13 +13,11 @@ impl Default for ShellConfig {
Self {
executable: "powershell.exe".to_string(),
arg: "-NoProfile -NonInteractive -Command".to_string(),
redirect_syntax: "2>&1".to_string(),
}
} else {
Self {
executable: "bash".to_string(),
arg: "-c".to_string(),
redirect_syntax: "2>&1".to_string(),
}
}
}
@@ -31,13 +28,12 @@ pub fn get_shell_config() -> ShellConfig {
}
pub fn format_command_for_platform(command: &str) -> String {
let config = get_shell_config();
if cfg!(windows) {
// For PowerShell, wrap the command in braces to handle special characters
format!("{{ {} }} {}", command, config.redirect_syntax)
format!("{{ {} }}", command)
} else {
// For other shells, no braces needed
format!("{} {}", command, config.redirect_syntax)
command.to_string()
}
}

View File

@@ -7,6 +7,7 @@ use base64::Engine;
use chrono::NaiveDate;
use indoc::indoc;
use lazy_static::lazy_static;
use mcp_core::protocol::JsonRpcMessage;
use mcp_core::tool::ToolAnnotations;
use oauth_pkce::PkceOAuth2Client;
use regex::Regex;
@@ -14,6 +15,7 @@ use serde_json::{json, Value};
use std::io::Cursor;
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
use storage::CredentialsManager;
use tokio::sync::mpsc;
use mcp_core::content::Content;
use mcp_core::{
@@ -3281,6 +3283,7 @@ impl Router for GoogleDriveRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -5,7 +5,7 @@ use mcp_core::{
content::Content,
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource,
role::Role,
tool::Tool,
@@ -16,7 +16,7 @@ use serde_json::Value;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{sleep, Duration};
use tracing::error;
@@ -158,6 +158,7 @@ impl Router for JetBrainsRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -10,11 +10,12 @@ use std::{
path::PathBuf,
pin::Pin,
};
use tokio::sync::mpsc;
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource,
tool::{Tool, ToolAnnotations, ToolCall},
Content,
@@ -520,6 +521,7 @@ impl Router for MemoryRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -3,11 +3,12 @@ use include_dir::{include_dir, Dir};
use indoc::formatdoc;
use serde_json::{json, Value};
use std::{future::Future, pin::Pin};
use tokio::sync::mpsc;
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource,
role::Role,
tool::{Tool, ToolAnnotations},
@@ -130,6 +131,7 @@ impl Router for TutorialRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -45,6 +45,8 @@ use utoipa::OpenApi;
super::routes::schedule::run_now_handler,
super::routes::schedule::pause_schedule,
super::routes::schedule::unpause_schedule,
super::routes::schedule::kill_running_job,
super::routes::schedule::inspect_running_job,
super::routes::schedule::sessions_handler
),
components(schemas(
@@ -95,6 +97,8 @@ use utoipa::OpenApi;
SessionMetadata,
super::routes::schedule::CreateScheduleRequest,
super::routes::schedule::UpdateScheduleRequest,
super::routes::schedule::KillJobResponse,
super::routes::schedule::InspectJobResponse,
goose::scheduler::ScheduledJob,
super::routes::schedule::RunNowResponse,
super::routes::schedule::ListSchedulesResponse,

View File

@@ -20,7 +20,7 @@
"gcp_vertex_ai": {
"name": "GCP Vertex AI",
"description": "Use Vertex AI platform models",
"models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05", "gemini-2.5-pro-exp-03-25"],
"models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05", "gemini-2.5-pro-exp-03-25", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-05-06"],
"required_keys": ["GCP_PROJECT_ID", "GCP_LOCATION"]
},
"google": {

View File

@@ -10,7 +10,7 @@ use axum::{
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::{
agents::SessionConfig,
agents::{AgentEvent, SessionConfig},
message::{Message, MessageContent},
permission::permission_confirmation::PrincipalType,
};
@@ -18,7 +18,7 @@ use goose::{
permission::{Permission, PermissionConfirmation},
session,
};
use mcp_core::{role::Role, Content, ToolResult};
use mcp_core::{protocol::JsonRpcMessage, role::Role, Content, ToolResult};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
@@ -79,9 +79,19 @@ impl IntoResponse for SseResponse {
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum MessageEvent {
Message { message: Message },
Error { error: String },
Finish { reason: String },
Message {
message: Message,
},
Error {
error: String,
},
Finish {
reason: String,
},
Notification {
request_id: String,
message: JsonRpcMessage,
},
}
async fn stream_event(
@@ -200,7 +210,7 @@ async fn handler(
tokio::select! {
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(message))) => {
Ok(Some(Ok(AgentEvent::Message(message)))) => {
all_messages.push(message.clone());
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
@@ -223,6 +233,20 @@ async fn handler(
}
});
}
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
if let Err(e) = stream_event(MessageEvent::Notification{
request_id: request_id.clone(),
message: n,
}, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
let _ = stream_event(
@@ -317,7 +341,7 @@ async fn ask_handler(
while let Some(response) = stream.next().await {
match response {
Ok(message) => {
Ok(AgentEvent::Message(message)) => {
if message.role == Role::Assistant {
for content in &message.content {
if let MessageContent::Text(text) = content {
@@ -328,6 +352,10 @@ async fn ask_handler(
}
}
}
Ok(AgentEvent::McpNotification(n)) => {
// Handle notifications if needed
tracing::info!("Received notification: {:?}", n);
}
Err(e) => {
tracing::error!("Error processing as_ai message: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);

View File

@@ -31,6 +31,21 @@ pub struct ListSchedulesResponse {
jobs: Vec<ScheduledJob>,
}
// Response for the kill endpoint
#[derive(Serialize, utoipa::ToSchema)]
pub struct KillJobResponse {
message: String,
}
// Response for the inspect endpoint
#[derive(Serialize, utoipa::ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct InspectJobResponse {
session_id: Option<String>,
process_start_time: Option<String>,
running_duration_seconds: Option<i64>,
}
// Response for the run_now endpoint
#[derive(Serialize, utoipa::ToSchema)]
pub struct RunNowResponse {
@@ -100,6 +115,8 @@ async fn create_schedule(
last_run: None,
currently_running: false,
paused: false,
current_session_id: None,
process_start_time: None,
};
scheduler
.add_scheduled_job(job.clone())
@@ -199,6 +216,17 @@ async fn run_now_handler(
eprintln!("Error running schedule '{}' now: {:?}", id, e);
match e {
goose::scheduler::SchedulerError::JobNotFound(_) => Err(StatusCode::NOT_FOUND),
goose::scheduler::SchedulerError::AnyhowError(ref err) => {
// Check if this is a cancellation error
if err.to_string().contains("was successfully cancelled") {
// Return a special session_id to indicate cancellation
Ok(Json(RunNowResponse {
session_id: "CANCELLED".to_string(),
}))
} else {
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
_ => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
@@ -389,6 +417,92 @@ async fn update_schedule(
Ok(Json(updated_job))
}
#[utoipa::path(
post,
path = "/schedule/{id}/kill",
responses(
(status = 200, description = "Running job killed successfully"),
),
tag = "schedule"
)]
#[axum::debug_handler]
pub async fn kill_running_job(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Path(id): Path<String>,
) -> Result<Json<KillJobResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let scheduler = state
.scheduler()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
scheduler.kill_running_job(&id).await.map_err(|e| {
eprintln!("Error killing running job '{}': {:?}", id, e);
match e {
goose::scheduler::SchedulerError::JobNotFound(_) => StatusCode::NOT_FOUND,
goose::scheduler::SchedulerError::AnyhowError(_) => StatusCode::BAD_REQUEST,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
})?;
Ok(Json(KillJobResponse {
message: format!("Successfully killed running job '{}'", id),
}))
}
#[utoipa::path(
get,
path = "/schedule/{id}/inspect",
params(
("id" = String, Path, description = "ID of the schedule to inspect")
),
responses(
(status = 200, description = "Running job information", body = InspectJobResponse),
(status = 404, description = "Scheduled job not found"),
(status = 500, description = "Internal server error")
),
tag = "schedule"
)]
#[axum::debug_handler]
pub async fn inspect_running_job(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Path(id): Path<String>,
) -> Result<Json<InspectJobResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let scheduler = state
.scheduler()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
match scheduler.get_running_job_info(&id).await {
Ok(info) => {
if let Some((session_id, start_time)) = info {
let duration = chrono::Utc::now().signed_duration_since(start_time);
Ok(Json(InspectJobResponse {
session_id: Some(session_id),
process_start_time: Some(start_time.to_rfc3339()),
running_duration_seconds: Some(duration.num_seconds()),
}))
} else {
Ok(Json(InspectJobResponse {
session_id: None,
process_start_time: None,
running_duration_seconds: None,
}))
}
}
Err(e) => {
eprintln!("Error inspecting running job '{}': {:?}", id, e);
match e {
goose::scheduler::SchedulerError::JobNotFound(_) => Err(StatusCode::NOT_FOUND),
_ => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
}
}
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/schedule/create", post(create_schedule))
@@ -398,6 +512,8 @@ pub fn routes(state: Arc<AppState>) -> Router {
.route("/schedule/{id}/run_now", post(run_now_handler)) // Corrected
.route("/schedule/{id}/pause", post(pause_schedule))
.route("/schedule/{id}/unpause", post(unpause_schedule))
.route("/schedule/{id}/kill", post(kill_running_job))
.route("/schedule/{id}/inspect", get(inspect_running_job))
.route("/schedule/{id}/sessions", get(sessions_handler)) // Corrected
.with_state(state)
}

File diff suppressed because it is too large Load Diff

View File

@@ -71,10 +71,10 @@ aws-sdk-bedrockruntime = "1.74.0"
# For GCP Vertex AI provider auth
jsonwebtoken = "9.3.1"
# Added blake3 hashing library as a dependency
blake3 = "1.5"
fs2 = "0.4.3"
futures-util = "0.3.31"
tokio-stream = "0.1.17"
# Vector database for tool selection
lancedb = "0.13"

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use dotenv::dotenv;
use futures::StreamExt;
use goose::agents::{Agent, ExtensionConfig};
use goose::agents::{Agent, AgentEvent, ExtensionConfig};
use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT};
use goose::message::Message;
use goose::providers::databricks::DatabricksProvider;
@@ -20,10 +20,11 @@ async fn main() {
let config = ExtensionConfig::stdio(
"developer",
"./target/debug/developer",
"./target/debug/goose",
DEFAULT_EXTENSION_DESCRIPTION,
DEFAULT_EXTENSION_TIMEOUT,
);
)
.with_args(vec!["mcp", "developer"]);
agent.add_extension(config).await.unwrap();
println!("Extensions:");
@@ -35,11 +36,8 @@ async fn main() {
.with_text("can you summarize the readme.md in this dir using just a haiku?")];
let mut stream = agent.reply(&messages, None).await.unwrap();
while let Some(message) = stream.next().await {
println!(
"{}",
serde_json::to_string_pretty(&message.unwrap()).unwrap()
);
while let Some(Ok(AgentEvent::Message(message))) = stream.next().await {
println!("{}", serde_json::to_string_pretty(&message).unwrap());
println!("\n");
}
}

View File

@@ -1,9 +1,14 @@
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use futures::stream::BoxStream;
use futures::TryStreamExt;
use futures::{FutureExt, Stream, TryStreamExt};
use futures_util::stream;
use futures_util::stream::StreamExt;
use mcp_core::protocol::JsonRpcMessage;
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
use crate::message::Message;
@@ -39,7 +44,7 @@ use mcp_core::{
use super::platform_tools;
use super::router_tools;
use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
/// The main goose Agent
pub struct Agent {
@@ -56,6 +61,12 @@ pub struct Agent {
pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>,
}
#[derive(Clone, Debug)]
pub enum AgentEvent {
Message(Message),
McpNotification((String, JsonRpcMessage)),
}
impl Agent {
pub fn new() -> Self {
// Create channels with buffer size 32 (adjust if needed)
@@ -100,6 +111,40 @@ impl Default for Agent {
}
}
pub enum ToolStreamItem<T> {
Message(JsonRpcMessage),
Result(T),
}
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem<ToolResult<Vec<Content>>>> + Send>>;
// tool_stream combines a stream of JsonRpcMessages with a future representing the
// final result of the tool call. MCP notifications are not request-scoped, but
// this lets us capture all notifications emitted during the tool call for
// simpler consumption
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
where
S: Stream<Item = JsonRpcMessage> + Send + Unpin + 'static,
F: Future<Output = ToolResult<Vec<Content>>> + Send + 'static,
{
Box::pin(async_stream::stream! {
tokio::pin!(done);
let mut rx = rx;
loop {
tokio::select! {
Some(msg) = rx.next() => {
yield ToolStreamItem::Message(msg);
}
r = &mut done => {
yield ToolStreamItem::Result(r);
break;
}
}
}
})
}
impl Agent {
/// Get a reference count clone to the provider
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
@@ -143,7 +188,7 @@ impl Agent {
&self,
tool_call: mcp_core::tool::ToolCall,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
) -> (String, Result<ToolCallResult, ToolError>) {
// Check if this tool call should be allowed based on repetition monitoring
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
@@ -171,52 +216,65 @@ impl Agent {
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
return self
let (request_id, result) = self
.manage_extensions(action, extension_name, request_id)
.await;
return (request_id, Ok(ToolCallResult::from(result)));
}
let extension_manager = self.extension_manager.lock().await;
let result = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
let result: ToolCallResult = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
// Check if the tool is read_resource and handle it separately
ToolCallResult::from(
extension_manager
.read_resource(tool_call.arguments.clone())
.await
.await,
)
} else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME {
ToolCallResult::from(
extension_manager
.list_resources(tool_call.arguments.clone())
.await
.await,
)
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
extension_manager.search_available_extensions().await
ToolCallResult::from(extension_manager.search_available_extensions().await)
} else if self.is_frontend_tool(&tool_call.name).await {
// For frontend tools, return an error indicating we need frontend execution
Err(ToolError::ExecutionError(
ToolCallResult::from(Err(ToolError::ExecutionError(
"Frontend tool execution required".to_string(),
))
)))
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
let selector = self.router_tool_selector.lock().await.clone();
if let Some(selector) = selector {
ToolCallResult::from(if let Some(selector) = selector {
selector.select_tools(tool_call.arguments.clone()).await
} else {
Err(ToolError::ExecutionError(
"Encountered vector search error.".to_string(),
))
}
})
} else {
extension_manager
// Clone the result to ensure no references to extension_manager are returned
let result = extension_manager
.dispatch_tool_call(tool_call.clone())
.await
.await;
match result {
Ok(call_result) => call_result,
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
}
};
debug!(
"input" = serde_json::to_string(&tool_call).unwrap(),
"output" = serde_json::to_string(&result).unwrap(),
);
// Process the response to handle large text content
let processed_result = super::large_response_handler::process_tool_response(result);
(request_id, processed_result)
(
request_id,
Ok(ToolCallResult {
notification_stream: result.notification_stream,
result: Box::new(
result
.result
.map(super::large_response_handler::process_tool_response),
),
}),
)
}
pub(super) async fn manage_extensions(
@@ -466,7 +524,7 @@ impl Agent {
&self,
messages: &[Message],
session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
) -> anyhow::Result<BoxStream<'_, anyhow::Result<AgentEvent>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
@@ -532,9 +590,8 @@ impl Agent {
}
}
}
// Yield the assistant's response with frontend tool requests filtered out
yield filtered_response.clone();
yield AgentEvent::Message(filtered_response.clone());
tokio::task::yield_now().await;
@@ -556,7 +613,7 @@ impl Agent {
// execution is yeield back to this reply loop, and is of the same Message
// type, so we can yield that back up to be handled
while let Some(msg) = frontend_tool_stream.try_next().await? {
yield msg;
yield AgentEvent::Message(msg);
}
// Clone goose_mode once before the match to avoid move issues
@@ -584,13 +641,23 @@ impl Agent {
self.provider().await?).await;
// Handle pre-approved and read-only tools in parallel
let mut tool_futures: Vec<ToolFuture> = Vec::new();
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
// Skip the confirmation for approved tools
for request in &permission_check_result.approved {
if let Ok(tool_call) = request.tool_call.clone() {
let tool_future = self.dispatch_tool_call(tool_call, request.id.clone());
tool_futures.push(Box::pin(tool_future));
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
tool_futures.push((req_id, match tool_result {
Ok(result) => tool_stream(
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
result.result,
),
Err(e) => tool_stream(
Box::new(stream::empty()),
futures::future::ready(Err(e)),
),
}));
}
}
@@ -618,7 +685,7 @@ impl Agent {
// type, so we can yield the Message back up to be handled and grab any
// confirmations or denials
while let Some(msg) = tool_approval_stream.try_next().await? {
yield msg;
yield AgentEvent::Message(msg);
}
tool_futures = {
@@ -628,16 +695,30 @@ impl Agent {
futures_lock.drain(..).collect::<Vec<_>>()
};
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
let with_id = tool_futures
.into_iter()
.map(|(request_id, stream)| {
stream.map(move |item| (request_id.clone(), item))
})
.collect::<Vec<_>>();
let mut combined = stream::select_all(with_id);
let mut all_install_successful = true;
for (request_id, output) in results.into_iter() {
while let Some((request_id, item)) = combined.next().await {
match item {
ToolStreamItem::Result(output) => {
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
all_install_successful = false;
}
let mut response = message_tool_response.lock().await;
*response = response.clone().with_tool_response(request_id, output);
},
ToolStreamItem::Message(msg) => {
yield AgentEvent::McpNotification((request_id, msg))
}
}
}
// Update system prompt and tools if installations were successful
@@ -647,7 +728,7 @@ impl Agent {
}
let final_message_tool_resp = message_tool_response.lock().await.clone();
yield final_message_tool_resp.clone();
yield AgentEvent::Message(final_message_tool_resp.clone());
messages.push(response);
messages.push(final_message_tool_resp);
@@ -656,15 +737,15 @@ impl Agent {
// At this point, the last message should be a user message
// because call to provider led to context length exceeded error
// Immediately yield a special message and break
yield Message::assistant().with_context_length_exceeded(
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
"The context length of the model has been exceeded. Please start a new session and try again.",
);
));
break;
},
Err(e) => {
// Create an error message & terminate the stream
error!("Error: {}", e);
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
break;
}
}

View File

@@ -1,8 +1,7 @@
use anyhow::Result;
use chrono::{DateTime, TimeZone, Utc};
use futures::future;
use futures::stream::{FuturesUnordered, StreamExt};
use mcp_client::McpService;
use futures::{future, FutureExt};
use mcp_core::protocol::GetPromptResult;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
@@ -10,15 +9,22 @@ use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task;
use tracing::{debug, error, warn};
use tokio_stream::wrappers::ReceiverStream;
use tracing::{error, warn};
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
use super::tool_execution::ToolCallResult;
use crate::agents::extension::Envs;
use crate::config::{Config, ExtensionConfigManager};
use crate::prompt_template;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
<<<<<<< HEAD
use mcp_client::transport::{PendingRequests, SseTransport, StdioTransport, Transport};
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
=======
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError};
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
use serde_json::Value;
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
@@ -115,7 +121,8 @@ impl ExtensionManager {
/// Add a new MCP extension based on the provided client type
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
let sanitized_name = normalize(config.key().to_string());
let config_name = config.key().to_string();
let sanitized_name = normalize(config_name.clone());
/// Helper function to merge environment variables from direct envs and keychain-stored env_keys
async fn merge_environments(
@@ -185,6 +192,7 @@ impl ExtensionManager {
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
let transport = SseTransport::new(uri, all_envs);
let handle = transport.start().await?;
<<<<<<< HEAD
let pending = handle.pending_requests();
let service = McpService::with_timeout(
handle,
@@ -194,6 +202,17 @@ impl ExtensionManager {
);
self.pending_requests.insert(sanitized_name.clone(), pending);
Box::new(McpClient::new(service))
=======
Box::new(
McpClient::connect(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
)
.await?,
)
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
}
ExtensionConfig::Stdio {
cmd,
@@ -206,6 +225,7 @@ impl ExtensionManager {
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
let handle = transport.start().await?;
<<<<<<< HEAD
let pending = handle.pending_requests();
let service = McpService::with_timeout(
handle,
@@ -215,6 +235,17 @@ impl ExtensionManager {
);
self.pending_requests.insert(sanitized_name.clone(), pending);
Box::new(McpClient::new(service))
=======
Box::new(
McpClient::connect(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
)
.await?,
)
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
}
ExtensionConfig::Builtin {
name,
@@ -233,6 +264,7 @@ impl ExtensionManager {
HashMap::new(),
);
let handle = transport.start().await?;
<<<<<<< HEAD
let pending = handle.pending_requests();
let service = McpService::with_timeout(
handle,
@@ -242,6 +274,17 @@ impl ExtensionManager {
);
self.pending_requests.insert(sanitized_name.clone(), pending);
Box::new(McpClient::new(service))
=======
Box::new(
McpClient::connect(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
)
.await?,
)
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
}
_ => unreachable!(),
};
@@ -627,7 +670,7 @@ impl ExtensionManager {
}
}
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> ToolResult<Vec<Content>> {
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> Result<ToolCallResult> {
// Dispatch tool call based on the prefix naming convention
let (client_name, client) = self
.get_client_for_tool(&tool_call.name)
@@ -638,22 +681,26 @@ impl ExtensionManager {
.name
.strip_prefix(client_name)
.and_then(|s| s.strip_prefix("__"))
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?
.to_string();
let arguments = tool_call.arguments.clone();
let client = client.clone();
let notifications_receiver = client.lock().await.subscribe().await;
let fut = async move {
let client_guard = client.lock().await;
let result = client_guard
.call_tool(tool_name, tool_call.clone().arguments)
client_guard
.call_tool(&tool_name, arguments)
.await
.map(|result| result.content)
.map_err(|e| ToolError::ExecutionError(e.to_string()));
.map(|call| call.content)
.map_err(|e| ToolError::ExecutionError(e.to_string()))
};
debug!(
"input" = serde_json::to_string(&tool_call).unwrap(),
"output" = serde_json::to_string(&result).unwrap(),
);
result
Ok(ToolCallResult {
result: Box::new(fut.boxed()),
notification_stream: Some(Box::new(ReceiverStream::new(notifications_receiver))),
})
}
pub async fn list_prompts_from_extension(
@@ -811,10 +858,11 @@ mod tests {
use mcp_client::client::Error;
use mcp_client::client::McpClientTrait;
use mcp_core::protocol::{
CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult,
ListToolsResult, ReadResourceResult,
CallToolResult, GetPromptResult, InitializeResult, JsonRpcMessage, ListPromptsResult,
ListResourcesResult, ListToolsResult, ReadResourceResult,
};
use serde_json::json;
use tokio::sync::mpsc;
struct MockClient {}
@@ -867,6 +915,10 @@ mod tests {
) -> Result<GetPromptResult, Error> {
Err(Error::NotInitialized)
}
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
mpsc::channel(1).1
}
}
#[test]
@@ -988,6 +1040,9 @@ mod tests {
let result = extension_manager
.dispatch_tool_call(invalid_tool_call)
.await
.unwrap()
.result
.await;
assert!(matches!(
result.err().unwrap(),
@@ -1004,6 +1059,11 @@ mod tests {
let result = extension_manager
.dispatch_tool_call(invalid_tool_call)
.await;
assert!(matches!(result.err().unwrap(), ToolError::NotFound(_)));
if let Err(err) = result {
let tool_err = err.downcast_ref::<ToolError>().expect("Expected ToolError");
assert!(matches!(tool_err, ToolError::NotFound(_)));
} else {
panic!("Expected ToolError::NotFound");
}
}
}

View File

@@ -3,8 +3,7 @@ use mcp_core::{Content, ToolError};
use std::fs::File;
use std::io::Write;
// Constant for the size threshold (20K characters)
const LARGE_TEXT_THRESHOLD: usize = 20_000;
const LARGE_TEXT_THRESHOLD: usize = 200_000;
/// Process tool response and handle large text content
pub fn process_tool_response(

View File

@@ -13,7 +13,7 @@ mod tool_router_index_manager;
pub(crate) mod tool_vectordb;
mod types;
pub use agent::Agent;
pub use agent::{Agent, AgentEvent};
pub use extension::ExtensionConfig;
pub use extension_manager::ExtensionManager;
pub use prompt_manager::PromptManager;

View File

@@ -8,7 +8,8 @@ use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::{Provider, ProviderUsage};
use crate::providers::errors::ProviderError;
use crate::providers::toolshim::{
augment_message_with_tool_calls, modify_system_prompt_for_tool_json, OllamaInterpreter,
augment_message_with_tool_calls, convert_tool_messages_to_text,
modify_system_prompt_for_tool_json, OllamaInterpreter,
};
use crate::session;
use mcp_core::tool::Tool;
@@ -110,8 +111,17 @@ impl Agent {
) -> Result<(Message, ProviderUsage), ProviderError> {
let config = provider.get_model_config();
// Convert tool messages to text if toolshim is enabled
let messages_for_provider = if config.toolshim {
convert_tool_messages_to_text(messages)
} else {
messages.to_vec()
};
// Call the provider to get a response
let (mut response, usage) = provider.complete(system_prompt, messages, tools).await?;
let (mut response, usage) = provider
.complete(system_prompt, &messages_for_provider, tools)
.await?;
// Store the model information in the global store
crate::providers::base::set_current_model(&usage.model);

View File

@@ -39,13 +39,13 @@ impl VectorToolSelector {
pub async fn new(provider: Arc<dyn Provider>, table_name: String) -> Result<Self> {
let vector_db = ToolVectorDB::new(Some(table_name)).await?;
let embedding_provider = if env::var("EMBEDDING_MODEL_PROVIDER").is_ok() {
let embedding_provider = if env::var("GOOSE_EMBEDDING_MODEL_PROVIDER").is_ok() {
// If env var is set, create a new provider for embeddings
// Get embedding model and provider from environment variables
let embedding_model = env::var("EMBEDDING_MODEL")
let embedding_model = env::var("GOOSE_EMBEDDING_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
let embedding_provider_name =
env::var("EMBEDDING_MODEL_PROVIDER").unwrap_or_else(|_| "openai".to_string());
env::var("GOOSE_EMBEDDING_MODEL_PROVIDER").unwrap_or_else(|_| "openai".to_string());
// Create the provider using the factory
let model_config = ModelConfig::new(embedding_model);

View File

@@ -1,23 +1,35 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use async_stream::try_stream;
use futures::stream::BoxStream;
use futures::StreamExt;
use futures::stream::{self, BoxStream};
use futures::{Stream, StreamExt};
use mcp_core::protocol::JsonRpcMessage;
use tokio::sync::Mutex;
use crate::config::permission::PermissionLevel;
use crate::config::PermissionManager;
use crate::message::{Message, ToolRequest};
use crate::permission::Permission;
use mcp_core::{Content, ToolError};
use mcp_core::{Content, ToolResult};
// Type alias for ToolFutures - used in the agent loop to join all futures together
pub(crate) type ToolFuture<'a> =
Pin<Box<dyn Future<Output = (String, Result<Vec<Content>, ToolError>)> + Send + 'a>>;
pub(crate) type ToolFuturesVec<'a> = Arc<Mutex<Vec<ToolFuture<'a>>>>;
// ToolCallResult combines the result of a tool call with an optional notification stream that
// can be used to receive notifications from the tool.
pub struct ToolCallResult {
pub result: Box<dyn Future<Output = ToolResult<Vec<Content>>> + Send + Unpin>,
pub notification_stream: Option<Box<dyn Stream<Item = JsonRpcMessage> + Send + Unpin>>,
}
impl From<ToolResult<Vec<Content>>> for ToolCallResult {
fn from(result: ToolResult<Vec<Content>>) -> Self {
Self {
result: Box::new(futures::future::ready(result)),
notification_stream: None,
}
}
}
use super::agent::{tool_stream, ToolStream};
use crate::agents::Agent;
pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
@@ -37,7 +49,7 @@ impl Agent {
pub(crate) fn handle_approval_tool_requests<'a>(
&'a self,
tool_requests: &'a [ToolRequest],
tool_futures: ToolFuturesVec<'a>,
tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
permission_manager: &'a mut PermissionManager,
message_tool_response: Arc<Mutex<Message>>,
) -> BoxStream<'a, anyhow::Result<Message>> {
@@ -56,9 +68,19 @@ impl Agent {
while let Some((req_id, confirmation)) = rx.recv().await {
if req_id == request.id {
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone()).await;
let mut futures = tool_futures.lock().await;
futures.push(Box::pin(tool_future));
futures.push((req_id, match tool_result {
Ok(result) => tool_stream(
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
result.result,
),
Err(e) => tool_stream(
Box::new(stream::empty()),
futures::future::ready(Err(e)),
),
}));
if confirmation.permission == Permission::AlwaysAllow {
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);

View File

@@ -148,6 +148,12 @@ impl Usage {
use async_trait::async_trait;
/// Trait for LeadWorkerProvider-specific functionality
pub trait LeadWorkerProviderTrait {
/// Get information about the lead and worker models for logging
fn get_model_info(&self) -> (String, String);
}
/// Base trait for AI providers (OpenAI, Anthropic, etc)
#[async_trait]
pub trait Provider: Send + Sync {
@@ -195,6 +201,12 @@ pub trait Provider: Send + Sync {
"This provider does not support embeddings".to_string(),
))
}
/// Check if this provider is a LeadWorkerProvider
/// This is used for logging model information at startup
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
None
}
}
#[cfg(test)]

View File

@@ -17,6 +17,7 @@ use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
use tokio::time::sleep;
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
@@ -24,6 +25,17 @@ const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
// https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
/// Default timeout for API requests in seconds
const DEFAULT_TIMEOUT_SECS: u64 = 600;
/// Default initial interval for retry (in milliseconds)
const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000;
/// Default maximum number of retries
const DEFAULT_MAX_RETRIES: usize = 6;
/// Default retry backoff multiplier
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
/// Default maximum interval for retry (in milliseconds)
const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000;
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet";
// Databricks can passthrough to a wide range of models, we only provide the default
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
@@ -36,6 +48,53 @@ pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
pub const DATABRICKS_DOC_URL: &str =
"https://docs.databricks.com/en/generative-ai/external-models/index.html";
/// Retry configuration for handling rate limit errors
#[derive(Debug, Clone)]
struct RetryConfig {
/// Maximum number of retry attempts
max_retries: usize,
/// Initial interval between retries in milliseconds
initial_interval_ms: u64,
/// Multiplier for backoff (exponential)
backoff_multiplier: f64,
/// Maximum interval between retries in milliseconds
max_interval_ms: u64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: DEFAULT_MAX_RETRIES,
initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS,
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS,
}
}
}
impl RetryConfig {
/// Calculate the delay for a specific retry attempt (with jitter)
fn delay_for_attempt(&self, attempt: usize) -> Duration {
if attempt == 0 {
return Duration::from_millis(0);
}
// Calculate exponential backoff
let exponent = (attempt - 1) as u32;
let base_delay_ms = (self.initial_interval_ms as f64
* self.backoff_multiplier.powi(exponent as i32)) as u64;
// Apply max limit
let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms);
// Add jitter (+/-20% randomness) to avoid thundering herd problem
let jitter_factor = 0.8 + (rand::random::<f64>() * 0.4); // Between 0.8 and 1.2
let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64;
Duration::from_millis(jittered_delay_ms)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DatabricksAuth {
Token(String),
@@ -70,6 +129,8 @@ pub struct DatabricksProvider {
auth: DatabricksAuth,
model: ModelConfig,
image_format: ImageFormat,
#[serde(skip)]
retry_config: RetryConfig,
}
impl Default for DatabricksProvider {
@@ -100,9 +161,12 @@ impl DatabricksProvider {
let host = host?;
let client = Client::builder()
.timeout(Duration::from_secs(600))
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()?;
// Load optional retry configuration from environment
let retry_config = Self::load_retry_config(config);
// If we find a databricks token we prefer that
if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") {
return Ok(Self {
@@ -111,6 +175,7 @@ impl DatabricksProvider {
auth: DatabricksAuth::token(api_key),
model,
image_format: ImageFormat::OpenAi,
retry_config,
});
}
@@ -121,9 +186,44 @@ impl DatabricksProvider {
host,
model,
image_format: ImageFormat::OpenAi,
retry_config,
})
}
/// Loads retry configuration from environment variables or uses defaults.
fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
let max_retries = config
.get_param("DATABRICKS_MAX_RETRIES")
.ok()
.and_then(|v: String| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_RETRIES);
let initial_interval_ms = config
.get_param("DATABRICKS_INITIAL_RETRY_INTERVAL_MS")
.ok()
.and_then(|v: String| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
let backoff_multiplier = config
.get_param("DATABRICKS_BACKOFF_MULTIPLIER")
.ok()
.and_then(|v: String| v.parse::<f64>().ok())
.unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
let max_interval_ms = config
.get_param("DATABRICKS_MAX_RETRY_INTERVAL_MS")
.ok()
.and_then(|v: String| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
RetryConfig {
max_retries,
initial_interval_ms,
backoff_multiplier,
max_interval_ms,
}
}
/// Create a new DatabricksProvider with the specified host and token
///
/// # Arguments
@@ -145,6 +245,7 @@ impl DatabricksProvider {
auth: DatabricksAuth::token(api_key),
model,
image_format: ImageFormat::OpenAi,
retry_config: RetryConfig::default(),
})
}
@@ -182,10 +283,25 @@ impl DatabricksProvider {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
// Initialize retry counter
let mut attempts = 0;
let mut last_error = None;
loop {
// Check if we've exceeded max retries
if attempts > 0 && attempts > self.retry_config.max_retries {
let error_msg = format!(
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
self.retry_config.max_retries
);
tracing::error!("{}", error_msg);
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
}
let auth_header = self.ensure_auth_header().await?;
let response = self
.client
.post(url)
.post(url.clone())
.header("Authorization", auth_header)
.json(&payload)
.send()
@@ -195,15 +311,24 @@ impl DatabricksProvider {
let payload: Option<Value> = response.json().await.ok();
match status {
StatusCode::OK => payload.ok_or_else(|| ProviderError::RequestFailed("Response body is not valid JSON".to_string())),
StatusCode::OK => {
return payload.ok_or_else(|| {
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
});
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}", status, payload)))
return Err(ProviderError::Authentication(format!(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}",
status, payload
)));
}
StatusCode::BAD_REQUEST => {
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
let payload_str = serde_json::to_string(&payload).unwrap_or_default().to_lowercase();
let payload_str = serde_json::to_string(&payload)
.unwrap_or_default()
.to_lowercase();
let check_phrases = [
"too long",
"context length",
@@ -211,6 +336,11 @@ impl DatabricksProvider {
"reduce the length",
"token count",
"exceeds",
"exceed context limit",
"input length",
"max_tokens",
"decrease input length",
"context limit",
];
if check_phrases.iter().any(|c| payload_str.contains(c)) {
return Err(ProviderError::ContextLengthExceeded(payload_str));
@@ -223,29 +353,78 @@ impl DatabricksProvider {
.get("message")
.and_then(|m| m.as_str())
.or_else(|| {
payload.get("external_model_message")
payload
.get("external_model_message")
.and_then(|ext| ext.get("message"))
.and_then(|m| m.as_str())
})
.unwrap_or("Unknown error").to_string();
.unwrap_or("Unknown error")
.to_string();
}
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)))
return Err(ProviderError::RequestFailed(format!(
"Request failed with status: {}. Message: {}",
status, error_msg
)));
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
attempts += 1;
let error_msg = format!(
"Rate limit exceeded (attempt {}/{}): {:?}",
attempts, self.retry_config.max_retries, payload
);
tracing::warn!("{}. Retrying after backoff...", error_msg);
// Store the error in case we need to return it after max retries
last_error = Some(ProviderError::RateLimitExceeded(error_msg));
// Calculate and apply the backoff delay
let delay = self.retry_config.delay_for_attempt(attempts);
tracing::info!("Backing off for {:?} before retry", delay);
sleep(delay).await;
// Continue to the next retry attempt
continue;
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
attempts += 1;
let error_msg = format!(
"Server error (attempt {}/{}): {:?}",
attempts, self.retry_config.max_retries, payload
);
tracing::warn!("{}. Retrying after backoff...", error_msg);
// Store the error in case we need to return it after max retries
last_error = Some(ProviderError::ServerError(error_msg));
// Calculate and apply the backoff delay
let delay = self.retry_config.delay_for_attempt(attempts);
tracing::info!("Backing off for {:?} before retry", delay);
sleep(delay).await;
// Continue to the next retry attempt
continue;
}
_ => {
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
return Err(ProviderError::RequestFailed(format!(
"Request failed with status: {}",
status
)));
}
}
}
}

View File

@@ -10,14 +10,31 @@ use super::{
githubcopilot::GithubCopilotProvider,
google::GoogleProvider,
groq::GroqProvider,
lead_worker::LeadWorkerProvider,
ollama::OllamaProvider,
openai::OpenAiProvider,
openrouter::OpenRouterProvider,
snowflake::SnowflakeProvider,
venice::VeniceProvider,
};
use crate::model::ModelConfig;
use anyhow::Result;
#[cfg(test)]
use super::errors::ProviderError;
#[cfg(test)]
use mcp_core::tool::Tool;
fn default_lead_turns() -> usize {
3
}
fn default_failure_threshold() -> usize {
2
}
fn default_fallback_turns() -> usize {
2
}
pub fn providers() -> Vec<ProviderMetadata> {
vec![
AnthropicProvider::metadata(),
@@ -32,10 +49,67 @@ pub fn providers() -> Vec<ProviderMetadata> {
OpenAiProvider::metadata(),
OpenRouterProvider::metadata(),
VeniceProvider::metadata(),
SnowflakeProvider::metadata(),
]
}
pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
let config = crate::config::Config::global();
// Check for lead model environment variables
if let Ok(lead_model_name) = config.get_param::<String>("GOOSE_LEAD_MODEL") {
tracing::info!("Creating lead/worker provider from environment variables");
return create_lead_worker_from_env(name, &model, &lead_model_name);
}
// Default: create regular provider
create_provider(name, model)
}
/// Create a lead/worker provider from environment variables
fn create_lead_worker_from_env(
default_provider_name: &str,
default_model: &ModelConfig,
lead_model_name: &str,
) -> Result<Arc<dyn Provider>> {
let config = crate::config::Config::global();
// Get lead provider (optional, defaults to main provider)
let lead_provider_name = config
.get_param::<String>("GOOSE_LEAD_PROVIDER")
.unwrap_or_else(|_| default_provider_name.to_string());
// Get configuration parameters with defaults
let lead_turns = config
.get_param::<usize>("GOOSE_LEAD_TURNS")
.unwrap_or(default_lead_turns());
let failure_threshold = config
.get_param::<usize>("GOOSE_LEAD_FAILURE_THRESHOLD")
.unwrap_or(default_failure_threshold());
let fallback_turns = config
.get_param::<usize>("GOOSE_LEAD_FALLBACK_TURNS")
.unwrap_or(default_fallback_turns());
// Create model configs
let lead_model_config = ModelConfig::new(lead_model_name.to_string());
let worker_model_config = default_model.clone();
// Create the providers
let lead_provider = create_provider(&lead_provider_name, lead_model_config)?;
let worker_provider = create_provider(default_provider_name, worker_model_config)?;
// Create the lead/worker provider with configured settings
Ok(Arc::new(LeadWorkerProvider::new_with_settings(
lead_provider,
worker_provider,
lead_turns,
failure_threshold,
fallback_turns,
)))
}
fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
// We use Arc instead of Box to be able to clone for multiple async tasks
match name {
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
@@ -49,7 +123,220 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),
"google" => Ok(Arc::new(GoogleProvider::from_env(model)?)),
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{Message, MessageContent};
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
use chrono::Utc;
use mcp_core::{content::TextContent, Role};
use std::env;
#[derive(Clone)]
struct MockTestProvider {
name: String,
model_config: ModelConfig,
}
#[async_trait::async_trait]
impl Provider for MockTestProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"mock_test",
"Mock Test Provider",
"A mock provider for testing",
"mock-model",
vec!["mock-model"],
"",
vec![],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!(
"Response from {} with model {}",
self.name, self.model_config.model_name
),
annotations: None,
})],
},
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
))
}
}
#[test]
fn test_create_lead_worker_provider() {
// Save current env vars
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
// Test with basic lead model configuration
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
// This will try to create a lead/worker provider
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// The creation might succeed or fail depending on API keys, but we can verify the logic path
match result {
Ok(_) => {
// If it succeeds, it means we created a lead/worker provider successfully
// This would happen if API keys are available in the test environment
}
Err(error) => {
// If it fails, it should be due to missing API keys, confirming we tried to create providers
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Test with different lead provider
env::set_var("GOOSE_LEAD_PROVIDER", "anthropic");
env::set_var("GOOSE_LEAD_TURNS", "5");
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Similar validation as above - will fail due to missing API keys but confirms the logic
// Restore env vars
match saved_lead {
Some(val) => env::set_var("GOOSE_LEAD_MODEL", val),
None => env::remove_var("GOOSE_LEAD_MODEL"),
}
match saved_provider {
Some(val) => env::set_var("GOOSE_LEAD_PROVIDER", val),
None => env::remove_var("GOOSE_LEAD_PROVIDER"),
}
match saved_turns {
Some(val) => env::set_var("GOOSE_LEAD_TURNS", val),
None => env::remove_var("GOOSE_LEAD_TURNS"),
}
}
#[test]
fn test_lead_model_env_vars_with_defaults() {
// Save current env vars
let saved_vars = [
("GOOSE_LEAD_MODEL", env::var("GOOSE_LEAD_MODEL").ok()),
("GOOSE_LEAD_PROVIDER", env::var("GOOSE_LEAD_PROVIDER").ok()),
("GOOSE_LEAD_TURNS", env::var("GOOSE_LEAD_TURNS").ok()),
(
"GOOSE_LEAD_FAILURE_THRESHOLD",
env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok(),
),
(
"GOOSE_LEAD_FALLBACK_TURNS",
env::var("GOOSE_LEAD_FALLBACK_TURNS").ok(),
),
];
// Clear all lead env vars
for (key, _) in &saved_vars {
env::remove_var(key);
}
// Set only the required lead model
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
// This should use defaults for all other values
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Should attempt to create lead/worker provider (will fail due to missing API keys but confirms logic)
match result {
Ok(_) => {
// Success means we have API keys and created the provider
}
Err(error) => {
// Should fail due to missing API keys, confirming we tried to create providers
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Test with custom values
env::set_var("GOOSE_LEAD_TURNS", "7");
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", "4");
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", "3");
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Should still attempt to create lead/worker provider with custom settings
// Restore all env vars
for (key, value) in saved_vars {
match value {
Some(val) => env::set_var(key, val),
None => env::remove_var(key),
}
}
}
#[test]
fn test_create_regular_provider_without_lead_config() {
// Save current env vars
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
let saved_threshold = env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok();
let saved_fallback = env::var("GOOSE_LEAD_FALLBACK_TURNS").ok();
// Ensure all GOOSE_LEAD_* variables are not set
env::remove_var("GOOSE_LEAD_MODEL");
env::remove_var("GOOSE_LEAD_PROVIDER");
env::remove_var("GOOSE_LEAD_TURNS");
env::remove_var("GOOSE_LEAD_FAILURE_THRESHOLD");
env::remove_var("GOOSE_LEAD_FALLBACK_TURNS");
// This should try to create a regular provider
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// The creation might succeed or fail depending on API keys
match result {
Ok(_) => {
// If it succeeds, it means we created a regular provider successfully
// This would happen if API keys are available in the test environment
}
Err(error) => {
// If it fails, it should be due to missing API keys
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Restore env vars
if let Some(val) = saved_lead {
env::set_var("GOOSE_LEAD_MODEL", val);
}
if let Some(val) = saved_provider {
env::set_var("GOOSE_LEAD_PROVIDER", val);
}
if let Some(val) = saved_turns {
env::set_var("GOOSE_LEAD_TURNS", val);
}
if let Some(val) = saved_threshold {
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", val);
}
if let Some(val) = saved_fallback {
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", val);
}
}
}

View File

@@ -98,6 +98,10 @@ pub enum GeminiVersion {
Pro20Exp,
/// Gemini 2.5 Pro Experimental version
Pro25Exp,
/// Gemini 2.5 Flash Preview version
Flash25Preview,
/// Gemini 2.5 Pro Preview version
Pro25Preview,
/// Generic Gemini model for custom or new versions
Generic(String),
}
@@ -118,6 +122,8 @@ impl fmt::Display for GcpVertexAIModel {
GeminiVersion::Flash20 => "gemini-2.0-flash-001",
GeminiVersion::Pro20Exp => "gemini-2.0-pro-exp-02-05",
GeminiVersion::Pro25Exp => "gemini-2.5-pro-exp-03-25",
GeminiVersion::Flash25Preview => "gemini-2.5-flash-preview-05-20",
GeminiVersion::Pro25Preview => "gemini-2.5-pro-preview-05-06",
GeminiVersion::Generic(name) => name,
},
};
@@ -154,6 +160,8 @@ impl TryFrom<&str> for GcpVertexAIModel {
"gemini-2.0-flash-001" => Ok(Self::Gemini(GeminiVersion::Flash20)),
"gemini-2.0-pro-exp-02-05" => Ok(Self::Gemini(GeminiVersion::Pro20Exp)),
"gemini-2.5-pro-exp-03-25" => Ok(Self::Gemini(GeminiVersion::Pro25Exp)),
"gemini-2.5-flash-preview-05-20" => Ok(Self::Gemini(GeminiVersion::Flash25Preview)),
"gemini-2.5-pro-preview-05-06" => Ok(Self::Gemini(GeminiVersion::Pro25Preview)),
// Generic models based on prefix matching
_ if s.starts_with("claude-") => {
Ok(Self::Claude(ClaudeVersion::Generic(s.to_string())))
@@ -349,6 +357,8 @@ mod tests {
"gemini-2.0-flash-001",
"gemini-2.0-pro-exp-02-05",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-flash-preview-05-20",
"gemini-2.5-pro-preview-05-06",
];
for model_id in valid_models {
@@ -372,6 +382,8 @@ mod tests {
("gemini-2.0-flash-001", GcpLocation::Iowa),
("gemini-2.0-pro-exp-02-05", GcpLocation::Iowa),
("gemini-2.5-pro-exp-03-25", GcpLocation::Iowa),
("gemini-2.5-flash-preview-05-20", GcpLocation::Iowa),
("gemini-2.5-pro-preview-05-06", GcpLocation::Iowa),
];
for (model_id, expected_location) in test_cases {

View File

@@ -4,3 +4,4 @@ pub mod databricks;
pub mod gcpvertexai;
pub mod google;
pub mod openai;
pub mod snowflake;

View File

@@ -0,0 +1,716 @@
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use crate::providers::base::Usage;
use crate::providers::errors::ProviderError;
use anyhow::{anyhow, Result};
use mcp_core::content::Content;
use mcp_core::role::Role;
use mcp_core::tool::{Tool, ToolCall};
use serde_json::{json, Value};
use std::collections::HashSet;
/// Convert internal Message format to Snowflake's API message specification
pub fn format_messages(messages: &[Message]) -> Vec<Value> {
let mut snowflake_messages = Vec::new();
// Convert messages to Snowflake format
for message in messages {
let role = match message.role {
Role::User => "user",
Role::Assistant => "assistant",
};
let mut text_content = String::new();
for msg_content in &message.content {
match msg_content {
MessageContent::Text(text) => {
if !text_content.is_empty() {
text_content.push('\n');
}
text_content.push_str(&text.text);
}
MessageContent::ToolRequest(_tool_request) => {
// Skip tool requests in message formatting - tools are handled separately
// through the tools parameter in the API request
continue;
}
MessageContent::ToolResponse(tool_response) => {
if let Ok(result) = &tool_response.tool_result {
let text = result
.iter()
.filter_map(|c| match c {
Content::Text(t) => Some(t.text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
if !text_content.is_empty() {
text_content.push('\n');
}
if !text.is_empty() {
text_content.push_str(&format!("Tool result: {}", text));
}
}
}
MessageContent::ToolConfirmationRequest(_) => {
// Skip tool confirmation requests
}
MessageContent::ContextLengthExceeded(_) => {
// Skip
}
MessageContent::SummarizationRequested(_) => {
// Skip
}
MessageContent::Thinking(_thinking) => {
// Skip thinking for now
}
MessageContent::RedactedThinking(_redacted) => {
// Skip redacted thinking for now
}
MessageContent::Image(_) => continue, // Snowflake doesn't support image content yet
MessageContent::FrontendToolRequest(_tool_request) => {
// Skip frontend tool requests
}
}
}
// Add message if it has text content
if !text_content.is_empty() {
snowflake_messages.push(json!({
"role": role,
"content": text_content
}));
}
}
// Only add default message if we truly have no messages at all
// This should be rare and only for edge cases
if snowflake_messages.is_empty() {
snowflake_messages.push(json!({
"role": "user",
"content": "Continue the conversation"
}));
}
snowflake_messages
}
/// Convert internal Tool format to Snowflake's API tool specification
pub fn format_tools(tools: &[Tool]) -> Vec<Value> {
let mut unique_tools = HashSet::new();
let mut tool_specs = Vec::new();
for tool in tools.iter() {
if unique_tools.insert(tool.name.clone()) {
let tool_spec = json!({
"type": "generic",
"name": tool.name,
"description": tool.description,
"input_schema": tool.input_schema
});
tool_specs.push(json!({"tool_spec": tool_spec}));
}
}
tool_specs
}
/// Convert system message to Snowflake's API system specification
pub fn format_system(system: &str) -> Value {
json!({
"role": "system",
"content": system,
})
}
/// Convert Snowflake's streaming API response to internal Message format
pub fn parse_streaming_response(sse_data: &str) -> Result<Message> {
let mut message = Message::assistant();
let mut accumulated_text = String::new();
let mut tool_use_id: Option<String> = None;
let mut tool_name: Option<String> = None;
let mut tool_input = String::new();
// Parse each SSE event
for line in sse_data.lines() {
if !line.starts_with("data: ") {
continue;
}
let json_str = &line[6..]; // Remove "data: " prefix
if json_str.trim().is_empty() || json_str.trim() == "[DONE]" {
continue;
}
let event: Value = match serde_json::from_str(json_str) {
Ok(v) => v,
Err(_) => {
continue;
}
};
if let Some(choices) = event.get("choices").and_then(|c| c.as_array()) {
if let Some(choice) = choices.first() {
if let Some(delta) = choice.get("delta") {
match delta.get("type").and_then(|t| t.as_str()) {
Some("text") => {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
accumulated_text.push_str(content);
}
}
Some("tool_use") => {
if let Some(id) = delta.get("tool_use_id").and_then(|i| i.as_str()) {
tool_use_id = Some(id.to_string());
}
if let Some(name) = delta.get("name").and_then(|n| n.as_str()) {
tool_name = Some(name.to_string());
}
if let Some(input) = delta.get("input").and_then(|i| i.as_str()) {
tool_input.push_str(input);
}
}
_ => {}
}
}
}
}
}
// Add accumulated text if any
if !accumulated_text.is_empty() {
message = message.with_text(accumulated_text);
}
// Add tool use if complete
if let (Some(id), Some(name)) = (&tool_use_id, &tool_name) {
if !tool_input.is_empty() {
let input_value = serde_json::from_str::<Value>(&tool_input)
.unwrap_or_else(|_| Value::String(tool_input.clone()));
let tool_call = ToolCall::new(name, input_value);
message = message.with_tool_request(id, Ok(tool_call));
} else if tool_name.is_some() {
// Tool with no input - use empty object
let tool_call = ToolCall::new(name, Value::Object(serde_json::Map::new()));
message = message.with_tool_request(id, Ok(tool_call));
}
}
Ok(message)
}
/// Convert Snowflake's API response to internal Message format
pub fn response_to_message(response: Value) -> Result<Message> {
let mut message = Message::assistant();
let content_list = response.get("content_list").and_then(|cl| cl.as_array());
// Handle case where content_list is missing or empty
let content_list = match content_list {
Some(list) if !list.is_empty() => list,
_ => {
// If no content_list or empty, check if there's a direct content field
if let Some(direct_content) = response.get("content").and_then(|c| c.as_str()) {
if !direct_content.is_empty() {
message = message.with_text(direct_content.to_string());
}
return Ok(message);
} else {
// Return empty assistant message for empty responses
return Ok(message);
}
}
};
// Process all content items in the list
for content in content_list {
match content.get("type").and_then(|t| t.as_str()) {
Some("text") => {
if let Some(text) = content.get("text").and_then(|t| t.as_str()) {
if !text.is_empty() {
message = message.with_text(text.to_string());
}
}
}
Some("tool_use") => {
let id = content
.get("tool_use_id")
.and_then(|i| i.as_str())
.ok_or_else(|| anyhow!("Missing tool_use id"))?;
let name = content
.get("name")
.and_then(|n| n.as_str())
.ok_or_else(|| anyhow!("Missing tool_use name"))?;
let input = content
.get("input")
.ok_or_else(|| anyhow!("Missing tool input"))?
.clone();
let tool_call = ToolCall::new(name, input);
message = message.with_tool_request(id, Ok(tool_call));
}
Some("thinking") => {
let thinking = content
.get("thinking")
.and_then(|t| t.as_str())
.ok_or_else(|| anyhow!("Missing thinking content"))?;
let signature = content
.get("signature")
.and_then(|s| s.as_str())
.ok_or_else(|| anyhow!("Missing thinking signature"))?;
message = message.with_thinking(thinking, signature);
}
Some("redacted_thinking") => {
let data = content
.get("data")
.and_then(|d| d.as_str())
.ok_or_else(|| anyhow!("Missing redacted_thinking data"))?;
message = message.with_redacted_thinking(data);
}
_ => {
// Ignore unrecognized content types
}
}
}
Ok(message)
}
/// Extract usage information from Snowflake's API response
pub fn get_usage(data: &Value) -> Result<Usage> {
// Extract usage data if available
if let Some(usage) = data.get("usage") {
let input_tokens = usage
.get("input_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let output_tokens = usage
.get("output_tokens")
.and_then(|v| v.as_u64())
.map(|v| v as i32);
let total_tokens = match (input_tokens, output_tokens) {
(Some(input), Some(output)) => Some(input + output),
_ => None,
};
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
} else {
tracing::debug!(
"Failed to get usage data: {}",
ProviderError::UsageError("No usage data found in response".to_string())
);
// If no usage data, return None for all values
Ok(Usage::new(None, None, None))
}
}
/// Create a complete request payload for Snowflake's API
pub fn create_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<Value> {
let mut snowflake_messages = format_messages(messages);
let system_spec = format_system(system);
// Add system message to the beginning of the messages
snowflake_messages.insert(0, system_spec);
// Check if we have any messages to send
if snowflake_messages.is_empty() {
return Err(anyhow!("No valid messages to send to Snowflake API"));
}
// Detect description generation requests and exclude tools to prevent interference
// with normal tool execution flow
let is_description_request =
system.contains("Reply with only a description in four words or less");
let tool_specs = if is_description_request {
// For description generation, don't include any tools to avoid confusion
format_tools(&[])
} else {
format_tools(tools)
};
let max_tokens = model_config.max_tokens.unwrap_or(4096);
let mut payload = json!({
"model": model_config.model_name,
"messages": snowflake_messages,
"max_tokens": max_tokens,
});
// Add tools if present and not a description request
if !tool_specs.is_empty() {
if let Some(obj) = payload.as_object_mut() {
obj.insert("tools".to_string(), json!(tool_specs));
} else {
return Err(anyhow!(
"Failed to create request payload: payload is not a JSON object"
));
}
}
Ok(payload)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_parse_text_response() -> Result<()> {
let response = json!({
"id": "msg_123",
"type": "message",
"role": "assistant",
"content_list": [{
"type": "text",
"text": "Hello! How can I assist you today?"
}],
"model": "claude-3-5-sonnet",
"stop_reason": "end_turn",
"stop_sequence": null,
"usage": {
"input_tokens": 12,
"output_tokens": 15
}
});
let message = response_to_message(response.clone())?;
let usage = get_usage(&response)?;
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "Hello! How can I assist you today?");
} else {
panic!("Expected Text content");
}
assert_eq!(usage.input_tokens, Some(12));
assert_eq!(usage.output_tokens, Some(15));
assert_eq!(usage.total_tokens, Some(27)); // 12 + 15
Ok(())
}
#[test]
fn test_parse_tool_response() -> Result<()> {
let response = json!({
"id": "msg_123",
"type": "message",
"role": "assistant",
"content_list": [{
"type": "tool_use",
"tool_use_id": "tool_1",
"name": "calculator",
"input": {"expression": "2 + 2"}
}],
"model": "claude-3-5-sonnet",
"stop_reason": "end_turn",
"stop_sequence": null,
"usage": {
"input_tokens": 15,
"output_tokens": 20
}
});
let message = response_to_message(response.clone())?;
let usage = get_usage(&response)?;
if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
let tool_call = tool_request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "calculator");
assert_eq!(tool_call.arguments, json!({"expression": "2 + 2"}));
} else {
panic!("Expected ToolRequest content");
}
assert_eq!(usage.input_tokens, Some(15));
assert_eq!(usage.output_tokens, Some(20));
assert_eq!(usage.total_tokens, Some(35)); // 15 + 20
Ok(())
}
#[test]
fn test_message_to_snowflake_spec() {
let messages = vec![
Message::user().with_text("Hello"),
Message::assistant().with_text("Hi there"),
Message::user().with_text("How are you?"),
];
let spec = format_messages(&messages);
assert_eq!(spec.len(), 3);
assert_eq!(spec[0]["role"], "user");
assert_eq!(spec[0]["content"], "Hello");
assert_eq!(spec[1]["role"], "assistant");
assert_eq!(spec[1]["content"], "Hi there");
assert_eq!(spec[2]["role"], "user");
assert_eq!(spec[2]["content"], "How are you?");
}
#[test]
fn test_tools_to_snowflake_spec() {
let tools = vec![
Tool::new(
"calculator",
"Calculate mathematical expressions",
json!({
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate"
}
}
}),
None,
),
Tool::new(
"weather",
"Get weather information",
json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get weather for"
}
}
}),
None,
),
];
let spec = format_tools(&tools);
assert_eq!(spec.len(), 2);
assert_eq!(spec[0]["tool_spec"]["name"], "calculator");
assert_eq!(
spec[0]["tool_spec"]["description"],
"Calculate mathematical expressions"
);
assert_eq!(spec[1]["tool_spec"]["name"], "weather");
assert_eq!(
spec[1]["tool_spec"]["description"],
"Get weather information"
);
}
#[test]
fn test_system_to_snowflake_spec() {
let system = "You are a helpful assistant.";
let spec = format_system(system);
assert_eq!(spec["role"], "system");
assert_eq!(spec["content"], system);
}
#[test]
fn test_parse_streaming_response() -> Result<()> {
let sse_data = r#"data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","choices":[{"delta":{"type":"text","content":"I","content_list":[{"type":"text","text":"I"}],"text":"I"}}],"usage":{}}
data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","choices":[{"delta":{"type":"text","content":"'ll help you check Nvidia's current","content_list":[{"type":"text","text":"'ll help you check Nvidia's current"}],"text":"'ll help you check Nvidia's current"}}],"usage":{}}
data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","choices":[{"delta":{"type":"tool_use","tool_use_id":"tooluse_FB_nOElDTAOKa-YnVWI5Uw","name":"get_stock_price","content_list":[{"tool_use_id":"tooluse_FB_nOElDTAOKa-YnVWI5Uw","name":"get_stock_price"}],"text":""}}],"usage":{}}
data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","choices":[{"delta":{"type":"tool_use","input":"{\"symbol\":\"NVDA\"}","content_list":[{"input":"{\"symbol\":\"NVDA\"}"}],"text":""}}],"usage":{"prompt_tokens":397,"completion_tokens":65,"total_tokens":462}}
"#;
let message = parse_streaming_response(sse_data)?;
// Should have both text and tool request
assert_eq!(message.content.len(), 2);
if let MessageContent::Text(text) = &message.content[0] {
assert!(text.text.contains("I'll help you check Nvidia's current"));
} else {
panic!("Expected Text content first");
}
if let MessageContent::ToolRequest(tool_request) = &message.content[1] {
let tool_call = tool_request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "get_stock_price");
assert_eq!(tool_call.arguments, json!({"symbol": "NVDA"}));
assert_eq!(tool_request.id, "tooluse_FB_nOElDTAOKa-YnVWI5Uw");
} else {
panic!("Expected ToolRequest content second");
}
Ok(())
}
#[test]
fn test_create_request_format() -> Result<()> {
use crate::model::ModelConfig;
let model_config = ModelConfig::new("claude-3-5-sonnet".to_string());
let system = "You are a helpful assistant that can use tools to get information.";
let messages = vec![Message::user().with_text("What is the stock price of Nvidia?")];
let tools = vec![Tool::new(
"get_stock_price",
"Get stock price information",
json!({
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The symbol for the stock ticker, e.g. Snowflake = SNOW"
}
},
"required": ["symbol"]
}),
None,
)];
let request = create_request(&model_config, system, &messages, &tools)?;
// Check basic structure
assert_eq!(request["model"], "claude-3-5-sonnet");
let messages_array = request["messages"].as_array().unwrap();
assert_eq!(messages_array.len(), 2); // system + user message
// First message should be system with simple content
assert_eq!(messages_array[0]["role"], "system");
assert_eq!(
messages_array[0]["content"],
"You are a helpful assistant that can use tools to get information."
);
// Second message should be user with simple content
assert_eq!(messages_array[1]["role"], "user");
assert_eq!(
messages_array[1]["content"],
"What is the stock price of Nvidia?"
);
// Tools should have tool_spec wrapper
let tools_array = request["tools"].as_array().unwrap();
assert_eq!(tools_array[0]["tool_spec"]["name"], "get_stock_price");
Ok(())
}
#[test]
fn test_parse_mixed_text_and_tool_response() -> Result<()> {
let response = json!({
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": "I'll help you with that calculation.",
"content_list": [
{
"type": "text",
"text": "I'll help you with that calculation."
},
{
"type": "tool_use",
"tool_use_id": "tool_1",
"name": "calculator",
"input": {"expression": "2 + 2"}
}
],
"model": "claude-3-5-sonnet",
"usage": {
"input_tokens": 10,
"output_tokens": 15
}
});
let message = response_to_message(response.clone())?;
// Should have both text and tool request content
assert_eq!(message.content.len(), 2);
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "I'll help you with that calculation.");
} else {
panic!("Expected Text content first");
}
if let MessageContent::ToolRequest(tool_request) = &message.content[1] {
let tool_call = tool_request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "calculator");
assert_eq!(tool_request.id, "tool_1");
} else {
panic!("Expected ToolRequest content second");
}
Ok(())
}
#[test]
fn test_empty_tools_array() {
let tools: Vec<Tool> = vec![];
let spec = format_tools(&tools);
assert_eq!(spec.len(), 0);
}
#[test]
fn test_create_request_excludes_tools_for_description() -> Result<()> {
use crate::model::ModelConfig;
let model_config = ModelConfig::new("claude-3-5-sonnet".to_string());
let system = "Reply with only a description in four words or less";
let messages = vec![Message::user().with_text("Test message")];
let tools = vec![Tool::new(
"test_tool",
"Test tool",
json!({"type": "object", "properties": {}}),
None,
)];
let request = create_request(&model_config, system, &messages, &tools)?;
// Should not include tools for description requests
assert!(request.get("tools").is_none());
Ok(())
}
#[test]
fn test_message_formatting_skips_tool_requests() {
use mcp_core::tool::ToolCall;
// Create a conversation with text, tool requests, and tool responses
let tool_call = ToolCall::new("calculator", json!({"expression": "2 + 2"}));
let messages = vec![
Message::user().with_text("Calculate 2 + 2"),
Message::assistant()
.with_text("I'll help you calculate that.")
.with_tool_request("tool_1", Ok(tool_call)),
Message::user().with_text("Thanks!"),
];
let spec = format_messages(&messages);
// Should only have 3 messages - the tool request should be skipped
assert_eq!(spec.len(), 3);
assert_eq!(spec[0]["role"], "user");
assert_eq!(spec[0]["content"], "Calculate 2 + 2");
assert_eq!(spec[1]["role"], "assistant");
assert_eq!(spec[1]["content"], "I'll help you calculate that.");
assert_eq!(spec[2]["role"], "user");
assert_eq!(spec[2]["content"], "Thanks!");
// Verify no tool request content is in the message history
for message in &spec {
let content = message["content"].as_str().unwrap();
assert!(!content.contains("Using tool:"));
assert!(!content.contains("calculator"));
}
}
}

View File

@@ -434,6 +434,9 @@ impl Provider for GcpVertexAIProvider {
GcpVertexAIModel::Gemini(GeminiVersion::Pro15),
GcpVertexAIModel::Gemini(GeminiVersion::Flash20),
GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp),
GcpVertexAIModel::Gemini(GeminiVersion::Pro25Exp),
GcpVertexAIModel::Gemini(GeminiVersion::Flash25Preview),
GcpVertexAIModel::Gemini(GeminiVersion::Pro25Preview),
]
.iter()
.map(|model| model.to_string())

View File

@@ -230,7 +230,7 @@ impl GithubCopilotProvider {
async fn refresh_api_info(&self) -> Result<CopilotTokenInfo> {
let config = Config::global();
let token = match config.get_secret::<String>("GITHUB_TOKEN") {
let token = match config.get_secret::<String>("GITHUB_COPILOT_TOKEN") {
Ok(token) => token,
Err(err) => match err {
ConfigError::NotFound(_) => {
@@ -238,7 +238,7 @@ impl GithubCopilotProvider {
.get_access_token()
.await
.context("unable to login into github")?;
config.set_secret("GITHUB_TOKEN", Value::String(token.clone()))?;
config.set_secret("GITHUB_COPILOT_TOKEN", Value::String(token.clone()))?;
token
}
_ => return Err(err.into()),

View File

@@ -0,0 +1,637 @@
use anyhow::Result;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use mcp_core::{tool::Tool, Content};
/// A provider that switches between a lead model and a worker model based on turn count
/// and can fallback to lead model on consecutive failures
pub struct LeadWorkerProvider {
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: usize,
turn_count: Arc<Mutex<usize>>,
failure_count: Arc<Mutex<usize>>,
max_failures_before_fallback: usize,
fallback_turns: usize,
in_fallback_mode: Arc<Mutex<bool>>,
fallback_remaining: Arc<Mutex<usize>>,
}
impl LeadWorkerProvider {
/// Create a new LeadWorkerProvider
///
/// # Arguments
/// * `lead_provider` - The provider to use for the initial turns
/// * `worker_provider` - The provider to use after lead_turns
/// * `lead_turns` - Number of turns to use the lead provider (default: 3)
pub fn new(
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: Option<usize>,
) -> Self {
Self {
lead_provider,
worker_provider,
lead_turns: lead_turns.unwrap_or(3),
turn_count: Arc::new(Mutex::new(0)),
failure_count: Arc::new(Mutex::new(0)),
max_failures_before_fallback: 2, // Fallback after 2 consecutive failures
fallback_turns: 2, // Use lead model for 2 turns when in fallback mode
in_fallback_mode: Arc::new(Mutex::new(false)),
fallback_remaining: Arc::new(Mutex::new(0)),
}
}
/// Create a new LeadWorkerProvider with custom settings
///
/// # Arguments
/// * `lead_provider` - The provider to use for the initial turns
/// * `worker_provider` - The provider to use after lead_turns
/// * `lead_turns` - Number of turns to use the lead provider
/// * `failure_threshold` - Number of consecutive failures before fallback
/// * `fallback_turns` - Number of turns to use lead model in fallback mode
pub fn new_with_settings(
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: usize,
failure_threshold: usize,
fallback_turns: usize,
) -> Self {
Self {
lead_provider,
worker_provider,
lead_turns,
turn_count: Arc::new(Mutex::new(0)),
failure_count: Arc::new(Mutex::new(0)),
max_failures_before_fallback: failure_threshold,
fallback_turns,
in_fallback_mode: Arc::new(Mutex::new(false)),
fallback_remaining: Arc::new(Mutex::new(0)),
}
}
/// Reset the turn counter and failure tracking (useful for new conversations)
pub async fn reset_turn_count(&self) {
let mut count = self.turn_count.lock().await;
*count = 0;
let mut failures = self.failure_count.lock().await;
*failures = 0;
let mut fallback = self.in_fallback_mode.lock().await;
*fallback = false;
let mut remaining = self.fallback_remaining.lock().await;
*remaining = 0;
}
/// Get the current turn count
pub async fn get_turn_count(&self) -> usize {
*self.turn_count.lock().await
}
/// Get the current failure count
pub async fn get_failure_count(&self) -> usize {
*self.failure_count.lock().await
}
/// Check if currently in fallback mode
pub async fn is_in_fallback_mode(&self) -> bool {
*self.in_fallback_mode.lock().await
}
/// Get the currently active provider based on turn count and fallback state
async fn get_active_provider(&self) -> Arc<dyn Provider> {
let count = *self.turn_count.lock().await;
let in_fallback = *self.in_fallback_mode.lock().await;
// Use lead provider if we're in initial turns OR in fallback mode
if count < self.lead_turns || in_fallback {
Arc::clone(&self.lead_provider)
} else {
Arc::clone(&self.worker_provider)
}
}
/// Handle the result of a completion attempt and update failure tracking
async fn handle_completion_result(
&self,
result: &Result<(Message, ProviderUsage), ProviderError>,
) {
match result {
Ok((message, _usage)) => {
// Check for task-level failures in the response
let has_task_failure = self.detect_task_failures(message).await;
if has_task_failure {
// Task failure detected - increment failure count
let mut failures = self.failure_count.lock().await;
*failures += 1;
let failure_count = *failures;
let turn_count = *self.turn_count.lock().await;
tracing::warn!(
"Task failure detected in response (failure count: {})",
failure_count
);
// Check if we should trigger fallback
if turn_count >= self.lead_turns
&& !*self.in_fallback_mode.lock().await
&& failure_count >= self.max_failures_before_fallback
{
let mut in_fallback = self.in_fallback_mode.lock().await;
let mut fallback_remaining = self.fallback_remaining.lock().await;
*in_fallback = true;
*fallback_remaining = self.fallback_turns;
*failures = 0; // Reset failure count when entering fallback
tracing::warn!(
"🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns",
self.max_failures_before_fallback,
self.fallback_turns
);
}
} else {
// Success - reset failure count and handle fallback mode
let mut failures = self.failure_count.lock().await;
*failures = 0;
let mut in_fallback = self.in_fallback_mode.lock().await;
let mut fallback_remaining = self.fallback_remaining.lock().await;
if *in_fallback {
*fallback_remaining -= 1;
if *fallback_remaining == 0 {
*in_fallback = false;
tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed");
}
}
}
// Increment turn count on any completion (success or task failure)
let mut count = self.turn_count.lock().await;
*count += 1;
}
Err(_) => {
// Technical failure - just log and let it bubble up
// For technical failures (API/LLM issues), we don't want to second-guess
// the model choice - just let the default model handle it
tracing::warn!(
"Technical failure detected - API/LLM issue, will use default model"
);
// Don't increment turn count or failure tracking for technical failures
// as these are temporary infrastructure issues, not model capability issues
}
}
}
/// Detect task-level failures in the model's response
async fn detect_task_failures(&self, message: &Message) -> bool {
let mut failure_indicators = 0;
for content in &message.content {
match content {
MessageContent::ToolRequest(tool_request) => {
// Check if tool request itself failed (malformed, etc.)
if tool_request.tool_call.is_err() {
failure_indicators += 1;
tracing::debug!(
"Failed tool request detected: {:?}",
tool_request.tool_call
);
}
}
MessageContent::ToolResponse(tool_response) => {
// Check if tool execution failed
if let Err(tool_error) = &tool_response.tool_result {
failure_indicators += 1;
tracing::debug!("Tool execution failure detected: {:?}", tool_error);
} else if let Ok(contents) = &tool_response.tool_result {
// Check tool output for error indicators
if self.contains_error_indicators(contents) {
failure_indicators += 1;
tracing::debug!("Tool output contains error indicators");
}
}
}
MessageContent::Text(text_content) => {
// Check for user correction patterns or error acknowledgments
if self.contains_user_correction_patterns(&text_content.text) {
failure_indicators += 1;
tracing::debug!("User correction pattern detected in text");
}
}
_ => {}
}
}
// Consider it a failure if we have multiple failure indicators
failure_indicators >= 1
}
/// Check if tool output contains error indicators
fn contains_error_indicators(&self, contents: &[Content]) -> bool {
for content in contents {
if let Content::Text(text_content) = content {
let text_lower = text_content.text.to_lowercase();
// Common error patterns in tool outputs
if text_lower.contains("error:")
|| text_lower.contains("failed:")
|| text_lower.contains("exception:")
|| text_lower.contains("traceback")
|| text_lower.contains("syntax error")
|| text_lower.contains("permission denied")
|| text_lower.contains("file not found")
|| text_lower.contains("command not found")
|| text_lower.contains("compilation failed")
|| text_lower.contains("test failed")
|| text_lower.contains("assertion failed")
{
return true;
}
}
}
false
}
/// Check for user correction patterns in text
fn contains_user_correction_patterns(&self, text: &str) -> bool {
let text_lower = text.to_lowercase();
// Patterns indicating user is correcting or expressing dissatisfaction
text_lower.contains("that's wrong")
|| text_lower.contains("that's not right")
|| text_lower.contains("that doesn't work")
|| text_lower.contains("try again")
|| text_lower.contains("let me correct")
|| text_lower.contains("actually, ")
|| text_lower.contains("no, that's")
|| text_lower.contains("that's incorrect")
|| text_lower.contains("fix this")
|| text_lower.contains("this is broken")
|| text_lower.contains("this doesn't")
|| text_lower.starts_with("no,")
|| text_lower.starts_with("wrong")
|| text_lower.starts_with("incorrect")
}
}
impl LeadWorkerProviderTrait for LeadWorkerProvider {
/// Get information about the lead and worker models for logging
fn get_model_info(&self) -> (String, String) {
let lead_model = self.lead_provider.get_model_config().model_name;
let worker_model = self.worker_provider.get_model_config().model_name;
(lead_model, worker_model)
}
}
#[async_trait]
impl Provider for LeadWorkerProvider {
fn metadata() -> ProviderMetadata {
// This is a wrapper provider, so we return minimal metadata
ProviderMetadata::new(
"lead_worker",
"Lead/Worker Provider",
"A provider that switches between lead and worker models based on turn count",
"", // No default model as this is determined by the wrapped providers
vec![], // No known models as this depends on wrapped providers
"", // No doc link
vec![], // No config keys as configuration is done through wrapped providers
)
}
fn get_model_config(&self) -> ModelConfig {
// Return the lead provider's model config as the default
// In practice, this might need to be more sophisticated
self.lead_provider.get_model_config()
}
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
// Get the active provider
let provider = self.get_active_provider().await;
// Log which provider is being used
let turn_count = *self.turn_count.lock().await;
let in_fallback = *self.in_fallback_mode.lock().await;
let fallback_remaining = *self.fallback_remaining.lock().await;
let provider_type = if turn_count < self.lead_turns {
"lead (initial)"
} else if in_fallback {
"lead (fallback)"
} else {
"worker"
};
if in_fallback {
tracing::info!(
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining)",
provider_type,
turn_count + 1,
fallback_remaining
);
} else {
tracing::info!(
"Using {} provider for turn {} (lead_turns: {})",
provider_type,
turn_count + 1,
self.lead_turns
);
}
// Make the completion request
let result = provider.complete(system, messages, tools).await;
// For technical failures, try with default model (lead provider) instead
let final_result = match &result {
Err(_) => {
tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type);
// Try with lead provider as the default/fallback for technical failures
let default_result = self.lead_provider.complete(system, messages, tools).await;
match &default_result {
Ok(_) => {
tracing::info!(
"✅ Default model (lead provider) succeeded after technical failure"
);
default_result
}
Err(_) => {
tracing::error!("❌ Default model (lead provider) also failed - returning original error");
result // Return the original error
}
}
}
Ok(_) => result, // Success with original provider
};
// Handle the result and update tracking (only for successful completions)
self.handle_completion_result(&final_result).await;
final_result
}
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
// Combine models from both providers
let lead_models = self.lead_provider.fetch_supported_models_async().await?;
let worker_models = self.worker_provider.fetch_supported_models_async().await?;
match (lead_models, worker_models) {
(Some(lead), Some(worker)) => {
let mut all_models = lead;
all_models.extend(worker);
all_models.sort();
all_models.dedup();
Ok(Some(all_models))
}
(Some(models), None) | (None, Some(models)) => Ok(Some(models)),
(None, None) => Ok(None),
}
}
fn supports_embeddings(&self) -> bool {
// Support embeddings if either provider supports them
self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings()
}
async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
// Use the lead provider for embeddings if it supports them, otherwise use worker
if self.lead_provider.supports_embeddings() {
self.lead_provider.create_embeddings(texts).await
} else if self.worker_provider.supports_embeddings() {
self.worker_provider.create_embeddings(texts).await
} else {
Err(ProviderError::ExecutionError(
"Neither lead nor worker provider supports embeddings".to_string(),
))
}
}
/// Check if this provider is a LeadWorkerProvider
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
Some(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
use chrono::Utc;
use mcp_core::{content::TextContent, Role};
#[derive(Clone)]
struct MockProvider {
name: String,
model_config: ModelConfig,
}
#[async_trait]
impl Provider for MockProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!("Response from {}", self.name),
annotations: None,
})],
},
ProviderUsage::new(self.name.clone(), Usage::default()),
))
}
}
#[tokio::test]
async fn test_lead_worker_switching() {
let lead_provider = Arc::new(MockProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
});
let worker_provider = Arc::new(MockProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3));
// First three turns should use lead provider
for i in 0..3 {
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "lead");
assert_eq!(provider.get_turn_count().await, i + 1);
assert!(!provider.is_in_fallback_mode().await);
}
// Subsequent turns should use worker provider
for i in 3..6 {
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "worker");
assert_eq!(provider.get_turn_count().await, i + 1);
assert!(!provider.is_in_fallback_mode().await);
}
// Reset and verify it goes back to lead
provider.reset_turn_count().await;
assert_eq!(provider.get_turn_count().await, 0);
assert_eq!(provider.get_failure_count().await, 0);
assert!(!provider.is_in_fallback_mode().await);
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "lead");
}
#[tokio::test]
async fn test_technical_failure_retry() {
let lead_provider = Arc::new(MockFailureProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
should_fail: false, // Lead provider works
});
let worker_provider = Arc::new(MockFailureProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
should_fail: true, // Worker will fail
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
// First two turns use lead (should succeed)
for _i in 0..2 {
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(!provider.is_in_fallback_mode().await);
}
// Next turn uses worker (will fail, but should retry with lead and succeed)
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures
assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode
// Another turn - should still try worker first, then retry with lead
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking
assert!(!provider.is_in_fallback_mode().await); // Still not in fallback mode
}
#[tokio::test]
async fn test_fallback_on_task_failures() {
// Test that task failures (not technical failures) still trigger fallback mode
// This would need a different mock that simulates task failures in successful responses
// For now, we'll test the fallback mode functionality directly
let lead_provider = Arc::new(MockFailureProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
should_fail: false,
});
let worker_provider = Arc::new(MockFailureProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
should_fail: false,
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
// Simulate being in fallback mode
{
let mut in_fallback = provider.in_fallback_mode.lock().await;
*in_fallback = true;
let mut fallback_remaining = provider.fallback_remaining.lock().await;
*fallback_remaining = 2;
let mut turn_count = provider.turn_count.lock().await;
*turn_count = 4; // Past initial lead turns
}
// Should use lead provider in fallback mode
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(provider.is_in_fallback_mode().await);
// One more fallback turn
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode
}
#[derive(Clone)]
struct MockFailureProvider {
name: String,
model_config: ModelConfig,
should_fail: bool,
}
#[async_trait]
impl Provider for MockFailureProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
if self.should_fail {
Err(ProviderError::ExecutionError(
"Simulated failure".to_string(),
))
} else {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!("Response from {}", self.name),
annotations: None,
})],
},
ProviderUsage::new(self.name.clone(), Usage::default()),
))
}
}
}
}

View File

@@ -13,10 +13,12 @@ pub mod gcpvertexai;
pub mod githubcopilot;
pub mod google;
pub mod groq;
pub mod lead_worker;
pub mod oauth;
pub mod ollama;
pub mod openai;
pub mod openrouter;
pub mod snowflake;
pub mod toolshim;
pub mod utils;
pub mod utils_universal_openai_stream;

View File

@@ -249,7 +249,7 @@ impl EmbeddingCapable for OpenAiProvider {
}
// Get embedding model from env var or use default
let embedding_model = std::env::var("EMBEDDING_MODEL")
let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
let request = EmbeddingRequest {

View File

@@ -0,0 +1,439 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::time::Duration;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use super::formats::snowflake::{create_request, get_usage, response_to_message};
use super::utils::{get_model, ImageFormat};
use crate::config::ConfigError;
use crate::message::Message;
use crate::model::ModelConfig;
use mcp_core::tool::Tool;
use url::Url;
pub const SNOWFLAKE_DEFAULT_MODEL: &str = "claude-3-7-sonnet";
pub const SNOWFLAKE_KNOWN_MODELS: &[&str] = &["claude-3-7-sonnet", "claude-3-5-sonnet"];
pub const SNOWFLAKE_DOC_URL: &str =
"https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#choosing-a-model";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SnowflakeAuth {
Token(String),
}
impl SnowflakeAuth {
pub fn token(token: String) -> Self {
Self::Token(token)
}
}
#[derive(Debug, serde::Serialize)]
pub struct SnowflakeProvider {
#[serde(skip)]
client: Client,
host: String,
auth: SnowflakeAuth,
model: ModelConfig,
image_format: ImageFormat,
}
impl Default for SnowflakeProvider {
fn default() -> Self {
let model = ModelConfig::new(SnowflakeProvider::metadata().default_model);
SnowflakeProvider::from_env(model).expect("Failed to initialize Snowflake provider")
}
}
impl SnowflakeProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let mut host: Result<String, ConfigError> = config.get_param("SNOWFLAKE_HOST");
if host.is_err() {
host = config.get_secret("SNOWFLAKE_HOST")
}
if host.is_err() {
return Err(ConfigError::NotFound(
"Did not find SNOWFLAKE_HOST in either config file or keyring".to_string(),
)
.into());
}
let mut host = host?;
// Convert host to lowercase
host = host.to_lowercase();
// Ensure host ends with snowflakecomputing.com
if !host.ends_with("snowflakecomputing.com") {
host = format!("{}.snowflakecomputing.com", host);
}
let mut token: Result<String, ConfigError> = config.get_param("SNOWFLAKE_TOKEN");
if token.is_err() {
token = config.get_secret("SNOWFLAKE_TOKEN")
}
if token.is_err() {
return Err(ConfigError::NotFound(
"Did not find SNOWFLAKE_TOKEN in either config file or keyring".to_string(),
)
.into());
}
let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;
// Use token-based authentication
let api_key = token?;
Ok(Self {
client,
host,
auth: SnowflakeAuth::token(api_key),
model,
image_format: ImageFormat::OpenAi,
})
}
async fn ensure_auth_header(&self) -> Result<String> {
match &self.auth {
// https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#using-a-programmatic-access-token-pat
SnowflakeAuth::Token(token) => Ok(format!("Bearer {}", token)),
}
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url_str =
if !self.host.starts_with("https://") && !self.host.starts_with("http://") {
format!("https://{}", self.host)
} else {
self.host.clone()
};
let base_url = Url::parse(&base_url_str)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let path = "api/v2/cortex/inference:complete";
let url = base_url.join(path).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let auth_header = self.ensure_auth_header().await?;
let response = self
.client
.post(url)
.header("Authorization", auth_header)
.header("User-Agent", "Goose")
.json(&payload)
.send()
.await?;
let status = response.status();
let payload_text: String = response.text().await.ok().unwrap_or_default();
if status == StatusCode::OK {
if let Ok(payload) = serde_json::from_str::<Value>(&payload_text) {
if payload.get("code").is_some() {
let code = payload
.get("code")
.and_then(|c| c.as_str())
.unwrap_or("Unknown code");
let message = payload
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown message");
return Err(ProviderError::RequestFailed(format!(
"{} - {}",
code, message
)));
}
}
}
let lines = payload_text.lines().collect::<Vec<_>>();
let mut text = String::new();
let mut tool_name = String::new();
let mut tool_input = String::new();
let mut tool_use_id = String::new();
for line in lines.iter() {
if line.is_empty() {
continue;
}
let json_str = match line.strip_prefix("data: ") {
Some(s) => s,
None => continue,
};
if let Ok(json_line) = serde_json::from_str::<Value>(json_str) {
let choices = match json_line.get("choices").and_then(|c| c.as_array()) {
Some(choices) => choices,
None => {
continue;
}
};
let choice = match choices.first() {
Some(choice) => choice,
None => {
continue;
}
};
let delta = match choice.get("delta") {
Some(delta) => delta,
None => {
continue;
}
};
// Track if we found text in content_list to avoid duplication
let mut found_text_in_content_list = false;
// Handle content_list array first
if let Some(content_list) = delta.get("content_list").and_then(|cl| cl.as_array()) {
for content_item in content_list {
match content_item.get("type").and_then(|t| t.as_str()) {
Some("text") => {
if let Some(text_content) =
content_item.get("text").and_then(|t| t.as_str())
{
text.push_str(text_content);
found_text_in_content_list = true;
}
}
Some("tool_use") => {
if let Some(tool_id) =
content_item.get("tool_use_id").and_then(|id| id.as_str())
{
tool_use_id.push_str(tool_id);
}
if let Some(name) =
content_item.get("name").and_then(|n| n.as_str())
{
tool_name.push_str(name);
}
if let Some(input) =
content_item.get("input").and_then(|i| i.as_str())
{
tool_input.push_str(input);
}
}
_ => {
// Handle content items without explicit type but with tool information
if let Some(name) =
content_item.get("name").and_then(|n| n.as_str())
{
tool_name.push_str(name);
}
if let Some(tool_id) =
content_item.get("tool_use_id").and_then(|id| id.as_str())
{
tool_use_id.push_str(tool_id);
}
if let Some(input) =
content_item.get("input").and_then(|i| i.as_str())
{
tool_input.push_str(input);
}
}
}
}
}
// Handle direct content field (for text) only if we didn't find text in content_list
if !found_text_in_content_list {
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
text.push_str(content);
}
}
}
}
// Build the appropriate response structure
let mut content_list = Vec::new();
// Add text content if available
if !text.is_empty() {
content_list.push(json!({
"type": "text",
"text": text
}));
}
// Add tool use content only if we have complete tool information
if !tool_use_id.is_empty() && !tool_name.is_empty() {
// Parse tool input as JSON if it's not empty
let parsed_input = if tool_input.is_empty() {
json!({})
} else {
serde_json::from_str::<Value>(&tool_input)
.unwrap_or_else(|_| json!({"raw_input": tool_input}))
};
content_list.push(json!({
"type": "tool_use",
"tool_use_id": tool_use_id,
"name": tool_name,
"input": parsed_input
}));
}
// Ensure we always have at least some content
if content_list.is_empty() {
content_list.push(json!({
"type": "text",
"text": ""
}));
}
let answer_payload = json!({
"role": "assistant",
"content": text,
"content_list": content_list
});
match status {
StatusCode::OK => Ok(answer_payload),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
// Extract a clean error message from the response if available
let error_msg = payload_text
.lines()
.find(|line| line.contains("\"message\""))
.and_then(|line| {
let json_str = line.strip_prefix("data: ").unwrap_or(line);
serde_json::from_str::<Value>(json_str).ok()
})
.and_then(|json| {
json.get("message")
.and_then(|m| m.as_str())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "Invalid credentials".to_string());
Err(ProviderError::Authentication(format!(
"Authentication failed. Please check your SNOWFLAKE_TOKEN and SNOWFLAKE_HOST configuration. Error: {}",
error_msg
)))
}
StatusCode::BAD_REQUEST => {
// Snowflake provides a generic 'error' but also includes 'external_model_message' which is provider specific
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
let payload_str = payload_text.to_lowercase();
let check_phrases = [
"too long",
"context length",
"context_length_exceeded",
"reduce the length",
"token count",
"exceeds",
"exceed context limit",
"input length",
"max_tokens",
"decrease input length",
"context limit",
];
if check_phrases.iter().any(|c| payload_str.contains(c)) {
return Err(ProviderError::ContextLengthExceeded("Request exceeds maximum context length. Please reduce the number of messages or content size.".to_string()));
}
// Try to parse a clean error message from the response
let error_msg = if let Ok(json) = serde_json::from_str::<Value>(&payload_text) {
json.get("message")
.and_then(|m| m.as_str())
.map(|s| s.to_string())
.or_else(|| {
json.get("external_model_message")
.and_then(|ext| ext.get("message"))
.and_then(|m| m.as_str())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "Bad request".to_string())
} else {
"Bad request".to_string()
};
tracing::debug!(
"Provider request failed with status: {}. Response: {}",
status,
payload_text
);
Err(ProviderError::RequestFailed(format!(
"Request failed: {}",
error_msg
)))
}
StatusCode::TOO_MANY_REQUESTS => Err(ProviderError::RateLimitExceeded(
"Rate limit exceeded. Please try again later.".to_string(),
)),
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(
"Snowflake service is temporarily unavailable. Please try again later."
.to_string(),
))
}
_ => {
tracing::debug!(
"Provider request failed with status: {}. Response: {}",
status,
payload_text
);
Err(ProviderError::RequestFailed(format!(
"Request failed with status: {}",
status
)))
}
}
}
}
#[async_trait]
impl Provider for SnowflakeProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"snowflake",
"Snowflake",
"Access several models using Snowflake Cortex services.",
SNOWFLAKE_DEFAULT_MODEL,
SNOWFLAKE_KNOWN_MODELS.to_vec(),
SNOWFLAKE_DOC_URL,
vec![
ConfigKey::new("SNOWFLAKE_HOST", true, false, None),
ConfigKey::new("SNOWFLAKE_TOKEN", true, true, None),
],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools)?;
let response = self.post(payload.clone()).await?;
// Parse response
let message = response_to_message(response.clone())?;
let usage = get_usage(&response)?;
let model = get_model(&response);
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
}
}

View File

@@ -38,6 +38,7 @@ use crate::model::ModelConfig;
use crate::providers::formats::openai::create_request;
use anyhow::Result;
use mcp_core::tool::{Tool, ToolCall};
use mcp_core::Content;
use reqwest::Client;
use serde_json::{json, Value};
use std::time::Duration;
@@ -164,7 +165,10 @@ impl OllamaInterpreter {
payload["stream"] = json!(false); // needed for the /api/chat endpoint to work
payload["format"] = format_schema;
// tracing::warn!("payload: {}", serde_json::to_string_pretty(&payload).unwrap_or_default());
tracing::info!(
"Tool interpreter payload: {}",
serde_json::to_string_pretty(&payload).unwrap_or_default()
);
let response = self.client.post(&url).json(&payload).send().await?;
@@ -193,7 +197,10 @@ impl OllamaInterpreter {
fn process_interpreter_response(response: &Value) -> Result<Vec<ToolCall>, ProviderError> {
let mut tool_calls = Vec::new();
tracing::info!(
"Tool interpreter response is {}",
serde_json::to_string_pretty(&response).unwrap_or_default()
);
// Extract tool_calls array from the response
if response.get("message").is_some() && response["message"].get("content").is_some() {
let content = response["message"]["content"].as_str().unwrap_or_default();
@@ -298,6 +305,72 @@ pub fn format_tool_info(tools: &[Tool]) -> String {
tool_info
}
/// Convert messages containing ToolRequest/ToolResponse to text messages for toolshim mode
/// This is necessary because some providers (like Bedrock) validate that tool_use/tool_result
/// blocks can only exist when tools are defined, but in toolshim mode we pass empty tools
pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec<Message> {
messages
.iter()
.map(|message| {
let mut new_content = Vec::new();
let mut has_tool_content = false;
for content in &message.content {
match content {
MessageContent::ToolRequest(req) => {
has_tool_content = true;
// Convert tool request to text format
let text = if let Ok(tool_call) = &req.tool_call {
format!(
"Using tool: {}\n{{\n \"name\": \"{}\",\n \"arguments\": {}\n}}",
tool_call.name,
tool_call.name,
serde_json::to_string_pretty(&tool_call.arguments)
.unwrap_or_default()
)
} else {
"Tool request failed".to_string()
};
new_content.push(MessageContent::text(text));
}
MessageContent::ToolResponse(res) => {
has_tool_content = true;
// Convert tool response to text format
let text = match &res.tool_result {
Ok(contents) => {
let text_contents: Vec<String> = contents
.iter()
.filter_map(|c| match c {
Content::Text(t) => Some(t.text.clone()),
_ => None,
})
.collect();
format!("Tool result:\n{}", text_contents.join("\n"))
}
Err(e) => format!("Tool error: {}", e),
};
new_content.push(MessageContent::text(text));
}
_ => {
// Keep other content types as-is
new_content.push(content.clone());
}
}
}
if has_tool_content {
Message {
role: message.role.clone(),
content: new_content,
created: message.created,
}
} else {
message.clone()
}
})
.collect()
}
/// Modifies the system prompt to include tool usage instructions when tool interpretation is enabled
pub fn modify_system_prompt_for_tool_json(system_prompt: &str, tools: &[Tool]) -> String {
let tool_info = format_tool_info(tools);

View File

@@ -28,7 +28,7 @@ fn default_version() -> String {
///
/// # Example
///
/// ```
///
/// use goose::recipe::Recipe;
///
/// // Using the builder pattern
@@ -52,7 +52,7 @@ fn default_version() -> String {
/// author: None,
/// parameters: None,
/// };
/// ```
///
#[derive(Serialize, Deserialize, Debug)]
pub struct Recipe {
// Required fields
@@ -166,7 +166,7 @@ impl Recipe {
///
/// # Example
///
/// ```
///
/// use goose::recipe::Recipe;
///
/// let recipe = Recipe::builder()
@@ -175,7 +175,7 @@ impl Recipe {
/// .instructions("Act as a helpful assistant")
/// .build()
/// .expect("Failed to build Recipe: missing required fields");
/// ```
///
pub fn builder() -> RecipeBuilder {
RecipeBuilder {
version: default_version(),

View File

@@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler};
use crate::agents::AgentEvent;
use crate::agents::{Agent, SessionConfig};
use crate::config::{self, Config};
use crate::message::Message;
@@ -20,6 +21,10 @@ use crate::recipe::Recipe;
use crate::session;
use crate::session::storage::SessionMetadata;
// Track running tasks with their abort handles
type RunningTasksMap = HashMap<String, tokio::task::AbortHandle>;
type JobsMap = HashMap<String, (JobId, ScheduledJob)>;
pub fn get_default_scheduler_storage_path() -> Result<PathBuf, io::Error> {
let strategy = choose_app_strategy(config::APP_STRATEGY.clone())
.map_err(|e| io::Error::new(io::ErrorKind::NotFound, e.to_string()))?;
@@ -111,11 +116,15 @@ pub struct ScheduledJob {
pub currently_running: bool,
#[serde(default)]
pub paused: bool,
#[serde(default)]
pub current_session_id: Option<String>,
#[serde(default)]
pub process_start_time: Option<DateTime<Utc>>,
}
async fn persist_jobs_from_arc(
storage_path: &Path,
jobs_arc: &Arc<Mutex<HashMap<String, (JobId, ScheduledJob)>>>,
jobs_arc: &Arc<Mutex<JobsMap>>,
) -> Result<(), SchedulerError> {
let jobs_guard = jobs_arc.lock().await;
let list: Vec<ScheduledJob> = jobs_guard.values().map(|(_, j)| j.clone()).collect();
@@ -129,8 +138,9 @@ async fn persist_jobs_from_arc(
pub struct Scheduler {
internal_scheduler: TokioJobScheduler,
jobs: Arc<Mutex<HashMap<String, (JobId, ScheduledJob)>>>,
jobs: Arc<Mutex<JobsMap>>,
storage_path: PathBuf,
running_tasks: Arc<Mutex<RunningTasksMap>>,
}
impl Scheduler {
@@ -140,11 +150,13 @@ impl Scheduler {
.map_err(|e| SchedulerError::SchedulerInternalError(e.to_string()))?;
let jobs = Arc::new(Mutex::new(HashMap::new()));
let running_tasks = Arc::new(Mutex::new(HashMap::new()));
let arc_self = Arc::new(Self {
internal_scheduler,
jobs,
storage_path,
running_tasks,
});
arc_self.load_jobs_from_storage().await?;
@@ -208,17 +220,21 @@ impl Scheduler {
let mut stored_job = original_job_spec.clone();
stored_job.source = destination_recipe_path.to_string_lossy().into_owned();
stored_job.current_session_id = None;
stored_job.process_start_time = None;
tracing::info!("Updated job source path to: {}", stored_job.source);
let job_for_task = stored_job.clone();
let jobs_arc_for_task = self.jobs.clone();
let storage_path_for_task = self.storage_path.clone();
let running_tasks_for_task = self.running_tasks.clone();
let cron_task = Job::new_async(&stored_job.cron, move |_uuid, _l| {
let task_job_id = job_for_task.id.clone();
let current_jobs_arc = jobs_arc_for_task.clone();
let local_storage_path = storage_path_for_task.clone();
let job_to_execute = job_for_task.clone(); // Clone for run_scheduled_job_internal
let running_tasks_arc = running_tasks_for_task.clone();
Box::pin(async move {
// Check if the job is paused before executing
@@ -243,6 +259,7 @@ impl Scheduler {
if let Some((_, current_job_in_map)) = jobs_map_guard.get_mut(&task_job_id) {
current_job_in_map.last_run = Some(current_time);
current_job_in_map.currently_running = true;
current_job_in_map.process_start_time = Some(current_time);
needs_persist = true;
}
}
@@ -258,14 +275,37 @@ impl Scheduler {
);
}
}
// Pass None for provider_override in normal execution
let result = run_scheduled_job_internal(job_to_execute, None).await;
// Spawn the job execution as an abortable task
let job_task = tokio::spawn(run_scheduled_job_internal(
job_to_execute.clone(),
None,
Some(current_jobs_arc.clone()),
Some(task_job_id.clone()),
));
// Store the abort handle at the scheduler level
{
let mut running_tasks_guard = running_tasks_arc.lock().await;
running_tasks_guard.insert(task_job_id.clone(), job_task.abort_handle());
}
// Wait for the job to complete or be aborted
let result = job_task.await;
// Remove the abort handle
{
let mut running_tasks_guard = running_tasks_arc.lock().await;
running_tasks_guard.remove(&task_job_id);
}
// Update the job status after execution
{
let mut jobs_map_guard = current_jobs_arc.lock().await;
if let Some((_, current_job_in_map)) = jobs_map_guard.get_mut(&task_job_id) {
current_job_in_map.currently_running = false;
current_job_in_map.current_session_id = None;
current_job_in_map.process_start_time = None;
needs_persist = true;
}
}
@@ -282,13 +322,28 @@ impl Scheduler {
}
}
if let Err(e) = result {
match result {
Ok(Ok(_session_id)) => {
tracing::info!("Scheduled job '{}' completed successfully", &task_job_id);
}
Ok(Err(e)) => {
tracing::error!(
"Scheduled job '{}' execution failed: {}",
&e.job_id,
e.error
);
}
Err(join_error) if join_error.is_cancelled() => {
tracing::info!("Scheduled job '{}' was cancelled/killed", &task_job_id);
}
Err(join_error) => {
tracing::error!(
"Scheduled job '{}' task failed: {}",
&task_job_id,
join_error
);
}
}
})
})
.map_err(|e| SchedulerError::CronParseError(e.to_string()))?;
@@ -328,12 +383,14 @@ impl Scheduler {
let job_for_task = job_to_load.clone();
let jobs_arc_for_task = self.jobs.clone();
let storage_path_for_task = self.storage_path.clone();
let running_tasks_for_task = self.running_tasks.clone();
let cron_task = Job::new_async(&job_to_load.cron, move |_uuid, _l| {
let task_job_id = job_for_task.id.clone();
let current_jobs_arc = jobs_arc_for_task.clone();
let local_storage_path = storage_path_for_task.clone();
let job_to_execute = job_for_task.clone(); // Clone for run_scheduled_job_internal
let running_tasks_arc = running_tasks_for_task.clone();
Box::pin(async move {
// Check if the job is paused before executing
@@ -358,6 +415,7 @@ impl Scheduler {
if let Some((_, stored_job)) = jobs_map_guard.get_mut(&task_job_id) {
stored_job.last_run = Some(current_time);
stored_job.currently_running = true;
stored_job.process_start_time = Some(current_time);
needs_persist = true;
}
}
@@ -373,14 +431,37 @@ impl Scheduler {
);
}
}
// Pass None for provider_override in normal execution
let result = run_scheduled_job_internal(job_to_execute, None).await;
// Spawn the job execution as an abortable task
let job_task = tokio::spawn(run_scheduled_job_internal(
job_to_execute,
None,
Some(current_jobs_arc.clone()),
Some(task_job_id.clone()),
));
// Store the abort handle at the scheduler level
{
let mut running_tasks_guard = running_tasks_arc.lock().await;
running_tasks_guard.insert(task_job_id.clone(), job_task.abort_handle());
}
// Wait for the job to complete or be aborted
let result = job_task.await;
// Remove the abort handle
{
let mut running_tasks_guard = running_tasks_arc.lock().await;
running_tasks_guard.remove(&task_job_id);
}
// Update the job status after execution
{
let mut jobs_map_guard = current_jobs_arc.lock().await;
if let Some((_, stored_job)) = jobs_map_guard.get_mut(&task_job_id) {
stored_job.currently_running = false;
stored_job.current_session_id = None;
stored_job.process_start_time = None;
needs_persist = true;
}
}
@@ -397,13 +478,31 @@ impl Scheduler {
}
}
if let Err(e) = result {
match result {
Ok(Ok(_session_id)) => {
tracing::info!(
"Scheduled job '{}' completed successfully",
&task_job_id
);
}
Ok(Err(e)) => {
tracing::error!(
"Scheduled job '{}' execution failed: {}",
&e.job_id,
e.error
);
}
Err(join_error) if join_error.is_cancelled() => {
tracing::info!("Scheduled job '{}' was cancelled/killed", &task_job_id);
}
Err(join_error) => {
tracing::error!(
"Scheduled job '{}' task failed: {}",
&task_job_id,
join_error
);
}
}
})
})
.map_err(|e| SchedulerError::CronParseError(e.to_string()))?;
@@ -421,7 +520,7 @@ impl Scheduler {
// Renamed and kept for direct use when a guard is already held (e.g. add/remove)
async fn persist_jobs_to_storage_with_guard(
&self,
jobs_guard: &tokio::sync::MutexGuard<'_, HashMap<String, (JobId, ScheduledJob)>>,
jobs_guard: &tokio::sync::MutexGuard<'_, JobsMap>,
) -> Result<(), SchedulerError> {
let list: Vec<ScheduledJob> = jobs_guard.values().map(|(_, j)| j.clone()).collect();
if let Some(parent) = self.storage_path.parent() {
@@ -523,14 +622,36 @@ impl Scheduler {
}
};
// Pass None for provider_override in normal execution
let run_result = run_scheduled_job_internal(job_to_run.clone(), None).await;
// Spawn the job execution as an abortable task for run_now
let job_task = tokio::spawn(run_scheduled_job_internal(
job_to_run.clone(),
None,
Some(self.jobs.clone()),
Some(sched_id.to_string()),
));
// Store the abort handle for run_now jobs
{
let mut running_tasks_guard = self.running_tasks.lock().await;
running_tasks_guard.insert(sched_id.to_string(), job_task.abort_handle());
}
// Wait for the job to complete or be aborted
let run_result = job_task.await;
// Remove the abort handle
{
let mut running_tasks_guard = self.running_tasks.lock().await;
running_tasks_guard.remove(sched_id);
}
// Clear the currently_running flag after execution
{
let mut jobs_guard = self.jobs.lock().await;
if let Some((_tokio_job_id, job_in_map)) = jobs_guard.get_mut(sched_id) {
job_in_map.currently_running = false;
job_in_map.current_session_id = None;
job_in_map.process_start_time = None;
job_in_map.last_run = Some(Utc::now());
} // MutexGuard is dropped here
}
@@ -539,12 +660,24 @@ impl Scheduler {
self.persist_jobs().await?;
match run_result {
Ok(session_id) => Ok(session_id),
Err(e) => Err(SchedulerError::AnyhowError(anyhow!(
Ok(Ok(session_id)) => Ok(session_id),
Ok(Err(e)) => Err(SchedulerError::AnyhowError(anyhow!(
"Failed to execute job '{}' immediately: {}",
sched_id,
e.error
))),
Err(join_error) if join_error.is_cancelled() => {
tracing::info!("Run now job '{}' was cancelled/killed", sched_id);
Err(SchedulerError::AnyhowError(anyhow!(
"Job '{}' was successfully cancelled",
sched_id
)))
}
Err(join_error) => Err(SchedulerError::AnyhowError(anyhow!(
"Failed to execute job '{}' immediately: {}",
sched_id,
join_error
))),
}
}
@@ -608,12 +741,14 @@ impl Scheduler {
let job_for_task = job_def.clone();
let jobs_arc_for_task = self.jobs.clone();
let storage_path_for_task = self.storage_path.clone();
let running_tasks_for_task = self.running_tasks.clone();
let cron_task = Job::new_async(&new_cron, move |_uuid, _l| {
let task_job_id = job_for_task.id.clone();
let current_jobs_arc = jobs_arc_for_task.clone();
let local_storage_path = storage_path_for_task.clone();
let job_to_execute = job_for_task.clone();
let running_tasks_arc = running_tasks_for_task.clone();
Box::pin(async move {
// Check if the job is paused before executing
@@ -641,6 +776,7 @@ impl Scheduler {
{
current_job_in_map.last_run = Some(current_time);
current_job_in_map.currently_running = true;
current_job_in_map.process_start_time = Some(current_time);
needs_persist = true;
}
}
@@ -657,7 +793,29 @@ impl Scheduler {
}
}
let result = run_scheduled_job_internal(job_to_execute, None).await;
// Spawn the job execution as an abortable task
let job_task = tokio::spawn(run_scheduled_job_internal(
job_to_execute,
None,
Some(current_jobs_arc.clone()),
Some(task_job_id.clone()),
));
// Store the abort handle at the scheduler level
{
let mut running_tasks_guard = running_tasks_arc.lock().await;
running_tasks_guard
.insert(task_job_id.clone(), job_task.abort_handle());
}
// Wait for the job to complete or be aborted
let result = job_task.await;
// Remove the abort handle
{
let mut running_tasks_guard = running_tasks_arc.lock().await;
running_tasks_guard.remove(&task_job_id);
}
// Update the job status after execution
{
@@ -666,6 +824,8 @@ impl Scheduler {
jobs_map_guard.get_mut(&task_job_id)
{
current_job_in_map.currently_running = false;
current_job_in_map.current_session_id = None;
current_job_in_map.process_start_time = None;
needs_persist = true;
}
}
@@ -682,13 +842,34 @@ impl Scheduler {
}
}
if let Err(e) = result {
match result {
Ok(Ok(_session_id)) => {
tracing::info!(
"Scheduled job '{}' completed successfully",
&task_job_id
);
}
Ok(Err(e)) => {
tracing::error!(
"Scheduled job '{}' execution failed: {}",
&e.job_id,
e.error
);
}
Err(join_error) if join_error.is_cancelled() => {
tracing::info!(
"Scheduled job '{}' was cancelled/killed",
&task_job_id
);
}
Err(join_error) => {
tracing::error!(
"Scheduled job '{}' task failed: {}",
&task_job_id,
join_error
);
}
}
})
})
.map_err(|e| SchedulerError::CronParseError(e.to_string()))?;
@@ -709,6 +890,70 @@ impl Scheduler {
None => Err(SchedulerError::JobNotFound(sched_id.to_string())),
}
}
pub async fn kill_running_job(&self, sched_id: &str) -> Result<(), SchedulerError> {
let mut jobs_guard = self.jobs.lock().await;
match jobs_guard.get_mut(sched_id) {
Some((_, job_def)) => {
if !job_def.currently_running {
return Err(SchedulerError::AnyhowError(anyhow!(
"Schedule '{}' is not currently running",
sched_id
)));
}
tracing::info!("Killing running job '{}'", sched_id);
// Abort the running task if it exists
{
let mut running_tasks_guard = self.running_tasks.lock().await;
if let Some(abort_handle) = running_tasks_guard.remove(sched_id) {
abort_handle.abort();
tracing::info!("Aborted running task for job '{}'", sched_id);
} else {
tracing::warn!(
"No abort handle found for job '{}' in running tasks map",
sched_id
);
}
}
// Mark the job as no longer running
job_def.currently_running = false;
job_def.current_session_id = None;
job_def.process_start_time = None;
self.persist_jobs_to_storage_with_guard(&jobs_guard).await?;
tracing::info!("Successfully killed job '{}'", sched_id);
Ok(())
}
None => Err(SchedulerError::JobNotFound(sched_id.to_string())),
}
}
pub async fn get_running_job_info(
&self,
sched_id: &str,
) -> Result<Option<(String, DateTime<Utc>)>, SchedulerError> {
let jobs_guard = self.jobs.lock().await;
match jobs_guard.get(sched_id) {
Some((_, job_def)) => {
if job_def.currently_running {
if let (Some(session_id), Some(start_time)) =
(&job_def.current_session_id, &job_def.process_start_time)
{
Ok(Some((session_id.clone(), *start_time)))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
None => Err(SchedulerError::JobNotFound(sched_id.to_string())),
}
}
}
#[derive(Debug)]
@@ -720,6 +965,8 @@ struct JobExecutionError {
async fn run_scheduled_job_internal(
job: ScheduledJob,
provider_override: Option<Arc<dyn GooseProvider>>, // New optional parameter
jobs_arc: Option<Arc<Mutex<JobsMap>>>,
job_id: Option<String>,
) -> std::result::Result<String, JobExecutionError> {
tracing::info!("Executing job: {} (Source: {})", job.id, job.source);
@@ -811,6 +1058,15 @@ async fn run_scheduled_job_internal(
tracing::info!("Agent configured with provider for job '{}'", job.id);
let session_id_for_return = session::generate_session_id();
// Update the job with the session ID if we have access to the jobs arc
if let (Some(jobs_arc), Some(job_id_str)) = (jobs_arc.as_ref(), job_id.as_ref()) {
let mut jobs_guard = jobs_arc.lock().await;
if let Some((_, job_def)) = jobs_guard.get_mut(job_id_str) {
job_def.current_session_id = Some(session_id_for_return.clone());
}
}
let session_file_path = crate::session::storage::get_path(
crate::session::storage::Identifier::Name(session_id_for_return.clone()),
);
@@ -843,13 +1099,19 @@ async fn run_scheduled_job_internal(
use futures::StreamExt;
while let Some(message_result) = stream.next().await {
// Check if the task has been cancelled
tokio::task::yield_now().await;
match message_result {
Ok(msg) => {
Ok(AgentEvent::Message(msg)) => {
if msg.role == mcp_core::role::Role::Assistant {
tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content);
}
all_session_messages.push(msg);
}
Ok(AgentEvent::McpNotification(_)) => {
// Handle notifications if needed
}
Err(e) => {
tracing::error!(
"[Job {}] Error receiving message from agent: {}",
@@ -1053,6 +1315,8 @@ mod tests {
last_run: None,
currently_running: false,
paused: false,
current_session_id: None,
process_start_time: None,
};
// Create the mock provider instance for the test
@@ -1061,7 +1325,7 @@ mod tests {
// Call run_scheduled_job_internal, passing the mock provider
let created_session_id =
run_scheduled_job_internal(dummy_job.clone(), Some(mock_provider_instance))
run_scheduled_job_internal(dummy_job.clone(), Some(mock_provider_instance), None, None)
.await
.expect("run_scheduled_job_internal failed");

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use anyhow::Result;
use futures::StreamExt;
use goose::agents::Agent;
use goose::agents::{Agent, AgentEvent};
use goose::message::Message;
use goose::model::ModelConfig;
use goose::providers::base::Provider;
@@ -132,7 +132,10 @@ async fn run_truncate_test(
let mut responses = Vec::new();
while let Some(response_result) = reply_stream.next().await {
match response_result {
Ok(response) => responses.push(response),
Ok(AgentEvent::Message(response)) => responses.push(response),
Ok(AgentEvent::McpNotification(n)) => {
println!("MCP Notification: {n:?}");
}
Err(e) => {
println!("Error: {:?}", e);
return Err(e);

View File

@@ -4,7 +4,7 @@ use goose::message::{Message, MessageContent};
use goose::providers::base::Provider;
use goose::providers::errors::ProviderError;
use goose::providers::{
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter,
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, snowflake,
};
use mcp_core::content::Content;
use mcp_core::tool::Tool;
@@ -491,6 +491,17 @@ async fn test_google_provider() -> Result<()> {
.await
}
#[tokio::test]
async fn test_snowflake_provider() -> Result<()> {
test_provider(
"Snowflake",
&["SNOWFLAKE_HOST", "SNOWFLAKE_TOKEN"],
None,
snowflake::SnowflakeProvider::default,
)
.await
}
// Print the final test report
#[ctor::dtor]
fn print_test_report() {

View File

@@ -1,7 +1,6 @@
use mcp_client::{
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
transport::{SseTransport, StdioTransport, Transport},
McpService,
};
use rand::Rng;
use rand::SeedableRng;
@@ -20,18 +19,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle1 = transport1.start().await?;
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
let client1 = McpClient::new(service1);
let client1 = McpClient::connect(handle1, Duration::from_secs(30)).await?;
let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle2 = transport2.start().await?;
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30));
let client2 = McpClient::new(service2);
let client2 = McpClient::connect(handle2, Duration::from_secs(30)).await?;
let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
let handle3 = transport3.start().await?;
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
let client3 = McpClient::new(service3);
let client3 = McpClient::connect(handle3, Duration::from_secs(10)).await?;
// Initialize both clients
let mut clients: Vec<Box<dyn McpClientTrait>> =

View File

@@ -0,0 +1,122 @@
use anyhow::Result;
use futures::lock::Mutex;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, Transport};
use mcp_client::StdioTransport;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing_subscriber::EnvFilter;
#[tokio::main]
async fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("mcp_client=debug".parse().unwrap())
.add_directive("eventsource_client=info".parse().unwrap()),
)
.init();
test_transport(sse_transport().await?).await?;
test_transport(stdio_transport().await?).await?;
Ok(())
}
async fn sse_transport() -> Result<SseTransport> {
let port = "60053";
tokio::process::Command::new("npx")
.env("PORT", port)
.arg("@modelcontextprotocol/server-everything")
.arg("sse")
.spawn()?;
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(SseTransport::new(
format!("http://localhost:{}/sse", port),
HashMap::new(),
))
}
async fn stdio_transport() -> Result<StdioTransport> {
Ok(StdioTransport::new(
"npx",
vec!["@modelcontextprotocol/server-everything"]
.into_iter()
.map(|s| s.to_string())
.collect(),
HashMap::new(),
))
}
async fn test_transport<T>(transport: T) -> Result<()>
where
T: Transport + Send + 'static,
{
// Start transport
let handle = transport.start().await?;
// Create client
let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?;
println!("Client created\n");
let mut receiver = client.subscribe().await;
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
tokio::spawn(async move {
while let Some(event) = receiver.recv().await {
println!("Received event: {event:?}");
events_clone.lock().await.push(event);
}
});
// Initialize
let server_info = client
.initialize(
ClientInfo {
name: "test-client".into(),
version: "1.0.0".into(),
},
ClientCapabilities::default(),
)
.await?;
println!("Connected to server: {server_info:?}\n");
// Sleep for 100ms to allow the server to start - surprisingly this is required!
tokio::time::sleep(Duration::from_millis(500)).await;
// List tools
let tools = client.list_tools(None).await?;
println!("Available tools: {tools:#?}\n");
// Call tool
let tool_result = client
.call_tool("echo", serde_json::json!({ "message": "honk" }))
.await?;
println!("Tool result: {tool_result:#?}\n");
let collected_eventes_before = events.lock().await.len();
let n_steps = 5;
let long_op = client
.call_tool(
"longRunningOperation",
serde_json::json!({ "duration": 3, "steps": n_steps }),
)
.await?;
println!("Long op result: {long_op:#?}\n");
let collected_events_after = events.lock().await.len();
assert_eq!(collected_events_after - collected_eventes_before, n_steps);
// List resources
let resources = client.list_resources(None).await?;
println!("Resources: {resources:#?}\n");
// Read resource
let resource = client.read_resource("test://static/resource/1").await?;
println!("Resource: {resource:#?}\n");
Ok(())
}

View File

@@ -1,7 +1,6 @@
use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap;
use std::time::Duration;
use tracing_subscriber::EnvFilter;
@@ -23,11 +22,8 @@ async fn main() -> Result<()> {
// Start transport
let handle = transport.start().await?;
// Create the service with timeout middleware
let service = McpService::with_timeout(handle, Duration::from_secs(3));
// Create client
let mut client = McpClient::new(service);
let mut client = McpClient::connect(handle, Duration::from_secs(3)).await?;
println!("Client created\n");
// Initialize

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use anyhow::Result;
use mcp_client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService,
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
StdioTransport, Transport,
};
use std::time::Duration;
@@ -25,11 +25,8 @@ async fn main() -> Result<(), ClientError> {
// 2) Start the transport to get a handle
let transport_handle = transport.start().await?;
// 3) Create the service with timeout middleware
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
// 4) Create the client with the middleware-wrapped service
let mut client = McpClient::new(service);
// 3) Create the client with the middleware-wrapped service
let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
// Initialize
let server_info = client

Some files were not shown because too many files have changed in this diff Show More