mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-18 06:24:20 +01:00
Quality update
This commit is contained in:
@@ -34,7 +34,9 @@ def main() -> None:
|
|||||||
# Initialize memory and make sure it is empty.
|
# Initialize memory and make sure it is empty.
|
||||||
# this is particularly important for indexing and referencing pinecone memory
|
# this is particularly important for indexing and referencing pinecone memory
|
||||||
memory = get_memory(cfg, init=True)
|
memory = get_memory(cfg, init=True)
|
||||||
logger.typewriter_log(f"Using memory of type:", Fore.GREEN, f"{memory.__class__.__name__}")
|
logger.typewriter_log(
|
||||||
|
f"Using memory of type:", Fore.GREEN, f"{memory.__class__.__name__}"
|
||||||
|
)
|
||||||
logger.typewriter_log(f"Using Browser:", Fore.GREEN, cfg.selenium_web_browser)
|
logger.typewriter_log(f"Using Browser:", Fore.GREEN, cfg.selenium_web_browser)
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
ai_name=ai_name,
|
ai_name=ai_name,
|
||||||
|
|||||||
@@ -93,9 +93,9 @@ def map_command_synonyms(command_name: str):
|
|||||||
string matches a list of common/known hallucinations
|
string matches a list of common/known hallucinations
|
||||||
"""
|
"""
|
||||||
synonyms = [
|
synonyms = [
|
||||||
('write_file', 'write_to_file'),
|
("write_file", "write_to_file"),
|
||||||
('create_file', 'write_to_file'),
|
("create_file", "write_to_file"),
|
||||||
('search', 'google')
|
("search", "google"),
|
||||||
]
|
]
|
||||||
for seen_command, actual_command_name in synonyms:
|
for seen_command, actual_command_name in synonyms:
|
||||||
if command_name == seen_command:
|
if command_name == seen_command:
|
||||||
@@ -125,7 +125,7 @@ def execute_command(command_name: str, arguments):
|
|||||||
google_result = google_official_search(arguments["input"])
|
google_result = google_official_search(arguments["input"])
|
||||||
else:
|
else:
|
||||||
google_result = google_search(arguments["input"])
|
google_result = google_search(arguments["input"])
|
||||||
safe_message = google_result.encode('utf-8', 'ignore')
|
safe_message = google_result.encode("utf-8", "ignore")
|
||||||
return str(safe_message)
|
return str(safe_message)
|
||||||
elif command_name == "memory_add":
|
elif command_name == "memory_add":
|
||||||
return memory.add(arguments["string"])
|
return memory.add(arguments["string"])
|
||||||
@@ -144,7 +144,9 @@ def execute_command(command_name: str, arguments):
|
|||||||
elif command_name == "get_hyperlinks":
|
elif command_name == "get_hyperlinks":
|
||||||
return get_hyperlinks(arguments["url"])
|
return get_hyperlinks(arguments["url"])
|
||||||
elif command_name == "clone_repository":
|
elif command_name == "clone_repository":
|
||||||
return clone_repository(arguments["repository_url"], arguments["clone_path"])
|
return clone_repository(
|
||||||
|
arguments["repository_url"], arguments["clone_path"]
|
||||||
|
)
|
||||||
elif command_name == "read_file":
|
elif command_name == "read_file":
|
||||||
return read_file(arguments["file"])
|
return read_file(arguments["file"])
|
||||||
elif command_name == "write_to_file":
|
elif command_name == "write_to_file":
|
||||||
@@ -278,7 +280,9 @@ def list_agents():
|
|||||||
Returns:
|
Returns:
|
||||||
str: A list of all agents
|
str: A list of all agents
|
||||||
"""
|
"""
|
||||||
return "List of agents:\n" + "\n".join([str(x[0]) + ": " + x[1] for x in AGENT_MANAGER.list_agents()])
|
return "List of agents:\n" + "\n".join(
|
||||||
|
[str(x[0]) + ": " + x[1] for x in AGENT_MANAGER.list_agents()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def delete_agent(key: str) -> str:
|
def delete_agent(key: str) -> str:
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def parse_arguments() -> None:
|
|||||||
"--use-browser",
|
"--use-browser",
|
||||||
"-b",
|
"-b",
|
||||||
dest="browser_name",
|
dest="browser_name",
|
||||||
help="Specifies which web-browser to use when using selenium to scrape the web."
|
help="Specifies which web-browser to use when using selenium to scrape the web.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ai-settings",
|
"--ai-settings",
|
||||||
|
|||||||
@@ -99,8 +99,8 @@ def execute_shell(command_line: str) -> str:
|
|||||||
str: The output of the command
|
str: The output of the command
|
||||||
"""
|
"""
|
||||||
current_dir = os.getcwd()
|
current_dir = os.getcwd()
|
||||||
|
# Change dir into workspace if necessary
|
||||||
if str(WORKING_DIRECTORY) not in current_dir: # Change dir into workspace if necessary
|
if str(WORKING_DIRECTORY) not in current_dir:
|
||||||
work_dir = os.path.join(os.getcwd(), WORKING_DIRECTORY)
|
work_dir = os.path.join(os.getcwd(), WORKING_DIRECTORY)
|
||||||
os.chdir(work_dir)
|
os.chdir(work_dir)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
|
"""Git operations for autogpt"""
|
||||||
import git
|
import git
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
|
|
||||||
cfg = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
def clone_repository(repo_url, clone_path):
|
def clone_repository(repo_url: str, clone_path: str) -> str:
|
||||||
"""Clone a github repository locally"""
|
"""Clone a github repository locally
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_url (str): The URL of the repository to clone
|
||||||
|
clone_path (str): The path to clone the repository to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the clone operation"""
|
||||||
split_url = repo_url.split("//")
|
split_url = repo_url.split("//")
|
||||||
auth_repo_url = f"//{cfg.github_username}:{cfg.github_api_key}@".join(split_url)
|
auth_repo_url = f"//{CFG.github_username}:{CFG.github_api_key}@".join(split_url)
|
||||||
git.Repo.clone_from(auth_repo_url, clone_path)
|
git.Repo.clone_from(auth_repo_url, clone_path)
|
||||||
result = f"""Cloned {repo_url} to {clone_path}"""
|
return f"""Cloned {repo_url} to {clone_path}"""
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -53,7 +53,11 @@ def scrape_text_with_selenium(url: str) -> Tuple[WebDriver, str]:
|
|||||||
"""
|
"""
|
||||||
logging.getLogger("selenium").setLevel(logging.CRITICAL)
|
logging.getLogger("selenium").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
options_available = {'chrome': ChromeOptions, 'safari': SafariOptions, 'firefox': FirefoxOptions}
|
options_available = {
|
||||||
|
"chrome": ChromeOptions,
|
||||||
|
"safari": SafariOptions,
|
||||||
|
"firefox": FirefoxOptions,
|
||||||
|
}
|
||||||
|
|
||||||
options = options_available[CFG.selenium_web_browser]()
|
options = options_available[CFG.selenium_web_browser]()
|
||||||
options.add_argument(
|
options.add_argument(
|
||||||
|
|||||||
@@ -137,7 +137,9 @@ class Config(metaclass=Singleton):
|
|||||||
config_params = {}
|
config_params = {}
|
||||||
self.openai_api_type = config_params.get("azure_api_type") or "azure"
|
self.openai_api_type = config_params.get("azure_api_type") or "azure"
|
||||||
self.openai_api_base = config_params.get("azure_api_base") or ""
|
self.openai_api_base = config_params.get("azure_api_base") or ""
|
||||||
self.openai_api_version = config_params.get("azure_api_version") or "2023-03-15-preview"
|
self.openai_api_version = (
|
||||||
|
config_params.get("azure_api_version") or "2023-03-15-preview"
|
||||||
|
)
|
||||||
self.azure_model_to_deployment_id_map = config_params.get("azure_model_map", [])
|
self.azure_model_to_deployment_id_map = config_params.get("azure_model_map", [])
|
||||||
|
|
||||||
def set_continuous_mode(self, value: bool) -> None:
|
def set_continuous_mode(self, value: bool) -> None:
|
||||||
|
|||||||
@@ -4,11 +4,20 @@ import json
|
|||||||
from autogpt.llm_utils import call_ai_function
|
from autogpt.llm_utils import call_ai_function
|
||||||
from autogpt.logs import logger
|
from autogpt.logs import logger
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
cfg = Config()
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
def fix_json(json_string: str, schema: str) -> str:
|
def fix_json(json_string: str, schema: str) -> str:
|
||||||
"""Fix the given JSON string to make it parseable and fully compliant with the provided schema."""
|
"""Fix the given JSON string to make it parseable and fully compliant with
|
||||||
|
the provided schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_string (str): The JSON string to fix.
|
||||||
|
schema (str): The schema to use to fix the JSON.
|
||||||
|
Returns:
|
||||||
|
str: The fixed JSON string.
|
||||||
|
"""
|
||||||
# Try to fix the JSON using GPT:
|
# Try to fix the JSON using GPT:
|
||||||
function_string = "def fix_json(json_string: str, schema:str=None) -> str:"
|
function_string = "def fix_json(json_string: str, schema:str=None) -> str:"
|
||||||
args = [f"'''{json_string}'''", f"'''{schema}'''"]
|
args = [f"'''{json_string}'''", f"'''{schema}'''"]
|
||||||
@@ -24,7 +33,7 @@ def fix_json(json_string: str, schema: str) -> str:
|
|||||||
if not json_string.startswith("`"):
|
if not json_string.startswith("`"):
|
||||||
json_string = "```json\n" + json_string + "\n```"
|
json_string = "```json\n" + json_string + "\n```"
|
||||||
result_string = call_ai_function(
|
result_string = call_ai_function(
|
||||||
function_string, args, description_string, model=cfg.fast_llm_model
|
function_string, args, description_string, model=CFG.fast_llm_model
|
||||||
)
|
)
|
||||||
logger.debug("------------ JSON FIX ATTEMPT ---------------")
|
logger.debug("------------ JSON FIX ATTEMPT ---------------")
|
||||||
logger.debug(f"Original JSON: {json_string}")
|
logger.debug(f"Original JSON: {json_string}")
|
||||||
|
|||||||
@@ -50,8 +50,10 @@ def get_memory(cfg, init=False):
|
|||||||
memory = RedisMemory(cfg)
|
memory = RedisMemory(cfg)
|
||||||
elif cfg.memory_backend == "milvus":
|
elif cfg.memory_backend == "milvus":
|
||||||
if not MilvusMemory:
|
if not MilvusMemory:
|
||||||
print("Error: Milvus sdk is not installed."
|
print(
|
||||||
"Please install pymilvus to use Milvus as memory backend.")
|
"Error: Milvus sdk is not installed."
|
||||||
|
"Please install pymilvus to use Milvus as memory backend."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
memory = MilvusMemory(cfg)
|
memory = MilvusMemory(cfg)
|
||||||
elif cfg.memory_backend == "no_memory":
|
elif cfg.memory_backend == "no_memory":
|
||||||
@@ -68,4 +70,11 @@ def get_supported_memory_backends():
|
|||||||
return supported_memory
|
return supported_memory
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_memory", "LocalCache", "RedisMemory", "PineconeMemory", "NoMemory", "MilvusMemory"]
|
__all__ = [
|
||||||
|
"get_memory",
|
||||||
|
"LocalCache",
|
||||||
|
"RedisMemory",
|
||||||
|
"PineconeMemory",
|
||||||
|
"NoMemory",
|
||||||
|
"MilvusMemory",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
""" Milvus memory storage provider."""
|
||||||
from pymilvus import (
|
from pymilvus import (
|
||||||
connections,
|
connections,
|
||||||
FieldSchema,
|
FieldSchema,
|
||||||
@@ -10,7 +11,9 @@ from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
|
|||||||
|
|
||||||
|
|
||||||
class MilvusMemory(MemoryProviderSingleton):
|
class MilvusMemory(MemoryProviderSingleton):
|
||||||
def __init__(self, cfg):
|
"""Milvus memory storage provider."""
|
||||||
|
|
||||||
|
def __init__(self, cfg) -> None:
|
||||||
"""Construct a milvus memory storage connection.
|
"""Construct a milvus memory storage connection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -19,12 +22,9 @@ class MilvusMemory(MemoryProviderSingleton):
|
|||||||
# connect to milvus server.
|
# connect to milvus server.
|
||||||
connections.connect(address=cfg.milvus_addr)
|
connections.connect(address=cfg.milvus_addr)
|
||||||
fields = [
|
fields = [
|
||||||
FieldSchema(name="pk", dtype=DataType.INT64,
|
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
||||||
is_primary=True, auto_id=True),
|
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=1536),
|
||||||
FieldSchema(name="embeddings",
|
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
||||||
dtype=DataType.FLOAT_VECTOR, dim=1536),
|
|
||||||
FieldSchema(name="raw_text", dtype=DataType.VARCHAR,
|
|
||||||
max_length=65535)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# create collection if not exist and load it.
|
# create collection if not exist and load it.
|
||||||
@@ -34,14 +34,18 @@ class MilvusMemory(MemoryProviderSingleton):
|
|||||||
# create index if not exist.
|
# create index if not exist.
|
||||||
if not self.collection.has_index():
|
if not self.collection.has_index():
|
||||||
self.collection.release()
|
self.collection.release()
|
||||||
self.collection.create_index("embeddings", {
|
self.collection.create_index(
|
||||||
|
"embeddings",
|
||||||
|
{
|
||||||
"metric_type": "IP",
|
"metric_type": "IP",
|
||||||
"index_type": "HNSW",
|
"index_type": "HNSW",
|
||||||
"params": {"M": 8, "efConstruction": 64},
|
"params": {"M": 8, "efConstruction": 64},
|
||||||
}, index_name="embeddings")
|
},
|
||||||
|
index_name="embeddings",
|
||||||
|
)
|
||||||
self.collection.load()
|
self.collection.load()
|
||||||
|
|
||||||
def add(self, data):
|
def add(self, data) -> str:
|
||||||
"""Add a embedding of data into memory.
|
"""Add a embedding of data into memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -52,7 +56,10 @@ class MilvusMemory(MemoryProviderSingleton):
|
|||||||
"""
|
"""
|
||||||
embedding = get_ada_embedding(data)
|
embedding = get_ada_embedding(data)
|
||||||
result = self.collection.insert([[embedding], [data]])
|
result = self.collection.insert([[embedding], [data]])
|
||||||
_text = f"Inserting data into memory at primary key: {result.primary_keys[0]}:\n data: {data}"
|
_text = (
|
||||||
|
"Inserting data into memory at primary key: "
|
||||||
|
f"{result.primary_keys[0]}:\n data: {data}"
|
||||||
|
)
|
||||||
return _text
|
return _text
|
||||||
|
|
||||||
def get(self, data):
|
def get(self, data):
|
||||||
@@ -62,24 +69,35 @@ class MilvusMemory(MemoryProviderSingleton):
|
|||||||
"""
|
"""
|
||||||
return self.get_relevant(data, 1)
|
return self.get_relevant(data, 1)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> str:
|
||||||
"""Drop the index in memory.
|
"""Drop the index in memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: log.
|
||||||
"""
|
"""
|
||||||
self.collection.drop()
|
self.collection.drop()
|
||||||
self.collection = Collection(self.milvus_collection, self.schema)
|
self.collection = Collection(self.milvus_collection, self.schema)
|
||||||
self.collection.create_index("embeddings", {
|
self.collection.create_index(
|
||||||
|
"embeddings",
|
||||||
|
{
|
||||||
"metric_type": "IP",
|
"metric_type": "IP",
|
||||||
"index_type": "HNSW",
|
"index_type": "HNSW",
|
||||||
"params": {"M": 8, "efConstruction": 64},
|
"params": {"M": 8, "efConstruction": 64},
|
||||||
}, index_name="embeddings")
|
},
|
||||||
|
index_name="embeddings",
|
||||||
|
)
|
||||||
self.collection.load()
|
self.collection.load()
|
||||||
return "Obliviated"
|
return "Obliviated"
|
||||||
|
|
||||||
def get_relevant(self, data, num_relevant=5):
|
def get_relevant(self, data: str, num_relevant: int = 5):
|
||||||
"""Return the top-k relevant data in memory.
|
"""Return the top-k relevant data in memory.
|
||||||
Args:
|
Args:
|
||||||
data: The data to compare to.
|
data: The data to compare to.
|
||||||
num_relevant (int, optional): The max number of relevant data. Defaults to 5.
|
num_relevant (int, optional): The max number of relevant data.
|
||||||
|
Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The top-k relevant data.
|
||||||
"""
|
"""
|
||||||
# search the embedding and return the most relevant text.
|
# search the embedding and return the most relevant text.
|
||||||
embedding = get_ada_embedding(data)
|
embedding = get_ada_embedding(data)
|
||||||
@@ -88,10 +106,15 @@ class MilvusMemory(MemoryProviderSingleton):
|
|||||||
"params": {"nprobe": 8},
|
"params": {"nprobe": 8},
|
||||||
}
|
}
|
||||||
result = self.collection.search(
|
result = self.collection.search(
|
||||||
[embedding], "embeddings", search_params, num_relevant, output_fields=["raw_text"])
|
[embedding],
|
||||||
|
"embeddings",
|
||||||
|
search_params,
|
||||||
|
num_relevant,
|
||||||
|
output_fields=["raw_text"],
|
||||||
|
)
|
||||||
return [item.entity.value_of_field("raw_text") for item in result[0]]
|
return [item.entity.value_of_field("raw_text") for item in result[0]]
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self) -> str:
|
||||||
"""
|
"""
|
||||||
Returns: The stats of the milvus cache.
|
Returns: The stats of the milvus cache.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -59,7 +59,11 @@ def get_prompt() -> str:
|
|||||||
),
|
),
|
||||||
("List GPT Agents", "list_agents", {}),
|
("List GPT Agents", "list_agents", {}),
|
||||||
("Delete GPT Agent", "delete_agent", {"key": "<key>"}),
|
("Delete GPT Agent", "delete_agent", {"key": "<key>"}),
|
||||||
("Clone Repository", "clone_repository", {"repository_url": "<url>", "clone_path": "<directory>"}),
|
(
|
||||||
|
"Clone Repository",
|
||||||
|
"clone_repository",
|
||||||
|
{"repository_url": "<url>", "clone_path": "<directory>"},
|
||||||
|
),
|
||||||
("Write to file", "write_to_file", {"file": "<file>", "text": "<text>"}),
|
("Write to file", "write_to_file", {"file": "<file>", "text": "<text>"}),
|
||||||
("Read file", "read_file", {"file": "<file>"}),
|
("Read file", "read_file", {"file": "<file>"}),
|
||||||
("Append to file", "append_to_file", {"file": "<file>", "text": "<text>"}),
|
("Append to file", "append_to_file", {"file": "<file>", "text": "<text>"}),
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
import pkg_resources
|
import pkg_resources
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
requirements_file = sys.argv[1]
|
requirements_file = sys.argv[1]
|
||||||
with open(requirements_file, 'r') as f:
|
with open(requirements_file, "r") as f:
|
||||||
required_packages = [line.strip().split('#')[0].strip() for line in f.readlines()]
|
required_packages = [
|
||||||
|
line.strip().split("#")[0].strip() for line in f.readlines()
|
||||||
|
]
|
||||||
|
|
||||||
installed_packages = [package.key for package in pkg_resources.working_set]
|
installed_packages = [package.key for package in pkg_resources.working_set]
|
||||||
|
|
||||||
@@ -12,16 +15,17 @@ def main():
|
|||||||
for package in required_packages:
|
for package in required_packages:
|
||||||
if not package: # Skip empty lines
|
if not package: # Skip empty lines
|
||||||
continue
|
continue
|
||||||
package_name = package.strip().split('==')[0]
|
package_name = package.strip().split("==")[0]
|
||||||
if package_name.lower() not in installed_packages:
|
if package_name.lower() not in installed_packages:
|
||||||
missing_packages.append(package_name)
|
missing_packages.append(package_name)
|
||||||
|
|
||||||
if missing_packages:
|
if missing_packages:
|
||||||
print('Missing packages:')
|
print("Missing packages:")
|
||||||
print(', '.join(missing_packages))
|
print(", ".join(missing_packages))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
print('All packages are installed.')
|
print("All packages are installed.")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ def MockConfig():
|
|||||||
"speak_mode": False,
|
"speak_mode": False,
|
||||||
"milvus_collection": "autogpt",
|
"milvus_collection": "autogpt",
|
||||||
"milvus_addr": "localhost:19530",
|
"milvus_addr": "localhost:19530",
|
||||||
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,56 +5,59 @@ import unittest
|
|||||||
|
|
||||||
from autogpt.file_operations import delete_file, read_file
|
from autogpt.file_operations import delete_file, read_file
|
||||||
|
|
||||||
env_vars = {
|
env_vars = {"MEMORY_BACKEND": "no_memory", "TEMPERATURE": "0"}
|
||||||
'MEMORY_BACKEND': 'no_memory',
|
|
||||||
'TEMPERATURE': "0"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TestCommands(unittest.TestCase):
|
class TestCommands(unittest.TestCase):
|
||||||
|
|
||||||
def test_write_file(self):
|
def test_write_file(self):
|
||||||
# Test case to check if the write_file command can successfully write 'Hello World' to a file
|
# Test case to check if the write_file command can successfully write 'Hello World' to a file
|
||||||
# named 'hello_world.txt'.
|
# named 'hello_world.txt'.
|
||||||
|
|
||||||
# Read the current ai_settings.yaml file and store its content.
|
# Read the current ai_settings.yaml file and store its content.
|
||||||
ai_settings = None
|
ai_settings = None
|
||||||
if os.path.exists('ai_settings.yaml'):
|
if os.path.exists("ai_settings.yaml"):
|
||||||
with open('ai_settings.yaml', 'r') as f:
|
with open("ai_settings.yaml", "r") as f:
|
||||||
ai_settings = f.read()
|
ai_settings = f.read()
|
||||||
os.remove('ai_settings.yaml')
|
os.remove("ai_settings.yaml")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.path.exists('hello_world.txt'):
|
if os.path.exists("hello_world.txt"):
|
||||||
# Clean up any existing 'hello_world.txt' file before testing.
|
# Clean up any existing 'hello_world.txt' file before testing.
|
||||||
delete_file('hello_world.txt')
|
delete_file("hello_world.txt")
|
||||||
# Prepare input data for the test.
|
# Prepare input data for the test.
|
||||||
input_data = '''write_file-GPT
|
input_data = """write_file-GPT
|
||||||
an AI designed to use the write_file command to write 'Hello World' into a file named "hello_world.txt" and then use the task_complete command to complete the task.
|
an AI designed to use the write_file command to write 'Hello World' into a file named "hello_world.txt" and then use the task_complete command to complete the task.
|
||||||
Use the write_file command to write 'Hello World' into a file named "hello_world.txt".
|
Use the write_file command to write 'Hello World' into a file named "hello_world.txt".
|
||||||
Use the task_complete command to complete the task.
|
Use the task_complete command to complete the task.
|
||||||
Do not use any other commands.
|
Do not use any other commands.
|
||||||
|
|
||||||
y -5
|
y -5
|
||||||
EOF'''
|
EOF"""
|
||||||
command = f'{sys.executable} -m autogpt'
|
command = f"{sys.executable} -m autogpt"
|
||||||
|
|
||||||
# Execute the script with the input data.
|
# Execute the script with the input data.
|
||||||
process = subprocess.Popen(command, stdin=subprocess.PIPE, shell=True, env={**os.environ, **env_vars})
|
process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
stdin=subprocess.PIPE,
|
||||||
|
shell=True,
|
||||||
|
env={**os.environ, **env_vars},
|
||||||
|
)
|
||||||
process.communicate(input_data.encode())
|
process.communicate(input_data.encode())
|
||||||
|
|
||||||
# Read the content of the 'hello_world.txt' file created during the test.
|
# Read the content of the 'hello_world.txt' file created during the test.
|
||||||
content = read_file('hello_world.txt')
|
content = read_file("hello_world.txt")
|
||||||
finally:
|
finally:
|
||||||
if ai_settings:
|
if ai_settings:
|
||||||
# Restore the original ai_settings.yaml file.
|
# Restore the original ai_settings.yaml file.
|
||||||
with open('ai_settings.yaml', 'w') as f:
|
with open("ai_settings.yaml", "w") as f:
|
||||||
f.write(ai_settings)
|
f.write(ai_settings)
|
||||||
|
|
||||||
# Check if the content of the 'hello_world.txt' file is equal to 'Hello World'.
|
# Check if the content of the 'hello_world.txt' file is equal to 'Hello World'.
|
||||||
self.assertEqual(content, 'Hello World', f"Expected 'Hello World', got {content}")
|
self.assertEqual(
|
||||||
|
content, "Hello World", f"Expected 'Hello World', got {content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Run the test case.
|
# Run the test case.
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -4,18 +4,17 @@ from autogpt.token_counter import count_message_tokens, count_string_tokens
|
|||||||
|
|
||||||
|
|
||||||
class TestTokenCounter(unittest.TestCase):
|
class TestTokenCounter(unittest.TestCase):
|
||||||
|
|
||||||
def test_count_message_tokens(self):
|
def test_count_message_tokens(self):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "Hello"},
|
{"role": "user", "content": "Hello"},
|
||||||
{"role": "assistant", "content": "Hi there!"}
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
]
|
]
|
||||||
self.assertEqual(count_message_tokens(messages), 17)
|
self.assertEqual(count_message_tokens(messages), 17)
|
||||||
|
|
||||||
def test_count_message_tokens_with_name(self):
|
def test_count_message_tokens_with_name(self):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "Hello", "name": "John"},
|
{"role": "user", "content": "Hello", "name": "John"},
|
||||||
{"role": "assistant", "content": "Hi there!"}
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
]
|
]
|
||||||
self.assertEqual(count_message_tokens(messages), 17)
|
self.assertEqual(count_message_tokens(messages), 17)
|
||||||
|
|
||||||
@@ -25,7 +24,7 @@ class TestTokenCounter(unittest.TestCase):
|
|||||||
def test_count_message_tokens_invalid_model(self):
|
def test_count_message_tokens_invalid_model(self):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "Hello"},
|
{"role": "user", "content": "Hello"},
|
||||||
{"role": "assistant", "content": "Hi there!"}
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
]
|
]
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
count_message_tokens(messages, model="invalid_model")
|
count_message_tokens(messages, model="invalid_model")
|
||||||
@@ -33,13 +32,15 @@ class TestTokenCounter(unittest.TestCase):
|
|||||||
def test_count_message_tokens_gpt_4(self):
|
def test_count_message_tokens_gpt_4(self):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "Hello"},
|
{"role": "user", "content": "Hello"},
|
||||||
{"role": "assistant", "content": "Hi there!"}
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
]
|
]
|
||||||
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
|
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
|
||||||
|
|
||||||
def test_count_string_tokens(self):
|
def test_count_string_tokens(self):
|
||||||
string = "Hello, world!"
|
string = "Hello, world!"
|
||||||
self.assertEqual(count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4)
|
self.assertEqual(
|
||||||
|
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4
|
||||||
|
)
|
||||||
|
|
||||||
def test_count_string_tokens_empty_input(self):
|
def test_count_string_tokens_empty_input(self):
|
||||||
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
|
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
|
||||||
@@ -47,7 +48,7 @@ class TestTokenCounter(unittest.TestCase):
|
|||||||
def test_count_message_tokens_invalid_model(self):
|
def test_count_message_tokens_invalid_model(self):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "Hello"},
|
{"role": "user", "content": "Hello"},
|
||||||
{"role": "assistant", "content": "Hi there!"}
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
]
|
]
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
count_message_tokens(messages, model="invalid_model")
|
count_message_tokens(messages, model="invalid_model")
|
||||||
@@ -57,5 +58,5 @@ class TestTokenCounter(unittest.TestCase):
|
|||||||
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)
|
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user