fix:extra error handling for gemini (#1268)

This commit is contained in:
Yingjie He
2025-02-18 12:57:29 -08:00
committed by GitHub
parent 7c5aaa1c75
commit d8ca7a3e1b
4 changed files with 211 additions and 44 deletions

View File

@@ -1,3 +1,4 @@
use reqwest::StatusCode;
use thiserror::Error;
#[derive(Error, Debug)]
@@ -35,3 +36,41 @@ impl From<reqwest::Error> for ProviderError {
ProviderError::ExecutionError(error.to_string())
}
}
#[derive(Debug)]
pub enum GoogleErrorCode {
BadRequest = 400,
Unauthorized = 401,
Forbidden = 403,
NotFound = 404,
TooManyRequests = 429,
InternalServerError = 500,
ServiceUnavailable = 503,
}
impl GoogleErrorCode {
pub fn to_status_code(&self) -> StatusCode {
match self {
Self::BadRequest => StatusCode::BAD_REQUEST,
Self::Unauthorized => StatusCode::UNAUTHORIZED,
Self::Forbidden => StatusCode::FORBIDDEN,
Self::NotFound => StatusCode::NOT_FOUND,
Self::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
Self::InternalServerError => StatusCode::INTERNAL_SERVER_ERROR,
Self::ServiceUnavailable => StatusCode::SERVICE_UNAVAILABLE,
}
}
pub fn from_code(code: u64) -> Option<Self> {
match code {
400 => Some(Self::BadRequest),
401 => Some(Self::Unauthorized),
403 => Some(Self::Forbidden),
404 => Some(Self::NotFound),
429 => Some(Self::TooManyRequests),
500 => Some(Self::InternalServerError),
503 => Some(Self::ServiceUnavailable),
_ => Some(Self::InternalServerError),
}
}
}

View File

@@ -3,11 +3,13 @@ use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use crate::providers::formats::google::{create_request, get_usage, response_to_message};
use crate::providers::utils::{emit_debug_trace, unescape_json_values};
use crate::providers::utils::{
emit_debug_trace, handle_response_google_compat, unescape_json_values,
};
use anyhow::Result;
use async_trait::async_trait;
use mcp_core::tool::Tool;
use reqwest::{Client, StatusCode};
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;
use url::Url;
@@ -84,44 +86,7 @@ impl GoogleProvider {
.send()
.await?;
let status = response.status();
let payload: Option<Value> = response.json().await.ok();
match status {
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}", status, payload )))
}
StatusCode::BAD_REQUEST => {
let mut error_msg = "Unknown error".to_string();
if let Some(payload) = &payload {
if let Some(error) = payload.get("error") {
error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status");
if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") {
return Err(ProviderError::ContextLengthExceeded(error_msg.to_string()));
}
}
}
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)))
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
_ => {
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
}
}
handle_response_google_compat(response).await
}
}

View File

@@ -6,7 +6,10 @@ use std::time::Duration;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat};
use super::utils::{
emit_debug_trace, get_model, handle_response_google_compat, handle_response_openai_compat,
is_google_model,
};
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
@@ -74,7 +77,11 @@ impl OpenRouterProvider {
.send()
.await?;
handle_response_openai_compat(response).await
if is_google_model(&payload) {
handle_response_google_compat(response).await
} else {
handle_response_openai_compat(response).await
}
}
}

View File

