diff --git a/autogpt/__main__.py b/autogpt/__main__.py index 5fc9a1ea..f995fb12 100644 --- a/autogpt/__main__.py +++ b/autogpt/__main__.py @@ -24,27 +24,7 @@ def main() -> None: check_openai_api_key() parse_arguments() logger.set_level(logging.DEBUG if cfg.debug_mode else logging.INFO) - plugins_found = load_plugins(Path(os.getcwd()) / "plugins") - loaded_plugins = [] - for plugin in plugins_found: - if plugin.__name__ in cfg.plugins_blacklist: - continue - if plugin.__name__ in cfg.plugins_whitelist: - loaded_plugins.append(plugin()) - else: - ack = input( - f"WARNNG Plugin {plugin.__name__} found. But not in the" - " whitelist... Load? (y/n): " - ) - if ack.lower() == "y": - loaded_plugins.append(plugin()) - - if loaded_plugins: - print(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------") - for plugin in loaded_plugins: - print(f"{plugin._name}: {plugin._version} - {plugin._description}") - - cfg.set_plugins(loaded_plugins) + cfg.set_plugins(load_plugins(cfg)) # Create a CommandRegistry instance and scan default folder command_registry = CommandRegistry() command_registry.import_commands("scripts.ai_functions") diff --git a/autogpt/config/config.py b/autogpt/config/config.py index 46ab95d8..66a23086 100644 --- a/autogpt/config/config.py +++ b/autogpt/config/config.py @@ -107,6 +107,7 @@ class Config(metaclass=Singleton): # Initialize the OpenAI API client openai.api_key = self.openai_api_key + self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins") self.plugins = [] self.plugins_whitelist = [] self.plugins_blacklist = [] diff --git a/autogpt/plugins.py b/autogpt/plugins.py index 7b843a6a..18680cba 100644 --- a/autogpt/plugins.py +++ b/autogpt/plugins.py @@ -1,11 +1,15 @@ """Handles loading of plugins.""" - -from ast import Module +import os import zipfile +from glob import glob from pathlib import Path from zipimport import zipimporter from typing import List, Optional, Tuple +from abstract_singleton import AbstractSingleton + +from autogpt.config import Config + def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]: """ @@ -29,32 +33,66 @@ def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]: return None -def scan_plugins(plugins_path: Path, debug: bool = False) -> List[Tuple[str, Path]]: +def scan_plugins(plugins_path: str, debug: bool = False) -> List[Tuple[str, Path]]: """Scan the plugins directory for plugins. Args: - plugins_path (Path): Path to the plugins directory. + plugins_path (str): Path to the plugins directory. + debug (bool, optional): Enable debug logging. Defaults to False. Returns: - List[Path]: List of plugins. + List[Tuple[str, Path]]: List of plugins. """ plugins = [] - for plugin in plugins_path.glob("*.zip"): + plugins_path_path = Path(plugins_path) + + for plugin in plugins_path_path.glob("*.zip"): if module := inspect_zip_for_module(str(plugin), debug): plugins.append((module, plugin)) return plugins -def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]: +def blacklist_whitelist_check(plugins: List[AbstractSingleton], cfg: Config): + """Check if the plugin is in the whitelist or blacklist. + + Args: + plugins (List[Tuple[str, Path]]): List of plugins. + cfg (Config): Config object. + + Returns: + List[Tuple[str, Path]]: List of plugins. + """ + loaded_plugins = [] + for plugin in plugins: + if plugin.__name__ in cfg.plugins_blacklist: + continue + if plugin.__name__ in cfg.plugins_whitelist: + loaded_plugins.append(plugin()) + else: + ack = input( + f"WARNNG Plugin {plugin.__name__} found. But not in the" + " whitelist... Load? (y/n): " + ) + if ack.lower() == "y": + loaded_plugins.append(plugin()) + + if loaded_plugins: + print(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------") + for plugin in loaded_plugins: + print(f"{plugin._name}: {plugin._version} - {plugin._description}") + return loaded_plugins + + +def load_plugins(cfg: Config = Config(), debug: bool = False) -> List[object]: """Load plugins from the plugins directory. Args: - plugins_path (Path): Path to the plugins directory. - + cfg (Config): Config instance inluding plugins config + debug (bool, optional): Enable debug logging. Defaults to False. Returns: - List[Path]: List of plugins. + List[AbstractSingleton]: List of plugins initialized. """ - plugins = scan_plugins(plugins_path) + plugins = scan_plugins(cfg.plugins_dir) plugin_modules = [] for module, plugin in plugins: plugin = Path(plugin) @@ -70,4 +108,5 @@ def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]: a_keys = dir(a_module) if "_abc_impl" in a_keys and a_module.__name__ != "AutoGPTPluginTemplate": plugin_modules.append(a_module) - return plugin_modules + loaded_plugin_modules = blacklist_whitelist_check(plugin_modules, cfg) + return loaded_plugin_modules diff --git a/tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip b/tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip new file mode 100644 index 00000000..f0db121e Binary files /dev/null and b/tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip differ diff --git a/tests/unit/test_plugins.py b/tests/unit/test_plugins.py new file mode 100644 index 00000000..a1d9d5e7 --- /dev/null +++ b/tests/unit/test_plugins.py @@ -0,0 +1,32 @@ +import pytest +from pathlib import Path +from zipfile import ZipFile +from autogpt.plugins import inspect_zip_for_module, scan_plugins, load_plugins +from autogpt.config import Config + +PLUGINS_TEST_DIR = "tests/unit/data/test_plugins/" +PLUGIN_TEST_ZIP_FILE = "Auto-GPT-Plugin-Test-master.zip" +PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_plugin_template/__init__.py" + + +@pytest.fixture +def config_with_plugins(): + cfg = Config() + cfg.plugins_dir = PLUGINS_TEST_DIR + return cfg + + +def test_inspect_zip_for_module(): + result = inspect_zip_for_module(str(PLUGINS_TEST_DIR + PLUGIN_TEST_ZIP_FILE)) + assert result == PLUGIN_TEST_INIT_PY + +def test_scan_plugins(): + result = scan_plugins(PLUGINS_TEST_DIR, debug=True) + assert len(result) == 1 + assert result[0][0] == PLUGIN_TEST_INIT_PY + + +def test_load_plugins_blacklisted(config_with_plugins): + config_with_plugins.plugins_blacklist = ['AbstractSingleton'] + result = load_plugins(cfg=config_with_plugins) + assert len(result) == 0