Fix Config type hint problems caused by #4803 (#4840)

Co-authored-by: Luke <2609441+lc0rp@users.noreply.github.com>
This commit is contained in:
Reinier van der Leer
2023-06-30 14:15:00 +02:00
committed by GitHub
parent 975094fcdd
commit 5070cc32ac
16 changed files with 109 additions and 64 deletions

View File

@@ -2,6 +2,7 @@ import json
import signal
import sys
from datetime import datetime
from pathlib import Path
from colorama import Fore, Style
@@ -64,7 +65,7 @@ class Agent:
ai_config: AIConfig,
system_prompt: str,
triggering_prompt: str,
workspace_directory: str,
workspace_directory: str | Path,
config: Config,
):
self.ai_name = ai_name

View File

@@ -1,11 +1,12 @@
"""
This module contains the configuration classes for AutoGPT.
"""
from autogpt.config.ai_config import AIConfig
from autogpt.config.config import Config, check_openai_api_key
from .ai_config import AIConfig
from .config import Config, ConfigBuilder, check_openai_api_key
__all__ = [
"check_openai_api_key",
"AIConfig",
"Config",
"ConfigBuilder",
]

View File

@@ -1,4 +1,6 @@
"""Configuration class to store the state of bools for different scripts access."""
from __future__ import annotations
import contextlib
import os
import re
@@ -8,21 +10,22 @@ import yaml
from colorama import Fore
from autogpt.core.configuration.schema import Configurable, SystemSettings
from autogpt.plugins.plugins_config import PluginsConfig
AZURE_CONFIG_FILE = os.path.join(os.path.dirname(__file__), "../..", "azure.yaml")
from typing import Optional
class ConfigSettings(SystemSettings):
class Config(SystemSettings):
fast_llm_model: str
smart_llm_model: str
continuous_mode: bool
skip_news: bool
workspace_path: Optional[str]
file_logger_path: Optional[str]
workspace_path: Optional[str] = None
file_logger_path: Optional[str] = None
debug_mode: bool
plugins_dir: str
plugins_config: dict[str, str]
plugins_config: PluginsConfig
continuous_limit: int
speak_mode: bool
skip_reprompt: bool
@@ -37,31 +40,31 @@ class ConfigSettings(SystemSettings):
prompt_settings_file: str
embedding_model: str
browse_spacy_language_model: str
openai_api_key: Optional[str]
openai_organization: Optional[str]
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
temperature: float
use_azure: bool
execute_local_commands: bool
restrict_to_workspace: bool
openai_api_type: Optional[str]
openai_api_base: Optional[str]
openai_api_version: Optional[str]
openai_api_type: Optional[str] = None
openai_api_base: Optional[str] = None
openai_api_version: Optional[str] = None
openai_functions: bool
elevenlabs_api_key: Optional[str]
elevenlabs_api_key: Optional[str] = None
streamelements_voice: str
text_to_speech_provider: str
github_api_key: Optional[str]
github_username: Optional[str]
google_api_key: Optional[str]
google_custom_search_engine_id: Optional[str]
image_provider: Optional[str]
github_api_key: Optional[str] = None
github_username: Optional[str] = None
google_api_key: Optional[str] = None
google_custom_search_engine_id: Optional[str] = None
image_provider: Optional[str] = None
image_size: int
huggingface_api_token: Optional[str]
huggingface_api_token: Optional[str] = None
huggingface_image_model: str
audio_to_text_provider: str
huggingface_audio_to_text_model: Optional[str]
sd_webui_url: Optional[str]
sd_webui_auth: Optional[str]
huggingface_audio_to_text_model: Optional[str] = None
sd_webui_url: Optional[str] = None
sd_webui_auth: Optional[str] = None
selenium_web_browser: str
selenium_headless: bool
user_agent: str
@@ -76,12 +79,17 @@ class ConfigSettings(SystemSettings):
plugins_openai: list[str]
plugins_config_file: str
chat_messages_enabled: bool
elevenlabs_voice_id: Optional[str]
elevenlabs_voice_id: Optional[str] = None
plugins: list[str]
authorise_key: str
# Executed immediately after init by Pydantic
def model_post_init(self, **kwargs) -> None:
if not self.plugins_config.plugins:
self.plugins_config = PluginsConfig.load_config(self)
class Config(Configurable):
class ConfigBuilder(Configurable[Config]):
default_plugins_config_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "..", "plugins_config.yaml"
)
@@ -96,7 +104,7 @@ class Config(Configurable):
else:
default_tts_provider = "gtts"
defaults_settings = ConfigSettings(
defaults_settings = Config(
name="Default Server Config",
description="This is a default server configuration",
smart_llm_model="gpt-3.5-turbo",
@@ -106,7 +114,7 @@ class Config(Configurable):
skip_news=False,
debug_mode=False,
plugins_dir="plugins",
plugins_config={},
plugins_config=PluginsConfig({}),
speak_mode=False,
skip_reprompt=False,
allow_downloads=False,

View File

@@ -1,7 +1,7 @@
import abc
import copy
import typing
from typing import Any
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
@@ -22,22 +22,26 @@ class SystemSettings(BaseModel, abc.ABC):
description: typing.Optional[str]
class Config:
arbitrary_types_allowed = True
extra = "forbid"
use_enum_values = True
class Configurable(abc.ABC):
S = TypeVar("S", bound=SystemSettings)
class Configurable(abc.ABC, Generic[S]):
"""A base class for all configurable objects."""
prefix: str = ""
defaults_settings: typing.ClassVar[SystemSettings]
defaults_settings: typing.ClassVar[S]
@classmethod
def get_user_config(cls) -> dict[str, Any]:
return _get_user_config_fields(cls.defaults_settings)
@classmethod
def build_agent_configuration(cls, configuration: dict = {}) -> SystemSettings:
def build_agent_configuration(cls, configuration: dict = {}) -> S:
"""Process the configuration for this object."""
defaults_settings = cls.defaults_settings.dict()

View File

@@ -1,15 +1,19 @@
"""Logging module for Auto-GPT."""
from __future__ import annotations
import logging
import os
import random
import re
import time
from logging import LogRecord
from typing import Any
from typing import TYPE_CHECKING, Any
from colorama import Fore, Style
from autogpt.config import Config
if TYPE_CHECKING:
from autogpt.config import Config
from autogpt.log_cycle.json_handler import JsonFileHandler, JsonFormatter
from autogpt.singleton import Singleton
from autogpt.speech import say_text

View File

@@ -1,11 +1,12 @@
"""The application entry point. Can be invoked by a CLI or any other front end application."""
import logging
import sys
from pathlib import Path
from colorama import Fore, Style
from autogpt.agent import Agent
from autogpt.config.config import Config, check_openai_api_key
from autogpt.config.config import ConfigBuilder, check_openai_api_key
from autogpt.configurator import create_config
from autogpt.logs import logger
from autogpt.memory.vector import get_memory
@@ -45,14 +46,14 @@ def run_auto_gpt(
browser_name: str,
allow_downloads: bool,
skip_news: bool,
workspace_directory: str,
workspace_directory: str | Path,
install_plugin_deps: bool,
):
# Configure logging before we do anything else.
logger.set_level(logging.DEBUG if debug else logging.INFO)
logger.speak_mode = speak
config = Config.build_config_from_env()
config = ConfigBuilder.build_config_from_env()
# TODO: fill in llm values here
check_openai_api_key(config)

View File

@@ -1,4 +1,5 @@
"""Handles loading of plugins."""
from __future__ import annotations
import importlib.util
import inspect
@@ -7,7 +8,7 @@ import os
import sys
import zipfile
from pathlib import Path
from typing import List
from typing import TYPE_CHECKING, List
from urllib.parse import urlparse
from zipimport import zipimporter
@@ -16,7 +17,9 @@ import requests
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from openapi_python_client.config import Config as OpenAPIConfig
from autogpt.config.config import Config
if TYPE_CHECKING:
from autogpt.config import Config
from autogpt.logs import logger
from autogpt.models.base_open_ai_plugin import BaseOpenAIPlugin

View File

@@ -1,9 +1,13 @@
from __future__ import annotations
import os
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union
import yaml
from autogpt.config.config import Config
if TYPE_CHECKING:
from autogpt.config import Config
from autogpt.logs import logger
from autogpt.plugins.plugin_config import PluginConfig
@@ -11,6 +15,8 @@ from autogpt.plugins.plugin_config import PluginConfig
class PluginsConfig:
"""Class for holding configuration of all plugins"""
plugins: dict[str, PluginConfig]
def __init__(self, plugins_config: dict[str, Any]):
self.plugins = {}
for name, plugin in plugins_config.items():
@@ -33,7 +39,7 @@ class PluginsConfig:
def is_enabled(self, name) -> bool:
plugin_config = self.plugins.get(name)
return plugin_config and plugin_config.enabled
return plugin_config is not None and plugin_config.enabled
@classmethod
def load_config(cls, global_config: Config) -> "PluginsConfig":

View File

@@ -1,9 +1,14 @@
"""Base class for all voice classes."""
from __future__ import annotations
import abc
import re
from threading import Lock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from autogpt.config import Config
from autogpt.config import Config
from autogpt.singleton import AbstractSingleton

View File

@@ -1,11 +1,15 @@
"""ElevenLabs speech module"""
from __future__ import annotations
import os
from typing import TYPE_CHECKING
import requests
from playsound import playsound
from autogpt.config.config import Config
from autogpt.speech.base import VoiceBase
if TYPE_CHECKING:
from autogpt.config import Config
from .base import VoiceBase
PLACEHOLDERS = {"your-voice-id"}

View File

@@ -1,13 +1,18 @@
""" Text to speech module """
from __future__ import annotations
import threading
from threading import Semaphore
from typing import TYPE_CHECKING
from autogpt.config.config import Config
from autogpt.speech.base import VoiceBase
from autogpt.speech.eleven_labs import ElevenLabsSpeech
from autogpt.speech.gtts import GTTSVoice
from autogpt.speech.macos_tts import MacOSTTS
from autogpt.speech.stream_elements_speech import StreamElementsSpeech
if TYPE_CHECKING:
from autogpt.config import Config
from .base import VoiceBase
from .eleven_labs import ElevenLabsSpeech
from .gtts import GTTSVoice
from .macos_tts import MacOSTTS
from .stream_elements_speech import StreamElementsSpeech
_QUEUE_SEMAPHORE = Semaphore(
1

View File

@@ -10,6 +10,7 @@ agent.
from __future__ import annotations
from pathlib import Path
from typing import Optional
from autogpt.config import Config
from autogpt.logs import logger
@@ -77,7 +78,7 @@ class Workspace:
@staticmethod
def _sanitize_path(
relative_path: str | Path,
root: str | Path = None,
root: Optional[str | Path] = None,
restrict_to_root: bool = True,
) -> Path:
"""Resolve the relative path within the given root if possible.
@@ -139,7 +140,7 @@ class Workspace:
return full_path
@staticmethod
def build_file_logger_path(config, workspace_directory):
def build_file_logger_path(config: Config, workspace_directory: Path):
file_logger_path = workspace_directory / "file_logger.txt"
if not file_logger_path.exists():
with file_logger_path.open(mode="w", encoding="utf-8") as f:
@@ -147,10 +148,12 @@ class Workspace:
config.file_logger_path = str(file_logger_path)
@staticmethod
def get_workspace_directory(config: Config, workspace_directory: str = None):
def get_workspace_directory(
config: Config, workspace_directory: Optional[str | Path] = None
):
if workspace_directory is None:
workspace_directory = Path(__file__).parent / "auto_gpt_workspace"
else:
elif type(workspace_directory) == str:
workspace_directory = Path(workspace_directory)
# TODO: pass in the ai_settings file and the env file and have them cloned into
# the workspace directory so we can bind them to the agent.

View File

@@ -1,5 +1,5 @@
from autogpt.agent import Agent
from autogpt.config import AIConfig, Config
from autogpt.config import AIConfig, Config, ConfigBuilder
from autogpt.main import COMMAND_CATEGORIES
from autogpt.memory.vector import get_memory
from autogpt.models.command_registry import CommandRegistry
@@ -13,7 +13,7 @@ def run_task(task) -> None:
def bootstrap_agent(task):
config = Config.build_config_from_env()
config = ConfigBuilder.build_config_from_env()
config.continuous_mode = False
config.temperature = 0
config.plain_output = True
@@ -42,7 +42,7 @@ def bootstrap_agent(task):
)
def get_command_registry(config):
def get_command_registry(config: Config):
command_registry = CommandRegistry()
enabled_command_categories = [
x for x in COMMAND_CATEGORIES if x not in config.disabled_command_categories

View File

@@ -2,10 +2,10 @@ import argparse
import logging
from autogpt.commands.file_operations import ingest_file, list_files
from autogpt.config import Config
from autogpt.config import ConfigBuilder
from autogpt.memory.vector import VectorMemory, get_memory
config = Config.build_config_from_env()
config = ConfigBuilder.build_config_from_env()
def configure_logging():

View File

@@ -7,8 +7,8 @@ import yaml
from pytest_mock import MockerFixture
from autogpt.agent.agent import Agent
from autogpt.config import AIConfig, Config, ConfigBuilder
from autogpt.config.ai_config import AIConfig
from autogpt.config.config import Config
from autogpt.llm.api_manager import ApiManager
from autogpt.logs import TypingConsoleHandler
from autogpt.memory.vector import get_memory
@@ -49,7 +49,7 @@ def temp_plugins_config_file():
def config(
temp_plugins_config_file: str, mocker: MockerFixture, workspace: Workspace
) -> Config:
config = Config.build_config_from_env()
config = ConfigBuilder.build_config_from_env()
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = "sk-dummy"

View File

@@ -7,7 +7,7 @@ from unittest.mock import patch
import pytest
from autogpt.config import Config
from autogpt.config import Config, ConfigBuilder
from autogpt.configurator import GPT_3_MODEL, GPT_4_MODEL, create_config
from autogpt.workspace.workspace import Workspace
@@ -131,13 +131,13 @@ def test_smart_and_fast_llm_models_set_to_gpt4(mock_list_models, config: Config)
config.smart_llm_model = smart_llm_model
def test_missing_azure_config(config: Config, workspace: Workspace):
def test_missing_azure_config(workspace: Workspace):
config_file = workspace.get_path("azure_config.yaml")
with pytest.raises(FileNotFoundError):
Config.load_azure_config(str(config_file))
ConfigBuilder.load_azure_config(str(config_file))
config_file.write_text("")
azure_config = Config.load_azure_config(str(config_file))
azure_config = ConfigBuilder.load_azure_config(str(config_file))
assert azure_config["openai_api_type"] == "azure"
assert azure_config["openai_api_base"] == ""