mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-31 11:54:30 +01:00
refactor(benchmark): load_webarena_challenges
- Reduce duplicate and nested statements - Add `skip_unavailable` parameter Related changes: - Add `available` and `unavailable_reason` attributes to `ChallengeInfo` and `WebArenaChallengeSpec` - Add `pytest.skip` statement to `WebArenaChallenge.test_method` to make sure unavailable challenges are not run
This commit is contained in:
@@ -28,6 +28,9 @@ class ChallengeInfo(BaseModel):
|
||||
source_uri: str
|
||||
"""Internal reference indicating the source of the challenge specification"""
|
||||
|
||||
available: bool = True
|
||||
unavailable_reason: str = ""
|
||||
|
||||
|
||||
class BaseChallenge(ABC):
|
||||
"""
|
||||
|
||||
@@ -179,6 +179,9 @@ class WebArenaChallengeSpec(BaseModel):
|
||||
intent_template_id: int
|
||||
instantiation_dict: dict[str, str | list[str]]
|
||||
|
||||
available: bool = True
|
||||
unavailable_reason: str = ""
|
||||
|
||||
class EvalSet(BaseModel):
|
||||
class StringMatchEvalSet(BaseModel):
|
||||
exact_match: str | None
|
||||
@@ -288,6 +291,8 @@ class WebArenaChallenge(BaseChallenge):
|
||||
], # TODO: make categories more specific
|
||||
reference_answer=spec.eval.reference_answer_raw_annotation,
|
||||
source_uri=cls.SOURCE_URI_TEMPLATE.format(task_id=spec.task_id),
|
||||
available=spec.available,
|
||||
unavailable_reason=spec.unavailable_reason,
|
||||
)
|
||||
return type(
|
||||
f"Test{challenge_info.name}",
|
||||
@@ -362,6 +367,9 @@ class WebArenaChallenge(BaseChallenge):
|
||||
request: pytest.FixtureRequest,
|
||||
i_attempt: int = 0,
|
||||
) -> None:
|
||||
if not self._spec.available:
|
||||
pytest.skip(self._spec.unavailable_reason)
|
||||
|
||||
# if os.environ.get("HELICONE_API_KEY"):
|
||||
# from helicone.lock import HeliconeLockManager
|
||||
|
||||
@@ -426,11 +434,13 @@ class WebArenaChallenge(BaseChallenge):
|
||||
) + "\n".join(f"{repr(r[0])}\n -> {repr(r[1])}" for r in evals_results)
|
||||
|
||||
|
||||
def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]:
|
||||
def load_webarena_challenges(
|
||||
skip_unavailable: bool = True
|
||||
) -> Iterator[type[WebArenaChallenge]]:
|
||||
logger.info("Loading WebArena challenges...")
|
||||
|
||||
for site, info in site_info_map.items():
|
||||
if not info.available:
|
||||
if not info.available and skip_unavailable:
|
||||
logger.warning(
|
||||
f"JungleGym site '{site}' is not available: {info.unavailable_reason} "
|
||||
"Skipping all challenges which use this site."
|
||||
@@ -457,30 +467,38 @@ def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]:
|
||||
for entry in challenge_dicts:
|
||||
try:
|
||||
challenge_spec = WebArenaChallengeSpec.parse_obj(entry)
|
||||
for site in challenge_spec.sites:
|
||||
site_info = site_info_map.get(site)
|
||||
if site_info is None:
|
||||
logger.warning(
|
||||
f"WebArena task {challenge_spec.task_id} requires unknown site "
|
||||
f"'{site}'; skipping..."
|
||||
)
|
||||
break
|
||||
if not site_info.available:
|
||||
logger.debug(
|
||||
f"WebArena task {challenge_spec.task_id} requires unavailable "
|
||||
f"site '{site}'; skipping..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
yield WebArenaChallenge.from_challenge_spec(challenge_spec)
|
||||
loaded += 1
|
||||
continue
|
||||
skipped += 1
|
||||
except ValidationError as e:
|
||||
failed += 1
|
||||
logger.warning(f"Error validating WebArena challenge entry: {entry}")
|
||||
logger.warning(f"Error details: {e}")
|
||||
continue
|
||||
|
||||
# Check all required sites for availability
|
||||
for site in challenge_spec.sites:
|
||||
site_info = site_info_map.get(site)
|
||||
if site_info is None:
|
||||
challenge_spec.available = False
|
||||
challenge_spec.unavailable_reason = (
|
||||
f"WebArena task {challenge_spec.task_id} requires unknown site "
|
||||
f"'{site}'"
|
||||
)
|
||||
elif not site_info.available:
|
||||
challenge_spec.available = False
|
||||
challenge_spec.unavailable_reason = (
|
||||
f"WebArena task {challenge_spec.task_id} requires unavailable "
|
||||
f"site '{site}'"
|
||||
)
|
||||
|
||||
if not challenge_spec.available and skip_unavailable:
|
||||
logger.debug(f"{challenge_spec.unavailable_reason}; skipping...")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
yield WebArenaChallenge.from_challenge_spec(challenge_spec)
|
||||
loaded += 1
|
||||
|
||||
logger.info(
|
||||
"Loading WebArena challenges complete: "
|
||||
f"loaded {loaded}, skipped {skipped}. {failed} challenge failed to load."
|
||||
f"loaded {loaded}, skipped {skipped}."
|
||||
+ (f" {failed} challenges failed to load." if failed else "")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user