mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-31 11:54:30 +01:00
refactor(benchmark): Interface & type consoledation, and arch change, to allow adding challenge providers
Squashed commit of the following: commit 7d6476d3297860f74c276d571da995d958a8cc1a Author: Reinier van der Leer <pwuts@agpt.co> Date: Tue Jan 9 18:10:45 2024 +0100 refactor(benchmark/challenge): Set up structure to support more challenge providers - Move `Challenge`, `ChallengeData`, `load_challenges` to `challenges/builtin.py` and rename to `BuiltinChallenge`, `BuiltinChallengeSpec`, `load_builtin_challenges` - Create `BaseChallenge` to serve as interface and base class for different challenge implementations - Create `ChallengeInfo` model to serve as universal challenge info object - Create `get_challenge_from_source_uri` function in `challenges/__init__.py` - Replace `ChallengeData` by `ChallengeInfo` everywhere except in `BuiltinChallenge` - Add strong typing to `task_informations` store in app.py - Use `call.duration` in `finalize_test_report` and remove `timer` fixture - Update docstring on `challenges/__init__.py:get_unique_categories` - Add docstring to `generate_test.py` commit 5df2aa7939b45d85a2c2b5de9ac0522330d1502a Author: Reinier van der Leer <pwuts@agpt.co> Date: Tue Jan 9 16:58:01 2024 +0100 refactor(benchmark): Refactor & rename functions in agent_interface.py and agent_api_interface.py - `copy_artifacts_into_temp_folder` -> `copy_challenge_artifacts_into_workspace` - `copy_agent_artifacts_into_folder` -> `download_agent_artifacts_into_folder` - Reorder parameters of `run_api_agent`, `copy_challenge_artifacts_into_workspace`; use `Path` instead of `str` commit 6a256fef4c7950b7ee82fb801e70c83afe6b6f8b Author: Reinier van der Leer <pwuts@agpt.co> Date: Tue Jan 9 16:02:25 2024 +0100 refactor(benchmark): Refactor & typefix report generation and handling logic - Rename functions in reports.py and ReportManager.py to better reflect what they do - `get_previous_test_results` -> `get_and_update_success_history` - `generate_single_call_report` -> `initialize_test_report` - `finalize_reports` -> `finalize_test_report` - `ReportManager.end_info_report` -> `SessionReportManager.finalize_session_report` - Modify `pytest_runtest_makereport` hook in conftest.py to finalize the report immediately after the challenge finishes running instead of after teardown - Move result processing logic from `initialize_test_report` to `finalize_test_report` in reports.py - Use `Test` and `Report` types from report_types.py where possible instead of untyped dicts: reports.py, utils.py, ReportManager.py - Differentiate `ReportManager` into `SessionReportManager`, `RegressionTestsTracker`, `SuccessRateTracker` - Move filtering of optional challenge categories from challenge.py (`Challenge.skip_optional_categories`) to conftest.py (`pytest_collection_modifyitems`) - Remove unused `scores` fixture in conftest.py commit 370d6dbf5df75d78e3878877968e8cd309d6d7fb Author: Reinier van der Leer <pwuts@agpt.co> Date: Tue Jan 9 15:16:43 2024 +0100 refactor(benchmark): Simplify models in report_types.py - Removed ForbidOptionalMeta and BaseModelBenchmark classes. - Changed model attributes to optional: `Metrics.difficulty`, `Metrics.success`, `Metrics.success_percentage`, `Metrics.run_time`, and `Test.reached_cutoff`. - Added validator to `Metrics` model to require `success` and `run_time` fields if `attempted=True`. - Added default values to all optional model fields. - Removed duplicate imports. - Added condition in process_report.py to prevent null lookups if `metrics.difficulty` is not set.
This commit is contained in:
@@ -2,27 +2,32 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from agent_protocol_client import AgentApi, ApiClient, Configuration, TaskRequestBody
|
||||
from agent_protocol_client import (
|
||||
AgentApi,
|
||||
ApiClient,
|
||||
Configuration,
|
||||
Step,
|
||||
TaskRequestBody,
|
||||
)
|
||||
|
||||
from agbenchmark.agent_interface import get_list_of_file_paths
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.utils.data_types import ChallengeData
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_api_agent(
|
||||
task: ChallengeData,
|
||||
task: str,
|
||||
config: AgentBenchmarkConfig,
|
||||
artifacts_location: str,
|
||||
timeout: int,
|
||||
) -> None:
|
||||
artifacts_location: Optional[Path] = None,
|
||||
) -> AsyncIterator[Step]:
|
||||
configuration = Configuration(host=config.host)
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
task_request_body = TaskRequestBody(input=task.task)
|
||||
task_request_body = TaskRequestBody(input=task)
|
||||
|
||||
start_time = time.time()
|
||||
response = await api_instance.create_agent_task(
|
||||
@@ -30,37 +35,33 @@ async def run_api_agent(
|
||||
)
|
||||
task_id = response.task_id
|
||||
|
||||
await upload_artifacts(
|
||||
api_instance, artifacts_location, task_id, "artifacts_in"
|
||||
)
|
||||
|
||||
i = 1
|
||||
steps_remaining = True
|
||||
while steps_remaining:
|
||||
# Read the existing JSON data from the file
|
||||
if artifacts_location:
|
||||
await upload_artifacts(
|
||||
api_instance, artifacts_location, task_id, "artifacts_in"
|
||||
)
|
||||
|
||||
while True:
|
||||
step = await api_instance.execute_agent_task_step(task_id=task_id)
|
||||
|
||||
print(f"[{task.name}] - step {step.name} ({i}. request)")
|
||||
i += 1
|
||||
yield step
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError("Time limit exceeded")
|
||||
if not step or step.is_last:
|
||||
steps_remaining = False
|
||||
break
|
||||
|
||||
# In "mock" mode, we cheat by giving the correct artifacts to pass the challenge
|
||||
if os.getenv("IS_MOCK"):
|
||||
await upload_artifacts(
|
||||
api_instance, artifacts_location, task_id, "artifacts_out"
|
||||
if artifacts_location:
|
||||
# In "mock" mode, we cheat by giving the correct artifacts to pass the test
|
||||
if os.getenv("IS_MOCK"):
|
||||
await upload_artifacts(
|
||||
api_instance, artifacts_location, task_id, "artifacts_out"
|
||||
)
|
||||
|
||||
await download_agent_artifacts_into_folder(
|
||||
api_instance, task_id, config.temp_folder
|
||||
)
|
||||
|
||||
await copy_agent_artifacts_into_folder(
|
||||
api_instance, task_id, config.temp_folder
|
||||
)
|
||||
|
||||
|
||||
async def copy_agent_artifacts_into_folder(
|
||||
async def download_agent_artifacts_into_folder(
|
||||
api_instance: AgentApi, task_id: str, folder: Path
|
||||
):
|
||||
artifacts = await api_instance.list_agent_task_artifacts(task_id=task_id)
|
||||
@@ -76,11 +77,10 @@ async def copy_agent_artifacts_into_folder(
|
||||
folder = (folder / path).parent
|
||||
|
||||
if not folder.exists():
|
||||
LOG.info(f"Creating directory {folder}")
|
||||
folder.mkdir(parents=True)
|
||||
|
||||
file_path = folder / artifact.file_name
|
||||
LOG.info(f"Writing file {file_path}")
|
||||
logger.debug(f"Downloading agent artifact {artifact.file_name} to {folder}")
|
||||
with open(file_path, "wb") as f:
|
||||
content = await api_instance.download_agent_task_artifact(
|
||||
task_id=task_id, artifact_id=artifact.artifact_id
|
||||
@@ -90,7 +90,7 @@ async def copy_agent_artifacts_into_folder(
|
||||
|
||||
|
||||
async def upload_artifacts(
|
||||
api_instance: AgentApi, artifacts_location: str, task_id: str, type: str
|
||||
api_instance: AgentApi, artifacts_location: Path, task_id: str, type: str
|
||||
) -> None:
|
||||
for file_path in get_list_of_file_paths(artifacts_location, type):
|
||||
relative_path: Optional[str] = "/".join(
|
||||
|
||||
@@ -18,8 +18,8 @@ def get_list_of_file_paths(
|
||||
return list(source_dir.iterdir())
|
||||
|
||||
|
||||
def copy_artifacts_into_temp_folder(
|
||||
workspace: str | Path, artifact_folder_name: str, challenge_dir_path: str | Path
|
||||
def copy_challenge_artifacts_into_workspace(
|
||||
challenge_dir_path: str | Path, artifact_folder_name: str, workspace: str | Path
|
||||
) -> None:
|
||||
file_paths = get_list_of_file_paths(challenge_dir_path, artifact_folder_name)
|
||||
for file_path in file_paths:
|
||||
|
||||
@@ -5,10 +5,10 @@ import logging
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from collections import deque
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import psutil
|
||||
@@ -18,6 +18,7 @@ from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Extra, ValidationError
|
||||
|
||||
from agbenchmark.challenges import ChallengeInfo
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.processing.report_types_v2 import (
|
||||
BenchmarkRun,
|
||||
@@ -27,14 +28,13 @@ from agbenchmark.reports.processing.report_types_v2 import (
|
||||
TaskInfo,
|
||||
)
|
||||
from agbenchmark.schema import TaskEvalRequestBody
|
||||
from agbenchmark.utils.data_types import ChallengeData
|
||||
from agbenchmark.utils.utils import write_pretty_json
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHALLENGES: dict[str, ChallengeData] = {}
|
||||
CHALLENGES: dict[str, ChallengeInfo] = {}
|
||||
challenges_path = Path(__file__).parent / "challenges"
|
||||
challenge_spec_files = deque(
|
||||
glob.glob(
|
||||
@@ -52,7 +52,7 @@ while challenge_spec_files:
|
||||
|
||||
logger.debug(f"Loading {challenge_relpath}...")
|
||||
try:
|
||||
challenge_info = ChallengeData.parse_file(challenge_spec_file)
|
||||
challenge_info = ChallengeInfo.parse_file(challenge_spec_file)
|
||||
except ValidationError as e:
|
||||
if logging.getLogger().level == logging.DEBUG:
|
||||
logger.warning(f"Spec file {challenge_relpath} failed to load:\n{e}")
|
||||
@@ -68,7 +68,14 @@ while challenge_spec_files:
|
||||
|
||||
CHALLENGES[challenge_info.eval_id] = challenge_info
|
||||
|
||||
task_informations = defaultdict(dict[str, Any])
|
||||
|
||||
class BenchmarkTaskInfo(BaseModel):
|
||||
task_id: str
|
||||
start_time: datetime.datetime
|
||||
challenge_info: ChallengeInfo
|
||||
|
||||
|
||||
task_informations: dict[str, BenchmarkTaskInfo] = {}
|
||||
|
||||
|
||||
def find_agbenchmark_without_uvicorn():
|
||||
@@ -124,12 +131,8 @@ def stream_output(pipe):
|
||||
|
||||
|
||||
def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
from agbenchmark.agent_api_interface import (
|
||||
copy_agent_artifacts_into_folder,
|
||||
upload_artifacts,
|
||||
)
|
||||
from agbenchmark.agent_interface import copy_artifacts_into_temp_folder
|
||||
from agbenchmark.generate_test import create_challenge_from_spec_file
|
||||
from agbenchmark.agent_api_interface import upload_artifacts
|
||||
from agbenchmark.challenges import get_challenge_from_source_uri
|
||||
from agbenchmark.main import run_benchmark
|
||||
|
||||
configuration = Configuration(
|
||||
@@ -231,28 +234,29 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
}
|
||||
"""
|
||||
try:
|
||||
challenge_info = CHALLENGES[task_eval_request.eval_id]
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
task_input = CHALLENGES[task_eval_request.eval_id].task
|
||||
task_input = challenge_info.task
|
||||
|
||||
task_request_body = TaskRequestBody(input=task_input)
|
||||
task_response = await api_instance.create_agent_task(
|
||||
task_request_body=task_request_body
|
||||
)
|
||||
task_informations[task_response.task_id][
|
||||
"benchmark_start_time"
|
||||
] = datetime.datetime.now(datetime.timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S+00:00"
|
||||
)
|
||||
task_informations[task_response.task_id][
|
||||
"eval_id"
|
||||
] = task_eval_request.eval_id
|
||||
await upload_artifacts(
|
||||
api_instance,
|
||||
str(CHALLENGES[task_eval_request.eval_id].spec_file.parent),
|
||||
task_response.task_id,
|
||||
"artifacts_in",
|
||||
task_info = BenchmarkTaskInfo(
|
||||
task_id=task_response.task_id,
|
||||
start_time=datetime.datetime.now(datetime.timezone.utc),
|
||||
challenge_info=challenge_info,
|
||||
)
|
||||
task_informations[task_info.task_id] = task_info
|
||||
|
||||
if input_artifacts_dir := challenge_info.task_artifacts_dir:
|
||||
await upload_artifacts(
|
||||
api_instance,
|
||||
input_artifacts_dir,
|
||||
task_response.task_id,
|
||||
"artifacts_in",
|
||||
)
|
||||
return task_response
|
||||
except ApiException as e:
|
||||
logger.error(f"Error whilst trying to create a task:\n{e}")
|
||||
@@ -281,41 +285,39 @@ def setup_fastapi_app(agbenchmark_config: AgentBenchmarkConfig) -> FastAPI:
|
||||
|
||||
@router.post("/agent/tasks/{task_id}/evaluations")
|
||||
async def create_evaluation(task_id: str) -> BenchmarkRun:
|
||||
challenge_info = CHALLENGES[task_informations[task_id]["eval_id"]]
|
||||
workspace = agbenchmark_config.temp_folder
|
||||
task_info = task_informations[task_id]
|
||||
challenge = get_challenge_from_source_uri(task_info.challenge_info.source_uri)
|
||||
try:
|
||||
async with ApiClient(configuration) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
await copy_agent_artifacts_into_folder(api_instance, task_id, workspace)
|
||||
|
||||
artifact_path = challenge_info.spec_file.parent
|
||||
copy_artifacts_into_temp_folder(workspace, "custom_python", artifact_path)
|
||||
|
||||
challenge = create_challenge_from_spec_file(challenge_info.spec_file)
|
||||
scores = challenge.get_scores(workspace)
|
||||
is_score_100 = 1 in scores["values"]
|
||||
eval_results = await challenge.evaluate_task_state(
|
||||
api_instance, task_id
|
||||
)
|
||||
|
||||
eval_info = BenchmarkRun(
|
||||
repository_info=RepositoryInfo(),
|
||||
run_details=RunDetails(
|
||||
command=f"agbenchmark --test={challenge_info.name}",
|
||||
command=f"agbenchmark --test={challenge.info.name}",
|
||||
benchmark_start_time=(
|
||||
task_informations[task_id]["benchmark_start_time"]
|
||||
task_info.start_time.strftime("%Y-%m-%dT%H:%M:%S+00:00")
|
||||
),
|
||||
test_name=challenge_info.name,
|
||||
test_name=challenge.info.name,
|
||||
),
|
||||
task_info=TaskInfo(
|
||||
data_path=str(
|
||||
challenge_info.spec_file.relative_to(challenges_path.parent)
|
||||
),
|
||||
data_path=challenge.info.source_uri,
|
||||
is_regression=None,
|
||||
category=[c.value for c in challenge_info.category],
|
||||
task=challenge_info.task,
|
||||
answer=challenge_info.ground.answer,
|
||||
description=challenge_info.info.description,
|
||||
category=[c.value for c in challenge.info.category],
|
||||
task=challenge.info.task,
|
||||
answer=challenge.info.reference_answer or "",
|
||||
description=challenge.info.description or "",
|
||||
),
|
||||
metrics=Metrics(
|
||||
success=is_score_100,
|
||||
success=all(e.passed for e in eval_results),
|
||||
success_percentage=(
|
||||
100 * sum(e.score for e in eval_results) / len(eval_results)
|
||||
if eval_results # avoid division by 0
|
||||
else 0
|
||||
),
|
||||
attempted=True,
|
||||
),
|
||||
config={},
|
||||
|
||||
@@ -3,14 +3,26 @@ import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from .base import BaseChallenge, ChallengeInfo
|
||||
from .builtin import OPTIONAL_CATEGORIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_challenge_from_source_uri(source_uri: str) -> type[BaseChallenge]:
|
||||
from .builtin import BuiltinChallenge
|
||||
|
||||
provider_prefix = source_uri.split("/", 1)[0]
|
||||
|
||||
if provider_prefix == BuiltinChallenge.SOURCE_URI_PREFIX:
|
||||
return BuiltinChallenge.from_source_uri(source_uri)
|
||||
|
||||
raise ValueError(f"Cannot resolve source_uri '{source_uri}'")
|
||||
|
||||
|
||||
def get_unique_categories() -> set[str]:
|
||||
"""
|
||||
Find all data.json files in the directory relative to this file and its
|
||||
subdirectories, read the "category" field from each file, and return a set of unique
|
||||
categories.
|
||||
Reads all challenge spec files and returns a set of all their categories.
|
||||
"""
|
||||
categories = set()
|
||||
|
||||
@@ -30,3 +42,11 @@ def get_unique_categories() -> set[str]:
|
||||
continue
|
||||
|
||||
return categories
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseChallenge",
|
||||
"ChallengeInfo",
|
||||
"get_unique_categories",
|
||||
"OPTIONAL_CATEGORIES",
|
||||
]
|
||||
|
||||
99
benchmark/agbenchmark/challenges/base.py
Normal file
99
benchmark/agbenchmark/challenges/base.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, ClassVar, Optional
|
||||
|
||||
import pytest
|
||||
from agent_protocol_client import AgentApi, Step
|
||||
from colorama import Fore, Style
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.utils.data_types import Category, DifficultyLevel, EvalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChallengeInfo(BaseModel):
|
||||
eval_id: str = ""
|
||||
name: str
|
||||
task: str
|
||||
task_artifacts_dir: Optional[Path] = None
|
||||
category: list[Category]
|
||||
difficulty: Optional[DifficultyLevel] = None
|
||||
description: Optional[str] = None
|
||||
dependencies: list[str] = Field(default_factory=list)
|
||||
reference_answer: Optional[str]
|
||||
|
||||
source_uri: str
|
||||
"""Internal reference indicating the source of the challenge specification"""
|
||||
|
||||
|
||||
class BaseChallenge(ABC):
|
||||
"""
|
||||
The base class and shared interface for all specific challenge implementations.
|
||||
"""
|
||||
|
||||
info: ClassVar[ChallengeInfo]
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_source_uri(cls, source_uri: str) -> type["BaseChallenge"]:
|
||||
"""
|
||||
Construct an individual challenge subclass from a suitable `source_uri` (as in
|
||||
`ChallengeInfo.source_uri`).
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def test_method(
|
||||
self, config: AgentBenchmarkConfig, request: pytest.FixtureRequest
|
||||
) -> None:
|
||||
"""
|
||||
Test method for use by Pytest-based benchmark sessions. Should return normally
|
||||
if the challenge passes, and raise a (preferably descriptive) error otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
async def run_challenge(
|
||||
cls, config: AgentBenchmarkConfig, timeout: int
|
||||
) -> AsyncIterator[Step]:
|
||||
"""
|
||||
Runs the challenge on the subject agent with the specified timeout.
|
||||
Also prints basic challenge and status info to STDOUT.
|
||||
|
||||
Params:
|
||||
config: The subject agent's benchmark config.
|
||||
timeout: Timeout (seconds) after which to stop the run if not finished.
|
||||
|
||||
Yields:
|
||||
Step: The steps generated by the agent for the challenge task.
|
||||
"""
|
||||
# avoid circular import
|
||||
from agbenchmark.agent_api_interface import run_api_agent
|
||||
|
||||
print()
|
||||
print(
|
||||
f"{Fore.MAGENTA + Style.BRIGHT}{'='*24} "
|
||||
f"Starting {cls.info.name} challenge"
|
||||
f" {'='*24}{Style.RESET_ALL}"
|
||||
)
|
||||
print(f"{Fore.CYAN}Timeout:{Fore.RESET} {timeout} seconds")
|
||||
print(f"{Fore.CYAN}Task:{Fore.RESET} {cls.info.task}")
|
||||
|
||||
print()
|
||||
logger.debug(f"Starting {cls.info.name} challenge run")
|
||||
i = 0
|
||||
async for step in run_api_agent(cls.info.task, config, timeout):
|
||||
i += 1
|
||||
print(f"[{cls.info.name}] - step {step.name} ({i}. request)")
|
||||
yield step
|
||||
logger.debug(f"Finished {cls.info.name} challenge run")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def evaluate_task_state(
|
||||
cls, agent: AgentApi, task_id: str
|
||||
) -> list[EvalResult]:
|
||||
...
|
||||
422
benchmark/agbenchmark/challenges/builtin.py
Normal file
422
benchmark/agbenchmark/challenges/builtin.py
Normal file
@@ -0,0 +1,422 @@
|
||||
from collections import deque
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Iterator, Literal, Optional
|
||||
|
||||
import pytest
|
||||
from agent_protocol_client import AgentApi, ApiClient, Configuration as ClientConfig
|
||||
from colorama import Fore, Style
|
||||
from openai import _load_client as get_openai_client
|
||||
from pydantic import BaseModel, constr, Field, validator
|
||||
|
||||
from agbenchmark.agent_api_interface import download_agent_artifacts_into_folder
|
||||
from agbenchmark.agent_interface import copy_challenge_artifacts_into_workspace
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.utils.data_types import Category, DifficultyLevel, EvalResult
|
||||
from agbenchmark.utils.prompts import (
|
||||
END_PROMPT,
|
||||
FEW_SHOT_EXAMPLES,
|
||||
PROMPT_MAP,
|
||||
SCORING_MAP,
|
||||
)
|
||||
|
||||
from .base import BaseChallenge, ChallengeInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with open(Path(__file__).parent / "optional_categories.json") as f:
|
||||
OPTIONAL_CATEGORIES: list[str] = json.load(f)["optional_categories"]
|
||||
|
||||
|
||||
class BuiltinChallengeSpec(BaseModel):
|
||||
eval_id: str = ""
|
||||
name: str
|
||||
task: str
|
||||
category: list[Category]
|
||||
dependencies: list[str]
|
||||
cutoff: int
|
||||
|
||||
class Info(BaseModel):
|
||||
difficulty: DifficultyLevel
|
||||
description: constr(regex=r"^Tests if the agent can.*")
|
||||
side_effects: list[str] = Field(default_factory=list)
|
||||
|
||||
info: Info
|
||||
|
||||
class Ground(BaseModel):
|
||||
answer: str
|
||||
should_contain: Optional[list[str]] = None
|
||||
should_not_contain: Optional[list[str]] = None
|
||||
files: list[str]
|
||||
case_sensitive: Optional[bool] = True
|
||||
|
||||
class Eval(BaseModel):
|
||||
type: str
|
||||
scoring: Optional[Literal["percentage", "scale", "binary"]]
|
||||
template: Optional[Literal["rubric", "reference", "question", "custom"]]
|
||||
examples: Optional[str]
|
||||
|
||||
@validator("scoring", "template", always=True)
|
||||
def validate_eval_fields(cls, v, values, field):
|
||||
if "type" in values and values["type"] == "llm":
|
||||
if v is None:
|
||||
raise ValueError(
|
||||
f"{field.name} must be provided when eval type is 'llm'"
|
||||
)
|
||||
else:
|
||||
if v is not None:
|
||||
raise ValueError(
|
||||
f"{field.name} should only exist when eval type is 'llm'"
|
||||
)
|
||||
return v
|
||||
|
||||
eval: Eval
|
||||
|
||||
ground: Ground
|
||||
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
spec_file: Path | None = Field(None, exclude=True)
|
||||
|
||||
|
||||
class BuiltinChallenge(BaseChallenge):
|
||||
"""
|
||||
Base class for AGBenchmark's built-in challenges (challenges/**/*.json).
|
||||
|
||||
All of the logic is present in this class. Individual challenges are created as
|
||||
subclasses of `BuiltinChallenge` with challenge-specific values assigned to the
|
||||
ClassVars `_spec` etc.
|
||||
|
||||
Dynamically constructing subclasses rather than class instances for the individual
|
||||
challenges makes them suitable for collection by Pytest, which will run their
|
||||
`test_method` like any regular test item.
|
||||
"""
|
||||
|
||||
_spec: ClassVar[BuiltinChallengeSpec]
|
||||
CHALLENGE_LOCATION: ClassVar[str]
|
||||
ARTIFACTS_LOCATION: ClassVar[str]
|
||||
|
||||
SOURCE_URI_PREFIX = "__BUILTIN__"
|
||||
|
||||
@classmethod
|
||||
def from_challenge_spec(
|
||||
cls, spec: BuiltinChallengeSpec
|
||||
) -> type["BuiltinChallenge"]:
|
||||
if not spec.spec_file:
|
||||
raise ValueError("spec.spec_file not defined")
|
||||
|
||||
challenge_info = ChallengeInfo(
|
||||
eval_id=spec.eval_id,
|
||||
name=spec.name,
|
||||
task=spec.task,
|
||||
task_artifacts_dir=spec.spec_file.parent,
|
||||
category=spec.category,
|
||||
difficulty=spec.info.difficulty,
|
||||
description=spec.info.description,
|
||||
dependencies=spec.dependencies,
|
||||
reference_answer=spec.ground.answer,
|
||||
source_uri=(
|
||||
f"__BUILTIN__/{spec.spec_file.relative_to(Path(__file__).parent)}"
|
||||
),
|
||||
)
|
||||
|
||||
challenge_class_name = f"Test{challenge_info.name}"
|
||||
logger.debug(f"Creating {challenge_class_name} from spec: {spec.spec_file}")
|
||||
return type(
|
||||
challenge_class_name,
|
||||
(BuiltinChallenge,),
|
||||
{
|
||||
"info": challenge_info,
|
||||
"_spec": spec,
|
||||
"CHALLENGE_LOCATION": str(spec.spec_file),
|
||||
"ARTIFACTS_LOCATION": str(spec.spec_file.resolve().parent),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_challenge_spec_file(cls, spec_file: Path) -> type["BuiltinChallenge"]:
|
||||
challenge_spec = BuiltinChallengeSpec.parse_file(spec_file)
|
||||
challenge_spec.spec_file = spec_file
|
||||
return cls.from_challenge_spec(challenge_spec)
|
||||
|
||||
@classmethod
|
||||
def from_source_uri(cls, source_uri: str) -> type["BuiltinChallenge"]:
|
||||
if not source_uri.startswith(cls.SOURCE_URI_PREFIX):
|
||||
raise ValueError(f"Invalid source_uri for BuiltinChallenge: {source_uri}")
|
||||
|
||||
path = source_uri.split("/", 1)[1]
|
||||
spec_file = Path(__file__).parent / path
|
||||
return cls.from_challenge_spec_file(spec_file)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_method(
|
||||
self, config: AgentBenchmarkConfig, request: pytest.FixtureRequest
|
||||
) -> None:
|
||||
if os.environ.get("HELICONE_API_KEY"):
|
||||
from helicone.lock import HeliconeLockManager
|
||||
|
||||
HeliconeLockManager.write_custom_property("challenge", self.info.name)
|
||||
|
||||
timeout = self._spec.cutoff or 60
|
||||
|
||||
if request.config.getoption("--nc"):
|
||||
timeout = 100000
|
||||
elif cutoff := request.config.getoption("--cutoff"):
|
||||
timeout = int(cutoff) # type: ignore
|
||||
|
||||
task_id = ""
|
||||
timed_out = None
|
||||
try:
|
||||
async for step in self.run_challenge(config, timeout):
|
||||
if not task_id:
|
||||
task_id = step.task_id
|
||||
if request.config.getoption("--mock"):
|
||||
# Run only one step in mock mode
|
||||
break
|
||||
timed_out = False
|
||||
except TimeoutError:
|
||||
timed_out = True
|
||||
request.node.user_properties.append(("timed_out", timed_out))
|
||||
|
||||
agent_client_config = ClientConfig(host=config.host)
|
||||
async with ApiClient(agent_client_config) as api_client:
|
||||
api_instance = AgentApi(api_client)
|
||||
eval_results = await self.evaluate_task_state(api_instance, task_id)
|
||||
|
||||
if not eval_results:
|
||||
if timed_out:
|
||||
raise TimeoutError("Timed out, no results to evaluate")
|
||||
else:
|
||||
raise ValueError("No results to evaluate")
|
||||
|
||||
request.node.user_properties.append(
|
||||
(
|
||||
"answers",
|
||||
[r.result for r in eval_results]
|
||||
if request.config.getoption("--keep-answers")
|
||||
else None,
|
||||
)
|
||||
)
|
||||
request.node.user_properties.append(("scores", [r.score for r in eval_results]))
|
||||
|
||||
# FIXME: this allows partial failure
|
||||
assert any(r.passed for r in eval_results), (
|
||||
f"No passed evals: {eval_results}"
|
||||
if not timed_out
|
||||
else f"Timed out; no passed evals: {eval_results}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def evaluate_task_state(
|
||||
cls, agent: AgentApi, task_id: str
|
||||
) -> list[EvalResult]:
|
||||
with tempfile.TemporaryDirectory() as workspace:
|
||||
workspace = Path(workspace)
|
||||
await download_agent_artifacts_into_folder(agent, task_id, workspace)
|
||||
if cls.info.task_artifacts_dir:
|
||||
copy_challenge_artifacts_into_workspace(
|
||||
cls.info.task_artifacts_dir, "custom_python", workspace
|
||||
)
|
||||
|
||||
return list(cls.evaluate_workspace_content(workspace))
|
||||
|
||||
@classmethod
|
||||
def evaluate_workspace_content(cls, workspace: Path) -> Iterator[EvalResult]:
|
||||
if cls._spec.task == "" and os.getenv("IS_MOCK"):
|
||||
yield EvalResult(
|
||||
result="This is a mock answer",
|
||||
result_source="step_output",
|
||||
score=1.0,
|
||||
passed=True,
|
||||
)
|
||||
return
|
||||
|
||||
result_ground = cls._spec.ground
|
||||
outputs_for_eval = cls.get_outputs_for_eval(workspace, result_ground)
|
||||
|
||||
if result_ground.should_contain or result_ground.should_not_contain:
|
||||
for source, content in outputs_for_eval:
|
||||
score = cls.score_result(content, result_ground)
|
||||
if score is not None:
|
||||
print(f"{Fore.GREEN}Your score is:{Style.RESET_ALL}", score)
|
||||
yield EvalResult(
|
||||
result=content,
|
||||
result_source=str(source),
|
||||
score=score,
|
||||
passed=score > 0.9, # FIXME: arbitrary threshold
|
||||
)
|
||||
|
||||
if result_ground.eval.type == "llm":
|
||||
combined_results = "\n".join(output[1] for output in outputs_for_eval)
|
||||
llm_eval = cls.score_result_with_llm(combined_results, result_ground)
|
||||
print(f"{Fore.GREEN}Your score is:{Style.RESET_ALL}", llm_eval)
|
||||
if result_ground.eval.scoring == "percentage":
|
||||
score = llm_eval / 100
|
||||
elif result_ground.eval.scoring == "scale":
|
||||
score = llm_eval / 10
|
||||
else:
|
||||
score = llm_eval
|
||||
|
||||
yield EvalResult(
|
||||
result=combined_results,
|
||||
result_source=", ".join(str(res[0]) for res in outputs_for_eval),
|
||||
score=score,
|
||||
passed=score > 0.9, # FIXME: arbitrary threshold
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_outputs_for_eval(
|
||||
workspace: str | Path | dict[str, str], ground: BuiltinChallengeSpec.Ground
|
||||
) -> Iterator[tuple[str | Path, str]]:
|
||||
if isinstance(workspace, dict):
|
||||
workspace = workspace["output"]
|
||||
|
||||
script_dir = workspace
|
||||
|
||||
for file_pattern in ground.files:
|
||||
# Check if it is a file extension
|
||||
if file_pattern.startswith("."):
|
||||
# Find all files with the given extension in the workspace
|
||||
matching_files = glob.glob(os.path.join(script_dir, "*" + file_pattern))
|
||||
else:
|
||||
# Otherwise, it is a specific file
|
||||
matching_files = [os.path.join(script_dir, file_pattern)]
|
||||
|
||||
for file_path in matching_files:
|
||||
if ground.eval.type == "python":
|
||||
result = subprocess.run(
|
||||
[sys.executable, file_path],
|
||||
cwd=os.path.abspath(workspace),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if "error" in result.stderr or result.returncode != 0:
|
||||
print(result.stderr)
|
||||
assert False, result.stderr
|
||||
yield (
|
||||
Path(file_path).relative_to(workspace),
|
||||
f"Output: {result.stdout}\n",
|
||||
)
|
||||
else:
|
||||
with open(file_path, "r") as f:
|
||||
yield Path(file_path).relative_to(workspace), f.read()
|
||||
else:
|
||||
if ground.eval.type == "pytest":
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pytest"],
|
||||
cwd=os.path.abspath(workspace),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if "error" in result.stderr or result.returncode != 0:
|
||||
print(result.stderr)
|
||||
assert False, result.stderr
|
||||
yield "pytest", f"Output: {result.stdout}\n"
|
||||
|
||||
@staticmethod
|
||||
def score_result(content: str, ground: BuiltinChallengeSpec.Ground) -> float | None:
|
||||
print(f"{Fore.BLUE}Scoring content:{Style.RESET_ALL}", content)
|
||||
if ground.should_contain:
|
||||
for should_contain_word in ground.should_contain:
|
||||
if not ground.case_sensitive:
|
||||
should_contain_word = should_contain_word.lower()
|
||||
content = content.lower()
|
||||
print_content = (
|
||||
f"{Fore.BLUE}Word that should exist{Style.RESET_ALL}"
|
||||
f" - {should_contain_word}:"
|
||||
)
|
||||
if should_contain_word not in content:
|
||||
print(print_content, "False")
|
||||
return 0.0
|
||||
else:
|
||||
print(print_content, "True")
|
||||
return 1.0
|
||||
|
||||
if ground.should_not_contain:
|
||||
for should_not_contain_word in ground.should_not_contain:
|
||||
if not ground.case_sensitive:
|
||||
should_not_contain_word = should_not_contain_word.lower()
|
||||
content = content.lower()
|
||||
print_content = (
|
||||
f"{Fore.BLUE}Word that should not exist{Style.RESET_ALL}"
|
||||
f" - {should_not_contain_word}:"
|
||||
)
|
||||
if should_not_contain_word in content:
|
||||
print(print_content, "False")
|
||||
return 0.0
|
||||
else:
|
||||
print(print_content, "True")
|
||||
return 1.0
|
||||
|
||||
@classmethod
|
||||
def score_result_with_llm(
|
||||
cls, content: str, ground: BuiltinChallengeSpec.Ground
|
||||
) -> float:
|
||||
if os.getenv("IS_MOCK"):
|
||||
return 1.0
|
||||
|
||||
# the validation for this is done in the Eval BaseModel
|
||||
scoring = SCORING_MAP[ground.eval.scoring] # type: ignore
|
||||
prompt = PROMPT_MAP[ground.eval.template].format( # type: ignore
|
||||
task=cls._spec.task, scoring=scoring, answer=ground.answer, response=content
|
||||
)
|
||||
|
||||
if ground.eval.examples:
|
||||
prompt += FEW_SHOT_EXAMPLES.format(examples=ground.eval.examples)
|
||||
|
||||
prompt += END_PROMPT
|
||||
|
||||
answer = get_openai_client().chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
],
|
||||
)
|
||||
|
||||
return float(answer.choices[0].message.content) # type: ignore
|
||||
|
||||
|
||||
def load_builtin_challenges() -> Iterator[type[BuiltinChallenge]]:
|
||||
logger.info("Loading built-in challenges...")
|
||||
|
||||
challenges_path = os.path.dirname(__file__)
|
||||
logger.debug(f"Looking for challenge spec files in {challenges_path}...")
|
||||
|
||||
json_files = deque(
|
||||
glob.glob(
|
||||
f"{challenges_path}/**/data.json",
|
||||
recursive=True,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(json_files)} built-in challenges.")
|
||||
|
||||
loaded, ignored = 0, 0
|
||||
while json_files:
|
||||
# Take and remove the first element from json_files
|
||||
json_file = json_files.popleft()
|
||||
if _challenge_should_be_ignored(json_file):
|
||||
ignored += 1
|
||||
continue
|
||||
|
||||
challenge = BuiltinChallenge.from_challenge_spec_file(Path(json_file))
|
||||
logger.debug(f"Generated test for {challenge.info.name}")
|
||||
yield challenge
|
||||
|
||||
loaded += 1
|
||||
|
||||
logger.info(
|
||||
f"Loading built-in challenges complete: loaded {loaded}, ignored {ignored}."
|
||||
)
|
||||
|
||||
|
||||
def _challenge_should_be_ignored(json_file_path: str):
|
||||
return (
|
||||
"challenges/deprecated" in json_file_path
|
||||
or "challenges/library" in json_file_path
|
||||
)
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseSettings
|
||||
from pydantic import BaseSettings, Field
|
||||
|
||||
|
||||
def _calculate_info_test_path(base_path: Path, benchmark_start_time: datetime) -> Path:
|
||||
@@ -57,7 +57,7 @@ class AgentBenchmarkConfig(BaseSettings, extra="allow"):
|
||||
subject application exposes an Agent Protocol compliant API.
|
||||
"""
|
||||
|
||||
agbenchmark_config_dir: Path
|
||||
agbenchmark_config_dir: Path = Field(..., exclude=True)
|
||||
"""Path to the agbenchmark_config folder of the subject agent application."""
|
||||
|
||||
categories: list[str] | None = None
|
||||
|
||||
@@ -6,17 +6,18 @@ import shutil
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from agbenchmark.challenges import OPTIONAL_CATEGORIES, BaseChallenge
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.ReportManager import RegressionTestsTracker
|
||||
from agbenchmark.reports.reports import (
|
||||
finalize_reports,
|
||||
generate_single_call_report,
|
||||
finalize_test_report,
|
||||
initialize_test_report,
|
||||
session_finish,
|
||||
)
|
||||
from agbenchmark.utils.challenge import Challenge
|
||||
from agbenchmark.utils.data_types import Category
|
||||
|
||||
GLOBAL_TIMEOUT = (
|
||||
@@ -28,7 +29,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
pytest_plugins = ["agbenchmark.utils.dependencies"]
|
||||
collect_ignore = ["challenges"]
|
||||
suite_reports: dict[str, list] = {}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -118,18 +118,18 @@ def check_regression(request: pytest.FixtureRequest) -> None:
|
||||
request: The request object from which the test name and the benchmark
|
||||
configuration are retrieved.
|
||||
"""
|
||||
test_name = request.node.parent.name
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
regression_report = agbenchmark_config.regression_tests_file
|
||||
data = json.loads(regression_report.read_bytes())
|
||||
challenge_location = getattr(request.node.parent.cls, "CHALLENGE_LOCATION", "")
|
||||
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
||||
|
||||
test_name = request.node.parent.name
|
||||
challenge_location = getattr(request.node.parent.cls, "CHALLENGE_LOCATION", "")
|
||||
skip_string = f"Skipping {test_name} at {challenge_location}"
|
||||
|
||||
# Check if the test name exists in the regression tests
|
||||
if request.config.getoption("--improve") and data.get(test_name, None):
|
||||
is_regression_test = rt_tracker.has_regression_test(test_name)
|
||||
if request.config.getoption("--improve") and is_regression_test:
|
||||
pytest.skip(f"{skip_string} because it's a regression test")
|
||||
elif request.config.getoption("--maintain") and not data.get(test_name, None):
|
||||
elif request.config.getoption("--maintain") and not is_regression_test:
|
||||
pytest.skip(f"{skip_string} because it's not a regression test")
|
||||
|
||||
|
||||
@@ -149,24 +149,6 @@ def mock(request: pytest.FixtureRequest) -> bool:
|
||||
return request.config.getoption("--mock")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def timer(request: pytest.FixtureRequest) -> Generator[None, None, None]:
|
||||
"""
|
||||
Pytest fixture that times the execution of each test.
|
||||
At the start of each test, it records the current time.
|
||||
After the test function completes, it calculates the run time and adds it to
|
||||
the test node's `user_properties`.
|
||||
|
||||
Args:
|
||||
request: The `pytest.FixtureRequest` object through which the run time is stored
|
||||
in the test node's `user_properties`.
|
||||
"""
|
||||
start_time = time.time()
|
||||
yield
|
||||
run_time = time.time() - start_time
|
||||
request.node.user_properties.append(("run_time", run_time))
|
||||
|
||||
|
||||
def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None:
|
||||
"""
|
||||
Pytest hook that is called when a test report is being generated.
|
||||
@@ -176,21 +158,15 @@ def pytest_runtest_makereport(item: pytest.Item, call: pytest.CallInfo) -> None:
|
||||
item: The test item for which the report is being generated.
|
||||
call: The call object from which the test result is retrieved.
|
||||
"""
|
||||
challenge: type[Challenge] = item.cls # type: ignore
|
||||
challenge_data = challenge.data
|
||||
challenge_location = challenge.CHALLENGE_LOCATION
|
||||
challenge: type[BaseChallenge] = item.cls # type: ignore
|
||||
|
||||
if call.when == "setup":
|
||||
test_name = item.nodeid.split("::")[1]
|
||||
item.user_properties.append(("test_name", test_name))
|
||||
initialize_test_report(item, challenge.info)
|
||||
|
||||
if call.when == "call":
|
||||
answers = getattr(item, "answers", None)
|
||||
test_name = item.nodeid.split("::")[1]
|
||||
item.test_name = test_name
|
||||
|
||||
generate_single_call_report(
|
||||
item, call, challenge_data, answers, challenge_location, test_name
|
||||
)
|
||||
|
||||
if call.when == "teardown":
|
||||
finalize_reports(agbenchmark_config, item, challenge_data)
|
||||
finalize_test_report(item, call, agbenchmark_config)
|
||||
|
||||
|
||||
def timeout_monitor(start_time: int) -> None:
|
||||
@@ -226,21 +202,7 @@ def pytest_sessionfinish(session: pytest.Session) -> None:
|
||||
|
||||
Finalizes and saves the test reports.
|
||||
"""
|
||||
session_finish(agbenchmark_config, suite_reports)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scores(request: pytest.FixtureRequest) -> None:
|
||||
"""
|
||||
Pytest fixture that retrieves the scores of the test class.
|
||||
The scores are retrieved from the `Challenge.scores` attribute
|
||||
using the test class name.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
"""
|
||||
challenge: type[Challenge] = request.node.cls
|
||||
return challenge.scores.get(challenge.__name__)
|
||||
session_finish(agbenchmark_config)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(
|
||||
@@ -255,10 +217,7 @@ def pytest_collection_modifyitems(
|
||||
items: The collected test items to be modified.
|
||||
config: The active pytest configuration.
|
||||
"""
|
||||
regression_file = agbenchmark_config.regression_tests_file
|
||||
regression_tests: dict[str, Any] = (
|
||||
json.loads(regression_file.read_bytes()) if regression_file.is_file() else {}
|
||||
)
|
||||
rt_tracker = RegressionTestsTracker(agbenchmark_config.regression_tests_file)
|
||||
|
||||
try:
|
||||
challenges_beaten_in_the_past = json.loads(
|
||||
@@ -277,7 +236,7 @@ def pytest_collection_modifyitems(
|
||||
challenge = item.cls
|
||||
challenge_name = item.cls.__name__
|
||||
|
||||
if not issubclass(challenge, Challenge):
|
||||
if not issubclass(challenge, BaseChallenge):
|
||||
item.warn(
|
||||
pytest.PytestCollectionWarning(
|
||||
f"Non-challenge item collected: {challenge}"
|
||||
@@ -287,7 +246,7 @@ def pytest_collection_modifyitems(
|
||||
continue
|
||||
|
||||
# --test: remove the test from the set if it's not specifically selected
|
||||
if selected_tests and challenge.data.name not in selected_tests:
|
||||
if selected_tests and challenge.info.name not in selected_tests:
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
@@ -295,8 +254,8 @@ def pytest_collection_modifyitems(
|
||||
# --maintain -> only challenges expected to be passed (= regression tests)
|
||||
# --improve -> only challenges that so far are not passed (reliably)
|
||||
# --explore -> only challenges that have never been passed
|
||||
is_regression_test = regression_tests.get(challenge.data.name, None)
|
||||
has_been_passed = challenges_beaten_in_the_past.get(challenge.data.name, False)
|
||||
is_regression_test = rt_tracker.has_regression_test(challenge.info.name)
|
||||
has_been_passed = challenges_beaten_in_the_past.get(challenge.info.name, False)
|
||||
if (
|
||||
(config.getoption("--maintain") and not is_regression_test)
|
||||
or (config.getoption("--improve") and is_regression_test)
|
||||
@@ -305,7 +264,7 @@ def pytest_collection_modifyitems(
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
dependencies = challenge.data.dependencies
|
||||
dependencies = challenge.info.dependencies
|
||||
if (
|
||||
config.getoption("--test")
|
||||
or config.getoption("--no-dep")
|
||||
@@ -319,17 +278,17 @@ def pytest_collection_modifyitems(
|
||||
elif config.getoption("--improve"):
|
||||
# Filter dependencies, keep only deps that are not "regression" tests
|
||||
dependencies = [
|
||||
d for d in dependencies if not regression_tests.get(d, None)
|
||||
d for d in dependencies if not rt_tracker.has_regression_test(d)
|
||||
]
|
||||
|
||||
# Set category markers
|
||||
challenge_categories = [c.value for c in challenge.data.category]
|
||||
challenge_categories = set(c.value for c in challenge.info.category)
|
||||
for category in challenge_categories:
|
||||
item.add_marker(category)
|
||||
|
||||
# Enforce category selection
|
||||
if selected_categories:
|
||||
if not set(challenge_categories).intersection(set(selected_categories)):
|
||||
if not challenge_categories.intersection(set(selected_categories)):
|
||||
items.remove(item)
|
||||
continue
|
||||
# # Filter dependencies, keep only deps from selected categories
|
||||
@@ -338,6 +297,22 @@ def pytest_collection_modifyitems(
|
||||
# if not set(d.categories).intersection(set(selected_categories))
|
||||
# ]
|
||||
|
||||
# Skip items in optional categories that are not selected for the subject agent
|
||||
challenge_optional_categories = challenge_categories & set(OPTIONAL_CATEGORIES)
|
||||
if challenge_optional_categories and not (
|
||||
agbenchmark_config.categories
|
||||
and challenge_optional_categories.issubset(
|
||||
set(agbenchmark_config.categories)
|
||||
)
|
||||
):
|
||||
logger.debug(
|
||||
f"Skipping {challenge_name}: "
|
||||
f"category {' and '.join(challenge_optional_categories)} is optional, "
|
||||
"and not explicitly selected in the benchmark config."
|
||||
)
|
||||
items.remove(item)
|
||||
continue
|
||||
|
||||
# Add marker for the DependencyManager
|
||||
item.add_marker(pytest.mark.depends(on=dependencies, name=challenge_name))
|
||||
|
||||
|
||||
@@ -1,75 +1,24 @@
|
||||
import glob
|
||||
"""
|
||||
AGBenchmark's test discovery endpoint for Pytest.
|
||||
|
||||
This module is picked up by Pytest's *_test.py file matching pattern, and all challenge
|
||||
classes in the module that conform to the `Test*` pattern are collected.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
from agbenchmark.utils.challenge import Challenge
|
||||
from agbenchmark.utils.data_types import ChallengeData
|
||||
|
||||
DATA_CATEGORY = {}
|
||||
from agbenchmark.challenges.builtin import load_builtin_challenges
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_CATEGORY = {}
|
||||
|
||||
def create_challenge_from_spec_file(spec_file: Path) -> type[Challenge]:
|
||||
challenge = Challenge.from_challenge_spec(spec_file)
|
||||
DATA_CATEGORY[challenge.data.name] = challenge.data.category[0].value
|
||||
return challenge
|
||||
|
||||
|
||||
def create_challenge_from_spec_file_path(spec_file_path: str) -> type[Challenge]:
|
||||
spec_file = Path(spec_file_path).resolve()
|
||||
return create_challenge_from_spec_file(spec_file)
|
||||
|
||||
|
||||
def load_challenges() -> None:
|
||||
logger.info("Loading challenges...")
|
||||
|
||||
challenges_path = os.path.join(os.path.dirname(__file__), "challenges")
|
||||
logger.debug(f"Looking for challenges in {challenges_path}...")
|
||||
|
||||
json_files = deque(
|
||||
glob.glob(
|
||||
f"{challenges_path}/**/data.json",
|
||||
recursive=True,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(json_files)} challenges.")
|
||||
logger.debug(f"Sample path: {json_files[0]}")
|
||||
|
||||
loaded, ignored = 0, 0
|
||||
while json_files:
|
||||
# Take and remove the first element from json_files
|
||||
json_file = json_files.popleft()
|
||||
if challenge_should_be_ignored(json_file):
|
||||
ignored += 1
|
||||
continue
|
||||
|
||||
challenge_info = ChallengeData.parse_file(json_file)
|
||||
|
||||
challenge_class = create_challenge_from_spec_file_path(json_file)
|
||||
|
||||
logger.debug(f"Generated test for {challenge_info.name}")
|
||||
_add_challenge_to_module(challenge_class)
|
||||
loaded += 1
|
||||
|
||||
logger.info(f"Loading challenges complete: loaded {loaded}, ignored {ignored}.")
|
||||
|
||||
|
||||
def challenge_should_be_ignored(json_file_path: str):
|
||||
return (
|
||||
"challenges/deprecated" in json_file_path
|
||||
or "challenges/library" in json_file_path
|
||||
)
|
||||
|
||||
|
||||
def _add_challenge_to_module(challenge: type[Challenge]):
|
||||
# Load challenges and attach them to this module
|
||||
for challenge in load_builtin_challenges():
|
||||
# Attach the Challenge class to this module so it can be discovered by pytest
|
||||
module = importlib.import_module(__name__)
|
||||
setattr(module, f"{challenge.__name__}", challenge)
|
||||
setattr(module, challenge.__name__, challenge)
|
||||
|
||||
|
||||
load_challenges()
|
||||
# Build a map of challenge names and their primary category
|
||||
DATA_CATEGORY[challenge.info.name] = challenge.info.category[0].value
|
||||
|
||||
@@ -1,21 +1,29 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.processing.graphs import save_single_radar_chart
|
||||
from agbenchmark.reports.processing.process_report import get_agent_category
|
||||
from agbenchmark.reports.processing.report_types import Report
|
||||
from agbenchmark.reports.processing.report_types import MetricsOverall, Report, Test
|
||||
from agbenchmark.utils.utils import get_highest_success_difficulty
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SingletonReportManager:
|
||||
instance = None
|
||||
|
||||
INFO_MANAGER: "SessionReportManager"
|
||||
REGRESSION_MANAGER: "RegressionTestsTracker"
|
||||
SUCCESS_RATE_TRACKER: "SuccessRatesTracker"
|
||||
|
||||
def __new__(cls):
|
||||
if not cls.instance:
|
||||
cls.instance = super(SingletonReportManager, cls).__new__(cls)
|
||||
@@ -26,17 +34,16 @@ class SingletonReportManager:
|
||||
) # or any logic to fetch the datetime
|
||||
|
||||
# Make the Managers class attributes
|
||||
cls.REGRESSION_MANAGER = ReportManager(
|
||||
agent_benchmark_config.regression_tests_file,
|
||||
benchmark_start_time_dt,
|
||||
)
|
||||
cls.INFO_MANAGER = ReportManager(
|
||||
cls.INFO_MANAGER = SessionReportManager(
|
||||
agent_benchmark_config.get_report_dir(benchmark_start_time_dt)
|
||||
/ "report.json",
|
||||
benchmark_start_time_dt,
|
||||
)
|
||||
cls.INTERNAL_INFO_MANAGER = ReportManager(
|
||||
agent_benchmark_config.success_rate_file, benchmark_start_time_dt
|
||||
cls.REGRESSION_MANAGER = RegressionTestsTracker(
|
||||
agent_benchmark_config.regression_tests_file
|
||||
)
|
||||
cls.SUCCESS_RATE_TRACKER = SuccessRatesTracker(
|
||||
agent_benchmark_config.success_rate_file
|
||||
)
|
||||
|
||||
return cls.instance
|
||||
@@ -44,39 +51,33 @@ class SingletonReportManager:
|
||||
@classmethod
|
||||
def clear_instance(cls):
|
||||
cls.instance = None
|
||||
cls.REGRESSION_MANAGER = None
|
||||
cls.INFO_MANAGER = None
|
||||
cls.INTERNAL_INFO_MANAGER = None
|
||||
cls.REGRESSION_MANAGER = None
|
||||
cls.SUCCESS_RATE_TRACKER = None
|
||||
|
||||
|
||||
class ReportManager:
|
||||
class BaseReportManager:
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
def __init__(self, report_file: Path, benchmark_start_time: datetime):
|
||||
tests: dict[str, Any]
|
||||
|
||||
def __init__(self, report_file: Path):
|
||||
self.report_file = report_file
|
||||
self.start_time = time.time()
|
||||
self.benchmark_start_time = benchmark_start_time
|
||||
|
||||
self.load()
|
||||
|
||||
def load(self) -> None:
|
||||
if not self.report_file.exists():
|
||||
self.report_file.parent.mkdir(exist_ok=True)
|
||||
self.report_file.touch()
|
||||
|
||||
try:
|
||||
with self.report_file.open("r") as f:
|
||||
file_content = (
|
||||
f.read().strip()
|
||||
) # read the content and remove any leading/trailing whitespace
|
||||
if file_content: # if file is not empty, load the json
|
||||
data = json.loads(file_content)
|
||||
self.tests = {k: data[k] for k in sorted(data)}
|
||||
else: # if file is empty, assign an empty dictionary
|
||||
self.tests = {}
|
||||
data = json.load(f)
|
||||
self.tests = {k: data[k] for k in sorted(data)}
|
||||
except FileNotFoundError:
|
||||
self.tests = {}
|
||||
except json.decoder.JSONDecodeError: # If JSON is invalid
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
logger.warning(f"Could not parse {self.report_file}: {e}")
|
||||
self.tests = {}
|
||||
self.save()
|
||||
|
||||
@@ -84,13 +85,6 @@ class ReportManager:
|
||||
with self.report_file.open("w") as f:
|
||||
json.dump(self.tests, f, indent=4)
|
||||
|
||||
def add_test(self, test_name: str, test_details: dict | list) -> None:
|
||||
if test_name.startswith("Test"):
|
||||
test_name = test_name[4:]
|
||||
self.tests[test_name] = test_details
|
||||
|
||||
self.save()
|
||||
|
||||
def remove_test(self, test_name: str) -> None:
|
||||
if test_name in self.tests:
|
||||
del self.tests[test_name]
|
||||
@@ -100,34 +94,61 @@ class ReportManager:
|
||||
self.tests = {}
|
||||
self.save()
|
||||
|
||||
def end_info_report(self, config: AgentBenchmarkConfig) -> None:
|
||||
|
||||
class SessionReportManager(BaseReportManager):
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
tests: dict[str, Test] | Report
|
||||
|
||||
def __init__(self, report_file: Path, benchmark_start_time: datetime):
|
||||
super().__init__(report_file)
|
||||
|
||||
self.start_time = time.time()
|
||||
self.benchmark_start_time = benchmark_start_time
|
||||
|
||||
def save(self) -> None:
|
||||
with self.report_file.open("w") as f:
|
||||
if isinstance(self.tests, Report):
|
||||
f.write(self.tests.json(indent=4))
|
||||
else:
|
||||
json.dump({k: v.dict() for k, v in self.tests.items()}, f, indent=4)
|
||||
|
||||
def add_test_report(self, test_name: str, test_report: Test) -> None:
|
||||
if isinstance(self.tests, Report):
|
||||
raise RuntimeError("Session report already finalized")
|
||||
|
||||
if test_name.startswith("Test"):
|
||||
test_name = test_name[4:]
|
||||
self.tests[test_name] = test_report
|
||||
|
||||
self.save()
|
||||
|
||||
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
|
||||
command = " ".join(sys.argv)
|
||||
|
||||
self.tests = {
|
||||
"command": command.split(os.sep)[-1],
|
||||
"benchmark_git_commit_sha": "---",
|
||||
"agent_git_commit_sha": "---",
|
||||
"completion_time": datetime.now(timezone.utc).strftime(
|
||||
if isinstance(self.tests, Report):
|
||||
raise RuntimeError("Session report already finalized")
|
||||
|
||||
self.tests = Report(
|
||||
command=command.split(os.sep)[-1],
|
||||
benchmark_git_commit_sha="---",
|
||||
agent_git_commit_sha="---",
|
||||
completion_time=datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S+00:00"
|
||||
),
|
||||
"benchmark_start_time": self.benchmark_start_time.strftime(
|
||||
benchmark_start_time=self.benchmark_start_time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%S+00:00"
|
||||
),
|
||||
"metrics": {
|
||||
"run_time": str(round(time.time() - self.start_time, 2)) + " seconds",
|
||||
"highest_difficulty": get_highest_success_difficulty(self.tests),
|
||||
"total_cost": self.get_total_costs(),
|
||||
},
|
||||
"tests": copy.copy(self.tests),
|
||||
"config": {
|
||||
k: v for k, v in json.loads(config.json()).items() if v is not None
|
||||
},
|
||||
}
|
||||
Report.parse_obj(self.tests)
|
||||
metrics=MetricsOverall(
|
||||
run_time=str(round(time.time() - self.start_time, 2)) + " seconds",
|
||||
highest_difficulty=get_highest_success_difficulty(self.tests),
|
||||
total_cost=self.get_total_costs(),
|
||||
),
|
||||
tests=copy.copy(self.tests),
|
||||
config=config.dict(exclude_none=True),
|
||||
)
|
||||
|
||||
converted_data = Report.parse_obj(self.tests)
|
||||
|
||||
agent_categories = get_agent_category(converted_data)
|
||||
agent_categories = get_agent_category(self.tests)
|
||||
if len(agent_categories) > 1:
|
||||
save_single_radar_chart(
|
||||
agent_categories,
|
||||
@@ -137,12 +158,15 @@ class ReportManager:
|
||||
self.save()
|
||||
|
||||
def get_total_costs(self):
|
||||
if isinstance(self.tests, Report):
|
||||
tests = self.tests.tests
|
||||
else:
|
||||
tests = self.tests
|
||||
|
||||
total_cost = 0
|
||||
all_costs_none = True
|
||||
for test_name, test_data in self.tests.items():
|
||||
cost = test_data["metrics"].get(
|
||||
"cost", 0
|
||||
) # gets the cost or defaults to 0 if cost is missing
|
||||
for test_data in tests.values():
|
||||
cost = test_data.metrics.cost or 0.0
|
||||
|
||||
if cost is not None: # check if cost is not None
|
||||
all_costs_none = False
|
||||
@@ -150,3 +174,32 @@ class ReportManager:
|
||||
if all_costs_none:
|
||||
total_cost = None
|
||||
return total_cost
|
||||
|
||||
|
||||
class RegressionTestsTracker(BaseReportManager):
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
tests: dict[str, dict]
|
||||
|
||||
def add_test(self, test_name: str, test_details: dict) -> None:
|
||||
if test_name.startswith("Test"):
|
||||
test_name = test_name[4:]
|
||||
self.tests[test_name] = test_details
|
||||
|
||||
self.save()
|
||||
|
||||
def has_regression_test(self, test_name: str) -> bool:
|
||||
return self.tests.get(test_name) is not None
|
||||
|
||||
|
||||
class SuccessRatesTracker(BaseReportManager):
|
||||
"""Abstracts interaction with the regression tests file"""
|
||||
|
||||
tests: dict[str, list[bool]]
|
||||
|
||||
def update(self, test_name: str, success_history: list[bool]) -> None:
|
||||
if test_name.startswith("Test"):
|
||||
test_name = test_name[4:]
|
||||
self.tests[test_name] = success_history
|
||||
|
||||
self.save()
|
||||
|
||||
@@ -46,7 +46,7 @@ def get_agent_category(report: Report) -> dict[str, Any]:
|
||||
):
|
||||
continue
|
||||
categories.setdefault(category, 0)
|
||||
if data.metrics.success:
|
||||
if data.metrics.success and data.metrics.difficulty:
|
||||
num_dif = STRING_DIFFICULTY_MAP[data.metrics.difficulty]
|
||||
if num_dif > categories[category]:
|
||||
categories[category] = num_dif
|
||||
|
||||
@@ -1,48 +1,38 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, constr, validator
|
||||
|
||||
datetime_format = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00$"
|
||||
from pydantic import BaseModel, constr
|
||||
|
||||
|
||||
class ForbidOptionalMeta(type(BaseModel)): # metaclass to forbid optional fields
|
||||
def __new__(cls, name: str, bases: tuple, dct: Dict[str, Any]) -> Any:
|
||||
for attr_name, attr_value in dct.items():
|
||||
if (
|
||||
getattr(attr_value, "__origin__", None) == Union
|
||||
and type(None) in attr_value.__args__
|
||||
):
|
||||
raise TypeError(
|
||||
f"Optional fields are forbidden, but found in {attr_name}"
|
||||
)
|
||||
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
|
||||
|
||||
class BaseModelBenchmark(BaseModel, metaclass=ForbidOptionalMeta):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
class Metrics(BaseModelBenchmark):
|
||||
difficulty: str
|
||||
success: bool
|
||||
success_percentage: float = Field(..., alias="success_%")
|
||||
run_time: str
|
||||
fail_reason: str | None
|
||||
class Metrics(BaseModel):
|
||||
difficulty: str | None
|
||||
success: bool | None = None
|
||||
run_time: str | None = None
|
||||
fail_reason: str | None = None
|
||||
success_percentage: float | None = Field(default=None, alias="success_%")
|
||||
attempted: bool
|
||||
cost: float | None
|
||||
cost: float | None = None
|
||||
|
||||
@validator("attempted")
|
||||
def require_metrics_if_attempted(cls, v: bool, values: dict[str, Any]):
|
||||
required_fields_if_attempted = ["success", "run_time"]
|
||||
if v:
|
||||
for f in required_fields_if_attempted:
|
||||
assert (
|
||||
values.get(f) is not None
|
||||
), f"'{f}' must be defined if attempted is True"
|
||||
return v
|
||||
|
||||
|
||||
class MetricsOverall(BaseModelBenchmark):
|
||||
class MetricsOverall(BaseModel):
|
||||
run_time: str
|
||||
highest_difficulty: str
|
||||
percentage: float | None
|
||||
total_cost: float | None
|
||||
percentage: float | None = None
|
||||
total_cost: float | None = None
|
||||
|
||||
|
||||
class Test(BaseModelBenchmark):
|
||||
class Test(BaseModel):
|
||||
data_path: str
|
||||
is_regression: bool
|
||||
answer: str
|
||||
@@ -50,19 +40,19 @@ class Test(BaseModelBenchmark):
|
||||
metrics: Metrics
|
||||
category: List[str]
|
||||
task: str
|
||||
reached_cutoff: bool
|
||||
metadata: Any
|
||||
reached_cutoff: bool | None = None # None if in progress
|
||||
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ReportBase(BaseModelBenchmark):
|
||||
class ReportBase(BaseModel):
|
||||
command: str
|
||||
completion_time: str | None
|
||||
completion_time: str | None = None
|
||||
benchmark_start_time: constr(regex=datetime_format)
|
||||
metrics: MetricsOverall
|
||||
config: Dict[str, str | dict[str, str]]
|
||||
agent_git_commit_sha: str | None
|
||||
benchmark_git_commit_sha: str | None
|
||||
repo_url: str | None
|
||||
agent_git_commit_sha: str | None = None
|
||||
benchmark_git_commit_sha: str | None = None
|
||||
repo_url: str | None = None
|
||||
|
||||
|
||||
class Report(ReportBase):
|
||||
|
||||
@@ -3,13 +3,14 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from agbenchmark.challenges import ChallengeInfo
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.reports.processing.report_types import Metrics, Test
|
||||
from agbenchmark.reports.ReportManager import SingletonReportManager
|
||||
from agbenchmark.utils.data_types import ChallengeData, DifficultyLevel
|
||||
from agbenchmark.utils.data_types import DifficultyLevel
|
||||
from agbenchmark.utils.utils import calculate_success_percentage
|
||||
|
||||
# from agbenchmark.utils.get_data_from_helicone import get_data_from_helicone
|
||||
@@ -17,24 +18,22 @@ from agbenchmark.utils.utils import calculate_success_percentage
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_previous_test_results(
|
||||
test_name: str, info_details: dict[str, Any]
|
||||
) -> list[bool]:
|
||||
def get_and_update_success_history(test_name: str, info_details: Test) -> list[bool]:
|
||||
mock = os.getenv("IS_MOCK") # Check if --mock is in sys.argv
|
||||
|
||||
prev_test_results = SingletonReportManager().INTERNAL_INFO_MANAGER.tests.get(
|
||||
prev_test_results = SingletonReportManager().SUCCESS_RATE_TRACKER.tests.get(
|
||||
test_name, []
|
||||
)
|
||||
|
||||
if not mock:
|
||||
if not mock and info_details.metrics.success is not None:
|
||||
# only add if it's an actual test
|
||||
prev_test_results.append(info_details["metrics"]["success"])
|
||||
SingletonReportManager().INTERNAL_INFO_MANAGER.add_test(
|
||||
prev_test_results.append(info_details.metrics.success)
|
||||
SingletonReportManager().SUCCESS_RATE_TRACKER.update(
|
||||
test_name, prev_test_results
|
||||
)
|
||||
|
||||
# can calculate success rate regardless of mock
|
||||
info_details["metrics"]["success_%"] = calculate_success_percentage(
|
||||
info_details.metrics.success_percentage = calculate_success_percentage(
|
||||
prev_test_results
|
||||
)
|
||||
|
||||
@@ -43,26 +42,22 @@ def get_previous_test_results(
|
||||
|
||||
def update_regression_tests(
|
||||
prev_test_results: list[bool],
|
||||
info_details: dict,
|
||||
info_details: Test,
|
||||
test_name: str,
|
||||
test_details: dict,
|
||||
) -> None:
|
||||
if len(prev_test_results) >= 3 and prev_test_results[-3:] == [True, True, True]:
|
||||
# if the last 3 tests were successful, add to the regression tests
|
||||
info_details["is_regression"] = True
|
||||
SingletonReportManager().REGRESSION_MANAGER.add_test(test_name, test_details)
|
||||
info_details.is_regression = True
|
||||
SingletonReportManager().REGRESSION_MANAGER.add_test(
|
||||
test_name, info_details.dict(include={"difficulty", "data_path"})
|
||||
)
|
||||
|
||||
|
||||
def generate_single_call_report(
|
||||
def initialize_test_report(
|
||||
item: pytest.Item,
|
||||
call: pytest.CallInfo,
|
||||
challenge_data: ChallengeData,
|
||||
answers: dict[str, Any],
|
||||
challenge_location: str,
|
||||
test_name: str,
|
||||
) -> None:
|
||||
difficulty = challenge_data.info.difficulty
|
||||
|
||||
challenge_info: ChallengeInfo,
|
||||
):
|
||||
difficulty = challenge_info.difficulty
|
||||
if isinstance(difficulty, DifficultyLevel):
|
||||
difficulty = difficulty.value
|
||||
|
||||
@@ -71,105 +66,73 @@ def generate_single_call_report(
|
||||
# test_name = item.nodeid.split("::")[1]
|
||||
# item.test_name = test_name
|
||||
|
||||
test_details = {
|
||||
"difficulty": difficulty,
|
||||
"data_path": challenge_location,
|
||||
}
|
||||
|
||||
info_details: Any = {
|
||||
"data_path": challenge_location,
|
||||
"is_regression": False,
|
||||
"category": challenge_data.category,
|
||||
"task": challenge_data.task,
|
||||
"answer": challenge_data.ground.answer,
|
||||
"description": challenge_data.info.description,
|
||||
"metrics": {
|
||||
"difficulty": difficulty,
|
||||
"success": False,
|
||||
"attempted": True,
|
||||
},
|
||||
# "answers": answers,
|
||||
}
|
||||
if answers:
|
||||
info_details["answers"] = answers
|
||||
|
||||
if challenge_data.metadata:
|
||||
info_details["metadata"] = challenge_data.metadata
|
||||
|
||||
mock = os.getenv("IS_MOCK") # Check if --mock is in sys.argv
|
||||
if call:
|
||||
if call.excinfo is None:
|
||||
info_details["metrics"]["success"] = True
|
||||
else:
|
||||
if not mock: # don't remove if it's a mock test
|
||||
SingletonReportManager().REGRESSION_MANAGER.remove_test(test_name)
|
||||
info_details["metrics"]["fail_reason"] = str(call.excinfo.value)
|
||||
if call.excinfo.typename == "Skipped":
|
||||
info_details["metrics"]["attempted"] = False
|
||||
|
||||
prev_test_results: list[bool] = get_previous_test_results(test_name, info_details)
|
||||
|
||||
update_regression_tests(prev_test_results, info_details, test_name, test_details)
|
||||
test_info = dict(item.user_properties).get("info_details") or Test(
|
||||
data_path=challenge_info.source_uri,
|
||||
is_regression=False,
|
||||
category=[c.value for c in challenge_info.category],
|
||||
task=challenge_info.task,
|
||||
answer=challenge_info.reference_answer or "",
|
||||
description=challenge_info.description or "",
|
||||
metrics=Metrics(
|
||||
difficulty=difficulty,
|
||||
attempted=False,
|
||||
),
|
||||
)
|
||||
|
||||
# user facing reporting
|
||||
if item:
|
||||
item.info_details = info_details
|
||||
item.user_properties.append(("info_details", test_info))
|
||||
|
||||
return info_details
|
||||
return test_info
|
||||
|
||||
|
||||
def finalize_reports(
|
||||
config: AgentBenchmarkConfig, item: pytest.Item, challenge_data: ChallengeData
|
||||
def finalize_test_report(
|
||||
item: pytest.Item, call: pytest.CallInfo, config: AgentBenchmarkConfig
|
||||
) -> None:
|
||||
run_time = dict(item.user_properties).get("run_time")
|
||||
user_properties: dict = dict(item.user_properties)
|
||||
|
||||
info_details = getattr(item, "info_details", {})
|
||||
test_name = getattr(item, "test_name", "")
|
||||
info_details: Test = user_properties.get("info_details", {})
|
||||
test_name: str = user_properties.get("test_name", "")
|
||||
|
||||
mock = os.getenv("IS_MOCK") # Check if --mock is in sys.argv
|
||||
|
||||
logger.debug(f"Finalizing report with CallInfo: {vars(call)}")
|
||||
if call.excinfo is None:
|
||||
info_details.metrics.success = True
|
||||
else:
|
||||
if not mock: # don't remove if it's a mock test
|
||||
SingletonReportManager().REGRESSION_MANAGER.remove_test(test_name)
|
||||
info_details.metrics.fail_reason = str(call.excinfo.value)
|
||||
if call.excinfo.typename == "Skipped":
|
||||
info_details.metrics.attempted = False
|
||||
info_details.metrics.attempted = True
|
||||
info_details.metrics.run_time = f"{str(round(call.duration, 3))} seconds"
|
||||
info_details.reached_cutoff = user_properties.get("timed_out", False)
|
||||
|
||||
prev_test_results: list[bool] = get_and_update_success_history(
|
||||
test_name, info_details
|
||||
)
|
||||
|
||||
update_regression_tests(prev_test_results, info_details, test_name)
|
||||
|
||||
if info_details and test_name:
|
||||
if run_time is not None:
|
||||
cost = None
|
||||
# if "--mock" not in sys.argv and os.environ.get("HELICONE_API_KEY"):
|
||||
# logger.debug("Getting cost from Helicone")
|
||||
# cost = get_data_from_helicone(test_name)
|
||||
# logger.debug(f"Cost: {cost}")
|
||||
# if "--mock" not in sys.argv and os.environ.get("HELICONE_API_KEY"):
|
||||
# logger.debug("Getting cost from Helicone")
|
||||
# info_details.metrics.cost = get_data_from_helicone(test_name)
|
||||
# logger.debug(f"Cost: {cost}")
|
||||
|
||||
info_details["metrics"]["cost"] = cost
|
||||
if "--mock" not in sys.argv:
|
||||
update_challenges_already_beaten(
|
||||
config.challenges_already_beaten_file, info_details, test_name
|
||||
)
|
||||
|
||||
if info_details["metrics"].get("success", None) is None:
|
||||
info_details["metrics"]["attempted"] = False
|
||||
info_details["metrics"]["success"] = False
|
||||
elif (
|
||||
info_details["metrics"].get("success") is False
|
||||
and "attempted" not in info_details["metrics"]
|
||||
):
|
||||
info_details["metrics"]["attempted"] = False
|
||||
|
||||
info_details["metrics"]["run_time"] = f"{str(round(run_time, 3))} seconds"
|
||||
|
||||
info_details["reached_cutoff"] = float(run_time) > challenge_data.cutoff
|
||||
|
||||
if "--mock" not in sys.argv:
|
||||
update_challenges_already_beaten(
|
||||
config.challenges_already_beaten_file, info_details, test_name
|
||||
)
|
||||
if info_details.get("tests") is not None:
|
||||
for nested_test_name, nested_test_info in info_details[
|
||||
"tests"
|
||||
].items():
|
||||
update_challenges_already_beaten(
|
||||
config.challenges_already_beaten_file,
|
||||
nested_test_info,
|
||||
nested_test_name,
|
||||
)
|
||||
|
||||
SingletonReportManager().INFO_MANAGER.add_test(test_name, info_details)
|
||||
SingletonReportManager().INFO_MANAGER.add_test_report(test_name, info_details)
|
||||
|
||||
|
||||
def update_challenges_already_beaten(
|
||||
challenges_already_beaten_file: Path, info_details: Dict[str, Any], test_name: str
|
||||
challenges_already_beaten_file: Path, info_details: Test, test_name: str
|
||||
) -> None:
|
||||
current_run_successful = info_details["metrics"]["success"]
|
||||
current_run_successful = info_details.metrics.success
|
||||
try:
|
||||
with open(challenges_already_beaten_file, "r") as f:
|
||||
challenge_data = json.load(f)
|
||||
@@ -185,9 +148,7 @@ def update_challenges_already_beaten(
|
||||
json.dump(challenge_data, f, indent=4)
|
||||
|
||||
|
||||
def session_finish(
|
||||
agbenchmark_config: AgentBenchmarkConfig, suite_reports: dict
|
||||
) -> None:
|
||||
SingletonReportManager().INTERNAL_INFO_MANAGER.save()
|
||||
SingletonReportManager().INFO_MANAGER.end_info_report(agbenchmark_config)
|
||||
def session_finish(agbenchmark_config: AgentBenchmarkConfig) -> None:
|
||||
SingletonReportManager().INFO_MANAGER.finalize_session_report(agbenchmark_config)
|
||||
SingletonReportManager().REGRESSION_MANAGER.save()
|
||||
SingletonReportManager().SUCCESS_RATE_TRACKER.save()
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, List
|
||||
|
||||
import pytest
|
||||
from colorama import Fore, Style
|
||||
from openai import OpenAI
|
||||
|
||||
from agbenchmark.agent_api_interface import run_api_agent
|
||||
from agbenchmark.config import AgentBenchmarkConfig
|
||||
from agbenchmark.utils.data_types import ChallengeData, Ground
|
||||
from agbenchmark.utils.prompts import (
|
||||
END_PROMPT,
|
||||
FEW_SHOT_EXAMPLES,
|
||||
PROMPT_MAP,
|
||||
SCORING_MAP,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with open(
|
||||
Path(__file__).parent.parent / "challenges" / "optional_categories.json"
|
||||
) as f:
|
||||
OPTIONAL_CATEGORIES: list[str] = json.load(f)["optional_categories"]
|
||||
|
||||
|
||||
class Challenge(ABC):
|
||||
"""The parent class to all specific challenges classes.
|
||||
Defines helper methods for running a challenge"""
|
||||
|
||||
data: ChallengeData
|
||||
CHALLENGE_LOCATION: ClassVar[str]
|
||||
ARTIFACTS_LOCATION: ClassVar[str]
|
||||
scores: ClassVar[dict[str, Any]] = {} # this is for suites
|
||||
|
||||
@staticmethod
|
||||
def from_challenge_spec(spec_file: Path) -> type["Challenge"]:
|
||||
challenge_data = ChallengeData.parse_file(spec_file)
|
||||
|
||||
challenge_class_name = f"Test{challenge_data.name}"
|
||||
logger.debug(f"Creating {challenge_class_name} from spec: {spec_file}")
|
||||
return type(
|
||||
challenge_class_name,
|
||||
(Challenge,),
|
||||
{
|
||||
"data": challenge_data,
|
||||
"CHALLENGE_LOCATION": str(spec_file),
|
||||
"ARTIFACTS_LOCATION": str(spec_file.resolve().parent),
|
||||
},
|
||||
)
|
||||
|
||||
# Define test method within the dynamically created class
|
||||
@pytest.mark.asyncio
|
||||
async def test_method(
|
||||
self, config: AgentBenchmarkConfig, request: pytest.FixtureRequest
|
||||
) -> None:
|
||||
# skip optional categories
|
||||
self.skip_optional_categories(config)
|
||||
|
||||
# if os.environ.get("HELICONE_API_KEY"):
|
||||
# from helicone.lock import HeliconeLockManager
|
||||
|
||||
# HeliconeLockManager.write_custom_property("challenge", self.data.name)
|
||||
|
||||
timeout = self.data.cutoff or 60
|
||||
|
||||
if request.config.getoption("--nc"):
|
||||
timeout = 100000
|
||||
elif cutoff := request.config.getoption("--cutoff"):
|
||||
timeout = int(cutoff)
|
||||
|
||||
await self.run_challenge(config, timeout)
|
||||
|
||||
scores = self.get_scores(config.temp_folder)
|
||||
request.node.answers = (
|
||||
scores["answers"] if request.config.getoption("--keep-answers") else None
|
||||
)
|
||||
del scores["answers"] # remove answers from scores
|
||||
request.node.scores = scores # store scores in request.node
|
||||
is_score_100 = 1 in scores["values"]
|
||||
|
||||
assert is_score_100
|
||||
|
||||
async def run_challenge(self, config: AgentBenchmarkConfig, cutoff: int) -> None:
|
||||
from agbenchmark.agent_interface import copy_artifacts_into_temp_folder
|
||||
|
||||
if not self.data.task:
|
||||
return
|
||||
|
||||
print(
|
||||
f"{Fore.MAGENTA + Style.BRIGHT}{'='*24} "
|
||||
f"Starting {self.data.name} challenge"
|
||||
f" {'='*24}{Style.RESET_ALL}"
|
||||
)
|
||||
print(f"{Fore.BLACK}Task: {self.data.task}{Fore.RESET}")
|
||||
|
||||
await run_api_agent(self.data, config, self.ARTIFACTS_LOCATION, cutoff)
|
||||
|
||||
# hidden files are added after the agent runs. Hidden files can be python test files.
|
||||
# We copy them in the temporary folder to make it easy to import the code produced by the agent
|
||||
artifact_paths = [
|
||||
self.ARTIFACTS_LOCATION,
|
||||
str(Path(self.CHALLENGE_LOCATION).parent),
|
||||
]
|
||||
for path in artifact_paths:
|
||||
copy_artifacts_into_temp_folder(config.temp_folder, "custom_python", path)
|
||||
|
||||
@staticmethod
|
||||
def get_artifacts_out(
|
||||
workspace: str | Path | dict[str, str], ground: Ground
|
||||
) -> List[str]:
|
||||
if isinstance(workspace, dict):
|
||||
workspace = workspace["output"]
|
||||
|
||||
script_dir = workspace
|
||||
files_contents = []
|
||||
|
||||
for file_pattern in ground.files:
|
||||
# Check if it is a file extension
|
||||
if file_pattern.startswith("."):
|
||||
# Find all files with the given extension in the workspace
|
||||
matching_files = glob.glob(os.path.join(script_dir, "*" + file_pattern))
|
||||
else:
|
||||
# Otherwise, it is a specific file
|
||||
matching_files = [os.path.join(script_dir, file_pattern)]
|
||||
|
||||
for file_path in matching_files:
|
||||
if ground.eval.type == "python":
|
||||
result = subprocess.run(
|
||||
[sys.executable, file_path],
|
||||
cwd=os.path.abspath(workspace),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if "error" in result.stderr or result.returncode != 0:
|
||||
print(result.stderr)
|
||||
assert False, result.stderr
|
||||
files_contents.append(f"Output: {result.stdout}\n")
|
||||
else:
|
||||
with open(file_path, "r") as f:
|
||||
files_contents.append(f.read())
|
||||
else:
|
||||
if ground.eval.type == "pytest":
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pytest"],
|
||||
cwd=os.path.abspath(workspace),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if "error" in result.stderr or result.returncode != 0:
|
||||
print(result.stderr)
|
||||
assert False, result.stderr
|
||||
files_contents.append(f"Output: {result.stdout}\n")
|
||||
|
||||
return files_contents
|
||||
|
||||
@staticmethod
|
||||
def scoring(content: str, ground: Ground) -> float:
|
||||
print(f"{Fore.BLUE}Scoring content:{Style.RESET_ALL}", content)
|
||||
if ground.should_contain:
|
||||
for should_contain_word in ground.should_contain:
|
||||
if not getattr(ground, "case_sensitive", True):
|
||||
should_contain_word = should_contain_word.lower()
|
||||
content = content.lower()
|
||||
print_content = (
|
||||
f"{Fore.BLUE}Word that should exist{Style.RESET_ALL}"
|
||||
f" - {should_contain_word}:"
|
||||
)
|
||||
if should_contain_word not in content:
|
||||
print(print_content, "False")
|
||||
return 0.0
|
||||
else:
|
||||
print(print_content, "True")
|
||||
|
||||
if ground.should_not_contain:
|
||||
for should_not_contain_word in ground.should_not_contain:
|
||||
if not getattr(ground, "case_sensitive", True):
|
||||
should_not_contain_word = should_not_contain_word.lower()
|
||||
content = content.lower()
|
||||
print_content = (
|
||||
f"{Fore.BLUE}Word that should not exist{Style.RESET_ALL}"
|
||||
f" - {should_not_contain_word}:"
|
||||
)
|
||||
if should_not_contain_word in content:
|
||||
print(print_content, "False")
|
||||
return 0.0
|
||||
else:
|
||||
print(print_content, "True")
|
||||
|
||||
return 1.0
|
||||
|
||||
@classmethod
|
||||
def llm_eval(cls, content: str, ground: Ground) -> float:
|
||||
openai_client = OpenAI()
|
||||
if os.getenv("IS_MOCK"):
|
||||
return 1.0
|
||||
|
||||
# the validation for this is done in the Eval BaseModel
|
||||
scoring = SCORING_MAP[ground.eval.scoring] # type: ignore
|
||||
prompt = PROMPT_MAP[ground.eval.template].format( # type: ignore
|
||||
task=cls.data.task, scoring=scoring, answer=ground.answer, response=content
|
||||
)
|
||||
|
||||
if ground.eval.examples:
|
||||
prompt += FEW_SHOT_EXAMPLES.format(examples=ground.eval.examples)
|
||||
|
||||
prompt += END_PROMPT
|
||||
|
||||
answer = openai_client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{"role": "system", "content": prompt},
|
||||
],
|
||||
)
|
||||
|
||||
return float(answer.choices[0].message.content) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_scores(cls, workspace: Path) -> dict[str, Any]:
|
||||
scores = []
|
||||
scores_dict: Any = {}
|
||||
percentage = None
|
||||
answers = {}
|
||||
try:
|
||||
if cls.data.task == "" and os.getenv("IS_MOCK"):
|
||||
scores = [1.0]
|
||||
answers = {"mock": "This is a mock answer"}
|
||||
elif isinstance(cls.data.ground, Ground):
|
||||
files_contents = cls.get_artifacts_out(workspace, cls.data.ground)
|
||||
answers = {"answer": files_contents}
|
||||
for file_content in files_contents:
|
||||
score = cls.scoring(file_content, cls.data.ground)
|
||||
print(f"{Fore.GREEN}Your score is:{Style.RESET_ALL}", score)
|
||||
scores.append(score)
|
||||
|
||||
if cls.data.ground.eval.type == "llm":
|
||||
llm_eval = cls.llm_eval("\n".join(files_contents), cls.data.ground)
|
||||
if cls.data.ground.eval.scoring == "percentage":
|
||||
scores.append(math.ceil(llm_eval / 100))
|
||||
elif cls.data.ground.eval.scoring == "scale":
|
||||
scores.append(math.ceil(llm_eval / 10))
|
||||
print(f"{Fore.GREEN}Your score is:{Style.RESET_ALL}", llm_eval)
|
||||
|
||||
scores.append(llm_eval)
|
||||
except Exception as e:
|
||||
print("Error getting scores", e)
|
||||
|
||||
scores_data = {
|
||||
"values": scores,
|
||||
"scores_obj": scores_dict,
|
||||
"percentage": percentage,
|
||||
"answers": answers,
|
||||
}
|
||||
|
||||
cls.scores[cls.__name__] = scores_data
|
||||
|
||||
return scores_data
|
||||
|
||||
def get_dummy_scores(self, test_name: str, scores: dict[str, Any]) -> int | None:
|
||||
return 1 # remove this once this works
|
||||
if 1 in scores.get("scores_obj", {}).get(test_name, []):
|
||||
return 1
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def skip_optional_categories(cls, config: AgentBenchmarkConfig) -> None:
|
||||
challenge_categories = set(c.value for c in cls.data.category)
|
||||
challenge_optional_categories = challenge_categories & set(OPTIONAL_CATEGORIES)
|
||||
if challenge_optional_categories and not (
|
||||
config.categories
|
||||
and set(challenge_optional_categories).issubset(set(config.categories))
|
||||
):
|
||||
pytest.skip(
|
||||
f"Category {', '.join(challenge_optional_categories)} is optional, "
|
||||
"and not explicitly selected in the benchmark config."
|
||||
)
|
||||
@@ -1,8 +1,7 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, constr, validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DifficultyLevel(Enum):
|
||||
@@ -29,87 +28,19 @@ DIFFICULTY_MAP = {
|
||||
STRING_DIFFICULTY_MAP = {e.value: DIFFICULTY_MAP[e] for e in DifficultyLevel}
|
||||
|
||||
|
||||
class Info(BaseModel):
|
||||
difficulty: DifficultyLevel
|
||||
description: constr(regex=r"^Tests if the agent can.*")
|
||||
side_effects: List[str]
|
||||
|
||||
@validator("difficulty", pre=True)
|
||||
def difficulty_to_enum(cls: "Info", v: str | DifficultyLevel) -> DifficultyLevel:
|
||||
"""Convert a string to an instance of DifficultyLevel."""
|
||||
if isinstance(v, DifficultyLevel):
|
||||
return v
|
||||
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return DifficultyLevel(v.lower())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
raise ValueError(f"Cannot convert {v} to DifficultyLevel.")
|
||||
|
||||
|
||||
class Eval(BaseModel):
|
||||
type: str
|
||||
scoring: Optional[str]
|
||||
template: Optional[str]
|
||||
examples: Optional[str]
|
||||
|
||||
@validator("scoring", "template", always=True)
|
||||
def validate_eval_fields(cls, v, values, field):
|
||||
if "type" in values and values["type"] == "llm":
|
||||
if v is None:
|
||||
raise ValueError(f"{field.name} must be provided when type is 'llm'")
|
||||
else:
|
||||
if v is not None:
|
||||
raise ValueError(f"{field.name} should only exist when type is 'llm'")
|
||||
return v
|
||||
|
||||
@validator("scoring")
|
||||
def validate_scoring(cls, v):
|
||||
if v is not None and v not in ["percentage", "scale", "binary"]:
|
||||
raise ValueError(
|
||||
"scoring must be either 'percentage', 'scale', or 'binary'"
|
||||
)
|
||||
return v
|
||||
|
||||
@validator("template")
|
||||
def validate_template(cls, v):
|
||||
if v is not None and v not in ["rubric", "reference", "question", "custom"]:
|
||||
raise ValueError(
|
||||
"template must be either 'rubric', 'reference', 'question', or 'custom'"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class Ground(BaseModel):
|
||||
answer: str
|
||||
should_contain: Optional[List[str]] = None
|
||||
should_not_contain: Optional[List[str]] = None
|
||||
files: List[str]
|
||||
case_sensitive: Optional[bool] = True
|
||||
eval: Eval
|
||||
|
||||
|
||||
class Category(str, Enum):
|
||||
DATA = "data"
|
||||
GENERALIST = "general"
|
||||
CODING = "coding"
|
||||
SCRAPE_SYNTHESIZE = "scrape_synthesize"
|
||||
WEB = "web"
|
||||
GAIA_1 = "GAIA_1"
|
||||
GAIA_2 = "GAIA_2"
|
||||
GAIA_3 = "GAIA_3"
|
||||
|
||||
|
||||
class ChallengeData(BaseModel):
|
||||
eval_id: str = ""
|
||||
name: str
|
||||
category: List[Category]
|
||||
task: str
|
||||
dependencies: List[str]
|
||||
cutoff: int
|
||||
ground: Ground | Dict[str, Ground]
|
||||
info: Info | Dict[str, Info]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
spec_file: Path | None = Field(None, exclude=True)
|
||||
class EvalResult(BaseModel):
|
||||
result: str
|
||||
result_source: Literal["step_output"] | str
|
||||
score: float
|
||||
passed: bool
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agbenchmark.reports.processing.report_types import Test
|
||||
from agbenchmark.utils.data_types import DIFFICULTY_MAP, DifficultyLevel
|
||||
|
||||
load_dotenv()
|
||||
@@ -63,41 +64,31 @@ def get_test_path(json_file: str | Path) -> str:
|
||||
|
||||
|
||||
def get_highest_success_difficulty(
|
||||
data: dict, just_string: Optional[bool] = None
|
||||
data: dict[str, Test], just_string: Optional[bool] = None
|
||||
) -> str:
|
||||
highest_difficulty = None
|
||||
highest_difficulty_level = 0
|
||||
|
||||
for test_name, test_data in data.items():
|
||||
try:
|
||||
if test_data.get("tests", None):
|
||||
highest_difficulty_str = test_data["metrics"]["highest_difficulty"]
|
||||
if test_data.metrics.success:
|
||||
difficulty_str = test_data.metrics.difficulty
|
||||
if not difficulty_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
highest_difficulty = DifficultyLevel[highest_difficulty_str]
|
||||
highest_difficulty_level = DIFFICULTY_MAP[highest_difficulty]
|
||||
difficulty_enum = DifficultyLevel[difficulty_str.lower()]
|
||||
difficulty_level = DIFFICULTY_MAP[difficulty_enum]
|
||||
|
||||
if difficulty_level > highest_difficulty_level:
|
||||
highest_difficulty = difficulty_enum
|
||||
highest_difficulty_level = difficulty_level
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"Unexpected difficulty level '{highest_difficulty_str}' "
|
||||
f"Unexpected difficulty level '{difficulty_str}' "
|
||||
f"in test '{test_name}'"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
if test_data["metrics"]["success"]:
|
||||
difficulty_str = test_data["metrics"]["difficulty"]
|
||||
|
||||
try:
|
||||
difficulty_enum = DifficultyLevel[difficulty_str.lower()]
|
||||
difficulty_level = DIFFICULTY_MAP[difficulty_enum]
|
||||
|
||||
if difficulty_level > highest_difficulty_level:
|
||||
highest_difficulty = difficulty_enum
|
||||
highest_difficulty_level = difficulty_level
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
f"Unexpected difficulty level '{difficulty_str}' "
|
||||
f"in test '{test_name}'"
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"An unexpected error [1] occurred while analyzing report [2]."
|
||||
|
||||
Reference in New Issue
Block a user