adding plugin interface instantiation

This commit is contained in:
Evgeny Vakhteev
2023-04-17 17:13:53 -07:00
parent 7f4e38844f
commit 9ed5e0f1fc
3 changed files with 241 additions and 8 deletions

View File

@@ -1,21 +1,218 @@
"""Handles loading of plugins."""
import importlib
import json
import mimetypes
import os
import zipfile
from pathlib import Path
from urllib.parse import urlparse
from zipimport import zipimporter
from typing import List, Optional, Tuple
import openapi_python_client
import requests
from abstract_singleton import AbstractSingleton
from openapi_python_client.cli import _process_config, Config as OpenAPIConfig
import abc
from pathlib import Path
from typing import TypeVar
from urllib.parse import urlparse
from zipimport import zipimporter
from openapi_python_client.cli import Config as OpenAPIConfig
from typing import Any, Dict, List, Optional, Tuple, TypedDict
from abstract_singleton import AbstractSingleton, Singleton
from autogpt.config import Config
PromptGenerator = TypeVar("PromptGenerator")
class Message(TypedDict):
role: str
content: str
class BaseOpenAIPlugin():
"""
This is a template for Auto-GPT plugins.
"""
def __init__(self, manifests_specs_clients: dict):
# super().__init__()
self._name = manifests_specs_clients["manifest"]["name_for_model"]
self._version = manifests_specs_clients["manifest"]["schema_version"]
self._description = manifests_specs_clients["manifest"]["description_for_model"]
self.client = manifests_specs_clients["client"]
self.manifest = manifests_specs_clients["manifest"]
self.openapi_spec = manifests_specs_clients["openapi_spec"]
def can_handle_on_response(self) -> bool:
"""This method is called to check that the plugin can
handle the on_response method.
Returns:
bool: True if the plugin can handle the on_response method."""
return False
def on_response(self, response: str, *args, **kwargs) -> str:
"""This method is called when a response is received from the model."""
pass
def can_handle_post_prompt(self) -> bool:
"""This method is called to check that the plugin can
handle the post_prompt method.
Returns:
bool: True if the plugin can handle the post_prompt method."""
return False
def post_prompt(self, prompt: PromptGenerator) -> PromptGenerator:
"""This method is called just after the generate_prompt is called,
but actually before the prompt is generated.
Args:
prompt (PromptGenerator): The prompt generator.
Returns:
PromptGenerator: The prompt generator.
"""
pass
def can_handle_on_planning(self) -> bool:
"""This method is called to check that the plugin can
handle the on_planning method.
Returns:
bool: True if the plugin can handle the on_planning method."""
return False
def on_planning(
self, prompt: PromptGenerator, messages: List[Message]
) -> Optional[str]:
"""This method is called before the planning chat completion is done.
Args:
prompt (PromptGenerator): The prompt generator.
messages (List[str]): The list of messages.
"""
pass
def can_handle_post_planning(self) -> bool:
"""This method is called to check that the plugin can
handle the post_planning method.
Returns:
bool: True if the plugin can handle the post_planning method."""
return False
def post_planning(self, response: str) -> str:
"""This method is called after the planning chat completion is done.
Args:
response (str): The response.
Returns:
str: The resulting response.
"""
pass
def can_handle_pre_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the pre_instruction method.
Returns:
bool: True if the plugin can handle the pre_instruction method."""
return False
def pre_instruction(self, messages: List[Message]) -> List[Message]:
"""This method is called before the instruction chat is done.
Args:
messages (List[Message]): The list of context messages.
Returns:
List[Message]: The resulting list of messages.
"""
pass
def can_handle_on_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the on_instruction method.
Returns:
bool: True if the plugin can handle the on_instruction method."""
return False
def on_instruction(self, messages: List[Message]) -> Optional[str]:
"""This method is called when the instruction chat is done.
Args:
messages (List[Message]): The list of context messages.
Returns:
Optional[str]: The resulting message.
"""
pass
def can_handle_post_instruction(self) -> bool:
"""This method is called to check that the plugin can
handle the post_instruction method.
Returns:
bool: True if the plugin can handle the post_instruction method."""
return False
def post_instruction(self, response: str) -> str:
"""This method is called after the instruction chat is done.
Args:
response (str): The response.
Returns:
str: The resulting response.
"""
pass
def can_handle_pre_command(self) -> bool:
"""This method is called to check that the plugin can
handle the pre_command method.
Returns:
bool: True if the plugin can handle the pre_command method."""
return False
def pre_command(
self, command_name: str, arguments: Dict[str, Any]
) -> Tuple[str, Dict[str, Any]]:
"""This method is called before the command is executed.
Args:
command_name (str): The command name.
arguments (Dict[str, Any]): The arguments.
Returns:
Tuple[str, Dict[str, Any]]: The command name and the arguments.
"""
pass
def can_handle_post_command(self) -> bool:
"""This method is called to check that the plugin can
handle the post_command method.
Returns:
bool: True if the plugin can handle the post_command method."""
return False
def post_command(self, command_name: str, response: str) -> str:
"""This method is called after the command is executed.
Args:
command_name (str): The command name.
response (str): The response.
Returns:
str: The resulting response.
"""
pass
def can_handle_chat_completion(
self, messages: Dict[Any, Any], model: str, temperature: float, max_tokens: int
) -> bool:
"""This method is called to check that the plugin can
handle the chat_completion method.
Args:
messages (List[Message]): The messages.
model (str): The model name.
temperature (float): The temperature.
max_tokens (int): The max tokens.
Returns:
bool: True if the plugin can handle the chat_completion method."""
return False
def handle_chat_completion(
self, messages: List[Message], model: str, temperature: float, max_tokens: int
) -> str:
"""This method is called when the chat completion is done.
Args:
messages (List[Message]): The messages.
model (str): The model name.
temperature (float): The temperature.
max_tokens (int): The max tokens.
Returns:
str: The resulting response.
"""
pass
def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
"""
@@ -37,6 +234,8 @@ def inspect_zip_for_module(zip_path: str, debug: bool = False) -> Optional[str]:
if debug:
print(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
return None
def write_dict_to_json_file(data: dict, file_path: str):
"""
Write a dictionary to a JSON file.
@@ -156,6 +355,29 @@ def initialize_openai_plugins(manifests_specs: dict, cfg: Config, debug: bool =
return manifests_specs
def instantiate_openai_plugin_clients(manifests_specs_clients: dict, cfg: Config, debug: bool = False) -> dict:
"""
Instantiates BaseOpenAIPluginClient instances for each OpenAI plugin.
Args:
manifests_specs_clients (dict): per url dictionary of manifest, spec and client.
cfg (Config): Config instance including plugins config
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
plugins (dict): per url dictionary of BaseOpenAIPluginClient instances.
"""
plugins = {}
for url, manifest_spec_client in manifests_specs_clients.items():
plugins[url] = BaseOpenAIPluginClient(
manifest=manifest_spec_client['manifest'],
openapi_spec=manifest_spec_client['openapi_spec'],
client=manifest_spec_client['client'],
cfg=cfg,
debug=debug
)
return plugins
def scan_plugins(cfg: Config, debug: bool = False) -> List[Tuple[str, Path]]:
"""Scan the plugins directory for plugins.
@@ -177,6 +399,11 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[Tuple[str, Path]]:
manifests_specs = fetch_openai_plugins_manifest_and_spec(cfg)
if manifests_specs.keys():
manifests_specs_clients = initialize_openai_plugins(manifests_specs, cfg, debug)
for url, openai_plugin_meta in manifests_specs_clients.items():
plugin = BaseOpenAIPlugin(openai_plugin_meta)
plugins.append((plugin, url))
return plugins