From 894026cdd4094e94429ecbc022f52df45f773626 Mon Sep 17 00:00:00 2001 From: Evgeny Vakhteev Date: Tue, 18 Apr 2023 12:52:09 -0700 Subject: [PATCH] reshaping code and fixing tests --- autogpt/__main__.py | 4 +- autogpt/models/base_open_ai_plugin.py | 13 +- autogpt/plugins.py | 91 +++---- requirements.txt | 4 +- .../Auto-GPT-Plugin-Test-master.zip | Bin 15284 -> 14927 bytes tests/unit/test_plugins.py | 81 ++++++- tests/unit/test_plugins_gpt_generated.py | 222 ------------------ 7 files changed, 117 insertions(+), 298 deletions(-) delete mode 100644 tests/unit/test_plugins_gpt_generated.py diff --git a/autogpt/__main__.py b/autogpt/__main__.py index d694fd59..f8d20487 100644 --- a/autogpt/__main__.py +++ b/autogpt/__main__.py @@ -11,7 +11,7 @@ from autogpt.logs import logger from autogpt.memory import get_memory from autogpt.prompts.prompt import construct_main_ai_config -from autogpt.plugins import load_plugins +from autogpt.plugins import scan_plugins # Load environment variables from .env file @@ -24,7 +24,7 @@ def main() -> None: check_openai_api_key() parse_arguments() logger.set_level(logging.DEBUG if cfg.debug_mode else logging.INFO) - cfg.set_plugins(load_plugins(cfg, cfg.debug_mode)) + cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # Create a CommandRegistry instance and scan default folder command_registry = CommandRegistry() command_registry.import_commands("scripts.ai_functions") diff --git a/autogpt/models/base_open_ai_plugin.py b/autogpt/models/base_open_ai_plugin.py index 3aafff84..fafd3932 100644 --- a/autogpt/models/base_open_ai_plugin.py +++ b/autogpt/models/base_open_ai_plugin.py @@ -2,6 +2,8 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict from typing import TypeVar +from auto_gpt_plugin_template import AutoGPTPluginTemplate + PromptGenerator = TypeVar("PromptGenerator") @@ -10,9 +12,9 @@ class Message(TypedDict): content: str -class BaseOpenAIPlugin: +class BaseOpenAIPlugin(AutoGPTPluginTemplate): """ - This is a template for Auto-GPT plugins. + This is a BaseOpenAIPlugin class for generating Auto-GPT plugins. """ def __init__(self, manifests_specs_clients: dict): @@ -20,9 +22,9 @@ class BaseOpenAIPlugin: self._name = manifests_specs_clients["manifest"]["name_for_model"] self._version = manifests_specs_clients["manifest"]["schema_version"] self._description = manifests_specs_clients["manifest"]["description_for_model"] - self.client = manifests_specs_clients["client"] - self.manifest = manifests_specs_clients["manifest"] - self.openapi_spec = manifests_specs_clients["openapi_spec"] + self._client = manifests_specs_clients["client"] + self._manifest = manifests_specs_clients["manifest"] + self._openapi_spec = manifests_specs_clients["openapi_spec"] def can_handle_on_response(self) -> bool: """This method is called to check that the plugin can @@ -196,4 +198,3 @@ class BaseOpenAIPlugin: str: The resulting response. """ pass - diff --git a/autogpt/plugins.py b/autogpt/plugins.py index 2455a89e..974adddc 100644 --- a/autogpt/plugins.py +++ b/autogpt/plugins.py @@ -176,7 +176,7 @@ def instantiate_openai_plugin_clients(manifests_specs_clients: dict, cfg: Config def scan_plugins(cfg: Config, debug: bool = False) -> List[Tuple[str, Path]]: - """Scan the plugins directory for plugins. + """Scan the plugins directory for plugins and loads them. Args: cfg (Config): Config instance including plugins config @@ -185,46 +185,37 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[Tuple[str, Path]]: Returns: List[Tuple[str, Path]]: List of plugins. """ - plugins = [] + loaded_plugins = [] # Generic plugins plugins_path_path = Path(cfg.plugins_dir) for plugin in plugins_path_path.glob("*.zip"): if module := inspect_zip_for_module(str(plugin), debug): - plugins.append((module, plugin)) + plugin = Path(plugin) + module = Path(module) + if debug: + print(f"Plugin: {plugin} Module: {module}") + zipped_package = zipimporter(plugin) + zipped_module = zipped_package.load_module(str(module.parent)) + for key in dir(zipped_module): + if key.startswith("__"): + continue + a_module = getattr(zipped_module, key) + a_keys = dir(a_module) + if ( + "_abc_impl" in a_keys + and a_module.__name__ != "AutoGPTPluginTemplate" + and blacklist_whitelist_check(a_module.__name__, cfg) + ): + loaded_plugins.append(a_module()) # OpenAI plugins if cfg.plugins_openai: manifests_specs = fetch_openai_plugins_manifest_and_spec(cfg) if manifests_specs.keys(): manifests_specs_clients = initialize_openai_plugins(manifests_specs, cfg, debug) for url, openai_plugin_meta in manifests_specs_clients.items(): - plugin = BaseOpenAIPlugin(openai_plugin_meta) - plugins.append((plugin, url)) - return plugins - - -def blacklist_whitelist_check(plugins: List[AbstractSingleton], cfg: Config): - """Check if the plugin is in the whitelist or blacklist. - - Args: - plugins (List[Tuple[str, Path]]): List of plugins. - cfg (Config): Config object. - - Returns: - List[Tuple[str, Path]]: List of plugins. - """ - loaded_plugins = [] - for plugin in plugins: - if plugin.__name__ in cfg.plugins_blacklist: - continue - if plugin.__name__ in cfg.plugins_whitelist: - loaded_plugins.append(plugin()) - else: - ack = input( - f"WARNNG Plugin {plugin.__name__} found. But not in the" - " whitelist... Load? (y/n): " - ) - if ack.lower() == "y": - loaded_plugins.append(plugin()) + if blacklist_whitelist_check(url, cfg): + plugin = BaseOpenAIPlugin(openai_plugin_meta) + loaded_plugins.append(plugin) if loaded_plugins: print(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------") @@ -232,30 +223,22 @@ def blacklist_whitelist_check(plugins: List[AbstractSingleton], cfg: Config): print(f"{plugin._name}: {plugin._version} - {plugin._description}") return loaded_plugins - -def load_plugins(cfg: Config = Config(), debug: bool = False) -> List[object]: - """Load plugins from the plugins directory. +def blacklist_whitelist_check(plugin_name: str, cfg: Config) -> bool: + """Check if the plugin is in the whitelist or blacklist. Args: - cfg (Config): Config instance including plugins config - debug (bool, optional): Enable debug logging. Defaults to False. + plugin_name (str): Name of the plugin. + cfg (Config): Config object. + Returns: - List[AbstractSingleton]: List of plugins initialized. + True or False """ - plugins = scan_plugins(cfg) - plugin_modules = [] - for module, plugin in plugins: - plugin = Path(plugin) - module = Path(module) - if debug: - print(f"Plugin: {plugin} Module: {module}") - zipped_package = zipimporter(plugin) - zipped_module = zipped_package.load_module(str(module.parent)) - for key in dir(zipped_module): - if key.startswith("__"): - continue - a_module = getattr(zipped_module, key) - a_keys = dir(a_module) - if "_abc_impl" in a_keys and a_module.__name__ != "AutoGPTPluginTemplate": - plugin_modules.append(a_module) - return blacklist_whitelist_check(plugin_modules, cfg) + if plugin_name in cfg.plugins_blacklist: + return False + if plugin_name in cfg.plugins_whitelist: + return True + ack = input( + f"WARNNG Plugin {plugin_name} found. But not in the" + " whitelist... Load? (y/n): " + ) + return ack.lower() == "y" diff --git a/requirements.txt b/requirements.txt index 6583d65a..9f015f92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,7 +30,7 @@ pytest-mock tweepy -# OpenAI plugins import +# OpenAI and Generic plugins import openapi-python-client==0.13.4 abstract-singleton -auto-vicuna +auto-gpt-plugin-template diff --git a/tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip b/tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip index 7a6421af1c23f354c0915f8b729343dc8de32e5f..00bc1f4f58dc1c8e07c7ca3adce6be605fe6a3ce 100644 GIT binary patch delta 2349 zcmZuz3p7;w8sBEfL{o#Qlo;xS#^cWHd7I)8m7&IX3&$hVcok-pL=CxQQo^PZlF$*w zaYkA?bUg|kd2|z%bBw&=;CSWMRdU9j?paHB_FjAK{rj!o_kF+h{nz@wT@}?CiFU9< zDyjf0*RWqkIvP~lJ!u~BKN1C#4ttPNvNG+G0#d4hbP*OS81hGKhw*s?1IA5=GcewZ z+`WRu=a9WfnDA1D2`ESTcNoW@Y+*qQ%3VQCn$Z>cC)U6k2>_XE000khkh+kAiaIRe zs<34_9-!eO;(2r;LY4%>gsawIgfM2RnZr0%jSb@tH8VJas`_>;Oc3Iw(_x@)8FN_&(VEY~*?P>S_T;cT6x zVOIpFIq zG{N0L!^sh1pyb@YzpUyuM8m`F6-VQhWc`jeqOLj=_@tm#qkXs)FROxGlY4?KIf?+F zss#W{*fYTl3;7yjge%&Y?vk)RD(375pNJQ2!;Qp;y}$&<*WJN)Z$iVE9RYPWw)%co zk*?#h&ywi_BCJhYW!9^G`c;*Whn5yU&ZatMe3x>tAnwaJx)G6;l+in3RnKw_cMFLY zEdO-Pz!vmy;j5l@+vJaaQk|KV#H`F|n6@eB&F~7wrmqSeH#-u{6KuTsMfu($9}h}U z)t1+7J;fbcaoNqso@Vb0yF5JnX9H{Vy>KvE_jgB0Vf}?5l?gW<_2706m}jrg;IX1T zeNt*S8EcdG)a3!*`po%>o0?Yj>?)d_m*tg>x4Jy8;D^h4GHRT?{j(0e`@ww5nk=}n zw%X=by@+aI*mVpP_$KNuHa_U++1zDtS+nhG$z17AaXAdrQ)@0U20-4nYQCtpJy6Fw zpOk*=nmhklQ7hiL7^{-wH883X3uQzxcABuOOKAL{-_nky+I6YVp-0))xZ5V@56`qd z>OKD1J!qheLe2jeEi5cIDlmI6`g45!ty0m>veqV{THA5&b{CJn8Ra1rCU&r!H=q+{ zIUc+Q=YHSy{rubU^EKkGv;oC{v!M=M{vWR$aP&)k5|u&m6rHOsdyxY1hYcNnq)sJp zTob#3=0+-F2+I}|`J%Dh&@Y*h#(P;24pSsGYb0r2NMds1REwRfm{YeO(l;Y~c%zMR zBSS*3WT$+~`EXTEdgV3iBL?;JVF#aN-6OkFn@<#8Fg;`u(ZjCMFZSyF?o9)&=z90< zvHQ6Nm^$WQu)PA);*vm*&I}2m#fIt)PHxa#OPKwzbK|?c+~f;CW4e9HIZ3${^H&RA z&DnBWf9x-tQ)bY246!OHrJjAKad!-tj<}czL(4bxz15ge6^-2LL}xOhM@pWPw_Jae ziRfgB9eTNncREX64!vK;AK_gLa21Oyp4_|fAO&0f{PP=S-QdU_i-+AB`kS#5O{bxM zdGKf_*O_YEaVn{y5+h58N;rDjFIs=d%h47@^WQJ!<>)l^uytF3bqPD0{fF|KU9J#7 zpv7mS?7m$yQ-^LFdO6Pm3P?G9OjxHX8^eG?BOV6-|m3wbkz6 zgx0rvT0_iUx=e6}XVjPeII%GZB@#s5b4!1>L^cW*d!ByCDhxj3ebd(F3OmHXbYbkD zzpd}<#7}uL>q>h)vDaCt|210(RGxlpN*C(&O*>t>X~7L}5@e@g@`)9f>&_=WPYUIq ze0FNteG-bV!UdI$kF;zWOU(RW1qcd1^|VC1oRpNMQ^HS*RH!$~BmR>>J`+RyW%sB4 z{UfAAvk(cvo-}q9Snkdv+p~pL#t$SVCT$7Pqhe1n>g==F6lz%U37d(^!HjWV^Wd0{ zs#q&l+TwmAIW>(MSIL*#*DIsKHKV{UlG)LV+;|B-S?F7H8=Usp`Em8;)CvzN;sryY>XKc&g9i>L4VMlb760^LSiPCs4$USQD3 z|3%6=XwJ|=fOu$@tOG?e)a53>!O&8Oh1AS;kX@XW0EN|c;QFr81N{3U0BgT~W2EmU zX&{}jItH%1dQp9iQGg@28pD3R?yYY;$!vF+X(8GPuq+f@FIR$RY9YLcu;denwvt&; zlKx)+z$WRv4Fv!TL;x@ZKq8Gypm2#aDtOe-j|dV8U;vpwCmx~E2}c8HWG7N)C3H9`p`3|j2uBHXJ6C8pa{ z_H8l|B3w(BOG2ZR`j2n!|9_6o{J-z~zH^@EJn#EH=lz}Yy>HK%nzO==HZUFmVDq6y zr=}!|<{JTVG{OoLv>=EOF?L3G*rLQ9GZ+A5@B)B3#R-O@I0}e^k{|&GPL9XI zU=iwy@Kq=$iKqcswO*7AVgs=QAigN(0OD>jeK3Qd_$~>M(AsSb5~SVooD3PNdSFD4 zsyQgwp{5IBCp9b36~S9v-jdxHH}=tByHcREt09S-?Saz|io|Fkc%Faz(bOJ@c3r_YIK{FciAZ-~M`TPSRL*b?jxwf*+#85{tZ3IJdLdfh-vP(n0B z$(!0f&vz#$@S>O2JVY;~OQxi1u2&<*;ITfv@T3!rr?YWlh#JGXzSazsvwg4mPnxOq z{X>US7GB|<6|bv^5N=}=Os$1t{Te&<=td8DB5#-Aitl&Mn>V*MmbLz-nUtLUE<%=v zj`GShS2;xITX{fCOHT^Qv$J}4G5Hqxrm^8|X|EYp7|T%ED7aiI-4~nklZ-ZZ*Sfc@ z?a5x!^ZXCuiy^oW3m2Wcm6>}Y(knxn)6M4F1@&#*n}n9Tp3SBUj5xgC1CNW}C-|CK z>laq0mw9k(*~9P5IHr6VR-b)>6wz@lD(ju{JDq%wY#jj; zQctoAyl*N3bC58LqocaiK3I?_8P#tc&)`t!+9q2^^Sr0ctw}{FC0Zg*;7ym%w6YZ@ zD^Z!BFuBqn+j#DX3DV+Wnq@8`x|zQi{d@pg^{wcW;?h&Ac$KAr70eGQ-}PnbrJ8mL zt3U-gi3P#(L{*k+C$PQB+iIRHEqn-n3M1ChqBp^RgsD-=lznmCxEo@SRa<(0+I@op)PK)&Eb z+EM_KX?R0Fv?N2trahQ2_K&UlSP)%L$4qHQep`*|a?0t0Ub*h5tR~w?V0->cp@ngw zgjGml3d!dYEeJoSTP$N>9Xc4A*_2!QS6mso`A^xH45dKtzJk^X;y|INgT;9_Yh5SQ zS=BFv+Y&D~?(%juUB2Q@7ECcxke8pzjpc<$qgI~qT*EefiBA`BZtYhM3&rY>Og6Kq zuVR_mC6C64>#jHT&u3|PS#~?$njZ07HdRQU9#1-l^E&!>?E<7{6t^S38{UnYohXtP z7;<2~&aG$ezrca5U-YfqR*yj0`dGKkQ%jH+CY&79~ITmFH-e?nwyy zIIueN_8H&73w|@2+5$2wXNW-$_lNnsXBD(97R7u#4b6*pll9@HuDma1$W#bPej;Jo zaA&`}{o`ghNB1+hDRSig8=o~fv9`~nN-hDyeOq=xK>qd*Kp38-jGz4d-~Y)WOm=%GS;P?zD4s$tt%>)m#^oPHHWdn7gGC5 zTP?<-2&bKxX$7KrsH(HHgu(Yi;k_3E8yYn6^i;;@8oA?>4dSaNtulK>&ntN^Bt*D% zk5be6Fi}YYQF`Pf8#B?z*IE8i7g)4g?WNB?K0YNCseRfe=KNgL&Pcm_)vSsML{nr% z{&!TrGQFm4+s1NL%^z9{D~$y&<{a|HKX+}%9jipj<}yS|XxDW;H+n@{gd4ktrtA8R zkM4PVd_Xz&>NbAlfd#;3tQ=pKChEjopLJQ;82H1(3sLivtTO`j=&{Xws-HqTds%J? zNqlLnAtmxo>gUH0QqBVxB%yUB@6?Ghig6v!+T3vEiI+7}%0F;Qq@*C9HL2emGtH$f zD8@8LEwSNj$a14o)M}6X$Vwq<$^KMoc;PZ85>8=|1{=dYn~P6G46n(-#x-=aW?(fu`& z)4@Sq8K}BCMG+%U30D{Anw+2~13iP@l*p1I1NOPO!U4$EVjL5xY+VDWqimq1wYEyY zYFn29fS-Nu0005M;u@fbe|Ljld~mEDC&)lCm`z!_J}*UB8^N{29a>PoJv*fn`zUWI zLH{>a*o0#Z&fe(&K$88+A^_kp764QMf}W0+wl4N)fPbK_HZ~v-AE2X+4