reshaping code and fixing tests

This commit is contained in:
Evgeny Vakhteev
2023-04-18 12:52:09 -07:00
parent 9fd80a8660
commit 894026cdd4
7 changed files with 117 additions and 298 deletions

View File

@@ -176,7 +176,7 @@ def instantiate_openai_plugin_clients(manifests_specs_clients: dict, cfg: Config
def scan_plugins(cfg: Config, debug: bool = False) -> List[Tuple[str, Path]]:
"""Scan the plugins directory for plugins.
"""Scan the plugins directory for plugins and loads them.
Args:
cfg (Config): Config instance including plugins config
@@ -185,46 +185,37 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[Tuple[str, Path]]:
Returns:
List[Tuple[str, Path]]: List of plugins.
"""
plugins = []
loaded_plugins = []
# Generic plugins
plugins_path_path = Path(cfg.plugins_dir)
for plugin in plugins_path_path.glob("*.zip"):
if module := inspect_zip_for_module(str(plugin), debug):
plugins.append((module, plugin))
plugin = Path(plugin)
module = Path(module)
if debug:
print(f"Plugin: {plugin} Module: {module}")
zipped_package = zipimporter(plugin)
zipped_module = zipped_package.load_module(str(module.parent))
for key in dir(zipped_module):
if key.startswith("__"):
continue
a_module = getattr(zipped_module, key)
a_keys = dir(a_module)
if (
"_abc_impl" in a_keys
and a_module.__name__ != "AutoGPTPluginTemplate"
and blacklist_whitelist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
# OpenAI plugins
if cfg.plugins_openai:
manifests_specs = fetch_openai_plugins_manifest_and_spec(cfg)
if manifests_specs.keys():
manifests_specs_clients = initialize_openai_plugins(manifests_specs, cfg, debug)
for url, openai_plugin_meta in manifests_specs_clients.items():
plugin = BaseOpenAIPlugin(openai_plugin_meta)
plugins.append((plugin, url))
return plugins
def blacklist_whitelist_check(plugins: List[AbstractSingleton], cfg: Config):
"""Check if the plugin is in the whitelist or blacklist.
Args:
plugins (List[Tuple[str, Path]]): List of plugins.
cfg (Config): Config object.
Returns:
List[Tuple[str, Path]]: List of plugins.
"""
loaded_plugins = []
for plugin in plugins:
if plugin.__name__ in cfg.plugins_blacklist:
continue
if plugin.__name__ in cfg.plugins_whitelist:
loaded_plugins.append(plugin())
else:
ack = input(
f"WARNNG Plugin {plugin.__name__} found. But not in the"
" whitelist... Load? (y/n): "
)
if ack.lower() == "y":
loaded_plugins.append(plugin())
if blacklist_whitelist_check(url, cfg):
plugin = BaseOpenAIPlugin(openai_plugin_meta)
loaded_plugins.append(plugin)
if loaded_plugins:
print(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
@@ -232,30 +223,22 @@ def blacklist_whitelist_check(plugins: List[AbstractSingleton], cfg: Config):
print(f"{plugin._name}: {plugin._version} - {plugin._description}")
return loaded_plugins
def load_plugins(cfg: Config = Config(), debug: bool = False) -> List[object]:
"""Load plugins from the plugins directory.
def blacklist_whitelist_check(plugin_name: str, cfg: Config) -> bool:
"""Check if the plugin is in the whitelist or blacklist.
Args:
cfg (Config): Config instance including plugins config
debug (bool, optional): Enable debug logging. Defaults to False.
plugin_name (str): Name of the plugin.
cfg (Config): Config object.
Returns:
List[AbstractSingleton]: List of plugins initialized.
True or False
"""
plugins = scan_plugins(cfg)
plugin_modules = []
for module, plugin in plugins:
plugin = Path(plugin)
module = Path(module)
if debug:
print(f"Plugin: {plugin} Module: {module}")
zipped_package = zipimporter(plugin)
zipped_module = zipped_package.load_module(str(module.parent))
for key in dir(zipped_module):
if key.startswith("__"):
continue
a_module = getattr(zipped_module, key)
a_keys = dir(a_module)
if "_abc_impl" in a_keys and a_module.__name__ != "AutoGPTPluginTemplate":
plugin_modules.append(a_module)
return blacklist_whitelist_check(plugin_modules, cfg)
if plugin_name in cfg.plugins_blacklist:
return False
if plugin_name in cfg.plugins_whitelist:
return True
ack = input(
f"WARNNG Plugin {plugin_name} found. But not in the"
" whitelist... Load? (y/n): "
)
return ack.lower() == "y"