Quality update

This commit is contained in:
BillSchumacher
2023-04-15 14:55:13 -05:00
parent 82f53aae54
commit 11d6dabe37
15 changed files with 161 additions and 91 deletions

View File

@@ -34,7 +34,9 @@ def main() -> None:
# Initialize memory and make sure it is empty.
# this is particularly important for indexing and referencing pinecone memory
memory = get_memory(cfg, init=True)
logger.typewriter_log(f"Using memory of type:", Fore.GREEN, f"{memory.__class__.__name__}")
logger.typewriter_log(
f"Using memory of type:", Fore.GREEN, f"{memory.__class__.__name__}"
)
logger.typewriter_log(f"Using Browser:", Fore.GREEN, cfg.selenium_web_browser)
agent = Agent(
ai_name=ai_name,

View File

@@ -89,13 +89,13 @@ def get_command(response: str):
def map_command_synonyms(command_name: str):
""" Takes the original command name given by the AI, and checks if the
string matches a list of common/known hallucinations
"""Takes the original command name given by the AI, and checks if the
string matches a list of common/known hallucinations
"""
synonyms = [
('write_file', 'write_to_file'),
('create_file', 'write_to_file'),
('search', 'google')
("write_file", "write_to_file"),
("create_file", "write_to_file"),
("search", "google"),
]
for seen_command, actual_command_name in synonyms:
if command_name == seen_command:
@@ -125,7 +125,7 @@ def execute_command(command_name: str, arguments):
google_result = google_official_search(arguments["input"])
else:
google_result = google_search(arguments["input"])
safe_message = google_result.encode('utf-8', 'ignore')
safe_message = google_result.encode("utf-8", "ignore")
return str(safe_message)
elif command_name == "memory_add":
return memory.add(arguments["string"])
@@ -144,7 +144,9 @@ def execute_command(command_name: str, arguments):
elif command_name == "get_hyperlinks":
return get_hyperlinks(arguments["url"])
elif command_name == "clone_repository":
return clone_repository(arguments["repository_url"], arguments["clone_path"])
return clone_repository(
arguments["repository_url"], arguments["clone_path"]
)
elif command_name == "read_file":
return read_file(arguments["file"])
elif command_name == "write_to_file":
@@ -278,7 +280,9 @@ def list_agents():
Returns:
str: A list of all agents
"""
return "List of agents:\n" + "\n".join([str(x[0]) + ": " + x[1] for x in AGENT_MANAGER.list_agents()])
return "List of agents:\n" + "\n".join(
[str(x[0]) + ": " + x[1] for x in AGENT_MANAGER.list_agents()]
)
def delete_agent(key: str) -> str:

View File

@@ -54,7 +54,7 @@ def parse_arguments() -> None:
"--use-browser",
"-b",
dest="browser_name",
help="Specifies which web-browser to use when using selenium to scrape the web."
help="Specifies which web-browser to use when using selenium to scrape the web.",
)
parser.add_argument(
"--ai-settings",

View File

@@ -99,8 +99,8 @@ def execute_shell(command_line: str) -> str:
str: The output of the command
"""
current_dir = os.getcwd()
if str(WORKING_DIRECTORY) not in current_dir: # Change dir into workspace if necessary
# Change dir into workspace if necessary
if str(WORKING_DIRECTORY) not in current_dir:
work_dir = os.path.join(os.getcwd(), WORKING_DIRECTORY)
os.chdir(work_dir)

View File

@@ -1,14 +1,20 @@
"""Git operations for autogpt"""
import git
from autogpt.config import Config
cfg = Config()
CFG = Config()
def clone_repository(repo_url, clone_path):
"""Clone a github repository locally"""
def clone_repository(repo_url: str, clone_path: str) -> str:
"""Clone a github repository locally
Args:
repo_url (str): The URL of the repository to clone
clone_path (str): The path to clone the repository to
Returns:
str: The result of the clone operation"""
split_url = repo_url.split("//")
auth_repo_url = f"//{cfg.github_username}:{cfg.github_api_key}@".join(split_url)
auth_repo_url = f"//{CFG.github_username}:{CFG.github_api_key}@".join(split_url)
git.Repo.clone_from(auth_repo_url, clone_path)
result = f"""Cloned {repo_url} to {clone_path}"""
return result
return f"""Cloned {repo_url} to {clone_path}"""

View File

@@ -53,7 +53,11 @@ def scrape_text_with_selenium(url: str) -> Tuple[WebDriver, str]:
"""
logging.getLogger("selenium").setLevel(logging.CRITICAL)
options_available = {'chrome': ChromeOptions, 'safari': SafariOptions, 'firefox': FirefoxOptions}
options_available = {
"chrome": ChromeOptions,
"safari": SafariOptions,
"firefox": FirefoxOptions,
}
options = options_available[CFG.selenium_web_browser]()
options.add_argument(

View File

@@ -137,7 +137,9 @@ class Config(metaclass=Singleton):
config_params = {}
self.openai_api_type = config_params.get("azure_api_type") or "azure"
self.openai_api_base = config_params.get("azure_api_base") or ""
self.openai_api_version = config_params.get("azure_api_version") or "2023-03-15-preview"
self.openai_api_version = (
config_params.get("azure_api_version") or "2023-03-15-preview"
)
self.azure_model_to_deployment_id_map = config_params.get("azure_model_map", [])
def set_continuous_mode(self, value: bool) -> None:

View File

@@ -4,11 +4,20 @@ import json
from autogpt.llm_utils import call_ai_function
from autogpt.logs import logger
from autogpt.config import Config
cfg = Config()
CFG = Config()
def fix_json(json_string: str, schema: str) -> str:
"""Fix the given JSON string to make it parseable and fully compliant with the provided schema."""
"""Fix the given JSON string to make it parseable and fully compliant with
the provided schema.
Args:
json_string (str): The JSON string to fix.
schema (str): The schema to use to fix the JSON.
Returns:
str: The fixed JSON string.
"""
# Try to fix the JSON using GPT:
function_string = "def fix_json(json_string: str, schema:str=None) -> str:"
args = [f"'''{json_string}'''", f"'''{schema}'''"]
@@ -24,7 +33,7 @@ def fix_json(json_string: str, schema: str) -> str:
if not json_string.startswith("`"):
json_string = "```json\n" + json_string + "\n```"
result_string = call_ai_function(
function_string, args, description_string, model=cfg.fast_llm_model
function_string, args, description_string, model=CFG.fast_llm_model
)
logger.debug("------------ JSON FIX ATTEMPT ---------------")
logger.debug(f"Original JSON: {json_string}")

View File

@@ -50,8 +50,10 @@ def get_memory(cfg, init=False):
memory = RedisMemory(cfg)
elif cfg.memory_backend == "milvus":
if not MilvusMemory:
print("Error: Milvus sdk is not installed."
"Please install pymilvus to use Milvus as memory backend.")
print(
"Error: Milvus sdk is not installed."
"Please install pymilvus to use Milvus as memory backend."
)
else:
memory = MilvusMemory(cfg)
elif cfg.memory_backend == "no_memory":
@@ -68,4 +70,11 @@ def get_supported_memory_backends():
return supported_memory
__all__ = ["get_memory", "LocalCache", "RedisMemory", "PineconeMemory", "NoMemory", "MilvusMemory"]
__all__ = [
"get_memory",
"LocalCache",
"RedisMemory",
"PineconeMemory",
"NoMemory",
"MilvusMemory",
]

View File

@@ -1,3 +1,4 @@
""" Milvus memory storage provider."""
from pymilvus import (
connections,
FieldSchema,
@@ -10,8 +11,10 @@ from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding
class MilvusMemory(MemoryProviderSingleton):
def __init__(self, cfg):
""" Construct a milvus memory storage connection.
"""Milvus memory storage provider."""
def __init__(self, cfg) -> None:
"""Construct a milvus memory storage connection.
Args:
cfg (Config): Auto-GPT global config.
@@ -19,12 +22,9 @@ class MilvusMemory(MemoryProviderSingleton):
# connect to milvus server.
connections.connect(address=cfg.milvus_addr)
fields = [
FieldSchema(name="pk", dtype=DataType.INT64,
is_primary=True, auto_id=True),
FieldSchema(name="embeddings",
dtype=DataType.FLOAT_VECTOR, dim=1536),
FieldSchema(name="raw_text", dtype=DataType.VARCHAR,
max_length=65535)
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=1536),
FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
]
# create collection if not exist and load it.
@@ -34,15 +34,19 @@ class MilvusMemory(MemoryProviderSingleton):
# create index if not exist.
if not self.collection.has_index():
self.collection.release()
self.collection.create_index("embeddings", {
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}, index_name="embeddings")
self.collection.create_index(
"embeddings",
{
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
},
index_name="embeddings",
)
self.collection.load()
def add(self, data):
""" Add a embedding of data into memory.
def add(self, data) -> str:
"""Add a embedding of data into memory.
Args:
data (str): The raw text to construct embedding index.
@@ -52,34 +56,48 @@ class MilvusMemory(MemoryProviderSingleton):
"""
embedding = get_ada_embedding(data)
result = self.collection.insert([[embedding], [data]])
_text = f"Inserting data into memory at primary key: {result.primary_keys[0]}:\n data: {data}"
_text = (
"Inserting data into memory at primary key: "
f"{result.primary_keys[0]}:\n data: {data}"
)
return _text
def get(self, data):
""" Return the most relevant data in memory.
"""Return the most relevant data in memory.
Args:
data: The data to compare to.
"""
return self.get_relevant(data, 1)
def clear(self):
""" Drop the index in memory.
def clear(self) -> str:
"""Drop the index in memory.
Returns:
str: log.
"""
self.collection.drop()
self.collection = Collection(self.milvus_collection, self.schema)
self.collection.create_index("embeddings", {
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}, index_name="embeddings")
self.collection.create_index(
"embeddings",
{
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
},
index_name="embeddings",
)
self.collection.load()
return "Obliviated"
def get_relevant(self, data, num_relevant=5):
""" Return the top-k relevant data in memory.
def get_relevant(self, data: str, num_relevant: int = 5):
"""Return the top-k relevant data in memory.
Args:
data: The data to compare to.
num_relevant (int, optional): The max number of relevant data. Defaults to 5.
num_relevant (int, optional): The max number of relevant data.
Defaults to 5.
Returns:
list: The top-k relevant data.
"""
# search the embedding and return the most relevant text.
embedding = get_ada_embedding(data)
@@ -88,10 +106,15 @@ class MilvusMemory(MemoryProviderSingleton):
"params": {"nprobe": 8},
}
result = self.collection.search(
[embedding], "embeddings", search_params, num_relevant, output_fields=["raw_text"])
[embedding],
"embeddings",
search_params,
num_relevant,
output_fields=["raw_text"],
)
return [item.entity.value_of_field("raw_text") for item in result[0]]
def get_stats(self):
def get_stats(self) -> str:
"""
Returns: The stats of the milvus cache.
"""

