mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-09 16:44:24 +01:00
fix: improve configure process with error message (#919)
This commit is contained in:
@@ -4,7 +4,8 @@ use goose::agents::{extension::Envs, ExtensionConfig};
|
||||
use goose::config::{Config, ConfigError, ExtensionEntry, ExtensionManager};
|
||||
use goose::message::Message;
|
||||
use goose::providers::{create, providers};
|
||||
use serde_json::Value;
|
||||
use mcp_core::Tool;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
|
||||
@@ -270,37 +271,36 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
|
||||
let model_config = goose::model::ModelConfig::new(model.clone()).with_max_tokens(Some(10));
|
||||
let provider = create(provider_name, model_config)?;
|
||||
|
||||
let message = Message::user().with_text(
|
||||
"Please give a nice welcome message (one sentence) and let them know they are all set to use this agent"
|
||||
let messages =
|
||||
vec![Message::user().with_text("What is the weather like in San Francisco today?")];
|
||||
let sample_tool = Tool::new(
|
||||
"get_weather".to_string(),
|
||||
"Get current temperature for a given location.".to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let result = provider
|
||||
.complete(
|
||||
"You are an AI agent called Goose. You use tools of connected extensions to solve problems.",
|
||||
&[message],
|
||||
&[]
|
||||
&messages,
|
||||
&[sample_tool]
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok((message, _usage)) => {
|
||||
if let Some(content) = message.content.first() {
|
||||
if let Some(text) = content.as_text() {
|
||||
spin.stop(text);
|
||||
} else {
|
||||
spin.stop("No response text available");
|
||||
}
|
||||
} else {
|
||||
spin.stop("No response content available");
|
||||
}
|
||||
|
||||
Ok((_message, _usage)) => {
|
||||
cliclack::outro("Configuration saved successfully")?;
|
||||
Ok(true)
|
||||
}
|
||||
Err(e) => {
|
||||
println!("{:?}", e);
|
||||
spin.stop("We could not connect!");
|
||||
let _ = cliclack::outro("The provider configuration was invalid");
|
||||
spin.stop(style(e.to_string()).red());
|
||||
cliclack::outro(style("Failed to configure provider: init chat completion request with tool did not succeed.").on_red().white())?;
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,10 +80,11 @@ impl AnthropicProvider {
|
||||
Status: {}. Response: {:?}", status, payload)))
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
if let Some(error) = payload.get("error") {
|
||||
tracing::debug!("Bad Request Error: {error:?}");
|
||||
let error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error");
|
||||
error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
|
||||
if error_msg.to_lowercase().contains("too long") || error_msg.to_lowercase().contains("too many") {
|
||||
return Err(ProviderError::ContextLengthExceeded(error_msg.to_string()));
|
||||
}
|
||||
@@ -91,7 +92,7 @@ impl AnthropicProvider {
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
|
||||
@@ -177,10 +177,15 @@ impl DatabricksProvider {
|
||||
return Err(ProviderError::ContextLengthExceeded(payload_str));
|
||||
}
|
||||
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
error_msg = payload.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
|
||||
@@ -86,9 +86,10 @@ impl GoogleProvider {
|
||||
Status: {}. Response: {:?}", status, payload )))
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
let mut error_msg = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
if let Some(error) = payload.get("error") {
|
||||
let error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error");
|
||||
error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
|
||||
let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status");
|
||||
if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") {
|
||||
return Err(ProviderError::ContextLengthExceeded(error_msg.to_string()));
|
||||
@@ -98,7 +99,7 @@ impl GoogleProvider {
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, error_msg)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
|
||||
@@ -49,18 +49,18 @@ pub async fn handle_response_openai_compat(response: Response) -> Result<Value,
|
||||
Status: {}. Response: {:?}", status, payload)))
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
let mut message = "Unknown error".to_string();
|
||||
if let Some(payload) = &payload {
|
||||
if let Some(error) = payload.get("error") {
|
||||
tracing::debug!("Bad Request Error: {error:?}");
|
||||
if let Some(code) = error.get("code").and_then(|c| c.as_str()) {
|
||||
if code == "context_length_exceeded" || code == "string_above_max_length" {
|
||||
let message = error
|
||||
message = error
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("Unknown error")
|
||||
.to_string();
|
||||
|
||||
|
||||
if let Some(code) = error.get("code").and_then(|c| c.as_str()) {
|
||||
if code == "context_length_exceeded" || code == "string_above_max_length" {
|
||||
return Err(ProviderError::ContextLengthExceeded(message));
|
||||
}
|
||||
}
|
||||
@@ -68,7 +68,7 @@ pub async fn handle_response_openai_compat(response: Response) -> Result<Value,
|
||||
tracing::debug!(
|
||||
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}", status)))
|
||||
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, message)))
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
|
||||
|
||||
Reference in New Issue
Block a user