Implement directory-based plugin system (#4548)

* Implement directory-based plugin system

* Fix Selenium test

---------

Co-authored-by: Nicholas Tindle <nick@ntindle.com>
Co-authored-by: Merwane Hamadi <merwanehamadi@gmail.com>
This commit is contained in:
Erik Peterson
2023-06-10 13:16:00 -07:00
committed by GitHub
parent 6ff8478118
commit 15c6b0c1c3
8 changed files with 319 additions and 6 deletions

View File

@@ -41,7 +41,7 @@ By following these guidelines, your PRs are more likely to be merged quickly aft
black . black .
isort . isort .
mypy mypy
autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports autogpt tests --in-place autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests --in-place
``` ```
<!-- If you haven't added tests, please explain why. If you have, check the appropriate box. If you've ensured your PR is atomic and well-documented, check the corresponding boxes. --> <!-- If you haven't added tests, please explain why. If you have, check the appropriate box. If you've ensured your PR is atomic and well-documented, check the corresponding boxes. -->

View File

@@ -69,7 +69,7 @@ jobs:
- name: Check for unused imports and pass statements - name: Check for unused imports and pass statements
run: | run: |
cmd="autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports autogpt tests" cmd="autoflake --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests"
$cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1) $cmd --check || (echo "You have unused imports or pass statements, please run '${cmd} --in-place'" && exit 1)
test: test:

View File

@@ -31,7 +31,7 @@ repos:
hooks: hooks:
- id: autoflake - id: autoflake
name: autoflake name: autoflake
entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports autogpt tests entry: autoflake --in-place --remove-all-unused-imports --recursive --ignore-init-module-imports --ignore-pass-after-docstring autogpt tests
language: python language: python
types: [ python ] types: [ python ]
- id: pytest-check - id: pytest-check

View File

@@ -1,8 +1,10 @@
"""Handles loading of plugins.""" """Handles loading of plugins."""
import importlib.util import importlib.util
import inspect
import json import json
import os import os
import sys
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from typing import List from typing import List
@@ -217,6 +219,28 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}") logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}") logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
# Directory-based plugins
for plugin_path in [f.path for f in os.scandir(cfg.plugins_dir) if f.is_dir()]:
# Avoid going into __pycache__ or other hidden directories
if plugin_path.startswith("__"):
continue
plugin_module_path = plugin_path.split(os.path.sep)
plugin_module_name = plugin_module_path[-1]
qualified_module_name = ".".join(plugin_module_path)
__import__(qualified_module_name)
plugin = sys.modules[qualified_module_name]
for _, class_obj in inspect.getmembers(plugin):
if (
hasattr(class_obj, "_abc_impl")
and AutoGPTPluginTemplate in class_obj.__bases__
and denylist_allowlist_check(plugin_module_name, cfg)
):
loaded_plugins.append(class_obj())
# Zip-based plugins
for plugin in plugins_path_path.glob("*.zip"): for plugin in plugins_path_path.glob("*.zip"):
if moduleList := inspect_zip_for_modules(str(plugin), debug): if moduleList := inspect_zip_for_modules(str(plugin), debug):
for module in moduleList: for module in moduleList:
@@ -236,6 +260,7 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
and denylist_allowlist_check(a_module.__name__, cfg) and denylist_allowlist_check(a_module.__name__, cfg)
): ):
loaded_plugins.append(a_module()) loaded_plugins.append(a_module())
# OpenAI plugins # OpenAI plugins
if cfg.plugins_openai: if cfg.plugins_openai:
manifests_specs = fetch_openai_plugins_manifest_and_spec(cfg) manifests_specs = fetch_openai_plugins_manifest_and_spec(cfg)

View File

@@ -2,6 +2,7 @@ import os
import subprocess import subprocess
import sys import sys
import zipfile import zipfile
from glob import glob
from pathlib import Path from pathlib import Path
@@ -16,6 +17,8 @@ def install_plugin_dependencies():
None None
""" """
plugins_dir = Path(os.getenv("PLUGINS_DIR", "plugins")) plugins_dir = Path(os.getenv("PLUGINS_DIR", "plugins"))
# Install zip-based plugins
for plugin in plugins_dir.glob("*.zip"): for plugin in plugins_dir.glob("*.zip"):
with zipfile.ZipFile(str(plugin), "r") as zfile: with zipfile.ZipFile(str(plugin), "r") as zfile:
try: try:
@@ -30,6 +33,13 @@ def install_plugin_dependencies():
except KeyError: except KeyError:
continue continue
# Install directory-based plugins
for requirements_file in glob(f"{plugins_dir}/*/requirements.txt"):
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", requirements_file],
stdout=subprocess.DEVNULL,
)
if __name__ == "__main__": if __name__ == "__main__":
install_plugin_dependencies() install_plugin_dependencies()