View File

@@ -59,7 +59,11 @@ def get_prompt() -> str:
),
("List GPT Agents", "list_agents", {}),
("Delete GPT Agent", "delete_agent", {"key": "<key>"}),
("Clone Repository", "clone_repository", {"repository_url": "<url>", "clone_path": "<directory>"}),
(
"Clone Repository",
"clone_repository",
{"repository_url": "<url>", "clone_path": "<directory>"},
),
("Write to file", "write_to_file", {"file": "<file>", "text": "<text>"}),
("Read file", "read_file", {"file": "<file>"}),
("Append to file", "append_to_file", {"file": "<file>", "text": "<text>"}),

View File

@@ -1,10 +1,13 @@
import pkg_resources
import sys
def main():
requirements_file = sys.argv[1]
with open(requirements_file, 'r') as f:
required_packages = [line.strip().split('#')[0].strip() for line in f.readlines()]
with open(requirements_file, "r") as f:
required_packages = [
line.strip().split("#")[0].strip() for line in f.readlines()
]
installed_packages = [package.key for package in pkg_resources.working_set]
@@ -12,16 +15,17 @@ def main():
for package in required_packages:
if not package: # Skip empty lines
continue
package_name = package.strip().split('==')[0]
package_name = package.strip().split("==")[0]
if package_name.lower() not in installed_packages:
missing_packages.append(package_name)
if missing_packages:
print('Missing packages:')
print(', '.join(missing_packages))
print("Missing packages:")
print(", ".join(missing_packages))
sys.exit(1)
else:
print('All packages are installed.')
print("All packages are installed.")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -15,7 +15,6 @@ def MockConfig():
"speak_mode": False,
"milvus_collection": "autogpt",
"milvus_addr": "localhost:19530",
},
)

