refactor(benchmark): load_webarena_challenges

- Reduce duplicate and nested statements
- Add `skip_unavailable` parameter

Related changes:
- Add `available` and `unavailable_reason` attributes to `ChallengeInfo` and `WebArenaChallengeSpec`
- Add `pytest.skip` statement to `WebArenaChallenge.test_method` to make sure unavailable challenges are not run
This commit is contained in:
Reinier van der Leer
2024-02-16 14:58:53 +01:00
parent 650a701317
commit 70e345b2ce
2 changed files with 43 additions and 22 deletions

View File

@@ -28,6 +28,9 @@ class ChallengeInfo(BaseModel):
source_uri: str
"""Internal reference indicating the source of the challenge specification"""
available: bool = True
unavailable_reason: str = ""
class BaseChallenge(ABC):
"""

View File

@@ -179,6 +179,9 @@ class WebArenaChallengeSpec(BaseModel):
intent_template_id: int
instantiation_dict: dict[str, str | list[str]]
available: bool = True
unavailable_reason: str = ""
class EvalSet(BaseModel):
class StringMatchEvalSet(BaseModel):
exact_match: str | None
@@ -288,6 +291,8 @@ class WebArenaChallenge(BaseChallenge):
], # TODO: make categories more specific
reference_answer=spec.eval.reference_answer_raw_annotation,
source_uri=cls.SOURCE_URI_TEMPLATE.format(task_id=spec.task_id),
available=spec.available,
unavailable_reason=spec.unavailable_reason,
)
return type(
f"Test{challenge_info.name}",
@@ -362,6 +367,9 @@ class WebArenaChallenge(BaseChallenge):
request: pytest.FixtureRequest,
i_attempt: int = 0,
) -> None:
if not self._spec.available:
pytest.skip(self._spec.unavailable_reason)
# if os.environ.get("HELICONE_API_KEY"):
# from helicone.lock import HeliconeLockManager
@@ -426,11 +434,13 @@ class WebArenaChallenge(BaseChallenge):
) + "\n".join(f"{repr(r[0])}\n -> {repr(r[1])}" for r in evals_results)
def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]:
def load_webarena_challenges(
skip_unavailable: bool = True
) -> Iterator[type[WebArenaChallenge]]:
logger.info("Loading WebArena challenges...")
for site, info in site_info_map.items():
if not info.available:
if not info.available and skip_unavailable:
logger.warning(
f"JungleGym site '{site}' is not available: {info.unavailable_reason} "
"Skipping all challenges which use this site."
@@ -457,30 +467,38 @@ def load_webarena_challenges() -> Iterator[type[WebArenaChallenge]]:
for entry in challenge_dicts:
try:
challenge_spec = WebArenaChallengeSpec.parse_obj(entry)
for site in challenge_spec.sites:
site_info = site_info_map.get(site)
if site_info is None:
logger.warning(
f"WebArena task {challenge_spec.task_id} requires unknown site "
f"'{site}'; skipping..."
)
break
if not site_info.available:
logger.debug(
f"WebArena task {challenge_spec.task_id} requires unavailable "
f"site '{site}'; skipping..."
)
break
else:
yield WebArenaChallenge.from_challenge_spec(challenge_spec)
loaded += 1
continue
skipped += 1
except ValidationError as e:
failed += 1
logger.warning(f"Error validating WebArena challenge entry: {entry}")
logger.warning(f"Error details: {e}")
continue
# Check all required sites for availability
for site in challenge_spec.sites:
site_info = site_info_map.get(site)
if site_info is None:
challenge_spec.available = False
challenge_spec.unavailable_reason = (
f"WebArena task {challenge_spec.task_id} requires unknown site "
f"'{site}'"
)
elif not site_info.available:
challenge_spec.available = False
challenge_spec.unavailable_reason = (
f"WebArena task {challenge_spec.task_id} requires unavailable "
f"site '{site}'"
)
if not challenge_spec.available and skip_unavailable:
logger.debug(f"{challenge_spec.unavailable_reason}; skipping...")
skipped += 1
continue
yield WebArenaChallenge.from_challenge_spec(challenge_spec)
loaded += 1
logger.info(
"Loading WebArena challenges complete: "
f"loaded {loaded}, skipped {skipped}. {failed} challenge failed to load."
f"loaded {loaded}, skipped {skipped}."
+ (f" {failed} challenges failed to load." if failed else "")
)