feat: refactor register eval (#1713)

This commit is contained in:
marcelle
2025-03-18 15:18:09 -04:00
committed by GitHub
parent 5ed1f048ae
commit 4c03b34058
26 changed files with 166 additions and 121 deletions

1
Cargo.lock generated
View File

@@ -2352,6 +2352,7 @@ dependencies = [
"mcp-core",
"once_cell",
"paste",
"regex",
"serde",
"serde_json",
"tokio",

View File

@@ -23,6 +23,7 @@ tracing-subscriber = { version = "0.3", features = ["registry"] }
tokio = { version = "1.0", features = ["full"] }
include_dir = "0.7.4"
once_cell = "1.19"
regex = "1.11.1"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -14,8 +14,6 @@ pub struct BenchmarkWorkDir {
run_dir: PathBuf,
cwd: PathBuf,
run_name: String,
suite: Option<String>,
eval: Option<String>,
}
impl Default for BenchmarkWorkDir {
@@ -59,8 +57,6 @@ impl BenchmarkWorkDir {
run_dir,
cwd: base_path.clone(),
run_name,
suite: None,
eval: None,
}
}
fn copy_auto_included_dirs(dest: &Path) {
@@ -77,24 +73,10 @@ impl BenchmarkWorkDir {
self.cwd = path;
Ok(self)
}
pub fn set_suite(&mut self, suite: &str) {
self.eval = None;
self.suite = Some(suite.to_string());
let mut suite_dir = self.base_path.clone();
suite_dir.push(self.run_name.clone());
suite_dir.push(suite);
self.cd(suite_dir.clone()).unwrap_or_else(|_| {
panic!("Failed to execute cd into {}", suite_dir.clone().display())
});
}
pub fn set_eval(&mut self, eval: &str) {
self.eval = Some(eval.to_string());
let eval = eval.replace(":", std::path::MAIN_SEPARATOR_STR);
let mut eval_dir = self.base_path.clone();
eval_dir.push(self.run_name.clone());
eval_dir.push(self.suite.clone().unwrap());
eval_dir.push(eval);
self.cd(eval_dir.clone())

View File

@@ -0,0 +1,3 @@
// computer controller extension evals
mod script;
mod web_scrape;

View File

@@ -81,4 +81,4 @@ impl Evaluation for ComputerControllerScript {
}
}
register_evaluation!("computercontroller", ComputerControllerScript);
register_evaluation!(ComputerControllerScript);

View File

@@ -84,4 +84,4 @@ impl Evaluation for ComputerControllerWebScrape {
}
}
register_evaluation!("computercontroller", ComputerControllerWebScrape);
register_evaluation!(ComputerControllerWebScrape);

View File

@@ -124,4 +124,4 @@ impl Evaluation for DeveloperCreateFile {
}
}
register_evaluation!("developer", DeveloperCreateFile);
register_evaluation!(DeveloperCreateFile);

View File

@@ -85,4 +85,4 @@ impl Evaluation for DeveloperListFiles {
}
}
register_evaluation!("developer", DeveloperListFiles);
register_evaluation!(DeveloperListFiles);

View File

@@ -0,0 +1,3 @@
// developer extension evals
mod create_file;
mod list_files;

View File

@@ -102,4 +102,4 @@ impl Evaluation for DeveloperImage {
}
}
register_evaluation!("developer_image", DeveloperImage);
register_evaluation!(DeveloperImage);

View File

@@ -0,0 +1 @@
mod image;

View File

@@ -0,0 +1 @@
mod search_replace;

View File

@@ -106,4 +106,4 @@ impl Evaluation for DeveloperSearchReplace {
}
}
register_evaluation!("developer_search_replace", DeveloperSearchReplace);
register_evaluation!(DeveloperSearchReplace);

View File

@@ -43,4 +43,4 @@ impl Evaluation for ExampleEval {
}
}
register_evaluation!("core", ExampleEval);
register_evaluation!(ExampleEval);

View File

@@ -0,0 +1,2 @@
// memory extension evals
mod save_fact;

View File

@@ -86,4 +86,4 @@ impl Evaluation for MemoryRememberMemory {
}
}
register_evaluation!("memory", MemoryRememberMemory);
register_evaluation!(MemoryRememberMemory);

