Implement directory-based plugin system (#4548)

* Implement directory-based plugin system

* Fix Selenium test

---------

Co-authored-by: Nicholas Tindle <nick@ntindle.com>
Co-authored-by: Merwane Hamadi <merwanehamadi@gmail.com>
This commit is contained in:
Erik Peterson
2023-06-10 13:16:00 -07:00
committed by GitHub
parent 6ff8478118
commit 15c6b0c1c3
8 changed files with 319 additions and 6 deletions

View File

@@ -41,7 +41,7 @@ By following these guidelines, your PRs are more likely to be merged quickly aft
black .
isort .
mypy
autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports autogpt tests --in-place
autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests --in-place
```
<!-- If you haven't added tests, please explain why. If you have, check the appropriate box. If you've ensured your PR is atomic and well-documented, check the corresponding boxes. -->

View File

@@ -69,7 +69,7 @@ jobs:
- name: Check for unused imports and pass statements
run: |
cmd="autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports autogpt tests"
cmd="autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests"
$cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
test:

View File

@@ -31,7 +31,7 @@ repos:
hooks:
- id: autoflake
name: autoflake
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports autogpt tests
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests
language: python
types: [ python ]
- id: pytest-check

View File

@@ -1,8 +1,10 @@
"""Handles loading of plugins."""
import importlib.util
import inspect
import json
import os
import sys
import zipfile
from pathlib import Path
from typing import List
@@ -217,6 +219,28 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
# Directory-based plugins
for plugin_path in [f.path for f in os.scandir(cfg.plugins_dir) if f.is_dir()]:
# Avoid going into __pycache__ or other hidden directories
if plugin_path.startswith("__"):
continue
plugin_module_path = plugin_path.split(os.path.sep)
plugin_module_name = plugin_module_path[-1]
qualified_module_name = ".".join(plugin_module_path)
__import__(qualified_module_name)
plugin = sys.modules[qualified_module_name]
for _, class_obj in inspect.getmembers(plugin):
if (
hasattr(class_obj, "_abc_impl")
and AutoGPTPluginTemplate in class_obj.__bases__
and denylist_allowlist_check(plugin_module_name, cfg)
):
loaded_plugins.append(class_obj())
# Zip-based plugins
for plugin in plugins_path_path.glob("*.zip"):
if moduleList := inspect_zip_for_modules(str(plugin), debug):
for module in moduleList:
@@ -236,6 +260,7 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
and denylist_allowlist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
# OpenAI plugins
if cfg.plugins_openai:
manifests_specs = fetch_openai_plugins_manifest_and_spec(cfg)

View File

@@ -2,6 +2,7 @@ import os
import subprocess
import sys
import zipfile
from glob import glob
from pathlib import Path
@@ -16,6 +17,8 @@ def install_plugin_dependencies():
None
"""
plugins_dir = Path(os.getenv("PLUGINS_DIR", "plugins"))
# Install zip-based plugins
for plugin in plugins_dir.glob("*.zip"):
with zipfile.ZipFile(str(plugin), "r") as zfile:
try:
@@ -30,6 +33,13 @@ def install_plugin_dependencies():
except KeyError:
continue
# Install directory-based plugins
for requirements_file in glob(f"{plugins_dir}/*/requirements.txt"):
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", requirements_file],
stdout=subprocess.DEVNULL,
)
if __name__ == "__main__":
install_plugin_dependencies()

View File

@@ -39,7 +39,7 @@ def mock_config_openai_plugin():
plugins_dir = PLUGINS_TEST_DIR
plugins_openai = [PLUGIN_TEST_OPENAI]
plugins_denylist = ["AutoGPTPVicuna"]
plugins_denylist = ["AutoGPTPVicuna", "auto_gpt_guanaco"]
plugins_allowlist = [PLUGIN_TEST_OPENAI]
return MockConfig()
@@ -60,7 +60,7 @@ def mock_config_generic_plugin():
plugins_dir = PLUGINS_TEST_DIR
plugins_openai = []
plugins_denylist = []
plugins_allowlist = ["AutoGPTPVicuna"]
plugins_allowlist = ["AutoGPTPVicuna", "auto_gpt_guanaco"]
return MockConfig()
@@ -68,4 +68,4 @@ def mock_config_generic_plugin():
def test_scan_plugins_generic(mock_config_generic_plugin):
# Test that the function returns the correct number of plugins
result = scan_plugins(mock_config_generic_plugin, debug=True)
assert len(result) == 1
assert len(result) == 2

View File

@@ -1,9 +1,13 @@
import pytest
from pytest_mock import MockerFixture
from autogpt.commands.web_selenium import browse_website
from autogpt.config import Config
from tests.utils import requires_api_key
@pytest.mark.vcr
@requires_api_key("OPENAI_API_KEY")
def test_browse_website(config: Config, patched_api_requestor: MockerFixture):
url = "https://barrel-roll.com"
question = "How to execute a barrel roll"

View File

@@ -0,0 +1,274 @@
"""This is the Test plugin for Auto-GPT."""
from typing import Any, Dict, List, Optional, Tuple, TypeVar
from auto_gpt_plugin_template import AutoGPTPluginTemplate
PromptGenerator = TypeVar("PromptGenerator")
class AutoGPTGuanaco(AutoGPTPluginTemplate):
"""
This is plugin for Auto-GPT.
"""
def __init__(self):
super().__init__()
self._name = "Auto-GPT-Guanaco"
self._version = "0.1.0"
self._description = "This is a Guanaco local model plugin."
def can_handle_on_response(self) -> bool:
"""This method is called to check that the plugin can
handle the on_response method.
Returns:
bool: True if the plugin can handle the on_response method."""
return False
def on_response(self, response: str, *args, **kwargs) -> str:
"""This method is called when a response is received from the model."""
if len(response):
print("OMG OMG It's Alive!")
else:
print("Is it alive?")
def can_handle_post_prompt(self) -> bool:
"""This method is called to check that the plugin can
handle the post_prompt method.
Returns:
bool: True if the plugin can handle the post_prompt method."""
return False
def post_prompt(self, prompt: PromptGenerator) -> PromptGenerator:
"""This method is called just after the generate_prompt is called,
but actually before the prompt is generated.
Args:
prompt (PromptGenerator): The prompt generator.
Returns:
PromptGenerator: The prompt generator.
"""
def can_handle_on_planning(self) -> bool:
"""This method is called to check that the plugin can
handle the on_planning method.
Returns:
bool: True if the plugin can handle the on_planning method."""
return False
def on_planning(
self, prompt: PromptGenerator, messages: List[str]
) -> Optional[str]:
"""This method is called before the planning chat completeion is done.
Args:
prompt (PromptGenerator): The prompt generator.
messages (List[str]): The list of messages.
"""
def can_handle_post_planning(self) -> bool:
"""This method is called to check that the plugin can
handle the post_planning method.
Returns:
bool: True if the plugin can handle the post_planning method."""
return False
def post_planning(self, response: str) -> str:
"""This method is called after the planning chat completeion is done.
Args:
response (str): The response.
Returns:
str: The resulting response.
"""
def can_handle_pre_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the pre_instruction method.
Returns:
bool: True if the plugin can handle the pre_instruction method."""
return False
def pre_instruction(self, messages: List[str]) -> List[str]:
"""This method is called before the instruction chat is done.
Args:
messages (List[str]): The list of context messages.
Returns:
List[str]: The resulting list of messages.
"""
def can_handle_on_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the on_instruction method.
Returns:
bool: True if the plugin can handle the on_instruction method."""
return False
def on_instruction(self, messages: List[str]) -> Optional[str]:
"""This method is called when the instruction chat is done.
Args:
messages (List[str]): The list of context messages.
Returns:
Optional[str]: The resulting message.
"""
def can_handle_post_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the post_instruction method.
Returns:
bool: True if the plugin can handle the post_instruction method."""
return False
def post_instruction(self, response: str) -> str:
"""This method is called after the instruction chat is done.
Args:
response (str): The response.
Returns:
str: The resulting response.
"""
def can_handle_pre_command(self) -> bool:
"""This method is called to check that the plugin can
handle the pre_command method.
Returns:
bool: True if the plugin can handle the pre_command method."""
return False
def pre_command(
self, command_name: str, arguments: Dict[str, Any]
) -> Tuple[str, Dict[str, Any]]:
"""This method is called before the command is executed.
Args:
command_name (str): The command name.
arguments (Dict[str, Any]): The arguments.
Returns:
Tuple[str, Dict[str, Any]]: The command name and the arguments.
"""
def can_handle_post_command(self) -> bool:
"""This method is called to check that the plugin can
handle the post_command method.
Returns:
bool: True if the plugin can handle the post_command method."""
return False
def post_command(self, command_name: str, response: str) -> str:
"""This method is called after the command is executed.
Args:
command_name (str): The command name.
response (str): The response.
Returns:
str: The resulting response.
"""
def can_handle_chat_completion(
self,
messages: list[Dict[Any, Any]],
model: str,
temperature: float,
max_tokens: int,
) -> bool:
"""This method is called to check that the plugin can
handle the chat_completion method.
Args:
messages (Dict[Any, Any]): The messages.
model (str): The model name.
temperature (float): The temperature.
max_tokens (int): The max tokens.
Returns:
bool: True if the plugin can handle the chat_completion method."""
return False
def handle_chat_completion(
self,
messages: list[Dict[Any, Any]],
model: str,
temperature: float,
max_tokens: int,
) -> str:
"""This method is called when the chat completion is done.
Args:
messages (Dict[Any, Any]): The messages.
model (str): The model name.
temperature (float): The temperature.
max_tokens (int): The max tokens.
Returns:
str: The resulting response.
"""
def can_handle_text_embedding(self, text: str) -> bool:
"""This method is called to check that the plugin can
handle the text_embedding method.
Args:
text (str): The text to be convert to embedding.
Returns:
bool: True if the plugin can handle the text_embedding method."""
return False
def handle_text_embedding(self, text: str) -> list:
"""This method is called when the chat completion is done.
Args:
text (str): The text to be convert to embedding.
Returns:
list: The text embedding.
"""
def can_handle_user_input(self, user_input: str) -> bool:
"""This method is called to check that the plugin can
handle the user_input method.
Args:
user_input (str): The user input.
Returns:
bool: True if the plugin can handle the user_input method."""
return False
def user_input(self, user_input: str) -> str:
"""This method is called to request user input to the user.
Args:
user_input (str): The question or prompt to ask the user.
Returns:
str: The user input.
"""
def can_handle_report(self) -> bool:
"""This method is called to check that the plugin can
handle the report method.
Returns:
bool: True if the plugin can handle the report method."""
return False
def report(self, message: str) -> None:
"""This method is called to report a message to the user.
Args:
message (str): The message to report.
"""