mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
fix(ui): enable selection of zero-config providers in desktop GUI (#3378)
Signed-off-by: Kyle Santiago <kyle@privkey.io>
This commit is contained in:
@@ -109,6 +109,13 @@ pub fn inspect_keys(
|
|||||||
pub fn check_provider_configured(metadata: &ProviderMetadata) -> bool {
|
pub fn check_provider_configured(metadata: &ProviderMetadata) -> bool {
|
||||||
let config = Config::global();
|
let config = Config::global();
|
||||||
|
|
||||||
|
// Special case: Zero-config providers (no config keys)
|
||||||
|
if metadata.config_keys.is_empty() {
|
||||||
|
// Check if the provider has been explicitly configured via the UI
|
||||||
|
let configured_marker = format!("{}_configured", metadata.name);
|
||||||
|
return config.get_param::<bool>(&configured_marker).is_ok();
|
||||||
|
}
|
||||||
|
|
||||||
// Get all required keys
|
// Get all required keys
|
||||||
let required_keys: Vec<&ConfigKey> = metadata
|
let required_keys: Vec<&ConfigKey> = metadata
|
||||||
.config_keys
|
.config_keys
|
||||||
@@ -128,6 +135,21 @@ pub fn check_provider_configured(metadata: &ProviderMetadata) -> bool {
|
|||||||
return is_set_in_env || is_set_in_config;
|
return is_set_in_env || is_set_in_config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Special case: If a provider has only optional keys with defaults,
|
||||||
|
// check if a configuration marker exists
|
||||||
|
if required_keys.is_empty() && !metadata.config_keys.is_empty() {
|
||||||
|
let all_optional_with_defaults = metadata
|
||||||
|
.config_keys
|
||||||
|
.iter()
|
||||||
|
.all(|key| !key.required && key.default.is_some());
|
||||||
|
|
||||||
|
if all_optional_with_defaults {
|
||||||
|
// Check if the provider has been explicitly configured via the UI
|
||||||
|
let configured_marker = format!("{}_configured", metadata.name);
|
||||||
|
return config.get_param::<bool>(&configured_marker).is_ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// For providers with multiple keys or keys without defaults:
|
// For providers with multiple keys or keys without defaults:
|
||||||
// Find required keys that don't have default values
|
// Find required keys that don't have default values
|
||||||
let required_non_default_keys: Vec<&ConfigKey> = required_keys
|
let required_non_default_keys: Vec<&ConfigKey> = required_keys
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::process::Stdio;
|
use std::process::Stdio;
|
||||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
@@ -8,13 +9,14 @@ use tokio::process::Command;
|
|||||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||||
use super::errors::ProviderError;
|
use super::errors::ProviderError;
|
||||||
use super::utils::emit_debug_trace;
|
use super::utils::emit_debug_trace;
|
||||||
|
use crate::config::Config;
|
||||||
use crate::message::{Message, MessageContent};
|
use crate::message::{Message, MessageContent};
|
||||||
use crate::model::ModelConfig;
|
use crate::model::ModelConfig;
|
||||||
use mcp_core::tool::Tool;
|
use mcp_core::tool::Tool;
|
||||||
use rmcp::model::Role;
|
use mcp_core::Role;
|
||||||
|
|
||||||
pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "default";
|
pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "claude-3-5-sonnet-latest";
|
||||||
pub const CLAUDE_CODE_KNOWN_MODELS: &[&str] = &["default"];
|
pub const CLAUDE_CODE_KNOWN_MODELS: &[&str] = &["sonnet", "opus", "claude-3-5-sonnet-latest"];
|
||||||
|
|
||||||
pub const CLAUDE_CODE_DOC_URL: &str = "https://claude.ai/cli";
|
pub const CLAUDE_CODE_DOC_URL: &str = "https://claude.ai/cli";
|
||||||
|
|
||||||
@@ -38,7 +40,71 @@ impl ClaudeCodeProvider {
|
|||||||
.get_param("CLAUDE_CODE_COMMAND")
|
.get_param("CLAUDE_CODE_COMMAND")
|
||||||
.unwrap_or_else(|_| "claude".to_string());
|
.unwrap_or_else(|_| "claude".to_string());
|
||||||
|
|
||||||
Ok(Self { command, model })
|
let resolved_command = if !command.contains('/') {
|
||||||
|
Self::find_claude_executable(&command).unwrap_or(command)
|
||||||
|
} else {
|
||||||
|
command
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
command: resolved_command,
|
||||||
|
model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search for claude executable in common installation locations
|
||||||
|
fn find_claude_executable(command_name: &str) -> Option<String> {
|
||||||
|
let home = std::env::var("HOME").ok()?;
|
||||||
|
|
||||||
|
let search_paths = vec![
|
||||||
|
format!("{}/.claude/local/{}", home, command_name),
|
||||||
|
format!("{}/.local/bin/{}", home, command_name),
|
||||||
|
format!("{}/bin/{}", home, command_name),
|
||||||
|
format!("/usr/local/bin/{}", command_name),
|
||||||
|
format!("/usr/bin/{}", command_name),
|
||||||
|
format!("/opt/claude/{}", command_name),
|
||||||
|
];
|
||||||
|
|
||||||
|
for path in search_paths {
|
||||||
|
let path_buf = PathBuf::from(&path);
|
||||||
|
if path_buf.exists() && path_buf.is_file() {
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
if let Ok(metadata) = std::fs::metadata(&path_buf) {
|
||||||
|
let permissions = metadata.permissions();
|
||||||
|
if permissions.mode() & 0o111 != 0 {
|
||||||
|
tracing::info!("Found claude executable at: {}", path);
|
||||||
|
return Some(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
{
|
||||||
|
tracing::info!("Found claude executable at: {}", path);
|
||||||
|
return Some(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(path_var) = std::env::var("PATH") {
|
||||||
|
#[cfg(unix)]
|
||||||
|
let path_separator = ':';
|
||||||
|
#[cfg(windows)]
|
||||||
|
let path_separator = ';';
|
||||||
|
|
||||||
|
for dir in path_var.split(path_separator) {
|
||||||
|
let path_buf = PathBuf::from(dir).join(command_name);
|
||||||
|
if path_buf.exists() && path_buf.is_file() {
|
||||||
|
let full_path = path_buf.to_string_lossy().to_string();
|
||||||
|
tracing::info!("Found claude executable in PATH at: {}", full_path);
|
||||||
|
return Some(full_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!("Could not find claude executable in common locations");
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Filter out the Extensions section from the system prompt
|
/// Filter out the Extensions section from the system prompt
|
||||||
@@ -97,8 +163,13 @@ impl ClaudeCodeProvider {
|
|||||||
// Convert tool result contents to text
|
// Convert tool result contents to text
|
||||||
let content_text = tool_contents
|
let content_text = tool_contents
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|content| content.as_text().map(|t| t.text.clone()))
|
.filter_map(|content| match &content.raw {
|
||||||
.collect::<Vec<_>>()
|
rmcp::model::RawContent::Text(text_content) => {
|
||||||
|
Some(text_content.text.as_str())
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
.join("\n");
|
.join("\n");
|
||||||
|
|
||||||
content_parts.push(json!({
|
content_parts.push(json!({
|
||||||
@@ -215,11 +286,12 @@ impl ClaudeCodeProvider {
|
|||||||
|
|
||||||
let message_content = vec![MessageContent::text(combined_text)];
|
let message_content = vec![MessageContent::text(combined_text)];
|
||||||
|
|
||||||
let response_message = Message::new(
|
let response_message = Message {
|
||||||
Role::Assistant,
|
id: None,
|
||||||
chrono::Utc::now().timestamp(),
|
role: Role::Assistant,
|
||||||
message_content,
|
created: chrono::Utc::now().timestamp(),
|
||||||
);
|
content: message_content,
|
||||||
|
};
|
||||||
|
|
||||||
Ok((response_message, usage))
|
Ok((response_message, usage))
|
||||||
}
|
}
|
||||||
@@ -261,10 +333,20 @@ impl ClaudeCodeProvider {
|
|||||||
.arg(messages_json.to_string())
|
.arg(messages_json.to_string())
|
||||||
.arg("--system-prompt")
|
.arg("--system-prompt")
|
||||||
.arg(&filtered_system)
|
.arg(&filtered_system)
|
||||||
|
.arg("--model")
|
||||||
|
.arg(&self.model.model_name)
|
||||||
.arg("--verbose")
|
.arg("--verbose")
|
||||||
.arg("--output-format")
|
.arg("--output-format")
|
||||||
.arg("json");
|
.arg("json");
|
||||||
|
|
||||||
|
// Add permission mode based on GOOSE_MODE setting
|
||||||
|
let config = Config::global();
|
||||||
|
if let Ok(goose_mode) = config.get_param::<String>("GOOSE_MODE") {
|
||||||
|
if goose_mode.as_str() == "auto" {
|
||||||
|
cmd.arg("--permission-mode").arg("acceptEdits");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
|
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
|
||||||
|
|
||||||
let mut child = cmd
|
let mut child = cmd
|
||||||
@@ -326,7 +408,7 @@ impl ClaudeCodeProvider {
|
|||||||
// Extract the first user message text
|
// Extract the first user message text
|
||||||
let description = messages
|
let description = messages
|
||||||
.iter()
|
.iter()
|
||||||
.find(|m| m.role == rmcp::model::Role::User)
|
.find(|m| m.role == mcp_core::Role::User)
|
||||||
.and_then(|m| {
|
.and_then(|m| {
|
||||||
m.content.iter().find_map(|c| match c {
|
m.content.iter().find_map(|c| match c {
|
||||||
MessageContent::Text(text_content) => Some(&text_content.text),
|
MessageContent::Text(text_content) => Some(&text_content.text),
|
||||||
@@ -349,11 +431,12 @@ impl ClaudeCodeProvider {
|
|||||||
println!("================================");
|
println!("================================");
|
||||||
}
|
}
|
||||||
|
|
||||||
let message = Message::new(
|
let message = Message {
|
||||||
rmcp::model::Role::Assistant,
|
id: None,
|
||||||
chrono::Utc::now().timestamp(),
|
role: mcp_core::Role::Assistant,
|
||||||
vec![MessageContent::text(description.clone())],
|
created: chrono::Utc::now().timestamp(),
|
||||||
);
|
content: vec![MessageContent::text(description.clone())],
|
||||||
|
};
|
||||||
|
|
||||||
let usage = Usage::default();
|
let usage = Usage::default();
|
||||||
|
|
||||||
@@ -384,8 +467,8 @@ impl Provider for ClaudeCodeProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_model_config(&self) -> ModelConfig {
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
// Return a custom config with 200K token limit for Claude Code
|
// Return the model config with appropriate context limit for Claude models
|
||||||
ModelConfig::new("claude-3-5-sonnet-latest".to_string()).with_context_limit(Some(200_000))
|
self.model.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@@ -439,6 +522,19 @@ mod tests {
|
|||||||
let config = provider.get_model_config();
|
let config = provider.get_model_config();
|
||||||
|
|
||||||
assert_eq!(config.model_name, "claude-3-5-sonnet-latest");
|
assert_eq!(config.model_name, "claude-3-5-sonnet-latest");
|
||||||
assert_eq!(config.context_limit(), 200_000);
|
// Context limit should be set by the ModelConfig
|
||||||
|
assert!(config.context_limit() > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_permission_mode_flag_construction() {
|
||||||
|
// Test that in auto mode, the --permission-mode acceptEdits flag is added
|
||||||
|
std::env::set_var("GOOSE_MODE", "auto");
|
||||||
|
|
||||||
|
let config = Config::global();
|
||||||
|
let goose_mode: String = config.get_param("GOOSE_MODE").unwrap();
|
||||||
|
assert_eq!(goose_mode, "auto");
|
||||||
|
|
||||||
|
std::env::remove_var("GOOSE_MODE");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::process::Stdio;
|
use std::process::Stdio;
|
||||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
@@ -13,8 +14,8 @@ use crate::model::ModelConfig;
|
|||||||
use mcp_core::tool::Tool;
|
use mcp_core::tool::Tool;
|
||||||
use rmcp::model::Role;
|
use rmcp::model::Role;
|
||||||
|
|
||||||
pub const GEMINI_CLI_DEFAULT_MODEL: &str = "default";
|
pub const GEMINI_CLI_DEFAULT_MODEL: &str = "gemini-2.5-pro";
|
||||||
pub const GEMINI_CLI_KNOWN_MODELS: &[&str] = &["default"];
|
pub const GEMINI_CLI_KNOWN_MODELS: &[&str] = &["gemini-2.5-pro"];
|
||||||
|
|
||||||
pub const GEMINI_CLI_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs";
|
pub const GEMINI_CLI_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs";
|
||||||
|
|
||||||
@@ -33,9 +34,76 @@ impl Default for GeminiCliProvider {
|
|||||||
|
|
||||||
impl GeminiCliProvider {
|
impl GeminiCliProvider {
|
||||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||||
let command = "gemini".to_string(); // Fixed command, no configuration needed
|
let config = crate::config::Config::global();
|
||||||
|
let command: String = config
|
||||||
|
.get_param("GEMINI_CLI_COMMAND")
|
||||||
|
.unwrap_or_else(|_| "gemini".to_string());
|
||||||
|
|
||||||
Ok(Self { command, model })
|
let resolved_command = if !command.contains('/') {
|
||||||
|
Self::find_gemini_executable(&command).unwrap_or(command)
|
||||||
|
} else {
|
||||||
|
command
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
command: resolved_command,
|
||||||
|
model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Search for gemini executable in common installation locations
|
||||||
|
fn find_gemini_executable(command_name: &str) -> Option<String> {
|
||||||
|
let home = std::env::var("HOME").ok()?;
|
||||||
|
|
||||||
|
// Common locations where gemini might be installed
|
||||||
|
let search_paths = vec![
|
||||||
|
format!("{}/.gemini/local/{}", home, command_name),
|
||||||
|
format!("{}/.local/bin/{}", home, command_name),
|
||||||
|
format!("{}/bin/{}", home, command_name),
|
||||||
|
format!("/usr/local/bin/{}", command_name),
|
||||||
|
format!("/usr/bin/{}", command_name),
|
||||||
|
format!("/opt/gemini/{}", command_name),
|
||||||
|
format!("/opt/google/{}", command_name),
|
||||||
|
];
|
||||||
|
|
||||||
|
for path in search_paths {
|
||||||
|
let path_buf = PathBuf::from(&path);
|
||||||
|
if path_buf.exists() && path_buf.is_file() {
|
||||||
|
// Check if it's executable
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
if let Ok(metadata) = std::fs::metadata(&path_buf) {
|
||||||
|
let permissions = metadata.permissions();
|
||||||
|
if permissions.mode() & 0o111 != 0 {
|
||||||
|
tracing::info!("Found gemini executable at: {}", path);
|
||||||
|
return Some(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
{
|
||||||
|
// On non-Unix systems, just check if file exists
|
||||||
|
tracing::info!("Found gemini executable at: {}", path);
|
||||||
|
return Some(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not found in common locations, check if it's in PATH
|
||||||
|
if let Ok(path_var) = std::env::var("PATH") {
|
||||||
|
for dir in path_var.split(':') {
|
||||||
|
let full_path = format!("{}/{}", dir, command_name);
|
||||||
|
let path_buf = PathBuf::from(&full_path);
|
||||||
|
if path_buf.exists() && path_buf.is_file() {
|
||||||
|
tracing::info!("Found gemini executable in PATH at: {}", full_path);
|
||||||
|
return Some(full_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!("Could not find gemini executable in common locations");
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Filter out the Extensions section from the system prompt
|
/// Filter out the Extensions section from the system prompt
|
||||||
@@ -102,7 +170,11 @@ impl GeminiCliProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut cmd = Command::new(&self.command);
|
let mut cmd = Command::new(&self.command);
|
||||||
cmd.arg("-p").arg(&full_prompt).arg("--yolo");
|
cmd.arg("-m")
|
||||||
|
.arg(&self.model.model_name)
|
||||||
|
.arg("-p")
|
||||||
|
.arg(&full_prompt)
|
||||||
|
.arg("--yolo");
|
||||||
|
|
||||||
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
|
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
|
||||||
|
|
||||||
@@ -125,7 +197,7 @@ impl GeminiCliProvider {
|
|||||||
Ok(0) => break, // EOF
|
Ok(0) => break, // EOF
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
let trimmed = line.trim();
|
let trimmed = line.trim();
|
||||||
if !trimmed.is_empty() {
|
if !trimmed.is_empty() && !trimmed.starts_with("Loaded cached credentials") {
|
||||||
lines.push(trimmed.to_string());
|
lines.push(trimmed.to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -240,8 +312,8 @@ impl Provider for GeminiCliProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_model_config(&self) -> ModelConfig {
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
// Return a custom config with 1M token limit for Gemini CLI
|
// Return the model config with appropriate context limit for Gemini models
|
||||||
ModelConfig::new("gemini-1.5-pro".to_string()).with_context_limit(Some(1_000_000))
|
self.model.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
@@ -294,7 +366,8 @@ mod tests {
|
|||||||
let provider = GeminiCliProvider::default();
|
let provider = GeminiCliProvider::default();
|
||||||
let config = provider.get_model_config();
|
let config = provider.get_model_config();
|
||||||
|
|
||||||
assert_eq!(config.model_name, "gemini-1.5-pro");
|
assert_eq!(config.model_name, "gemini-2.5-pro");
|
||||||
assert_eq!(config.context_limit(), 1_000_000);
|
// Context limit should be set by the ModelConfig
|
||||||
|
assert!(config.context_limit() > 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -176,6 +176,7 @@ goose session
|
|||||||
| Environment Variable | Description | Default |
|
| Environment Variable | Description | Default |
|
||||||
|---------------------|-------------|---------|
|
|---------------------|-------------|---------|
|
||||||
| `GOOSE_PROVIDER` | Set to `gemini-cli` to use this provider | None |
|
| `GOOSE_PROVIDER` | Set to `gemini-cli` to use this provider | None |
|
||||||
|
| `GEMINI_CLI_COMMAND` | Path to the Gemini CLI command | `gemini` |
|
||||||
|
|
||||||
## How It Works
|
## How It Works
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ const ModelAndProviderContext = createContext<ModelAndProviderContextType | unde
|
|||||||
export const ModelAndProviderProvider: React.FC<ModelAndProviderProviderProps> = ({ children }) => {
|
export const ModelAndProviderProvider: React.FC<ModelAndProviderProviderProps> = ({ children }) => {
|
||||||
const [currentModel, setCurrentModel] = useState<string | null>(null);
|
const [currentModel, setCurrentModel] = useState<string | null>(null);
|
||||||
const [currentProvider, setCurrentProvider] = useState<string | null>(null);
|
const [currentProvider, setCurrentProvider] = useState<string | null>(null);
|
||||||
const { read, upsert, getProviders } = useConfig();
|
const { read, upsert, getProviders, config } = useConfig();
|
||||||
|
|
||||||
const changeModel = useCallback(
|
const changeModel = useCallback(
|
||||||
async (model: Model) => {
|
async (model: Model) => {
|
||||||
@@ -183,6 +183,19 @@ export const ModelAndProviderProvider: React.FC<ModelAndProviderProviderProps> =
|
|||||||
refreshCurrentModelAndProvider();
|
refreshCurrentModelAndProvider();
|
||||||
}, [refreshCurrentModelAndProvider]);
|
}, [refreshCurrentModelAndProvider]);
|
||||||
|
|
||||||
|
// Extract config values for dependency array
|
||||||
|
const configObj = config as Record<string, unknown>;
|
||||||
|
const gooseModel = configObj?.GOOSE_MODEL;
|
||||||
|
const gooseProvider = configObj?.GOOSE_PROVIDER;
|
||||||
|
|
||||||
|
// Listen for config changes and refresh when GOOSE_MODEL or GOOSE_PROVIDER changes
|
||||||
|
useEffect(() => {
|
||||||
|
// Only refresh if the config has loaded and model/provider values exist
|
||||||
|
if (config && Object.keys(config).length > 0 && (gooseModel || gooseProvider)) {
|
||||||
|
refreshCurrentModelAndProvider();
|
||||||
|
}
|
||||||
|
}, [config, gooseModel, gooseProvider, refreshCurrentModelAndProvider]);
|
||||||
|
|
||||||
const contextValue = useMemo(
|
const contextValue = useMemo(
|
||||||
() => ({
|
() => ({
|
||||||
currentModel,
|
currentModel,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import {
|
|||||||
} from '../../../ui/dropdown-menu';
|
} from '../../../ui/dropdown-menu';
|
||||||
import { useCurrentModelInfo } from '../../../BaseChat';
|
import { useCurrentModelInfo } from '../../../BaseChat';
|
||||||
import { useConfig } from '../../../ConfigContext';
|
import { useConfig } from '../../../ConfigContext';
|
||||||
|
import { getProviderMetadata } from '../modelInterface';
|
||||||
import { Alert } from '../../../alerts';
|
import { Alert } from '../../../alerts';
|
||||||
import BottomMenuAlertPopover from '../../../bottom_menu/BottomMenuAlertPopover';
|
import BottomMenuAlertPopover from '../../../bottom_menu/BottomMenuAlertPopover';
|
||||||
import { Recipe } from '../../../../recipe';
|
import { Recipe } from '../../../../recipe';
|
||||||
@@ -42,12 +43,13 @@ export default function ModelsBottomBar({
|
|||||||
getCurrentProviderDisplayName,
|
getCurrentProviderDisplayName,
|
||||||
} = useModelAndProvider();
|
} = useModelAndProvider();
|
||||||
const currentModelInfo = useCurrentModelInfo();
|
const currentModelInfo = useCurrentModelInfo();
|
||||||
const { read } = useConfig();
|
const { read, getProviders } = useConfig();
|
||||||
const [displayProvider, setDisplayProvider] = useState<string | null>(null);
|
const [displayProvider, setDisplayProvider] = useState<string | null>(null);
|
||||||
const [displayModelName, setDisplayModelName] = useState<string>('Select Model');
|
const [displayModelName, setDisplayModelName] = useState<string>('Select Model');
|
||||||
const [isAddModelModalOpen, setIsAddModelModalOpen] = useState(false);
|
const [isAddModelModalOpen, setIsAddModelModalOpen] = useState(false);
|
||||||
const [isLeadWorkerModalOpen, setIsLeadWorkerModalOpen] = useState(false);
|
const [isLeadWorkerModalOpen, setIsLeadWorkerModalOpen] = useState(false);
|
||||||
const [isLeadWorkerActive, setIsLeadWorkerActive] = useState(false);
|
const [isLeadWorkerActive, setIsLeadWorkerActive] = useState(false);
|
||||||
|
const [providerDefaultModel, setProviderDefaultModel] = useState<string | null>(null);
|
||||||
|
|
||||||
// Save recipe dialog state (like in RecipeEditor.tsx)
|
// Save recipe dialog state (like in RecipeEditor.tsx)
|
||||||
const [showSaveDialog, setShowSaveDialog] = useState(false);
|
const [showSaveDialog, setShowSaveDialog] = useState(false);
|
||||||
@@ -91,10 +93,6 @@ export default function ModelsBottomBar({
|
|||||||
checkLeadWorker();
|
checkLeadWorker();
|
||||||
};
|
};
|
||||||
|
|
||||||
// Determine which model to display - activeModel takes priority when lead/worker is active
|
|
||||||
const displayModel =
|
|
||||||
isLeadWorkerActive && currentModelInfo?.model ? currentModelInfo.model : displayModelName;
|
|
||||||
|
|
||||||
// Since currentModelInfo.mode is not working, let's determine mode differently
|
// Since currentModelInfo.mode is not working, let's determine mode differently
|
||||||
// We'll need to get the lead model and compare it with the current model
|
// We'll need to get the lead model and compare it with the current model
|
||||||
const [leadModelName, setLeadModelName] = useState<string>('');
|
const [leadModelName, setLeadModelName] = useState<string>('');
|
||||||
@@ -122,6 +120,12 @@ export default function ModelsBottomBar({
|
|||||||
: 'worker'
|
: 'worker'
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
|
// Determine which model to display - activeModel takes priority when lead/worker is active
|
||||||
|
const displayModel =
|
||||||
|
isLeadWorkerActive && currentModelInfo?.model
|
||||||
|
? currentModelInfo.model
|
||||||
|
: currentModel || providerDefaultModel || displayModelName;
|
||||||
|
|
||||||
// Update display provider when current provider changes
|
// Update display provider when current provider changes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (currentProvider) {
|
if (currentProvider) {
|
||||||
@@ -137,6 +141,24 @@ export default function ModelsBottomBar({
|
|||||||
}
|
}
|
||||||
}, [currentProvider, getCurrentProviderDisplayName, getCurrentModelAndProviderForDisplay]);
|
}, [currentProvider, getCurrentProviderDisplayName, getCurrentModelAndProviderForDisplay]);
|
||||||
|
|
||||||
|
// Fetch provider default model when provider changes and no current model
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentProvider && !currentModel) {
|
||||||
|
(async () => {
|
||||||
|
try {
|
||||||
|
const metadata = await getProviderMetadata(currentProvider, getProviders);
|
||||||
|
setProviderDefaultModel(metadata.default_model);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to get provider default model:', error);
|
||||||
|
setProviderDefaultModel(null);
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
} else if (currentModel) {
|
||||||
|
// Clear provider default when we have a current model
|
||||||
|
setProviderDefaultModel(null);
|
||||||
|
}
|
||||||
|
}, [currentProvider, currentModel, getProviders]);
|
||||||
|
|
||||||
// Update display model name when current model changes
|
// Update display model name when current model changes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
(async () => {
|
(async () => {
|
||||||
|
|||||||
@@ -116,14 +116,24 @@ export default function ProviderSetupActions({
|
|||||||
</Button>
|
</Button>
|
||||||
</>
|
</>
|
||||||
) : (
|
) : (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
type="submit"
|
||||||
|
variant="ghost"
|
||||||
|
onClick={onSubmit}
|
||||||
|
className="w-full h-[60px] rounded-none border-t border-borderSubtle text-md hover:bg-bgSubtle text-textProminent font-medium"
|
||||||
|
>
|
||||||
|
Enable Provider
|
||||||
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
onClick={onCancel}
|
onClick={onCancel}
|
||||||
className="w-full h-[60px] rounded-none border-t border-borderSubtle hover:text-textStandard text-textSubtle hover:bg-bgSubtle text-md font-regular"
|
className="w-full h-[60px] rounded-none border-t border-borderSubtle hover:text-textStandard text-textSubtle hover:bg-bgSubtle text-md font-regular"
|
||||||
>
|
>
|
||||||
Close
|
Cancel
|
||||||
</Button>
|
</Button>
|
||||||
|
</>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
export const DefaultSubmitHandler = async (
|
export const DefaultSubmitHandler = async (
|
||||||
upsertFn: (key: string, value: unknown, isSecret: boolean) => Promise<void>,
|
upsertFn: (key: string, value: unknown, isSecret: boolean) => Promise<void>,
|
||||||
provider: {
|
provider: {
|
||||||
|
name: string;
|
||||||
metadata: {
|
metadata: {
|
||||||
config_keys?: Array<{
|
config_keys?: Array<{
|
||||||
name: string;
|
name: string;
|
||||||
@@ -18,6 +19,37 @@ export const DefaultSubmitHandler = async (
|
|||||||
) => {
|
) => {
|
||||||
const parameters = provider.metadata.config_keys || [];
|
const parameters = provider.metadata.config_keys || [];
|
||||||
|
|
||||||
|
if (parameters.length === 0) {
|
||||||
|
// For zero-config providers, mark them as configured
|
||||||
|
const configKey = `${provider.name}_configured`;
|
||||||
|
await upsertFn(configKey, true, false);
|
||||||
|
|
||||||
|
await upsertFn('GOOSE_PROVIDER', provider.name, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const requiredParams = parameters.filter((param) => param.required);
|
||||||
|
if (requiredParams.length === 0 && parameters.length > 0) {
|
||||||
|
const allOptionalWithDefaults = parameters.every(
|
||||||
|
(param) => !param.required && param.default !== undefined
|
||||||
|
);
|
||||||
|
if (allOptionalWithDefaults) {
|
||||||
|
const promises: Promise<void>[] = [];
|
||||||
|
const configKey = `${provider.name}_configured`;
|
||||||
|
promises.push(upsertFn(configKey, true, false));
|
||||||
|
|
||||||
|
for (const param of parameters) {
|
||||||
|
if (param.default !== undefined) {
|
||||||
|
const value =
|
||||||
|
configValues[param.name] !== undefined ? configValues[param.name] : param.default;
|
||||||
|
promises.push(upsertFn(param.name, value, param.secret === true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Promise.all(promises);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const upsertPromises = parameters.map(
|
const upsertPromises = parameters.map(
|
||||||
(parameter: { name: string; required?: boolean; default?: unknown; secret?: boolean }) => {
|
(parameter: { name: string; required?: boolean; default?: unknown; secret?: boolean }) => {
|
||||||
// Skip parameters that don't have a value and aren't required
|
// Skip parameters that don't have a value and aren't required
|
||||||
|
|||||||
Reference in New Issue
Block a user