moving load plugins into plugins from main, adding tests

This commit is contained in:
Evgeny Vakhteev
2023-04-17 09:33:01 -07:00
parent 167628c696
commit 08ad320d19
5 changed files with 85 additions and 33 deletions

View File

@@ -1,11 +1,15 @@
"""Handles loading of plugins."""
from ast import Module
import os
import zipfile
from glob import glob
from pathlib import Path
from zipimport import zipimporter
from typing import List, Optional, Tuple
from abstract_singleton import AbstractSingleton
from autogpt.config import Config
def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
"""
@@ -29,32 +33,66 @@ def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
return None
def scan_plugins(plugins_path: Path, debug: bool = False) -> List[Tuple[str, Path]]:
def scan_plugins(plugins_path: str, debug: bool = False) -> List[Tuple[str, Path]]:
"""Scan the plugins directory for plugins.
Args:
plugins_path (Path): Path to the plugins directory.
plugins_path (str): Path to the plugins directory.
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
List[Path]: List of plugins.
List[Tuple[str, Path]]: List of plugins.
"""
plugins = []
for plugin in plugins_path.glob("*.zip"):
plugins_path_path = Path(plugins_path)
for plugin in plugins_path_path.glob("*.zip"):
if module := inspect_zip_for_module(str(plugin), debug):
plugins.append((module, plugin))
return plugins
def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]:
def blacklist_whitelist_check(plugins: List[AbstractSingleton], cfg: Config):
"""Check if the plugin is in the whitelist or blacklist.
Args:
plugins (List[Tuple[str, Path]]): List of plugins.
cfg (Config): Config object.
Returns:
List[Tuple[str, Path]]: List of plugins.
"""
loaded_plugins = []
for plugin in plugins:
if plugin.__name__ in cfg.plugins_blacklist:
continue
if plugin.__name__ in cfg.plugins_whitelist:
loaded_plugins.append(plugin())
else:
ack = input(
f"WARNNG Plugin {plugin.__name__} found. But not in the"
" whitelist... Load? (y/n): "
)
if ack.lower() == "y":
loaded_plugins.append(plugin())
if loaded_plugins:
print(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
for plugin in loaded_plugins:
print(f"{plugin._name}: {plugin._version} - {plugin._description}")
return loaded_plugins
def load_plugins(cfg: Config = Config(), debug: bool = False) -> List[object]:
"""Load plugins from the plugins directory.
Args:
plugins_path (Path): Path to the plugins directory.
cfg (Config): Config instance inluding plugins config
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
List[Path]: List of plugins.
List[AbstractSingleton]: List of plugins initialized.
"""
plugins = scan_plugins(plugins_path)
plugins = scan_plugins(cfg.plugins_dir)
plugin_modules = []
for module, plugin in plugins:
plugin = Path(plugin)
@@ -70,4 +108,5 @@ def load_plugins(plugins_path: Path, debug: bool = False) -> List[Module]:
a_keys = dir(a_module)
if "_abc_impl" in a_keys and a_module.__name__ != "AutoGPTPluginTemplate":
plugin_modules.append(a_module)
return plugin_modules
loaded_plugin_modules = blacklist_whitelist_check(plugin_modules, cfg)
return loaded_plugin_modules