View File

@@ -39,7 +39,7 @@ def mock_config_openai_plugin():
plugins_dir = PLUGINS_TEST_DIR plugins_dir = PLUGINS_TEST_DIR
plugins_openai = [PLUGIN_TEST_OPENAI] plugins_openai = [PLUGIN_TEST_OPENAI]
plugins_denylist = ["AutoGPTPVicuna"] plugins_denylist = ["AutoGPTPVicuna", "auto_gpt_guanaco"]
plugins_allowlist = [PLUGIN_TEST_OPENAI] plugins_allowlist = [PLUGIN_TEST_OPENAI]
return MockConfig() return MockConfig()
@@ -60,7 +60,7 @@ def mock_config_generic_plugin():
plugins_dir = PLUGINS_TEST_DIR plugins_dir = PLUGINS_TEST_DIR
plugins_openai = [] plugins_openai = []
plugins_denylist = [] plugins_denylist = []
plugins_allowlist = ["AutoGPTPVicuna"] plugins_allowlist = ["AutoGPTPVicuna", "auto_gpt_guanaco"]
return MockConfig() return MockConfig()
@@ -68,4 +68,4 @@ def mock_config_generic_plugin():
def test_scan_plugins_generic(mock_config_generic_plugin): def test_scan_plugins_generic(mock_config_generic_plugin):
# Test that the function returns the correct number of plugins # Test that the function returns the correct number of plugins
result = scan_plugins(mock_config_generic_plugin, debug=True) result = scan_plugins(mock_config_generic_plugin, debug=True)
assert len(result) == 1 assert len(result) == 2

View File

@@ -1,9 +1,13 @@
import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from autogpt.commands.web_selenium import browse_website from autogpt.commands.web_selenium import browse_website
from autogpt.config import Config from autogpt.config import Config
from tests.utils import requires_api_key
@pytest.mark.vcr
@requires_api_key("OPENAI_API_KEY")
def test_browse_website(config: Config, patched_api_requestor: MockerFixture): def test_browse_website(config: Config, patched_api_requestor: MockerFixture):
url = "https://barrel-roll.com" url = "https://barrel-roll.com"
question = "How to execute a barrel roll" question = "How to execute a barrel roll"

View File

