diff --git a/benchmark/agbenchmark/challenges/base.py b/benchmark/agbenchmark/challenges/base.py index 4fe73a2d..f77a08c6 100644 --- a/benchmark/agbenchmark/challenges/base.py +++ b/benchmark/agbenchmark/challenges/base.py @@ -28,6 +28,9 @@ class ChallengeInfo(BaseModel): source_uri: str """Internal reference indicating the source of the challenge specification""" + available: bool = True + unavailable_reason: str = "" + class BaseChallenge(ABC): """ diff --git a/benchmark/agbenchmark/challenges/webarena.py b/benchmark/agbenchmark/challenges/webarena.py index a11330c1..09f80108 100644 --- a/benchmark/agbenchmark/challenges/webarena.py +++ b/benchmark/agbenchmark/challenges/webarena.py @@ -179,6 +179,9 @@ class WebArenaChallengeSpec(BaseModel): intent_template_id: int instantiation_dict: dict[str, str | list[str]] + available: bool = True + unavailable_reason: str = "" + class EvalSet(BaseModel): class StringMatchEvalSet(BaseModel): exact_match: str | None @@ -288,6 +291,8 @@ class WebArenaChallenge(BaseChallenge): ], # TODO: make categories more specific reference_answer=spec.eval.reference_answer_raw_annotation, source_uri=cls.SOURCE_URI_TEMPLATE.format(task_id=spec.task_id), + available=spec.available, + unavailable_reason=spec.unavailable_reason, ) return type( f"Test{challenge_info.name}", @@ -362,6 +367,9 @@ class WebArenaChallenge(BaseChallenge): request: pytest.FixtureRequest, i_attempt: int = 0, ) -> None: + if not self._spec.available: + pytest.skip(self._spec.unavailable_reason) + # if os.environ.get("HELICONE_API_KEY"): # from helicone.lock import HeliconeLockManager @@ -426,11 +434,13 @@ class WebArenaChallenge(BaseChallenge): ) + "\n".join(f"{repr(r[0])}\n -> {repr(r[1])}" for r in evals_results) -def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]: +def load_webarena_challenges( + skip_unavailable: bool = True +) -> Iterator[type[WebArenaChallenge]]: logger.info("Loading WebArena challenges...") for site, info in site_info_map.items(): - if not info.available: + if not info.available and skip_unavailable: logger.warning( f"JungleGym site '{site}' is not available: {info.unavailable_reason} " "Skipping all challenges which use this site." @@ -457,30 +467,38 @@ def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]: for entry in challenge_dicts: try: challenge_spec = WebArenaChallengeSpec.parse_obj(entry) - for site in challenge_spec.sites: - site_info = site_info_map.get(site) - if site_info is None: - logger.warning( - f"WebArena task {challenge_spec.task_id} requires unknown site " - f"'{site}'; skipping..." - ) - break - if not site_info.available: - logger.debug( - f"WebArena task {challenge_spec.task_id} requires unavailable " - f"site '{site}'; skipping..." - ) - break - else: - yield WebArenaChallenge.from_challenge_spec(challenge_spec) - loaded += 1 - continue - skipped += 1 except ValidationError as e: failed += 1 logger.warning(f"Error validating WebArena challenge entry: {entry}") logger.warning(f"Error details: {e}") + continue + + # Check all required sites for availability + for site in challenge_spec.sites: + site_info = site_info_map.get(site) + if site_info is None: + challenge_spec.available = False + challenge_spec.unavailable_reason = ( + f"WebArena task {challenge_spec.task_id} requires unknown site " + f"'{site}'" + ) + elif not site_info.available: + challenge_spec.available = False + challenge_spec.unavailable_reason = ( + f"WebArena task {challenge_spec.task_id} requires unavailable " + f"site '{site}'" + ) + + if not challenge_spec.available and skip_unavailable: + logger.debug(f"{challenge_spec.unavailable_reason}; skipping...") + skipped += 1 + continue + + yield WebArenaChallenge.from_challenge_spec(challenge_spec) + loaded += 1 + logger.info( "Loading WebArena challenges complete: " - f"loaded {loaded}, skipped {skipped}. {failed} challenge failed to load." + f"loaded {loaded}, skipped {skipped}." + + (f" {failed} challenges failed to load." if failed else "") )