diff --git a/agbenchmark/Challenge.py b/agbenchmark/Challenge.py index 9828a0e9..d159296b 100644 --- a/agbenchmark/Challenge.py +++ b/agbenchmark/Challenge.py @@ -1,5 +1,5 @@ import os -from typing import Optional +import glob from agbenchmark.challenges.define_task_types import Ground @@ -14,6 +14,26 @@ class Challenge: with open(workspace_dir, "r") as f: return f.read() + @staticmethod + def open_files(workspace: str, file_patterns: list): + script_dir = os.path.abspath(workspace) + files_contents = [] + + for file_pattern in file_patterns: + # Check if it is a file extension + if file_pattern.startswith("."): + # Find all files with the given extension in the workspace + matching_files = glob.glob(os.path.join(script_dir, "*" + file_pattern)) + else: + # Otherwise, it is a specific file + matching_files = [os.path.join(script_dir, file_pattern)] + + for file_path in matching_files: + with open(file_path, "r") as f: + files_contents.append(f.read()) + + return files_contents + @staticmethod def write_to_file(workspace: str, filename: str, content: str): script_dir = os.path.abspath(workspace) diff --git a/agbenchmark/challenges/retrieval/r1/r1_test.py b/agbenchmark/challenges/retrieval/r1/r1_test.py index 489d298f..2a7d92a7 100644 --- a/agbenchmark/challenges/retrieval/r1/r1_test.py +++ b/agbenchmark/challenges/retrieval/r1/r1_test.py @@ -23,10 +23,12 @@ class TestRetrieval1(RetrievalChallenge): indirect=True, ) def test_retrieval(self, workspace, current_challenge_data): - file = self.open_file(workspace, data.ground.files[0]) + files_contents = self.open_files(workspace, data.ground.files) - score = self.scoring(file, data.ground) + scores = [] + for file_content in files_contents: + score = self.scoring(file_content, data.ground) + print("Your score is:", score) + scores.append(score) - print("You score is:", score) - - assert score + assert 1 in scores diff --git a/agbenchmark/tests/basic_abilities/read_file/read_file_test.py b/agbenchmark/tests/basic_abilities/read_file/read_file_test.py index 7d14228c..90946670 100644 --- a/agbenchmark/tests/basic_abilities/read_file/read_file_test.py +++ b/agbenchmark/tests/basic_abilities/read_file/read_file_test.py @@ -32,10 +32,12 @@ class TestReadFile(BasicChallenge): ) @pytest.mark.depends(on=data.dependencies) def test_read_file(self, workspace): - file = self.open_file(workspace, data.ground.files[0]) + files_contents = self.open_files(workspace, data.ground.files) - score = self.scoring(file, data.ground) + scores = [] + for file_content in files_contents: + score = self.scoring(file_content, data.ground) + print("Your score is:", score) + scores.append(score) - print("You score is:", score) - - assert score + assert 1 in scores diff --git a/agbenchmark/tests/basic_abilities/write_file/w_file_data.json b/agbenchmark/tests/basic_abilities/write_file/w_file_data.json index 1d262108..037c5bd8 100644 --- a/agbenchmark/tests/basic_abilities/write_file/w_file_data.json +++ b/agbenchmark/tests/basic_abilities/write_file/w_file_data.json @@ -6,7 +6,7 @@ "answer": "Washington", "should_contain": ["Washington"], "should_not_contain": ["New York", "Los Angeles", "San Francisco"], - "files": ["file_to_check.txt"] + "files": [".txt"] }, "mock_func": "basic_write_file_mock", "info": { diff --git a/agbenchmark/tests/basic_abilities/write_file/write_file_test.py b/agbenchmark/tests/basic_abilities/write_file/write_file_test.py index 33012889..187378ff 100644 --- a/agbenchmark/tests/basic_abilities/write_file/write_file_test.py +++ b/agbenchmark/tests/basic_abilities/write_file/write_file_test.py @@ -23,10 +23,12 @@ class TestWriteFile(BasicChallenge): ) @pytest.mark.depends(name="test_write_file") def test_write_file(self, workspace): - file = self.open_file(workspace, data.ground.files[0]) + files_contents = self.open_files(workspace, data.ground.files) - score = self.scoring(file, data.ground) + scores = [] + for file_content in files_contents: + score = self.scoring(file_content, data.ground) + print("Your score is:", score) + scores.append(score) - print("You score is:", score) - - assert score + assert 1 in scores diff --git a/agbenchmark/tests/regression/regression_tests.json b/agbenchmark/tests/regression/regression_tests.json index 9e26dfee..c84fc9c9 100644 --- a/agbenchmark/tests/regression/regression_tests.json +++ b/agbenchmark/tests/regression/regression_tests.json @@ -1 +1,14 @@ -{} \ No newline at end of file +{ + "TestWriteFile": { + "difficulty": "basic", + "dependencies": [], + "test": "agbenchmark/tests/basic_abilities/write_file/write_file_test.py::TestWriteFile::test_write_file[regression_data0-server_response0]" + }, + "TestReadFile": { + "difficulty": "basic", + "dependencies": [ + "test_write_file" + ], + "test": "agbenchmark/tests/basic_abilities/read_file/read_file_test.py::TestReadFile::test_read_file[regression_data0-server_response0]" + } +} \ No newline at end of file