View File

@@ -5,56 +5,59 @@ import unittest
from autogpt.file_operations import delete_file, read_file
env_vars = {
'MEMORY_BACKEND': 'no_memory',
'TEMPERATURE': "0"
}
env_vars = {"MEMORY_BACKEND": "no_memory", "TEMPERATURE": "0"}
class TestCommands(unittest.TestCase):
def test_write_file(self):
# Test case to check if the write_file command can successfully write 'Hello World' to a file
# named 'hello_world.txt'.
# Read the current ai_settings.yaml file and store its content.
ai_settings = None
if os.path.exists('ai_settings.yaml'):
with open('ai_settings.yaml', 'r') as f:
if os.path.exists("ai_settings.yaml"):
with open("ai_settings.yaml", "r") as f:
ai_settings = f.read()
os.remove('ai_settings.yaml')
os.remove("ai_settings.yaml")
try:
if os.path.exists('hello_world.txt'):
if os.path.exists("hello_world.txt"):
# Clean up any existing 'hello_world.txt' file before testing.
delete_file('hello_world.txt')
delete_file("hello_world.txt")
# Prepare input data for the test.
input_data = '''write_file-GPT
input_data = """write_file-GPT
an AI designed to use the write_file command to write 'Hello World' into a file named "hello_world.txt" and then use the task_complete command to complete the task.
Use the write_file command to write 'Hello World' into a file named "hello_world.txt".
Use the task_complete command to complete the task.
Do not use any other commands.
y -5
EOF'''
command = f'{sys.executable} -m autogpt'
EOF"""
command = f"{sys.executable} -m autogpt"
# Execute the script with the input data.
process = subprocess.Popen(command, stdin=subprocess.PIPE, shell=True, env={**os.environ, **env_vars})
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
shell=True,
env={**os.environ, **env_vars},
)
process.communicate(input_data.encode())
# Read the content of the 'hello_world.txt' file created during the test.
content = read_file('hello_world.txt')
content = read_file("hello_world.txt")
finally:
if ai_settings:
# Restore the original ai_settings.yaml file.
with open('ai_settings.yaml', 'w') as f:
with open("ai_settings.yaml", "w") as f:
f.write(ai_settings)
# Check if the content of the 'hello_world.txt' file is equal to 'Hello World'.
self.assertEqual(content, 'Hello World', f"Expected 'Hello World', got {content}")
self.assertEqual(
content, "Hello World", f"Expected 'Hello World', got {content}"
)
# Run the test case.
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@@ -4,18 +4,17 @@ from autogpt.token_counter import count_message_tokens, count_string_tokens
class TestTokenCounter(unittest.TestCase):
def test_count_message_tokens(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages), 17)
def test_count_message_tokens_with_name(self):
messages = [
{"role": "user", "content": "Hello", "name": "John"},
{"role": "assistant", "content": "Hi there!"}
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages), 17)
@@ -25,7 +24,7 @@ class TestTokenCounter(unittest.TestCase):
def test_count_message_tokens_invalid_model(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
{"role": "assistant", "content": "Hi there!"},
]
with self.assertRaises(KeyError):
count_message_tokens(messages, model="invalid_model")
@@ -33,13 +32,15 @@ class TestTokenCounter(unittest.TestCase):
def test_count_message_tokens_gpt_4(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
{"role": "assistant", "content": "Hi there!"},
]
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
def test_count_string_tokens(self):
string = "Hello, world!"
self.assertEqual(count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4)
self.assertEqual(
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4
)
def test_count_string_tokens_empty_input(self):
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
@@ -47,7 +48,7 @@ class TestTokenCounter(unittest.TestCase):
def test_count_message_tokens_invalid_model(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
{"role": "assistant", "content": "Hi there!"},
]
with self.assertRaises(NotImplementedError):
count_message_tokens(messages, model="invalid_model")
@@ -57,5 +58,5 @@ class TestTokenCounter(unittest.TestCase):
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()