feat: Support extending the system prompt (#1167)

This commit is contained in:
Bradley Axen
2025-02-11 20:03:32 -08:00
committed by GitHub
parent a5e2419380
commit 6220ef054f
8 changed files with 88 additions and 1 deletions

View File

@@ -0,0 +1,16 @@
/// Returns a system prompt extension that explains CLI-specific functionality
pub fn get_cli_prompt() -> String {
String::from(
"You are being accessed through a command-line interface. The following slash commands are available
- you can let the user know about them if they need help:
- /exit or /quit - Exit the session
- /t - Toggle between Light/Dark/Ansi themes
- /? or /help - Display help message
Additional keyboard shortcuts:
- Ctrl+C - Interrupt the current interaction (resets to before the interrupted request)
- Ctrl+J - Add a newline
- Up/Down arrows - Navigate command history"
)
}

View File

@@ -161,6 +161,11 @@ pub async fn build_session(
let prompt = Box::new(RustylinePrompt::new());
// Add CLI-specific system prompt extension
agent
.extend_system_prompt(crate::cli_prompt::get_cli_prompt())
.await;
display_session_info(resume, &provider_name, &model, &session_file);
Session::new(agent, prompt, session_file)
}

View File

@@ -9,6 +9,7 @@ pub static APP_STRATEGY: Lazy<AppStrategyArgs> = Lazy::new(|| AppStrategyArgs {
app_name: "goose".to_string(),
});
mod cli_prompt;
mod commands;
mod log_usage;
mod logging;

View File

@@ -17,6 +17,16 @@ struct VersionsResponse {
default_version: String,
}
#[derive(Deserialize)]
struct ExtendPromptRequest {
extension: String,
}
#[derive(Serialize)]
struct ExtendPromptResponse {
success: bool,
}
#[derive(Deserialize)]
struct CreateAgentRequest {
version: Option<String>,
@@ -61,6 +71,30 @@ async fn get_versions() -> Json<VersionsResponse> {
})
}
async fn extend_prompt(
State(state): State<AppState>,
headers: HeaderMap,
Json(payload): Json<ExtendPromptRequest>,
) -> Result<Json<ExtendPromptResponse>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}
let mut agent = state.agent.lock().await;
if let Some(ref mut agent) = *agent {
agent.extend_system_prompt(payload.extension).await;
Ok(Json(ExtendPromptResponse { success: true }))
} else {
Err(StatusCode::NOT_FOUND)
}
}
async fn create_agent(
State(state): State<AppState>,
headers: HeaderMap,
@@ -132,6 +166,7 @@ pub fn routes(state: AppState) -> Router {
Router::new()
.route("/agent/versions", get(get_versions))
.route("/agent/providers", get(list_providers))
.route("/agent/prompt", post(extend_prompt))
.route("/agent", post(create_agent))
.with_state(state)
}

View File

@@ -28,4 +28,7 @@ pub trait Agent: Send + Sync {
/// Get the total usage of the agent
async fn usage(&self) -> Vec<ProviderUsage>;
/// Add custom text to be included in the system prompt
async fn extend_system_prompt(&mut self, extension: String);
}

View File

@@ -30,6 +30,7 @@ pub struct Capabilities {
resource_capable_extensions: HashSet<String>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
system_prompt_extensions: Vec<String>,
}
/// A flattened representation of a resource used by the agent to prepare inference
@@ -88,6 +89,7 @@ impl Capabilities {
resource_capable_extensions: HashSet::new(),
provider,
provider_usage: Mutex::new(Vec::new()),
system_prompt_extensions: Vec::new(),
}
}
@@ -164,6 +166,11 @@ impl Capabilities {
Ok(())
}
/// Add a system prompt extension
pub fn add_system_prompt_extension(&mut self, extension: String) {
self.system_prompt_extensions.push(extension);
}
/// Get a reference to the provider
pub fn provider(&self) -> &dyn Provider {
&*self.provider
@@ -303,7 +310,17 @@ impl Capabilities {
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
context.insert("current_date_time", Value::String(current_date_time));
load_prompt_file("system.md", &context).expect("Prompt should render")
let base_prompt = load_prompt_file("system.md", &context).expect("Prompt should render");
if self.system_prompt_extensions.is_empty() {
base_prompt
} else {
format!(
"{}\n\n# Additional Instructions:\n\n{}",
base_prompt,
self.system_prompt_extensions.join("\n\n")
)
}
}
/// Find and return a reference to the appropriate client for a tool call

View File

@@ -184,6 +184,11 @@ impl Agent for ReferenceAgent {
let capabilities = self.capabilities.lock().await;
capabilities.get_usage().await
}
async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
}
register_agent!("reference", ReferenceAgent);

View File

@@ -292,6 +292,11 @@ impl Agent for TruncateAgent {
let capabilities = self.capabilities.lock().await;
capabilities.get_usage().await
}
async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
}
register_agent!("truncate", TruncateAgent);