View File

@@ -1,11 +1,6 @@
mod computercontroller;
mod developer;
mod developer_image;
mod developer_search_replace;
mod example;
// developer extension evals
mod create_file;
mod image;
mod list_files;
mod search_replace;
// computer controller extension evals
mod script;
mod web_scrape;
// memory extension evals
mod save_fact;
mod memory;

View File

@@ -1,62 +1,127 @@
pub use super::Evaluation;
use regex::Regex;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};
type EvaluationConstructor = fn() -> Box<dyn Evaluation>;
type Registry = &'static RwLock<HashMap<&'static str, EvaluationConstructor>>;
// Use std::sync::RwLock for interior mutability
static EVALUATION_REGISTRY: OnceLock<RwLock<HashMap<&'static str, Vec<EvaluationConstructor>>>> =
static EVAL_REGISTRY: OnceLock<RwLock<HashMap<&'static str, EvaluationConstructor>>> =
OnceLock::new();
/// Initialize the registry if it hasn't been initialized
fn registry() -> &'static RwLock<HashMap<&'static str, Vec<EvaluationConstructor>>> {
EVALUATION_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
fn eval_registry() -> Registry {
EVAL_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
}
/// Register a new evaluation version
pub fn register_evaluation(suite_name: &'static str, constructor: fn() -> Box<dyn Evaluation>) {
let registry = registry();
pub fn register_eval(selector: &'static str, constructor: fn() -> Box<dyn Evaluation>) {
let registry = eval_registry();
if let Ok(mut map) = registry.write() {
map.entry(suite_name)
.or_insert_with(Vec::new)
.push(constructor);
map.insert(selector, constructor);
}
}
pub struct EvaluationSuiteFactory;
pub struct EvaluationSuite;
impl EvaluationSuiteFactory {
pub fn create(suite_name: &str) -> Option<Vec<Box<dyn Evaluation>>> {
let registry = registry();
impl EvaluationSuite {
pub fn from(selector: &str) -> Option<Box<dyn Evaluation>> {
let registry = eval_registry();
let map = registry
.read()
.expect("Failed to read the benchmark evaluation registry.");
let constructors = map.get(suite_name)?;
let instances = constructors
.iter()
.map(|&constructor| constructor())
.collect::<Vec<_>>();
let constructor = map.get(selector)?;
let instance = constructor();
Some(instances)
Some(instance)
}
pub fn available_evaluations() -> Vec<&'static str> {
registry()
pub fn registered_evals() -> Vec<&'static str> {
let registry = eval_registry();
let map = registry
.read()
.map(|map| map.keys().copied().collect())
.unwrap_or_default()
.expect("Failed to read the benchmark evaluation registry.");
let evals: Vec<_> = map.keys().copied().collect();
evals
}
pub fn select(selectors: Vec<String>) -> HashMap<String, Vec<&'static str>> {
let eval_name_pattern = Regex::new(r":\w+$").unwrap();
let grouped_by_suite: HashMap<String, Vec<&'static str>> =
EvaluationSuite::registered_evals()
.into_iter()
.filter(|&eval| selectors.is_empty() || matches_any_selectors(eval, &selectors))
.fold(HashMap::new(), |mut suites, eval| {
let suite = match eval_name_pattern.replace(eval, "") {
Cow::Borrowed(s) => s.to_string(),
Cow::Owned(s) => s,
};
suites.entry(suite).or_default().push(eval);
suites
});
grouped_by_suite
}
pub fn available_selectors() -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
for selector in EvaluationSuite::registered_evals() {
let parts = selector.split(":").collect::<Vec<_>>();
for i in 0..parts.len() {
let sel = parts[..i + 1].join(":");
*counts.entry(sel).or_insert(0) += 1;
}
}
counts
}
}
fn matches_any_selectors(eval: &str, selectors: &Vec<String>) -> bool {
// selectors must prefix match exactly, no matching half-way in a word
// remove one level of nesting at a time and check exact match
let nesting_pattern = Regex::new(r":\w+$").unwrap();
let mut level_up = eval.to_string();
for selector in selectors {
while !level_up.is_empty() {
if level_up == *selector {
return true;
}
if !level_up.contains(":") {
break;
};
level_up = match nesting_pattern.replace(&level_up, "") {
Cow::Borrowed(s) => s.to_string(),
Cow::Owned(s) => s,
};
}
}
false
}
#[macro_export]
macro_rules! register_evaluation {
($suite_name:expr, $evaluation_type:ty) => {
($evaluation_type:ty) => {
paste::paste! {
#[ctor::ctor]
#[allow(non_snake_case)]
fn [<__register_evaluation_ $suite_name>]() {
$crate::eval_suites::factory::register_evaluation($suite_name, || {
fn [<__register_evaluation_ $evaluation_type>]() {
let mut path = std::path::PathBuf::from(file!());
path.set_extension("");
let eval_suites_dir = "eval_suites";
let eval_selector = {
let s = path.components()
.skip_while(|comp| comp.as_os_str() != eval_suites_dir)
.skip(1)
.map(|comp| comp.as_os_str().to_string_lossy().to_string())
.collect::<Vec<_>>()
.join(":");
Box::leak(s.into_boxed_str())
};
$crate::eval_suites::factory::register_eval(eval_selector, || {
Box::new(<$evaluation_type>::new())
});
}

View File

@@ -6,6 +6,6 @@ mod utils;
mod vibes;
pub use evaluation::*;
pub use factory::{register_evaluation, EvaluationSuiteFactory};
pub use factory::{register_eval, EvaluationSuite};
pub use metrics::*;
pub use utils::*;

View File

@@ -86,4 +86,4 @@ impl Evaluation for BlogSummary {
}
}
register_evaluation!("vibes", BlogSummary);
register_evaluation!(BlogSummary);

View File

@@ -118,4 +118,4 @@ impl Evaluation for FlappyBird {
}
}
register_evaluation!("vibes", FlappyBird);
register_evaluation!(FlappyBird);

View File

@@ -96,4 +96,4 @@ impl Evaluation for GooseWiki {
}
}
register_evaluation!("vibes", GooseWiki);
register_evaluation!(GooseWiki);

View File

@@ -106,4 +106,4 @@ Present the information in order of significance or quality. Focus specifically
}
}
register_evaluation!("vibes", RestaurantResearch);
register_evaluation!(RestaurantResearch);

View File

@@ -174,4 +174,4 @@ After writing the script, run it using python3 and show the results. Do not ask
}
}
register_evaluation!("vibes", SquirrelCensus);
register_evaluation!(SquirrelCensus);

View File

@@ -5,9 +5,8 @@ use async_trait::async_trait;
use goose::config::Config;
use goose::message::Message;
use goose_bench::bench_work_dir::BenchmarkWorkDir;
use goose_bench::eval_suites::{BenchAgent, BenchAgentError, Evaluation, EvaluationSuiteFactory};
use goose_bench::eval_suites::{BenchAgent, BenchAgentError, Evaluation, EvaluationSuite};
use goose_bench::reporting::{BenchmarkResults, EvaluationResult, SuiteResult};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
@@ -116,31 +115,10 @@ async fn run_eval(
Ok(result)
}
async fn run_suite(suite: &str, work_dir: &mut BenchmarkWorkDir) -> anyhow::Result<SuiteResult> {
let mut suite_result = SuiteResult::new(suite.to_string());
let eval_work_dir_guard = Mutex::new(work_dir);
if let Some(evals) = EvaluationSuiteFactory::create(suite) {
for eval in evals {
let mut eval_work_dir = eval_work_dir_guard.lock().await;
eval_work_dir.set_eval(eval.name());
let eval_result = run_eval(eval, &mut eval_work_dir).await?;
suite_result.add_evaluation(eval_result);
}
}
Ok(suite_result)
}
pub async fn run_benchmark(
suites: Vec<String>,
selectors: Vec<String>,
include_dirs: Vec<PathBuf>,
) -> anyhow::Result<BenchmarkResults> {
let suites = EvaluationSuiteFactory::available_evaluations()
.into_iter()
.filter(|&s| suites.contains(&s.to_string()))
.collect::<Vec<_>>();
let config = Config::global();
let goose_model: String = config
.get_param("GOOSE_MODEL")
@@ -151,30 +129,45 @@ pub async fn run_benchmark(
let mut results = BenchmarkResults::new(provider_name.clone());
let suite_work_dir = Mutex::new(BenchmarkWorkDir::new(
let work_dir = Mutex::new(BenchmarkWorkDir::new(
format!("{}-{}", provider_name, goose_model),
include_dirs.clone(),
));
for suite in suites {
let mut work_dir = suite_work_dir.lock().await;
work_dir.set_suite(suite);
let suite_result = run_suite(suite, &mut work_dir).await?;
for (suite, evals) in EvaluationSuite::select(selectors).iter() {
let mut suite_result = SuiteResult::new(suite.clone());
for eval_selector in evals {
if let Some(eval) = EvaluationSuite::from(eval_selector) {
let mut work_dir = work_dir.lock().await;
work_dir.set_eval(eval_selector);
let eval_result = run_eval(eval, &mut work_dir).await?;
suite_result.add_evaluation(eval_result);
}
}
results.add_suite(suite_result);
}
Ok(results)
}
pub async fn list_suites() -> anyhow::Result<HashMap<String, usize>> {
let suites = EvaluationSuiteFactory::available_evaluations();
let mut suite_counts = HashMap::new();
for suite in suites {
if let Some(evals) = EvaluationSuiteFactory::create(suite) {
suite_counts.insert(suite.to_string(), evals.len());
pub async fn list_selectors() -> anyhow::Result<()> {
let selector_eval_counts = EvaluationSuite::available_selectors();
let mut keys: Vec<_> = selector_eval_counts.keys().collect();
keys.sort();
let max_key_len = keys.iter().map(|k| k.len()).max().unwrap_or(0);
println!(
"selector {} => Eval Count",
" ".repeat(max_key_len - "selector".len())
);
println!("{}", "-".repeat(max_key_len + 6));
for selector in keys {
println!(
"{} {} => {}",
selector,
" ".repeat(max_key_len - selector.len()),
selector_eval_counts.get(selector).unwrap()
);
}
}
Ok(suite_counts)
Ok(())
}

View File

@@ -4,7 +4,7 @@ use clap::{Args, Parser, Subcommand};
use goose::config::Config;
use goose_cli::commands::agent_version::AgentCommand;
use goose_cli::commands::bench::{list_suites, run_benchmark};
use goose_cli::commands::bench::{list_selectors, run_benchmark};
use goose_cli::commands::configure::handle_configure;
use goose_cli::commands::info::handle_info;
use goose_cli::commands::mcp::run_server;
@@ -237,13 +237,13 @@ enum Command {
Bench {
#[arg(
short = 's',
long = "suites",
value_name = "BENCH_SUITE_NAME",
long = "selectors",
value_name = "EVALUATIONS_SELECTOR",
help = "Run this list of bench-suites.",
long_help = "Specify a comma-separated list of evaluation-suite names to be run.",
value_delimiter = ','
)]
suites: Vec<String>,
selectors: Vec<String>,
#[arg(
short = 'i',
@@ -266,7 +266,7 @@ enum Command {
#[arg(
long = "list",
value_name = "LIST",
help = "List all available bench suites."
help = "List all selectors and the number of evaluations they select."
)]
list: bool,
@@ -416,7 +416,7 @@ async fn main() -> Result<()> {
return Ok(());
}
Some(Command::Bench {
suites,
selectors,
include_dirs,
repeat,
list,
@@ -425,24 +425,22 @@ async fn main() -> Result<()> {
summary,
}) => {
if list {
let suites = list_suites().await?;
for suite in suites.keys() {
println!("{}: {}", suite, suites.get(suite).unwrap());
return list_selectors().await;
}
return Ok(());
}
let suites = if suites.is_empty() {
let selectors = if selectors.is_empty() {
vec!["core".to_string()]
} else {
suites
selectors
};
let current_dir = std::env::current_dir()?;
for i in 0..repeat {
if repeat > 1 {
println!("\nRun {} of {}:", i + 1, repeat);
}
let results = run_benchmark(suites.clone(), include_dirs.clone()).await?;
let results = run_benchmark(selectors.clone(), include_dirs.clone()).await?;
// Handle output based on format
let output_str = match format.as_str() {