From d8ca7a3e1bf2a462bc6ea9d078ca637a957646d6 Mon Sep 17 00:00:00 2001 From: Yingjie He Date: Tue, 18 Feb 2025 12:57:29 -0800 Subject: [PATCH] fix:extra error handling for gemini (#1268) --- crates/goose/src/providers/errors.rs | 39 ++++++ crates/goose/src/providers/google.rs | 45 +------ crates/goose/src/providers/openrouter.rs | 11 +- crates/goose/src/providers/utils.rs | 160 ++++++++++++++++++++++- 4 files changed, 211 insertions(+), 44 deletions(-) diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index 7bd50f92..1536107d 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -1,3 +1,4 @@ +use reqwest::StatusCode; use thiserror::Error; #[derive(Error, Debug)] @@ -35,3 +36,41 @@ impl From for ProviderError { ProviderError::ExecutionError(error.to_string()) } } + +#[derive(Debug)] +pub enum GoogleErrorCode { + BadRequest = 400, + Unauthorized = 401, + Forbidden = 403, + NotFound = 404, + TooManyRequests = 429, + InternalServerError = 500, + ServiceUnavailable = 503, +} + +impl GoogleErrorCode { + pub fn to_status_code(&self) -> StatusCode { + match self { + Self::BadRequest => StatusCode::BAD_REQUEST, + Self::Unauthorized => StatusCode::UNAUTHORIZED, + Self::Forbidden => StatusCode::FORBIDDEN, + Self::NotFound => StatusCode::NOT_FOUND, + Self::TooManyRequests => StatusCode::TOO_MANY_REQUESTS, + Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR, + Self::ServiceUnavailable => StatusCode::SERVICE_UNAVAILABLE, + } + } + + pub fn from_code(code: u64) -> Option { + match code { + 400 => Some(Self::BadRequest), + 401 => Some(Self::Unauthorized), + 403 => Some(Self::Forbidden), + 404 => Some(Self::NotFound), + 429 => Some(Self::TooManyRequests), + 500 => Some(Self::InternalServerError), + 503 => Some(Self::ServiceUnavailable), + _ => Some(Self::InternalServerError), + } + } +} diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index 740b2d65..03186bba 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -3,11 +3,13 @@ use crate::message::Message; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use crate::providers::formats::google::{create_request, get_usage, response_to_message}; -use crate::providers::utils::{emit_debug_trace, unescape_json_values}; +use crate::providers::utils::{ + emit_debug_trace, handle_response_google_compat, unescape_json_values, +}; use anyhow::Result; use async_trait::async_trait; use mcp_core::tool::Tool; -use reqwest::{Client, StatusCode}; +use reqwest::Client; use serde_json::Value; use std::time::Duration; use url::Url; @@ -84,44 +86,7 @@ impl GoogleProvider { .send() .await?; - let status = response.status(); - let payload: Option = response.json().await.ok(); - - match status { - StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}. Response: {:?}", status, payload ))) - } - StatusCode::BAD_REQUEST => { - let mut error_msg = "Unknown error".to_string(); - if let Some(payload) = &payload { - if let Some(error) = payload.get("error") { - error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string(); - let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status"); - if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") { - return Err(ProviderError::ContextLengthExceeded(error_msg.to_string())); - } - } - } - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg))) - } - StatusCode::TOO_MANY_REQUESTS => { - Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) - } - StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { - Err(ProviderError::ServerError(format!("{:?}", payload))) - } - _ => { - tracing::debug!( - "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) - ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status))) - } - } + handle_response_google_compat(response).await } } diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 290bc03f..e0028aac 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -6,7 +6,10 @@ use std::time::Duration; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; -use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat}; +use super::utils::{ + emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat, + is_google_model, +}; use crate::message::Message; use crate::model::ModelConfig; use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; @@ -74,7 +77,11 @@ impl OpenRouterProvider { .send() .await?; - handle_response_openai_compat(response).await + if is_google_model(&payload) { + handle_response_google_compat(response).await + } else { + handle_response_openai_compat(response).await + } } } diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 52881470..8b84bff7 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -1,4 +1,5 @@ use super::base::Usage; +use super::errors::GoogleErrorCode; use anyhow::Result; use base64::Engine; use regex::Regex; @@ -90,6 +91,97 @@ pub async fn handle_response_openai_compat(response: Response) -> Result bool { + if let Some(model) = payload.get("model").and_then(|m| m.as_str()) { + // Check if the model name contains "google" + return model.to_lowercase().contains("google"); + } + false +} + +/// Extracts `StatusCode` from response status or payload error code. +/// This function first checks the status code of the response. If the status is successful (2xx), +/// it then checks the payload for any error codes and maps them to appropriate `StatusCode`. +/// If the status is not successful (e.g., 4xx or 5xx), the original status code is returned. +fn get_google_final_status(status: StatusCode, payload: Option<&Value>) -> StatusCode { + // If the status is successful, check for an error in the payload + if status.is_success() { + if let Some(payload) = payload { + if let Some(error) = payload.get("error") { + if let Some(code) = error.get("code").and_then(|c| c.as_u64()) { + if let Some(google_error) = GoogleErrorCode::from_code(code) { + return google_error.to_status_code(); + } + } + } + } + } + status +} + +/// Handle response from Google Gemini API-compatible endpoints. +/// +/// Processes HTTP responses, handling specific statuses and parsing the payload +/// for error messages. Logs the response payload for debugging purposes. +/// +/// ### References +/// - Error Codes: https://ai.google.dev/gemini-api/docs/troubleshooting?lang=python +/// +/// ### Arguments +/// - `response`: The HTTP response to process. +/// +/// ### Returns +/// - `Ok(Value)`: Parsed JSON on success. +/// - `Err(ProviderError)`: Describes the failure reason. +pub async fn handle_response_google_compat(response: Response) -> Result { + let status = response.status(); + let payload: Option = response.json().await.ok(); + let final_status = get_google_final_status(status, payload.as_ref()); + + match final_status { + StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \ + Status: {}. Response: {:?}", final_status, payload ))) + } + StatusCode::BAD_REQUEST => { + let mut error_msg = "Unknown error".to_string(); + if let Some(payload) = &payload { + if let Some(error) = payload.get("error") { + error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string(); + let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status"); + if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") { + return Err(ProviderError::ContextLengthExceeded(error_msg.to_string())); + } + } + } + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload) + ); + Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", final_status, error_msg))) + } + StatusCode::TOO_MANY_REQUESTS => { + Err(ProviderError::RateLimitExceeded(format!("{:?}", payload))) + } + StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => { + Err(ProviderError::ServerError(format!("{:?}", payload))) + } + _ => { + tracing::debug!( + "{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload) + ); + Err(ProviderError::RequestFailed(format!("Request failed with status: {}", final_status))) + } + } +} + pub fn sanitize_function_name(name: &str) -> String { let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap(); re.replace_all(name, "_").to_string() @@ -253,8 +345,6 @@ pub fn emit_debug_trace( mod tests { use super::*; use serde_json::json; - use std::io::Write; - use tempfile::NamedTempFile; #[test] fn test_detect_image_path() { @@ -395,4 +485,70 @@ mod tests { let unescaped_value = unescape_json_values(&value); assert_eq!(unescaped_value, json!({"text": "Hello World"})); } + + #[test] + fn test_is_google_model() { + // Define the test cases as a vector of tuples + let test_cases = vec![ + // (input, expected_result) + (json!({ "model": "google_gemini" }), true), + (json!({ "model": "microsoft_bing" }), false), + (json!({ "model": "" }), false), + (json!({}), false), + (json!({ "model": "Google_XYZ" }), true), + (json!({ "model": "google_abc" }), true), + ]; + + // Iterate through each test case and assert the result + for (payload, expected_result) in test_cases { + assert_eq!(is_google_model(&payload), expected_result); + } + } + + #[test] + fn test_get_google_final_status_success() { + let status = StatusCode::OK; + let payload = json!({}); + let result = get_google_final_status(status, Some(&payload)); + assert_eq!(result, StatusCode::OK); + } + + #[test] + fn test_get_google_final_status_with_error_code() { + // Test error code mappings for different payload error codes + let test_cases = vec![ + // (error code, status, expected status code) + (200, None, StatusCode::OK), + (429, Some(StatusCode::OK), StatusCode::TOO_MANY_REQUESTS), + (400, Some(StatusCode::OK), StatusCode::BAD_REQUEST), + (401, Some(StatusCode::OK), StatusCode::UNAUTHORIZED), + (403, Some(StatusCode::OK), StatusCode::FORBIDDEN), + (404, Some(StatusCode::OK), StatusCode::NOT_FOUND), + (500, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR), + (503, Some(StatusCode::OK), StatusCode::SERVICE_UNAVAILABLE), + (999, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR), + (500, Some(StatusCode::BAD_REQUEST), StatusCode::BAD_REQUEST), + ( + 404, + Some(StatusCode::INTERNAL_SERVER_ERROR), + StatusCode::INTERNAL_SERVER_ERROR, + ), + ]; + + for (error_code, status, expected_status) in test_cases { + let payload = if let Some(_status) = status { + json!({ + "error": { + "code": error_code, + "message": "Error message" + } + }) + } else { + json!({}) + }; + + let result = get_google_final_status(status.unwrap_or(StatusCode::OK), Some(&payload)); + assert_eq!(result, expected_status); + } + } }