mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-06 16:04:28 +01:00
feat: refactor register eval (#1713)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2352,6 +2352,7 @@ dependencies = [
|
||||
"mcp-core",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
|
||||
@@ -23,6 +23,7 @@ tracing-subscriber = { version = "0.3", features = ["registry"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
include_dir = "0.7.4"
|
||||
once_cell = "1.19"
|
||||
regex = "1.11.1"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
@@ -14,8 +14,6 @@ pub struct BenchmarkWorkDir {
|
||||
run_dir: PathBuf,
|
||||
cwd: PathBuf,
|
||||
run_name: String,
|
||||
suite: Option<String>,
|
||||
eval: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for BenchmarkWorkDir {
|
||||
@@ -59,8 +57,6 @@ impl BenchmarkWorkDir {
|
||||
run_dir,
|
||||
cwd: base_path.clone(),
|
||||
run_name,
|
||||
suite: None,
|
||||
eval: None,
|
||||
}
|
||||
}
|
||||
fn copy_auto_included_dirs(dest: &Path) {
|
||||
@@ -77,24 +73,10 @@ impl BenchmarkWorkDir {
|
||||
self.cwd = path;
|
||||
Ok(self)
|
||||
}
|
||||
pub fn set_suite(&mut self, suite: &str) {
|
||||
self.eval = None;
|
||||
self.suite = Some(suite.to_string());
|
||||
|
||||
let mut suite_dir = self.base_path.clone();
|
||||
suite_dir.push(self.run_name.clone());
|
||||
suite_dir.push(suite);
|
||||
|
||||
self.cd(suite_dir.clone()).unwrap_or_else(|_| {
|
||||
panic!("Failed to execute cd into {}", suite_dir.clone().display())
|
||||
});
|
||||
}
|
||||
pub fn set_eval(&mut self, eval: &str) {
|
||||
self.eval = Some(eval.to_string());
|
||||
|
||||
let eval = eval.replace(":", std::path::MAIN_SEPARATOR_STR);
|
||||
let mut eval_dir = self.base_path.clone();
|
||||
eval_dir.push(self.run_name.clone());
|
||||
eval_dir.push(self.suite.clone().unwrap());
|
||||
eval_dir.push(eval);
|
||||
|
||||
self.cd(eval_dir.clone())
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
// computer controller extension evals
|
||||
mod script;
|
||||
mod web_scrape;
|
||||
@@ -81,4 +81,4 @@ impl Evaluation for ComputerControllerScript {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("computercontroller", ComputerControllerScript);
|
||||
register_evaluation!(ComputerControllerScript);
|
||||
@@ -84,4 +84,4 @@ impl Evaluation for ComputerControllerWebScrape {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("computercontroller", ComputerControllerWebScrape);
|
||||
register_evaluation!(ComputerControllerWebScrape);
|
||||
@@ -124,4 +124,4 @@ impl Evaluation for DeveloperCreateFile {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("developer", DeveloperCreateFile);
|
||||
register_evaluation!(DeveloperCreateFile);
|
||||
@@ -85,4 +85,4 @@ impl Evaluation for DeveloperListFiles {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("developer", DeveloperListFiles);
|
||||
register_evaluation!(DeveloperListFiles);
|
||||
3
crates/goose-bench/src/eval_suites/core/developer/mod.rs
Normal file
3
crates/goose-bench/src/eval_suites/core/developer/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
// developer extension evals
|
||||
mod create_file;
|
||||
mod list_files;
|
||||
@@ -102,4 +102,4 @@ impl Evaluation for DeveloperImage {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("developer_image", DeveloperImage);
|
||||
register_evaluation!(DeveloperImage);
|
||||
@@ -0,0 +1 @@
|
||||
mod image;
|
||||
@@ -0,0 +1 @@
|
||||
mod search_replace;
|
||||
@@ -106,4 +106,4 @@ impl Evaluation for DeveloperSearchReplace {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("developer_search_replace", DeveloperSearchReplace);
|
||||
register_evaluation!(DeveloperSearchReplace);
|
||||
@@ -43,4 +43,4 @@ impl Evaluation for ExampleEval {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("core", ExampleEval);
|
||||
register_evaluation!(ExampleEval);
|
||||
|
||||
2
crates/goose-bench/src/eval_suites/core/memory/mod.rs
Normal file
2
crates/goose-bench/src/eval_suites/core/memory/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
// memory extension evals
|
||||
mod save_fact;
|
||||
@@ -86,4 +86,4 @@ impl Evaluation for MemoryRememberMemory {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("memory", MemoryRememberMemory);
|
||||
register_evaluation!(MemoryRememberMemory);
|
||||
@@ -1,11 +1,6 @@
|
||||
mod computercontroller;
|
||||
mod developer;
|
||||
mod developer_image;
|
||||
mod developer_search_replace;
|
||||
mod example;
|
||||
// developer extension evals
|
||||
mod create_file;
|
||||
mod image;
|
||||
mod list_files;
|
||||
mod search_replace;
|
||||
// computer controller extension evals
|
||||
mod script;
|
||||
mod web_scrape;
|
||||
// memory extension evals
|
||||
mod save_fact;
|
||||
mod memory;
|
||||
|
||||
@@ -1,62 +1,127 @@
|
||||
pub use super::Evaluation;
|
||||
use regex::Regex;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{OnceLock, RwLock};
|
||||
|
||||
type EvaluationConstructor = fn() -> Box<dyn Evaluation>;
|
||||
type Registry = &'static RwLock<HashMap<&'static str, EvaluationConstructor>>;
|
||||
|
||||
// Use std::sync::RwLock for interior mutability
|
||||
static EVALUATION_REGISTRY: OnceLock<RwLock<HashMap<&'static str, Vec<EvaluationConstructor>>>> =
|
||||
static EVAL_REGISTRY: OnceLock<RwLock<HashMap<&'static str, EvaluationConstructor>>> =
|
||||
OnceLock::new();
|
||||
|
||||
/// Initialize the registry if it hasn't been initialized
|
||||
fn registry() -> &'static RwLock<HashMap<&'static str, Vec<EvaluationConstructor>>> {
|
||||
EVALUATION_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
fn eval_registry() -> Registry {
|
||||
EVAL_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Register a new evaluation version
|
||||
pub fn register_evaluation(suite_name: &'static str, constructor: fn() -> Box<dyn Evaluation>) {
|
||||
let registry = registry();
|
||||
pub fn register_eval(selector: &'static str, constructor: fn() -> Box<dyn Evaluation>) {
|
||||
let registry = eval_registry();
|
||||
if let Ok(mut map) = registry.write() {
|
||||
map.entry(suite_name)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(constructor);
|
||||
map.insert(selector, constructor);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EvaluationSuiteFactory;
|
||||
pub struct EvaluationSuite;
|
||||
|
||||
impl EvaluationSuiteFactory {
|
||||
pub fn create(suite_name: &str) -> Option<Vec<Box<dyn Evaluation>>> {
|
||||
let registry = registry();
|
||||
impl EvaluationSuite {
|
||||
pub fn from(selector: &str) -> Option<Box<dyn Evaluation>> {
|
||||
let registry = eval_registry();
|
||||
let map = registry
|
||||
.read()
|
||||
.expect("Failed to read the benchmark evaluation registry.");
|
||||
|
||||
let constructors = map.get(suite_name)?;
|
||||
let instances = constructors
|
||||
.iter()
|
||||
.map(|&constructor| constructor())
|
||||
.collect::<Vec<_>>();
|
||||
let constructor = map.get(selector)?;
|
||||
let instance = constructor();
|
||||
|
||||
Some(instances)
|
||||
Some(instance)
|
||||
}
|
||||
|
||||
pub fn available_evaluations() -> Vec<&'static str> {
|
||||
registry()
|
||||
pub fn registered_evals() -> Vec<&'static str> {
|
||||
let registry = eval_registry();
|
||||
let map = registry
|
||||
.read()
|
||||
.map(|map| map.keys().copied().collect())
|
||||
.unwrap_or_default()
|
||||
.expect("Failed to read the benchmark evaluation registry.");
|
||||
|
||||
let evals: Vec<_> = map.keys().copied().collect();
|
||||
evals
|
||||
}
|
||||
pub fn select(selectors: Vec<String>) -> HashMap<String, Vec<&'static str>> {
|
||||
let eval_name_pattern = Regex::new(r":\w+$").unwrap();
|
||||
let grouped_by_suite: HashMap<String, Vec<&'static str>> =
|
||||
EvaluationSuite::registered_evals()
|
||||
.into_iter()
|
||||
.filter(|&eval| selectors.is_empty() || matches_any_selectors(eval, &selectors))
|
||||
.fold(HashMap::new(), |mut suites, eval| {
|
||||
let suite = match eval_name_pattern.replace(eval, "") {
|
||||
Cow::Borrowed(s) => s.to_string(),
|
||||
Cow::Owned(s) => s,
|
||||
};
|
||||
suites.entry(suite).or_default().push(eval);
|
||||
suites
|
||||
});
|
||||
|
||||
grouped_by_suite
|
||||
}
|
||||
|
||||
pub fn available_selectors() -> HashMap<String, usize> {
|
||||
let mut counts: HashMap<String, usize> = HashMap::new();
|
||||
for selector in EvaluationSuite::registered_evals() {
|
||||
let parts = selector.split(":").collect::<Vec<_>>();
|
||||
for i in 0..parts.len() {
|
||||
let sel = parts[..i + 1].join(":");
|
||||
*counts.entry(sel).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
counts
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_any_selectors(eval: &str, selectors: &Vec<String>) -> bool {
|
||||
// selectors must prefix match exactly, no matching half-way in a word
|
||||
// remove one level of nesting at a time and check exact match
|
||||
let nesting_pattern = Regex::new(r":\w+$").unwrap();
|
||||
let mut level_up = eval.to_string();
|
||||
for selector in selectors {
|
||||
while !level_up.is_empty() {
|
||||
if level_up == *selector {
|
||||
return true;
|
||||
}
|
||||
if !level_up.contains(":") {
|
||||
break;
|
||||
};
|
||||
level_up = match nesting_pattern.replace(&level_up, "") {
|
||||
Cow::Borrowed(s) => s.to_string(),
|
||||
Cow::Owned(s) => s,
|
||||
};
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! register_evaluation {
|
||||
($suite_name:expr, $evaluation_type:ty) => {
|
||||
($evaluation_type:ty) => {
|
||||
paste::paste! {
|
||||
#[ctor::ctor]
|
||||
#[allow(non_snake_case)]
|
||||
fn [<__register_evaluation_ $suite_name>]() {
|
||||
$crate::eval_suites::factory::register_evaluation($suite_name, || {
|
||||
fn [<__register_evaluation_ $evaluation_type>]() {
|
||||
let mut path = std::path::PathBuf::from(file!());
|
||||
path.set_extension("");
|
||||
let eval_suites_dir = "eval_suites";
|
||||
let eval_selector = {
|
||||
let s = path.components()
|
||||
.skip_while(|comp| comp.as_os_str() != eval_suites_dir)
|
||||
.skip(1)
|
||||
.map(|comp| comp.as_os_str().to_string_lossy().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(":");
|
||||
Box::leak(s.into_boxed_str())
|
||||
};
|
||||
|
||||
$crate::eval_suites::factory::register_eval(eval_selector, || {
|
||||
Box::new(<$evaluation_type>::new())
|
||||
});
|
||||
}
|
||||
|
||||
@@ -6,6 +6,6 @@ mod utils;
|
||||
mod vibes;
|
||||
|
||||
pub use evaluation::*;
|
||||
pub use factory::{register_evaluation, EvaluationSuiteFactory};
|
||||
pub use factory::{register_eval, EvaluationSuite};
|
||||
pub use metrics::*;
|
||||
pub use utils::*;
|
||||
|
||||
@@ -86,4 +86,4 @@ impl Evaluation for BlogSummary {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("vibes", BlogSummary);
|
||||
register_evaluation!(BlogSummary);
|
||||
|
||||
@@ -118,4 +118,4 @@ impl Evaluation for FlappyBird {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("vibes", FlappyBird);
|
||||
register_evaluation!(FlappyBird);
|
||||
|
||||
@@ -96,4 +96,4 @@ impl Evaluation for GooseWiki {
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("vibes", GooseWiki);
|
||||
register_evaluation!(GooseWiki);
|
||||
|
||||
@@ -106,4 +106,4 @@ Present the information in order of significance or quality. Focus specifically
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("vibes", RestaurantResearch);
|
||||
register_evaluation!(RestaurantResearch);
|
||||
|
||||
@@ -174,4 +174,4 @@ After writing the script, run it using python3 and show the results. Do not ask
|
||||
}
|
||||
}
|
||||
|
||||
register_evaluation!("vibes", SquirrelCensus);
|
||||
register_evaluation!(SquirrelCensus);
|
||||
|
||||
@@ -5,9 +5,8 @@ use async_trait::async_trait;
|
||||
use goose::config::Config;
|
||||
use goose::message::Message;
|
||||
use goose_bench::bench_work_dir::BenchmarkWorkDir;
|
||||
use goose_bench::eval_suites::{BenchAgent, BenchAgentError, Evaluation, EvaluationSuiteFactory};
|
||||
use goose_bench::eval_suites::{BenchAgent, BenchAgentError, Evaluation, EvaluationSuite};
|
||||
use goose_bench::reporting::{BenchmarkResults, EvaluationResult, SuiteResult};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -116,31 +115,10 @@ async fn run_eval(
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn run_suite(suite: &str, work_dir: &mut BenchmarkWorkDir) -> anyhow::Result<SuiteResult> {
|
||||
let mut suite_result = SuiteResult::new(suite.to_string());
|
||||
let eval_work_dir_guard = Mutex::new(work_dir);
|
||||
|
||||
if let Some(evals) = EvaluationSuiteFactory::create(suite) {
|
||||
for eval in evals {
|
||||
let mut eval_work_dir = eval_work_dir_guard.lock().await;
|
||||
eval_work_dir.set_eval(eval.name());
|
||||
let eval_result = run_eval(eval, &mut eval_work_dir).await?;
|
||||
suite_result.add_evaluation(eval_result);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(suite_result)
|
||||
}
|
||||
|
||||
pub async fn run_benchmark(
|
||||
suites: Vec<String>,
|
||||
selectors: Vec<String>,
|
||||
include_dirs: Vec<PathBuf>,
|
||||
) -> anyhow::Result<BenchmarkResults> {
|
||||
let suites = EvaluationSuiteFactory::available_evaluations()
|
||||
.into_iter()
|
||||
.filter(|&s| suites.contains(&s.to_string()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let config = Config::global();
|
||||
let goose_model: String = config
|
||||
.get_param("GOOSE_MODEL")
|
||||
@@ -151,30 +129,45 @@ pub async fn run_benchmark(
|
||||
|
||||
let mut results = BenchmarkResults::new(provider_name.clone());
|
||||
|
||||
let suite_work_dir = Mutex::new(BenchmarkWorkDir::new(
|
||||
let work_dir = Mutex::new(BenchmarkWorkDir::new(
|
||||
format!("{}-{}", provider_name, goose_model),
|
||||
include_dirs.clone(),
|
||||
));
|
||||
|
||||
for suite in suites {
|
||||
let mut work_dir = suite_work_dir.lock().await;
|
||||
work_dir.set_suite(suite);
|
||||
let suite_result = run_suite(suite, &mut work_dir).await?;
|
||||
for (suite, evals) in EvaluationSuite::select(selectors).iter() {
|
||||
let mut suite_result = SuiteResult::new(suite.clone());
|
||||
for eval_selector in evals {
|
||||
if let Some(eval) = EvaluationSuite::from(eval_selector) {
|
||||
let mut work_dir = work_dir.lock().await;
|
||||
work_dir.set_eval(eval_selector);
|
||||
let eval_result = run_eval(eval, &mut work_dir).await?;
|
||||
suite_result.add_evaluation(eval_result);
|
||||
}
|
||||
}
|
||||
|
||||
results.add_suite(suite_result);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
pub async fn list_suites() -> anyhow::Result<HashMap<String, usize>> {
|
||||
let suites = EvaluationSuiteFactory::available_evaluations();
|
||||
let mut suite_counts = HashMap::new();
|
||||
|
||||
for suite in suites {
|
||||
if let Some(evals) = EvaluationSuiteFactory::create(suite) {
|
||||
suite_counts.insert(suite.to_string(), evals.len());
|
||||
pub async fn list_selectors() -> anyhow::Result<()> {
|
||||
let selector_eval_counts = EvaluationSuite::available_selectors();
|
||||
let mut keys: Vec<_> = selector_eval_counts.keys().collect();
|
||||
keys.sort();
|
||||
let max_key_len = keys.iter().map(|k| k.len()).max().unwrap_or(0);
|
||||
println!(
|
||||
"selector {} => Eval Count",
|
||||
" ".repeat(max_key_len - "selector".len())
|
||||
);
|
||||
println!("{}", "-".repeat(max_key_len + 6));
|
||||
for selector in keys {
|
||||
println!(
|
||||
"{} {} => {}",
|
||||
selector,
|
||||
" ".repeat(max_key_len - selector.len()),
|
||||
selector_eval_counts.get(selector).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(suite_counts)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use clap::{Args, Parser, Subcommand};
|
||||
use goose::config::Config;
|
||||
|
||||
use goose_cli::commands::agent_version::AgentCommand;
|
||||
use goose_cli::commands::bench::{list_suites, run_benchmark};
|
||||
use goose_cli::commands::bench::{list_selectors, run_benchmark};
|
||||
use goose_cli::commands::configure::handle_configure;
|
||||
use goose_cli::commands::info::handle_info;
|
||||
use goose_cli::commands::mcp::run_server;
|
||||
@@ -237,13 +237,13 @@ enum Command {
|
||||
Bench {
|
||||
#[arg(
|
||||
short = 's',
|
||||
long = "suites",
|
||||
value_name = "BENCH_SUITE_NAME",
|
||||
long = "selectors",
|
||||
value_name = "EVALUATIONS_SELECTOR",
|
||||
help = "Run this list of bench-suites.",
|
||||
long_help = "Specify a comma-separated list of evaluation-suite names to be run.",
|
||||
value_delimiter = ','
|
||||
)]
|
||||
suites: Vec<String>,
|
||||
selectors: Vec<String>,
|
||||
|
||||
#[arg(
|
||||
short = 'i',
|
||||
@@ -266,7 +266,7 @@ enum Command {
|
||||
#[arg(
|
||||
long = "list",
|
||||
value_name = "LIST",
|
||||
help = "List all available bench suites."
|
||||
help = "List all selectors and the number of evaluations they select."
|
||||
)]
|
||||
list: bool,
|
||||
|
||||
@@ -416,7 +416,7 @@ async fn main() -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Bench {
|
||||
suites,
|
||||
selectors,
|
||||
include_dirs,
|
||||
repeat,
|
||||
list,
|
||||
@@ -425,24 +425,22 @@ async fn main() -> Result<()> {
|
||||
summary,
|
||||
}) => {
|
||||
if list {
|
||||
let suites = list_suites().await?;
|
||||
for suite in suites.keys() {
|
||||
println!("{}: {}", suite, suites.get(suite).unwrap());
|
||||
return list_selectors().await;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
let suites = if suites.is_empty() {
|
||||
|
||||
let selectors = if selectors.is_empty() {
|
||||
vec!["core".to_string()]
|
||||
} else {
|
||||
suites
|
||||
selectors
|
||||
};
|
||||
|
||||
let current_dir = std::env::current_dir()?;
|
||||
|
||||
for i in 0..repeat {
|
||||
if repeat > 1 {
|
||||
println!("\nRun {} of {}:", i + 1, repeat);
|
||||
}
|
||||
let results = run_benchmark(suites.clone(), include_dirs.clone()).await?;
|
||||
let results = run_benchmark(selectors.clone(), include_dirs.clone()).await?;
|
||||
|
||||
// Handle output based on format
|
||||
let output_str = match format.as_str() {
|
||||
|
||||
Reference in New Issue
Block a user