mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-08 16:14:24 +01:00
feat: list Databricks-supported models and enable fuzzy search during model configuration (#3039)
This commit is contained in:
@@ -351,6 +351,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
|
||||
.map(|m| (m, m.as_str(), ""))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.filter_mode() // enable "fuzzy search" filtering for the list of models
|
||||
.interact()?
|
||||
.to_string(),
|
||||
Ok(None) => {
|
||||
|
||||
@@ -495,6 +495,89 @@ impl Provider for DatabricksProvider {
|
||||
.await
|
||||
.map_err(|e| ProviderError::ExecutionError(e.to_string()))
|
||||
}
|
||||
|
||||
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
|
||||
let base_url = Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let url = base_url.join("api/2.0/serving-endpoints").map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let auth_header = match self.ensure_auth_header().await {
|
||||
Ok(header) => header,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to authorize with Databricks: {}", e);
|
||||
return Ok(None); // Return None to fall back to manual input
|
||||
}
|
||||
};
|
||||
|
||||
let response = match self
|
||||
.client
|
||||
.get(url)
|
||||
.header("Authorization", auth_header)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch Databricks models: {}", e);
|
||||
return Ok(None); // Return None to fall back to manual input
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
if let Ok(error_text) = response.text().await {
|
||||
tracing::warn!(
|
||||
"Failed to fetch Databricks models: {} - {}",
|
||||
status,
|
||||
error_text
|
||||
);
|
||||
} else {
|
||||
tracing::warn!("Failed to fetch Databricks models: {}", status);
|
||||
}
|
||||
return Ok(None); // Return None to fall back to manual input
|
||||
}
|
||||
|
||||
let json: Value = match response.json().await {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse Databricks API response: {}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
|
||||
Some(endpoints) => endpoints,
|
||||
None => {
|
||||
tracing::warn!(
|
||||
"Unexpected response format from Databricks API: missing 'endpoints' array"
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let models: Vec<String> = endpoints
|
||||
.iter()
|
||||
.filter_map(|endpoint| {
|
||||
endpoint
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|name| name.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
if models.is_empty() {
|
||||
tracing::debug!("No serving endpoints found in Databricks workspace");
|
||||
Ok(None)
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"Found {} serving endpoints in Databricks workspace",
|
||||
models.len()
|
||||
);
|
||||
Ok(Some(models))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
Reference in New Issue
Block a user