mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-17 05:54:26 +01:00
feat(benchmark/cli): Add challenge list, challenge info subcommands
- Add `challenge list` command with options `--all`, `--names`, `--json` - Add `tabular` dependency - Add `.utils.utils.sorted_by_enum_index` function to easily sort lists by an enum value/property based on the order of the enum's definition - Add `challenge info [name]` command with option `--json` - Add `.utils.utils.pretty_print_model` routine to pretty-print Pydantic models - Refactor `config` subcommand to use `pretty_print_model`
This commit is contained in:
@@ -202,15 +202,136 @@ def serve(port: Optional[int] = None):
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
def config():
|
def config():
|
||||||
"""Displays info regarding the present AGBenchmark config."""
|
"""Displays info regarding the present AGBenchmark config."""
|
||||||
|
from .utils.utils import pretty_print_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = AgentBenchmarkConfig.load()
|
config = AgentBenchmarkConfig.load()
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
click.echo(e, err=True)
|
click.echo(e, err=True)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
k_col_width = max(len(k) for k in config.dict().keys())
|
pretty_print_model(config, include_header=False)
|
||||||
for k, v in config.dict().items():
|
|
||||||
click.echo(f"{k: <{k_col_width}} = {v}")
|
|
||||||
|
@cli.group()
|
||||||
|
def challenge():
|
||||||
|
logging.getLogger().setLevel(logging.WARNING)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@challenge.command("list")
|
||||||
|
@click.option(
|
||||||
|
"--all", "include_unavailable", is_flag=True, help="Include unavailable challenges."
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--names", "only_names", is_flag=True, help="List only the challenge names."
|
||||||
|
)
|
||||||
|
@click.option("--json", "output_json", is_flag=True)
|
||||||
|
def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool):
|
||||||
|
"""Lists [available|all] challenges."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from .challenges.builtin import load_builtin_challenges
|
||||||
|
from .challenges.webarena import load_webarena_challenges
|
||||||
|
from .utils.data_types import Category, DifficultyLevel
|
||||||
|
from .utils.utils import sorted_by_enum_index
|
||||||
|
|
||||||
|
DIFFICULTY_COLORS = {
|
||||||
|
difficulty: color
|
||||||
|
for difficulty, color in zip(
|
||||||
|
DifficultyLevel,
|
||||||
|
["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
CATEGORY_COLORS = {
|
||||||
|
category: f"bright_{color}"
|
||||||
|
for category, color in zip(
|
||||||
|
Category,
|
||||||
|
["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load challenges
|
||||||
|
challenges = filter(
|
||||||
|
lambda c: c.info.available or include_unavailable,
|
||||||
|
[
|
||||||
|
*load_builtin_challenges(),
|
||||||
|
*load_webarena_challenges(skip_unavailable=False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
challenges = sorted_by_enum_index(
|
||||||
|
challenges, DifficultyLevel, key=lambda c: c.info.difficulty
|
||||||
|
)
|
||||||
|
|
||||||
|
if only_names:
|
||||||
|
if output_json:
|
||||||
|
click.echo(json.dumps([c.info.name for c in challenges]))
|
||||||
|
return
|
||||||
|
|
||||||
|
for c in challenges:
|
||||||
|
click.echo(
|
||||||
|
click.style(c.info.name, fg=None if c.info.available else "black")
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if output_json:
|
||||||
|
click.echo(json.dumps([json.loads(c.info.json()) for c in challenges]))
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = tuple(
|
||||||
|
click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories")
|
||||||
|
)
|
||||||
|
table = [
|
||||||
|
tuple(
|
||||||
|
v if challenge.info.available else click.style(v, fg="black")
|
||||||
|
for v in (
|
||||||
|
challenge.info.name,
|
||||||
|
(
|
||||||
|
click.style(
|
||||||
|
challenge.info.difficulty.value,
|
||||||
|
fg=DIFFICULTY_COLORS[challenge.info.difficulty],
|
||||||
|
)
|
||||||
|
if challenge.info.difficulty
|
||||||
|
else click.style("-", fg="black")
|
||||||
|
),
|
||||||
|
" ".join(
|
||||||
|
click.style(cat.value, fg=CATEGORY_COLORS[cat])
|
||||||
|
for cat in sorted_by_enum_index(challenge.info.category, Category)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for challenge in challenges
|
||||||
|
]
|
||||||
|
click.echo(tabulate(table, headers=headers))
|
||||||
|
|
||||||
|
|
||||||
|
@challenge.command()
|
||||||
|
@click.option("--json", is_flag=True)
|
||||||
|
@click.argument("name")
|
||||||
|
def info(name: str, json: bool):
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
from .challenges.builtin import load_builtin_challenges
|
||||||
|
from .challenges.webarena import load_webarena_challenges
|
||||||
|
from .utils.utils import pretty_print_model
|
||||||
|
|
||||||
|
for challenge in chain(
|
||||||
|
load_builtin_challenges(),
|
||||||
|
load_webarena_challenges(skip_unavailable=False),
|
||||||
|
):
|
||||||
|
if challenge.info.name != name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if json:
|
||||||
|
click.echo(challenge.info.json())
|
||||||
|
break
|
||||||
|
|
||||||
|
pretty_print_model(challenge.info)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ STRING_DIFFICULTY_MAP = {e.value: DIFFICULTY_MAP[e] for e in DifficultyLevel}
|
|||||||
|
|
||||||
|
|
||||||
class Category(str, Enum):
|
class Category(str, Enum):
|
||||||
DATA = "data"
|
|
||||||
GENERALIST = "general"
|
GENERALIST = "general"
|
||||||
|
DATA = "data"
|
||||||
CODING = "coding"
|
CODING = "coding"
|
||||||
SCRAPE_SYNTHESIZE = "scrape_synthesize"
|
SCRAPE_SYNTHESIZE = "scrape_synthesize"
|
||||||
WEB = "web"
|
WEB = "web"
|
||||||
|
|||||||
@@ -3,10 +3,13 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Iterable, Optional, TypeVar, overload
|
||||||
|
|
||||||
|
import click
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agbenchmark.reports.processing.report_types import Test
|
from agbenchmark.reports.processing.report_types import Test
|
||||||
from agbenchmark.utils.data_types import DIFFICULTY_MAP, DifficultyLevel
|
from agbenchmark.utils.data_types import DIFFICULTY_MAP, DifficultyLevel
|
||||||
@@ -17,6 +20,9 @@ AGENT_NAME = os.getenv("AGENT_NAME")
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
E = TypeVar("E", bound=Enum)
|
||||||
|
|
||||||
|
|
||||||
def replace_backslash(value: Any) -> Any:
|
def replace_backslash(value: Any) -> Any:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
@@ -124,6 +130,42 @@ def write_pretty_json(data, json_file):
|
|||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_model(model: BaseModel, include_header: bool = True) -> None:
|
||||||
|
indent = ""
|
||||||
|
if include_header:
|
||||||
|
# Try to find the ID and/or name attribute of the model
|
||||||
|
id, name = None, None
|
||||||
|
for attr, value in model.dict().items():
|
||||||
|
if attr == "id" or attr.endswith("_id"):
|
||||||
|
id = value
|
||||||
|
if attr.endswith("name"):
|
||||||
|
name = value
|
||||||
|
if id and name:
|
||||||
|
break
|
||||||
|
identifiers = [v for v in [name, id] if v]
|
||||||
|
click.echo(
|
||||||
|
f"{model.__repr_name__()}{repr(identifiers) if identifiers else ''}:"
|
||||||
|
)
|
||||||
|
indent = " " * 2
|
||||||
|
|
||||||
|
k_col_width = max(len(k) for k in model.dict().keys())
|
||||||
|
for k, v in model.dict().items():
|
||||||
|
v_fmt = repr(v)
|
||||||
|
if v is None or v == "":
|
||||||
|
v_fmt = click.style(v_fmt, fg="black")
|
||||||
|
elif type(v) is bool:
|
||||||
|
v_fmt = click.style(v_fmt, fg="green" if v else "red")
|
||||||
|
elif type(v) is str and "\n" in v:
|
||||||
|
v_fmt = f"\n{v}".replace(
|
||||||
|
"\n", f"\n{indent} {click.style('|', fg='black')} "
|
||||||
|
)
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
v_fmt = click.style(v.value, fg="blue")
|
||||||
|
elif type(v) is list and len(v) > 0 and isinstance(v[0], Enum):
|
||||||
|
v_fmt = ", ".join(click.style(lv.value, fg="blue") for lv in v)
|
||||||
|
click.echo(f"{indent}{k: <{k_col_width}} = {v_fmt}")
|
||||||
|
|
||||||
|
|
||||||
def deep_sort(obj):
|
def deep_sort(obj):
|
||||||
"""
|
"""
|
||||||
Recursively sort the keys in JSON object
|
Recursively sort the keys in JSON object
|
||||||
@@ -133,3 +175,38 @@ def deep_sort(obj):
|
|||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return [deep_sort(elem) for elem in obj]
|
return [deep_sort(elem) for elem in obj]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def sorted_by_enum_index(
|
||||||
|
sortable: Iterable[E],
|
||||||
|
enum: type[E],
|
||||||
|
*,
|
||||||
|
reverse: bool = False,
|
||||||
|
) -> list[E]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def sorted_by_enum_index(
|
||||||
|
sortable: Iterable[T],
|
||||||
|
enum: type[Enum],
|
||||||
|
*,
|
||||||
|
key: Callable[[T], Enum | None],
|
||||||
|
reverse: bool = False,
|
||||||
|
) -> list[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def sorted_by_enum_index(
|
||||||
|
sortable: Iterable[T],
|
||||||
|
enum: type[Enum],
|
||||||
|
*,
|
||||||
|
key: Callable[[T], Enum | None] = lambda x: x, # type: ignore
|
||||||
|
reverse: bool = False,
|
||||||
|
) -> list[T]:
|
||||||
|
return sorted(
|
||||||
|
sortable,
|
||||||
|
key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3,
|
||||||
|
reverse=reverse,
|
||||||
|
)
|
||||||
|
|||||||
16
benchmark/poetry.lock
generated
16
benchmark/poetry.lock
generated
@@ -2431,6 +2431,20 @@ anyio = ">=3.4.0,<5"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
|
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tabulate"
|
||||||
|
version = "0.9.0"
|
||||||
|
description = "Pretty-print tabular data"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
|
||||||
|
{file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
widechars = ["wcwidth"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml"
|
name = "toml"
|
||||||
version = "0.10.2"
|
version = "0.10.2"
|
||||||
@@ -2760,4 +2774,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "d7893a88906b5a8eda566e13e6a9492d012c910ded0da1b1ef12b69a14f8e047"
|
content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9"
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ toml = "^0.10.2"
|
|||||||
httpx = "^0.24.0"
|
httpx = "^0.24.0"
|
||||||
agent-protocol-client = "^1.1.0"
|
agent-protocol-client = "^1.1.0"
|
||||||
click-default-group = "^1.2.4"
|
click-default-group = "^1.2.4"
|
||||||
|
tabulate = "^0.9.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
flake8 = "^3.9.2"
|
flake8 = "^3.9.2"
|
||||||
|
|||||||
Reference in New Issue
Block a user