Files
Auto-GPT/agbenchmark/challenge.py
2023-07-08 03:27:31 -04:00

154 lines
5.1 KiB
Python

import glob
import inspect
import os
import subprocess
import types
from abc import ABC, ABCMeta
from typing import Any, Dict, List, Optional, Tuple, Type, cast
import pytest
from dotenv import load_dotenv
from agbenchmark.challenges.define_task_types import ChallengeData, Ground
load_dotenv()
mock_test_str = os.getenv("MOCK_TEST")
MOCK_TEST = mock_test_str.lower() == "true" if mock_test_str else False
class ChallengeMeta(ABCMeta):
def __init__(self, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> None:
super().__init__(name, bases, dct)
try:
frame = cast(types.FrameType, inspect.currentframe())
assert frame.f_back is not None
self.CHALLENGE_LOCATION = os.path.dirname(inspect.getfile(frame.f_back))
except Exception as e:
print(f"Unable to get the file from 8 frames back due to: {str(e)}")
raise e
class Challenge(ABC, metaclass=ChallengeMeta):
"""The parent class to all specific challenges classes.
Defines helper methods for running a challenge"""
_data_cache: Dict[str, ChallengeData] = {}
CHALLENGE_LOCATION: str
@property
def data(self) -> ChallengeData:
file_path = f"{self.CHALLENGE_LOCATION}/data.json"
Challenge._data_cache[file_path] = ChallengeData.deserialize(file_path)
return Challenge._data_cache[file_path]
@property
def mock(self) -> Optional[str]:
return self.data.mock.mock_func if self.data.mock else None
@property
def task(self) -> str:
return str(
self.data.mock.mock_task if self.data.mock and MOCK_TEST else self.data.task
)
@property
def dependencies(self) -> list:
return self.data.dependencies
def setup_challenge(self, config: Dict[str, Any]) -> None:
from agbenchmark.agent_interface import copy_artifacts_into_workspace, run_agent
copy_artifacts_into_workspace(
config["workspace"], "artifacts_in", self.__class__.CHALLENGE_LOCATION
)
run_agent(self.task, self.mock, config, self.__class__.CHALLENGE_LOCATION)
@property
def name(self) -> str:
return self.data.name
@pytest.mark.parametrize(
"challenge_data",
[data],
indirect=True,
)
def test_method(self, config: Dict[str, Any]) -> None:
raise NotImplementedError
@staticmethod
def open_file(workspace: str, filename: str) -> str:
script_dir = workspace
workspace_dir = os.path.join(script_dir, filename)
with open(workspace_dir, "r") as f:
return f.read()
def get_artifacts_out(self, workspace: str, file_patterns: list) -> List[str]:
script_dir = workspace
files_contents = []
for file_pattern in file_patterns:
# Check if it is a file extension
if file_pattern.startswith("."):
# Find all files with the given extension in the workspace
matching_files = glob.glob(os.path.join(script_dir, "*" + file_pattern))
else:
# Otherwise, it is a specific file
matching_files = [os.path.join(script_dir, file_pattern)]
for file_path in matching_files:
if self.data.ground.type == "execute_python_code":
result = subprocess.run(
["python3", file_path],
cwd=os.path.abspath(workspace),
capture_output=True,
text=True,
)
files_contents.append(result.stdout)
else:
with open(file_path, "r") as f:
files_contents.append(f.read())
return files_contents
@staticmethod
def write_to_file(workspace: str, filename: str, content: str) -> None:
script_dir = workspace
print("Writing file at", script_dir)
workspace_dir = os.path.join(script_dir, filename)
# Open the file in write mode.
with open(workspace_dir, "w") as f:
# Write the content to the file.
f.write(content)
def get_filenames_in_workspace(self, workspace: str) -> List[str]:
return [
filename
for filename in os.listdir(workspace)
if os.path.isfile(os.path.join(workspace, filename))
]
def scoring(self, content: str, ground: Ground) -> float:
if ground.should_contain:
for should_contain_word in ground.should_contain:
if should_contain_word not in content:
return 0.0
else:
print(
f"Word that should exist: {should_contain_word} exists in the content"
)
if ground.should_not_contain:
for should_not_contain_word in ground.should_not_contain:
if should_not_contain_word in content:
return 0.0
else:
print(
f"Word that should not exist: {should_not_contain_word} does not exist in the content"
)
return 1.0