mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
merge
This commit is contained in:
18
.github/workflows/build-cli.yml
vendored
18
.github/workflows/build-cli.yml
vendored
@@ -20,6 +20,10 @@ on:
|
||||
type: string
|
||||
required: false
|
||||
default: '["x86_64","aarch64"]'
|
||||
ref:
|
||||
type: string
|
||||
required: false
|
||||
default: 'refs/heads/main'
|
||||
|
||||
name: "Reusable workflow to build CLI"
|
||||
|
||||
@@ -41,6 +45,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Update version in Cargo.toml
|
||||
if: ${{ inputs.version != '' }}
|
||||
@@ -48,14 +55,8 @@ jobs:
|
||||
sed -i.bak 's/^version = ".*"/version = "'${{ inputs.version }}'"/' Cargo.toml
|
||||
rm -f Cargo.toml.bak
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
target: ${{ matrix.architecture }}-${{ matrix.target-suffix }}
|
||||
|
||||
- name: Install cross
|
||||
run: cargo install cross --git https://github.com/cross-rs/cross
|
||||
run: source ./bin/activate-hermit && cargo install cross --git https://github.com/cross-rs/cross
|
||||
|
||||
- name: Build CLI
|
||||
env:
|
||||
@@ -64,6 +65,7 @@ jobs:
|
||||
RUST_BACKTRACE: 1
|
||||
CROSS_VERBOSE: 1
|
||||
run: |
|
||||
source ./bin/activate-hermit
|
||||
export TARGET="${{ matrix.architecture }}-${{ matrix.target-suffix }}"
|
||||
rustup target add "${TARGET}"
|
||||
echo "Building for target: ${TARGET}"
|
||||
@@ -72,7 +74,7 @@ jobs:
|
||||
echo "Cross version:"
|
||||
cross --version
|
||||
|
||||
# 'cross' is used to cross-compile for different architectures (see Cross.toml)
|
||||
echo "Building with explicit PROTOC path..."
|
||||
cross build --release --target ${TARGET} -p goose-cli -vv
|
||||
|
||||
# tar the goose binary as goose-<TARGET>.tar.bz2
|
||||
|
||||
26
.github/workflows/bundle-desktop-intel.yml
vendored
26
.github/workflows/bundle-desktop-intel.yml
vendored
@@ -21,6 +21,10 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
ref:
|
||||
type: string
|
||||
required: false
|
||||
default: 'refs/heads/main'
|
||||
secrets:
|
||||
CERTIFICATE_OSX_APPLICATION:
|
||||
description: 'Certificate for macOS application signing'
|
||||
@@ -77,6 +81,9 @@ jobs:
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# Update versions before build
|
||||
- name: Update versions
|
||||
@@ -87,18 +94,14 @@ jobs:
|
||||
rm -f Cargo.toml.bak
|
||||
|
||||
# Update version in package.json
|
||||
source ./bin/activate-hermit
|
||||
cd ui/desktop
|
||||
npm version ${{ inputs.version }} --no-git-tag-version --allow-same-version
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: x86_64-apple-darwin
|
||||
|
||||
# Pre-build cleanup to ensure enough disk space
|
||||
- name: Pre-build cleanup
|
||||
run: |
|
||||
source ./bin/activate-hermit
|
||||
echo "Performing pre-build cleanup..."
|
||||
# Clean npm cache
|
||||
npm cache clean --force || true
|
||||
@@ -137,7 +140,7 @@ jobs:
|
||||
|
||||
# Build specifically for Intel architecture
|
||||
- name: Build goosed for Intel
|
||||
run: cargo build --release -p goose-server --target x86_64-apple-darwin
|
||||
run: source ./bin/activate-hermit && cargo build --release -p goose-server --target x86_64-apple-darwin
|
||||
|
||||
# Post-build cleanup to free space
|
||||
- name: Post-build cleanup
|
||||
@@ -164,13 +167,8 @@ jobs:
|
||||
CERTIFICATE_OSX_APPLICATION: ${{ secrets.CERTIFICATE_OSX_APPLICATION }}
|
||||
CERTIFICATE_PASSWORD: ${{ secrets.CERTIFICATE_PASSWORD }}
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@7c12f8017d5436eb855f1ed4399f037a36fbd9e8 # pin@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
run: source ../../bin/activate-hermit && npm ci
|
||||
working-directory: ui/desktop
|
||||
|
||||
# Configure Electron builder for Intel architecture
|
||||
@@ -187,6 +185,7 @@ jobs:
|
||||
- name: Make Unsigned App
|
||||
if: ${{ !inputs.signing }}
|
||||
run: |
|
||||
source ../../bin/activate-hermit
|
||||
attempt=0
|
||||
max_attempts=2
|
||||
until [ $attempt -ge $max_attempts ]; do
|
||||
@@ -204,6 +203,7 @@ jobs:
|
||||
- name: Make Signed App
|
||||
if: ${{ inputs.signing }}
|
||||
run: |
|
||||
source ../../bin/activate-hermit
|
||||
attempt=0
|
||||
max_attempts=2
|
||||
until [ $attempt -ge $max_attempts ]; do
|
||||
|
||||
160
.github/workflows/bundle-desktop-windows.yml
vendored
160
.github/workflows/bundle-desktop-windows.yml
vendored
@@ -17,32 +17,31 @@ on:
|
||||
required: false
|
||||
WINDOWS_CERTIFICATE_PASSWORD:
|
||||
required: false
|
||||
ref:
|
||||
type: string
|
||||
required: false
|
||||
default: 'refs/heads/main'
|
||||
|
||||
jobs:
|
||||
build-desktop-windows:
|
||||
name: Build Desktop (Windows)
|
||||
runs-on: windows-latest
|
||||
runs-on: ubuntu-latest # Use Ubuntu for cross-compilation
|
||||
|
||||
steps:
|
||||
# 1) Check out source
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
fetch-depth: 0
|
||||
|
||||
# 2) Set up Rust
|
||||
- name: Set up Rust
|
||||
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b
|
||||
# If you need a specific version, you could do:
|
||||
# or uses: actions/setup-rust@v1
|
||||
# with:
|
||||
# rust-version: 1.73.0
|
||||
|
||||
# 3) Set up Node.js
|
||||
# 2) Set up Node.js
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@1a4442cacd436585916779262731d5b162bc6ec7 # pin@v3
|
||||
with:
|
||||
node-version: 16
|
||||
node-version: 18
|
||||
|
||||
# 4) Cache dependencies (optional, can add more paths if needed)
|
||||
# 3) Cache dependencies
|
||||
- name: Cache node_modules
|
||||
uses: actions/cache@2f8e54208210a422b2efd51efaa6bd6d7ca8920f # pin@v3
|
||||
with:
|
||||
@@ -53,103 +52,92 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-build-desktop-windows-
|
||||
|
||||
# 5) Install top-level dependencies if a package.json is in root
|
||||
- name: Install top-level deps
|
||||
# 4) Build Rust for Windows using Docker (cross-compilation)
|
||||
- name: Build Windows executable using Docker
|
||||
run: |
|
||||
if (Test-Path package.json) {
|
||||
npm install
|
||||
}
|
||||
echo "Building Windows executable using Docker cross-compilation..."
|
||||
docker volume create goose-windows-cache || true
|
||||
docker run --rm \
|
||||
-v "$(pwd)":/usr/src/myapp \
|
||||
-v goose-windows-cache:/usr/local/cargo/registry \
|
||||
-w /usr/src/myapp \
|
||||
rust:latest \
|
||||
sh -c "rustup target add x86_64-pc-windows-gnu && \
|
||||
apt-get update && \
|
||||
apt-get install -y mingw-w64 protobuf-compiler cmake && \
|
||||
export CC_x86_64_pc_windows_gnu=x86_64-w64-mingw32-gcc && \
|
||||
export CXX_x86_64_pc_windows_gnu=x86_64-w64-mingw32-g++ && \
|
||||
export AR_x86_64_pc_windows_gnu=x86_64-w64-mingw32-ar && \
|
||||
export CARGO_TARGET_X86_64_PC_WINDOWS_GNU_LINKER=x86_64-w64-mingw32-gcc && \
|
||||
export PKG_CONFIG_ALLOW_CROSS=1 && \
|
||||
export PROTOC=/usr/bin/protoc && \
|
||||
export PATH=/usr/bin:\$PATH && \
|
||||
protoc --version && \
|
||||
cargo build --release --target x86_64-pc-windows-gnu && \
|
||||
GCC_DIR=\$(ls -d /usr/lib/gcc/x86_64-w64-mingw32/*/ | head -n 1) && \
|
||||
cp \$GCC_DIR/libstdc++-6.dll /usr/src/myapp/target/x86_64-pc-windows-gnu/release/ && \
|
||||
cp \$GCC_DIR/libgcc_s_seh-1.dll /usr/src/myapp/target/x86_64-pc-windows-gnu/release/ && \
|
||||
cp /usr/x86_64-w64-mingw32/lib/libwinpthread-1.dll /usr/src/myapp/target/x86_64-pc-windows-gnu/release/"
|
||||
|
||||
# 6) Build rust for x86_64-pc-windows-gnu
|
||||
- name: Install MinGW dependencies
|
||||
run: |
|
||||
choco install mingw --version=8.1.0
|
||||
# Debug - check installation paths
|
||||
Write-Host "Checking MinGW installation..."
|
||||
Get-ChildItem -Path "C:\ProgramData\chocolatey\lib\mingw" -Recurse -Filter "*.dll" | ForEach-Object {
|
||||
Write-Host $_.FullName
|
||||
}
|
||||
Get-ChildItem -Path "C:\tools" -Recurse -Filter "*.dll" | ForEach-Object {
|
||||
Write-Host $_.FullName
|
||||
}
|
||||
|
||||
- name: Cargo build for Windows
|
||||
run: |
|
||||
cargo build --release --target x86_64-pc-windows-gnu
|
||||
|
||||
# 7) Check that the compiled goosed.exe exists and copy exe/dll to ui/desktop/src/bin
|
||||
# 5) Prepare Windows binary and DLLs
|
||||
- name: Prepare Windows binary and DLLs
|
||||
run: |
|
||||
if (!(Test-Path .\target\x86_64-pc-windows-gnu\release\goosed.exe)) {
|
||||
Write-Error "Windows binary not found."; exit 1;
|
||||
}
|
||||
Write-Host "Copying Windows binary and DLLs to ui/desktop/src/bin..."
|
||||
if (!(Test-Path ui\desktop\src\bin)) {
|
||||
New-Item -ItemType Directory -Path ui\desktop\src\bin | Out-Null
|
||||
}
|
||||
Copy-Item .\target\x86_64-pc-windows-gnu\release\goosed.exe ui\desktop\src\bin\
|
||||
if [ ! -f "./target/x86_64-pc-windows-gnu/release/goosed.exe" ]; then
|
||||
echo "Windows binary not found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Copy MinGW DLLs - try both possible locations
|
||||
$mingwPaths = @(
|
||||
"C:\ProgramData\chocolatey\lib\mingw\tools\install\mingw64\bin",
|
||||
"C:\tools\mingw64\bin"
|
||||
)
|
||||
echo "Cleaning destination directory..."
|
||||
rm -rf ./ui/desktop/src/bin
|
||||
mkdir -p ./ui/desktop/src/bin
|
||||
|
||||
foreach ($path in $mingwPaths) {
|
||||
if (Test-Path "$path\libstdc++-6.dll") {
|
||||
Write-Host "Found MinGW DLLs in $path"
|
||||
Copy-Item "$path\libstdc++-6.dll" ui\desktop\src\bin\
|
||||
Copy-Item "$path\libgcc_s_seh-1.dll" ui\desktop\src\bin\
|
||||
Copy-Item "$path\libwinpthread-1.dll" ui\desktop\src\bin\
|
||||
break
|
||||
}
|
||||
}
|
||||
echo "Copying Windows binary and DLLs..."
|
||||
cp -f ./target/x86_64-pc-windows-gnu/release/goosed.exe ./ui/desktop/src/bin/
|
||||
cp -f ./target/x86_64-pc-windows-gnu/release/*.dll ./ui/desktop/src/bin/
|
||||
|
||||
# Copy any other DLLs from the release directory
|
||||
ls .\target\x86_64-pc-windows-gnu\release\*.dll | ForEach-Object {
|
||||
Copy-Item $_ ui\desktop\src\bin\
|
||||
}
|
||||
# Copy Windows platform files (tools, scripts, etc.)
|
||||
if [ -d "./ui/desktop/src/platform/windows/bin" ]; then
|
||||
echo "Copying Windows platform files..."
|
||||
for file in ./ui/desktop/src/platform/windows/bin/*.{exe,dll,cmd}; do
|
||||
if [ -f "$file" ] && [ "$(basename "$file")" != "goosed.exe" ]; then
|
||||
cp -f "$file" ./ui/desktop/src/bin/
|
||||
fi
|
||||
done
|
||||
|
||||
# 8) Install & build UI desktop
|
||||
if [ -d "./ui/desktop/src/platform/windows/bin/goose-npm" ]; then
|
||||
echo "Setting up npm environment..."
|
||||
rsync -a --delete ./ui/desktop/src/platform/windows/bin/goose-npm/ ./ui/desktop/src/bin/goose-npm/
|
||||
fi
|
||||
echo "Windows-specific files copied successfully"
|
||||
fi
|
||||
|
||||
# 6) Install & build UI desktop
|
||||
- name: Build desktop UI with npm
|
||||
run: |
|
||||
cd ui\desktop
|
||||
cd ui/desktop
|
||||
npm install
|
||||
npm run bundle:windows
|
||||
|
||||
# 9) Copy exe/dll to final out/Goose-win32-x64/resources/bin
|
||||
# 7) Copy exe/dll to final out/Goose-win32-x64/resources/bin
|
||||
- name: Copy exe/dll to out folder
|
||||
run: |
|
||||
cd ui\desktop
|
||||
if (!(Test-Path .\out\Goose-win32-x64\resources\bin)) {
|
||||
New-Item -ItemType Directory -Path .\out\Goose-win32-x64\resources\bin | Out-Null
|
||||
}
|
||||
Copy-Item .\src\bin\goosed.exe .\out\Goose-win32-x64\resources\bin\
|
||||
ls .\src\bin\*.dll | ForEach-Object {
|
||||
Copy-Item $_ .\out\Goose-win32-x64\resources\bin\
|
||||
}
|
||||
cd ui/desktop
|
||||
mkdir -p ./out/Goose-win32-x64/resources/bin
|
||||
rsync -av src/bin/ out/Goose-win32-x64/resources/bin/
|
||||
|
||||
# 10) Code signing (if enabled)
|
||||
# 8) Code signing (if enabled)
|
||||
- name: Sign Windows executable
|
||||
# Skip this step by default - enable when we have a certificate
|
||||
if: inputs.signing && inputs.signing == true
|
||||
env:
|
||||
WINDOWS_CERTIFICATE: ${{ secrets.WINDOWS_CERTIFICATE }}
|
||||
WINDOWS_CERTIFICATE_PASSWORD: ${{ secrets.WINDOWS_CERTIFICATE_PASSWORD }}
|
||||
run: |
|
||||
# Create a temporary certificate file
|
||||
$certBytes = [Convert]::FromBase64String($env:WINDOWS_CERTIFICATE)
|
||||
$certPath = Join-Path -Path $env:RUNNER_TEMP -ChildPath "certificate.pfx"
|
||||
[IO.File]::WriteAllBytes($certPath, $certBytes)
|
||||
# Note: This would need to be adapted for Linux-based signing
|
||||
# or moved to a Windows runner for the signing step only
|
||||
echo "Code signing would be implemented here"
|
||||
echo "Currently skipped as we're running on Ubuntu"
|
||||
|
||||
# Sign the main executable
|
||||
$signtool = "C:\Program Files (x86)\Windows Kits\10\bin\10.0.17763.0\x64\signtool.exe"
|
||||
& $signtool sign /f $certPath /p $env:WINDOWS_CERTIFICATE_PASSWORD /tr http://timestamp.digicert.com /td sha256 /fd sha256 "ui\desktop\out\Goose-win32-x64\Goose.exe"
|
||||
|
||||
# Clean up the certificate
|
||||
Remove-Item -Path $certPath
|
||||
|
||||
# 11) Upload the final Windows build
|
||||
# 9) Upload the final Windows build
|
||||
- name: Upload Windows build artifacts
|
||||
uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # pin@v4
|
||||
with:
|
||||
|
||||
23
.github/workflows/bundle-desktop.yml
vendored
23
.github/workflows/bundle-desktop.yml
vendored
@@ -135,6 +135,7 @@ jobs:
|
||||
sed -i.bak "s/^version = \".*\"/version = \"${VERSION}\"/" Cargo.toml
|
||||
rm -f Cargo.toml.bak
|
||||
|
||||
source ./bin/activate-hermit
|
||||
# Update version in package.json
|
||||
cd ui/desktop
|
||||
npm version "${VERSION}" --no-git-tag-version --allow-same-version
|
||||
@@ -142,6 +143,7 @@ jobs:
|
||||
# Pre-build cleanup to ensure enough disk space
|
||||
- name: Pre-build cleanup
|
||||
run: |
|
||||
source ./bin/activate-hermit
|
||||
echo "Performing pre-build cleanup..."
|
||||
# Clean npm cache
|
||||
npm cache clean --force || true
|
||||
@@ -154,16 +156,6 @@ jobs:
|
||||
# Check disk space after cleanup
|
||||
df -h
|
||||
|
||||
- name: Install protobuf
|
||||
run: |
|
||||
brew install protobuf
|
||||
echo "PROTOC=$(which protoc)" >> $GITHUB_ENV
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Cache Cargo registry
|
||||
uses: actions/cache@2f8e54208210a422b2efd51efaa6bd6d7ca8920f # pin@v3
|
||||
with:
|
||||
@@ -190,7 +182,7 @@ jobs:
|
||||
|
||||
# Build the project
|
||||
- name: Build goosed
|
||||
run: cargo build --release -p goose-server
|
||||
run: source ./bin/activate-hermit && cargo build --release -p goose-server
|
||||
|
||||
# Post-build cleanup to free space
|
||||
- name: Post-build cleanup
|
||||
@@ -216,13 +208,8 @@ jobs:
|
||||
CERTIFICATE_OSX_APPLICATION: ${{ secrets.CERTIFICATE_OSX_APPLICATION }}
|
||||
CERTIFICATE_PASSWORD: ${{ secrets.CERTIFICATE_PASSWORD }}
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@7c12f8017d5436eb855f1ed4399f037a36fbd9e8 # pin@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
run: source ../../bin/activate-hermit && npm ci
|
||||
working-directory: ui/desktop
|
||||
|
||||
# Check disk space before bundling
|
||||
@@ -232,6 +219,7 @@ jobs:
|
||||
- name: Make Unsigned App
|
||||
if: ${{ !inputs.signing }}
|
||||
run: |
|
||||
source ../../bin/activate-hermit
|
||||
attempt=0
|
||||
max_attempts=2
|
||||
until [ $attempt -ge $max_attempts ]; do
|
||||
@@ -253,6 +241,7 @@ jobs:
|
||||
APPLE_ID_PASSWORD: ${{ secrets.APPLE_ID_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||
run: |
|
||||
|
||||
attempt=0
|
||||
max_attempts=2
|
||||
until [ $attempt -ge $max_attempts ]; do
|
||||
|
||||
34
.github/workflows/ci.yml
vendored
34
.github/workflows/ci.yml
vendored
@@ -5,6 +5,9 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
merge_group:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
name: CI
|
||||
@@ -17,19 +20,14 @@ jobs:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Run cargo fmt
|
||||
run: cargo fmt --check
|
||||
run: source ./bin/activate-hermit && cargo fmt --check
|
||||
|
||||
rust-build-and-test:
|
||||
name: Build and Test Rust Project
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Add disk space cleanup before linting
|
||||
# Add disk space cleanup before linting
|
||||
- name: Check disk space before build
|
||||
run: df -h
|
||||
|
||||
@@ -57,12 +55,7 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
sudo apt update -y
|
||||
sudo apt install -y libdbus-1-dev gnome-keyring libxcb1-dev protobuf-compiler
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@38b70195107dddab2c7bbd522bcf763bac00963b # pin@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
sudo apt install -y libdbus-1-dev gnome-keyring libxcb1-dev
|
||||
|
||||
- name: Cache Cargo Registry
|
||||
uses: actions/cache@2f8e54208210a422b2efd51efaa6bd6d7ca8920f # pin@v3
|
||||
@@ -91,7 +84,7 @@ jobs:
|
||||
- name: Build and Test
|
||||
run: |
|
||||
gnome-keyring-daemon --components=secrets --daemonize --unlock <<< 'foobar'
|
||||
cargo test
|
||||
source ../bin/activate-hermit && cargo test
|
||||
working-directory: crates
|
||||
|
||||
# Add disk space cleanup before linting
|
||||
@@ -120,7 +113,7 @@ jobs:
|
||||
run: df -h
|
||||
|
||||
- name: Lint
|
||||
run: cargo clippy -- -D warnings
|
||||
run: source ./bin/activate-hermit && cargo clippy -- -D warnings
|
||||
|
||||
desktop-lint:
|
||||
name: Lint Electron Desktop App
|
||||
@@ -129,22 +122,17 @@ jobs:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@7c12f8017d5436eb855f1ed4399f037a36fbd9e8 # pin@v2
|
||||
with:
|
||||
node-version: "lts/*"
|
||||
|
||||
- name: Install Dependencies
|
||||
run: npm ci
|
||||
run: source ../../bin/activate-hermit && npm ci
|
||||
working-directory: ui/desktop
|
||||
|
||||
- name: Run Lint
|
||||
run: npm run lint:check
|
||||
run: source ../../bin/activate-hermit && npm run lint:check
|
||||
working-directory: ui/desktop
|
||||
|
||||
# Faster Desktop App build for PRs only
|
||||
bundle-desktop-unsigned:
|
||||
uses: ./.github/workflows/bundle-desktop.yml
|
||||
if: github.event_name == 'pull_request'
|
||||
if: github.event_name == 'pull_request' || github.event_name == 'merge_group'
|
||||
with:
|
||||
signing: false
|
||||
|
||||
17
.github/workflows/pr-comment-build-cli.yml
vendored
17
.github/workflows/pr-comment-build-cli.yml
vendored
@@ -27,6 +27,7 @@ jobs:
|
||||
outputs:
|
||||
continue: ${{ steps.command.outputs.continue || github.event_name == 'workflow_dispatch' }}
|
||||
pr_number: ${{ steps.command.outputs.issue_number || github.event.inputs.pr_number }}
|
||||
head_sha: ${{ steps.set_head_sha.outputs.head_sha || github.sha }}
|
||||
steps:
|
||||
- if: ${{ github.event_name == 'issue_comment' }}
|
||||
uses: github/command@v1.3.0
|
||||
@@ -37,10 +38,26 @@ jobs:
|
||||
reaction: "eyes"
|
||||
allowed_contexts: pull_request
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
|
||||
|
||||
- name: Get PR head SHA with gh
|
||||
id: set_head_sha
|
||||
run: |
|
||||
echo "Get PR head SHA with gh"
|
||||
HEAD_SHA=$(gh pr view "$ISSUE_NUMBER" --json headRefOid -q .headRefOid)
|
||||
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||
echo "head_sha=$HEAD_SHA"
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
ISSUE_NUMBER: ${{ steps.command.outputs.issue_number }}
|
||||
|
||||
build-cli:
|
||||
needs: [trigger-on-command]
|
||||
if: ${{ needs.trigger-on-command.outputs.continue == 'true' }}
|
||||
uses: ./.github/workflows/build-cli.yml
|
||||
with:
|
||||
ref: ${{ needs.trigger-on-command.outputs.head_sha }}
|
||||
|
||||
pr-comment-cli:
|
||||
name: PR Comment with CLI builds
|
||||
|
||||
16
.github/workflows/pr-comment-bundle-intel.yml
vendored
16
.github/workflows/pr-comment-bundle-intel.yml
vendored
@@ -30,6 +30,7 @@ jobs:
|
||||
continue: ${{ steps.command.outputs.continue || github.event_name == 'workflow_dispatch' }}
|
||||
# Cannot use github.event.pull_request.number since the trigger is 'issue_comment'
|
||||
pr_number: ${{ steps.command.outputs.issue_number || github.event.inputs.pr_number }}
|
||||
head_sha: ${{ steps.set_head_sha.outputs.head_sha || github.sha }}
|
||||
steps:
|
||||
- if: ${{ github.event_name == 'issue_comment' }}
|
||||
uses: github/command@319d5236cc34ed2cb72a47c058a363db0b628ebe # pin@v1.3.0
|
||||
@@ -40,6 +41,20 @@ jobs:
|
||||
reaction: "eyes"
|
||||
allowed_contexts: pull_request
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
|
||||
|
||||
- name: Get PR head SHA with gh
|
||||
id: set_head_sha
|
||||
run: |
|
||||
echo "Get PR head SHA with gh"
|
||||
HEAD_SHA=$(gh pr view "$ISSUE_NUMBER" --json headRefOid -q .headRefOid)
|
||||
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||
echo "head_sha=$HEAD_SHA"
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
ISSUE_NUMBER: ${{ steps.command.outputs.issue_number }}
|
||||
|
||||
bundle-desktop-intel:
|
||||
# Only run this if ".bundle-intel" command is detected.
|
||||
needs: [trigger-on-command]
|
||||
@@ -47,6 +62,7 @@ jobs:
|
||||
uses: ./.github/workflows/bundle-desktop-intel.yml
|
||||
with:
|
||||
signing: true
|
||||
ref: ${{ needs.trigger-on-command.outputs.head_sha }}
|
||||
secrets:
|
||||
CERTIFICATE_OSX_APPLICATION: ${{ secrets.CERTIFICATE_OSX_APPLICATION }}
|
||||
CERTIFICATE_PASSWORD: ${{ secrets.CERTIFICATE_PASSWORD }}
|
||||
|
||||
16
.github/workflows/pr-comment-bundle-windows.yml
vendored
16
.github/workflows/pr-comment-bundle-windows.yml
vendored
@@ -30,6 +30,7 @@ jobs:
|
||||
continue: ${{ steps.command.outputs.continue || github.event_name == 'workflow_dispatch' }}
|
||||
# Cannot use github.event.pull_request.number since the trigger is 'issue_comment'
|
||||
pr_number: ${{ steps.command.outputs.issue_number || github.event.inputs.pr_number }}
|
||||
head_sha: ${{ steps.set_head_sha.outputs.head_sha || github.sha }}
|
||||
steps:
|
||||
- if: ${{ github.event_name == 'issue_comment' }}
|
||||
uses: github/command@319d5236cc34ed2cb72a47c058a363db0b628ebe # pin@v1.3.0
|
||||
@@ -40,6 +41,20 @@ jobs:
|
||||
reaction: "eyes"
|
||||
allowed_contexts: pull_request
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4
|
||||
|
||||
- name: Get PR head SHA with gh
|
||||
id: set_head_sha
|
||||
run: |
|
||||
echo "Get PR head SHA with gh"
|
||||
HEAD_SHA=$(gh pr view "$ISSUE_NUMBER" --json headRefOid -q .headRefOid)
|
||||
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||
echo "head_sha=$HEAD_SHA"
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
ISSUE_NUMBER: ${{ steps.command.outputs.issue_number }}
|
||||
|
||||
bundle-desktop-windows:
|
||||
# Only run this if ".bundle-windows" command is detected.
|
||||
needs: [trigger-on-command]
|
||||
@@ -47,6 +62,7 @@ jobs:
|
||||
uses: ./.github/workflows/bundle-desktop-windows.yml
|
||||
with:
|
||||
signing: false # false for now as we don't have a cert yet
|
||||
ref: ${{ needs.trigger-on-command.outputs.head_sha }}
|
||||
secrets:
|
||||
WINDOWS_CERTIFICATE: ${{ secrets.WINDOWS_CERTIFICATE }}
|
||||
WINDOWS_CERTIFICATE_PASSWORD: ${{ secrets.WINDOWS_CERTIFICATE_PASSWORD }}
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -23,9 +23,12 @@ target/
|
||||
./ui/desktop/node_modules
|
||||
./ui/desktop/out
|
||||
|
||||
# Generated Goose DLLs (built at build time, not checked in)
|
||||
ui/desktop/src/bin/goose_ffi.dll
|
||||
ui/desktop/src/bin/goose_llm.dll
|
||||
|
||||
# Hermit
|
||||
/.hermit/
|
||||
/bin/
|
||||
.hermit/
|
||||
|
||||
debug_*.txt
|
||||
|
||||
|
||||
@@ -2,12 +2,20 @@
|
||||
|
||||
# Only auto-format desktop TS code if relevant files are modified
|
||||
if git diff --cached --name-only | grep -q "^ui/desktop/"; then
|
||||
. "$(dirname -- "$0")/_/husky.sh"
|
||||
cd ui/desktop && npx lint-staged
|
||||
if [ -d "ui/desktop" ]; then
|
||||
. "$(dirname -- "$0")/_/husky.sh"
|
||||
cd ui/desktop && npx lint-staged
|
||||
else
|
||||
echo "Warning: ui/desktop directory does not exist, skipping lint-staged"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Only auto-format ui-v2 TS code if relevant files are modified
|
||||
if git diff --cached --name-only | grep -q "^ui-v2/"; then
|
||||
. "$(dirname -- "$0")/_/husky.sh"
|
||||
cd ui-v2 && npx lint-staged
|
||||
if [ -d "ui-v2" ]; then
|
||||
. "$(dirname -- "$0")/_/husky.sh"
|
||||
cd ui-v2 && npx lint-staged
|
||||
else
|
||||
echo "Warning: ui-v2 directory does not exist, skipping lint-staged"
|
||||
fi
|
||||
fi
|
||||
|
||||
72
Cargo.lock
generated
72
Cargo.lock
generated
@@ -2083,8 +2083,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
|
||||
source = "git+https://github.com/nmathewson/crunchy?branch=cross-compilation-fix#260ec5f08969480c342bb3fe47f88870ed5c6cce"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
@@ -3452,12 +3451,13 @@ dependencies = [
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-cron-scheduler",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"utoipa",
|
||||
"uuid",
|
||||
"webbrowser",
|
||||
"webbrowser 0.8.15",
|
||||
"winapi",
|
||||
"wiremock",
|
||||
]
|
||||
@@ -3514,8 +3514,10 @@ version = "1.0.24"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"base64 0.22.1",
|
||||
"bat",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"clap 4.5.31",
|
||||
"cliclack",
|
||||
@@ -3525,6 +3527,8 @@ dependencies = [
|
||||
"goose",
|
||||
"goose-bench",
|
||||
"goose-mcp",
|
||||
"http 1.2.0",
|
||||
"indicatif",
|
||||
"mcp-client",
|
||||
"mcp-core",
|
||||
"mcp-server",
|
||||
@@ -3544,9 +3548,12 @@ dependencies = [
|
||||
"tempfile",
|
||||
"test-case",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
"webbrowser 1.0.4",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
@@ -3639,7 +3646,7 @@ dependencies = [
|
||||
"url",
|
||||
"urlencoding",
|
||||
"utoipa",
|
||||
"webbrowser",
|
||||
"webbrowser 0.8.15",
|
||||
"xcap",
|
||||
]
|
||||
|
||||
@@ -3941,6 +3948,12 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-range-header"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c"
|
||||
|
||||
[[package]]
|
||||
name = "httparse"
|
||||
version = "1.10.1"
|
||||
@@ -5902,6 +5915,31 @@ dependencies = [
|
||||
"malloc_buf",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc2"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88c6597e14493ab2e44ce58f2fdecf095a51f12ca57bec060a11c57332520551"
|
||||
dependencies = [
|
||||
"objc2-encode",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc2-encode"
|
||||
version = "4.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
|
||||
|
||||
[[package]]
|
||||
name = "objc2-foundation"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"objc2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.7"
|
||||
@@ -8727,12 +8765,21 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.2.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"http-range-header",
|
||||
"httpdate",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -9420,6 +9467,23 @@ dependencies = [
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webbrowser"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d5df295f8451142f1856b1bd86a606dfe9587d439bc036e319c827700dbd555e"
|
||||
dependencies = [
|
||||
"core-foundation 0.10.0",
|
||||
"home",
|
||||
"jni",
|
||||
"log",
|
||||
"ndk-context",
|
||||
"objc2",
|
||||
"objc2-foundation",
|
||||
"url",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.8"
|
||||
|
||||
@@ -9,3 +9,7 @@ authors = ["Block <ai-oss-tools@block.xyz>"]
|
||||
license = "Apache-2.0"
|
||||
repository = "https://github.com/block/goose"
|
||||
description = "An AI agent"
|
||||
|
||||
# Patch for Windows cross-compilation issue with crunchy
|
||||
[patch.crates-io]
|
||||
crunchy = { git = "https://github.com/nmathewson/crunchy", branch = "cross-compilation-fix" }
|
||||
44
Cross.toml
44
Cross.toml
@@ -6,29 +6,59 @@ pre-build = [
|
||||
"dpkg --add-architecture arm64",
|
||||
"""\
|
||||
apt-get update --fix-missing && apt-get install -y \
|
||||
curl \
|
||||
unzip \
|
||||
pkg-config \
|
||||
libssl-dev:arm64 \
|
||||
libdbus-1-dev:arm64 \
|
||||
libxcb1-dev:arm64
|
||||
""",
|
||||
"""\
|
||||
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v31.1/protoc-31.1-linux-x86_64.zip && \
|
||||
unzip -o protoc-31.1-linux-x86_64.zip -d /usr/local && \
|
||||
chmod +x /usr/local/bin/protoc && \
|
||||
ln -sf /usr/local/bin/protoc /usr/bin/protoc && \
|
||||
which protoc && \
|
||||
protoc --version
|
||||
"""
|
||||
]
|
||||
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
xargo = false
|
||||
pre-build = [
|
||||
# Install necessary dependencies for x86_64
|
||||
# We don't need architecture-specific flags because x86_64 dependencies are installable on Ubuntu system
|
||||
"""\
|
||||
apt-get update && apt-get install -y \
|
||||
curl \
|
||||
unzip \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
libdbus-1-dev \
|
||||
libxcb1-dev \
|
||||
libxcb1-dev
|
||||
""",
|
||||
"""\
|
||||
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v31.1/protoc-31.1-linux-x86_64.zip && \
|
||||
unzip -o protoc-31.1-linux-x86_64.zip -d /usr/local && \
|
||||
chmod +x /usr/local/bin/protoc && \
|
||||
ln -sf /usr/local/bin/protoc /usr/bin/protoc && \
|
||||
which protoc && \
|
||||
protoc --version
|
||||
"""
|
||||
]
|
||||
|
||||
[target.x86_64-pc-windows-gnu]
|
||||
image = "dockcross/windows-static-x64:latest"
|
||||
# Enable verbose output for Windows builds
|
||||
build-std = true
|
||||
env = { "RUST_LOG" = "debug", "RUST_BACKTRACE" = "1", "CROSS_VERBOSE" = "1" }
|
||||
image = "ghcr.io/cross-rs/x86_64-pc-windows-gnu:latest"
|
||||
env = { "RUST_LOG" = "debug", "RUST_BACKTRACE" = "1", "CROSS_VERBOSE" = "1", "PKG_CONFIG_ALLOW_CROSS" = "1" }
|
||||
pre-build = [
|
||||
"""\
|
||||
apt-get update && apt-get install -y \
|
||||
curl \
|
||||
unzip
|
||||
""",
|
||||
"""\
|
||||
curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v31.1/protoc-31.1-linux-x86_64.zip && \
|
||||
unzip protoc-31.1-linux-x86_64.zip -d /usr/local && \
|
||||
chmod +x /usr/local/bin/protoc && \
|
||||
export PROTOC=/usr/local/bin/protoc && \
|
||||
protoc --version
|
||||
"""
|
||||
]
|
||||
|
||||
10
Justfile
10
Justfile
@@ -25,7 +25,15 @@ release-windows:
|
||||
rust:latest \
|
||||
sh -c "rustup target add x86_64-pc-windows-gnu && \
|
||||
apt-get update && \
|
||||
apt-get install -y mingw-w64 && \
|
||||
apt-get install -y mingw-w64 protobuf-compiler cmake && \
|
||||
export CC_x86_64_pc_windows_gnu=x86_64-w64-mingw32-gcc && \
|
||||
export CXX_x86_64_pc_windows_gnu=x86_64-w64-mingw32-g++ && \
|
||||
export AR_x86_64_pc_windows_gnu=x86_64-w64-mingw32-ar && \
|
||||
export CARGO_TARGET_X86_64_PC_WINDOWS_GNU_LINKER=x86_64-w64-mingw32-gcc && \
|
||||
export PKG_CONFIG_ALLOW_CROSS=1 && \
|
||||
export PROTOC=/usr/bin/protoc && \
|
||||
export PATH=/usr/bin:\$PATH && \
|
||||
protoc --version && \
|
||||
cargo build --release --target x86_64-pc-windows-gnu && \
|
||||
GCC_DIR=\$(ls -d /usr/lib/gcc/x86_64-w64-mingw32/*/ | head -n 1) && \
|
||||
cp \$GCC_DIR/libstdc++-6.dll /usr/src/myapp/target/x86_64-pc-windows-gnu/release/ && \
|
||||
|
||||
25
README.md
25
README.md
@@ -23,6 +23,31 @@ Whether you're prototyping an idea, refining existing code, or managing intricat
|
||||
|
||||
Designed for maximum flexibility, goose works with any LLM, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation.
|
||||
|
||||
## Multiple Model Configuration
|
||||
|
||||
goose supports using different models for different purposes to optimize performance and cost, which can work across model providers as well as models.
|
||||
|
||||
### Lead/Worker Model Pattern
|
||||
Use a powerful model for initial planning and complex reasoning, then switch to a faster/cheaper model for execution, this happens automatically by goose:
|
||||
|
||||
```bash
|
||||
# Required: Enable lead model mode
|
||||
export GOOSE_LEAD_MODEL=modelY
|
||||
# Optional: configure a provider for the lead model if not the default provider
|
||||
export GOOSE_LEAD_PROVIDER=providerX # Defaults to main provider
|
||||
```
|
||||
|
||||
### Planning Model Configuration
|
||||
Use a specialized model for the `/plan` command in CLI mode, this is explicitly invoked when you want to plan (vs execute)
|
||||
|
||||
```bash
|
||||
# Optional: Use different model for planning
|
||||
export GOOSE_PLANNER_PROVIDER=openai
|
||||
export GOOSE_PLANNER_MODEL=gpt-4
|
||||
```
|
||||
|
||||
Both patterns help you balance model capabilities with cost and speed for optimal results, and switch between models and vendors as required.
|
||||
|
||||
|
||||
# Quick Links
|
||||
- [Quickstart](https://block.github.io/goose/docs/quickstart)
|
||||
|
||||
1
bin/.node-22.9.0.pkg
Symbolic link
1
bin/.node-22.9.0.pkg
Symbolic link
@@ -0,0 +1 @@
|
||||
hermit
|
||||
1
bin/.protoc-31.1.pkg
Symbolic link
1
bin/.protoc-31.1.pkg
Symbolic link
@@ -0,0 +1 @@
|
||||
hermit
|
||||
1
bin/.rustup-1.25.2.pkg
Symbolic link
1
bin/.rustup-1.25.2.pkg
Symbolic link
@@ -0,0 +1 @@
|
||||
hermit
|
||||
7
bin/README.hermit.md
Normal file
7
bin/README.hermit.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Hermit environment
|
||||
|
||||
This is a [Hermit](https://github.com/cashapp/hermit) bin directory.
|
||||
|
||||
The symlinks in this directory are managed by Hermit and will automatically
|
||||
download and install Hermit itself as well as packages. These packages are
|
||||
local to this environment.
|
||||
21
bin/activate-hermit
Executable file
21
bin/activate-hermit
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# This file must be used with "source bin/activate-hermit" from bash or zsh.
|
||||
# You cannot run it directly
|
||||
#
|
||||
# THIS FILE IS GENERATED; DO NOT MODIFY
|
||||
|
||||
if [ "${BASH_SOURCE-}" = "$0" ]; then
|
||||
echo "You must source this script: \$ source $0" >&2
|
||||
exit 33
|
||||
fi
|
||||
|
||||
BIN_DIR="$(dirname "${BASH_SOURCE[0]:-${(%):-%x}}")"
|
||||
if "${BIN_DIR}/hermit" noop > /dev/null; then
|
||||
eval "$("${BIN_DIR}/hermit" activate "${BIN_DIR}/..")"
|
||||
|
||||
if [ -n "${BASH-}" ] || [ -n "${ZSH_VERSION-}" ]; then
|
||||
hash -r 2>/dev/null
|
||||
fi
|
||||
|
||||
echo "Hermit environment $("${HERMIT_ENV}"/bin/hermit env HERMIT_ENV) activated"
|
||||
fi
|
||||
24
bin/activate-hermit.fish
Executable file
24
bin/activate-hermit.fish
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env fish
|
||||
|
||||
# This file must be sourced with "source bin/activate-hermit.fish" from Fish shell.
|
||||
# You cannot run it directly.
|
||||
#
|
||||
# THIS FILE IS GENERATED; DO NOT MODIFY
|
||||
|
||||
if status is-interactive
|
||||
set BIN_DIR (dirname (status --current-filename))
|
||||
|
||||
if "$BIN_DIR/hermit" noop > /dev/null
|
||||
# Source the activation script generated by Hermit
|
||||
"$BIN_DIR/hermit" activate "$BIN_DIR/.." | source
|
||||
|
||||
# Clear the command cache if applicable
|
||||
functions -c > /dev/null 2>&1
|
||||
|
||||
# Display activation message
|
||||
echo "Hermit environment $($HERMIT_ENV/bin/hermit env HERMIT_ENV) activated"
|
||||
end
|
||||
else
|
||||
echo "You must source this script: source $argv[0]" >&2
|
||||
exit 33
|
||||
end
|
||||
1
bin/cargo-clippy
Symbolic link
1
bin/cargo-clippy
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/cargo-fmt
Symbolic link
1
bin/cargo-fmt
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/cargo-miri
Symbolic link
1
bin/cargo-miri
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/clippy-driver
Symbolic link
1
bin/clippy-driver
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/corepack
Symbolic link
1
bin/corepack
Symbolic link
@@ -0,0 +1 @@
|
||||
.node-22.9.0.pkg
|
||||
43
bin/hermit
Executable file
43
bin/hermit
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# THIS FILE IS GENERATED; DO NOT MODIFY
|
||||
|
||||
set -eo pipefail
|
||||
|
||||
export HERMIT_USER_HOME=~
|
||||
|
||||
if [ -z "${HERMIT_STATE_DIR}" ]; then
|
||||
case "$(uname -s)" in
|
||||
Darwin)
|
||||
export HERMIT_STATE_DIR="${HERMIT_USER_HOME}/Library/Caches/hermit"
|
||||
;;
|
||||
Linux)
|
||||
export HERMIT_STATE_DIR="${XDG_CACHE_HOME:-${HERMIT_USER_HOME}/.cache}/hermit"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
export HERMIT_DIST_URL="${HERMIT_DIST_URL:-https://github.com/cashapp/hermit/releases/download/stable}"
|
||||
HERMIT_CHANNEL="$(basename "${HERMIT_DIST_URL}")"
|
||||
export HERMIT_CHANNEL
|
||||
export HERMIT_EXE=${HERMIT_EXE:-${HERMIT_STATE_DIR}/pkg/hermit@${HERMIT_CHANNEL}/hermit}
|
||||
|
||||
if [ ! -x "${HERMIT_EXE}" ]; then
|
||||
echo "Bootstrapping ${HERMIT_EXE} from ${HERMIT_DIST_URL}" 1>&2
|
||||
INSTALL_SCRIPT="$(mktemp)"
|
||||
# This value must match that of the install script
|
||||
INSTALL_SCRIPT_SHA256="09ed936378857886fd4a7a4878c0f0c7e3d839883f39ca8b4f2f242e3126e1c6"
|
||||
if [ "${INSTALL_SCRIPT_SHA256}" = "BYPASS" ]; then
|
||||
curl -fsSL "${HERMIT_DIST_URL}/install.sh" -o "${INSTALL_SCRIPT}"
|
||||
else
|
||||
# Install script is versioned by its sha256sum value
|
||||
curl -fsSL "${HERMIT_DIST_URL}/install-${INSTALL_SCRIPT_SHA256}.sh" -o "${INSTALL_SCRIPT}"
|
||||
# Verify install script's sha256sum
|
||||
openssl dgst -sha256 "${INSTALL_SCRIPT}" | \
|
||||
awk -v EXPECTED="$INSTALL_SCRIPT_SHA256" \
|
||||
'$2!=EXPECTED {print "Install script sha256 " $2 " does not match " EXPECTED; exit 1}'
|
||||
fi
|
||||
/bin/bash "${INSTALL_SCRIPT}" 1>&2
|
||||
fi
|
||||
|
||||
exec "${HERMIT_EXE}" --level=fatal exec "$0" -- "$@"
|
||||
4
bin/hermit.hcl
Normal file
4
bin/hermit.hcl
Normal file
@@ -0,0 +1,4 @@
|
||||
manage-git = false
|
||||
|
||||
github-token-auth {
|
||||
}
|
||||
1
bin/protoc
Symbolic link
1
bin/protoc
Symbolic link
@@ -0,0 +1 @@
|
||||
.protoc-31.1.pkg
|
||||
1
bin/rust-analyzer
Symbolic link
1
bin/rust-analyzer
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/rust-gdb
Symbolic link
1
bin/rust-gdb
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/rust-gdbgui
Symbolic link
1
bin/rust-gdbgui
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/rust-lldb
Symbolic link
1
bin/rust-lldb
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/rustdoc
Symbolic link
1
bin/rustdoc
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/rustfmt
Symbolic link
1
bin/rustfmt
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
1
bin/rustup
Symbolic link
1
bin/rustup
Symbolic link
@@ -0,0 +1 @@
|
||||
.rustup-1.25.2.pkg
|
||||
@@ -55,6 +55,15 @@ regex = "1.11.1"
|
||||
minijinja = "2.8.0"
|
||||
nix = { version = "0.30.1", features = ["process", "signal"] }
|
||||
tar = "0.4"
|
||||
# Web server dependencies
|
||||
axum = { version = "0.8.1", features = ["ws", "macros"] }
|
||||
tower-http = { version = "0.5", features = ["cors", "fs"] }
|
||||
tokio-stream = "0.1"
|
||||
bytes = "1.5"
|
||||
http = "1.0"
|
||||
webbrowser = "1.0"
|
||||
|
||||
indicatif = "0.17.11"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
78
crates/goose-cli/WEB_INTERFACE.md
Normal file
78
crates/goose-cli/WEB_INTERFACE.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# Goose Web Interface
|
||||
|
||||
The `goose web` command provides a (preview) web-based chat interface for interacting with Goose.
|
||||
Do not expose this publicly - this is in a preview state as an option.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Start the web server on default port (3000)
|
||||
goose web
|
||||
|
||||
# Start on a specific port
|
||||
goose web --port 8080
|
||||
|
||||
# Start and automatically open in browser
|
||||
goose web --open
|
||||
|
||||
# Bind to a specific host
|
||||
goose web --host 0.0.0.0 --port 8080
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **Real-time chat interface**: Communicate with Goose through a clean web UI
|
||||
- **WebSocket support**: Real-time message streaming
|
||||
- **Session management**: Each browser tab maintains its own session
|
||||
- **Responsive design**: Works on desktop and mobile devices
|
||||
|
||||
## Architecture
|
||||
|
||||
The web interface is built with:
|
||||
- **Backend**: Rust with Axum web framework
|
||||
- **Frontend**: Vanilla JavaScript with WebSocket communication
|
||||
- **Styling**: CSS with dark/light mode support
|
||||
|
||||
## Development Notes
|
||||
|
||||
### Current Implementation
|
||||
|
||||
The web interface provides:
|
||||
1. A simple chat UI similar to the desktop Electron app
|
||||
2. WebSocket-based real-time communication
|
||||
3. Basic session management (messages are stored in memory)
|
||||
|
||||
### Future Enhancements
|
||||
|
||||
- [ ] Persistent session storage
|
||||
- [ ] Tool call visualization
|
||||
- [ ] File upload support
|
||||
- [ ] Multiple session tabs
|
||||
- [ ] Authentication/authorization
|
||||
- [ ] Streaming responses with proper formatting
|
||||
- [ ] Code syntax highlighting
|
||||
- [ ] Export chat history
|
||||
|
||||
### Integration with Goose Agent
|
||||
|
||||
The web server creates an instance of the Goose Agent and processes messages through the same pipeline as the CLI. However, some features like:
|
||||
- Extension management
|
||||
- Tool confirmations
|
||||
- File system interactions
|
||||
|
||||
...may require additional UI components to be fully functional.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
Currently, the web interface:
|
||||
- Binds to localhost by default for security
|
||||
- Does not include authentication (planned for future)
|
||||
- Should not be exposed to the internet without proper security measures
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you encounter issues:
|
||||
|
||||
1. **Port already in use**: Try a different port with `--port`
|
||||
2. **Cannot connect**: Ensure no firewall is blocking the port
|
||||
3. **Agent not configured**: Run `goose configure` first to set up a provider
|
||||
@@ -34,7 +34,7 @@ struct Cli {
|
||||
command: Option<Command>,
|
||||
}
|
||||
|
||||
#[derive(Args)]
|
||||
#[derive(Args, Debug)]
|
||||
#[group(required = false, multiple = false)]
|
||||
struct Identifier {
|
||||
#[arg(
|
||||
@@ -102,6 +102,19 @@ enum SessionCommand {
|
||||
#[arg(short, long, help = "Regex for removing matched sessions (optional)")]
|
||||
regex: Option<String>,
|
||||
},
|
||||
#[command(about = "Export a session to Markdown format")]
|
||||
Export {
|
||||
#[command(flatten)]
|
||||
identifier: Option<Identifier>,
|
||||
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
help = "Output file path (default: stdout)",
|
||||
long_help = "Path to save the exported Markdown. If not provided, output will be sent to stdout"
|
||||
)]
|
||||
output: Option<PathBuf>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
@@ -491,6 +504,31 @@ enum Command {
|
||||
#[command(subcommand)]
|
||||
cmd: BenchCommand,
|
||||
},
|
||||
|
||||
/// Start a web server with a chat interface
|
||||
#[command(about = "Start a web server with a chat interface", hide = true)]
|
||||
Web {
|
||||
/// Port to run the web server on
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
default_value = "3000",
|
||||
help = "Port to run the web server on"
|
||||
)]
|
||||
port: u16,
|
||||
|
||||
/// Host to bind the web server to
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "127.0.0.1",
|
||||
help = "Host to bind the web server to"
|
||||
)]
|
||||
host: String,
|
||||
|
||||
/// Open browser automatically
|
||||
#[arg(long, help = "Open browser automatically when server starts")]
|
||||
open: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Clone, Debug)]
|
||||
@@ -550,6 +588,23 @@ pub async fn cli() -> Result<()> {
|
||||
handle_session_remove(id, regex)?;
|
||||
return Ok(());
|
||||
}
|
||||
Some(SessionCommand::Export { identifier, output }) => {
|
||||
let session_identifier = if let Some(id) = identifier {
|
||||
extract_identifier(id)
|
||||
} else {
|
||||
// If no identifier is provided, prompt for interactive selection
|
||||
match crate::commands::session::prompt_interactive_session_selection() {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
eprintln!("Error: {}", e);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
crate::commands::session::handle_session_export(session_identifier, output)?;
|
||||
Ok(())
|
||||
}
|
||||
None => {
|
||||
// Run session command by default
|
||||
let mut session: crate::Session = build_session(SessionBuilderConfig {
|
||||
@@ -755,6 +810,10 @@ pub async fn cli() -> Result<()> {
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Web { port, host, open }) => {
|
||||
crate::commands::web::handle_web(port, host, open).await?;
|
||||
return Ok(());
|
||||
}
|
||||
None => {
|
||||
return if !Config::global().exists() {
|
||||
let _ = handle_configure().await;
|
||||
|
||||
@@ -7,3 +7,4 @@ pub mod recipe;
|
||||
pub mod schedule;
|
||||
pub mod session;
|
||||
pub mod update;
|
||||
pub mod web;
|
||||
|
||||
@@ -34,6 +34,8 @@ pub async fn handle_schedule_add(
|
||||
last_run: None,
|
||||
currently_running: false,
|
||||
paused: false,
|
||||
current_session_id: None,
|
||||
process_start_time: None,
|
||||
};
|
||||
|
||||
let scheduler_storage_path =
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use crate::session::message_to_markdown;
|
||||
use anyhow::{Context, Result};
|
||||
use cliclack::{confirm, multiselect};
|
||||
use cliclack::{confirm, multiselect, select};
|
||||
use goose::session::info::{get_session_info, SessionInfo, SortOrder};
|
||||
use goose::session::{self, Identifier};
|
||||
use regex::Regex;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const TRUNCATED_DESC_LENGTH: usize = 60;
|
||||
|
||||
@@ -29,7 +32,7 @@ pub fn remove_sessions(sessions: Vec<SessionInfo>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prompt_interactive_session_selection(sessions: &[SessionInfo]) -> Result<Vec<SessionInfo>> {
|
||||
fn prompt_interactive_session_removal(sessions: &[SessionInfo]) -> Result<Vec<SessionInfo>> {
|
||||
if sessions.is_empty() {
|
||||
println!("No sessions to delete.");
|
||||
return Ok(vec![]);
|
||||
@@ -105,7 +108,7 @@ pub fn handle_session_remove(id: Option<String>, regex_string: Option<String>) -
|
||||
if all_sessions.is_empty() {
|
||||
return Err(anyhow::anyhow!("No sessions found."));
|
||||
}
|
||||
matched_sessions = prompt_interactive_session_selection(&all_sessions)?;
|
||||
matched_sessions = prompt_interactive_session_removal(&all_sessions)?;
|
||||
}
|
||||
|
||||
if matched_sessions.is_empty() {
|
||||
@@ -165,3 +168,184 @@ pub fn handle_session_list(verbose: bool, format: String, ascending: bool) -> Re
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Export a session to Markdown without creating a full Session object
|
||||
///
|
||||
/// This function directly reads messages from the session file and converts them to Markdown
|
||||
/// without creating an Agent or prompting about working directories.
|
||||
pub fn handle_session_export(identifier: Identifier, output_path: Option<PathBuf>) -> Result<()> {
|
||||
// Get the session file path
|
||||
let session_file_path = goose::session::get_path(identifier.clone());
|
||||
|
||||
if !session_file_path.exists() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Session file not found (expected path: {})",
|
||||
session_file_path.display()
|
||||
));
|
||||
}
|
||||
|
||||
// Read messages directly without using Session
|
||||
let messages = match goose::session::read_messages(&session_file_path) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
return Err(anyhow::anyhow!("Failed to read session messages: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
// Generate the markdown content using the export functionality
|
||||
let markdown = export_session_to_markdown(messages, &session_file_path, None);
|
||||
|
||||
// Output the markdown
|
||||
if let Some(output) = output_path {
|
||||
fs::write(&output, markdown)
|
||||
.with_context(|| format!("Failed to write to output file: {}", output.display()))?;
|
||||
println!("Session exported to {}", output.display());
|
||||
} else {
|
||||
println!("{}", markdown);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert a list of messages to markdown format for session export
|
||||
///
|
||||
/// This function handles the formatting of a complete session including headers,
|
||||
/// message organization, and proper tool request/response pairing.
|
||||
fn export_session_to_markdown(
|
||||
messages: Vec<goose::message::Message>,
|
||||
session_file: &Path,
|
||||
session_name_override: Option<&str>,
|
||||
) -> String {
|
||||
let mut markdown_output = String::new();
|
||||
|
||||
let session_name = session_name_override.unwrap_or_else(|| {
|
||||
session_file
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("Unnamed Session")
|
||||
});
|
||||
|
||||
markdown_output.push_str(&format!("# Session Export: {}\n\n", session_name));
|
||||
|
||||
if messages.is_empty() {
|
||||
markdown_output.push_str("*(This session has no messages)*\n");
|
||||
return markdown_output;
|
||||
}
|
||||
|
||||
markdown_output.push_str(&format!("*Total messages: {}*\n\n---\n\n", messages.len()));
|
||||
|
||||
// Track if the last message had tool requests to properly handle tool responses
|
||||
let mut skip_next_if_tool_response = false;
|
||||
|
||||
for message in &messages {
|
||||
// Check if this is a User message containing only ToolResponses
|
||||
let is_only_tool_response = message.role == mcp_core::role::Role::User
|
||||
&& message
|
||||
.content
|
||||
.iter()
|
||||
.all(|content| matches!(content, goose::message::MessageContent::ToolResponse(_)));
|
||||
|
||||
// If the previous message had tool requests and this one is just tool responses,
|
||||
// don't create a new User section - we'll attach the responses to the tool calls
|
||||
if skip_next_if_tool_response && is_only_tool_response {
|
||||
// Export the tool responses without a User heading
|
||||
markdown_output.push_str(&message_to_markdown(message, false));
|
||||
markdown_output.push_str("\n\n---\n\n");
|
||||
skip_next_if_tool_response = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Reset the skip flag - we'll update it below if needed
|
||||
skip_next_if_tool_response = false;
|
||||
|
||||
// Output the role prefix except for tool response-only messages
|
||||
if !is_only_tool_response {
|
||||
let role_prefix = match message.role {
|
||||
mcp_core::role::Role::User => "### User:\n",
|
||||
mcp_core::role::Role::Assistant => "### Assistant:\n",
|
||||
};
|
||||
markdown_output.push_str(role_prefix);
|
||||
}
|
||||
|
||||
// Add the message content
|
||||
markdown_output.push_str(&message_to_markdown(message, false));
|
||||
markdown_output.push_str("\n\n---\n\n");
|
||||
|
||||
// Check if this message has any tool requests, to handle the next message differently
|
||||
if message
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| matches!(content, goose::message::MessageContent::ToolRequest(_)))
|
||||
{
|
||||
skip_next_if_tool_response = true;
|
||||
}
|
||||
}
|
||||
|
||||
markdown_output
|
||||
}
|
||||
|
||||
/// Prompt the user to interactively select a session
|
||||
///
|
||||
/// Shows a list of available sessions and lets the user select one
|
||||
pub fn prompt_interactive_session_selection() -> Result<session::Identifier> {
|
||||
// Get sessions sorted by modification date (newest first)
|
||||
let sessions = match get_session_info(SortOrder::Descending) {
|
||||
Ok(sessions) => sessions,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to list sessions: {:?}", e);
|
||||
return Err(anyhow::anyhow!("Failed to list sessions"));
|
||||
}
|
||||
};
|
||||
|
||||
if sessions.is_empty() {
|
||||
return Err(anyhow::anyhow!("No sessions found"));
|
||||
}
|
||||
|
||||
// Build the selection prompt
|
||||
let mut selector = select("Select a session to export:");
|
||||
|
||||
// Map to display text
|
||||
let display_map: std::collections::HashMap<String, SessionInfo> = sessions
|
||||
.iter()
|
||||
.map(|s| {
|
||||
let desc = if s.metadata.description.is_empty() {
|
||||
"(no description)"
|
||||
} else {
|
||||
&s.metadata.description
|
||||
};
|
||||
|
||||
// Truncate description if too long
|
||||
let truncated_desc = if desc.len() > 40 {
|
||||
format!("{}...", &desc[..37])
|
||||
} else {
|
||||
desc.to_string()
|
||||
};
|
||||
|
||||
let display_text = format!("{} - {} ({})", s.modified, truncated_desc, s.id);
|
||||
(display_text, s.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add each session as an option
|
||||
for display_text in display_map.keys() {
|
||||
selector = selector.item(display_text.clone(), display_text.clone(), "");
|
||||
}
|
||||
|
||||
// Add a cancel option
|
||||
let cancel_value = String::from("cancel");
|
||||
selector = selector.item(cancel_value, "Cancel", "Cancel export");
|
||||
|
||||
// Get user selection
|
||||
let selected_display_text: String = selector.interact()?;
|
||||
|
||||
if selected_display_text == "cancel" {
|
||||
return Err(anyhow::anyhow!("Export canceled"));
|
||||
}
|
||||
|
||||
// Retrieve the selected session
|
||||
if let Some(session) = display_map.get(&selected_display_text) {
|
||||
Ok(goose::session::Identifier::Name(session.id.clone()))
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Invalid selection"))
|
||||
}
|
||||
}
|
||||
|
||||
640
crates/goose-cli/src/commands/web.rs
Normal file
640
crates/goose-cli/src/commands/web.rs
Normal file
@@ -0,0 +1,640 @@
|
||||
use anyhow::Result;
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
State,
|
||||
},
|
||||
response::{Html, IntoResponse, Response},
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
use goose::agents::{Agent, AgentEvent};
|
||||
use goose::message::Message as GooseMessage;
|
||||
use goose::session;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing::error;
|
||||
|
||||
type SessionStore = Arc<RwLock<std::collections::HashMap<String, Arc<Mutex<Vec<GooseMessage>>>>>>;
|
||||
type CancellationStore = Arc<RwLock<std::collections::HashMap<String, tokio::task::AbortHandle>>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
agent: Arc<Agent>,
|
||||
sessions: SessionStore,
|
||||
cancellations: CancellationStore,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum WebSocketMessage {
|
||||
#[serde(rename = "message")]
|
||||
Message {
|
||||
content: String,
|
||||
session_id: String,
|
||||
timestamp: i64,
|
||||
},
|
||||
#[serde(rename = "cancel")]
|
||||
Cancel { session_id: String },
|
||||
#[serde(rename = "response")]
|
||||
Response {
|
||||
content: String,
|
||||
role: String,
|
||||
timestamp: i64,
|
||||
},
|
||||
#[serde(rename = "tool_request")]
|
||||
ToolRequest {
|
||||
id: String,
|
||||
tool_name: String,
|
||||
arguments: serde_json::Value,
|
||||
},
|
||||
#[serde(rename = "tool_response")]
|
||||
ToolResponse {
|
||||
id: String,
|
||||
result: serde_json::Value,
|
||||
is_error: bool,
|
||||
},
|
||||
#[serde(rename = "tool_confirmation")]
|
||||
ToolConfirmation {
|
||||
id: String,
|
||||
tool_name: String,
|
||||
arguments: serde_json::Value,
|
||||
needs_confirmation: bool,
|
||||
},
|
||||
#[serde(rename = "error")]
|
||||
Error { message: String },
|
||||
#[serde(rename = "thinking")]
|
||||
Thinking { message: String },
|
||||
#[serde(rename = "context_exceeded")]
|
||||
ContextExceeded { message: String },
|
||||
#[serde(rename = "cancelled")]
|
||||
Cancelled { message: String },
|
||||
#[serde(rename = "complete")]
|
||||
Complete { message: String },
|
||||
}
|
||||
|
||||
pub async fn handle_web(port: u16, host: String, open: bool) -> Result<()> {
|
||||
// Setup logging
|
||||
crate::logging::setup_logging(Some("goose-web"), None)?;
|
||||
|
||||
// Load config and create agent just like the CLI does
|
||||
let config = goose::config::Config::global();
|
||||
|
||||
let provider_name: String = match config.get_param("GOOSE_PROVIDER") {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
eprintln!("No provider configured. Run 'goose configure' first");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
let model: String = match config.get_param("GOOSE_MODEL") {
|
||||
Ok(m) => m,
|
||||
Err(_) => {
|
||||
eprintln!("No model configured. Run 'goose configure' first");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
let model_config = goose::model::ModelConfig::new(model.clone());
|
||||
|
||||
// Create the agent
|
||||
let agent = Agent::new();
|
||||
let provider = goose::providers::create(&provider_name, model_config)?;
|
||||
agent.update_provider(provider).await?;
|
||||
|
||||
// Load and enable extensions from config
|
||||
let extensions = goose::config::ExtensionConfigManager::get_all()?;
|
||||
for ext_config in extensions {
|
||||
if ext_config.enabled {
|
||||
if let Err(e) = agent.add_extension(ext_config.config.clone()).await {
|
||||
eprintln!(
|
||||
"Warning: Failed to load extension {}: {}",
|
||||
ext_config.config.name(),
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let state = AppState {
|
||||
agent: Arc::new(agent),
|
||||
sessions: Arc::new(RwLock::new(std::collections::HashMap::new())),
|
||||
cancellations: Arc::new(RwLock::new(std::collections::HashMap::new())),
|
||||
};
|
||||
|
||||
// Build router
|
||||
let app = Router::new()
|
||||
.route("/", get(serve_index))
|
||||
.route("/session/{session_name}", get(serve_session))
|
||||
.route("/ws", get(websocket_handler))
|
||||
.route("/api/health", get(health_check))
|
||||
.route("/api/sessions", get(list_sessions))
|
||||
.route("/api/sessions/{session_id}", get(get_session))
|
||||
.route("/static/{*path}", get(serve_static))
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any),
|
||||
)
|
||||
.with_state(state);
|
||||
|
||||
let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
|
||||
|
||||
println!("\n🪿 Starting Goose web server");
|
||||
println!(" Provider: {} | Model: {}", provider_name, model);
|
||||
println!(
|
||||
" Working directory: {}",
|
||||
std::env::current_dir()?.display()
|
||||
);
|
||||
println!(" Server: http://{}", addr);
|
||||
println!(" Press Ctrl+C to stop\n");
|
||||
|
||||
if open {
|
||||
// Open browser
|
||||
let url = format!("http://{}", addr);
|
||||
if let Err(e) = webbrowser::open(&url) {
|
||||
eprintln!("Failed to open browser: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn serve_index() -> Html<&'static str> {
|
||||
Html(include_str!("../../static/index.html"))
|
||||
}
|
||||
|
||||
async fn serve_session(
|
||||
axum::extract::Path(session_name): axum::extract::Path<String>,
|
||||
) -> Html<String> {
|
||||
let html = include_str!("../../static/index.html");
|
||||
// Inject the session name into the HTML so JavaScript can use it
|
||||
let html_with_session = html.replace(
|
||||
"<script src=\"/static/script.js\"></script>",
|
||||
&format!(
|
||||
"<script>window.GOOSE_SESSION_NAME = '{}';</script>\n <script src=\"/static/script.js\"></script>",
|
||||
session_name
|
||||
)
|
||||
);
|
||||
Html(html_with_session)
|
||||
}
|
||||
|
||||
async fn serve_static(axum::extract::Path(path): axum::extract::Path<String>) -> Response {
|
||||
match path.as_str() {
|
||||
"style.css" => (
|
||||
[("content-type", "text/css")],
|
||||
include_str!("../../static/style.css"),
|
||||
)
|
||||
.into_response(),
|
||||
"script.js" => (
|
||||
[("content-type", "application/javascript")],
|
||||
include_str!("../../static/script.js"),
|
||||
)
|
||||
.into_response(),
|
||||
_ => (axum::http::StatusCode::NOT_FOUND, "Not found").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"status": "ok",
|
||||
"service": "goose-web"
|
||||
}))
|
||||
}
|
||||
|
||||
async fn list_sessions() -> Json<serde_json::Value> {
|
||||
match session::list_sessions() {
|
||||
Ok(sessions) => {
|
||||
let session_info: Vec<serde_json::Value> = sessions
|
||||
.into_iter()
|
||||
.map(|(name, path)| {
|
||||
let metadata = session::read_metadata(&path).unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"name": name,
|
||||
"path": path,
|
||||
"description": metadata.description,
|
||||
"message_count": metadata.message_count,
|
||||
"working_dir": metadata.working_dir
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(serde_json::json!({
|
||||
"sessions": session_info
|
||||
}))
|
||||
}
|
||||
Err(e) => Json(serde_json::json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_session(
|
||||
axum::extract::Path(session_id): axum::extract::Path<String>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let session_file = session::get_path(session::Identifier::Name(session_id));
|
||||
|
||||
match session::read_messages(&session_file) {
|
||||
Ok(messages) => {
|
||||
let metadata = session::read_metadata(&session_file).unwrap_or_default();
|
||||
Json(serde_json::json!({
|
||||
"metadata": metadata,
|
||||
"messages": messages
|
||||
}))
|
||||
}
|
||||
Err(e) => Json(serde_json::json!({
|
||||
"error": e.to_string()
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn websocket_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(|socket| handle_socket(socket, state))
|
||||
}
|
||||
|
||||
async fn handle_socket(socket: WebSocket, state: AppState) {
|
||||
let (sender, mut receiver) = socket.split();
|
||||
let sender = Arc::new(Mutex::new(sender));
|
||||
|
||||
while let Some(msg) = receiver.next().await {
|
||||
if let Ok(msg) = msg {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
match serde_json::from_str::<WebSocketMessage>(&text.to_string()) {
|
||||
Ok(WebSocketMessage::Message {
|
||||
content,
|
||||
session_id,
|
||||
..
|
||||
}) => {
|
||||
// Get session file path from session_id
|
||||
let session_file =
|
||||
session::get_path(session::Identifier::Name(session_id.clone()));
|
||||
|
||||
// Get or create session in memory (for fast access during processing)
|
||||
let session_messages = {
|
||||
let sessions = state.sessions.read().await;
|
||||
if let Some(session) = sessions.get(&session_id) {
|
||||
session.clone()
|
||||
} else {
|
||||
drop(sessions);
|
||||
let mut sessions = state.sessions.write().await;
|
||||
|
||||
// Load existing messages from JSONL file if it exists
|
||||
let existing_messages = session::read_messages(&session_file)
|
||||
.unwrap_or_else(|_| Vec::new());
|
||||
|
||||
let new_session = Arc::new(Mutex::new(existing_messages));
|
||||
sessions.insert(session_id.clone(), new_session.clone());
|
||||
new_session
|
||||
}
|
||||
};
|
||||
|
||||
// Clone sender for async processing
|
||||
let sender_clone = sender.clone();
|
||||
let agent = state.agent.clone();
|
||||
|
||||
// Process message in a separate task to allow streaming
|
||||
let task_handle = tokio::spawn(async move {
|
||||
let result = process_message_streaming(
|
||||
&agent,
|
||||
session_messages,
|
||||
session_file,
|
||||
content,
|
||||
sender_clone,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
error!("Error processing message: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
// Store the abort handle
|
||||
{
|
||||
let mut cancellations = state.cancellations.write().await;
|
||||
cancellations
|
||||
.insert(session_id.clone(), task_handle.abort_handle());
|
||||
}
|
||||
|
||||
// Wait for task completion and handle abort
|
||||
let sender_for_abort = sender.clone();
|
||||
let session_id_for_cleanup = session_id.clone();
|
||||
let cancellations_for_cleanup = state.cancellations.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match task_handle.await {
|
||||
Ok(_) => {
|
||||
// Task completed normally
|
||||
}
|
||||
Err(e) if e.is_cancelled() => {
|
||||
// Task was aborted
|
||||
let mut sender = sender_for_abort.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(
|
||||
&WebSocketMessage::Cancelled {
|
||||
message: "Operation cancelled by user"
|
||||
.to_string(),
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Task error: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up cancellation token
|
||||
{
|
||||
let mut cancellations = cancellations_for_cleanup.write().await;
|
||||
cancellations.remove(&session_id_for_cleanup);
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(WebSocketMessage::Cancel { session_id }) => {
|
||||
// Cancel the active operation for this session
|
||||
let abort_handle = {
|
||||
let mut cancellations = state.cancellations.write().await;
|
||||
cancellations.remove(&session_id)
|
||||
};
|
||||
|
||||
if let Some(handle) = abort_handle {
|
||||
handle.abort();
|
||||
|
||||
// Send cancellation confirmation
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(&WebSocketMessage::Cancelled {
|
||||
message: "Operation cancelled".to_string(),
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Ok(_) => {
|
||||
// Ignore other message types
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse WebSocket message: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_message_streaming(
|
||||
agent: &Agent,
|
||||
session_messages: Arc<Mutex<Vec<GooseMessage>>>,
|
||||
session_file: std::path::PathBuf,
|
||||
content: String,
|
||||
sender: Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
|
||||
) -> Result<()> {
|
||||
use futures::StreamExt;
|
||||
use goose::agents::SessionConfig;
|
||||
use goose::message::MessageContent;
|
||||
use goose::session;
|
||||
|
||||
// Create a user message
|
||||
let user_message = GooseMessage::user().with_text(content.clone());
|
||||
|
||||
// Get existing messages from session and add the new user message
|
||||
let mut messages = {
|
||||
let mut session_msgs = session_messages.lock().await;
|
||||
session_msgs.push(user_message.clone());
|
||||
session_msgs.clone()
|
||||
};
|
||||
|
||||
// Persist messages to JSONL file with provider for automatic description generation
|
||||
let provider = agent.provider().await;
|
||||
if provider.is_err() {
|
||||
let error_msg = "I'm not properly configured yet. Please configure a provider through the CLI first using `goose configure`.".to_string();
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(&WebSocketMessage::Response {
|
||||
content: error_msg,
|
||||
role: "assistant".to_string(),
|
||||
timestamp: chrono::Utc::now().timestamp_millis(),
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = provider.unwrap();
|
||||
session::persist_messages(&session_file, &messages, Some(provider.clone())).await?;
|
||||
|
||||
// Create a session config
|
||||
let session_config = SessionConfig {
|
||||
id: session::Identifier::Path(session_file.clone()),
|
||||
working_dir: std::env::current_dir()?,
|
||||
schedule_id: None,
|
||||
};
|
||||
|
||||
// Get response from agent
|
||||
match agent.reply(&messages, Some(session_config)).await {
|
||||
Ok(mut stream) => {
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(AgentEvent::Message(message)) => {
|
||||
// Add message to our session
|
||||
{
|
||||
let mut session_msgs = session_messages.lock().await;
|
||||
session_msgs.push(message.clone());
|
||||
}
|
||||
|
||||
// Persist messages to JSONL file (no provider needed for assistant messages)
|
||||
let current_messages = {
|
||||
let session_msgs = session_messages.lock().await;
|
||||
session_msgs.clone()
|
||||
};
|
||||
session::persist_messages(&session_file, ¤t_messages, None).await?;
|
||||
// Handle different message content types
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
// Send the text response
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(&WebSocketMessage::Response {
|
||||
content: text.text.clone(),
|
||||
role: "assistant".to_string(),
|
||||
timestamp: chrono::Utc::now().timestamp_millis(),
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
MessageContent::ToolRequest(req) => {
|
||||
// Send tool request notification
|
||||
let mut sender = sender.lock().await;
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(
|
||||
&WebSocketMessage::ToolRequest {
|
||||
id: req.id.clone(),
|
||||
tool_name: tool_call.name.clone(),
|
||||
arguments: tool_call.arguments.clone(),
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResponse(_resp) => {
|
||||
// Tool responses are already included in the complete message stream
|
||||
// and will be persisted to session history. No need to send separate
|
||||
// WebSocket messages as this would cause duplicates.
|
||||
}
|
||||
MessageContent::ToolConfirmationRequest(confirmation) => {
|
||||
// Send tool confirmation request
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(
|
||||
&WebSocketMessage::ToolConfirmation {
|
||||
id: confirmation.id.clone(),
|
||||
tool_name: confirmation.tool_name.clone(),
|
||||
arguments: confirmation.arguments.clone(),
|
||||
needs_confirmation: true,
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
|
||||
// For now, auto-approve in web mode
|
||||
// TODO: Implement proper confirmation UI
|
||||
agent.handle_confirmation(
|
||||
confirmation.id.clone(),
|
||||
goose::permission::PermissionConfirmation {
|
||||
principal_type: goose::permission::permission_confirmation::PrincipalType::Tool,
|
||||
permission: goose::permission::Permission::AllowOnce,
|
||||
}
|
||||
).await;
|
||||
}
|
||||
MessageContent::Thinking(thinking) => {
|
||||
// Send thinking indicator
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(&WebSocketMessage::Thinking {
|
||||
message: thinking.thinking.clone(),
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
MessageContent::ContextLengthExceeded(msg) => {
|
||||
// Send context exceeded notification
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(
|
||||
&WebSocketMessage::ContextExceeded {
|
||||
message: msg.msg.clone(),
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
|
||||
// For now, auto-summarize in web mode
|
||||
// TODO: Implement proper UI for context handling
|
||||
let (summarized_messages, _) =
|
||||
agent.summarize_context(&messages).await?;
|
||||
messages = summarized_messages;
|
||||
}
|
||||
_ => {
|
||||
// Handle other message types as needed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(AgentEvent::McpNotification(_notification)) => {
|
||||
// Handle MCP notifications if needed
|
||||
// For now, we'll just log them
|
||||
tracing::info!("Received MCP notification in web interface");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error in message stream: {}", e);
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(&WebSocketMessage::Error {
|
||||
message: format!("Error: {}", e),
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error calling agent: {}", e);
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(&WebSocketMessage::Error {
|
||||
message: format!("Error: {}", e),
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Send completion message
|
||||
let mut sender = sender.lock().await;
|
||||
let _ = sender
|
||||
.send(Message::Text(
|
||||
serde_json::to_string(&WebSocketMessage::Complete {
|
||||
message: "Response complete".to_string(),
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Add webbrowser dependency for opening browser
|
||||
use webbrowser;
|
||||
@@ -7,6 +7,7 @@ use goose::session;
|
||||
use goose::session::Identifier;
|
||||
use mcp_client::transport::Error as McpClientError;
|
||||
use std::process;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::output;
|
||||
use super::Session;
|
||||
@@ -55,6 +56,22 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
// Create the agent
|
||||
let agent: Agent = Agent::new();
|
||||
let new_provider = create(&provider_name, model_config).unwrap();
|
||||
|
||||
// Keep a reference to the provider for display_session_info
|
||||
let provider_for_display = Arc::clone(&new_provider);
|
||||
|
||||
// Log model information at startup
|
||||
if let Some(lead_worker) = new_provider.as_lead_worker() {
|
||||
let (lead_model, worker_model) = lead_worker.get_model_info();
|
||||
tracing::info!(
|
||||
"🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled",
|
||||
lead_model,
|
||||
worker_model
|
||||
);
|
||||
} else {
|
||||
tracing::info!("🤖 Using model: {}", model);
|
||||
}
|
||||
|
||||
agent
|
||||
.update_provider(new_provider)
|
||||
.await
|
||||
@@ -217,6 +234,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
session.agent.override_system_prompt(override_prompt).await;
|
||||
}
|
||||
|
||||
output::display_session_info(session_config.resume, &provider_name, &model, &session_file);
|
||||
output::display_session_info(
|
||||
session_config.resume,
|
||||
&provider_name,
|
||||
&model,
|
||||
&session_file,
|
||||
Some(&provider_for_display),
|
||||
);
|
||||
session
|
||||
}
|
||||
|
||||
1095
crates/goose-cli/src/session/export.rs
Normal file
1095
crates/goose-cli/src/session/export.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,15 @@
|
||||
mod builder;
|
||||
mod completion;
|
||||
mod export;
|
||||
mod input;
|
||||
mod output;
|
||||
mod prompt;
|
||||
mod thinking;
|
||||
|
||||
pub use self::export::message_to_markdown;
|
||||
pub use builder::{build_session, SessionBuilderConfig};
|
||||
use console::Color;
|
||||
use goose::agents::AgentEvent;
|
||||
use goose::permission::permission_confirmation::PrincipalType;
|
||||
use goose::permission::Permission;
|
||||
use goose::permission::PermissionConfirmation;
|
||||
@@ -15,8 +18,7 @@ pub use goose::session::Identifier;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use completion::GooseCompleter;
|
||||
use etcetera::choose_app_strategy;
|
||||
use etcetera::AppStrategy;
|
||||
use etcetera::{choose_app_strategy, AppStrategy};
|
||||
use goose::agents::extension::{Envs, ExtensionConfig};
|
||||
use goose::agents::{Agent, SessionConfig};
|
||||
use goose::config::Config;
|
||||
@@ -25,6 +27,8 @@ use goose::session;
|
||||
use input::InputResult;
|
||||
use mcp_core::handler::ToolError;
|
||||
use mcp_core::prompt::PromptMessage;
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use mcp_core::protocol::JsonRpcNotification;
|
||||
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde_json::Value;
|
||||
@@ -351,9 +355,10 @@ impl Session {
|
||||
// Create and use a global history file in ~/.config/goose directory
|
||||
// This allows command history to persist across different chat sessions
|
||||
// instead of being tied to each individual session's messages
|
||||
let history_file = choose_app_strategy(crate::APP_STRATEGY.clone())
|
||||
.expect("goose requires a home dir")
|
||||
.in_config_dir("history.txt");
|
||||
let strategy =
|
||||
choose_app_strategy(crate::APP_STRATEGY.clone()).expect("goose requires a home dir");
|
||||
let config_dir = strategy.config_dir();
|
||||
let history_file = config_dir.join("history.txt");
|
||||
|
||||
// Ensure config directory exists
|
||||
if let Some(parent) = history_file.parent() {
|
||||
@@ -379,6 +384,9 @@ impl Session {
|
||||
|
||||
output::display_greeting();
|
||||
loop {
|
||||
// Display context usage before each prompt
|
||||
self.display_context_usage().await?;
|
||||
|
||||
match input::get_input(&mut editor)? {
|
||||
input::InputResult::Message(content) => {
|
||||
match self.run_mode {
|
||||
@@ -713,12 +721,14 @@ impl Session {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut progress_bars = output::McpSpinners::new();
|
||||
|
||||
use futures::StreamExt;
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = stream.next() => {
|
||||
match result {
|
||||
Some(Ok(message)) => {
|
||||
Some(Ok(AgentEvent::Message(message))) => {
|
||||
// If it's a confirmation request, get approval but otherwise do not render/persist
|
||||
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
|
||||
output::hide_thinking();
|
||||
@@ -768,56 +778,68 @@ impl Session {
|
||||
} else if let Some(MessageContent::ContextLengthExceeded(_)) = message.content.first() {
|
||||
output::hide_thinking();
|
||||
|
||||
if interactive {
|
||||
// In interactive mode, ask the user what to do
|
||||
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
|
||||
let selected_result = cliclack::select(prompt)
|
||||
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
|
||||
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
|
||||
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
|
||||
.item("cancel", "Cancel", "Cancel and return to chat")
|
||||
.interact();
|
||||
// Check for user-configured default context strategy
|
||||
let config = Config::global();
|
||||
let context_strategy = config.get_param::<String>("GOOSE_CONTEXT_STRATEGY")
|
||||
.unwrap_or_else(|_| if interactive { "prompt".to_string() } else { "summarize".to_string() });
|
||||
|
||||
let selected = match selected_result {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::Interrupted {
|
||||
"cancel" // If interrupted, set selected to cancel
|
||||
} else {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match selected {
|
||||
"clear" => {
|
||||
self.messages.clear();
|
||||
let msg = format!("Session cleared.\n{}", "-".repeat(50));
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
break; // exit the loop to hand back control to the user
|
||||
}
|
||||
"truncate" => {
|
||||
// Truncate messages to fit within context length
|
||||
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
|
||||
let msg = format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50));
|
||||
output::render_text("", Some(Color::Yellow), true);
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
self.messages = truncated_messages;
|
||||
}
|
||||
"summarize" => {
|
||||
// Use the helper function to summarize context
|
||||
Self::summarize_context_messages(&mut self.messages, &self.agent, "Goose summarized messages for you.").await?;
|
||||
}
|
||||
"cancel" => {
|
||||
break; // Return to main prompt
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
let selected = match context_strategy.as_str() {
|
||||
"clear" => "clear",
|
||||
"truncate" => "truncate",
|
||||
"summarize" => "summarize",
|
||||
_ => {
|
||||
if interactive {
|
||||
// In interactive mode with no default, ask the user what to do
|
||||
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
|
||||
cliclack::select(prompt)
|
||||
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
|
||||
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
|
||||
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
|
||||
.interact()?
|
||||
} else {
|
||||
// In headless mode, default to summarize
|
||||
"summarize"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// In headless mode (goose run), automatically use summarize
|
||||
Self::summarize_context_messages(&mut self.messages, &self.agent, "Goose automatically summarized messages to continue processing.").await?;
|
||||
};
|
||||
|
||||
match selected {
|
||||
"clear" => {
|
||||
self.messages.clear();
|
||||
let msg = if context_strategy == "clear" {
|
||||
format!("Context maxed out - automatically cleared session.\n{}", "-".repeat(50))
|
||||
} else {
|
||||
format!("Session cleared.\n{}", "-".repeat(50))
|
||||
};
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
break; // exit the loop to hand back control to the user
|
||||
}
|
||||
"truncate" => {
|
||||
// Truncate messages to fit within context length
|
||||
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
|
||||
let msg = if context_strategy == "truncate" {
|
||||
format!("Context maxed out - automatically truncated messages.\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50))
|
||||
} else {
|
||||
format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50))
|
||||
};
|
||||
output::render_text("", Some(Color::Yellow), true);
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
self.messages = truncated_messages;
|
||||
}
|
||||
"summarize" => {
|
||||
// Use the helper function to summarize context
|
||||
let message_suffix = if context_strategy == "summarize" {
|
||||
"Goose automatically summarized messages for you."
|
||||
} else if interactive {
|
||||
"Goose summarized messages for you."
|
||||
} else {
|
||||
"Goose automatically summarized messages to continue processing."
|
||||
};
|
||||
Self::summarize_context_messages(&mut self.messages, &self.agent, message_suffix).await?;
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
// Restart the stream after handling ContextLengthExceeded
|
||||
@@ -842,10 +864,55 @@ impl Session {
|
||||
session::persist_messages(&self.session_file, &self.messages, None).await?;
|
||||
|
||||
if interactive {output::hide_thinking()};
|
||||
let _ = progress_bars.hide();
|
||||
output::render_message(&message, self.debug);
|
||||
if interactive {output::show_thinking()};
|
||||
}
|
||||
}
|
||||
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
|
||||
if let JsonRpcMessage::Notification(JsonRpcNotification{
|
||||
method,
|
||||
params: Some(Value::Object(o)),
|
||||
..
|
||||
}) = message {
|
||||
match method.as_str() {
|
||||
"notifications/message" => {
|
||||
let data = o.get("data").unwrap_or(&Value::Null);
|
||||
let message = match data {
|
||||
Value::String(s) => s.clone(),
|
||||
Value::Object(o) => {
|
||||
if let Some(Value::String(output)) = o.get("output") {
|
||||
output.to_owned()
|
||||
} else {
|
||||
data.to_string()
|
||||
}
|
||||
},
|
||||
v => {
|
||||
v.to_string()
|
||||
},
|
||||
};
|
||||
progress_bars.log(&message);
|
||||
},
|
||||
"notifications/progress" => {
|
||||
let progress = o.get("progress").and_then(|v| v.as_f64());
|
||||
let token = o.get("progressToken").map(|v| v.to_string());
|
||||
let message = o.get("message").and_then(|v| v.as_str());
|
||||
let total = o
|
||||
.get("total")
|
||||
.and_then(|v| v.as_f64());
|
||||
if let (Some(progress), Some(token)) = (progress, token) {
|
||||
progress_bars.update(
|
||||
token.as_str(),
|
||||
progress,
|
||||
total,
|
||||
message,
|
||||
);
|
||||
}
|
||||
},
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
eprintln!("Error: {}", e);
|
||||
drop(stream);
|
||||
@@ -872,6 +939,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1054,6 +1122,26 @@ impl Session {
|
||||
Ok(metadata.total_tokens)
|
||||
}
|
||||
|
||||
/// Display enhanced context usage with session totals
|
||||
pub async fn display_context_usage(&self) -> Result<()> {
|
||||
let provider = self.agent.provider().await?;
|
||||
let model_config = provider.get_model_config();
|
||||
let context_limit = model_config.context_limit.unwrap_or(32000);
|
||||
|
||||
match self.get_metadata() {
|
||||
Ok(metadata) => {
|
||||
let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;
|
||||
|
||||
output::display_context_usage(total_tokens, context_limit);
|
||||
}
|
||||
Err(_) => {
|
||||
output::display_context_usage(0, context_limit);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle prompt command execution
|
||||
async fn handle_prompt_command(&mut self, opts: input::PromptCommandOptions) -> Result<()> {
|
||||
// name is required
|
||||
|
||||
@@ -2,12 +2,16 @@ use bat::WrappingMode;
|
||||
use console::{style, Color};
|
||||
use goose::config::Config;
|
||||
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
||||
use mcp_core::prompt::PromptArgument;
|
||||
use mcp_core::tool::ToolCall;
|
||||
use serde_json::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Error;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
// Re-export theme for use in main
|
||||
#[derive(Clone, Copy)]
|
||||
@@ -144,6 +148,10 @@ pub fn render_message(message: &Message, debug: bool) {
|
||||
}
|
||||
|
||||
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
||||
render_text_no_newlines(format!("\n{}\n\n", text).as_str(), color, dim);
|
||||
}
|
||||
|
||||
pub fn render_text_no_newlines(text: &str, color: Option<Color>, dim: bool) {
|
||||
let mut styled_text = style(text);
|
||||
if dim {
|
||||
styled_text = styled_text.dim();
|
||||
@@ -153,7 +161,7 @@ pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
||||
} else {
|
||||
styled_text = styled_text.green();
|
||||
}
|
||||
println!("\n{}\n", styled_text);
|
||||
print!("{}", styled_text);
|
||||
}
|
||||
|
||||
pub fn render_enter_plan_mode() {
|
||||
@@ -359,7 +367,6 @@ fn render_shell_request(call: &ToolCall, debug: bool) {
|
||||
}
|
||||
_ => print_params(&call.arguments, 0, debug),
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
fn render_default_request(call: &ToolCall, debug: bool) {
|
||||
@@ -530,7 +537,13 @@ fn shorten_path(path: &str, debug: bool) -> String {
|
||||
}
|
||||
|
||||
// Session display functions
|
||||
pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) {
|
||||
pub fn display_session_info(
|
||||
resume: bool,
|
||||
provider: &str,
|
||||
model: &str,
|
||||
session_file: &Path,
|
||||
provider_instance: Option<&Arc<dyn goose::providers::base::Provider>>,
|
||||
) {
|
||||
let start_session_msg = if resume {
|
||||
"resuming session |"
|
||||
} else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") {
|
||||
@@ -538,14 +551,42 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
|
||||
} else {
|
||||
"starting session |"
|
||||
};
|
||||
println!(
|
||||
"{} {} {} {} {}",
|
||||
style(start_session_msg).dim(),
|
||||
style("provider:").dim(),
|
||||
style(provider).cyan().dim(),
|
||||
style("model:").dim(),
|
||||
style(model).cyan().dim(),
|
||||
);
|
||||
|
||||
// Check if we have lead/worker mode
|
||||
if let Some(provider_inst) = provider_instance {
|
||||
if let Some(lead_worker) = provider_inst.as_lead_worker() {
|
||||
let (lead_model, worker_model) = lead_worker.get_model_info();
|
||||
println!(
|
||||
"{} {} {} {} {} {} {}",
|
||||
style(start_session_msg).dim(),
|
||||
style("provider:").dim(),
|
||||
style(provider).cyan().dim(),
|
||||
style("lead model:").dim(),
|
||||
style(&lead_model).cyan().dim(),
|
||||
style("worker model:").dim(),
|
||||
style(&worker_model).cyan().dim(),
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"{} {} {} {} {}",
|
||||
style(start_session_msg).dim(),
|
||||
style("provider:").dim(),
|
||||
style(provider).cyan().dim(),
|
||||
style("model:").dim(),
|
||||
style(model).cyan().dim(),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Fallback to original behavior if no provider instance
|
||||
println!(
|
||||
"{} {} {} {} {}",
|
||||
style(start_session_msg).dim(),
|
||||
style("provider:").dim(),
|
||||
style(provider).cyan().dim(),
|
||||
style("model:").dim(),
|
||||
style(model).cyan().dim(),
|
||||
);
|
||||
}
|
||||
|
||||
if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") {
|
||||
println!(
|
||||
@@ -568,6 +609,102 @@ pub fn display_greeting() {
|
||||
println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n");
|
||||
}
|
||||
|
||||
/// Display context window usage with both current and session totals
|
||||
pub fn display_context_usage(total_tokens: usize, context_limit: usize) {
|
||||
use console::style;
|
||||
|
||||
// Calculate percentage used
|
||||
let percentage = (total_tokens as f64 / context_limit as f64 * 100.0).round() as usize;
|
||||
|
||||
// Create dot visualization
|
||||
let dot_count = 10;
|
||||
let filled_dots = ((percentage as f64 / 100.0) * dot_count as f64).round() as usize;
|
||||
let empty_dots = dot_count - filled_dots;
|
||||
|
||||
let filled = "●".repeat(filled_dots);
|
||||
let empty = "○".repeat(empty_dots);
|
||||
|
||||
// Combine dots and apply color
|
||||
let dots = format!("{}{}", filled, empty);
|
||||
let colored_dots = if percentage < 50 {
|
||||
style(dots).green()
|
||||
} else if percentage < 85 {
|
||||
style(dots).yellow()
|
||||
} else {
|
||||
style(dots).red()
|
||||
};
|
||||
|
||||
// Print the status line
|
||||
println!(
|
||||
"Context: {} {}% ({}/{} tokens)",
|
||||
colored_dots, percentage, total_tokens, context_limit
|
||||
);
|
||||
}
|
||||
|
||||
pub struct McpSpinners {
|
||||
bars: HashMap<String, ProgressBar>,
|
||||
log_spinner: Option<ProgressBar>,
|
||||
|
||||
multi_bar: MultiProgress,
|
||||
}
|
||||
|
||||
impl McpSpinners {
|
||||
pub fn new() -> Self {
|
||||
McpSpinners {
|
||||
bars: HashMap::new(),
|
||||
log_spinner: None,
|
||||
multi_bar: MultiProgress::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log(&mut self, message: &str) {
|
||||
let spinner = self.log_spinner.get_or_insert_with(|| {
|
||||
let bar = self.multi_bar.add(
|
||||
ProgressBar::new_spinner()
|
||||
.with_style(
|
||||
ProgressStyle::with_template("{spinner:.green} {msg}")
|
||||
.unwrap()
|
||||
.tick_chars("⠋⠙⠚⠛⠓⠒⠊⠉"),
|
||||
)
|
||||
.with_message(message.to_string()),
|
||||
);
|
||||
bar.enable_steady_tick(Duration::from_millis(100));
|
||||
bar
|
||||
});
|
||||
|
||||
spinner.set_message(message.to_string());
|
||||
}
|
||||
|
||||
pub fn update(&mut self, token: &str, value: f64, total: Option<f64>, message: Option<&str>) {
|
||||
let bar = self.bars.entry(token.to_string()).or_insert_with(|| {
|
||||
if let Some(total) = total {
|
||||
self.multi_bar.add(
|
||||
ProgressBar::new((total * 100.0) as u64).with_style(
|
||||
ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}")
|
||||
.unwrap(),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
self.multi_bar.add(ProgressBar::new_spinner())
|
||||
}
|
||||
});
|
||||
bar.set_position((value * 100.0) as u64);
|
||||
if let Some(msg) = message {
|
||||
bar.set_message(msg.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hide(&mut self) -> Result<(), Error> {
|
||||
self.bars.iter_mut().for_each(|(_, bar)| {
|
||||
bar.disable_steady_tick();
|
||||
});
|
||||
if let Some(spinner) = self.log_spinner.as_mut() {
|
||||
spinner.disable_steady_tick();
|
||||
}
|
||||
self.multi_bar.clear()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
46
crates/goose-cli/static/index.html
Normal file
46
crates/goose-cli/static/index.html
Normal file
@@ -0,0 +1,46 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Goose Chat</title>
|
||||
<link rel="stylesheet" href="/static/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1 id="session-title">Goose Chat</h1>
|
||||
<div class="status" id="connection-status">Connecting...</div>
|
||||
</header>
|
||||
|
||||
<div class="chat-container">
|
||||
<div class="messages" id="messages">
|
||||
<div class="welcome-message">
|
||||
<h2>Welcome to Goose!</h2>
|
||||
<p>I'm your AI assistant. How can I help you today?</p>
|
||||
|
||||
<div class="suggestion-pills">
|
||||
<div class="suggestion-pill" onclick="sendSuggestion('What can you do?')">What can you do?</div>
|
||||
<div class="suggestion-pill" onclick="sendSuggestion('Demo writing and reading files')">Demo writing and reading files</div>
|
||||
<div class="suggestion-pill" onclick="sendSuggestion('Make a snake game in a new folder')">Make a snake game in a new folder</div>
|
||||
<div class="suggestion-pill" onclick="sendSuggestion('List files in my current directory')">List files in my current directory</div>
|
||||
<div class="suggestion-pill" onclick="sendSuggestion('Take a screenshot and summarize')">Take a screenshot and summarize</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="input-container">
|
||||
<textarea
|
||||
id="message-input"
|
||||
placeholder="Type your message here..."
|
||||
rows="3"
|
||||
autofocus
|
||||
></textarea>
|
||||
<button id="send-button" type="button">Send</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="/static/script.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
523
crates/goose-cli/static/script.js
Normal file
523
crates/goose-cli/static/script.js
Normal file
@@ -0,0 +1,523 @@
|
||||
// WebSocket connection and chat functionality
|
||||
let socket = null;
|
||||
let sessionId = getSessionId();
|
||||
let isConnected = false;
|
||||
|
||||
// DOM elements
|
||||
const messagesContainer = document.getElementById('messages');
|
||||
const messageInput = document.getElementById('message-input');
|
||||
const sendButton = document.getElementById('send-button');
|
||||
const connectionStatus = document.getElementById('connection-status');
|
||||
|
||||
// Track if we're currently processing
|
||||
let isProcessing = false;
|
||||
|
||||
// Get session ID - either from URL parameter, injected session name, or generate new one
|
||||
function getSessionId() {
|
||||
// Check if session name was injected by server (for /session/:name routes)
|
||||
if (window.GOOSE_SESSION_NAME) {
|
||||
return window.GOOSE_SESSION_NAME;
|
||||
}
|
||||
|
||||
// Check URL parameters
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
const sessionParam = urlParams.get('session') || urlParams.get('name');
|
||||
if (sessionParam) {
|
||||
return sessionParam;
|
||||
}
|
||||
|
||||
// Generate new session ID using CLI format
|
||||
return generateSessionId();
|
||||
}
|
||||
|
||||
// Generate a session ID using timestamp format (yyyymmdd_hhmmss) like CLI
|
||||
function generateSessionId() {
|
||||
const now = new Date();
|
||||
const year = now.getFullYear();
|
||||
const month = String(now.getMonth() + 1).padStart(2, '0');
|
||||
const day = String(now.getDate()).padStart(2, '0');
|
||||
const hour = String(now.getHours()).padStart(2, '0');
|
||||
const minute = String(now.getMinutes()).padStart(2, '0');
|
||||
const second = String(now.getSeconds()).padStart(2, '0');
|
||||
|
||||
return `${year}${month}${day}_${hour}${minute}${second}`;
|
||||
}
|
||||
|
||||
// Format timestamp
|
||||
function formatTimestamp(date) {
|
||||
return date.toLocaleTimeString('en-US', {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
});
|
||||
}
|
||||
|
||||
// Create message element
|
||||
function createMessageElement(content, role, timestamp) {
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = `message ${role}`;
|
||||
|
||||
// Create content div
|
||||
const contentDiv = document.createElement('div');
|
||||
contentDiv.className = 'message-content';
|
||||
contentDiv.innerHTML = formatMessageContent(content);
|
||||
messageDiv.appendChild(contentDiv);
|
||||
|
||||
// Add timestamp
|
||||
const timestampDiv = document.createElement('div');
|
||||
timestampDiv.className = 'timestamp';
|
||||
timestampDiv.textContent = formatTimestamp(new Date(timestamp || Date.now()));
|
||||
messageDiv.appendChild(timestampDiv);
|
||||
|
||||
return messageDiv;
|
||||
}
|
||||
|
||||
// Format message content (handle markdown-like formatting)
|
||||
function formatMessageContent(content) {
|
||||
// Escape HTML
|
||||
let formatted = content
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>');
|
||||
|
||||
// Handle code blocks
|
||||
formatted = formatted.replace(/```(\w+)?\n([\s\S]*?)```/g, (match, lang, code) => {
|
||||
return `<pre><code class="language-${lang || 'plaintext'}">${code.trim()}</code></pre>`;
|
||||
});
|
||||
|
||||
// Handle inline code
|
||||
formatted = formatted.replace(/`([^`]+)`/g, '<code>$1</code>');
|
||||
|
||||
// Handle line breaks
|
||||
formatted = formatted.replace(/\n/g, '<br>');
|
||||
|
||||
return formatted;
|
||||
}
|
||||
|
||||
// Add message to chat
|
||||
function addMessage(content, role, timestamp) {
|
||||
// Remove welcome message if it exists
|
||||
const welcomeMessage = messagesContainer.querySelector('.welcome-message');
|
||||
if (welcomeMessage) {
|
||||
welcomeMessage.remove();
|
||||
}
|
||||
|
||||
const messageElement = createMessageElement(content, role, timestamp);
|
||||
messagesContainer.appendChild(messageElement);
|
||||
|
||||
// Scroll to bottom
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// Add thinking indicator
|
||||
function addThinkingIndicator() {
|
||||
removeThinkingIndicator(); // Remove any existing one first
|
||||
|
||||
const thinkingDiv = document.createElement('div');
|
||||
thinkingDiv.id = 'thinking-indicator';
|
||||
thinkingDiv.className = 'message thinking-message';
|
||||
thinkingDiv.innerHTML = `
|
||||
<div class="thinking-dots">
|
||||
<span></span>
|
||||
<span></span>
|
||||
<span></span>
|
||||
</div>
|
||||
<span class="thinking-text">Goose is thinking...</span>
|
||||
`;
|
||||
messagesContainer.appendChild(thinkingDiv);
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// Remove thinking indicator
|
||||
function removeThinkingIndicator() {
|
||||
const thinking = document.getElementById('thinking-indicator');
|
||||
if (thinking) {
|
||||
thinking.remove();
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to WebSocket
|
||||
function connectWebSocket() {
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsUrl = `${protocol}//${window.location.host}/ws`;
|
||||
|
||||
socket = new WebSocket(wsUrl);
|
||||
|
||||
socket.onopen = () => {
|
||||
console.log('WebSocket connected');
|
||||
isConnected = true;
|
||||
connectionStatus.textContent = 'Connected';
|
||||
connectionStatus.className = 'status connected';
|
||||
sendButton.disabled = false;
|
||||
|
||||
// Check if this session exists and load history if it does
|
||||
loadSessionIfExists();
|
||||
};
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
handleServerMessage(data);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse message:', e);
|
||||
}
|
||||
};
|
||||
|
||||
socket.onclose = () => {
|
||||
console.log('WebSocket disconnected');
|
||||
isConnected = false;
|
||||
connectionStatus.textContent = 'Disconnected';
|
||||
connectionStatus.className = 'status disconnected';
|
||||
sendButton.disabled = true;
|
||||
|
||||
// Attempt to reconnect after 3 seconds
|
||||
setTimeout(connectWebSocket, 3000);
|
||||
};
|
||||
|
||||
socket.onerror = (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
};
|
||||
}
|
||||
|
||||
// Handle messages from server
|
||||
function handleServerMessage(data) {
|
||||
switch (data.type) {
|
||||
case 'response':
|
||||
// For streaming responses, we need to handle partial messages
|
||||
handleStreamingResponse(data);
|
||||
break;
|
||||
case 'tool_request':
|
||||
handleToolRequest(data);
|
||||
break;
|
||||
case 'tool_response':
|
||||
handleToolResponse(data);
|
||||
break;
|
||||
case 'tool_confirmation':
|
||||
handleToolConfirmation(data);
|
||||
break;
|
||||
case 'thinking':
|
||||
handleThinking(data);
|
||||
break;
|
||||
case 'context_exceeded':
|
||||
handleContextExceeded(data);
|
||||
break;
|
||||
case 'cancelled':
|
||||
handleCancelled(data);
|
||||
break;
|
||||
case 'complete':
|
||||
handleComplete(data);
|
||||
break;
|
||||
case 'error':
|
||||
removeThinkingIndicator();
|
||||
resetSendButton();
|
||||
addMessage(`Error: ${data.message}`, 'assistant', Date.now());
|
||||
break;
|
||||
default:
|
||||
console.log('Unknown message type:', data.type);
|
||||
}
|
||||
}
|
||||
|
||||
// Track current streaming message
|
||||
let currentStreamingMessage = null;
|
||||
|
||||
// Handle streaming responses
|
||||
function handleStreamingResponse(data) {
|
||||
removeThinkingIndicator();
|
||||
|
||||
// If this is the first chunk of a new message, or we don't have a current streaming message
|
||||
if (!currentStreamingMessage) {
|
||||
// Create a new message element
|
||||
const messageElement = createMessageElement(data.content, data.role || 'assistant', data.timestamp);
|
||||
messageElement.setAttribute('data-streaming', 'true');
|
||||
messagesContainer.appendChild(messageElement);
|
||||
|
||||
currentStreamingMessage = {
|
||||
element: messageElement,
|
||||
content: data.content,
|
||||
role: data.role || 'assistant',
|
||||
timestamp: data.timestamp
|
||||
};
|
||||
} else {
|
||||
// Append to existing streaming message
|
||||
currentStreamingMessage.content += data.content;
|
||||
|
||||
// Update the message content using the proper content div
|
||||
const contentDiv = currentStreamingMessage.element.querySelector('.message-content');
|
||||
if (contentDiv) {
|
||||
contentDiv.innerHTML = formatMessageContent(currentStreamingMessage.content);
|
||||
}
|
||||
}
|
||||
|
||||
// Scroll to bottom
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// Handle tool requests
|
||||
function handleToolRequest(data) {
|
||||
removeThinkingIndicator(); // Remove thinking when tool starts
|
||||
|
||||
// Reset streaming message so tool doesn't interfere with message flow
|
||||
currentStreamingMessage = null;
|
||||
|
||||
const toolDiv = document.createElement('div');
|
||||
toolDiv.className = 'message assistant tool-message';
|
||||
|
||||
const headerDiv = document.createElement('div');
|
||||
headerDiv.className = 'tool-header';
|
||||
headerDiv.innerHTML = `🔧 <strong>${data.tool_name}</strong>`;
|
||||
|
||||
const contentDiv = document.createElement('div');
|
||||
contentDiv.className = 'tool-content';
|
||||
|
||||
// Format the arguments
|
||||
if (data.tool_name === 'developer__shell' && data.arguments.command) {
|
||||
contentDiv.innerHTML = `<pre><code>${escapeHtml(data.arguments.command)}</code></pre>`;
|
||||
} else if (data.tool_name === 'developer__text_editor') {
|
||||
const action = data.arguments.command || 'unknown';
|
||||
const path = data.arguments.path || 'unknown';
|
||||
contentDiv.innerHTML = `<div class="tool-param"><strong>action:</strong> ${action}</div>`;
|
||||
contentDiv.innerHTML += `<div class="tool-param"><strong>path:</strong> ${escapeHtml(path)}</div>`;
|
||||
if (data.arguments.file_text) {
|
||||
contentDiv.innerHTML += `<div class="tool-param"><strong>content:</strong> <pre><code>${escapeHtml(data.arguments.file_text.substring(0, 200))}${data.arguments.file_text.length > 200 ? '...' : ''}</code></pre></div>`;
|
||||
}
|
||||
} else {
|
||||
contentDiv.innerHTML = `<pre><code>${JSON.stringify(data.arguments, null, 2)}</code></pre>`;
|
||||
}
|
||||
|
||||
toolDiv.appendChild(headerDiv);
|
||||
toolDiv.appendChild(contentDiv);
|
||||
|
||||
// Add a "running" indicator
|
||||
const runningDiv = document.createElement('div');
|
||||
runningDiv.className = 'tool-running';
|
||||
runningDiv.innerHTML = '⏳ Running...';
|
||||
toolDiv.appendChild(runningDiv);
|
||||
|
||||
messagesContainer.appendChild(toolDiv);
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// Handle tool responses
|
||||
function handleToolResponse(data) {
|
||||
// Remove the "running" indicator from the last tool message
|
||||
const toolMessages = messagesContainer.querySelectorAll('.tool-message');
|
||||
if (toolMessages.length > 0) {
|
||||
const lastToolMessage = toolMessages[toolMessages.length - 1];
|
||||
const runningIndicator = lastToolMessage.querySelector('.tool-running');
|
||||
if (runningIndicator) {
|
||||
runningIndicator.remove();
|
||||
}
|
||||
}
|
||||
|
||||
if (data.is_error) {
|
||||
const errorDiv = document.createElement('div');
|
||||
errorDiv.className = 'message tool-error';
|
||||
errorDiv.innerHTML = `<strong>Tool Error:</strong> ${escapeHtml(data.result.error || 'Unknown error')}`;
|
||||
messagesContainer.appendChild(errorDiv);
|
||||
} else {
|
||||
// Handle successful tool response
|
||||
if (Array.isArray(data.result)) {
|
||||
data.result.forEach(content => {
|
||||
if (content.type === 'text' && content.text) {
|
||||
const responseDiv = document.createElement('div');
|
||||
responseDiv.className = 'message tool-result';
|
||||
responseDiv.innerHTML = `<pre>${escapeHtml(content.text)}</pre>`;
|
||||
messagesContainer.appendChild(responseDiv);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
|
||||
// Reset streaming message so next assistant response creates a new message
|
||||
currentStreamingMessage = null;
|
||||
|
||||
// Show thinking indicator because assistant will likely follow up with explanation
|
||||
// Only show if we're still processing (cancel button is active)
|
||||
if (isProcessing) {
|
||||
addThinkingIndicator();
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool confirmations
|
||||
function handleToolConfirmation(data) {
|
||||
const confirmDiv = document.createElement('div');
|
||||
confirmDiv.className = 'message tool-confirmation';
|
||||
confirmDiv.innerHTML = `
|
||||
<div class="tool-confirm-header">⚠️ Tool Confirmation Required</div>
|
||||
<div class="tool-confirm-content">
|
||||
<strong>${data.tool_name}</strong> wants to execute with:
|
||||
<pre><code>${JSON.stringify(data.arguments, null, 2)}</code></pre>
|
||||
</div>
|
||||
<div class="tool-confirm-note">Auto-approved in web mode (UI coming soon)</div>
|
||||
`;
|
||||
messagesContainer.appendChild(confirmDiv);
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// Handle thinking messages
|
||||
function handleThinking(data) {
|
||||
// For now, just log thinking messages
|
||||
console.log('Thinking:', data.message);
|
||||
}
|
||||
|
||||
// Handle context exceeded
|
||||
function handleContextExceeded(data) {
|
||||
const contextDiv = document.createElement('div');
|
||||
contextDiv.className = 'message context-warning';
|
||||
contextDiv.innerHTML = `
|
||||
<div class="context-header">⚠️ Context Length Exceeded</div>
|
||||
<div class="context-content">${escapeHtml(data.message)}</div>
|
||||
<div class="context-note">Auto-summarizing conversation...</div>
|
||||
`;
|
||||
messagesContainer.appendChild(contextDiv);
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// Handle cancelled operation
|
||||
function handleCancelled(data) {
|
||||
removeThinkingIndicator();
|
||||
resetSendButton();
|
||||
|
||||
const cancelDiv = document.createElement('div');
|
||||
cancelDiv.className = 'message system-message cancelled';
|
||||
cancelDiv.innerHTML = `<em>${escapeHtml(data.message)}</em>`;
|
||||
messagesContainer.appendChild(cancelDiv);
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
|
||||
// Handle completion of response
|
||||
function handleComplete(data) {
|
||||
removeThinkingIndicator();
|
||||
resetSendButton();
|
||||
// Finalize any streaming message
|
||||
if (currentStreamingMessage) {
|
||||
currentStreamingMessage = null;
|
||||
}
|
||||
}
|
||||
|
||||
// Reset send button to normal state
|
||||
function resetSendButton() {
|
||||
isProcessing = false;
|
||||
sendButton.textContent = 'Send';
|
||||
sendButton.classList.remove('cancel-mode');
|
||||
}
|
||||
|
||||
// Escape HTML to prevent XSS
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
// Send message or cancel
|
||||
function sendMessage() {
|
||||
if (isProcessing) {
|
||||
// Cancel the current operation
|
||||
socket.send(JSON.stringify({
|
||||
type: 'cancel',
|
||||
session_id: sessionId
|
||||
}));
|
||||
return;
|
||||
}
|
||||
|
||||
const message = messageInput.value.trim();
|
||||
if (!message || !isConnected) return;
|
||||
|
||||
// Add user message to chat
|
||||
addMessage(message, 'user', Date.now());
|
||||
|
||||
// Clear input
|
||||
messageInput.value = '';
|
||||
messageInput.style.height = 'auto';
|
||||
|
||||
// Add thinking indicator
|
||||
addThinkingIndicator();
|
||||
|
||||
// Update button to show cancel
|
||||
isProcessing = true;
|
||||
sendButton.textContent = 'Cancel';
|
||||
sendButton.classList.add('cancel-mode');
|
||||
|
||||
// Send message through WebSocket
|
||||
socket.send(JSON.stringify({
|
||||
type: 'message',
|
||||
content: message,
|
||||
session_id: sessionId,
|
||||
timestamp: Date.now()
|
||||
}));
|
||||
}
|
||||
|
||||
// Handle suggestion pill clicks
|
||||
function sendSuggestion(text) {
|
||||
if (!isConnected || isProcessing) return;
|
||||
|
||||
messageInput.value = text;
|
||||
sendMessage();
|
||||
}
|
||||
|
||||
// Load session history if the session exists (like --resume in CLI)
|
||||
async function loadSessionIfExists() {
|
||||
try {
|
||||
const response = await fetch(`/api/sessions/${sessionId}`);
|
||||
if (response.ok) {
|
||||
const sessionData = await response.json();
|
||||
if (sessionData.messages && sessionData.messages.length > 0) {
|
||||
// Remove welcome message since we're resuming
|
||||
const welcomeMessage = messagesContainer.querySelector('.welcome-message');
|
||||
if (welcomeMessage) {
|
||||
welcomeMessage.remove();
|
||||
}
|
||||
|
||||
// Display session resumed message
|
||||
const resumeDiv = document.createElement('div');
|
||||
resumeDiv.className = 'message system-message';
|
||||
resumeDiv.innerHTML = `<em>Session resumed: ${sessionData.messages.length} messages loaded</em>`;
|
||||
messagesContainer.appendChild(resumeDiv);
|
||||
|
||||
|
||||
// Update page title with session description if available
|
||||
if (sessionData.metadata && sessionData.metadata.description) {
|
||||
document.title = `Goose Chat - ${sessionData.metadata.description}`;
|
||||
}
|
||||
|
||||
messagesContainer.scrollTop = messagesContainer.scrollHeight;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('No existing session found or error loading:', error);
|
||||
// This is fine - just means it's a new session
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Event listeners
|
||||
sendButton.addEventListener('click', sendMessage);
|
||||
|
||||
messageInput.addEventListener('keydown', (e) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
sendMessage();
|
||||
}
|
||||
});
|
||||
|
||||
// Auto-resize textarea
|
||||
messageInput.addEventListener('input', () => {
|
||||
messageInput.style.height = 'auto';
|
||||
messageInput.style.height = messageInput.scrollHeight + 'px';
|
||||
});
|
||||
|
||||
// Initialize WebSocket connection
|
||||
connectWebSocket();
|
||||
|
||||
// Focus on input
|
||||
messageInput.focus();
|
||||
|
||||
// Update session title
|
||||
function updateSessionTitle() {
|
||||
const titleElement = document.getElementById('session-title');
|
||||
// Just show "Goose Chat" - no need to show session ID
|
||||
titleElement.textContent = 'Goose Chat';
|
||||
}
|
||||
|
||||
// Update title on load
|
||||
updateSessionTitle();
|
||||
480
crates/goose-cli/static/style.css
Normal file
480
crates/goose-cli/static/style.css
Normal file
@@ -0,0 +1,480 @@
|
||||
:root {
|
||||
/* Dark theme colors (matching the dark.png) */
|
||||
--bg-primary: #000000;
|
||||
--bg-secondary: #0a0a0a;
|
||||
--bg-tertiary: #1a1a1a;
|
||||
--text-primary: #ffffff;
|
||||
--text-secondary: #a0a0a0;
|
||||
--text-muted: #666666;
|
||||
--border-color: #333333;
|
||||
--border-subtle: #1a1a1a;
|
||||
--accent-color: #ffffff;
|
||||
--accent-hover: #f0f0f0;
|
||||
--user-bg: #1a1a1a;
|
||||
--assistant-bg: #0a0a0a;
|
||||
--input-bg: #0a0a0a;
|
||||
--input-border: #333333;
|
||||
--button-bg: #ffffff;
|
||||
--button-text: #000000;
|
||||
--button-hover: #e0e0e0;
|
||||
--pill-bg: transparent;
|
||||
--pill-border: #333333;
|
||||
--pill-hover: #1a1a1a;
|
||||
--tool-bg: #0f0f0f;
|
||||
--code-bg: #0f0f0f;
|
||||
}
|
||||
|
||||
/* Light theme */
|
||||
@media (prefers-color-scheme: light) {
|
||||
:root {
|
||||
--bg-primary: #ffffff;
|
||||
--bg-secondary: #fafafa;
|
||||
--bg-tertiary: #f5f5f5;
|
||||
--text-primary: #000000;
|
||||
--text-secondary: #666666;
|
||||
--text-muted: #999999;
|
||||
--border-color: #e1e5e9;
|
||||
--border-subtle: #f0f0f0;
|
||||
--accent-color: #000000;
|
||||
--accent-hover: #333333;
|
||||
--user-bg: #f0f0f0;
|
||||
--assistant-bg: #fafafa;
|
||||
--input-bg: #ffffff;
|
||||
--input-border: #e1e5e9;
|
||||
--button-bg: #000000;
|
||||
--button-text: #ffffff;
|
||||
--button-hover: #333333;
|
||||
--pill-bg: #f5f5f5;
|
||||
--pill-border: #e1e5e9;
|
||||
--pill-hover: #e8eaed;
|
||||
--tool-bg: #f8f9fa;
|
||||
--code-bg: #f5f5f5;
|
||||
}
|
||||
}
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
background-color: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
line-height: 1.5;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100vh;
|
||||
max-width: 100%;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 1rem 1.5rem;
|
||||
background-color: var(--bg-primary);
|
||||
border-bottom: 1px solid var(--border-subtle);
|
||||
}
|
||||
|
||||
header h1 {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
header h1::before {
|
||||
content: "🪿";
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
|
||||
.status {
|
||||
font-size: 0.75rem;
|
||||
color: var(--text-secondary);
|
||||
padding: 0.25rem 0.75rem;
|
||||
border-radius: 1rem;
|
||||
background-color: var(--bg-secondary);
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.status.connected {
|
||||
color: #10b981;
|
||||
border-color: #10b981;
|
||||
background-color: rgba(16, 185, 129, 0.1);
|
||||
}
|
||||
|
||||
.status.disconnected {
|
||||
color: #ef4444;
|
||||
border-color: #ef4444;
|
||||
background-color: rgba(239, 68, 68, 0.1);
|
||||
}
|
||||
|
||||
.chat-container {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.messages {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 2rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
}
|
||||
|
||||
.welcome-message {
|
||||
text-align: center;
|
||||
padding: 4rem 2rem;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.welcome-message h2 {
|
||||
font-size: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
color: var(--text-primary);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.welcome-message p {
|
||||
font-size: 1rem;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
/* Suggestion pills like in the design */
|
||||
.suggestion-pills {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.75rem;
|
||||
justify-content: center;
|
||||
margin-top: 2rem;
|
||||
}
|
||||
|
||||
.suggestion-pill {
|
||||
padding: 0.75rem 1.25rem;
|
||||
background-color: var(--pill-bg);
|
||||
border: 1px solid var(--pill-border);
|
||||
border-radius: 2rem;
|
||||
color: var(--text-primary);
|
||||
font-size: 0.875rem;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
text-decoration: none;
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
.suggestion-pill:hover {
|
||||
background-color: var(--pill-hover);
|
||||
border-color: var(--border-color);
|
||||
}
|
||||
|
||||
.message {
|
||||
max-width: 80%;
|
||||
padding: 1rem 1.25rem;
|
||||
border-radius: 1rem;
|
||||
word-wrap: break-word;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.message.user {
|
||||
align-self: flex-end;
|
||||
background-color: var(--user-bg);
|
||||
margin-left: auto;
|
||||
border: 1px solid var(--border-subtle);
|
||||
}
|
||||
|
||||
.message.assistant {
|
||||
align-self: flex-start;
|
||||
background-color: var(--assistant-bg);
|
||||
border: 1px solid var(--border-subtle);
|
||||
}
|
||||
|
||||
.message-content {
|
||||
flex: 1;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.message .timestamp {
|
||||
font-size: 0.6875rem;
|
||||
color: var(--text-muted);
|
||||
margin-top: 0.5rem;
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.message pre {
|
||||
background-color: var(--code-bg);
|
||||
padding: 0.75rem;
|
||||
border-radius: 0.5rem;
|
||||
overflow-x: auto;
|
||||
margin: 0.75rem 0;
|
||||
border: 1px solid var(--border-color);
|
||||
font-size: 0.8125rem;
|
||||
}
|
||||
|
||||
.message code {
|
||||
background-color: var(--code-bg);
|
||||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.25rem;
|
||||
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
|
||||
font-size: 0.8125rem;
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.input-container {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
padding: 1.5rem;
|
||||
background-color: var(--bg-primary);
|
||||
border-top: 1px solid var(--border-subtle);
|
||||
}
|
||||
|
||||
#message-input {
|
||||
flex: 1;
|
||||
padding: 0.875rem 1rem;
|
||||
border: 1px solid var(--input-border);
|
||||
border-radius: 0.75rem;
|
||||
background-color: var(--input-bg);
|
||||
color: var(--text-primary);
|
||||
font-family: inherit;
|
||||
font-size: 0.875rem;
|
||||
resize: none;
|
||||
min-height: 2.75rem;
|
||||
max-height: 8rem;
|
||||
outline: none;
|
||||
transition: border-color 0.2s ease;
|
||||
}
|
||||
|
||||
#message-input:focus {
|
||||
border-color: var(--accent-color);
|
||||
}
|
||||
|
||||
#message-input::placeholder {
|
||||
color: var(--text-muted);
|
||||
}
|
||||
|
||||
#send-button {
|
||||
padding: 0.875rem 1.5rem;
|
||||
background-color: var(--button-bg);
|
||||
color: var(--button-text);
|
||||
border: none;
|
||||
border-radius: 0.75rem;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
min-width: 4rem;
|
||||
}
|
||||
|
||||
#send-button:hover {
|
||||
background-color: var(--button-hover);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
#send-button:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
#send-button.cancel-mode {
|
||||
background-color: #ef4444;
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
#send-button.cancel-mode:hover {
|
||||
background-color: #dc2626;
|
||||
}
|
||||
|
||||
/* Scrollbar styling */
|
||||
.messages::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.messages::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.messages::-webkit-scrollbar-thumb {
|
||||
background: var(--border-color);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.messages::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--text-secondary);
|
||||
}
|
||||
|
||||
/* Tool call styling */
|
||||
.tool-message, .tool-result, .tool-error, .tool-confirmation, .context-warning {
|
||||
background-color: var(--tool-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 0.75rem;
|
||||
padding: 1rem;
|
||||
margin: 0.75rem 0;
|
||||
max-width: 90%;
|
||||
}
|
||||
|
||||
.tool-header, .tool-confirm-header, .context-header {
|
||||
font-weight: 600;
|
||||
color: var(--accent-color);
|
||||
margin-bottom: 0.75rem;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.tool-content {
|
||||
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace;
|
||||
font-size: 0.8125rem;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.tool-param {
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
|
||||
.tool-param strong {
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.tool-running {
|
||||
font-size: 0.8125rem;
|
||||
color: var(--accent-color);
|
||||
margin-top: 0.75rem;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.tool-error {
|
||||
border-color: #ef4444;
|
||||
background-color: rgba(239, 68, 68, 0.05);
|
||||
}
|
||||
|
||||
.tool-error strong {
|
||||
color: #ef4444;
|
||||
}
|
||||
|
||||
.tool-result {
|
||||
background-color: var(--tool-bg);
|
||||
border-left: 3px solid var(--accent-color);
|
||||
margin-left: 1.5rem;
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.tool-confirmation {
|
||||
border-color: #f59e0b;
|
||||
background-color: rgba(245, 158, 11, 0.05);
|
||||
}
|
||||
|
||||
.tool-confirm-note, .context-note {
|
||||
font-size: 0.75rem;
|
||||
color: var(--text-muted);
|
||||
margin-top: 0.75rem;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.context-warning {
|
||||
border-color: #f59e0b;
|
||||
background-color: rgba(245, 158, 11, 0.05);
|
||||
}
|
||||
|
||||
.context-header {
|
||||
color: #f59e0b;
|
||||
}
|
||||
|
||||
.system-message {
|
||||
text-align: center;
|
||||
color: var(--text-secondary);
|
||||
font-style: italic;
|
||||
margin: 1rem 0;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.cancelled {
|
||||
color: #ef4444;
|
||||
}
|
||||
|
||||
/* Thinking indicator */
|
||||
.thinking-message {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
color: var(--text-secondary);
|
||||
font-style: italic;
|
||||
padding: 1rem 1.25rem;
|
||||
background-color: var(--bg-secondary);
|
||||
border-radius: 1rem;
|
||||
border: 1px solid var(--border-subtle);
|
||||
max-width: 80%;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.thinking-dots {
|
||||
display: flex;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.thinking-dots span {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
border-radius: 50%;
|
||||
background-color: var(--text-secondary);
|
||||
animation: thinking-bounce 1.4s infinite ease-in-out both;
|
||||
}
|
||||
|
||||
.thinking-dots span:nth-child(1) {
|
||||
animation-delay: -0.32s;
|
||||
}
|
||||
|
||||
.thinking-dots span:nth-child(2) {
|
||||
animation-delay: -0.16s;
|
||||
}
|
||||
|
||||
@keyframes thinking-bounce {
|
||||
0%, 80%, 100% {
|
||||
transform: scale(0.6);
|
||||
opacity: 0.5;
|
||||
}
|
||||
40% {
|
||||
transform: scale(1);
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
/* Keep the old loading indicator for backwards compatibility */
|
||||
.loading-message {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Responsive design */
|
||||
@media (max-width: 768px) {
|
||||
.messages {
|
||||
padding: 1rem;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.message {
|
||||
max-width: 90%;
|
||||
padding: 0.875rem 1rem;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
header {
|
||||
padding: 0.75rem 1rem;
|
||||
}
|
||||
|
||||
.welcome-message {
|
||||
padding: 2rem 1rem;
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@ use std::ptr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use goose::agents::Agent;
|
||||
use goose::agents::{Agent, AgentEvent};
|
||||
use goose::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::databricks::DatabricksProvider;
|
||||
@@ -256,13 +256,16 @@ pub unsafe extern "C" fn goose_agent_send_message(
|
||||
|
||||
while let Some(message_result) = stream.next().await {
|
||||
match message_result {
|
||||
Ok(message) => {
|
||||
Ok(AgentEvent::Message(message)) => {
|
||||
// Get text or serialize to JSON
|
||||
// Note: Message doesn't have as_text method, we'll serialize to JSON
|
||||
if let Ok(json) = serde_json::to_string(&message) {
|
||||
full_response.push_str(&json);
|
||||
}
|
||||
}
|
||||
Ok(AgentEvent::McpNotification(_)) => {
|
||||
// TODO: Handle MCP notifications.
|
||||
}
|
||||
Err(e) => {
|
||||
full_response.push_str(&format!("\nError in message stream: {}", e));
|
||||
}
|
||||
|
||||
@@ -138,6 +138,11 @@ impl DatabricksProvider {
|
||||
"reduce the length",
|
||||
"token count",
|
||||
"exceeds",
|
||||
"exceed context limit",
|
||||
"input length",
|
||||
"max_tokens",
|
||||
"decrease input length",
|
||||
"context limit",
|
||||
];
|
||||
if check_phrases.iter().any(|c| payload_str.contains(c)) {
|
||||
return Err(ProviderError::ContextLengthExceeded(payload_str));
|
||||
|
||||
@@ -6,7 +6,7 @@ use serde_json::{json, Value};
|
||||
use std::{
|
||||
collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
use tokio::{process::Command, sync::mpsc};
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
@@ -14,7 +14,7 @@ use std::os::unix::fs::PermissionsExt;
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::{Tool, ToolAnnotations},
|
||||
Content,
|
||||
@@ -1155,6 +1155,7 @@ impl Router for ComputerControllerRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -13,13 +13,17 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, BufReader},
|
||||
process::Command,
|
||||
sync::mpsc,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use include_dir::{include_dir, Dir};
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::Tool,
|
||||
Content,
|
||||
@@ -404,10 +408,20 @@ impl DeveloperRouter {
|
||||
if local_ignore_path.is_file() {
|
||||
let _ = builder.add(local_ignore_path);
|
||||
has_ignore_file = true;
|
||||
} else {
|
||||
// If no .gooseignore exists, check for .gitignore as fallback
|
||||
let gitignore_path = cwd.join(".gitignore");
|
||||
if gitignore_path.is_file() {
|
||||
tracing::debug!(
|
||||
"No .gooseignore found, using .gitignore as fallback for ignore patterns"
|
||||
);
|
||||
let _ = builder.add(gitignore_path);
|
||||
has_ignore_file = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Only use default patterns if no .gooseignore files were found
|
||||
// If the file is empty, we will not ignore any file
|
||||
// AND no .gitignore was used as fallback
|
||||
if !has_ignore_file {
|
||||
// Add some sensible defaults
|
||||
let _ = builder.add_line(None, "**/.env");
|
||||
@@ -456,7 +470,11 @@ impl DeveloperRouter {
|
||||
}
|
||||
|
||||
// Shell command execution with platform-specific handling
|
||||
async fn bash(&self, params: Value) -> Result<Vec<Content>, ToolError> {
|
||||
async fn bash(
|
||||
&self,
|
||||
params: Value,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Result<Vec<Content>, ToolError> {
|
||||
let command =
|
||||
params
|
||||
.get("command")
|
||||
@@ -488,27 +506,102 @@ impl DeveloperRouter {
|
||||
|
||||
// Get platform-specific shell configuration
|
||||
let shell_config = get_shell_config();
|
||||
let cmd_with_redirect = format_command_for_platform(command);
|
||||
let cmd_str = format_command_for_platform(command);
|
||||
|
||||
// Execute the command using platform-specific shell
|
||||
let child = Command::new(&shell_config.executable)
|
||||
let mut child = Command::new(&shell_config.executable)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.stdin(Stdio::null())
|
||||
.kill_on_drop(true)
|
||||
.arg(&shell_config.arg)
|
||||
.arg(cmd_with_redirect)
|
||||
.arg(cmd_str)
|
||||
.spawn()
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
let stderr = child.stderr.take().unwrap();
|
||||
|
||||
let mut stdout_reader = BufReader::new(stdout);
|
||||
let mut stderr_reader = BufReader::new(stderr);
|
||||
|
||||
let output_task = tokio::spawn(async move {
|
||||
let mut combined_output = String::new();
|
||||
|
||||
let mut stdout_buf = Vec::new();
|
||||
let mut stderr_buf = Vec::new();
|
||||
|
||||
let mut stdout_done = false;
|
||||
let mut stderr_done = false;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
n = stdout_reader.read_until(b'\n', &mut stdout_buf), if !stdout_done => {
|
||||
if n? == 0 {
|
||||
stdout_done = true;
|
||||
} else {
|
||||
let line = String::from_utf8_lossy(&stdout_buf);
|
||||
|
||||
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": {
|
||||
"type": "shell",
|
||||
"stream": "stdout",
|
||||
"output": line.to_string(),
|
||||
}
|
||||
})),
|
||||
})).ok();
|
||||
|
||||
combined_output.push_str(&line);
|
||||
stdout_buf.clear();
|
||||
}
|
||||
}
|
||||
|
||||
n = stderr_reader.read_until(b'\n', &mut stderr_buf), if !stderr_done => {
|
||||
if n? == 0 {
|
||||
stderr_done = true;
|
||||
} else {
|
||||
let line = String::from_utf8_lossy(&stderr_buf);
|
||||
|
||||
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": {
|
||||
"type": "shell",
|
||||
"stream": "stderr",
|
||||
"output": line.to_string(),
|
||||
}
|
||||
})),
|
||||
})).ok();
|
||||
|
||||
combined_output.push_str(&line);
|
||||
stderr_buf.clear();
|
||||
}
|
||||
}
|
||||
|
||||
else => break,
|
||||
}
|
||||
|
||||
if stdout_done && stderr_done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok::<_, std::io::Error>(combined_output)
|
||||
});
|
||||
|
||||
// Wait for the command to complete and get output
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
child
|
||||
.wait()
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let stdout_str = String::from_utf8_lossy(&output.stdout);
|
||||
let output_str = stdout_str;
|
||||
let output_str = match output_task.await {
|
||||
Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
|
||||
Err(e) => return Err(ToolError::ExecutionError(e.to_string())),
|
||||
};
|
||||
|
||||
// Check the character count of the output
|
||||
const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB
|
||||
@@ -1048,12 +1141,13 @@ impl Router for DeveloperRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
Box::pin(async move {
|
||||
match tool_name.as_str() {
|
||||
"shell" => this.bash(arguments).await,
|
||||
"shell" => this.bash(arguments, notifier).await,
|
||||
"text_editor" => this.text_editor(arguments).await,
|
||||
"list_windows" => this.list_windows(arguments).await,
|
||||
"screen_capture" => this.screen_capture(arguments).await,
|
||||
@@ -1195,6 +1289,10 @@ mod tests {
|
||||
.await
|
||||
}
|
||||
|
||||
fn dummy_sender() -> mpsc::Sender<JsonRpcMessage> {
|
||||
mpsc::channel(1).0
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_shell_missing_parameters() {
|
||||
@@ -1202,7 +1300,7 @@ mod tests {
|
||||
std::env::set_current_dir(&temp_dir).unwrap();
|
||||
|
||||
let router = get_router().await;
|
||||
let result = router.call_tool("shell", json!({})).await;
|
||||
let result = router.call_tool("shell", json!({}), dummy_sender()).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.err().unwrap();
|
||||
@@ -1263,6 +1361,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": large_file_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1288,6 +1387,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": many_chars_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1319,6 +1419,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "Hello, world!"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1331,6 +1432,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1369,6 +1471,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "Hello, world!"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1383,6 +1486,7 @@ mod tests {
|
||||
"old_str": "world",
|
||||
"new_str": "Rust"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1407,6 +1511,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1444,6 +1549,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "First line"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1458,6 +1564,7 @@ mod tests {
|
||||
"old_str": "First line",
|
||||
"new_str": "Second line"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1470,6 +1577,7 @@ mod tests {
|
||||
"command": "undo_edit",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1485,6 +1593,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1583,6 +1692,7 @@ mod tests {
|
||||
"path": temp_dir.path().join("secret.txt").to_str().unwrap(),
|
||||
"file_text": "test content"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1601,6 +1711,7 @@ mod tests {
|
||||
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
|
||||
"file_text": "test content"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1642,6 +1753,7 @@ mod tests {
|
||||
json!({
|
||||
"command": format!("cat {}", secret_file_path.to_str().unwrap())
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1658,6 +1770,202 @@ mod tests {
|
||||
json!({
|
||||
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok(), "Should be able to cat non-ignored file");
|
||||
|
||||
temp_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_gitignore_fallback_when_no_gooseignore() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
std::env::set_current_dir(&temp_dir).unwrap();
|
||||
|
||||
// Create a .gitignore file but no .gooseignore
|
||||
std::fs::write(temp_dir.path().join(".gitignore"), "*.log\n*.tmp\n.env").unwrap();
|
||||
|
||||
let router = DeveloperRouter::new();
|
||||
|
||||
// Test that gitignore patterns are respected
|
||||
assert!(
|
||||
router.is_ignored(Path::new("test.log")),
|
||||
"*.log pattern from .gitignore should be ignored"
|
||||
);
|
||||
assert!(
|
||||
router.is_ignored(Path::new("build.tmp")),
|
||||
"*.tmp pattern from .gitignore should be ignored"
|
||||
);
|
||||
assert!(
|
||||
router.is_ignored(Path::new(".env")),
|
||||
".env pattern from .gitignore should be ignored"
|
||||
);
|
||||
assert!(
|
||||
!router.is_ignored(Path::new("test.txt")),
|
||||
"test.txt should not be ignored"
|
||||
);
|
||||
|
||||
temp_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_gooseignore_takes_precedence_over_gitignore() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
std::env::set_current_dir(&temp_dir).unwrap();
|
||||
|
||||
// Create both .gooseignore and .gitignore files with different patterns
|
||||
std::fs::write(temp_dir.path().join(".gooseignore"), "*.secret").unwrap();
|
||||
std::fs::write(temp_dir.path().join(".gitignore"), "*.log\ntarget/").unwrap();
|
||||
|
||||
let router = DeveloperRouter::new();
|
||||
|
||||
// .gooseignore patterns should be used
|
||||
assert!(
|
||||
router.is_ignored(Path::new("test.secret")),
|
||||
"*.secret pattern from .gooseignore should be ignored"
|
||||
);
|
||||
|
||||
// .gitignore patterns should NOT be used when .gooseignore exists
|
||||
assert!(
|
||||
!router.is_ignored(Path::new("test.log")),
|
||||
"*.log pattern from .gitignore should NOT be ignored when .gooseignore exists"
|
||||
);
|
||||
assert!(
|
||||
!router.is_ignored(Path::new("build.tmp")),
|
||||
"*.tmp pattern from .gitignore should NOT be ignored when .gooseignore exists"
|
||||
);
|
||||
|
||||
temp_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_default_patterns_when_no_ignore_files() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
std::env::set_current_dir(&temp_dir).unwrap();
|
||||
|
||||
// Don't create any ignore files
|
||||
let router = DeveloperRouter::new();
|
||||
|
||||
// Default patterns should be used
|
||||
assert!(
|
||||
router.is_ignored(Path::new(".env")),
|
||||
".env should be ignored by default patterns"
|
||||
);
|
||||
assert!(
|
||||
router.is_ignored(Path::new(".env.local")),
|
||||
".env.local should be ignored by default patterns"
|
||||
);
|
||||
assert!(
|
||||
router.is_ignored(Path::new("secrets.txt")),
|
||||
"secrets.txt should be ignored by default patterns"
|
||||
);
|
||||
assert!(
|
||||
!router.is_ignored(Path::new("normal.txt")),
|
||||
"normal.txt should not be ignored"
|
||||
);
|
||||
|
||||
temp_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_text_editor_respects_gitignore_fallback() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
std::env::set_current_dir(&temp_dir).unwrap();
|
||||
|
||||
// Create a .gitignore file but no .gooseignore
|
||||
std::fs::write(temp_dir.path().join(".gitignore"), "*.log").unwrap();
|
||||
|
||||
let router = DeveloperRouter::new();
|
||||
|
||||
// Try to write to a file ignored by .gitignore
|
||||
let result = router
|
||||
.call_tool(
|
||||
"text_editor",
|
||||
json!({
|
||||
"command": "write",
|
||||
"path": temp_dir.path().join("test.log").to_str().unwrap(),
|
||||
"file_text": "test content"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should not be able to write to file ignored by .gitignore fallback"
|
||||
);
|
||||
assert!(matches!(result.unwrap_err(), ToolError::ExecutionError(_)));
|
||||
|
||||
// Try to write to a non-ignored file
|
||||
let result = router
|
||||
.call_tool(
|
||||
"text_editor",
|
||||
json!({
|
||||
"command": "write",
|
||||
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
|
||||
"file_text": "test content"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should be able to write to non-ignored file"
|
||||
);
|
||||
|
||||
temp_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_bash_respects_gitignore_fallback() {
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
std::env::set_current_dir(&temp_dir).unwrap();
|
||||
|
||||
// Create a .gitignore file but no .gooseignore
|
||||
std::fs::write(temp_dir.path().join(".gitignore"), "*.log").unwrap();
|
||||
|
||||
let router = DeveloperRouter::new();
|
||||
|
||||
// Create a file that would be ignored by .gitignore
|
||||
let log_file_path = temp_dir.path().join("test.log");
|
||||
std::fs::write(&log_file_path, "log content").unwrap();
|
||||
|
||||
// Try to cat the ignored file
|
||||
let result = router
|
||||
.call_tool(
|
||||
"shell",
|
||||
json!({
|
||||
"command": format!("cat {}", log_file_path.to_str().unwrap())
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should not be able to cat file ignored by .gitignore fallback"
|
||||
);
|
||||
assert!(matches!(result.unwrap_err(), ToolError::ExecutionError(_)));
|
||||
|
||||
// Try to cat a non-ignored file
|
||||
let allowed_file_path = temp_dir.path().join("allowed.txt");
|
||||
std::fs::write(&allowed_file_path, "allowed content").unwrap();
|
||||
|
||||
let result = router
|
||||
.call_tool(
|
||||
"shell",
|
||||
json!({
|
||||
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::env;
|
||||
pub struct ShellConfig {
|
||||
pub executable: String,
|
||||
pub arg: String,
|
||||
pub redirect_syntax: String,
|
||||
}
|
||||
|
||||
impl Default for ShellConfig {
|
||||
@@ -14,13 +13,11 @@ impl Default for ShellConfig {
|
||||
Self {
|
||||
executable: "powershell.exe".to_string(),
|
||||
arg: "-NoProfile -NonInteractive -Command".to_string(),
|
||||
redirect_syntax: "2>&1".to_string(),
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
executable: "bash".to_string(),
|
||||
arg: "-c".to_string(),
|
||||
redirect_syntax: "2>&1".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,13 +28,12 @@ pub fn get_shell_config() -> ShellConfig {
|
||||
}
|
||||
|
||||
pub fn format_command_for_platform(command: &str) -> String {
|
||||
let config = get_shell_config();
|
||||
if cfg!(windows) {
|
||||
// For PowerShell, wrap the command in braces to handle special characters
|
||||
format!("{{ {} }} {}", command, config.redirect_syntax)
|
||||
format!("{{ {} }}", command)
|
||||
} else {
|
||||
// For other shells, no braces needed
|
||||
format!("{} {}", command, config.redirect_syntax)
|
||||
command.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use base64::Engine;
|
||||
use chrono::NaiveDate;
|
||||
use indoc::indoc;
|
||||
use lazy_static::lazy_static;
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use mcp_core::tool::ToolAnnotations;
|
||||
use oauth_pkce::PkceOAuth2Client;
|
||||
use regex::Regex;
|
||||
@@ -14,6 +15,7 @@ use serde_json::{json, Value};
|
||||
use std::io::Cursor;
|
||||
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
|
||||
use storage::CredentialsManager;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::{
|
||||
@@ -3281,6 +3283,7 @@ impl Router for GoogleDriveRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -5,7 +5,7 @@ use mcp_core::{
|
||||
content::Content,
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
role::Role,
|
||||
tool::Tool,
|
||||
@@ -16,7 +16,7 @@ use serde_json::Value;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::error;
|
||||
|
||||
@@ -158,6 +158,7 @@ impl Router for JetBrainsRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -10,11 +10,12 @@ use std::{
|
||||
path::PathBuf,
|
||||
pin::Pin,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::{Tool, ToolAnnotations, ToolCall},
|
||||
Content,
|
||||
@@ -520,6 +521,7 @@ impl Router for MemoryRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -3,11 +3,12 @@ use include_dir::{include_dir, Dir};
|
||||
use indoc::formatdoc;
|
||||
use serde_json::{json, Value};
|
||||
use std::{future::Future, pin::Pin};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
role::Role,
|
||||
tool::{Tool, ToolAnnotations},
|
||||
@@ -130,6 +131,7 @@ impl Router for TutorialRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -45,6 +45,8 @@ use utoipa::OpenApi;
|
||||
super::routes::schedule::run_now_handler,
|
||||
super::routes::schedule::pause_schedule,
|
||||
super::routes::schedule::unpause_schedule,
|
||||
super::routes::schedule::kill_running_job,
|
||||
super::routes::schedule::inspect_running_job,
|
||||
super::routes::schedule::sessions_handler
|
||||
),
|
||||
components(schemas(
|
||||
@@ -95,6 +97,8 @@ use utoipa::OpenApi;
|
||||
SessionMetadata,
|
||||
super::routes::schedule::CreateScheduleRequest,
|
||||
super::routes::schedule::UpdateScheduleRequest,
|
||||
super::routes::schedule::KillJobResponse,
|
||||
super::routes::schedule::InspectJobResponse,
|
||||
goose::scheduler::ScheduledJob,
|
||||
super::routes::schedule::RunNowResponse,
|
||||
super::routes::schedule::ListSchedulesResponse,
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
"gcp_vertex_ai": {
|
||||
"name": "GCP Vertex AI",
|
||||
"description": "Use Vertex AI platform models",
|
||||
"models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05", "gemini-2.5-pro-exp-03-25"],
|
||||
"models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05", "gemini-2.5-pro-exp-03-25", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-05-06"],
|
||||
"required_keys": ["GCP_PROJECT_ID", "GCP_LOCATION"]
|
||||
},
|
||||
"google": {
|
||||
|
||||
@@ -10,7 +10,7 @@ use axum::{
|
||||
use bytes::Bytes;
|
||||
use futures::{stream::StreamExt, Stream};
|
||||
use goose::{
|
||||
agents::SessionConfig,
|
||||
agents::{AgentEvent, SessionConfig},
|
||||
message::{Message, MessageContent},
|
||||
permission::permission_confirmation::PrincipalType,
|
||||
};
|
||||
@@ -18,7 +18,7 @@ use goose::{
|
||||
permission::{Permission, PermissionConfirmation},
|
||||
session,
|
||||
};
|
||||
use mcp_core::{role::Role, Content, ToolResult};
|
||||
use mcp_core::{protocol::JsonRpcMessage, role::Role, Content, ToolResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
@@ -79,9 +79,19 @@ impl IntoResponse for SseResponse {
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum MessageEvent {
|
||||
Message { message: Message },
|
||||
Error { error: String },
|
||||
Finish { reason: String },
|
||||
Message {
|
||||
message: Message,
|
||||
},
|
||||
Error {
|
||||
error: String,
|
||||
},
|
||||
Finish {
|
||||
reason: String,
|
||||
},
|
||||
Notification {
|
||||
request_id: String,
|
||||
message: JsonRpcMessage,
|
||||
},
|
||||
}
|
||||
|
||||
async fn stream_event(
|
||||
@@ -200,7 +210,7 @@ async fn handler(
|
||||
tokio::select! {
|
||||
response = timeout(Duration::from_millis(500), stream.next()) => {
|
||||
match response {
|
||||
Ok(Some(Ok(message))) => {
|
||||
Ok(Some(Ok(AgentEvent::Message(message)))) => {
|
||||
all_messages.push(message.clone());
|
||||
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
@@ -223,6 +233,20 @@ async fn handler(
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
|
||||
if let Err(e) = stream_event(MessageEvent::Notification{
|
||||
request_id: request_id.clone(),
|
||||
message: n,
|
||||
}, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Ok(Some(Err(e))) => {
|
||||
tracing::error!("Error processing message: {}", e);
|
||||
let _ = stream_event(
|
||||
@@ -317,7 +341,7 @@ async fn ask_handler(
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
match response {
|
||||
Ok(message) => {
|
||||
Ok(AgentEvent::Message(message)) => {
|
||||
if message.role == Role::Assistant {
|
||||
for content in &message.content {
|
||||
if let MessageContent::Text(text) = content {
|
||||
@@ -328,6 +352,10 @@ async fn ask_handler(
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(AgentEvent::McpNotification(n)) => {
|
||||
// Handle notifications if needed
|
||||
tracing::info!("Received notification: {:?}", n);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error processing as_ai message: {}", e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
@@ -31,6 +31,21 @@ pub struct ListSchedulesResponse {
|
||||
jobs: Vec<ScheduledJob>,
|
||||
}
|
||||
|
||||
// Response for the kill endpoint
|
||||
#[derive(Serialize, utoipa::ToSchema)]
|
||||
pub struct KillJobResponse {
|
||||
message: String,
|
||||
}
|
||||
|
||||
// Response for the inspect endpoint
|
||||
#[derive(Serialize, utoipa::ToSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InspectJobResponse {
|
||||
session_id: Option<String>,
|
||||
process_start_time: Option<String>,
|
||||
running_duration_seconds: Option<i64>,
|
||||
}
|
||||
|
||||
// Response for the run_now endpoint
|
||||
#[derive(Serialize, utoipa::ToSchema)]
|
||||
pub struct RunNowResponse {
|
||||
@@ -100,6 +115,8 @@ async fn create_schedule(
|
||||
last_run: None,
|
||||
currently_running: false,
|
||||
paused: false,
|
||||
current_session_id: None,
|
||||
process_start_time: None,
|
||||
};
|
||||
scheduler
|
||||
.add_scheduled_job(job.clone())
|
||||
@@ -199,6 +216,17 @@ async fn run_now_handler(
|
||||
eprintln!("Error running schedule '{}' now: {:?}", id, e);
|
||||
match e {
|
||||
goose::scheduler::SchedulerError::JobNotFound(_) => Err(StatusCode::NOT_FOUND),
|
||||
goose::scheduler::SchedulerError::AnyhowError(ref err) => {
|
||||
// Check if this is a cancellation error
|
||||
if err.to_string().contains("was successfully cancelled") {
|
||||
// Return a special session_id to indicate cancellation
|
||||
Ok(Json(RunNowResponse {
|
||||
session_id: "CANCELLED".to_string(),
|
||||
}))
|
||||
} else {
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
_ => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
}
|
||||
@@ -389,6 +417,92 @@ async fn update_schedule(
|
||||
Ok(Json(updated_job))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/schedule/{id}/kill",
|
||||
responses(
|
||||
(status = 200, description = "Running job killed successfully"),
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
pub async fn kill_running_job(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<KillJobResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
scheduler.kill_running_job(&id).await.map_err(|e| {
|
||||
eprintln!("Error killing running job '{}': {:?}", id, e);
|
||||
match e {
|
||||
goose::scheduler::SchedulerError::JobNotFound(_) => StatusCode::NOT_FOUND,
|
||||
goose::scheduler::SchedulerError::AnyhowError(_) => StatusCode::BAD_REQUEST,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(Json(KillJobResponse {
|
||||
message: format!("Successfully killed running job '{}'", id),
|
||||
}))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/schedule/{id}/inspect",
|
||||
params(
|
||||
("id" = String, Path, description = "ID of the schedule to inspect")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Running job information", body = InspectJobResponse),
|
||||
(status = 404, description = "Scheduled job not found"),
|
||||
(status = 500, description = "Internal server error")
|
||||
),
|
||||
tag = "schedule"
|
||||
)]
|
||||
#[axum::debug_handler]
|
||||
pub async fn inspect_running_job(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<InspectJobResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
let scheduler = state
|
||||
.scheduler()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
match scheduler.get_running_job_info(&id).await {
|
||||
Ok(info) => {
|
||||
if let Some((session_id, start_time)) = info {
|
||||
let duration = chrono::Utc::now().signed_duration_since(start_time);
|
||||
Ok(Json(InspectJobResponse {
|
||||
session_id: Some(session_id),
|
||||
process_start_time: Some(start_time.to_rfc3339()),
|
||||
running_duration_seconds: Some(duration.num_seconds()),
|
||||
}))
|
||||
} else {
|
||||
Ok(Json(InspectJobResponse {
|
||||
session_id: None,
|
||||
process_start_time: None,
|
||||
running_duration_seconds: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error inspecting running job '{}': {:?}", id, e);
|
||||
match e {
|
||||
goose::scheduler::SchedulerError::JobNotFound(_) => Err(StatusCode::NOT_FOUND),
|
||||
_ => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/schedule/create", post(create_schedule))
|
||||
@@ -398,6 +512,8 @@ pub fn routes(state: Arc<AppState>) -> Router {
|
||||
.route("/schedule/{id}/run_now", post(run_now_handler)) // Corrected
|
||||
.route("/schedule/{id}/pause", post(pause_schedule))
|
||||
.route("/schedule/{id}/unpause", post(unpause_schedule))
|
||||
.route("/schedule/{id}/kill", post(kill_running_job))
|
||||
.route("/schedule/{id}/inspect", get(inspect_running_job))
|
||||
.route("/schedule/{id}/sessions", get(sessions_handler)) // Corrected
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -71,10 +71,10 @@ aws-sdk-bedrockruntime = "1.74.0"
|
||||
# For GCP Vertex AI provider auth
|
||||
jsonwebtoken = "9.3.1"
|
||||
|
||||
# Added blake3 hashing library as a dependency
|
||||
blake3 = "1.5"
|
||||
fs2 = "0.4.3"
|
||||
futures-util = "0.3.31"
|
||||
tokio-stream = "0.1.17"
|
||||
|
||||
# Vector database for tool selection
|
||||
lancedb = "0.13"
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use dotenv::dotenv;
|
||||
use futures::StreamExt;
|
||||
use goose::agents::{Agent, ExtensionConfig};
|
||||
use goose::agents::{Agent, AgentEvent, ExtensionConfig};
|
||||
use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT};
|
||||
use goose::message::Message;
|
||||
use goose::providers::databricks::DatabricksProvider;
|
||||
@@ -20,10 +20,11 @@ async fn main() {
|
||||
|
||||
let config = ExtensionConfig::stdio(
|
||||
"developer",
|
||||
"./target/debug/developer",
|
||||
"./target/debug/goose",
|
||||
DEFAULT_EXTENSION_DESCRIPTION,
|
||||
DEFAULT_EXTENSION_TIMEOUT,
|
||||
);
|
||||
)
|
||||
.with_args(vec!["mcp", "developer"]);
|
||||
agent.add_extension(config).await.unwrap();
|
||||
|
||||
println!("Extensions:");
|
||||
@@ -35,11 +36,8 @@ async fn main() {
|
||||
.with_text("can you summarize the readme.md in this dir using just a haiku?")];
|
||||
|
||||
let mut stream = agent.reply(&messages, None).await.unwrap();
|
||||
while let Some(message) = stream.next().await {
|
||||
println!(
|
||||
"{}",
|
||||
serde_json::to_string_pretty(&message.unwrap()).unwrap()
|
||||
);
|
||||
while let Some(Ok(AgentEvent::Message(message))) = stream.next().await {
|
||||
println!("{}", serde_json::to_string_pretty(&message).unwrap());
|
||||
println!("\n");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::stream::BoxStream;
|
||||
use futures::TryStreamExt;
|
||||
use futures::{FutureExt, Stream, TryStreamExt};
|
||||
use futures_util::stream;
|
||||
use futures_util::stream::StreamExt;
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
|
||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||
use crate::message::Message;
|
||||
@@ -39,7 +44,7 @@ use mcp_core::{
|
||||
|
||||
use super::platform_tools;
|
||||
use super::router_tools;
|
||||
use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
|
||||
use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
|
||||
|
||||
/// The main goose Agent
|
||||
pub struct Agent {
|
||||
@@ -56,6 +61,12 @@ pub struct Agent {
|
||||
pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AgentEvent {
|
||||
Message(Message),
|
||||
McpNotification((String, JsonRpcMessage)),
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub fn new() -> Self {
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
@@ -100,6 +111,40 @@ impl Default for Agent {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ToolStreamItem<T> {
|
||||
Message(JsonRpcMessage),
|
||||
Result(T),
|
||||
}
|
||||
|
||||
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem<ToolResult<Vec<Content>>>> + Send>>;
|
||||
|
||||
// tool_stream combines a stream of JsonRpcMessages with a future representing the
|
||||
// final result of the tool call. MCP notifications are not request-scoped, but
|
||||
// this lets us capture all notifications emitted during the tool call for
|
||||
// simpler consumption
|
||||
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
|
||||
where
|
||||
S: Stream<Item = JsonRpcMessage> + Send + Unpin + 'static,
|
||||
F: Future<Output = ToolResult<Vec<Content>>> + Send + 'static,
|
||||
{
|
||||
Box::pin(async_stream::stream! {
|
||||
tokio::pin!(done);
|
||||
let mut rx = rx;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(msg) = rx.next() => {
|
||||
yield ToolStreamItem::Message(msg);
|
||||
}
|
||||
r = &mut done => {
|
||||
yield ToolStreamItem::Result(r);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
/// Get a reference count clone to the provider
|
||||
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||
@@ -143,7 +188,7 @@ impl Agent {
|
||||
&self,
|
||||
tool_call: mcp_core::tool::ToolCall,
|
||||
request_id: String,
|
||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
||||
) -> (String, Result<ToolCallResult, ToolError>) {
|
||||
// Check if this tool call should be allowed based on repetition monitoring
|
||||
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
|
||||
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
|
||||
@@ -171,52 +216,65 @@ impl Agent {
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
return self
|
||||
let (request_id, result) = self
|
||||
.manage_extensions(action, extension_name, request_id)
|
||||
.await;
|
||||
|
||||
return (request_id, Ok(ToolCallResult::from(result)));
|
||||
}
|
||||
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
let result = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
||||
let result: ToolCallResult = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
||||
// Check if the tool is read_resource and handle it separately
|
||||
extension_manager
|
||||
.read_resource(tool_call.arguments.clone())
|
||||
.await
|
||||
ToolCallResult::from(
|
||||
extension_manager
|
||||
.read_resource(tool_call.arguments.clone())
|
||||
.await,
|
||||
)
|
||||
} else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME {
|
||||
extension_manager
|
||||
.list_resources(tool_call.arguments.clone())
|
||||
.await
|
||||
ToolCallResult::from(
|
||||
extension_manager
|
||||
.list_resources(tool_call.arguments.clone())
|
||||
.await,
|
||||
)
|
||||
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
|
||||
extension_manager.search_available_extensions().await
|
||||
ToolCallResult::from(extension_manager.search_available_extensions().await)
|
||||
} else if self.is_frontend_tool(&tool_call.name).await {
|
||||
// For frontend tools, return an error indicating we need frontend execution
|
||||
Err(ToolError::ExecutionError(
|
||||
ToolCallResult::from(Err(ToolError::ExecutionError(
|
||||
"Frontend tool execution required".to_string(),
|
||||
))
|
||||
)))
|
||||
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if let Some(selector) = selector {
|
||||
ToolCallResult::from(if let Some(selector) = selector {
|
||||
selector.select_tools(tool_call.arguments.clone()).await
|
||||
} else {
|
||||
Err(ToolError::ExecutionError(
|
||||
"Encountered vector search error.".to_string(),
|
||||
))
|
||||
}
|
||||
})
|
||||
} else {
|
||||
extension_manager
|
||||
// Clone the result to ensure no references to extension_manager are returned
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(tool_call.clone())
|
||||
.await
|
||||
.await;
|
||||
match result {
|
||||
Ok(call_result) => call_result,
|
||||
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"input" = serde_json::to_string(&tool_call).unwrap(),
|
||||
"output" = serde_json::to_string(&result).unwrap(),
|
||||
);
|
||||
|
||||
// Process the response to handle large text content
|
||||
let processed_result = super::large_response_handler::process_tool_response(result);
|
||||
|
||||
(request_id, processed_result)
|
||||
(
|
||||
request_id,
|
||||
Ok(ToolCallResult {
|
||||
notification_stream: result.notification_stream,
|
||||
result: Box::new(
|
||||
result
|
||||
.result
|
||||
.map(super::large_response_handler::process_tool_response),
|
||||
),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) async fn manage_extensions(
|
||||
@@ -466,7 +524,7 @@ impl Agent {
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session: Option<SessionConfig>,
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<AgentEvent>>> {
|
||||
let mut messages = messages.to_vec();
|
||||
let reply_span = tracing::Span::current();
|
||||
|
||||
@@ -532,9 +590,8 @@ impl Agent {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Yield the assistant's response with frontend tool requests filtered out
|
||||
yield filtered_response.clone();
|
||||
yield AgentEvent::Message(filtered_response.clone());
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
@@ -556,7 +613,7 @@ impl Agent {
|
||||
// execution is yeield back to this reply loop, and is of the same Message
|
||||
// type, so we can yield that back up to be handled
|
||||
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
||||
yield msg;
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
@@ -584,13 +641,23 @@ impl Agent {
|
||||
self.provider().await?).await;
|
||||
|
||||
// Handle pre-approved and read-only tools in parallel
|
||||
let mut tool_futures: Vec<ToolFuture> = Vec::new();
|
||||
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
|
||||
|
||||
// Skip the confirmation for approved tools
|
||||
for request in &permission_check_result.approved {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let tool_future = self.dispatch_tool_call(tool_call, request.id.clone());
|
||||
tool_futures.push(Box::pin(tool_future));
|
||||
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
|
||||
|
||||
tool_futures.push((req_id, match tool_result {
|
||||
Ok(result) => tool_stream(
|
||||
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||
result.result,
|
||||
),
|
||||
Err(e) => tool_stream(
|
||||
Box::new(stream::empty()),
|
||||
futures::future::ready(Err(e)),
|
||||
),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,7 +685,7 @@ impl Agent {
|
||||
// type, so we can yield the Message back up to be handled and grab any
|
||||
// confirmations or denials
|
||||
while let Some(msg) = tool_approval_stream.try_next().await? {
|
||||
yield msg;
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
tool_futures = {
|
||||
@@ -628,16 +695,30 @@ impl Agent {
|
||||
futures_lock.drain(..).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
// Wait for all tool calls to complete
|
||||
let results = futures::future::join_all(tool_futures).await;
|
||||
let with_id = tool_futures
|
||||
.into_iter()
|
||||
.map(|(request_id, stream)| {
|
||||
stream.map(move |item| (request_id.clone(), item))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut combined = stream::select_all(with_id);
|
||||
|
||||
let mut all_install_successful = true;
|
||||
|
||||
for (request_id, output) in results.into_iter() {
|
||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||
all_install_successful = false;
|
||||
while let Some((request_id, item)) = combined.next().await {
|
||||
match item {
|
||||
ToolStreamItem::Result(output) => {
|
||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||
all_install_successful = false;
|
||||
}
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(request_id, output);
|
||||
},
|
||||
ToolStreamItem::Message(msg) => {
|
||||
yield AgentEvent::McpNotification((request_id, msg))
|
||||
}
|
||||
}
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(request_id, output);
|
||||
}
|
||||
|
||||
// Update system prompt and tools if installations were successful
|
||||
@@ -647,7 +728,7 @@ impl Agent {
|
||||
}
|
||||
|
||||
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
||||
yield final_message_tool_resp.clone();
|
||||
yield AgentEvent::Message(final_message_tool_resp.clone());
|
||||
|
||||
messages.push(response);
|
||||
messages.push(final_message_tool_resp);
|
||||
@@ -656,15 +737,15 @@ impl Agent {
|
||||
// At this point, the last message should be a user message
|
||||
// because call to provider led to context length exceeded error
|
||||
// Immediately yield a special message and break
|
||||
yield Message::assistant().with_context_length_exceeded(
|
||||
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
|
||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||
);
|
||||
));
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
error!("Error: {}", e);
|
||||
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
|
||||
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use futures::future;
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use mcp_client::McpService;
|
||||
use futures::{future, FutureExt};
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
@@ -10,15 +9,22 @@ use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task;
|
||||
use tracing::{debug, error, warn};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
|
||||
use super::tool_execution::ToolCallResult;
|
||||
use crate::agents::extension::Envs;
|
||||
use crate::config::{Config, ExtensionConfigManager};
|
||||
use crate::prompt_template;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
<<<<<<< HEAD
|
||||
use mcp_client::transport::{PendingRequests, SseTransport, StdioTransport, Transport};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
|
||||
=======
|
||||
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError};
|
||||
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
|
||||
use serde_json::Value;
|
||||
|
||||
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
|
||||
@@ -115,7 +121,8 @@ impl ExtensionManager {
|
||||
/// Add a new MCP extension based on the provided client type
|
||||
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
|
||||
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
|
||||
let sanitized_name = normalize(config.key().to_string());
|
||||
let config_name = config.key().to_string();
|
||||
let sanitized_name = normalize(config_name.clone());
|
||||
|
||||
/// Helper function to merge environment variables from direct envs and keychain-stored env_keys
|
||||
async fn merge_environments(
|
||||
@@ -185,6 +192,7 @@ impl ExtensionManager {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport = SseTransport::new(uri, all_envs);
|
||||
let handle = transport.start().await?;
|
||||
<<<<<<< HEAD
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
@@ -194,6 +202,17 @@ impl ExtensionManager {
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
=======
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
|
||||
}
|
||||
ExtensionConfig::Stdio {
|
||||
cmd,
|
||||
@@ -206,6 +225,7 @@ impl ExtensionManager {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
|
||||
let handle = transport.start().await?;
|
||||
<<<<<<< HEAD
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
@@ -215,6 +235,17 @@ impl ExtensionManager {
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
=======
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
|
||||
}
|
||||
ExtensionConfig::Builtin {
|
||||
name,
|
||||
@@ -233,6 +264,7 @@ impl ExtensionManager {
|
||||
HashMap::new(),
|
||||
);
|
||||
let handle = transport.start().await?;
|
||||
<<<<<<< HEAD
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
@@ -242,6 +274,17 @@ impl ExtensionManager {
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
=======
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
@@ -627,7 +670,7 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> ToolResult<Vec<Content>> {
|
||||
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> Result<ToolCallResult> {
|
||||
// Dispatch tool call based on the prefix naming convention
|
||||
let (client_name, client) = self
|
||||
.get_client_for_tool(&tool_call.name)
|
||||
@@ -638,22 +681,26 @@ impl ExtensionManager {
|
||||
.name
|
||||
.strip_prefix(client_name)
|
||||
.and_then(|s| s.strip_prefix("__"))
|
||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
|
||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?
|
||||
.to_string();
|
||||
|
||||
let client_guard = client.lock().await;
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let client = client.clone();
|
||||
let notifications_receiver = client.lock().await.subscribe().await;
|
||||
|
||||
let result = client_guard
|
||||
.call_tool(tool_name, tool_call.clone().arguments)
|
||||
.await
|
||||
.map(|result| result.content)
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()));
|
||||
let fut = async move {
|
||||
let client_guard = client.lock().await;
|
||||
client_guard
|
||||
.call_tool(&tool_name, arguments)
|
||||
.await
|
||||
.map(|call| call.content)
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))
|
||||
};
|
||||
|
||||
debug!(
|
||||
"input" = serde_json::to_string(&tool_call).unwrap(),
|
||||
"output" = serde_json::to_string(&result).unwrap(),
|
||||
);
|
||||
|
||||
result
|
||||
Ok(ToolCallResult {
|
||||
result: Box::new(fut.boxed()),
|
||||
notification_stream: Some(Box::new(ReceiverStream::new(notifications_receiver))),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn list_prompts_from_extension(
|
||||
@@ -811,10 +858,11 @@ mod tests {
|
||||
use mcp_client::client::Error;
|
||||
use mcp_client::client::McpClientTrait;
|
||||
use mcp_core::protocol::{
|
||||
CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult,
|
||||
ListToolsResult, ReadResourceResult,
|
||||
CallToolResult, GetPromptResult, InitializeResult, JsonRpcMessage, ListPromptsResult,
|
||||
ListResourcesResult, ListToolsResult, ReadResourceResult,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
struct MockClient {}
|
||||
|
||||
@@ -867,6 +915,10 @@ mod tests {
|
||||
) -> Result<GetPromptResult, Error> {
|
||||
Err(Error::NotInitialized)
|
||||
}
|
||||
|
||||
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
|
||||
mpsc::channel(1).1
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -988,6 +1040,9 @@ mod tests {
|
||||
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(invalid_tool_call)
|
||||
.await
|
||||
.unwrap()
|
||||
.result
|
||||
.await;
|
||||
assert!(matches!(
|
||||
result.err().unwrap(),
|
||||
@@ -1004,6 +1059,11 @@ mod tests {
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(invalid_tool_call)
|
||||
.await;
|
||||
assert!(matches!(result.err().unwrap(), ToolError::NotFound(_)));
|
||||
if let Err(err) = result {
|
||||
let tool_err = err.downcast_ref::<ToolError>().expect("Expected ToolError");
|
||||
assert!(matches!(tool_err, ToolError::NotFound(_)));
|
||||
} else {
|
||||
panic!("Expected ToolError::NotFound");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,8 +3,7 @@ use mcp_core::{Content, ToolError};
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
|
||||
// Constant for the size threshold (20K characters)
|
||||
const LARGE_TEXT_THRESHOLD: usize = 20_000;
|
||||
const LARGE_TEXT_THRESHOLD: usize = 200_000;
|
||||
|
||||
/// Process tool response and handle large text content
|
||||
pub fn process_tool_response(
|
||||
|
||||
@@ -13,7 +13,7 @@ mod tool_router_index_manager;
|
||||
pub(crate) mod tool_vectordb;
|
||||
mod types;
|
||||
|
||||
pub use agent::Agent;
|
||||
pub use agent::{Agent, AgentEvent};
|
||||
pub use extension::ExtensionConfig;
|
||||
pub use extension_manager::ExtensionManager;
|
||||
pub use prompt_manager::PromptManager;
|
||||
|
||||
@@ -8,7 +8,8 @@ use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use crate::providers::base::{Provider, ProviderUsage};
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::toolshim::{
|
||||
augment_message_with_tool_calls, modify_system_prompt_for_tool_json, OllamaInterpreter,
|
||||
augment_message_with_tool_calls, convert_tool_messages_to_text,
|
||||
modify_system_prompt_for_tool_json, OllamaInterpreter,
|
||||
};
|
||||
use crate::session;
|
||||
use mcp_core::tool::Tool;
|
||||
@@ -110,8 +111,17 @@ impl Agent {
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let config = provider.get_model_config();
|
||||
|
||||
// Convert tool messages to text if toolshim is enabled
|
||||
let messages_for_provider = if config.toolshim {
|
||||
convert_tool_messages_to_text(messages)
|
||||
} else {
|
||||
messages.to_vec()
|
||||
};
|
||||
|
||||
// Call the provider to get a response
|
||||
let (mut response, usage) = provider.complete(system_prompt, messages, tools).await?;
|
||||
let (mut response, usage) = provider
|
||||
.complete(system_prompt, &messages_for_provider, tools)
|
||||
.await?;
|
||||
|
||||
// Store the model information in the global store
|
||||
crate::providers::base::set_current_model(&usage.model);
|
||||
|
||||
@@ -39,13 +39,13 @@ impl VectorToolSelector {
|
||||
pub async fn new(provider: Arc<dyn Provider>, table_name: String) -> Result<Self> {
|
||||
let vector_db = ToolVectorDB::new(Some(table_name)).await?;
|
||||
|
||||
let embedding_provider = if env::var("EMBEDDING_MODEL_PROVIDER").is_ok() {
|
||||
let embedding_provider = if env::var("GOOSE_EMBEDDING_MODEL_PROVIDER").is_ok() {
|
||||
// If env var is set, create a new provider for embeddings
|
||||
// Get embedding model and provider from environment variables
|
||||
let embedding_model = env::var("EMBEDDING_MODEL")
|
||||
let embedding_model = env::var("GOOSE_EMBEDDING_MODEL")
|
||||
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
|
||||
let embedding_provider_name =
|
||||
env::var("EMBEDDING_MODEL_PROVIDER").unwrap_or_else(|_| "openai".to_string());
|
||||
env::var("GOOSE_EMBEDDING_MODEL_PROVIDER").unwrap_or_else(|_| "openai".to_string());
|
||||
|
||||
// Create the provider using the factory
|
||||
let model_config = ModelConfig::new(embedding_model);
|
||||
|
||||
@@ -1,23 +1,35 @@
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::try_stream;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::{self, BoxStream};
|
||||
use futures::{Stream, StreamExt};
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::config::permission::PermissionLevel;
|
||||
use crate::config::PermissionManager;
|
||||
use crate::message::{Message, ToolRequest};
|
||||
use crate::permission::Permission;
|
||||
use mcp_core::{Content, ToolError};
|
||||
use mcp_core::{Content, ToolResult};
|
||||
|
||||
// Type alias for ToolFutures - used in the agent loop to join all futures together
|
||||
pub(crate) type ToolFuture<'a> =
|
||||
Pin<Box<dyn Future<Output = (String, Result<Vec<Content>, ToolError>)> + Send + 'a>>;
|
||||
pub(crate) type ToolFuturesVec<'a> = Arc<Mutex<Vec<ToolFuture<'a>>>>;
|
||||
// ToolCallResult combines the result of a tool call with an optional notification stream that
|
||||
// can be used to receive notifications from the tool.
|
||||
pub struct ToolCallResult {
|
||||
pub result: Box<dyn Future<Output = ToolResult<Vec<Content>>> + Send + Unpin>,
|
||||
pub notification_stream: Option<Box<dyn Stream<Item = JsonRpcMessage> + Send + Unpin>>,
|
||||
}
|
||||
|
||||
impl From<ToolResult<Vec<Content>>> for ToolCallResult {
|
||||
fn from(result: ToolResult<Vec<Content>>) -> Self {
|
||||
Self {
|
||||
result: Box::new(futures::future::ready(result)),
|
||||
notification_stream: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use super::agent::{tool_stream, ToolStream};
|
||||
use crate::agents::Agent;
|
||||
|
||||
pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
|
||||
@@ -37,7 +49,7 @@ impl Agent {
|
||||
pub(crate) fn handle_approval_tool_requests<'a>(
|
||||
&'a self,
|
||||
tool_requests: &'a [ToolRequest],
|
||||
tool_futures: ToolFuturesVec<'a>,
|
||||
tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
|
||||
permission_manager: &'a mut PermissionManager,
|
||||
message_tool_response: Arc<Mutex<Message>>,
|
||||
) -> BoxStream<'a, anyhow::Result<Message>> {
|
||||
@@ -56,9 +68,19 @@ impl Agent {
|
||||
while let Some((req_id, confirmation)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
|
||||
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
|
||||
let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone()).await;
|
||||
let mut futures = tool_futures.lock().await;
|
||||
futures.push(Box::pin(tool_future));
|
||||
|
||||
futures.push((req_id, match tool_result {
|
||||
Ok(result) => tool_stream(
|
||||
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||
result.result,
|
||||
),
|
||||
Err(e) => tool_stream(
|
||||
Box::new(stream::empty()),
|
||||
futures::future::ready(Err(e)),
|
||||
),
|
||||
}));
|
||||
|
||||
if confirmation.permission == Permission::AlwaysAllow {
|
||||
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
|
||||
|
||||
@@ -148,6 +148,12 @@ impl Usage {
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Trait for LeadWorkerProvider-specific functionality
|
||||
pub trait LeadWorkerProviderTrait {
|
||||
/// Get information about the lead and worker models for logging
|
||||
fn get_model_info(&self) -> (String, String);
|
||||
}
|
||||
|
||||
/// Base trait for AI providers (OpenAI, Anthropic, etc)
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
@@ -195,6 +201,12 @@ pub trait Provider: Send + Sync {
|
||||
"This provider does not support embeddings".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Check if this provider is a LeadWorkerProvider
|
||||
/// This is used for logging model information at startup
|
||||
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -17,6 +17,7 @@ use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
const DEFAULT_CLIENT_ID: &str = "databricks-cli";
|
||||
const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
||||
@@ -24,6 +25,17 @@ const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020";
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
|
||||
const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
|
||||
|
||||
/// Default timeout for API requests in seconds
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 600;
|
||||
/// Default initial interval for retry (in milliseconds)
|
||||
const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000;
|
||||
/// Default maximum number of retries
|
||||
const DEFAULT_MAX_RETRIES: usize = 6;
|
||||
/// Default retry backoff multiplier
|
||||
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
|
||||
/// Default maximum interval for retry (in milliseconds)
|
||||
const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000;
|
||||
|
||||
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet";
|
||||
// Databricks can passthrough to a wide range of models, we only provide the default
|
||||
pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
||||
@@ -36,6 +48,53 @@ pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
||||
pub const DATABRICKS_DOC_URL: &str =
|
||||
"https://docs.databricks.com/en/generative-ai/external-models/index.html";
|
||||
|
||||
/// Retry configuration for handling rate limit errors
|
||||
#[derive(Debug, Clone)]
|
||||
struct RetryConfig {
|
||||
/// Maximum number of retry attempts
|
||||
max_retries: usize,
|
||||
/// Initial interval between retries in milliseconds
|
||||
initial_interval_ms: u64,
|
||||
/// Multiplier for backoff (exponential)
|
||||
backoff_multiplier: f64,
|
||||
/// Maximum interval between retries in milliseconds
|
||||
max_interval_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS,
|
||||
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
|
||||
max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// Calculate the delay for a specific retry attempt (with jitter)
|
||||
fn delay_for_attempt(&self, attempt: usize) -> Duration {
|
||||
if attempt == 0 {
|
||||
return Duration::from_millis(0);
|
||||
}
|
||||
|
||||
// Calculate exponential backoff
|
||||
let exponent = (attempt - 1) as u32;
|
||||
let base_delay_ms = (self.initial_interval_ms as f64
|
||||
* self.backoff_multiplier.powi(exponent as i32)) as u64;
|
||||
|
||||
// Apply max limit
|
||||
let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms);
|
||||
|
||||
// Add jitter (+/-20% randomness) to avoid thundering herd problem
|
||||
let jitter_factor = 0.8 + (rand::random::<f64>() * 0.4); // Between 0.8 and 1.2
|
||||
let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64;
|
||||
|
||||
Duration::from_millis(jittered_delay_ms)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum DatabricksAuth {
|
||||
Token(String),
|
||||
@@ -70,6 +129,8 @@ pub struct DatabricksProvider {
|
||||
auth: DatabricksAuth,
|
||||
model: ModelConfig,
|
||||
image_format: ImageFormat,
|
||||
#[serde(skip)]
|
||||
retry_config: RetryConfig,
|
||||
}
|
||||
|
||||
impl Default for DatabricksProvider {
|
||||
@@ -100,9 +161,12 @@ impl DatabricksProvider {
|
||||
let host = host?;
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(600))
|
||||
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
|
||||
.build()?;
|
||||
|
||||
// Load optional retry configuration from environment
|
||||
let retry_config = Self::load_retry_config(config);
|
||||
|
||||
// If we find a databricks token we prefer that
|
||||
if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") {
|
||||
return Ok(Self {
|
||||
@@ -111,6 +175,7 @@ impl DatabricksProvider {
|
||||
auth: DatabricksAuth::token(api_key),
|
||||
model,
|
||||
image_format: ImageFormat::OpenAi,
|
||||
retry_config,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -121,9 +186,44 @@ impl DatabricksProvider {
|
||||
host,
|
||||
model,
|
||||
image_format: ImageFormat::OpenAi,
|
||||
retry_config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Loads retry configuration from environment variables or uses defaults.
|
||||
fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
|
||||
let max_retries = config
|
||||
.get_param("DATABRICKS_MAX_RETRIES")
|
||||
.ok()
|
||||
.and_then(|v: String| v.parse::<usize>().ok())
|
||||
.unwrap_or(DEFAULT_MAX_RETRIES);
|
||||
|
||||
let initial_interval_ms = config
|
||||
.get_param("DATABRICKS_INITIAL_RETRY_INTERVAL_MS")
|
||||
.ok()
|
||||
.and_then(|v: String| v.parse::<u64>().ok())
|
||||
.unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
|
||||
|
||||
let backoff_multiplier = config
|
||||
.get_param("DATABRICKS_BACKOFF_MULTIPLIER")
|
||||
.ok()
|
||||
.and_then(|v: String| v.parse::<f64>().ok())
|
||||
.unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
|
||||
|
||||
let max_interval_ms = config
|
||||
.get_param("DATABRICKS_MAX_RETRY_INTERVAL_MS")
|
||||
.ok()
|
||||
.and_then(|v: String| v.parse::<u64>().ok())
|
||||
.unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
|
||||
|
||||
RetryConfig {
|
||||
max_retries,
|
||||
initial_interval_ms,
|
||||
backoff_multiplier,
|
||||
max_interval_ms,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new DatabricksProvider with the specified host and token
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -145,6 +245,7 @@ impl DatabricksProvider {
|
||||
auth: DatabricksAuth::token(api_key),
|
||||
model,
|
||||
image_format: ImageFormat::OpenAi,
|
||||
retry_config: RetryConfig::default(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -182,70 +283,148 @@ impl DatabricksProvider {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let auth_header = self.ensure_auth_header().await?;
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", auth_header)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
// Initialize retry counter
|
||||
let mut attempts = 0;
|
||||
let mut last_error = None;
|
||||
|
||||
let status = response.status();
|
||||
let payload: Option<Value> = response.json().await.ok();
|
||||
|
||||
match status {
|
||||
StatusCode::OK => payload.ok_or_else(|| ProviderError::RequestFailed("Response body is not valid JSON".to_string())),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||
Status: {}. Response: {:?}", status, payload)))
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
|
||||
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
|
||||
let payload_str = serde_json::to_string(&payload).unwrap_or_default().to_lowercase();
|
||||
let check_phrases = [
|
||||
"too long",
|
||||
"context length",
|
||||
"context_length_exceeded",
|
||||
"reduce the length",
|
||||
"token count",
|
||||
"exceeds",
|
||||
];
|
||||
if check_phrases.iter().any(|c| payload_str.contains(c)) {
|
||||
return Err(ProviderError::ContextLengthExceeded(payload_str));
|
||||
}
|
||||
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
// try to convert message to string, if that fails use external_model_message
|
||||
error_msg = payload
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.or_else(|| {
|
||||
payload.get("external_model_message")
|
||||
.and_then(|ext| ext.get("message"))
|
||||
.and_then(|m| m.as_str())
|
||||
})
|
||||
.unwrap_or("Unknown error").to_string();
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
loop {
|
||||
// Check if we've exceeded max retries
|
||||
if attempts > 0 && attempts > self.retry_config.max_retries {
|
||||
let error_msg = format!(
|
||||
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
|
||||
self.retry_config.max_retries
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)))
|
||||
tracing::error!("{}", error_msg);
|
||||
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Err(ProviderError::ServerError(format!("{:?}", payload)))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
|
||||
|
||||
let auth_header = self.ensure_auth_header().await?;
|
||||
let response = self
|
||||
.client
|
||||
.post(url.clone())
|
||||
.header("Authorization", auth_header)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let payload: Option<Value> = response.json().await.ok();
|
||||
|
||||
match status {
|
||||
StatusCode::OK => {
|
||||
return payload.ok_or_else(|| {
|
||||
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
|
||||
});
|
||||
}
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
return Err(ProviderError::Authentication(format!(
|
||||
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||
Status: {}. Response: {:?}",
|
||||
status, payload
|
||||
)));
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Databricks provides a generic 'error' but also includes 'external_model_message' which is provider specific
|
||||
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
|
||||
let payload_str = serde_json::to_string(&payload)
|
||||
.unwrap_or_default()
|
||||
.to_lowercase();
|
||||
let check_phrases = [
|
||||
"too long",
|
||||
"context length",
|
||||
"context_length_exceeded",
|
||||
"reduce the length",
|
||||
"token count",
|
||||
"exceeds",
|
||||
"exceed context limit",
|
||||
"input length",
|
||||
"max_tokens",
|
||||
"decrease input length",
|
||||
"context limit",
|
||||
];
|
||||
if check_phrases.iter().any(|c| payload_str.contains(c)) {
|
||||
return Err(ProviderError::ContextLengthExceeded(payload_str));
|
||||
}
|
||||
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
// try to convert message to string, if that fails use external_model_message
|
||||
error_msg = payload
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.or_else(|| {
|
||||
payload
|
||||
.get("external_model_message")
|
||||
.and_then(|ext| ext.get("message"))
|
||||
.and_then(|m| m.as_str())
|
||||
})
|
||||
.unwrap_or("Unknown error")
|
||||
.to_string();
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
)
|
||||
);
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed with status: {}. Message: {}",
|
||||
status, error_msg
|
||||
)));
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
attempts += 1;
|
||||
let error_msg = format!(
|
||||
"Rate limit exceeded (attempt {}/{}): {:?}",
|
||||
attempts, self.retry_config.max_retries, payload
|
||||
);
|
||||
tracing::warn!("{}. Retrying after backoff...", error_msg);
|
||||
|
||||
// Store the error in case we need to return it after max retries
|
||||
last_error = Some(ProviderError::RateLimitExceeded(error_msg));
|
||||
|
||||
// Calculate and apply the backoff delay
|
||||
let delay = self.retry_config.delay_for_attempt(attempts);
|
||||
tracing::info!("Backing off for {:?} before retry", delay);
|
||||
sleep(delay).await;
|
||||
|
||||
// Continue to the next retry attempt
|
||||
continue;
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
attempts += 1;
|
||||
let error_msg = format!(
|
||||
"Server error (attempt {}/{}): {:?}",
|
||||
attempts, self.retry_config.max_retries, payload
|
||||
);
|
||||
tracing::warn!("{}. Retrying after backoff...", error_msg);
|
||||
|
||||
// Store the error in case we need to return it after max retries
|
||||
last_error = Some(ProviderError::ServerError(error_msg));
|
||||
|
||||
// Calculate and apply the backoff delay
|
||||
let delay = self.retry_config.delay_for_attempt(attempts);
|
||||
tracing::info!("Backing off for {:?} before retry", delay);
|
||||
sleep(delay).await;
|
||||
|
||||
// Continue to the next retry attempt
|
||||
continue;
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
)
|
||||
);
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed with status: {}",
|
||||
status
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,14 +10,31 @@ use super::{
|
||||
githubcopilot::GithubCopilotProvider,
|
||||
google::GoogleProvider,
|
||||
groq::GroqProvider,
|
||||
lead_worker::LeadWorkerProvider,
|
||||
ollama::OllamaProvider,
|
||||
openai::OpenAiProvider,
|
||||
openrouter::OpenRouterProvider,
|
||||
snowflake::SnowflakeProvider,
|
||||
venice::VeniceProvider,
|
||||
};
|
||||
use crate::model::ModelConfig;
|
||||
use anyhow::Result;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::errors::ProviderError;
|
||||
#[cfg(test)]
|
||||
use mcp_core::tool::Tool;
|
||||
|
||||
fn default_lead_turns() -> usize {
|
||||
3
|
||||
}
|
||||
fn default_failure_threshold() -> usize {
|
||||
2
|
||||
}
|
||||
fn default_fallback_turns() -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
pub fn providers() -> Vec<ProviderMetadata> {
|
||||
vec![
|
||||
AnthropicProvider::metadata(),
|
||||
@@ -32,10 +49,67 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
||||
OpenAiProvider::metadata(),
|
||||
OpenRouterProvider::metadata(),
|
||||
VeniceProvider::metadata(),
|
||||
SnowflakeProvider::metadata(),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||
let config = crate::config::Config::global();
|
||||
|
||||
// Check for lead model environment variables
|
||||
if let Ok(lead_model_name) = config.get_param::<String>("GOOSE_LEAD_MODEL") {
|
||||
tracing::info!("Creating lead/worker provider from environment variables");
|
||||
|
||||
return create_lead_worker_from_env(name, &model, &lead_model_name);
|
||||
}
|
||||
|
||||
// Default: create regular provider
|
||||
create_provider(name, model)
|
||||
}
|
||||
|
||||
/// Create a lead/worker provider from environment variables
|
||||
fn create_lead_worker_from_env(
|
||||
default_provider_name: &str,
|
||||
default_model: &ModelConfig,
|
||||
lead_model_name: &str,
|
||||
) -> Result<Arc<dyn Provider>> {
|
||||
let config = crate::config::Config::global();
|
||||
|
||||
// Get lead provider (optional, defaults to main provider)
|
||||
let lead_provider_name = config
|
||||
.get_param::<String>("GOOSE_LEAD_PROVIDER")
|
||||
.unwrap_or_else(|_| default_provider_name.to_string());
|
||||
|
||||
// Get configuration parameters with defaults
|
||||
let lead_turns = config
|
||||
.get_param::<usize>("GOOSE_LEAD_TURNS")
|
||||
.unwrap_or(default_lead_turns());
|
||||
let failure_threshold = config
|
||||
.get_param::<usize>("GOOSE_LEAD_FAILURE_THRESHOLD")
|
||||
.unwrap_or(default_failure_threshold());
|
||||
let fallback_turns = config
|
||||
.get_param::<usize>("GOOSE_LEAD_FALLBACK_TURNS")
|
||||
.unwrap_or(default_fallback_turns());
|
||||
|
||||
// Create model configs
|
||||
let lead_model_config = ModelConfig::new(lead_model_name.to_string());
|
||||
let worker_model_config = default_model.clone();
|
||||
|
||||
// Create the providers
|
||||
let lead_provider = create_provider(&lead_provider_name, lead_model_config)?;
|
||||
let worker_provider = create_provider(default_provider_name, worker_model_config)?;
|
||||
|
||||
// Create the lead/worker provider with configured settings
|
||||
Ok(Arc::new(LeadWorkerProvider::new_with_settings(
|
||||
lead_provider,
|
||||
worker_provider,
|
||||
lead_turns,
|
||||
failure_threshold,
|
||||
fallback_turns,
|
||||
)))
|
||||
}
|
||||
|
||||
fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||
// We use Arc instead of Box to be able to clone for multiple async tasks
|
||||
match name {
|
||||
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
|
||||
@@ -49,7 +123,220 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),
|
||||
"google" => Ok(Arc::new(GoogleProvider::from_env(model)?)),
|
||||
"venice" => Ok(Arc::new(VeniceProvider::from_env(model)?)),
|
||||
"snowflake" => Ok(Arc::new(SnowflakeProvider::from_env(model)?)),
|
||||
"github_copilot" => Ok(Arc::new(GithubCopilotProvider::from_env(model)?)),
|
||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
|
||||
use chrono::Utc;
|
||||
use mcp_core::{content::TextContent, Role};
|
||||
use std::env;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockTestProvider {
|
||||
name: String,
|
||||
model_config: ModelConfig,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for MockTestProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::new(
|
||||
"mock_test",
|
||||
"Mock Test Provider",
|
||||
"A mock provider for testing",
|
||||
"mock-model",
|
||||
vec!["mock-model"],
|
||||
"",
|
||||
vec![],
|
||||
)
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
text: format!(
|
||||
"Response from {} with model {}",
|
||||
self.name, self.model_config.model_name
|
||||
),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_lead_worker_provider() {
|
||||
// Save current env vars
|
||||
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
|
||||
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
|
||||
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
|
||||
|
||||
// Test with basic lead model configuration
|
||||
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
|
||||
|
||||
// This will try to create a lead/worker provider
|
||||
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||
|
||||
// The creation might succeed or fail depending on API keys, but we can verify the logic path
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// If it succeeds, it means we created a lead/worker provider successfully
|
||||
// This would happen if API keys are available in the test environment
|
||||
}
|
||||
Err(error) => {
|
||||
// If it fails, it should be due to missing API keys, confirming we tried to create providers
|
||||
let error_msg = error.to_string();
|
||||
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
|
||||
}
|
||||
}
|
||||
|
||||
// Test with different lead provider
|
||||
env::set_var("GOOSE_LEAD_PROVIDER", "anthropic");
|
||||
env::set_var("GOOSE_LEAD_TURNS", "5");
|
||||
|
||||
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||
// Similar validation as above - will fail due to missing API keys but confirms the logic
|
||||
|
||||
// Restore env vars
|
||||
match saved_lead {
|
||||
Some(val) => env::set_var("GOOSE_LEAD_MODEL", val),
|
||||
None => env::remove_var("GOOSE_LEAD_MODEL"),
|
||||
}
|
||||
match saved_provider {
|
||||
Some(val) => env::set_var("GOOSE_LEAD_PROVIDER", val),
|
||||
None => env::remove_var("GOOSE_LEAD_PROVIDER"),
|
||||
}
|
||||
match saved_turns {
|
||||
Some(val) => env::set_var("GOOSE_LEAD_TURNS", val),
|
||||
None => env::remove_var("GOOSE_LEAD_TURNS"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lead_model_env_vars_with_defaults() {
|
||||
// Save current env vars
|
||||
let saved_vars = [
|
||||
("GOOSE_LEAD_MODEL", env::var("GOOSE_LEAD_MODEL").ok()),
|
||||
("GOOSE_LEAD_PROVIDER", env::var("GOOSE_LEAD_PROVIDER").ok()),
|
||||
("GOOSE_LEAD_TURNS", env::var("GOOSE_LEAD_TURNS").ok()),
|
||||
(
|
||||
"GOOSE_LEAD_FAILURE_THRESHOLD",
|
||||
env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok(),
|
||||
),
|
||||
(
|
||||
"GOOSE_LEAD_FALLBACK_TURNS",
|
||||
env::var("GOOSE_LEAD_FALLBACK_TURNS").ok(),
|
||||
),
|
||||
];
|
||||
|
||||
// Clear all lead env vars
|
||||
for (key, _) in &saved_vars {
|
||||
env::remove_var(key);
|
||||
}
|
||||
|
||||
// Set only the required lead model
|
||||
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
|
||||
|
||||
// This should use defaults for all other values
|
||||
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||
|
||||
// Should attempt to create lead/worker provider (will fail due to missing API keys but confirms logic)
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Success means we have API keys and created the provider
|
||||
}
|
||||
Err(error) => {
|
||||
// Should fail due to missing API keys, confirming we tried to create providers
|
||||
let error_msg = error.to_string();
|
||||
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
|
||||
}
|
||||
}
|
||||
|
||||
// Test with custom values
|
||||
env::set_var("GOOSE_LEAD_TURNS", "7");
|
||||
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", "4");
|
||||
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", "3");
|
||||
|
||||
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||
// Should still attempt to create lead/worker provider with custom settings
|
||||
|
||||
// Restore all env vars
|
||||
for (key, value) in saved_vars {
|
||||
match value {
|
||||
Some(val) => env::set_var(key, val),
|
||||
None => env::remove_var(key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_regular_provider_without_lead_config() {
|
||||
// Save current env vars
|
||||
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
|
||||
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
|
||||
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
|
||||
let saved_threshold = env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok();
|
||||
let saved_fallback = env::var("GOOSE_LEAD_FALLBACK_TURNS").ok();
|
||||
|
||||
// Ensure all GOOSE_LEAD_* variables are not set
|
||||
env::remove_var("GOOSE_LEAD_MODEL");
|
||||
env::remove_var("GOOSE_LEAD_PROVIDER");
|
||||
env::remove_var("GOOSE_LEAD_TURNS");
|
||||
env::remove_var("GOOSE_LEAD_FAILURE_THRESHOLD");
|
||||
env::remove_var("GOOSE_LEAD_FALLBACK_TURNS");
|
||||
|
||||
// This should try to create a regular provider
|
||||
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||
|
||||
// The creation might succeed or fail depending on API keys
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// If it succeeds, it means we created a regular provider successfully
|
||||
// This would happen if API keys are available in the test environment
|
||||
}
|
||||
Err(error) => {
|
||||
// If it fails, it should be due to missing API keys
|
||||
let error_msg = error.to_string();
|
||||
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
|
||||
}
|
||||
}
|
||||
|
||||
// Restore env vars
|
||||
if let Some(val) = saved_lead {
|
||||
env::set_var("GOOSE_LEAD_MODEL", val);
|
||||
}
|
||||
if let Some(val) = saved_provider {
|
||||
env::set_var("GOOSE_LEAD_PROVIDER", val);
|
||||
}
|
||||
if let Some(val) = saved_turns {
|
||||
env::set_var("GOOSE_LEAD_TURNS", val);
|
||||
}
|
||||
if let Some(val) = saved_threshold {
|
||||
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", val);
|
||||
}
|
||||
if let Some(val) = saved_fallback {
|
||||
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +98,10 @@ pub enum GeminiVersion {
|
||||
Pro20Exp,
|
||||
/// Gemini 2.5 Pro Experimental version
|
||||
Pro25Exp,
|
||||
/// Gemini 2.5 Flash Preview version
|
||||
Flash25Preview,
|
||||
/// Gemini 2.5 Pro Preview version
|
||||
Pro25Preview,
|
||||
/// Generic Gemini model for custom or new versions
|
||||
Generic(String),
|
||||
}
|
||||
@@ -118,6 +122,8 @@ impl fmt::Display for GcpVertexAIModel {
|
||||
GeminiVersion::Flash20 => "gemini-2.0-flash-001",
|
||||
GeminiVersion::Pro20Exp => "gemini-2.0-pro-exp-02-05",
|
||||
GeminiVersion::Pro25Exp => "gemini-2.5-pro-exp-03-25",
|
||||
GeminiVersion::Flash25Preview => "gemini-2.5-flash-preview-05-20",
|
||||
GeminiVersion::Pro25Preview => "gemini-2.5-pro-preview-05-06",
|
||||
GeminiVersion::Generic(name) => name,
|
||||
},
|
||||
};
|
||||
@@ -154,6 +160,8 @@ impl TryFrom<&str> for GcpVertexAIModel {
|
||||
"gemini-2.0-flash-001" => Ok(Self::Gemini(GeminiVersion::Flash20)),
|
||||
"gemini-2.0-pro-exp-02-05" => Ok(Self::Gemini(GeminiVersion::Pro20Exp)),
|
||||
"gemini-2.5-pro-exp-03-25" => Ok(Self::Gemini(GeminiVersion::Pro25Exp)),
|
||||
"gemini-2.5-flash-preview-05-20" => Ok(Self::Gemini(GeminiVersion::Flash25Preview)),
|
||||
"gemini-2.5-pro-preview-05-06" => Ok(Self::Gemini(GeminiVersion::Pro25Preview)),
|
||||
// Generic models based on prefix matching
|
||||
_ if s.starts_with("claude-") => {
|
||||
Ok(Self::Claude(ClaudeVersion::Generic(s.to_string())))
|
||||
@@ -349,6 +357,8 @@ mod tests {
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-pro-exp-02-05",
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.5-flash-preview-05-20",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
];
|
||||
|
||||
for model_id in valid_models {
|
||||
@@ -372,6 +382,8 @@ mod tests {
|
||||
("gemini-2.0-flash-001", GcpLocation::Iowa),
|
||||
("gemini-2.0-pro-exp-02-05", GcpLocation::Iowa),
|
||||
("gemini-2.5-pro-exp-03-25", GcpLocation::Iowa),
|
||||
("gemini-2.5-flash-preview-05-20", GcpLocation::Iowa),
|
||||
("gemini-2.5-pro-preview-05-06", GcpLocation::Iowa),
|
||||
];
|
||||
|
||||
for (model_id, expected_location) in test_cases {
|
||||
|
||||
@@ -4,3 +4,4 @@ pub mod databricks;
|
||||
pub mod gcpvertexai;
|
||||
pub mod google;
|
||||
pub mod openai;
|
||||
pub mod snowflake;
|
||||
|
||||
716
crates/goose/src/providers/formats/snowflake.rs
Normal file
716
crates/goose/src/providers/formats/snowflake.rs
Normal file
@@ -0,0 +1,716 @@
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::Usage;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use anyhow::{anyhow, Result};
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::role::Role;
|
||||
use mcp_core::tool::{Tool, ToolCall};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Convert internal Message format to Snowflake's API message specification
|
||||
pub fn format_messages(messages: &[Message]) -> Vec<Value> {
|
||||
let mut snowflake_messages = Vec::new();
|
||||
|
||||
// Convert messages to Snowflake format
|
||||
for message in messages {
|
||||
let role = match message.role {
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
};
|
||||
|
||||
let mut text_content = String::new();
|
||||
|
||||
for msg_content in &message.content {
|
||||
match msg_content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text_content.is_empty() {
|
||||
text_content.push('\n');
|
||||
}
|
||||
text_content.push_str(&text.text);
|
||||
}
|
||||
MessageContent::ToolRequest(_tool_request) => {
|
||||
// Skip tool requests in message formatting - tools are handled separately
|
||||
// through the tools parameter in the API request
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolResponse(tool_response) => {
|
||||
if let Ok(result) = &tool_response.tool_result {
|
||||
let text = result
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
Content::Text(t) => Some(t.text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
if !text_content.is_empty() {
|
||||
text_content.push('\n');
|
||||
}
|
||||
if !text.is_empty() {
|
||||
text_content.push_str(&format!("Tool result: {}", text));
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::ToolConfirmationRequest(_) => {
|
||||
// Skip tool confirmation requests
|
||||
}
|
||||
MessageContent::ContextLengthExceeded(_) => {
|
||||
// Skip
|
||||
}
|
||||
MessageContent::SummarizationRequested(_) => {
|
||||
// Skip
|
||||
}
|
||||
MessageContent::Thinking(_thinking) => {
|
||||
// Skip thinking for now
|
||||
}
|
||||
MessageContent::RedactedThinking(_redacted) => {
|
||||
// Skip redacted thinking for now
|
||||
}
|
||||
MessageContent::Image(_) => continue, // Snowflake doesn't support image content yet
|
||||
MessageContent::FrontendToolRequest(_tool_request) => {
|
||||
// Skip frontend tool requests
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add message if it has text content
|
||||
if !text_content.is_empty() {
|
||||
snowflake_messages.push(json!({
|
||||
"role": role,
|
||||
"content": text_content
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Only add default message if we truly have no messages at all
|
||||
// This should be rare and only for edge cases
|
||||
if snowflake_messages.is_empty() {
|
||||
snowflake_messages.push(json!({
|
||||
"role": "user",
|
||||
"content": "Continue the conversation"
|
||||
}));
|
||||
}
|
||||
|
||||
snowflake_messages
|
||||
}
|
||||
|
||||
/// Convert internal Tool format to Snowflake's API tool specification
|
||||
pub fn format_tools(tools: &[Tool]) -> Vec<Value> {
|
||||
let mut unique_tools = HashSet::new();
|
||||
let mut tool_specs = Vec::new();
|
||||
|
||||
for tool in tools.iter() {
|
||||
if unique_tools.insert(tool.name.clone()) {
|
||||
let tool_spec = json!({
|
||||
"type": "generic",
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.input_schema
|
||||
});
|
||||
|
||||
tool_specs.push(json!({"tool_spec": tool_spec}));
|
||||
}
|
||||
}
|
||||
|
||||
tool_specs
|
||||
}
|
||||
|
||||
/// Convert system message to Snowflake's API system specification
|
||||
pub fn format_system(system: &str) -> Value {
|
||||
json!({
|
||||
"role": "system",
|
||||
"content": system,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert Snowflake's streaming API response to internal Message format
|
||||
pub fn parse_streaming_response(sse_data: &str) -> Result<Message> {
|
||||
let mut message = Message::assistant();
|
||||
let mut accumulated_text = String::new();
|
||||
let mut tool_use_id: Option<String> = None;
|
||||
let mut tool_name: Option<String> = None;
|
||||
let mut tool_input = String::new();
|
||||
|
||||
// Parse each SSE event
|
||||
for line in sse_data.lines() {
|
||||
if !line.starts_with("data: ") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let json_str = &line[6..]; // Remove "data: " prefix
|
||||
if json_str.trim().is_empty() || json_str.trim() == "[DONE]" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let event: Value = match serde_json::from_str(json_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(choices) = event.get("choices").and_then(|c| c.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(delta) = choice.get("delta") {
|
||||
match delta.get("type").and_then(|t| t.as_str()) {
|
||||
Some("text") => {
|
||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
accumulated_text.push_str(content);
|
||||
}
|
||||
}
|
||||
Some("tool_use") => {
|
||||
if let Some(id) = delta.get("tool_use_id").and_then(|i| i.as_str()) {
|
||||
tool_use_id = Some(id.to_string());
|
||||
}
|
||||
if let Some(name) = delta.get("name").and_then(|n| n.as_str()) {
|
||||
tool_name = Some(name.to_string());
|
||||
}
|
||||
if let Some(input) = delta.get("input").and_then(|i| i.as_str()) {
|
||||
tool_input.push_str(input);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add accumulated text if any
|
||||
if !accumulated_text.is_empty() {
|
||||
message = message.with_text(accumulated_text);
|
||||
}
|
||||
|
||||
// Add tool use if complete
|
||||
if let (Some(id), Some(name)) = (&tool_use_id, &tool_name) {
|
||||
if !tool_input.is_empty() {
|
||||
let input_value = serde_json::from_str::<Value>(&tool_input)
|
||||
.unwrap_or_else(|_| Value::String(tool_input.clone()));
|
||||
let tool_call = ToolCall::new(name, input_value);
|
||||
message = message.with_tool_request(id, Ok(tool_call));
|
||||
} else if tool_name.is_some() {
|
||||
// Tool with no input - use empty object
|
||||
let tool_call = ToolCall::new(name, Value::Object(serde_json::Map::new()));
|
||||
message = message.with_tool_request(id, Ok(tool_call));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
/// Convert Snowflake's API response to internal Message format
|
||||
pub fn response_to_message(response: Value) -> Result<Message> {
|
||||
let mut message = Message::assistant();
|
||||
|
||||
let content_list = response.get("content_list").and_then(|cl| cl.as_array());
|
||||
|
||||
// Handle case where content_list is missing or empty
|
||||
let content_list = match content_list {
|
||||
Some(list) if !list.is_empty() => list,
|
||||
_ => {
|
||||
// If no content_list or empty, check if there's a direct content field
|
||||
if let Some(direct_content) = response.get("content").and_then(|c| c.as_str()) {
|
||||
if !direct_content.is_empty() {
|
||||
message = message.with_text(direct_content.to_string());
|
||||
}
|
||||
return Ok(message);
|
||||
} else {
|
||||
// Return empty assistant message for empty responses
|
||||
return Ok(message);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Process all content items in the list
|
||||
for content in content_list {
|
||||
match content.get("type").and_then(|t| t.as_str()) {
|
||||
Some("text") => {
|
||||
if let Some(text) = content.get("text").and_then(|t| t.as_str()) {
|
||||
if !text.is_empty() {
|
||||
message = message.with_text(text.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
Some("tool_use") => {
|
||||
let id = content
|
||||
.get("tool_use_id")
|
||||
.and_then(|i| i.as_str())
|
||||
.ok_or_else(|| anyhow!("Missing tool_use id"))?;
|
||||
let name = content
|
||||
.get("name")
|
||||
.and_then(|n| n.as_str())
|
||||
.ok_or_else(|| anyhow!("Missing tool_use name"))?;
|
||||
|
||||
let input = content
|
||||
.get("input")
|
||||
.ok_or_else(|| anyhow!("Missing tool input"))?
|
||||
.clone();
|
||||
|
||||
let tool_call = ToolCall::new(name, input);
|
||||
message = message.with_tool_request(id, Ok(tool_call));
|
||||
}
|
||||
Some("thinking") => {
|
||||
let thinking = content
|
||||
.get("thinking")
|
||||
.and_then(|t| t.as_str())
|
||||
.ok_or_else(|| anyhow!("Missing thinking content"))?;
|
||||
let signature = content
|
||||
.get("signature")
|
||||
.and_then(|s| s.as_str())
|
||||
.ok_or_else(|| anyhow!("Missing thinking signature"))?;
|
||||
message = message.with_thinking(thinking, signature);
|
||||
}
|
||||
Some("redacted_thinking") => {
|
||||
let data = content
|
||||
.get("data")
|
||||
.and_then(|d| d.as_str())
|
||||
.ok_or_else(|| anyhow!("Missing redacted_thinking data"))?;
|
||||
message = message.with_redacted_thinking(data);
|
||||
}
|
||||
_ => {
|
||||
// Ignore unrecognized content types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
/// Extract usage information from Snowflake's API response
|
||||
pub fn get_usage(data: &Value) -> Result<Usage> {
|
||||
// Extract usage data if available
|
||||
if let Some(usage) = data.get("usage") {
|
||||
let input_tokens = usage
|
||||
.get("input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v as i32);
|
||||
|
||||
let output_tokens = usage
|
||||
.get("output_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v as i32);
|
||||
|
||||
let total_tokens = match (input_tokens, output_tokens) {
|
||||
(Some(input), Some(output)) => Some(input + output),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"Failed to get usage data: {}",
|
||||
ProviderError::UsageError("No usage data found in response".to_string())
|
||||
);
|
||||
// If no usage data, return None for all values
|
||||
Ok(Usage::new(None, None, None))
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a complete request payload for Snowflake's API
|
||||
pub fn create_request(
|
||||
model_config: &ModelConfig,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<Value> {
|
||||
let mut snowflake_messages = format_messages(messages);
|
||||
let system_spec = format_system(system);
|
||||
|
||||
// Add system message to the beginning of the messages
|
||||
snowflake_messages.insert(0, system_spec);
|
||||
|
||||
// Check if we have any messages to send
|
||||
if snowflake_messages.is_empty() {
|
||||
return Err(anyhow!("No valid messages to send to Snowflake API"));
|
||||
}
|
||||
|
||||
// Detect description generation requests and exclude tools to prevent interference
|
||||
// with normal tool execution flow
|
||||
let is_description_request =
|
||||
system.contains("Reply with only a description in four words or less");
|
||||
|
||||
let tool_specs = if is_description_request {
|
||||
// For description generation, don't include any tools to avoid confusion
|
||||
format_tools(&[])
|
||||
} else {
|
||||
format_tools(tools)
|
||||
};
|
||||
|
||||
let max_tokens = model_config.max_tokens.unwrap_or(4096);
|
||||
let mut payload = json!({
|
||||
"model": model_config.model_name,
|
||||
"messages": snowflake_messages,
|
||||
"max_tokens": max_tokens,
|
||||
});
|
||||
|
||||
// Add tools if present and not a description request
|
||||
if !tool_specs.is_empty() {
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
obj.insert("tools".to_string(), json!(tool_specs));
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Failed to create request payload: payload is not a JSON object"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_parse_text_response() -> Result<()> {
|
||||
let response = json!({
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content_list": [{
|
||||
"type": "text",
|
||||
"text": "Hello! How can I assist you today?"
|
||||
}],
|
||||
"model": "claude-3-5-sonnet",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"usage": {
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 15
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "Hello! How can I assist you today?");
|
||||
} else {
|
||||
panic!("Expected Text content");
|
||||
}
|
||||
|
||||
assert_eq!(usage.input_tokens, Some(12));
|
||||
assert_eq!(usage.output_tokens, Some(15));
|
||||
assert_eq!(usage.total_tokens, Some(27)); // 12 + 15
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_tool_response() -> Result<()> {
|
||||
let response = json!({
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content_list": [{
|
||||
"type": "tool_use",
|
||||
"tool_use_id": "tool_1",
|
||||
"name": "calculator",
|
||||
"input": {"expression": "2 + 2"}
|
||||
}],
|
||||
"model": "claude-3-5-sonnet",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"usage": {
|
||||
"input_tokens": 15,
|
||||
"output_tokens": 20
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
|
||||
let tool_call = tool_request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "calculator");
|
||||
assert_eq!(tool_call.arguments, json!({"expression": "2 + 2"}));
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
|
||||
assert_eq!(usage.input_tokens, Some(15));
|
||||
assert_eq!(usage.output_tokens, Some(20));
|
||||
assert_eq!(usage.total_tokens, Some(35)); // 15 + 20
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_to_snowflake_spec() {
|
||||
let messages = vec![
|
||||
Message::user().with_text("Hello"),
|
||||
Message::assistant().with_text("Hi there"),
|
||||
Message::user().with_text("How are you?"),
|
||||
];
|
||||
|
||||
let spec = format_messages(&messages);
|
||||
|
||||
assert_eq!(spec.len(), 3);
|
||||
assert_eq!(spec[0]["role"], "user");
|
||||
assert_eq!(spec[0]["content"], "Hello");
|
||||
assert_eq!(spec[1]["role"], "assistant");
|
||||
assert_eq!(spec[1]["content"], "Hi there");
|
||||
assert_eq!(spec[2]["role"], "user");
|
||||
assert_eq!(spec[2]["content"], "How are you?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tools_to_snowflake_spec() {
|
||||
let tools = vec![
|
||||
Tool::new(
|
||||
"calculator",
|
||||
"Calculate mathematical expressions",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "The mathematical expression to evaluate"
|
||||
}
|
||||
}
|
||||
}),
|
||||
None,
|
||||
),
|
||||
Tool::new(
|
||||
"weather",
|
||||
"Get weather information",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The location to get weather for"
|
||||
}
|
||||
}
|
||||
}),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let spec = format_tools(&tools);
|
||||
|
||||
assert_eq!(spec.len(), 2);
|
||||
assert_eq!(spec[0]["tool_spec"]["name"], "calculator");
|
||||
assert_eq!(
|
||||
spec[0]["tool_spec"]["description"],
|
||||
"Calculate mathematical expressions"
|
||||
);
|
||||
assert_eq!(spec[1]["tool_spec"]["name"], "weather");
|
||||
assert_eq!(
|
||||
spec[1]["tool_spec"]["description"],
|
||||
"Get weather information"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_to_snowflake_spec() {
|
||||
let system = "You are a helpful assistant.";
|
||||
let spec = format_system(system);
|
||||
|
||||
assert_eq!(spec["role"], "system");
|
||||
assert_eq!(spec["content"], system);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_streaming_response() -> Result<()> {
|
||||
let sse_data = r#"data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","choices":[{"delta":{"type":"text","content":"I","content_list":[{"type":"text","text":"I"}],"text":"I"}}],"usage":{}}
|
||||
|
||||
data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","choices":[{"delta":{"type":"text","content":"'ll help you check Nvidia's current","content_list":[{"type":"text","text":"'ll help you check Nvidia's current"}],"text":"'ll help you check Nvidia's current"}}],"usage":{}}
|
||||
|
||||
data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","choices":[{"delta":{"type":"tool_use","tool_use_id":"tooluse_FB_nOElDTAOKa-YnVWI5Uw","name":"get_stock_price","content_list":[{"tool_use_id":"tooluse_FB_nOElDTAOKa-YnVWI5Uw","name":"get_stock_price"}],"text":""}}],"usage":{}}
|
||||
|
||||
data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","choices":[{"delta":{"type":"tool_use","input":"{\"symbol\":\"NVDA\"}","content_list":[{"input":"{\"symbol\":\"NVDA\"}"}],"text":""}}],"usage":{"prompt_tokens":397,"completion_tokens":65,"total_tokens":462}}
|
||||
"#;
|
||||
|
||||
let message = parse_streaming_response(sse_data)?;
|
||||
|
||||
// Should have both text and tool request
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert!(text.text.contains("I'll help you check Nvidia's current"));
|
||||
} else {
|
||||
panic!("Expected Text content first");
|
||||
}
|
||||
|
||||
if let MessageContent::ToolRequest(tool_request) = &message.content[1] {
|
||||
let tool_call = tool_request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "get_stock_price");
|
||||
assert_eq!(tool_call.arguments, json!({"symbol": "NVDA"}));
|
||||
assert_eq!(tool_request.id, "tooluse_FB_nOElDTAOKa-YnVWI5Uw");
|
||||
} else {
|
||||
panic!("Expected ToolRequest content second");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_request_format() -> Result<()> {
|
||||
use crate::model::ModelConfig;
|
||||
|
||||
let model_config = ModelConfig::new("claude-3-5-sonnet".to_string());
|
||||
|
||||
let system = "You are a helpful assistant that can use tools to get information.";
|
||||
let messages = vec![Message::user().with_text("What is the stock price of Nvidia?")];
|
||||
|
||||
let tools = vec![Tool::new(
|
||||
"get_stock_price",
|
||||
"Get stock price information",
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The symbol for the stock ticker, e.g. Snowflake = SNOW"
|
||||
}
|
||||
},
|
||||
"required": ["symbol"]
|
||||
}),
|
||||
None,
|
||||
)];
|
||||
|
||||
let request = create_request(&model_config, system, &messages, &tools)?;
|
||||
|
||||
// Check basic structure
|
||||
assert_eq!(request["model"], "claude-3-5-sonnet");
|
||||
|
||||
let messages_array = request["messages"].as_array().unwrap();
|
||||
assert_eq!(messages_array.len(), 2); // system + user message
|
||||
|
||||
// First message should be system with simple content
|
||||
assert_eq!(messages_array[0]["role"], "system");
|
||||
assert_eq!(
|
||||
messages_array[0]["content"],
|
||||
"You are a helpful assistant that can use tools to get information."
|
||||
);
|
||||
|
||||
// Second message should be user with simple content
|
||||
assert_eq!(messages_array[1]["role"], "user");
|
||||
assert_eq!(
|
||||
messages_array[1]["content"],
|
||||
"What is the stock price of Nvidia?"
|
||||
);
|
||||
|
||||
// Tools should have tool_spec wrapper
|
||||
let tools_array = request["tools"].as_array().unwrap();
|
||||
assert_eq!(tools_array[0]["tool_spec"]["name"], "get_stock_price");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mixed_text_and_tool_response() -> Result<()> {
|
||||
let response = json!({
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": "I'll help you with that calculation.",
|
||||
"content_list": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "I'll help you with that calculation."
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"tool_use_id": "tool_1",
|
||||
"name": "calculator",
|
||||
"input": {"expression": "2 + 2"}
|
||||
}
|
||||
],
|
||||
"model": "claude-3-5-sonnet",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 15
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
|
||||
// Should have both text and tool request content
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "I'll help you with that calculation.");
|
||||
} else {
|
||||
panic!("Expected Text content first");
|
||||
}
|
||||
|
||||
if let MessageContent::ToolRequest(tool_request) = &message.content[1] {
|
||||
let tool_call = tool_request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "calculator");
|
||||
assert_eq!(tool_request.id, "tool_1");
|
||||
} else {
|
||||
panic!("Expected ToolRequest content second");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_tools_array() {
|
||||
let tools: Vec<Tool> = vec![];
|
||||
let spec = format_tools(&tools);
|
||||
assert_eq!(spec.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_request_excludes_tools_for_description() -> Result<()> {
|
||||
use crate::model::ModelConfig;
|
||||
|
||||
let model_config = ModelConfig::new("claude-3-5-sonnet".to_string());
|
||||
let system = "Reply with only a description in four words or less";
|
||||
let messages = vec![Message::user().with_text("Test message")];
|
||||
let tools = vec![Tool::new(
|
||||
"test_tool",
|
||||
"Test tool",
|
||||
json!({"type": "object", "properties": {}}),
|
||||
None,
|
||||
)];
|
||||
|
||||
let request = create_request(&model_config, system, &messages, &tools)?;
|
||||
|
||||
// Should not include tools for description requests
|
||||
assert!(request.get("tools").is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_formatting_skips_tool_requests() {
|
||||
use mcp_core::tool::ToolCall;
|
||||
|
||||
// Create a conversation with text, tool requests, and tool responses
|
||||
let tool_call = ToolCall::new("calculator", json!({"expression": "2 + 2"}));
|
||||
|
||||
let messages = vec![
|
||||
Message::user().with_text("Calculate 2 + 2"),
|
||||
Message::assistant()
|
||||
.with_text("I'll help you calculate that.")
|
||||
.with_tool_request("tool_1", Ok(tool_call)),
|
||||
Message::user().with_text("Thanks!"),
|
||||
];
|
||||
|
||||
let spec = format_messages(&messages);
|
||||
|
||||
// Should only have 3 messages - the tool request should be skipped
|
||||
assert_eq!(spec.len(), 3);
|
||||
assert_eq!(spec[0]["role"], "user");
|
||||
assert_eq!(spec[0]["content"], "Calculate 2 + 2");
|
||||
assert_eq!(spec[1]["role"], "assistant");
|
||||
assert_eq!(spec[1]["content"], "I'll help you calculate that.");
|
||||
assert_eq!(spec[2]["role"], "user");
|
||||
assert_eq!(spec[2]["content"], "Thanks!");
|
||||
|
||||
// Verify no tool request content is in the message history
|
||||
for message in &spec {
|
||||
let content = message["content"].as_str().unwrap();
|
||||
assert!(!content.contains("Using tool:"));
|
||||
assert!(!content.contains("calculator"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -434,6 +434,9 @@ impl Provider for GcpVertexAIProvider {
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Pro15),
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Flash20),
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp),
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Pro25Exp),
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Flash25Preview),
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Pro25Preview),
|
||||
]
|
||||
.iter()
|
||||
.map(|model| model.to_string())
|
||||
|
||||
@@ -230,7 +230,7 @@ impl GithubCopilotProvider {
|
||||
|
||||
async fn refresh_api_info(&self) -> Result<CopilotTokenInfo> {
|
||||
let config = Config::global();
|
||||
let token = match config.get_secret::<String>("GITHUB_TOKEN") {
|
||||
let token = match config.get_secret::<String>("GITHUB_COPILOT_TOKEN") {
|
||||
Ok(token) => token,
|
||||
Err(err) => match err {
|
||||
ConfigError::NotFound(_) => {
|
||||
@@ -238,7 +238,7 @@ impl GithubCopilotProvider {
|
||||
.get_access_token()
|
||||
.await
|
||||
.context("unable to login into github")?;
|
||||
config.set_secret("GITHUB_TOKEN", Value::String(token.clone()))?;
|
||||
config.set_secret("GITHUB_COPILOT_TOKEN", Value::String(token.clone()))?;
|
||||
token
|
||||
}
|
||||
_ => return Err(err.into()),
|
||||
|
||||
637
crates/goose/src/providers/lead_worker.rs
Normal file
637
crates/goose/src/providers/lead_worker.rs
Normal file
@@ -0,0 +1,637 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage};
|
||||
use super::errors::ProviderError;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::model::ModelConfig;
|
||||
use mcp_core::{tool::Tool, Content};
|
||||
|
||||
/// A provider that switches between a lead model and a worker model based on turn count
|
||||
/// and can fallback to lead model on consecutive failures
|
||||
pub struct LeadWorkerProvider {
|
||||
lead_provider: Arc<dyn Provider>,
|
||||
worker_provider: Arc<dyn Provider>,
|
||||
lead_turns: usize,
|
||||
turn_count: Arc<Mutex<usize>>,
|
||||
failure_count: Arc<Mutex<usize>>,
|
||||
max_failures_before_fallback: usize,
|
||||
fallback_turns: usize,
|
||||
in_fallback_mode: Arc<Mutex<bool>>,
|
||||
fallback_remaining: Arc<Mutex<usize>>,
|
||||
}
|
||||
|
||||
impl LeadWorkerProvider {
|
||||
/// Create a new LeadWorkerProvider
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `lead_provider` - The provider to use for the initial turns
|
||||
/// * `worker_provider` - The provider to use after lead_turns
|
||||
/// * `lead_turns` - Number of turns to use the lead provider (default: 3)
|
||||
pub fn new(
|
||||
lead_provider: Arc<dyn Provider>,
|
||||
worker_provider: Arc<dyn Provider>,
|
||||
lead_turns: Option<usize>,
|
||||
) -> Self {
|
||||
Self {
|
||||
lead_provider,
|
||||
worker_provider,
|
||||
lead_turns: lead_turns.unwrap_or(3),
|
||||
turn_count: Arc::new(Mutex::new(0)),
|
||||
failure_count: Arc::new(Mutex::new(0)),
|
||||
max_failures_before_fallback: 2, // Fallback after 2 consecutive failures
|
||||
fallback_turns: 2, // Use lead model for 2 turns when in fallback mode
|
||||
in_fallback_mode: Arc::new(Mutex::new(false)),
|
||||
fallback_remaining: Arc::new(Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new LeadWorkerProvider with custom settings
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `lead_provider` - The provider to use for the initial turns
|
||||
/// * `worker_provider` - The provider to use after lead_turns
|
||||
/// * `lead_turns` - Number of turns to use the lead provider
|
||||
/// * `failure_threshold` - Number of consecutive failures before fallback
|
||||
/// * `fallback_turns` - Number of turns to use lead model in fallback mode
|
||||
pub fn new_with_settings(
|
||||
lead_provider: Arc<dyn Provider>,
|
||||
worker_provider: Arc<dyn Provider>,
|
||||
lead_turns: usize,
|
||||
failure_threshold: usize,
|
||||
fallback_turns: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
lead_provider,
|
||||
worker_provider,
|
||||
lead_turns,
|
||||
turn_count: Arc::new(Mutex::new(0)),
|
||||
failure_count: Arc::new(Mutex::new(0)),
|
||||
max_failures_before_fallback: failure_threshold,
|
||||
fallback_turns,
|
||||
in_fallback_mode: Arc::new(Mutex::new(false)),
|
||||
fallback_remaining: Arc::new(Mutex::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset the turn counter and failure tracking (useful for new conversations)
|
||||
pub async fn reset_turn_count(&self) {
|
||||
let mut count = self.turn_count.lock().await;
|
||||
*count = 0;
|
||||
let mut failures = self.failure_count.lock().await;
|
||||
*failures = 0;
|
||||
let mut fallback = self.in_fallback_mode.lock().await;
|
||||
*fallback = false;
|
||||
let mut remaining = self.fallback_remaining.lock().await;
|
||||
*remaining = 0;
|
||||
}
|
||||
|
||||
/// Get the current turn count
|
||||
pub async fn get_turn_count(&self) -> usize {
|
||||
*self.turn_count.lock().await
|
||||
}
|
||||
|
||||
/// Get the current failure count
|
||||
pub async fn get_failure_count(&self) -> usize {
|
||||
*self.failure_count.lock().await
|
||||
}
|
||||
|
||||
/// Check if currently in fallback mode
|
||||
pub async fn is_in_fallback_mode(&self) -> bool {
|
||||
*self.in_fallback_mode.lock().await
|
||||
}
|
||||
|
||||
/// Get the currently active provider based on turn count and fallback state
|
||||
async fn get_active_provider(&self) -> Arc<dyn Provider> {
|
||||
let count = *self.turn_count.lock().await;
|
||||
let in_fallback = *self.in_fallback_mode.lock().await;
|
||||
|
||||
// Use lead provider if we're in initial turns OR in fallback mode
|
||||
if count < self.lead_turns || in_fallback {
|
||||
Arc::clone(&self.lead_provider)
|
||||
} else {
|
||||
Arc::clone(&self.worker_provider)
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle the result of a completion attempt and update failure tracking
|
||||
async fn handle_completion_result(
|
||||
&self,
|
||||
result: &Result<(Message, ProviderUsage), ProviderError>,
|
||||
) {
|
||||
match result {
|
||||
Ok((message, _usage)) => {
|
||||
// Check for task-level failures in the response
|
||||
let has_task_failure = self.detect_task_failures(message).await;
|
||||
|
||||
if has_task_failure {
|
||||
// Task failure detected - increment failure count
|
||||
let mut failures = self.failure_count.lock().await;
|
||||
*failures += 1;
|
||||
|
||||
let failure_count = *failures;
|
||||
let turn_count = *self.turn_count.lock().await;
|
||||
|
||||
tracing::warn!(
|
||||
"Task failure detected in response (failure count: {})",
|
||||
failure_count
|
||||
);
|
||||
|
||||
// Check if we should trigger fallback
|
||||
if turn_count >= self.lead_turns
|
||||
&& !*self.in_fallback_mode.lock().await
|
||||
&& failure_count >= self.max_failures_before_fallback
|
||||
{
|
||||
let mut in_fallback = self.in_fallback_mode.lock().await;
|
||||
let mut fallback_remaining = self.fallback_remaining.lock().await;
|
||||
|
||||
*in_fallback = true;
|
||||
*fallback_remaining = self.fallback_turns;
|
||||
*failures = 0; // Reset failure count when entering fallback
|
||||
|
||||
tracing::warn!(
|
||||
"🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns",
|
||||
self.max_failures_before_fallback,
|
||||
self.fallback_turns
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Success - reset failure count and handle fallback mode
|
||||
let mut failures = self.failure_count.lock().await;
|
||||
*failures = 0;
|
||||
|
||||
let mut in_fallback = self.in_fallback_mode.lock().await;
|
||||
let mut fallback_remaining = self.fallback_remaining.lock().await;
|
||||
|
||||
if *in_fallback {
|
||||
*fallback_remaining -= 1;
|
||||
if *fallback_remaining == 0 {
|
||||
*in_fallback = false;
|
||||
tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment turn count on any completion (success or task failure)
|
||||
let mut count = self.turn_count.lock().await;
|
||||
*count += 1;
|
||||
}
|
||||
Err(_) => {
|
||||
// Technical failure - just log and let it bubble up
|
||||
// For technical failures (API/LLM issues), we don't want to second-guess
|
||||
// the model choice - just let the default model handle it
|
||||
tracing::warn!(
|
||||
"Technical failure detected - API/LLM issue, will use default model"
|
||||
);
|
||||
|
||||
// Don't increment turn count or failure tracking for technical failures
|
||||
// as these are temporary infrastructure issues, not model capability issues
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect task-level failures in the model's response
|
||||
async fn detect_task_failures(&self, message: &Message) -> bool {
|
||||
let mut failure_indicators = 0;
|
||||
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::ToolRequest(tool_request) => {
|
||||
// Check if tool request itself failed (malformed, etc.)
|
||||
if tool_request.tool_call.is_err() {
|
||||
failure_indicators += 1;
|
||||
tracing::debug!(
|
||||
"Failed tool request detected: {:?}",
|
||||
tool_request.tool_call
|
||||
);
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResponse(tool_response) => {
|
||||
// Check if tool execution failed
|
||||
if let Err(tool_error) = &tool_response.tool_result {
|
||||
failure_indicators += 1;
|
||||
tracing::debug!("Tool execution failure detected: {:?}", tool_error);
|
||||
} else if let Ok(contents) = &tool_response.tool_result {
|
||||
// Check tool output for error indicators
|
||||
if self.contains_error_indicators(contents) {
|
||||
failure_indicators += 1;
|
||||
tracing::debug!("Tool output contains error indicators");
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::Text(text_content) => {
|
||||
// Check for user correction patterns or error acknowledgments
|
||||
if self.contains_user_correction_patterns(&text_content.text) {
|
||||
failure_indicators += 1;
|
||||
tracing::debug!("User correction pattern detected in text");
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Consider it a failure if we have multiple failure indicators
|
||||
failure_indicators >= 1
|
||||
}
|
||||
|
||||
/// Check if tool output contains error indicators
|
||||
fn contains_error_indicators(&self, contents: &[Content]) -> bool {
|
||||
for content in contents {
|
||||
if let Content::Text(text_content) = content {
|
||||
let text_lower = text_content.text.to_lowercase();
|
||||
|
||||
// Common error patterns in tool outputs
|
||||
if text_lower.contains("error:")
|
||||
|| text_lower.contains("failed:")
|
||||
|| text_lower.contains("exception:")
|
||||
|| text_lower.contains("traceback")
|
||||
|| text_lower.contains("syntax error")
|
||||
|| text_lower.contains("permission denied")
|
||||
|| text_lower.contains("file not found")
|
||||
|| text_lower.contains("command not found")
|
||||
|| text_lower.contains("compilation failed")
|
||||
|| text_lower.contains("test failed")
|
||||
|| text_lower.contains("assertion failed")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Check for user correction patterns in text
|
||||
fn contains_user_correction_patterns(&self, text: &str) -> bool {
|
||||
let text_lower = text.to_lowercase();
|
||||
|
||||
// Patterns indicating user is correcting or expressing dissatisfaction
|
||||
text_lower.contains("that's wrong")
|
||||
|| text_lower.contains("that's not right")
|
||||
|| text_lower.contains("that doesn't work")
|
||||
|| text_lower.contains("try again")
|
||||
|| text_lower.contains("let me correct")
|
||||
|| text_lower.contains("actually, ")
|
||||
|| text_lower.contains("no, that's")
|
||||
|| text_lower.contains("that's incorrect")
|
||||
|| text_lower.contains("fix this")
|
||||
|| text_lower.contains("this is broken")
|
||||
|| text_lower.contains("this doesn't")
|
||||
|| text_lower.starts_with("no,")
|
||||
|| text_lower.starts_with("wrong")
|
||||
|| text_lower.starts_with("incorrect")
|
||||
}
|
||||
}
|
||||
|
||||
impl LeadWorkerProviderTrait for LeadWorkerProvider {
|
||||
/// Get information about the lead and worker models for logging
|
||||
fn get_model_info(&self) -> (String, String) {
|
||||
let lead_model = self.lead_provider.get_model_config().model_name;
|
||||
let worker_model = self.worker_provider.get_model_config().model_name;
|
||||
(lead_model, worker_model)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for LeadWorkerProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
// This is a wrapper provider, so we return minimal metadata
|
||||
ProviderMetadata::new(
|
||||
"lead_worker",
|
||||
"Lead/Worker Provider",
|
||||
"A provider that switches between lead and worker models based on turn count",
|
||||
"", // No default model as this is determined by the wrapped providers
|
||||
vec![], // No known models as this depends on wrapped providers
|
||||
"", // No doc link
|
||||
vec![], // No config keys as configuration is done through wrapped providers
|
||||
)
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
// Return the lead provider's model config as the default
|
||||
// In practice, this might need to be more sophisticated
|
||||
self.lead_provider.get_model_config()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
// Get the active provider
|
||||
let provider = self.get_active_provider().await;
|
||||
|
||||
// Log which provider is being used
|
||||
let turn_count = *self.turn_count.lock().await;
|
||||
let in_fallback = *self.in_fallback_mode.lock().await;
|
||||
let fallback_remaining = *self.fallback_remaining.lock().await;
|
||||
|
||||
let provider_type = if turn_count < self.lead_turns {
|
||||
"lead (initial)"
|
||||
} else if in_fallback {
|
||||
"lead (fallback)"
|
||||
} else {
|
||||
"worker"
|
||||
};
|
||||
|
||||
if in_fallback {
|
||||
tracing::info!(
|
||||
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining)",
|
||||
provider_type,
|
||||
turn_count + 1,
|
||||
fallback_remaining
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Using {} provider for turn {} (lead_turns: {})",
|
||||
provider_type,
|
||||
turn_count + 1,
|
||||
self.lead_turns
|
||||
);
|
||||
}
|
||||
|
||||
// Make the completion request
|
||||
let result = provider.complete(system, messages, tools).await;
|
||||
|
||||
// For technical failures, try with default model (lead provider) instead
|
||||
let final_result = match &result {
|
||||
Err(_) => {
|
||||
tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type);
|
||||
|
||||
// Try with lead provider as the default/fallback for technical failures
|
||||
let default_result = self.lead_provider.complete(system, messages, tools).await;
|
||||
|
||||
match &default_result {
|
||||
Ok(_) => {
|
||||
tracing::info!(
|
||||
"✅ Default model (lead provider) succeeded after technical failure"
|
||||
);
|
||||
default_result
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::error!("❌ Default model (lead provider) also failed - returning original error");
|
||||
result // Return the original error
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_) => result, // Success with original provider
|
||||
};
|
||||
|
||||
// Handle the result and update tracking (only for successful completions)
|
||||
self.handle_completion_result(&final_result).await;
|
||||
|
||||
final_result
|
||||
}
|
||||
|
||||
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
|
||||
// Combine models from both providers
|
||||
let lead_models = self.lead_provider.fetch_supported_models_async().await?;
|
||||
let worker_models = self.worker_provider.fetch_supported_models_async().await?;
|
||||
|
||||
match (lead_models, worker_models) {
|
||||
(Some(lead), Some(worker)) => {
|
||||
let mut all_models = lead;
|
||||
all_models.extend(worker);
|
||||
all_models.sort();
|
||||
all_models.dedup();
|
||||
Ok(Some(all_models))
|
||||
}
|
||||
(Some(models), None) | (None, Some(models)) => Ok(Some(models)),
|
||||
(None, None) => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_embeddings(&self) -> bool {
|
||||
// Support embeddings if either provider supports them
|
||||
self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings()
|
||||
}
|
||||
|
||||
async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
|
||||
// Use the lead provider for embeddings if it supports them, otherwise use worker
|
||||
if self.lead_provider.supports_embeddings() {
|
||||
self.lead_provider.create_embeddings(texts).await
|
||||
} else if self.worker_provider.supports_embeddings() {
|
||||
self.worker_provider.create_embeddings(texts).await
|
||||
} else {
|
||||
Err(ProviderError::ExecutionError(
|
||||
"Neither lead nor worker provider supports embeddings".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this provider is a LeadWorkerProvider
|
||||
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::message::MessageContent;
|
||||
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
|
||||
use chrono::Utc;
|
||||
use mcp_core::{content::TextContent, Role};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockProvider {
|
||||
name: String,
|
||||
model_config: ModelConfig,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MockProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::empty()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
text: format!("Response from {}", self.name),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
ProviderUsage::new(self.name.clone(), Usage::default()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_lead_worker_switching() {
|
||||
let lead_provider = Arc::new(MockProvider {
|
||||
name: "lead".to_string(),
|
||||
model_config: ModelConfig::new("lead-model".to_string()),
|
||||
});
|
||||
|
||||
let worker_provider = Arc::new(MockProvider {
|
||||
name: "worker".to_string(),
|
||||
model_config: ModelConfig::new("worker-model".to_string()),
|
||||
});
|
||||
|
||||
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3));
|
||||
|
||||
// First three turns should use lead provider
|
||||
for i in 0..3 {
|
||||
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
|
||||
assert_eq!(usage.model, "lead");
|
||||
assert_eq!(provider.get_turn_count().await, i + 1);
|
||||
assert!(!provider.is_in_fallback_mode().await);
|
||||
}
|
||||
|
||||
// Subsequent turns should use worker provider
|
||||
for i in 3..6 {
|
||||
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
|
||||
assert_eq!(usage.model, "worker");
|
||||
assert_eq!(provider.get_turn_count().await, i + 1);
|
||||
assert!(!provider.is_in_fallback_mode().await);
|
||||
}
|
||||
|
||||
// Reset and verify it goes back to lead
|
||||
provider.reset_turn_count().await;
|
||||
assert_eq!(provider.get_turn_count().await, 0);
|
||||
assert_eq!(provider.get_failure_count().await, 0);
|
||||
assert!(!provider.is_in_fallback_mode().await);
|
||||
|
||||
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
|
||||
assert_eq!(usage.model, "lead");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_technical_failure_retry() {
|
||||
let lead_provider = Arc::new(MockFailureProvider {
|
||||
name: "lead".to_string(),
|
||||
model_config: ModelConfig::new("lead-model".to_string()),
|
||||
should_fail: false, // Lead provider works
|
||||
});
|
||||
|
||||
let worker_provider = Arc::new(MockFailureProvider {
|
||||
name: "worker".to_string(),
|
||||
model_config: ModelConfig::new("worker-model".to_string()),
|
||||
should_fail: true, // Worker will fail
|
||||
});
|
||||
|
||||
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
|
||||
|
||||
// First two turns use lead (should succeed)
|
||||
for _i in 0..2 {
|
||||
let result = provider.complete("system", &[], &[]).await;
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().1.model, "lead");
|
||||
assert!(!provider.is_in_fallback_mode().await);
|
||||
}
|
||||
|
||||
// Next turn uses worker (will fail, but should retry with lead and succeed)
|
||||
let result = provider.complete("system", &[], &[]).await;
|
||||
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
|
||||
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
|
||||
assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures
|
||||
assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode
|
||||
|
||||
// Another turn - should still try worker first, then retry with lead
|
||||
let result = provider.complete("system", &[], &[]).await;
|
||||
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
|
||||
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
|
||||
assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking
|
||||
assert!(!provider.is_in_fallback_mode().await); // Still not in fallback mode
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fallback_on_task_failures() {
|
||||
// Test that task failures (not technical failures) still trigger fallback mode
|
||||
// This would need a different mock that simulates task failures in successful responses
|
||||
// For now, we'll test the fallback mode functionality directly
|
||||
let lead_provider = Arc::new(MockFailureProvider {
|
||||
name: "lead".to_string(),
|
||||
model_config: ModelConfig::new("lead-model".to_string()),
|
||||
should_fail: false,
|
||||
});
|
||||
|
||||
let worker_provider = Arc::new(MockFailureProvider {
|
||||
name: "worker".to_string(),
|
||||
model_config: ModelConfig::new("worker-model".to_string()),
|
||||
should_fail: false,
|
||||
});
|
||||
|
||||
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
|
||||
|
||||
// Simulate being in fallback mode
|
||||
{
|
||||
let mut in_fallback = provider.in_fallback_mode.lock().await;
|
||||
*in_fallback = true;
|
||||
let mut fallback_remaining = provider.fallback_remaining.lock().await;
|
||||
*fallback_remaining = 2;
|
||||
let mut turn_count = provider.turn_count.lock().await;
|
||||
*turn_count = 4; // Past initial lead turns
|
||||
}
|
||||
|
||||
// Should use lead provider in fallback mode
|
||||
let result = provider.complete("system", &[], &[]).await;
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().1.model, "lead");
|
||||
assert!(provider.is_in_fallback_mode().await);
|
||||
|
||||
// One more fallback turn
|
||||
let result = provider.complete("system", &[], &[]).await;
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().1.model, "lead");
|
||||
assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockFailureProvider {
|
||||
name: String,
|
||||
model_config: ModelConfig,
|
||||
should_fail: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MockFailureProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::empty()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
if self.should_fail {
|
||||
Err(ProviderError::ExecutionError(
|
||||
"Simulated failure".to_string(),
|
||||
))
|
||||
} else {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
text: format!("Response from {}", self.name),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
ProviderUsage::new(self.name.clone(), Usage::default()),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,10 +13,12 @@ pub mod gcpvertexai;
|
||||
pub mod githubcopilot;
|
||||
pub mod google;
|
||||
pub mod groq;
|
||||
pub mod lead_worker;
|
||||
pub mod oauth;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod openrouter;
|
||||
pub mod snowflake;
|
||||
pub mod toolshim;
|
||||
pub mod utils;
|
||||
pub mod utils_universal_openai_stream;
|
||||
|
||||
@@ -249,7 +249,7 @@ impl EmbeddingCapable for OpenAiProvider {
|
||||
}
|
||||
|
||||
// Get embedding model from env var or use default
|
||||
let embedding_model = std::env::var("EMBEDDING_MODEL")
|
||||
let embedding_model = std::env::var("GOOSE_EMBEDDING_MODEL")
|
||||
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
|
||||
|
||||
let request = EmbeddingRequest {
|
||||
|
||||
439
crates/goose/src/providers/snowflake.rs
Normal file
439
crates/goose/src/providers/snowflake.rs
Normal file
@@ -0,0 +1,439 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
||||
use super::errors::ProviderError;
|
||||
use super::formats::snowflake::{create_request, get_usage, response_to_message};
|
||||
use super::utils::{get_model, ImageFormat};
|
||||
use crate::config::ConfigError;
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use mcp_core::tool::Tool;
|
||||
use url::Url;
|
||||
|
||||
pub const SNOWFLAKE_DEFAULT_MODEL: &str = "claude-3-7-sonnet";
|
||||
pub const SNOWFLAKE_KNOWN_MODELS: &[&str] = &["claude-3-7-sonnet", "claude-3-5-sonnet"];
|
||||
|
||||
pub const SNOWFLAKE_DOC_URL: &str =
|
||||
"https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#choosing-a-model";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum SnowflakeAuth {
|
||||
Token(String),
|
||||
}
|
||||
|
||||
impl SnowflakeAuth {
|
||||
pub fn token(token: String) -> Self {
|
||||
Self::Token(token)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct SnowflakeProvider {
|
||||
#[serde(skip)]
|
||||
client: Client,
|
||||
host: String,
|
||||
auth: SnowflakeAuth,
|
||||
model: ModelConfig,
|
||||
image_format: ImageFormat,
|
||||
}
|
||||
|
||||
impl Default for SnowflakeProvider {
|
||||
fn default() -> Self {
|
||||
let model = ModelConfig::new(SnowflakeProvider::metadata().default_model);
|
||||
SnowflakeProvider::from_env(model).expect("Failed to initialize Snowflake provider")
|
||||
}
|
||||
}
|
||||
|
||||
impl SnowflakeProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||
let config = crate::config::Config::global();
|
||||
let mut host: Result<String, ConfigError> = config.get_param("SNOWFLAKE_HOST");
|
||||
if host.is_err() {
|
||||
host = config.get_secret("SNOWFLAKE_HOST")
|
||||
}
|
||||
if host.is_err() {
|
||||
return Err(ConfigError::NotFound(
|
||||
"Did not find SNOWFLAKE_HOST in either config file or keyring".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let mut host = host?;
|
||||
|
||||
// Convert host to lowercase
|
||||
host = host.to_lowercase();
|
||||
|
||||
// Ensure host ends with snowflakecomputing.com
|
||||
if !host.ends_with("snowflakecomputing.com") {
|
||||
host = format!("{}.snowflakecomputing.com", host);
|
||||
}
|
||||
|
||||
let mut token: Result<String, ConfigError> = config.get_param("SNOWFLAKE_TOKEN");
|
||||
|
||||
if token.is_err() {
|
||||
token = config.get_secret("SNOWFLAKE_TOKEN")
|
||||
}
|
||||
|
||||
if token.is_err() {
|
||||
return Err(ConfigError::NotFound(
|
||||
"Did not find SNOWFLAKE_TOKEN in either config file or keyring".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(600))
|
||||
.build()?;
|
||||
|
||||
// Use token-based authentication
|
||||
let api_key = token?;
|
||||
Ok(Self {
|
||||
client,
|
||||
host,
|
||||
auth: SnowflakeAuth::token(api_key),
|
||||
model,
|
||||
image_format: ImageFormat::OpenAi,
|
||||
})
|
||||
}
|
||||
|
||||
async fn ensure_auth_header(&self) -> Result<String> {
|
||||
match &self.auth {
|
||||
// https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#using-a-programmatic-access-token-pat
|
||||
SnowflakeAuth::Token(token) => Ok(format!("Bearer {}", token)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let base_url_str =
|
||||
if !self.host.starts_with("https://") && !self.host.starts_with("http://") {
|
||||
format!("https://{}", self.host)
|
||||
} else {
|
||||
self.host.clone()
|
||||
};
|
||||
let base_url = Url::parse(&base_url_str)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let path = "api/v2/cortex/inference:complete";
|
||||
let url = base_url.join(path).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let auth_header = self.ensure_auth_header().await?;
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", auth_header)
|
||||
.header("User-Agent", "Goose")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
let payload_text: String = response.text().await.ok().unwrap_or_default();
|
||||
|
||||
if status == StatusCode::OK {
|
||||
if let Ok(payload) = serde_json::from_str::<Value>(&payload_text) {
|
||||
if payload.get("code").is_some() {
|
||||
let code = payload
|
||||
.get("code")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("Unknown code");
|
||||
let message = payload
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("Unknown message");
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
"{} - {}",
|
||||
code, message
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let lines = payload_text.lines().collect::<Vec<_>>();
|
||||
|
||||
let mut text = String::new();
|
||||
let mut tool_name = String::new();
|
||||
let mut tool_input = String::new();
|
||||
let mut tool_use_id = String::new();
|
||||
for line in lines.iter() {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let json_str = match line.strip_prefix("data: ") {
|
||||
Some(s) => s,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if let Ok(json_line) = serde_json::from_str::<Value>(json_str) {
|
||||
let choices = match json_line.get("choices").and_then(|c| c.as_array()) {
|
||||
Some(choices) => choices,
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let choice = match choices.first() {
|
||||
Some(choice) => choice,
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let delta = match choice.get("delta") {
|
||||
Some(delta) => delta,
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Track if we found text in content_list to avoid duplication
|
||||
let mut found_text_in_content_list = false;
|
||||
|
||||
// Handle content_list array first
|
||||
if let Some(content_list) = delta.get("content_list").and_then(|cl| cl.as_array()) {
|
||||
for content_item in content_list {
|
||||
match content_item.get("type").and_then(|t| t.as_str()) {
|
||||
Some("text") => {
|
||||
if let Some(text_content) =
|
||||
content_item.get("text").and_then(|t| t.as_str())
|
||||
{
|
||||
text.push_str(text_content);
|
||||
found_text_in_content_list = true;
|
||||
}
|
||||
}
|
||||
Some("tool_use") => {
|
||||
if let Some(tool_id) =
|
||||
content_item.get("tool_use_id").and_then(|id| id.as_str())
|
||||
{
|
||||
tool_use_id.push_str(tool_id);
|
||||
}
|
||||
if let Some(name) =
|
||||
content_item.get("name").and_then(|n| n.as_str())
|
||||
{
|
||||
tool_name.push_str(name);
|
||||
}
|
||||
if let Some(input) =
|
||||
content_item.get("input").and_then(|i| i.as_str())
|
||||
{
|
||||
tool_input.push_str(input);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Handle content items without explicit type but with tool information
|
||||
if let Some(name) =
|
||||
content_item.get("name").and_then(|n| n.as_str())
|
||||
{
|
||||
tool_name.push_str(name);
|
||||
}
|
||||
if let Some(tool_id) =
|
||||
content_item.get("tool_use_id").and_then(|id| id.as_str())
|
||||
{
|
||||
tool_use_id.push_str(tool_id);
|
||||
}
|
||||
if let Some(input) =
|
||||
content_item.get("input").and_then(|i| i.as_str())
|
||||
{
|
||||
tool_input.push_str(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle direct content field (for text) only if we didn't find text in content_list
|
||||
if !found_text_in_content_list {
|
||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
||||
text.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the appropriate response structure
|
||||
let mut content_list = Vec::new();
|
||||
|
||||
// Add text content if available
|
||||
if !text.is_empty() {
|
||||
content_list.push(json!({
|
||||
"type": "text",
|
||||
"text": text
|
||||
}));
|
||||
}
|
||||
|
||||
// Add tool use content only if we have complete tool information
|
||||
if !tool_use_id.is_empty() && !tool_name.is_empty() {
|
||||
// Parse tool input as JSON if it's not empty
|
||||
let parsed_input = if tool_input.is_empty() {
|
||||
json!({})
|
||||
} else {
|
||||
serde_json::from_str::<Value>(&tool_input)
|
||||
.unwrap_or_else(|_| json!({"raw_input": tool_input}))
|
||||
};
|
||||
|
||||
content_list.push(json!({
|
||||
"type": "tool_use",
|
||||
"tool_use_id": tool_use_id,
|
||||
"name": tool_name,
|
||||
"input": parsed_input
|
||||
}));
|
||||
}
|
||||
|
||||
// Ensure we always have at least some content
|
||||
if content_list.is_empty() {
|
||||
content_list.push(json!({
|
||||
"type": "text",
|
||||
"text": ""
|
||||
}));
|
||||
}
|
||||
|
||||
let answer_payload = json!({
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"content_list": content_list
|
||||
});
|
||||
|
||||
match status {
|
||||
StatusCode::OK => Ok(answer_payload),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
// Extract a clean error message from the response if available
|
||||
let error_msg = payload_text
|
||||
.lines()
|
||||
.find(|line| line.contains("\"message\""))
|
||||
.and_then(|line| {
|
||||
let json_str = line.strip_prefix("data: ").unwrap_or(line);
|
||||
serde_json::from_str::<Value>(json_str).ok()
|
||||
})
|
||||
.and_then(|json| {
|
||||
json.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.map(|s| s.to_string())
|
||||
})
|
||||
.unwrap_or_else(|| "Invalid credentials".to_string());
|
||||
|
||||
Err(ProviderError::Authentication(format!(
|
||||
"Authentication failed. Please check your SNOWFLAKE_TOKEN and SNOWFLAKE_HOST configuration. Error: {}",
|
||||
error_msg
|
||||
)))
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Snowflake provides a generic 'error' but also includes 'external_model_message' which is provider specific
|
||||
// We try to extract the error message from the payload and check for phrases that indicate context length exceeded
|
||||
let payload_str = payload_text.to_lowercase();
|
||||
let check_phrases = [
|
||||
"too long",
|
||||
"context length",
|
||||
"context_length_exceeded",
|
||||
"reduce the length",
|
||||
"token count",
|
||||
"exceeds",
|
||||
"exceed context limit",
|
||||
"input length",
|
||||
"max_tokens",
|
||||
"decrease input length",
|
||||
"context limit",
|
||||
];
|
||||
if check_phrases.iter().any(|c| payload_str.contains(c)) {
|
||||
return Err(ProviderError::ContextLengthExceeded("Request exceeds maximum context length. Please reduce the number of messages or content size.".to_string()));
|
||||
}
|
||||
|
||||
// Try to parse a clean error message from the response
|
||||
let error_msg = if let Ok(json) = serde_json::from_str::<Value>(&payload_text) {
|
||||
json.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
json.get("external_model_message")
|
||||
.and_then(|ext| ext.get("message"))
|
||||
.and_then(|m| m.as_str())
|
||||
.map(|s| s.to_string())
|
||||
})
|
||||
.unwrap_or_else(|| "Bad request".to_string())
|
||||
} else {
|
||||
"Bad request".to_string()
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
"Provider request failed with status: {}. Response: {}",
|
||||
status,
|
||||
payload_text
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed: {}",
|
||||
error_msg
|
||||
)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => Err(ProviderError::RateLimitExceeded(
|
||||
"Rate limit exceeded. Please try again later.".to_string(),
|
||||
)),
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Err(ProviderError::ServerError(
|
||||
"Snowflake service is temporarily unavailable. Please try again later."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"Provider request failed with status: {}. Response: {}",
|
||||
status,
|
||||
payload_text
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed with status: {}",
|
||||
status
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for SnowflakeProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::new(
|
||||
"snowflake",
|
||||
"Snowflake",
|
||||
"Access several models using Snowflake Cortex services.",
|
||||
SNOWFLAKE_DEFAULT_MODEL,
|
||||
SNOWFLAKE_KNOWN_MODELS.to_vec(),
|
||||
SNOWFLAKE_DOC_URL,
|
||||
vec![
|
||||
ConfigKey::new("SNOWFLAKE_HOST", true, false, None),
|
||||
ConfigKey::new("SNOWFLAKE_TOKEN", true, true, None),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip(self, system, messages, tools),
|
||||
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||
)]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let payload = create_request(&self.model, system, messages, tools)?;
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let usage = get_usage(&response)?;
|
||||
let model = get_model(&response);
|
||||
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
|
||||
Ok((message, ProviderUsage::new(model, usage)))
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,7 @@ use crate::model::ModelConfig;
|
||||
use crate::providers::formats::openai::create_request;
|
||||
use anyhow::Result;
|
||||
use mcp_core::tool::{Tool, ToolCall};
|
||||
use mcp_core::Content;
|
||||
use reqwest::Client;
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
@@ -164,7 +165,10 @@ impl OllamaInterpreter {
|
||||
payload["stream"] = json!(false); // needed for the /api/chat endpoint to work
|
||||
payload["format"] = format_schema;
|
||||
|
||||
// tracing::warn!("payload: {}", serde_json::to_string_pretty(&payload).unwrap_or_default());
|
||||
tracing::info!(
|
||||
"Tool interpreter payload: {}",
|
||||
serde_json::to_string_pretty(&payload).unwrap_or_default()
|
||||
);
|
||||
|
||||
let response = self.client.post(&url).json(&payload).send().await?;
|
||||
|
||||
@@ -193,7 +197,10 @@ impl OllamaInterpreter {
|
||||
|
||||
fn process_interpreter_response(response: &Value) -> Result<Vec<ToolCall>, ProviderError> {
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
tracing::info!(
|
||||
"Tool interpreter response is {}",
|
||||
serde_json::to_string_pretty(&response).unwrap_or_default()
|
||||
);
|
||||
// Extract tool_calls array from the response
|
||||
if response.get("message").is_some() && response["message"].get("content").is_some() {
|
||||
let content = response["message"]["content"].as_str().unwrap_or_default();
|
||||
@@ -298,6 +305,72 @@ pub fn format_tool_info(tools: &[Tool]) -> String {
|
||||
tool_info
|
||||
}
|
||||
|
||||
/// Convert messages containing ToolRequest/ToolResponse to text messages for toolshim mode
|
||||
/// This is necessary because some providers (like Bedrock) validate that tool_use/tool_result
|
||||
/// blocks can only exist when tools are defined, but in toolshim mode we pass empty tools
|
||||
pub fn convert_tool_messages_to_text(messages: &[Message]) -> Vec<Message> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
let mut new_content = Vec::new();
|
||||
let mut has_tool_content = false;
|
||||
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::ToolRequest(req) => {
|
||||
has_tool_content = true;
|
||||
// Convert tool request to text format
|
||||
let text = if let Ok(tool_call) = &req.tool_call {
|
||||
format!(
|
||||
"Using tool: {}\n{{\n \"name\": \"{}\",\n \"arguments\": {}\n}}",
|
||||
tool_call.name,
|
||||
tool_call.name,
|
||||
serde_json::to_string_pretty(&tool_call.arguments)
|
||||
.unwrap_or_default()
|
||||
)
|
||||
} else {
|
||||
"Tool request failed".to_string()
|
||||
};
|
||||
new_content.push(MessageContent::text(text));
|
||||
}
|
||||
MessageContent::ToolResponse(res) => {
|
||||
has_tool_content = true;
|
||||
// Convert tool response to text format
|
||||
let text = match &res.tool_result {
|
||||
Ok(contents) => {
|
||||
let text_contents: Vec<String> = contents
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
Content::Text(t) => Some(t.text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
format!("Tool result:\n{}", text_contents.join("\n"))
|
||||
}
|
||||
Err(e) => format!("Tool error: {}", e),
|
||||
};
|
||||
new_content.push(MessageContent::text(text));
|
||||
}
|
||||
_ => {
|
||||
// Keep other content types as-is
|
||||
new_content.push(content.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if has_tool_content {
|
||||
Message {
|
||||
role: message.role.clone(),
|
||||
content: new_content,
|
||||
created: message.created,
|
||||
}
|
||||
} else {
|
||||
message.clone()
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Modifies the system prompt to include tool usage instructions when tool interpretation is enabled
|
||||
pub fn modify_system_prompt_for_tool_json(system_prompt: &str, tools: &[Tool]) -> String {
|
||||
let tool_info = format_tool_info(tools);
|
||||
|
||||
@@ -28,7 +28,7 @@ fn default_version() -> String {
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
/// use goose::recipe::Recipe;
|
||||
///
|
||||
/// // Using the builder pattern
|
||||
@@ -52,7 +52,7 @@ fn default_version() -> String {
|
||||
/// author: None,
|
||||
/// parameters: None,
|
||||
/// };
|
||||
/// ```
|
||||
///
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Recipe {
|
||||
// Required fields
|
||||
@@ -166,7 +166,7 @@ impl Recipe {
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
/// use goose::recipe::Recipe;
|
||||
///
|
||||
/// let recipe = Recipe::builder()
|
||||
@@ -175,7 +175,7 @@ impl Recipe {
|
||||
/// .instructions("Act as a helpful assistant")
|
||||
/// .build()
|
||||
/// .expect("Failed to build Recipe: missing required fields");
|
||||
/// ```
|
||||
///
|
||||
pub fn builder() -> RecipeBuilder {
|
||||
RecipeBuilder {
|
||||
version: default_version(),
|
||||
|
||||
@@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler};
|
||||
|
||||
use crate::agents::AgentEvent;
|
||||
use crate::agents::{Agent, SessionConfig};
|
||||
use crate::config::{self, Config};
|
||||
use crate::message::Message;
|
||||
@@ -20,6 +21,10 @@ use crate::recipe::Recipe;
|
||||
use crate::session;
|
||||
use crate::session::storage::SessionMetadata;
|
||||
|
||||
// Track running tasks with their abort handles
|
||||
type RunningTasksMap = HashMap<String, tokio::task::AbortHandle>;
|
||||
type JobsMap = HashMap<String, (JobId, ScheduledJob)>;
|
||||
|
||||
pub fn get_default_scheduler_storage_path() -> Result<PathBuf, io::Error> {
|
||||
let strategy = choose_app_strategy(config::APP_STRATEGY.clone())
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::NotFound, e.to_string()))?;
|
||||
@@ -111,11 +116,15 @@ pub struct ScheduledJob {
|
||||
pub currently_running: bool,
|
||||
#[serde(default)]
|
||||
pub paused: bool,
|
||||
#[serde(default)]
|
||||
pub current_session_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub process_start_time: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
async fn persist_jobs_from_arc(
|
||||
storage_path: &Path,
|
||||
jobs_arc: &Arc<Mutex<HashMap<String, (JobId, ScheduledJob)>>>,
|
||||
jobs_arc: &Arc<Mutex<JobsMap>>,
|
||||
) -> Result<(), SchedulerError> {
|
||||
let jobs_guard = jobs_arc.lock().await;
|
||||
let list: Vec<ScheduledJob> = jobs_guard.values().map(|(_, j)| j.clone()).collect();
|
||||
@@ -129,8 +138,9 @@ async fn persist_jobs_from_arc(
|
||||
|
||||
pub struct Scheduler {
|
||||
internal_scheduler: TokioJobScheduler,
|
||||
jobs: Arc<Mutex<HashMap<String, (JobId, ScheduledJob)>>>,
|
||||
jobs: Arc<Mutex<JobsMap>>,
|
||||
storage_path: PathBuf,
|
||||
running_tasks: Arc<Mutex<RunningTasksMap>>,
|
||||
}
|
||||
|
||||
impl Scheduler {
|
||||
@@ -140,11 +150,13 @@ impl Scheduler {
|
||||
.map_err(|e| SchedulerError::SchedulerInternalError(e.to_string()))?;
|
||||
|
||||
let jobs = Arc::new(Mutex::new(HashMap::new()));
|
||||
let running_tasks = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let arc_self = Arc::new(Self {
|
||||
internal_scheduler,
|
||||
jobs,
|
||||
storage_path,
|
||||
running_tasks,
|
||||
});
|
||||
|
||||
arc_self.load_jobs_from_storage().await?;
|
||||
@@ -208,17 +220,21 @@ impl Scheduler {
|
||||
|
||||
let mut stored_job = original_job_spec.clone();
|
||||
stored_job.source = destination_recipe_path.to_string_lossy().into_owned();
|
||||
stored_job.current_session_id = None;
|
||||
stored_job.process_start_time = None;
|
||||
tracing::info!("Updated job source path to: {}", stored_job.source);
|
||||
|
||||
let job_for_task = stored_job.clone();
|
||||
let jobs_arc_for_task = self.jobs.clone();
|
||||
let storage_path_for_task = self.storage_path.clone();
|
||||
let running_tasks_for_task = self.running_tasks.clone();
|
||||
|
||||
let cron_task = Job::new_async(&stored_job.cron, move |_uuid, _l| {
|
||||
let task_job_id = job_for_task.id.clone();
|
||||
let current_jobs_arc = jobs_arc_for_task.clone();
|
||||
let local_storage_path = storage_path_for_task.clone();
|
||||
let job_to_execute = job_for_task.clone(); // Clone for run_scheduled_job_internal
|
||||
let running_tasks_arc = running_tasks_for_task.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
// Check if the job is paused before executing
|
||||
@@ -243,6 +259,7 @@ impl Scheduler {
|
||||
if let Some((_, current_job_in_map)) = jobs_map_guard.get_mut(&task_job_id) {
|
||||
current_job_in_map.last_run = Some(current_time);
|
||||
current_job_in_map.currently_running = true;
|
||||
current_job_in_map.process_start_time = Some(current_time);
|
||||
needs_persist = true;
|
||||
}
|
||||
}
|
||||
@@ -258,14 +275,37 @@ impl Scheduler {
|
||||
);
|
||||
}
|
||||
}
|
||||
// Pass None for provider_override in normal execution
|
||||
let result = run_scheduled_job_internal(job_to_execute, None).await;
|
||||
|
||||
// Spawn the job execution as an abortable task
|
||||
let job_task = tokio::spawn(run_scheduled_job_internal(
|
||||
job_to_execute.clone(),
|
||||
None,
|
||||
Some(current_jobs_arc.clone()),
|
||||
Some(task_job_id.clone()),
|
||||
));
|
||||
|
||||
// Store the abort handle at the scheduler level
|
||||
{
|
||||
let mut running_tasks_guard = running_tasks_arc.lock().await;
|
||||
running_tasks_guard.insert(task_job_id.clone(), job_task.abort_handle());
|
||||
}
|
||||
|
||||
// Wait for the job to complete or be aborted
|
||||
let result = job_task.await;
|
||||
|
||||
// Remove the abort handle
|
||||
{
|
||||
let mut running_tasks_guard = running_tasks_arc.lock().await;
|
||||
running_tasks_guard.remove(&task_job_id);
|
||||
}
|
||||
|
||||
// Update the job status after execution
|
||||
{
|
||||
let mut jobs_map_guard = current_jobs_arc.lock().await;
|
||||
if let Some((_, current_job_in_map)) = jobs_map_guard.get_mut(&task_job_id) {
|
||||
current_job_in_map.currently_running = false;
|
||||
current_job_in_map.current_session_id = None;
|
||||
current_job_in_map.process_start_time = None;
|
||||
needs_persist = true;
|
||||
}
|
||||
}
|
||||
@@ -282,12 +322,27 @@ impl Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = result {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' execution failed: {}",
|
||||
&e.job_id,
|
||||
e.error
|
||||
);
|
||||
match result {
|
||||
Ok(Ok(_session_id)) => {
|
||||
tracing::info!("Scheduled job '{}' completed successfully", &task_job_id);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' execution failed: {}",
|
||||
&e.job_id,
|
||||
e.error
|
||||
);
|
||||
}
|
||||
Err(join_error) if join_error.is_cancelled() => {
|
||||
tracing::info!("Scheduled job '{}' was cancelled/killed", &task_job_id);
|
||||
}
|
||||
Err(join_error) => {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' task failed: {}",
|
||||
&task_job_id,
|
||||
join_error
|
||||
);
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -328,12 +383,14 @@ impl Scheduler {
|
||||
let job_for_task = job_to_load.clone();
|
||||
let jobs_arc_for_task = self.jobs.clone();
|
||||
let storage_path_for_task = self.storage_path.clone();
|
||||
let running_tasks_for_task = self.running_tasks.clone();
|
||||
|
||||
let cron_task = Job::new_async(&job_to_load.cron, move |_uuid, _l| {
|
||||
let task_job_id = job_for_task.id.clone();
|
||||
let current_jobs_arc = jobs_arc_for_task.clone();
|
||||
let local_storage_path = storage_path_for_task.clone();
|
||||
let job_to_execute = job_for_task.clone(); // Clone for run_scheduled_job_internal
|
||||
let running_tasks_arc = running_tasks_for_task.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
// Check if the job is paused before executing
|
||||
@@ -358,6 +415,7 @@ impl Scheduler {
|
||||
if let Some((_, stored_job)) = jobs_map_guard.get_mut(&task_job_id) {
|
||||
stored_job.last_run = Some(current_time);
|
||||
stored_job.currently_running = true;
|
||||
stored_job.process_start_time = Some(current_time);
|
||||
needs_persist = true;
|
||||
}
|
||||
}
|
||||
@@ -373,14 +431,37 @@ impl Scheduler {
|
||||
);
|
||||
}
|
||||
}
|
||||
// Pass None for provider_override in normal execution
|
||||
let result = run_scheduled_job_internal(job_to_execute, None).await;
|
||||
|
||||
// Spawn the job execution as an abortable task
|
||||
let job_task = tokio::spawn(run_scheduled_job_internal(
|
||||
job_to_execute,
|
||||
None,
|
||||
Some(current_jobs_arc.clone()),
|
||||
Some(task_job_id.clone()),
|
||||
));
|
||||
|
||||
// Store the abort handle at the scheduler level
|
||||
{
|
||||
let mut running_tasks_guard = running_tasks_arc.lock().await;
|
||||
running_tasks_guard.insert(task_job_id.clone(), job_task.abort_handle());
|
||||
}
|
||||
|
||||
// Wait for the job to complete or be aborted
|
||||
let result = job_task.await;
|
||||
|
||||
// Remove the abort handle
|
||||
{
|
||||
let mut running_tasks_guard = running_tasks_arc.lock().await;
|
||||
running_tasks_guard.remove(&task_job_id);
|
||||
}
|
||||
|
||||
// Update the job status after execution
|
||||
{
|
||||
let mut jobs_map_guard = current_jobs_arc.lock().await;
|
||||
if let Some((_, stored_job)) = jobs_map_guard.get_mut(&task_job_id) {
|
||||
stored_job.currently_running = false;
|
||||
stored_job.current_session_id = None;
|
||||
stored_job.process_start_time = None;
|
||||
needs_persist = true;
|
||||
}
|
||||
}
|
||||
@@ -397,12 +478,30 @@ impl Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = result {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' execution failed: {}",
|
||||
&e.job_id,
|
||||
e.error
|
||||
);
|
||||
match result {
|
||||
Ok(Ok(_session_id)) => {
|
||||
tracing::info!(
|
||||
"Scheduled job '{}' completed successfully",
|
||||
&task_job_id
|
||||
);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' execution failed: {}",
|
||||
&e.job_id,
|
||||
e.error
|
||||
);
|
||||
}
|
||||
Err(join_error) if join_error.is_cancelled() => {
|
||||
tracing::info!("Scheduled job '{}' was cancelled/killed", &task_job_id);
|
||||
}
|
||||
Err(join_error) => {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' task failed: {}",
|
||||
&task_job_id,
|
||||
join_error
|
||||
);
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -421,7 +520,7 @@ impl Scheduler {
|
||||
// Renamed and kept for direct use when a guard is already held (e.g. add/remove)
|
||||
async fn persist_jobs_to_storage_with_guard(
|
||||
&self,
|
||||
jobs_guard: &tokio::sync::MutexGuard<'_, HashMap<String, (JobId, ScheduledJob)>>,
|
||||
jobs_guard: &tokio::sync::MutexGuard<'_, JobsMap>,
|
||||
) -> Result<(), SchedulerError> {
|
||||
let list: Vec<ScheduledJob> = jobs_guard.values().map(|(_, j)| j.clone()).collect();
|
||||
if let Some(parent) = self.storage_path.parent() {
|
||||
@@ -523,14 +622,36 @@ impl Scheduler {
|
||||
}
|
||||
};
|
||||
|
||||
// Pass None for provider_override in normal execution
|
||||
let run_result = run_scheduled_job_internal(job_to_run.clone(), None).await;
|
||||
// Spawn the job execution as an abortable task for run_now
|
||||
let job_task = tokio::spawn(run_scheduled_job_internal(
|
||||
job_to_run.clone(),
|
||||
None,
|
||||
Some(self.jobs.clone()),
|
||||
Some(sched_id.to_string()),
|
||||
));
|
||||
|
||||
// Store the abort handle for run_now jobs
|
||||
{
|
||||
let mut running_tasks_guard = self.running_tasks.lock().await;
|
||||
running_tasks_guard.insert(sched_id.to_string(), job_task.abort_handle());
|
||||
}
|
||||
|
||||
// Wait for the job to complete or be aborted
|
||||
let run_result = job_task.await;
|
||||
|
||||
// Remove the abort handle
|
||||
{
|
||||
let mut running_tasks_guard = self.running_tasks.lock().await;
|
||||
running_tasks_guard.remove(sched_id);
|
||||
}
|
||||
|
||||
// Clear the currently_running flag after execution
|
||||
{
|
||||
let mut jobs_guard = self.jobs.lock().await;
|
||||
if let Some((_tokio_job_id, job_in_map)) = jobs_guard.get_mut(sched_id) {
|
||||
job_in_map.currently_running = false;
|
||||
job_in_map.current_session_id = None;
|
||||
job_in_map.process_start_time = None;
|
||||
job_in_map.last_run = Some(Utc::now());
|
||||
} // MutexGuard is dropped here
|
||||
}
|
||||
@@ -539,12 +660,24 @@ impl Scheduler {
|
||||
self.persist_jobs().await?;
|
||||
|
||||
match run_result {
|
||||
Ok(session_id) => Ok(session_id),
|
||||
Err(e) => Err(SchedulerError::AnyhowError(anyhow!(
|
||||
Ok(Ok(session_id)) => Ok(session_id),
|
||||
Ok(Err(e)) => Err(SchedulerError::AnyhowError(anyhow!(
|
||||
"Failed to execute job '{}' immediately: {}",
|
||||
sched_id,
|
||||
e.error
|
||||
))),
|
||||
Err(join_error) if join_error.is_cancelled() => {
|
||||
tracing::info!("Run now job '{}' was cancelled/killed", sched_id);
|
||||
Err(SchedulerError::AnyhowError(anyhow!(
|
||||
"Job '{}' was successfully cancelled",
|
||||
sched_id
|
||||
)))
|
||||
}
|
||||
Err(join_error) => Err(SchedulerError::AnyhowError(anyhow!(
|
||||
"Failed to execute job '{}' immediately: {}",
|
||||
sched_id,
|
||||
join_error
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -608,12 +741,14 @@ impl Scheduler {
|
||||
let job_for_task = job_def.clone();
|
||||
let jobs_arc_for_task = self.jobs.clone();
|
||||
let storage_path_for_task = self.storage_path.clone();
|
||||
let running_tasks_for_task = self.running_tasks.clone();
|
||||
|
||||
let cron_task = Job::new_async(&new_cron, move |_uuid, _l| {
|
||||
let task_job_id = job_for_task.id.clone();
|
||||
let current_jobs_arc = jobs_arc_for_task.clone();
|
||||
let local_storage_path = storage_path_for_task.clone();
|
||||
let job_to_execute = job_for_task.clone();
|
||||
let running_tasks_arc = running_tasks_for_task.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
// Check if the job is paused before executing
|
||||
@@ -641,6 +776,7 @@ impl Scheduler {
|
||||
{
|
||||
current_job_in_map.last_run = Some(current_time);
|
||||
current_job_in_map.currently_running = true;
|
||||
current_job_in_map.process_start_time = Some(current_time);
|
||||
needs_persist = true;
|
||||
}
|
||||
}
|
||||
@@ -657,7 +793,29 @@ impl Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
let result = run_scheduled_job_internal(job_to_execute, None).await;
|
||||
// Spawn the job execution as an abortable task
|
||||
let job_task = tokio::spawn(run_scheduled_job_internal(
|
||||
job_to_execute,
|
||||
None,
|
||||
Some(current_jobs_arc.clone()),
|
||||
Some(task_job_id.clone()),
|
||||
));
|
||||
|
||||
// Store the abort handle at the scheduler level
|
||||
{
|
||||
let mut running_tasks_guard = running_tasks_arc.lock().await;
|
||||
running_tasks_guard
|
||||
.insert(task_job_id.clone(), job_task.abort_handle());
|
||||
}
|
||||
|
||||
// Wait for the job to complete or be aborted
|
||||
let result = job_task.await;
|
||||
|
||||
// Remove the abort handle
|
||||
{
|
||||
let mut running_tasks_guard = running_tasks_arc.lock().await;
|
||||
running_tasks_guard.remove(&task_job_id);
|
||||
}
|
||||
|
||||
// Update the job status after execution
|
||||
{
|
||||
@@ -666,6 +824,8 @@ impl Scheduler {
|
||||
jobs_map_guard.get_mut(&task_job_id)
|
||||
{
|
||||
current_job_in_map.currently_running = false;
|
||||
current_job_in_map.current_session_id = None;
|
||||
current_job_in_map.process_start_time = None;
|
||||
needs_persist = true;
|
||||
}
|
||||
}
|
||||
@@ -682,12 +842,33 @@ impl Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = result {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' execution failed: {}",
|
||||
&e.job_id,
|
||||
e.error
|
||||
);
|
||||
match result {
|
||||
Ok(Ok(_session_id)) => {
|
||||
tracing::info!(
|
||||
"Scheduled job '{}' completed successfully",
|
||||
&task_job_id
|
||||
);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' execution failed: {}",
|
||||
&e.job_id,
|
||||
e.error
|
||||
);
|
||||
}
|
||||
Err(join_error) if join_error.is_cancelled() => {
|
||||
tracing::info!(
|
||||
"Scheduled job '{}' was cancelled/killed",
|
||||
&task_job_id
|
||||
);
|
||||
}
|
||||
Err(join_error) => {
|
||||
tracing::error!(
|
||||
"Scheduled job '{}' task failed: {}",
|
||||
&task_job_id,
|
||||
join_error
|
||||
);
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -709,6 +890,70 @@ impl Scheduler {
|
||||
None => Err(SchedulerError::JobNotFound(sched_id.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn kill_running_job(&self, sched_id: &str) -> Result<(), SchedulerError> {
|
||||
let mut jobs_guard = self.jobs.lock().await;
|
||||
match jobs_guard.get_mut(sched_id) {
|
||||
Some((_, job_def)) => {
|
||||
if !job_def.currently_running {
|
||||
return Err(SchedulerError::AnyhowError(anyhow!(
|
||||
"Schedule '{}' is not currently running",
|
||||
sched_id
|
||||
)));
|
||||
}
|
||||
|
||||
tracing::info!("Killing running job '{}'", sched_id);
|
||||
|
||||
// Abort the running task if it exists
|
||||
{
|
||||
let mut running_tasks_guard = self.running_tasks.lock().await;
|
||||
if let Some(abort_handle) = running_tasks_guard.remove(sched_id) {
|
||||
abort_handle.abort();
|
||||
tracing::info!("Aborted running task for job '{}'", sched_id);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"No abort handle found for job '{}' in running tasks map",
|
||||
sched_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Mark the job as no longer running
|
||||
job_def.currently_running = false;
|
||||
job_def.current_session_id = None;
|
||||
job_def.process_start_time = None;
|
||||
|
||||
self.persist_jobs_to_storage_with_guard(&jobs_guard).await?;
|
||||
|
||||
tracing::info!("Successfully killed job '{}'", sched_id);
|
||||
Ok(())
|
||||
}
|
||||
None => Err(SchedulerError::JobNotFound(sched_id.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_running_job_info(
|
||||
&self,
|
||||
sched_id: &str,
|
||||
) -> Result<Option<(String, DateTime<Utc>)>, SchedulerError> {
|
||||
let jobs_guard = self.jobs.lock().await;
|
||||
match jobs_guard.get(sched_id) {
|
||||
Some((_, job_def)) => {
|
||||
if job_def.currently_running {
|
||||
if let (Some(session_id), Some(start_time)) =
|
||||
(&job_def.current_session_id, &job_def.process_start_time)
|
||||
{
|
||||
Ok(Some((session_id.clone(), *start_time)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
None => Err(SchedulerError::JobNotFound(sched_id.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -720,6 +965,8 @@ struct JobExecutionError {
|
||||
async fn run_scheduled_job_internal(
|
||||
job: ScheduledJob,
|
||||
provider_override: Option<Arc<dyn GooseProvider>>, // New optional parameter
|
||||
jobs_arc: Option<Arc<Mutex<JobsMap>>>,
|
||||
job_id: Option<String>,
|
||||
) -> std::result::Result<String, JobExecutionError> {
|
||||
tracing::info!("Executing job: {} (Source: {})", job.id, job.source);
|
||||
|
||||
@@ -811,6 +1058,15 @@ async fn run_scheduled_job_internal(
|
||||
tracing::info!("Agent configured with provider for job '{}'", job.id);
|
||||
|
||||
let session_id_for_return = session::generate_session_id();
|
||||
|
||||
// Update the job with the session ID if we have access to the jobs arc
|
||||
if let (Some(jobs_arc), Some(job_id_str)) = (jobs_arc.as_ref(), job_id.as_ref()) {
|
||||
let mut jobs_guard = jobs_arc.lock().await;
|
||||
if let Some((_, job_def)) = jobs_guard.get_mut(job_id_str) {
|
||||
job_def.current_session_id = Some(session_id_for_return.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let session_file_path = crate::session::storage::get_path(
|
||||
crate::session::storage::Identifier::Name(session_id_for_return.clone()),
|
||||
);
|
||||
@@ -843,13 +1099,19 @@ async fn run_scheduled_job_internal(
|
||||
use futures::StreamExt;
|
||||
|
||||
while let Some(message_result) = stream.next().await {
|
||||
// Check if the task has been cancelled
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
match message_result {
|
||||
Ok(msg) => {
|
||||
Ok(AgentEvent::Message(msg)) => {
|
||||
if msg.role == mcp_core::role::Role::Assistant {
|
||||
tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content);
|
||||
}
|
||||
all_session_messages.push(msg);
|
||||
}
|
||||
Ok(AgentEvent::McpNotification(_)) => {
|
||||
// Handle notifications if needed
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"[Job {}] Error receiving message from agent: {}",
|
||||
@@ -1053,6 +1315,8 @@ mod tests {
|
||||
last_run: None,
|
||||
currently_running: false,
|
||||
paused: false,
|
||||
current_session_id: None,
|
||||
process_start_time: None,
|
||||
};
|
||||
|
||||
// Create the mock provider instance for the test
|
||||
@@ -1061,7 +1325,7 @@ mod tests {
|
||||
|
||||
// Call run_scheduled_job_internal, passing the mock provider
|
||||
let created_session_id =
|
||||
run_scheduled_job_internal(dummy_job.clone(), Some(mock_provider_instance))
|
||||
run_scheduled_job_internal(dummy_job.clone(), Some(mock_provider_instance), None, None)
|
||||
.await
|
||||
.expect("run_scheduled_job_internal failed");
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt;
|
||||
use goose::agents::Agent;
|
||||
use goose::agents::{Agent, AgentEvent};
|
||||
use goose::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::Provider;
|
||||
@@ -132,7 +132,10 @@ async fn run_truncate_test(
|
||||
let mut responses = Vec::new();
|
||||
while let Some(response_result) = reply_stream.next().await {
|
||||
match response_result {
|
||||
Ok(response) => responses.push(response),
|
||||
Ok(AgentEvent::Message(response)) => responses.push(response),
|
||||
Ok(AgentEvent::McpNotification(n)) => {
|
||||
println!("MCP Notification: {n:?}");
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error: {:?}", e);
|
||||
return Err(e);
|
||||
|
||||
@@ -4,7 +4,7 @@ use goose::message::{Message, MessageContent};
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::errors::ProviderError;
|
||||
use goose::providers::{
|
||||
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter,
|
||||
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter, snowflake,
|
||||
};
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::tool::Tool;
|
||||
@@ -491,6 +491,17 @@ async fn test_google_provider() -> Result<()> {
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_snowflake_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"Snowflake",
|
||||
&["SNOWFLAKE_HOST", "SNOWFLAKE_TOKEN"],
|
||||
None,
|
||||
snowflake::SnowflakeProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Print the final test report
|
||||
#[ctor::dtor]
|
||||
fn print_test_report() {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use mcp_client::{
|
||||
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
|
||||
transport::{SseTransport, StdioTransport, Transport},
|
||||
McpService,
|
||||
};
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
@@ -20,18 +19,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
|
||||
let handle1 = transport1.start().await?;
|
||||
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
|
||||
let client1 = McpClient::new(service1);
|
||||
let client1 = McpClient::connect(handle1, Duration::from_secs(30)).await?;
|
||||
|
||||
let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
|
||||
let handle2 = transport2.start().await?;
|
||||
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30));
|
||||
let client2 = McpClient::new(service2);
|
||||
let client2 = McpClient::connect(handle2, Duration::from_secs(30)).await?;
|
||||
|
||||
let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
|
||||
let handle3 = transport3.start().await?;
|
||||
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
|
||||
let client3 = McpClient::new(service3);
|
||||
let client3 = McpClient::connect(handle3, Duration::from_secs(10)).await?;
|
||||
|
||||
// Initialize both clients
|
||||
let mut clients: Vec<Box<dyn McpClientTrait>> =
|
||||
|
||||
122
crates/mcp-client/examples/integration_test.rs
Normal file
122
crates/mcp-client/examples/integration_test.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
use anyhow::Result;
|
||||
use futures::lock::Mutex;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, Transport};
|
||||
use mcp_client::StdioTransport;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize logging
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::from_default_env()
|
||||
.add_directive("mcp_client=debug".parse().unwrap())
|
||||
.add_directive("eventsource_client=info".parse().unwrap()),
|
||||
)
|
||||
.init();
|
||||
|
||||
test_transport(sse_transport().await?).await?;
|
||||
test_transport(stdio_transport().await?).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sse_transport() -> Result<SseTransport> {
|
||||
let port = "60053";
|
||||
|
||||
tokio::process::Command::new("npx")
|
||||
.env("PORT", port)
|
||||
.arg("@modelcontextprotocol/server-everything")
|
||||
.arg("sse")
|
||||
.spawn()?;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
Ok(SseTransport::new(
|
||||
format!("http://localhost:{}/sse", port),
|
||||
HashMap::new(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn stdio_transport() -> Result<StdioTransport> {
|
||||
Ok(StdioTransport::new(
|
||||
"npx",
|
||||
vec!["@modelcontextprotocol/server-everything"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
HashMap::new(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn test_transport<T>(transport: T) -> Result<()>
|
||||
where
|
||||
T: Transport + Send + 'static,
|
||||
{
|
||||
// Start transport
|
||||
let handle = transport.start().await?;
|
||||
|
||||
// Create client
|
||||
let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?;
|
||||
println!("Client created\n");
|
||||
|
||||
let mut receiver = client.subscribe().await;
|
||||
let events = Arc::new(Mutex::new(Vec::new()));
|
||||
let events_clone = events.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(event) = receiver.recv().await {
|
||||
println!("Received event: {event:?}");
|
||||
events_clone.lock().await.push(event);
|
||||
}
|
||||
});
|
||||
|
||||
// Initialize
|
||||
let server_info = client
|
||||
.initialize(
|
||||
ClientInfo {
|
||||
name: "test-client".into(),
|
||||
version: "1.0.0".into(),
|
||||
},
|
||||
ClientCapabilities::default(),
|
||||
)
|
||||
.await?;
|
||||
println!("Connected to server: {server_info:?}\n");
|
||||
|
||||
// Sleep for 100ms to allow the server to start - surprisingly this is required!
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// List tools
|
||||
let tools = client.list_tools(None).await?;
|
||||
println!("Available tools: {tools:#?}\n");
|
||||
|
||||
// Call tool
|
||||
let tool_result = client
|
||||
.call_tool("echo", serde_json::json!({ "message": "honk" }))
|
||||
.await?;
|
||||
println!("Tool result: {tool_result:#?}\n");
|
||||
|
||||
let collected_eventes_before = events.lock().await.len();
|
||||
let n_steps = 5;
|
||||
let long_op = client
|
||||
.call_tool(
|
||||
"longRunningOperation",
|
||||
serde_json::json!({ "duration": 3, "steps": n_steps }),
|
||||
)
|
||||
.await?;
|
||||
println!("Long op result: {long_op:#?}\n");
|
||||
let collected_events_after = events.lock().await.len();
|
||||
assert_eq!(collected_events_after - collected_eventes_before, n_steps);
|
||||
|
||||
// List resources
|
||||
let resources = client.list_resources(None).await?;
|
||||
println!("Resources: {resources:#?}\n");
|
||||
|
||||
// Read resource
|
||||
let resource = client.read_resource("test://static/resource/1").await?;
|
||||
println!("Resource: {resource:#?}\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, Transport};
|
||||
use mcp_client::McpService;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
@@ -23,11 +22,8 @@ async fn main() -> Result<()> {
|
||||
// Start transport
|
||||
let handle = transport.start().await?;
|
||||
|
||||
// Create the service with timeout middleware
|
||||
let service = McpService::with_timeout(handle, Duration::from_secs(3));
|
||||
|
||||
// Create client
|
||||
let mut client = McpClient::new(service);
|
||||
let mut client = McpClient::connect(handle, Duration::from_secs(3)).await?;
|
||||
println!("Client created\n");
|
||||
|
||||
// Initialize
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use mcp_client::{
|
||||
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService,
|
||||
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
|
||||
StdioTransport, Transport,
|
||||
};
|
||||
use std::time::Duration;
|
||||
@@ -25,11 +25,8 @@ async fn main() -> Result<(), ClientError> {
|
||||
// 2) Start the transport to get a handle
|
||||
let transport_handle = transport.start().await?;
|
||||
|
||||
// 3) Create the service with timeout middleware
|
||||
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
|
||||
|
||||
// 4) Create the client with the middleware-wrapped service
|
||||
let mut client = McpClient::new(service);
|
||||
// 3) Create the client with the middleware-wrapped service
|
||||
let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
|
||||
|
||||
// Initialize
|
||||
let server_info = client
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user