@@ -0,0 +1,274 @@
"""This is the Test plugin for Auto-GPT."""
from typing import Any, Dict, List, Optional, Tuple, TypeVar
from auto_gpt_plugin_template import AutoGPTPluginTemplate
PromptGenerator = TypeVar("PromptGenerator")
class AutoGPTGuanaco(AutoGPTPluginTemplate):
"""
This is plugin for Auto-GPT.
"""
def __init__(self):
super().__init__()
self._name = "Auto-GPT-Guanaco"
self._version = "0.1.0"
self._description = "This is a Guanaco local model plugin."
def can_handle_on_response(self) -> bool:
"""This method is called to check that the plugin can
handle the on_response method.
Returns:
bool: True if the plugin can handle the on_response method."""
return False
def on_response(self, response: str, *args, **kwargs) -> str:
"""This method is called when a response is received from the model."""
if len(response):
print("OMG OMG It's Alive!")
else:
print("Is it alive?")
def can_handle_post_prompt(self) -> bool:
"""This method is called to check that the plugin can
handle the post_prompt method.
Returns:
bool: True if the plugin can handle the post_prompt method."""
return False
def post_prompt(self, prompt: PromptGenerator) -> PromptGenerator:
"""This method is called just after the generate_prompt is called,
but actually before the prompt is generated.
Args:
prompt (PromptGenerator): The prompt generator.
Returns:
PromptGenerator: The prompt generator.
"""
def can_handle_on_planning(self) -> bool:
"""This method is called to check that the plugin can
handle the on_planning method.
Returns:
bool: True if the plugin can handle the on_planning method."""
return False
def on_planning(
self, prompt: PromptGenerator, messages: List[str]
) -> Optional[str]:
"""This method is called before the planning chat completeion is done.
Args:
prompt (PromptGenerator): The prompt generator.
messages (List[str]): The list of messages.
"""
def can_handle_post_planning(self) -> bool:
"""This method is called to check that the plugin can
handle the post_planning method.
Returns:
bool: True if the plugin can handle the post_planning method."""
return False
def post_planning(self, response: str) -> str:
"""This method is called after the planning chat completeion is done.
Args:
response (str): The response.
Returns:
str: The resulting response.
"""
def can_handle_pre_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the pre_instruction method.
Returns:
bool: True if the plugin can handle the pre_instruction method."""
return False
def pre_instruction(self, messages: List[str]) -> List[str]:
"""This method is called before the instruction chat is done.
Args:
messages (List[str]): The list of context messages.
Returns:
List[str]: The resulting list of messages.
"""
def can_handle_on_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the on_instruction method.
Returns:
bool: True if the plugin can handle the on_instruction method."""
return False
def on_instruction(self, messages: List[str]) -> Optional[str]:
"""This method is called when the instruction chat is done.
Args:
messages (List[str]): The list of context messages.
Returns:
Optional[str]: The resulting message.
"""
def can_handle_post_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the post_instruction method.
Returns:
bool: True if the plugin can handle the post_instruction method."""
return False
def post_instruction(self, response: str) -> str:
"""This method is called after the instruction chat is done.
Args:
response (str): The response.
Returns:
str: The resulting response.
"""
def can_handle_pre_command(self) -> bool:
"""This method is called to check that the plugin can
handle the pre_command method.
Returns:
bool: True if the plugin can handle the pre_command method."""
return False
def pre_command(
self, command_name: str, arguments: Dict[str, Any]
) -> Tuple[str, Dict[str, Any]]:
"""This method is called before the command is executed.
Args:
command_name (str): The command name.
arguments (Dict[str, Any]): The arguments.
Returns:
Tuple[str, Dict[str, Any]]: The command name and the arguments.
"""
def can_handle_post_command(self) -> bool:
"""This method is called to check that the plugin can
handle the post_command method.
Returns:
bool: True if the plugin can handle the post_command method."""
return False
def post_command(self, command_name: str, response: str) -> str:
"""This method is called after the command is executed.
Args:
command_name (str): The command name.
response (str): The response.
Returns:
str: The resulting response.
"""
def can_handle_chat_completion(
self,
messages: list[Dict[Any, Any]],
model: str,
temperature: float,
max_tokens: int,
) -> bool:
"""This method is called to check that the plugin can
handle the chat_completion method.
Args:
messages (Dict[Any, Any]): The messages.
model (str): The model name.
temperature (float): The temperature.
max_tokens (int): The max tokens.
Returns:
bool: True if the plugin can handle the chat_completion method."""
return False
def handle_chat_completion(
self,
messages: list[Dict[Any, Any]],
model: str,
temperature: float,
max_tokens: int,
) -> str:
"""This method is called when the chat completion is done.
Args:
messages (Dict[Any, Any]): The messages.
model (str): The model name.
temperature (float): The temperature.
max_tokens (int): The max tokens.
Returns:
str: The resulting response.
"""
def can_handle_text_embedding(self, text: str) -> bool:
"""This method is called to check that the plugin can
handle the text_embedding method.
Args:
text (str): The text to be convert to embedding.
Returns:
bool: True if the plugin can handle the text_embedding method."""
return False
def handle_text_embedding(self, text: str) -> list:
"""This method is called when the chat completion is done.
Args:
text (str): The text to be convert to embedding.
Returns:
list: The text embedding.
"""
def can_handle_user_input(self, user_input: str) -> bool:
"""This method is called to check that the plugin can
handle the user_input method.
Args:
user_input (str): The user input.
Returns:
bool: True if the plugin can handle the user_input method."""
return False
def user_input(self, user_input: str) -> str:
"""This method is called to request user input to the user.
Args:
user_input (str): The question or prompt to ask the user.
Returns:
str: The user input.
"""
def can_handle_report(self) -> bool:
"""This method is called to check that the plugin can
handle the report method.
Returns:
bool: True if the plugin can handle the report method."""
return False
def report(self, message: str) -> None:
"""This method is called to report a message to the user.
Args:
message (str): The message to report.
"""