mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-05 22:54:33 +01:00
fix:extra error handling for gemini (#1268)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
use reqwest::StatusCode;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
@@ -35,3 +36,41 @@ impl From<reqwest::Error> for ProviderError {
|
||||
ProviderError::ExecutionError(error.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum GoogleErrorCode {
|
||||
BadRequest = 400,
|
||||
Unauthorized = 401,
|
||||
Forbidden = 403,
|
||||
NotFound = 404,
|
||||
TooManyRequests = 429,
|
||||
InternalServerError = 500,
|
||||
ServiceUnavailable = 503,
|
||||
}
|
||||
|
||||
impl GoogleErrorCode {
|
||||
pub fn to_status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Self::BadRequest => StatusCode::BAD_REQUEST,
|
||||
Self::Unauthorized => StatusCode::UNAUTHORIZED,
|
||||
Self::Forbidden => StatusCode::FORBIDDEN,
|
||||
Self::NotFound => StatusCode::NOT_FOUND,
|
||||
Self::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
|
||||
Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::ServiceUnavailable => StatusCode::SERVICE_UNAVAILABLE,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_code(code: u64) -> Option<Self> {
|
||||
match code {
|
||||
400 => Some(Self::BadRequest),
|
||||
401 => Some(Self::Unauthorized),
|
||||
403 => Some(Self::Forbidden),
|
||||
404 => Some(Self::NotFound),
|
||||
429 => Some(Self::TooManyRequests),
|
||||
500 => Some(Self::InternalServerError),
|
||||
503 => Some(Self::ServiceUnavailable),
|
||||
_ => Some(Self::InternalServerError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,13 @@ use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
||||
use crate::providers::formats::google::{create_request, get_usage, response_to_message};
|
||||
use crate::providers::utils::{emit_debug_trace, unescape_json_values};
|
||||
use crate::providers::utils::{
|
||||
emit_debug_trace, handle_response_google_compat, unescape_json_values,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use mcp_core::tool::Tool;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
@@ -84,44 +86,7 @@ impl GoogleProvider {
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let payload: Option<Value> = response.json().await.ok();
|
||||
|
||||
match status {
|
||||
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||
Status: {}. Response: {:?}", status, payload )))
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
if let Some(error) = payload.get("error") {
|
||||
error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
|
||||
let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status");
|
||||
if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") {
|
||||
return Err(ProviderError::ContextLengthExceeded(error_msg.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Err(ProviderError::ServerError(format!("{:?}", payload)))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
|
||||
}
|
||||
}
|
||||
handle_response_google_compat(response).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,10 @@ use std::time::Duration;
|
||||
|
||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use super::errors::ProviderError;
|
||||
use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat};
|
||||
use super::utils::{
|
||||
emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat,
|
||||
is_google_model,
|
||||
};
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
|
||||
@@ -74,7 +77,11 @@ impl OpenRouterProvider {
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
handle_response_openai_compat(response).await
|
||||
if is_google_model(&payload) {
|
||||
handle_response_google_compat(response).await
|
||||
} else {
|
||||
handle_response_openai_compat(response).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use super::base::Usage;
|
||||
use super::errors::GoogleErrorCode;
|
||||
use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use regex::Regex;
|
||||
@@ -90,6 +91,97 @@ pub async fn handle_response_openai_compat(response: Response) -> Result<Value,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the model is a Google model based on the "model" field in the payload.
|
||||
///
|
||||
/// ### Arguments
|
||||
/// - `payload`: The JSON payload as a `serde_json::Value`.
|
||||
///
|
||||
/// ### Returns
|
||||
/// - `bool`: Returns `true` if the model is a Google model, otherwise `false`.
|
||||
pub fn is_google_model(payload: &Value) -> bool {
|
||||
if let Some(model) = payload.get("model").and_then(|m| m.as_str()) {
|
||||
// Check if the model name contains "google"
|
||||
return model.to_lowercase().contains("google");
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Extracts `StatusCode` from response status or payload error code.
|
||||
/// This function first checks the status code of the response. If the status is successful (2xx),
|
||||
/// it then checks the payload for any error codes and maps them to appropriate `StatusCode`.
|
||||
/// If the status is not successful (e.g., 4xx or 5xx), the original status code is returned.
|
||||
fn get_google_final_status(status: StatusCode, payload: Option<&Value>) -> StatusCode {
|
||||
// If the status is successful, check for an error in the payload
|
||||
if status.is_success() {
|
||||
if let Some(payload) = payload {
|
||||
if let Some(error) = payload.get("error") {
|
||||
if let Some(code) = error.get("code").and_then(|c| c.as_u64()) {
|
||||
if let Some(google_error) = GoogleErrorCode::from_code(code) {
|
||||
return google_error.to_status_code();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
status
|
||||
}
|
||||
|
||||
/// Handle response from Google Gemini API-compatible endpoints.
|
||||
///
|
||||
/// Processes HTTP responses, handling specific statuses and parsing the payload
|
||||
/// for error messages. Logs the response payload for debugging purposes.
|
||||
///
|
||||
/// ### References
|
||||
/// - Error Codes: https://ai.google.dev/gemini-api/docs/troubleshooting?lang=python
|
||||
///
|
||||
/// ### Arguments
|
||||
/// - `response`: The HTTP response to process.
|
||||
///
|
||||
/// ### Returns
|
||||
/// - `Ok(Value)`: Parsed JSON on success.
|
||||
/// - `Err(ProviderError)`: Describes the failure reason.
|
||||
pub async fn handle_response_google_compat(response: Response) -> Result<Value, ProviderError> {
|
||||
let status = response.status();
|
||||
let payload: Option<Value> = response.json().await.ok();
|
||||
let final_status = get_google_final_status(status, payload.as_ref());
|
||||
|
||||
match final_status {
|
||||
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
|
||||
Status: {}. Response: {:?}", final_status, payload )))
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
if let Some(error) = payload.get("error") {
|
||||
error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
|
||||
let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status");
|
||||
if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") {
|
||||
return Err(ProviderError::ContextLengthExceeded(error_msg.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", final_status, error_msg)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
|
||||
Err(ProviderError::ServerError(format!("{:?}", payload)))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", final_status)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sanitize_function_name(name: &str) -> String {
|
||||
let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap();
|
||||
re.replace_all(name, "_").to_string()
|
||||
@@ -253,8 +345,6 @@ pub fn emit_debug_trace<T: serde::Serialize>(
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_detect_image_path() {
|
||||
@@ -395,4 +485,70 @@ mod tests {
|
||||
let unescaped_value = unescape_json_values(&value);
|
||||
assert_eq!(unescaped_value, json!({"text": "Hello World"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_google_model() {
|
||||
// Define the test cases as a vector of tuples
|
||||
let test_cases = vec![
|
||||
// (input, expected_result)
|
||||
(json!({ "model": "google_gemini" }), true),
|
||||
(json!({ "model": "microsoft_bing" }), false),
|
||||
(json!({ "model": "" }), false),
|
||||
(json!({}), false),
|
||||
(json!({ "model": "Google_XYZ" }), true),
|
||||
(json!({ "model": "google_abc" }), true),
|
||||
];
|
||||
|
||||
// Iterate through each test case and assert the result
|
||||
for (payload, expected_result) in test_cases {
|
||||
assert_eq!(is_google_model(&payload), expected_result);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_google_final_status_success() {
|
||||
let status = StatusCode::OK;
|
||||
let payload = json!({});
|
||||
let result = get_google_final_status(status, Some(&payload));
|
||||
assert_eq!(result, StatusCode::OK);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_google_final_status_with_error_code() {
|
||||
// Test error code mappings for different payload error codes
|
||||
let test_cases = vec![
|
||||
// (error code, status, expected status code)
|
||||
(200, None, StatusCode::OK),
|
||||
(429, Some(StatusCode::OK), StatusCode::TOO_MANY_REQUESTS),
|
||||
(400, Some(StatusCode::OK), StatusCode::BAD_REQUEST),
|
||||
(401, Some(StatusCode::OK), StatusCode::UNAUTHORIZED),
|
||||
(403, Some(StatusCode::OK), StatusCode::FORBIDDEN),
|
||||
(404, Some(StatusCode::OK), StatusCode::NOT_FOUND),
|
||||
(500, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR),
|
||||
(503, Some(StatusCode::OK), StatusCode::SERVICE_UNAVAILABLE),
|
||||
(999, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR),
|
||||
(500, Some(StatusCode::BAD_REQUEST), StatusCode::BAD_REQUEST),
|
||||
(
|
||||
404,
|
||||
Some(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
];
|
||||
|
||||
for (error_code, status, expected_status) in test_cases {
|
||||
let payload = if let Some(_status) = status {
|
||||
json!({
|
||||
"error": {
|
||||
"code": error_code,
|
||||
"message": "Error message"
|
||||
}
|
||||
})
|
||||
} else {
|
||||
json!({})
|
||||
};
|
||||
|
||||
let result = get_google_final_status(status.unwrap_or(StatusCode::OK), Some(&payload));
|
||||
assert_eq!(result, expected_status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user