@@ -1,4 +1,5 @@
use super::base::Usage;
use super::errors::GoogleErrorCode;
use anyhow::Result;
use base64::Engine;
use regex::Regex;
@@ -90,6 +91,97 @@ pub async fn handle_response_openai_compat(response: Response) -> Result<Value,
}
}
/// Check if the model is a Google model based on the "model" field in the payload.
///
/// ### Arguments
/// - `payload`: The JSON payload as a `serde_json::Value`.
///
/// ### Returns
/// - `bool`: Returns `true` if the model is a Google model, otherwise `false`.
pub fn is_google_model(payload: &Value) -> bool {
if let Some(model) = payload.get("model").and_then(|m| m.as_str()) {
// Check if the model name contains "google"
return model.to_lowercase().contains("google");
}
false
}
/// Extracts `StatusCode` from response status or payload error code.
/// This function first checks the status code of the response. If the status is successful (2xx),
/// it then checks the payload for any error codes and maps them to appropriate `StatusCode`.
/// If the status is not successful (e.g., 4xx or 5xx), the original status code is returned.
fn get_google_final_status(status: StatusCode, payload: Option<&Value>) -> StatusCode {
// If the status is successful, check for an error in the payload
if status.is_success() {
if let Some(payload) = payload {
if let Some(error) = payload.get("error") {
if let Some(code) = error.get("code").and_then(|c| c.as_u64()) {
if let Some(google_error) = GoogleErrorCode::from_code(code) {
return google_error.to_status_code();
}
}
}
}
}
status
}
/// Handle response from Google Gemini API-compatible endpoints.
///
/// Processes HTTP responses, handling specific statuses and parsing the payload
/// for error messages. Logs the response payload for debugging purposes.
///
/// ### References
/// - Error Codes: https://ai.google.dev/gemini-api/docs/troubleshooting?lang=python
///
/// ### Arguments
/// - `response`: The HTTP response to process.
///
/// ### Returns
/// - `Ok(Value)`: Parsed JSON on success.
/// - `Err(ProviderError)`: Describes the failure reason.
pub async fn handle_response_google_compat(response: Response) -> Result<Value, ProviderError> {
let status = response.status();
let payload: Option<Value> = response.json().await.ok();
let final_status = get_google_final_status(status, payload.as_ref());
match final_status {
StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}", final_status, payload )))
}
StatusCode::BAD_REQUEST => {
let mut error_msg = "Unknown error".to_string();
if let Some(payload) = &payload {
if let Some(error) = payload.get("error") {
error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status");
if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") {
return Err(ProviderError::ContextLengthExceeded(error_msg.to_string()));
}
}
}
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", final_status, error_msg)))
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
_ => {
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", final_status)))
}
}
}
pub fn sanitize_function_name(name: &str) -> String {
let re = Regex::new(r"[^a-zA-Z0-9_-]").unwrap();
re.replace_all(name, "_").to_string()
@@ -253,8 +345,6 @@ pub fn emit_debug_trace<T: serde::Serialize>(
mod tests {
use super::*;
use serde_json::json;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_detect_image_path() {
@@ -395,4 +485,70 @@ mod tests {
let unescaped_value = unescape_json_values(&value);
assert_eq!(unescaped_value, json!({"text": "Hello World"}));
}
#[test]
fn test_is_google_model() {
// Define the test cases as a vector of tuples
let test_cases = vec![
// (input, expected_result)
(json!({ "model": "google_gemini" }), true),
(json!({ "model": "microsoft_bing" }), false),
(json!({ "model": "" }), false),
(json!({}), false),
(json!({ "model": "Google_XYZ" }), true),
(json!({ "model": "google_abc" }), true),
];
// Iterate through each test case and assert the result
for (payload, expected_result) in test_cases {
assert_eq!(is_google_model(&payload), expected_result);
}
}
#[test]
fn test_get_google_final_status_success() {
let status = StatusCode::OK;
let payload = json!({});
let result = get_google_final_status(status, Some(&payload));
assert_eq!(result, StatusCode::OK);
}
#[test]
fn test_get_google_final_status_with_error_code() {
// Test error code mappings for different payload error codes
let test_cases = vec![
// (error code, status, expected status code)
(200, None, StatusCode::OK),
(429, Some(StatusCode::OK), StatusCode::TOO_MANY_REQUESTS),
(400, Some(StatusCode::OK), StatusCode::BAD_REQUEST),
(401, Some(StatusCode::OK), StatusCode::UNAUTHORIZED),
(403, Some(StatusCode::OK), StatusCode::FORBIDDEN),
(404, Some(StatusCode::OK), StatusCode::NOT_FOUND),
(500, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR),
(503, Some(StatusCode::OK), StatusCode::SERVICE_UNAVAILABLE),
(999, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR),
(500, Some(StatusCode::BAD_REQUEST), StatusCode::BAD_REQUEST),
(
404,
Some(StatusCode::INTERNAL_SERVER_ERROR),
StatusCode::INTERNAL_SERVER_ERROR,
),
];
for (error_code, status, expected_status) in test_cases {
let payload = if let Some(_status) = status {
json!({
"error": {
"code": error_code,
"message": "Error message"
}
})
} else {
json!({})
};
let result = get_google_final_status(status.unwrap_or(StatusCode::OK), Some(&payload));
assert_eq!(result, expected_status);
}
}
}