mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
feat: efficient benching (#1921)
Co-authored-by: Tyler Rockwood <rockwotj@gmail.com> Co-authored-by: Kalvin C <kalvinnchau@users.noreply.github.com> Co-authored-by: Alice Hau <110418948+ahau-square@users.noreply.github.com>
This commit is contained in:
@@ -1,88 +1,37 @@
|
||||
use crate::logging;
|
||||
use crate::session::build_session;
|
||||
use crate::Session;
|
||||
use crate::{logging, session, Session};
|
||||
use async_trait::async_trait;
|
||||
use goose::config::Config;
|
||||
use goose::message::Message;
|
||||
use goose_bench::bench_work_dir::BenchmarkWorkDir;
|
||||
use goose_bench::eval_suites::{BenchAgent, BenchAgentError, Evaluation, EvaluationSuite};
|
||||
use goose_bench::reporting::{BenchmarkResults, EvaluationResult, SuiteResult};
|
||||
use goose_bench::bench_session::{BenchAgent, BenchBaseSession};
|
||||
use goose_bench::eval_suites::ExtensionRequirements;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct BenchSession {
|
||||
session: Session,
|
||||
errors: Arc<Mutex<Vec<BenchAgentError>>>,
|
||||
}
|
||||
|
||||
impl BenchSession {
|
||||
pub fn new(session: Session) -> Self {
|
||||
let errors = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
// Initialize logging with error capture
|
||||
logging::setup_logging(Some("bench"), Some(errors.clone()))
|
||||
.expect("Failed to initialize logging");
|
||||
|
||||
Self { session, errors }
|
||||
}
|
||||
}
|
||||
|
||||
// allow session obj to be used in benchmarking
|
||||
#[async_trait]
|
||||
impl BenchAgent for BenchSession {
|
||||
async fn prompt(&mut self, p: String) -> anyhow::Result<Vec<Message>> {
|
||||
// Clear previous errors
|
||||
{
|
||||
let mut errors = self.errors.lock().await;
|
||||
errors.clear();
|
||||
}
|
||||
|
||||
self.session.headless(p).await?;
|
||||
Ok(self.session.message_history())
|
||||
impl BenchBaseSession for Session {
|
||||
async fn headless(&mut self, message: String) -> anyhow::Result<()> {
|
||||
self.headless(message).await
|
||||
}
|
||||
|
||||
async fn get_errors(&self) -> Vec<BenchAgentError> {
|
||||
let errors = self.errors.lock().await;
|
||||
errors.clone()
|
||||
fn session_file(&self) -> PathBuf {
|
||||
self.session_file()
|
||||
}
|
||||
|
||||
async fn get_token_usage(&self) -> Option<i32> {
|
||||
self.session.get_total_token_usage().ok().flatten()
|
||||
fn message_history(&self) -> Vec<Message> {
|
||||
self.message_history()
|
||||
}
|
||||
fn get_total_token_usage(&self) -> anyhow::Result<Option<i32>> {
|
||||
self.get_total_token_usage()
|
||||
}
|
||||
}
|
||||
pub async fn agent_generator(
|
||||
requirements: ExtensionRequirements,
|
||||
session_id: String,
|
||||
) -> BenchAgent {
|
||||
let identifier = Some(session::Identifier::Name(session_id));
|
||||
|
||||
// Wrapper struct to implement BenchAgent for Arc<Mutex<BenchSession>>
|
||||
struct BenchAgentWrapper(Arc<Mutex<BenchSession>>);
|
||||
|
||||
#[async_trait]
|
||||
impl BenchAgent for BenchAgentWrapper {
|
||||
async fn prompt(&mut self, p: String) -> anyhow::Result<Vec<Message>> {
|
||||
let mut session = self.0.lock().await;
|
||||
session.prompt(p).await
|
||||
}
|
||||
|
||||
async fn get_errors(&self) -> Vec<BenchAgentError> {
|
||||
let session = self.0.lock().await;
|
||||
session.get_errors().await
|
||||
}
|
||||
|
||||
async fn get_token_usage(&self) -> Option<i32> {
|
||||
let session = self.0.lock().await;
|
||||
session.get_token_usage().await
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_eval(
|
||||
evaluation: Box<dyn Evaluation>,
|
||||
work_dir: &mut BenchmarkWorkDir,
|
||||
) -> anyhow::Result<EvaluationResult> {
|
||||
let mut result = EvaluationResult::new(evaluation.name().to_string());
|
||||
|
||||
let requirements = evaluation.required_extensions();
|
||||
|
||||
// Create session with error capture
|
||||
let base_session = build_session(
|
||||
None,
|
||||
identifier,
|
||||
false,
|
||||
requirements.external,
|
||||
requirements.remote,
|
||||
@@ -91,84 +40,12 @@ async fn run_eval(
|
||||
)
|
||||
.await;
|
||||
|
||||
let bench_session = Arc::new(Mutex::new(BenchSession::new(base_session)));
|
||||
let bench_session_clone = bench_session.clone();
|
||||
// package session obj into benchmark-compatible struct
|
||||
let bench_agent = BenchAgent::new(Box::new(base_session));
|
||||
|
||||
if let Ok(metrics) = evaluation
|
||||
.run(Box::new(BenchAgentWrapper(bench_session)), work_dir)
|
||||
.await
|
||||
{
|
||||
for (name, metric) in metrics {
|
||||
result.add_metric(name, metric);
|
||||
}
|
||||
// Initialize logging with error capture
|
||||
let errors = Some(Arc::new(Mutex::new(bench_agent.get_errors().await)));
|
||||
logging::setup_logging(Some("bench"), errors).expect("Failed to initialize logging");
|
||||
|
||||
// Add any errors that occurred
|
||||
let agent = BenchAgentWrapper(bench_session_clone);
|
||||
for error in agent.get_errors().await {
|
||||
result.add_error(error);
|
||||
}
|
||||
}
|
||||
|
||||
let current_dir = std::env::current_dir()?;
|
||||
let output_str = serde_json::to_string_pretty(&result)?;
|
||||
std::fs::write(current_dir.join("eval_result.json"), &output_str)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn run_benchmark(
|
||||
selectors: Vec<String>,
|
||||
include_dirs: Vec<PathBuf>,
|
||||
) -> anyhow::Result<BenchmarkResults> {
|
||||
let config = Config::global();
|
||||
let goose_model: String = config
|
||||
.get_param("GOOSE_MODEL")
|
||||
.expect("No model configured. Run 'goose configure' first");
|
||||
let provider_name: String = config
|
||||
.get_param("GOOSE_PROVIDER")
|
||||
.expect("No provider configured. Run 'goose configure' first");
|
||||
|
||||
let mut results = BenchmarkResults::new(provider_name.clone());
|
||||
|
||||
let work_dir = Mutex::new(BenchmarkWorkDir::new(
|
||||
format!("{}-{}", provider_name, goose_model),
|
||||
include_dirs.clone(),
|
||||
));
|
||||
|
||||
for (suite, evals) in EvaluationSuite::select(selectors).iter() {
|
||||
let mut suite_result = SuiteResult::new(suite.clone());
|
||||
for eval_selector in evals {
|
||||
if let Some(eval) = EvaluationSuite::from(eval_selector) {
|
||||
let mut work_dir = work_dir.lock().await;
|
||||
work_dir.set_eval(eval_selector);
|
||||
let eval_result = run_eval(eval, &mut work_dir).await?;
|
||||
suite_result.add_evaluation(eval_result);
|
||||
}
|
||||
}
|
||||
|
||||
results.add_suite(suite_result);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
pub async fn list_selectors() -> anyhow::Result<()> {
|
||||
let selector_eval_counts = EvaluationSuite::available_selectors();
|
||||
let mut keys: Vec<_> = selector_eval_counts.keys().collect();
|
||||
keys.sort();
|
||||
let max_key_len = keys.iter().map(|k| k.len()).max().unwrap_or(0);
|
||||
println!(
|
||||
"selector {} => Eval Count",
|
||||
" ".repeat(max_key_len - "selector".len())
|
||||
);
|
||||
println!("{}", "-".repeat(max_key_len + 6));
|
||||
for selector in keys {
|
||||
println!(
|
||||
"{} {} => {}",
|
||||
selector,
|
||||
" ".repeat(max_key_len - selector.len()),
|
||||
selector_eval_counts.get(selector).unwrap()
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
bench_agent
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user