can now put file extensions or names in files data

This commit is contained in:
Silen Naihin
2023-06-25 19:30:04 -04:00
parent 2411c35d0e
commit d6a6e69f2e
6 changed files with 57 additions and 18 deletions

View File

@@ -1,5 +1,5 @@
import os
from typing import Optional
import glob
from agbenchmark.challenges.define_task_types import Ground
@@ -14,6 +14,26 @@ class Challenge:
with open(workspace_dir, "r") as f:
return f.read()
@staticmethod
def open_files(workspace: str, file_patterns: list):
script_dir = os.path.abspath(workspace)
files_contents = []
for file_pattern in file_patterns:
# Check if it is a file extension
if file_pattern.startswith("."):
# Find all files with the given extension in the workspace
matching_files = glob.glob(os.path.join(script_dir, "*" + file_pattern))
else:
# Otherwise, it is a specific file
matching_files = [os.path.join(script_dir, file_pattern)]
for file_path in matching_files:
with open(file_path, "r") as f:
files_contents.append(f.read())
return files_contents
@staticmethod
def write_to_file(workspace: str, filename: str, content: str):
script_dir = os.path.abspath(workspace)

View File

@@ -23,10 +23,12 @@ class TestRetrieval1(RetrievalChallenge):
indirect=True,
)
def test_retrieval(self, workspace, current_challenge_data):
file = self.open_file(workspace, data.ground.files[0])
files_contents = self.open_files(workspace, data.ground.files)
score = self.scoring(file, data.ground)
scores = []
for file_content in files_contents:
score = self.scoring(file_content, data.ground)
print("Your score is:", score)
scores.append(score)
print("You score is:", score)
assert score
assert 1 in scores

View File

@@ -32,10 +32,12 @@ class TestReadFile(BasicChallenge):
)
@pytest.mark.depends(on=data.dependencies)
def test_read_file(self, workspace):
file = self.open_file(workspace, data.ground.files[0])
files_contents = self.open_files(workspace, data.ground.files)
score = self.scoring(file, data.ground)
scores = []
for file_content in files_contents:
score = self.scoring(file_content, data.ground)
print("Your score is:", score)
scores.append(score)
print("You score is:", score)
assert score
assert 1 in scores

View File

@@ -6,7 +6,7 @@
"answer": "Washington",
"should_contain": ["Washington"],
"should_not_contain": ["New York", "Los Angeles", "San Francisco"],
"files": ["file_to_check.txt"]
"files": [".txt"]
},
"mock_func": "basic_write_file_mock",
"info": {

View File

@@ -23,10 +23,12 @@ class TestWriteFile(BasicChallenge):
)
@pytest.mark.depends(name="test_write_file")
def test_write_file(self, workspace):
file = self.open_file(workspace, data.ground.files[0])
files_contents = self.open_files(workspace, data.ground.files)
score = self.scoring(file, data.ground)
scores = []
for file_content in files_contents:
score = self.scoring(file_content, data.ground)
print("Your score is:", score)
scores.append(score)
print("You score is:", score)
assert score
assert 1 in scores

View File

@@ -1 +1,14 @@
{}
{
"TestWriteFile": {
"difficulty": "basic",
"dependencies": [],
"test": "agbenchmark/tests/basic_abilities/write_file/write_file_test.py::TestWriteFile::test_write_file[regression_data0-server_response0]"
},
"TestReadFile": {
"difficulty": "basic",
"dependencies": [
"test_write_file"
],
"test": "agbenchmark/tests/basic_abilities/read_file/read_file_test.py::TestReadFile::test_read_file[regression_data0-server_response0]"
}
}