diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index b564a544..512bffbe 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -330,8 +330,28 @@ pub fn create_request( let is_o1 = model_config.model_name.starts_with("o1"); let is_o3 = model_config.model_name.starts_with("o3"); + // Only extract reasoning effort for O1/O3 models + let (model_name, reasoning_effort) = if is_o1 || is_o3 { + let parts: Vec<&str> = model_config.model_name.split('-').collect(); + let last_part = parts.last().unwrap(); + + match *last_part { + "low" | "medium" | "high" => { + let base_name = parts[..parts.len() - 1].join("-"); + (base_name, Some(last_part.to_string())) + } + _ => ( + model_config.model_name.to_string(), + Some("medium".to_string()), + ), + } + } else { + // For non-O family models, use the model name as is and no reasoning effort + (model_config.model_name.to_string(), None) + }; + let system_message = json!({ - "role": if is_o1 { "developer" } else { "system" }, + "role": if is_o1 || is_o3 { "developer" } else { "system" }, "content": system }); @@ -349,10 +369,17 @@ pub fn create_request( messages_array.extend(messages_spec); let mut payload = json!({ - "model": model_config.model_name, + "model": model_name, "messages": messages_array }); + if let Some(effort) = reasoning_effort { + payload + .as_object_mut() + .unwrap() + .insert("reasoning_effort".to_string(), json!(effort)); + } + if !tools_spec.is_empty() { payload .as_object_mut() @@ -778,4 +805,96 @@ mod tests { Ok(()) } + + #[test] + fn test_create_request_gpt_4o() -> anyhow::Result<()> { + // Test default medium reasoning effort for O3 model + let model_config = ModelConfig { + model_name: "gpt-4o".to_string(), + tokenizer_name: "gpt-4o".to_string(), + context_limit: Some(4096), + temperature: None, + max_tokens: Some(1024), + }; + let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; + let obj = request.as_object().unwrap(); + let expected = json!({ + "model": "gpt-4o", + "messages": [ + { + "role": "system", + "content": "system" + } + ], + "max_tokens": 1024 + }); + + for (key, value) in expected.as_object().unwrap() { + assert_eq!(obj.get(key).unwrap(), value); + } + + Ok(()) + } + + #[test] + fn test_create_request_o1_default() -> anyhow::Result<()> { + // Test default medium reasoning effort for O1 model + let model_config = ModelConfig { + model_name: "o1".to_string(), + tokenizer_name: "o1".to_string(), + context_limit: Some(4096), + temperature: None, + max_tokens: Some(1024), + }; + let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; + let obj = request.as_object().unwrap(); + let expected = json!({ + "model": "o1", + "messages": [ + { + "role": "developer", + "content": "system" + } + ], + "reasoning_effort": "medium", + "max_completion_tokens": 1024 + }); + + for (key, value) in expected.as_object().unwrap() { + assert_eq!(obj.get(key).unwrap(), value); + } + + Ok(()) + } + + #[test] + fn test_create_request_o3_custom_reasoning_effort() -> anyhow::Result<()> { + // Test custom reasoning effort for O3 model + let model_config = ModelConfig { + model_name: "o3-mini-high".to_string(), + tokenizer_name: "o3-mini".to_string(), + context_limit: Some(4096), + temperature: None, + max_tokens: Some(1024), + }; + let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?; + let obj = request.as_object().unwrap(); + let expected = json!({ + "model": "o3-mini", + "messages": [ + { + "role": "developer", + "content": "system" + } + ], + "reasoning_effort": "high", + "max_completion_tokens": 1024 + }); + + for (key, value) in expected.as_object().unwrap() { + assert_eq!(obj.get(key).unwrap(), value); + } + + Ok(()) + } } diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index 4c8be6e9..4225797f 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -185,8 +185,8 @@ mod tests { async fn test_truncate_agent_with_openai() -> Result<()> { run_test_with_config(TestConfig { provider_type: ProviderType::OpenAi, - model: "gpt-4o-mini", - context_window: 128_000, + model: "o3-mini-low", + context_window: 200_000, }) .await }