From db48e7849beb4366460c08b874249dff78f50b55 Mon Sep 17 00:00:00 2001 From: merwanehamadi Date: Sun, 6 Aug 2023 20:59:53 -0700 Subject: [PATCH] Add product advisor tests (#267) --- agbenchmark/challenges | 2 +- agbenchmark/generate_test.py | 1 + agbenchmark/start_benchmark.py | 5 +++++ agbenchmark/utils/challenge.py | 15 +++++++++++++++ agbenchmark/utils/utils.py | 11 ++++++++++- 5 files changed, 32 insertions(+), 2 deletions(-) diff --git a/agbenchmark/challenges b/agbenchmark/challenges index f6bafad9..c3f6ae1f 160000 --- a/agbenchmark/challenges +++ b/agbenchmark/challenges @@ -1 +1 @@ -Subproject commit f6bafad9e45093099a4e4bec97bca17c447b530e +Subproject commit c3f6ae1f4b7282ce68e0f8808379d72d84611225 diff --git a/agbenchmark/generate_test.py b/agbenchmark/generate_test.py index b912d5b1..062b39d5 100644 --- a/agbenchmark/generate_test.py +++ b/agbenchmark/generate_test.py @@ -97,6 +97,7 @@ def create_single_test( # Define test method within the dynamically created class def test_method(self, config: Dict[str, Any], request) -> None: # type: ignore + self.skip_optional_categories(config) from helicone.lock import HeliconeLockManager if os.environ.get("HELICONE_API_KEY"): diff --git a/agbenchmark/start_benchmark.py b/agbenchmark/start_benchmark.py index 559eecee..39bd357a 100644 --- a/agbenchmark/start_benchmark.py +++ b/agbenchmark/start_benchmark.py @@ -33,6 +33,11 @@ if os.environ.get("HELICONE_API_KEY"): ) = calculate_dynamic_paths() BENCHMARK_GIT_COMMIT_SHA = get_git_commit_sha(HOME_DIRECTORY / ".." / "..") AGENT_GIT_COMMIT_SHA = get_git_commit_sha(HOME_DIRECTORY) +# open a file in the challenges/optional_categories +with open( + Path(__file__).resolve().parent / "challenges" / "optional_categories.json" +) as f: + OPTIONAL_CATEGORIES = json.load(f)["optional_categories"] @click.group() diff --git a/agbenchmark/utils/challenge.py b/agbenchmark/utils/challenge.py index 0831621b..6495957e 100644 --- a/agbenchmark/utils/challenge.py +++ b/agbenchmark/utils/challenge.py @@ -7,8 +7,10 @@ from abc import ABC from typing import Any, Dict, List import openai +import pytest from agbenchmark.agent_interface import MOCK_FLAG +from agbenchmark.start_benchmark import OPTIONAL_CATEGORIES from agbenchmark.utils.data_types import ChallengeData, Ground from agbenchmark.utils.prompts import ( END_PROMPT, @@ -16,6 +18,7 @@ from agbenchmark.utils.prompts import ( PROMPT_MAP, SCORING_MAP, ) +from agbenchmark.utils.utils import agent_eligibible_for_optional_categories class Challenge(ABC): @@ -262,3 +265,15 @@ class Challenge(ABC): return 1 return None + + def skip_optional_categories(self, config: Dict[str, Any]) -> None: + challenge_category = self.data.category + categories = [ + category + for category in OPTIONAL_CATEGORIES + if category in challenge_category + ] + if not agent_eligibible_for_optional_categories( + categories, config.get("category", []) + ): + pytest.skip("Agent is not eligible for this category") diff --git a/agbenchmark/utils/utils.py b/agbenchmark/utils/utils.py index d42fd69b..41b659e9 100644 --- a/agbenchmark/utils/utils.py +++ b/agbenchmark/utils/utils.py @@ -5,7 +5,7 @@ import re import sys from datetime import datetime from pathlib import Path -from typing import Any, Optional +from typing import Any, List, Optional import git from dotenv import load_dotenv @@ -285,3 +285,12 @@ def get_git_commit_sha(directory: Path) -> Optional[str]: except Exception: print(f"{directory} is not a git repository!") return None + + +def agent_eligibible_for_optional_categories( + optional_challenge_categories: List, agent_categories: List +) -> bool: + for element in optional_challenge_categories: + if element not in agent_categories: + return False + return True