diff --git a/autogpt/config/config.py b/autogpt/config/config.py index 1c2084f7..5711764c 100644 --- a/autogpt/config/config.py +++ b/autogpt/config/config.py @@ -4,7 +4,7 @@ from __future__ import annotations import contextlib import os import re -from typing import Dict +from typing import Dict, Union import yaml from colorama import Fore @@ -83,18 +83,6 @@ class Config(SystemSettings): plugins: list[str] authorise_key: str - def __init__(self, **kwargs): - super().__init__(**kwargs) - - # Hotfix: Call model_post_init explictly as it doesn't seem to be called for pydantic<2.0.0 - # https://github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214 - self.model_post_init(**kwargs) - - # Executed immediately after init by Pydantic - def model_post_init(self, **kwargs) -> None: - if not self.plugins_config.plugins: - self.plugins_config = PluginsConfig.load_config(self) - class ConfigBuilder(Configurable[Config]): default_plugins_config_file = os.path.join( @@ -213,21 +201,16 @@ class ConfigBuilder(Configurable[Config]): "chat_messages_enabled": os.getenv("CHAT_MESSAGES_ENABLED") == "True", } - # Converting to a list from comma-separated string - disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES") - if disabled_command_categories: - config_dict[ - "disabled_command_categories" - ] = disabled_command_categories.split(",") + config_dict["disabled_command_categories"] = _safe_split( + os.getenv("DISABLED_COMMAND_CATEGORIES") + ) - # Converting to a list from comma-separated string - shell_denylist = os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS")) - if shell_denylist: - config_dict["shell_denylist"] = shell_denylist.split(",") - - shell_allowlist = os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS")) - if shell_allowlist: - config_dict["shell_allowlist"] = shell_allowlist.split(",") + config_dict["shell_denylist"] = _safe_split( + os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS")) + ) + config_dict["shell_allowlist"] = _safe_split( + os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS")) + ) config_dict["google_custom_search_engine_id"] = os.getenv( "GOOGLE_CUSTOM_SEARCH_ENGINE_ID", os.getenv("CUSTOM_SEARCH_ENGINE_ID") @@ -237,13 +220,13 @@ class ConfigBuilder(Configurable[Config]): "ELEVENLABS_VOICE_ID", os.getenv("ELEVENLABS_VOICE_1_ID") ) - plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS") - if plugins_allowlist: - config_dict["plugins_allowlist"] = plugins_allowlist.split(",") - - plugins_denylist = os.getenv("DENYLISTED_PLUGINS") - if plugins_denylist: - config_dict["plugins_denylist"] = plugins_denylist.split(",") + config_dict["plugins_allowlist"] = _safe_split(os.getenv("ALLOWLISTED_PLUGINS")) + config_dict["plugins_denylist"] = _safe_split(os.getenv("DENYLISTED_PLUGINS")) + config_dict["plugins_config"] = PluginsConfig.load_config( + config_dict["plugins_config_file"], + config_dict["plugins_denylist"], + config_dict["plugins_allowlist"], + ) with contextlib.suppress(TypeError): config_dict["image_size"] = int(os.getenv("IMAGE_SIZE")) @@ -325,3 +308,10 @@ def check_openai_api_key(config: Config) -> None: else: print("Invalid OpenAI API key!") exit(1) + + +def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]: + """Split a string by a separator. Return an empty list if the string is None.""" + if s is None: + return [] + return s.split(sep) diff --git a/autogpt/plugins/plugins_config.py b/autogpt/plugins/plugins_config.py index 7fcb5197..13b87130 100644 --- a/autogpt/plugins/plugins_config.py +++ b/autogpt/plugins/plugins_config.py @@ -1,13 +1,9 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Union +from typing import Union import yaml - -if TYPE_CHECKING: - from autogpt.config import Config - from pydantic import BaseModel from autogpt.logs import logger @@ -30,11 +26,20 @@ class PluginsConfig(BaseModel): return plugin_config is not None and plugin_config.enabled @classmethod - def load_config(cls, global_config: Config) -> "PluginsConfig": + def load_config( + cls, + plugins_config_file: str, + plugins_denylist: list[str], + plugins_allowlist: list[str], + ) -> "PluginsConfig": empty_config = cls(plugins={}) try: - config_data = cls.deserialize_config_file(global_config=global_config) + config_data = cls.deserialize_config_file( + plugins_config_file, + plugins_denylist, + plugins_allowlist, + ) if type(config_data) != dict: logger.error( f"Expected plugins config to be a dict, got {type(config_data)}, continuing without plugins" @@ -49,13 +54,21 @@ class PluginsConfig(BaseModel): return empty_config @classmethod - def deserialize_config_file(cls, global_config: Config) -> dict[str, PluginConfig]: - plugins_config_path = global_config.plugins_config_file - if not os.path.exists(plugins_config_path): + def deserialize_config_file( + cls, + plugins_config_file: str, + plugins_denylist: list[str], + plugins_allowlist: list[str], + ) -> dict[str, PluginConfig]: + if not os.path.exists(plugins_config_file): logger.warn("plugins_config.yaml does not exist, creating base config.") - cls.create_empty_plugins_config(global_config=global_config) + cls.create_empty_plugins_config( + plugins_config_file, + plugins_denylist, + plugins_allowlist, + ) - with open(plugins_config_path, "r") as f: + with open(plugins_config_file, "r") as f: plugins_config = yaml.load(f, Loader=yaml.FullLoader) plugins = {} @@ -73,23 +86,27 @@ class PluginsConfig(BaseModel): return plugins @staticmethod - def create_empty_plugins_config(global_config: Config): + def create_empty_plugins_config( + plugins_config_file: str, + plugins_denylist: list[str], + plugins_allowlist: list[str], + ): """Create an empty plugins_config.yaml file. Fill it with values from old env variables.""" base_config = {} - logger.debug(f"Legacy plugin denylist: {global_config.plugins_denylist}") - logger.debug(f"Legacy plugin allowlist: {global_config.plugins_allowlist}") + logger.debug(f"Legacy plugin denylist: {plugins_denylist}") + logger.debug(f"Legacy plugin allowlist: {plugins_allowlist}") # Backwards-compatibility shim - for plugin_name in global_config.plugins_denylist: + for plugin_name in plugins_denylist: base_config[plugin_name] = {"enabled": False, "config": {}} - for plugin_name in global_config.plugins_allowlist: + for plugin_name in plugins_allowlist: base_config[plugin_name] = {"enabled": True, "config": {}} logger.debug(f"Constructed base plugins config: {base_config}") - logger.debug(f"Creating plugin config file {global_config.plugins_config_file}") - with open(global_config.plugins_config_file, "w+") as f: + logger.debug(f"Creating plugin config file {plugins_config_file}") + with open(plugins_config_file, "w+") as f: f.write(yaml.dump(base_config)) return base_config diff --git a/scripts/install_plugin_deps.py b/scripts/install_plugin_deps.py index 00d9f8a3..1cd0bd1a 100644 --- a/scripts/install_plugin_deps.py +++ b/scripts/install_plugin_deps.py @@ -5,6 +5,8 @@ import zipfile from glob import glob from pathlib import Path +from autogpt.logs import logger + def install_plugin_dependencies(): """ @@ -18,28 +20,46 @@ def install_plugin_dependencies(): """ plugins_dir = Path(os.getenv("PLUGINS_DIR", "plugins")) + logger.debug(f"Checking for dependencies in zipped plugins...") + # Install zip-based plugins - for plugin in plugins_dir.glob("*.zip"): - with zipfile.ZipFile(str(plugin), "r") as zfile: - try: - basedir = zfile.namelist()[0] - basereqs = os.path.join(basedir, "requirements.txt") - extracted = zfile.extract(basereqs, path=plugins_dir) - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "-r", extracted] - ) - os.remove(extracted) - os.rmdir(os.path.join(plugins_dir, basedir)) - except KeyError: + for plugin_archive in plugins_dir.glob("*.zip"): + logger.debug(f"Checking for requirements in '{plugin_archive}'...") + with zipfile.ZipFile(str(plugin_archive), "r") as zfile: + if not zfile.namelist(): continue + # Assume the first entry in the list will be (in) the lowest common dir + first_entry = zfile.namelist()[0] + basedir = first_entry.rsplit("/", 1)[0] if "/" in first_entry else "" + logger.debug(f"Looking for requirements.txt in '{basedir}'") + + basereqs = os.path.join(basedir, "requirements.txt") + try: + extracted = zfile.extract(basereqs, path=plugins_dir) + except KeyError as e: + logger.debug(e.args[0]) + continue + + logger.debug(f"Installing dependencies from '{basereqs}'...") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "-r", extracted] + ) + os.remove(extracted) + os.rmdir(os.path.join(plugins_dir, basedir)) + + logger.debug(f"Checking for dependencies in other plugin folders...") + # Install directory-based plugins for requirements_file in glob(f"{plugins_dir}/*/requirements.txt"): + logger.debug(f"Installing dependencies from '{requirements_file}'...") subprocess.check_call( [sys.executable, "-m", "pip", "install", "-r", requirements_file], stdout=subprocess.DEVNULL, ) + logger.debug("Finished installing plugin dependencies") + if __name__ == "__main__": install_plugin_dependencies() diff --git a/tests/conftest.py b/tests/conftest.py index 920fc4e4..f2ca5904 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,7 +59,11 @@ def config( # avoid circular dependency from autogpt.plugins.plugins_config import PluginsConfig - config.plugins_config = PluginsConfig.load_config(global_config=config) + config.plugins_config = PluginsConfig.load_config( + plugins_config_file=config.plugins_config_file, + plugins_denylist=config.plugins_denylist, + plugins_allowlist=config.plugins_allowlist, + ) # Do a little setup and teardown since the config object is a singleton mocker.patch.multiple( diff --git a/tests/unit/test_plugins.py b/tests/unit/test_plugins.py index 24b7d1e9..981715ac 100644 --- a/tests/unit/test_plugins.py +++ b/tests/unit/test_plugins.py @@ -70,7 +70,11 @@ def test_create_base_config(config: Config): config.plugins_denylist = ["c", "d"] os.remove(config.plugins_config_file) - plugins_config = PluginsConfig.load_config(global_config=config) + plugins_config = PluginsConfig.load_config( + plugins_config_file=config.plugins_config_file, + plugins_denylist=config.plugins_denylist, + plugins_allowlist=config.plugins_allowlist, + ) # Check the structure of the plugins config data assert len(plugins_config.plugins) == 4 @@ -102,7 +106,11 @@ def test_load_config(config: Config): f.write(yaml.dump(test_config)) # Load the config from disk - plugins_config = PluginsConfig.load_config(global_config=config) + plugins_config = PluginsConfig.load_config( + plugins_config_file=config.plugins_config_file, + plugins_denylist=config.plugins_denylist, + plugins_allowlist=config.plugins_allowlist, + ) # Check that the loaded config is equal to the test config assert len(plugins_config.plugins) == 2