mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-06 07:04:37 +01:00
chore: remove needless clone() in goose/providers (#2528)
Signed-off-by: Mike Seddon <seddonm1@gmail.com>
This commit is contained in:
@@ -73,7 +73,7 @@ impl AnthropicProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, headers: HeaderMap, payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, headers: HeaderMap, payload: &Value) -> Result<Value, ProviderError> {
|
||||
let base_url = url::Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("v1/messages").map_err(|e| {
|
||||
@@ -84,7 +84,7 @@ impl AnthropicProvider {
|
||||
.client
|
||||
.post(url)
|
||||
.headers(headers)
|
||||
.json(&payload)
|
||||
.json(payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
@@ -198,10 +198,10 @@ impl Provider for AnthropicProvider {
|
||||
}
|
||||
|
||||
// Make request
|
||||
let response = self.post(headers, payload.clone()).await?;
|
||||
let response = self.post(headers, &payload).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = get_usage(&response)?;
|
||||
tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
|
||||
usage.input_tokens, usage.output_tokens, usage.total_tokens);
|
||||
|
||||
@@ -87,7 +87,7 @@ impl AzureProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
|
||||
let mut base_url = url::Url::parse(&self.endpoint)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
@@ -143,7 +143,7 @@ impl AzureProvider {
|
||||
}
|
||||
}
|
||||
|
||||
let response_result = request_builder.json(&payload).send().await;
|
||||
let response_result = request_builder.json(payload).send().await;
|
||||
|
||||
match response_result {
|
||||
Ok(response) => match handle_response_openai_compat(response).await {
|
||||
@@ -249,9 +249,9 @@ impl Provider for AzureProvider {
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&payload).await?;
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
|
||||
@@ -273,7 +273,7 @@ impl DatabricksProvider {
|
||||
}
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
|
||||
// Check if this is an embedding request by looking at the payload structure
|
||||
let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
|
||||
let path = if is_embedding {
|
||||
@@ -284,7 +284,7 @@ impl DatabricksProvider {
|
||||
format!("serving-endpoints/{}/invocations", self.model.model_name)
|
||||
};
|
||||
|
||||
match self.post_with_retry(path.as_str(), &payload).await {
|
||||
match self.post_with_retry(path.as_str(), payload).await {
|
||||
Ok(res) => res.json().await.map_err(|_| {
|
||||
ProviderError::RequestFailed("Response body is not valid JSON".to_string())
|
||||
}),
|
||||
@@ -451,10 +451,10 @@ impl Provider for DatabricksProvider {
|
||||
.expect("payload should have model key")
|
||||
.remove("model");
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&payload).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
@@ -619,7 +619,7 @@ impl EmbeddingCapable for DatabricksProvider {
|
||||
"input": texts,
|
||||
});
|
||||
|
||||
let response = self.post(request).await?;
|
||||
let response = self.post(&request).await?;
|
||||
|
||||
let embeddings = response["data"]
|
||||
.as_array()
|
||||
|
||||
@@ -207,7 +207,7 @@ pub fn format_system(system: &str) -> Value {
|
||||
}
|
||||
|
||||
/// Convert Anthropic's API response to internal Message format
|
||||
pub fn response_to_message(response: Value) -> Result<Message> {
|
||||
pub fn response_to_message(response: &Value) -> Result<Message> {
|
||||
let content_blocks = response
|
||||
.get(CONTENT_FIELD)
|
||||
.and_then(|c| c.as_array())
|
||||
@@ -699,7 +699,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
@@ -740,7 +740,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
|
||||
@@ -790,7 +790,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
assert_eq!(message.content.len(), 3);
|
||||
|
||||
@@ -268,8 +268,8 @@ pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
|
||||
}
|
||||
|
||||
/// Convert Databricks' API response to internal Message format
|
||||
pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
let original = response["choices"][0]["message"].clone();
|
||||
pub fn response_to_message(response: &Value) -> anyhow::Result<Message> {
|
||||
let original = &response["choices"][0]["message"];
|
||||
let mut content = Vec::new();
|
||||
|
||||
// Handle array-based content
|
||||
@@ -737,7 +737,7 @@ mod tests {
|
||||
|
||||
// Get the ID from the tool request to use in the response
|
||||
let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] {
|
||||
request.id.clone()
|
||||
&request.id
|
||||
} else {
|
||||
panic!("should be tool request");
|
||||
};
|
||||
@@ -770,7 +770,7 @@ mod tests {
|
||||
|
||||
// Get the ID from the tool request to use in the response
|
||||
let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] {
|
||||
request.id.clone()
|
||||
&request.id
|
||||
} else {
|
||||
panic!("should be tool request");
|
||||
};
|
||||
@@ -891,7 +891,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
assert_eq!(message.content.len(), 1);
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "Hello from John Cena!");
|
||||
@@ -906,7 +906,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_response_to_message_valid_toolrequest() -> anyhow::Result<()> {
|
||||
let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
assert_eq!(message.content.len(), 1);
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
@@ -926,7 +926,7 @@ mod tests {
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] =
|
||||
json!("invalid fn");
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
match &request.tool_call {
|
||||
@@ -948,7 +948,7 @@ mod tests {
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
|
||||
json!("invalid json {");
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
match &request.tool_call {
|
||||
@@ -970,7 +970,7 @@ mod tests {
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
|
||||
serde_json::Value::String("".to_string());
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
let tool_call = request.tool_call.as_ref().unwrap();
|
||||
@@ -1107,7 +1107,7 @@ mod tests {
|
||||
}]
|
||||
});
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
if let MessageContent::Thinking(thinking) = &message.content[0] {
|
||||
@@ -1154,7 +1154,7 @@ mod tests {
|
||||
}]
|
||||
});
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
if let MessageContent::RedactedThinking(redacted) = &message.content[0] {
|
||||
|
||||
@@ -332,7 +332,7 @@ pub fn create_request(
|
||||
/// * `Result<Message>` - Converted message
|
||||
pub fn response_to_message(response: Value, request_context: RequestContext) -> Result<Message> {
|
||||
match request_context.provider() {
|
||||
ModelProvider::Anthropic => anthropic::response_to_message(response),
|
||||
ModelProvider::Anthropic => anthropic::response_to_message(&response),
|
||||
ModelProvider::Google => google::response_to_message(response),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,8 +268,8 @@ pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
|
||||
}
|
||||
|
||||
/// Convert OpenAI's API response to internal Message format
|
||||
pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
let original = response["choices"][0]["message"].clone();
|
||||
pub fn response_to_message(response: &Value) -> anyhow::Result<Message> {
|
||||
let original = &response["choices"][0]["message"];
|
||||
let mut content = Vec::new();
|
||||
|
||||
if let Some(text) = original.get("content") {
|
||||
@@ -910,7 +910,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
assert_eq!(message.content.len(), 1);
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "Hello from John Cena!");
|
||||
@@ -925,7 +925,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_response_to_message_valid_toolrequest() -> anyhow::Result<()> {
|
||||
let response: Value = serde_json::from_str(OPENAI_TOOL_USE_RESPONSE)?;
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
assert_eq!(message.content.len(), 1);
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
@@ -945,7 +945,7 @@ mod tests {
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] =
|
||||
json!("invalid fn");
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
match &request.tool_call {
|
||||
@@ -967,7 +967,7 @@ mod tests {
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
|
||||
json!("invalid json {");
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
match &request.tool_call {
|
||||
@@ -989,7 +989,7 @@ mod tests {
|
||||
response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"] =
|
||||
serde_json::Value::String("".to_string());
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
let tool_call = request.tool_call.as_ref().unwrap();
|
||||
|
||||
@@ -198,7 +198,7 @@ pub fn parse_streaming_response(sse_data: &str) -> Result<Message> {
|
||||
}
|
||||
|
||||
/// Convert Snowflake's API response to internal Message format
|
||||
pub fn response_to_message(response: Value) -> Result<Message> {
|
||||
pub fn response_to_message(response: &Value) -> Result<Message> {
|
||||
let mut message = Message::assistant();
|
||||
|
||||
let content_list = response.get("content_list").and_then(|cl| cl.as_array());
|
||||
@@ -380,7 +380,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
@@ -417,7 +417,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
|
||||
@@ -625,7 +625,7 @@ data: {"id":"a9537c2c-2017-4906-9817-2456168d89fa","model":"claude-3-5-sonnet","
|
||||
}
|
||||
});
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
// Should have both text and tool request content
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
@@ -281,14 +281,14 @@ impl GcpVertexAIProvider {
|
||||
) -> Result<Url, GcpVertexAIError> {
|
||||
// Create host URL for the specified location
|
||||
let host_url = if self.location == location {
|
||||
self.host.clone()
|
||||
&self.host
|
||||
} else {
|
||||
// Only allocate a new string if location differs
|
||||
self.host.replace(&self.location, location)
|
||||
&self.host.replace(&self.location, location)
|
||||
};
|
||||
|
||||
let base_url =
|
||||
Url::parse(&host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;
|
||||
Url::parse(host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;
|
||||
|
||||
// Determine endpoint based on provider type
|
||||
let endpoint = match provider {
|
||||
@@ -470,10 +470,14 @@ impl GcpVertexAIProvider {
|
||||
/// # Arguments
|
||||
/// * `payload` - The request payload to send
|
||||
/// * `context` - Request context containing model information
|
||||
async fn post(&self, payload: Value, context: &RequestContext) -> Result<Value, ProviderError> {
|
||||
async fn post(
|
||||
&self,
|
||||
payload: &Value,
|
||||
context: &RequestContext,
|
||||
) -> Result<Value, ProviderError> {
|
||||
// Try with user-specified location first
|
||||
let result = self
|
||||
.post_with_location(&payload, context, &self.location)
|
||||
.post_with_location(payload, context, &self.location)
|
||||
.await;
|
||||
|
||||
// If location is already the known location for the model or request succeeded, return result
|
||||
@@ -492,7 +496,7 @@ impl GcpVertexAIProvider {
|
||||
"Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}"
|
||||
);
|
||||
|
||||
self.post_with_location(&payload, context, &known_location)
|
||||
self.post_with_location(payload, context, &known_location)
|
||||
.await
|
||||
}
|
||||
// For any other error, return the original result
|
||||
@@ -609,7 +613,7 @@ impl Provider for GcpVertexAIProvider {
|
||||
let (request, context) = create_request(&self.model, system, messages, tools)?;
|
||||
|
||||
// Send request and process response
|
||||
let response = self.post(request.clone(), &context).await?;
|
||||
let response = self.post(&request, &context).await?;
|
||||
let usage = get_usage(&response, &context)?;
|
||||
|
||||
emit_debug_trace(&self.model, &request, &response, &usage);
|
||||
|
||||
@@ -137,7 +137,7 @@ impl GithubCopilotProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, mut payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &mut Value) -> Result<Value, ProviderError> {
|
||||
use crate::providers::utils_universal_openai_stream::{OAIStreamChunk, OAIStreamCollector};
|
||||
use futures::StreamExt;
|
||||
// Detect gpt-4.1 and stream
|
||||
@@ -159,7 +159,7 @@ impl GithubCopilotProvider {
|
||||
.post(url)
|
||||
.headers(self.get_github_headers())
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.json(&payload)
|
||||
.json(payload)
|
||||
.send()
|
||||
.await?;
|
||||
if stream_only_model {
|
||||
@@ -408,13 +408,14 @@ impl Provider for GithubCopilotProvider {
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
|
||||
let mut payload =
|
||||
create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
|
||||
|
||||
// Make request
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&mut payload).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
|
||||
@@ -86,7 +86,7 @@ impl GoogleProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
@@ -178,7 +178,7 @@ impl Provider for GoogleProvider {
|
||||
let payload = create_request(&self.model, system, messages, tools)?;
|
||||
|
||||
// Make request
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&payload).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(unescape_json_values(&response))?;
|
||||
|
||||
@@ -54,7 +54,7 @@ impl GroqProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> anyhow::Result<Value, ProviderError> {
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("openai/v1/chat/completions").map_err(|e| {
|
||||
@@ -65,7 +65,7 @@ impl GroqProvider {
|
||||
.client
|
||||
.post(url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&payload)
|
||||
.json(payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
@@ -136,9 +136,9 @@ impl Provider for GroqProvider {
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&payload).await?;
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
|
||||
@@ -128,7 +128,7 @@ impl LiteLLMProvider {
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join(&self.base_path).map_err(|e| {
|
||||
@@ -142,7 +142,7 @@ impl LiteLLMProvider {
|
||||
|
||||
let request = self.add_headers(request);
|
||||
|
||||
let response = request.json(&payload).send().await?;
|
||||
let response = request.json(payload).send().await?;
|
||||
|
||||
handle_response_openai_compat(response).await
|
||||
}
|
||||
@@ -196,9 +196,9 @@ impl Provider for LiteLLMProvider {
|
||||
payload = update_request_for_cache_control(&payload);
|
||||
}
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&payload).await?;
|
||||
|
||||
let message = super::formats::openai::response_to_message(response.clone())?;
|
||||
let message = super::formats::openai::response_to_message(&response)?;
|
||||
let usage = super::formats::openai::get_usage(&response);
|
||||
let model = get_model(&response);
|
||||
emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
|
||||
@@ -58,12 +58,12 @@ impl OllamaProvider {
|
||||
fn get_base_url(&self) -> Result<Url, ProviderError> {
|
||||
// OLLAMA_HOST is sometimes just the 'host' or 'host:port' without a scheme
|
||||
let base = if self.host.starts_with("http://") || self.host.starts_with("https://") {
|
||||
self.host.clone()
|
||||
&self.host
|
||||
} else {
|
||||
format!("http://{}", self.host)
|
||||
&format!("http://{}", self.host)
|
||||
};
|
||||
|
||||
let mut base_url = Url::parse(&base)
|
||||
let mut base_url = Url::parse(base)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
// Set the default port if missing
|
||||
@@ -82,7 +82,7 @@ impl OllamaProvider {
|
||||
Ok(base_url)
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
|
||||
// TODO: remove this later when the UI handles provider config refresh
|
||||
let base_url = self.get_base_url()?;
|
||||
|
||||
@@ -90,7 +90,7 @@ impl OllamaProvider {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let response = self.client.post(url).json(&payload).send().await?;
|
||||
let response = self.client.post(url).json(payload).send().await?;
|
||||
|
||||
handle_response_openai_compat(response).await
|
||||
}
|
||||
@@ -143,8 +143,8 @@ impl Provider for OllamaProvider {
|
||||
filtered_tools,
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let message = response_to_message(response.clone())?;
|
||||
let response = self.post(&payload).await?;
|
||||
let message = response_to_message(&response)?;
|
||||
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
|
||||
@@ -113,7 +113,7 @@ impl OpenAiProvider {
|
||||
request
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Response, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> Result<Response, ProviderError> {
|
||||
let base_url = url::Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join(&self.base_path).map_err(|e| {
|
||||
@@ -178,10 +178,10 @@ impl Provider for OpenAiProvider {
|
||||
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
|
||||
|
||||
// Make request
|
||||
let response = handle_response_openai_compat(self.post(payload.clone()).await?).await?;
|
||||
let response = handle_response_openai_compat(self.post(&payload).await?).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
@@ -258,7 +258,7 @@ impl Provider for OpenAiProvider {
|
||||
"include_usage": true,
|
||||
});
|
||||
|
||||
let response = handle_status_openai_compat(self.post(payload.clone()).await?).await?;
|
||||
let response = handle_status_openai_compat(self.post(&payload).await?).await?;
|
||||
|
||||
let stream = response.bytes_stream().map_err(io::Error::other);
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ impl OpenRouterProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("api/v1/chat/completions").map_err(|e| {
|
||||
@@ -79,12 +79,12 @@ impl OpenRouterProvider {
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("HTTP-Referer", "https://block.github.io/goose")
|
||||
.header("X-Title", "Goose")
|
||||
.json(&payload)
|
||||
.json(payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
// Handle Google-compatible model responses differently
|
||||
if is_google_model(&payload) {
|
||||
if is_google_model(payload) {
|
||||
return handle_response_google_compat(response).await;
|
||||
}
|
||||
|
||||
@@ -259,10 +259,10 @@ impl Provider for OpenRouterProvider {
|
||||
let payload = create_request_based_on_model(self, system, messages, tools)?;
|
||||
|
||||
// Make request
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&payload).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
|
||||
@@ -108,7 +108,7 @@ impl SnowflakeProvider {
|
||||
}
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
|
||||
let base_url_str =
|
||||
if !self.host.starts_with("https://") && !self.host.starts_with("http://") {
|
||||
format!("https://{}", self.host)
|
||||
@@ -318,7 +318,7 @@ impl SnowflakeProvider {
|
||||
.unwrap_or_else(|| "Invalid credentials".to_string());
|
||||
|
||||
Err(ProviderError::Authentication(format!(
|
||||
"Authentication failed. Please check your SNOWFLAKE_TOKEN and SNOWFLAKE_HOST configuration. Error: {}",
|
||||
"Authentication failed. Please check your SNOWFLAKE_TOKEN and SNOWFLAKE_HOST configuration. Error: {}",
|
||||
error_msg
|
||||
)))
|
||||
}
|
||||
@@ -426,10 +426,10 @@ impl Provider for SnowflakeProvider {
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let payload = create_request(&self.model, system, messages, tools)?;
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&payload).await?;
|
||||
|
||||
// Parse response
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = get_usage(&response)?;
|
||||
let model = get_model(&response);
|
||||
super::utils::emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
|
||||
@@ -89,12 +89,12 @@ impl OllamaInterpreter {
|
||||
|
||||
// Format the URL correctly with http:// prefix if needed
|
||||
let base = if host.starts_with("http://") || host.starts_with("https://") {
|
||||
host.clone()
|
||||
&host
|
||||
} else {
|
||||
format!("http://{}", host)
|
||||
&format!("http://{}", host)
|
||||
};
|
||||
|
||||
let mut base_url = url::Url::parse(&base)
|
||||
let mut base_url = url::Url::parse(base)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
// Set the default port if missing
|
||||
|
||||
@@ -72,7 +72,7 @@ impl XaiProvider {
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> anyhow::Result<Value, ProviderError> {
|
||||
async fn post(&self, payload: &Value) -> anyhow::Result<Value, ProviderError> {
|
||||
// Ensure the host ends with a slash for proper URL joining
|
||||
let host = if self.host.ends_with('/') {
|
||||
self.host.clone()
|
||||
@@ -163,9 +163,9 @@ impl Provider for XaiProvider {
|
||||
&super::utils::ImageFormat::OpenAi,
|
||||
)?;
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
let response = self.post(&payload).await?;
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let message = response_to_message(&response)?;
|
||||
let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
|
||||
tracing::debug!("Failed to get usage data");
|
||||
Usage::default()
|
||||
|
||||
Reference in New Issue
Block a user