feat(benchmark): Get agent task cost from Step.additional_output

This commit is contained in:
Reinier van der Leer
2024-02-16 18:10:46 +01:00
parent 752bac099b
commit 21f1e64559
3 changed files with 18 additions and 0 deletions

View File

@@ -175,18 +175,26 @@ class BuiltinChallenge(BaseChallenge):
task_id = ""
n_steps = 0
timed_out = None
agent_task_cost = None
try:
async for step in self.run_challenge(
config, timeout, mock=request.config.getoption("--mock")
):
if not task_id:
task_id = step.task_id
n_steps += 1
if step.additional_output:
agent_task_cost = step.additional_output.get(
"task_total_cost",
step.additional_output.get("task_cumulative_cost"),
)
timed_out = False
except TimeoutError:
timed_out = True
request.node.user_properties.append(("n_steps", n_steps))
request.node.user_properties.append(("timed_out", timed_out))
request.node.user_properties.append(("agent_task_cost", agent_task_cost))
agent_client_config = ClientConfig(host=config.host)
async with ApiClient(agent_client_config) as api_client:

View File

@@ -395,6 +395,7 @@ class WebArenaChallenge(BaseChallenge):
n_steps = 0
timed_out = None
agent_task_cost = None
eval_results_per_step: list[list[tuple[_Eval, EvalResult]]] = []
try:
async for step in self.run_challenge(
@@ -403,7 +404,14 @@ class WebArenaChallenge(BaseChallenge):
if not step.output:
logger.warn(f"Step has no output: {step}")
continue
n_steps += 1
if step.additional_output:
agent_task_cost = step.additional_output.get(
"task_total_cost",
step.additional_output.get("task_cumulative_cost"),
)
step_eval_results = self.evaluate_step_result(
step, mock=request.config.getoption("--mock")
)
@@ -423,6 +431,7 @@ class WebArenaChallenge(BaseChallenge):
timed_out = True
request.node.user_properties.append(("n_steps", n_steps))
request.node.user_properties.append(("timed_out", timed_out))
request.node.user_properties.append(("agent_task_cost", agent_task_cost))
# Get the column aggregate (highest score for each Eval)
# from the matrix of EvalResults per step.

View File

@@ -93,6 +93,7 @@ def add_test_result_to_report(
fail_reason=str(call.excinfo.value) if call.excinfo else None,
reached_cutoff=user_properties.get("timed_out", False),
n_steps=user_properties.get("n_steps"),
cost=user_properties.get("agent_task_cost"),
)
)
test_report.metrics.success_percentage = (