AutoGPT: Convert dataclasses to Pydantic models

This commit is contained in:
Reinier van der Leer
2023-09-18 19:07:37 +02:00
parent d8f1d34345
commit 6b22abd526
15 changed files with 165 additions and 156 deletions

View File

@@ -147,7 +147,7 @@ class Agent(ContextMixin, WorkspaceMixin, WatchdogMixin, BaseAgent):
result: ActionResult
if command_name == "human_feedback":
result = ActionInterruptedByHuman(user_input)
result = ActionInterruptedByHuman(feedback=user_input)
self.message_history.add(
"user",
"I interrupted the execution of the command you proposed "
@@ -185,9 +185,9 @@ class Agent(ContextMixin, WorkspaceMixin, WatchdogMixin, BaseAgent):
)
self.context.add(context_item)
result = ActionSuccessResult(return_value)
result = ActionSuccessResult(outputs=return_value)
except AgentException as e:
result = ActionErrorResult(e.message, e)
result = ActionErrorResult(reason=e.message, error=e)
result_tlength = count_string_tokens(str(result), self.llm.name)
history_tlength = count_string_tokens(

View File

@@ -16,7 +16,7 @@ from autogpt.llm.base import ChatSequence, Message
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS, get_openai_command_specs
from autogpt.llm.utils import count_message_tokens, create_chat_completion
from autogpt.memory.message_history import MessageHistory
from autogpt.models.agent_actions import ActionHistory, ActionResult
from autogpt.models.agent_actions import EpisodicActionHistory, ActionResult
from autogpt.prompts.generator import PromptGenerator
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
@@ -90,10 +90,10 @@ class BaseAgent(metaclass=ABCMeta):
defaults to 75% of `llm.max_tokens`.
"""
self.event_history = ActionHistory()
self.event_history = EpisodicActionHistory()
self.message_history = MessageHistory(
self.llm,
model=self.llm,
max_summary_tlength=summary_max_tlength or self.send_token_limit // 6,
)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from contextlib import ExitStack
from autogpt.models.agent_actions import ActionHistory
from autogpt.models.agent_actions import EpisodicActionHistory
from ..base import BaseAgent
@@ -16,7 +16,7 @@ class WatchdogMixin:
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
"""
event_history: ActionHistory
event_history: EpisodicActionHistory
def __init__(self, **kwargs) -> None:
# Initialize other bases first, because we need the event_history from BaseAgent
@@ -38,7 +38,7 @@ class WatchdogMixin:
and self.config.fast_llm != self.config.smart_llm
):
# Detect repetitive commands
previous_cycle = self.event_history.cycles[self.event_history.cursor - 1]
previous_cycle = self.event_history.episodes[self.event_history.cursor - 1]
if (
command_name == previous_cycle.action.name
and command_args == previous_cycle.action.args

View File

@@ -23,7 +23,7 @@ from autogpt.logs.log_cycle import (
)
from autogpt.models.agent_actions import (
ActionErrorResult,
ActionHistory,
EpisodicActionHistory,
ActionInterruptedByHuman,
ActionResult,
ActionSuccessResult,
@@ -69,7 +69,7 @@ class PlanningAgent(ContextMixin, WorkspaceMixin, BaseAgent):
self.log_cycle_handler = LogCycleHandler()
"""LogCycleHandler for structured debug logging."""
self.action_history = ActionHistory()
self.action_history = EpisodicActionHistory()
self.plan: list[str] = []
"""List of steps that the Agent plans to take"""
@@ -229,7 +229,7 @@ class PlanningAgent(ContextMixin, WorkspaceMixin, BaseAgent):
self.ai_config.ai_name,
self.created_at,
self.cycle_count,
self.action_history.cycles,
self.action_history.episodes,
"action_history.json",
)
self.log_cycle_handler.log_cycle(
@@ -250,7 +250,7 @@ class PlanningAgent(ContextMixin, WorkspaceMixin, BaseAgent):
result: ActionResult
if command_name == "human_feedback":
result = ActionInterruptedByHuman(user_input)
result = ActionInterruptedByHuman(feedback=user_input)
self.log_cycle_handler.log_cycle(
self.ai_config.ai_name,
self.created_at,
@@ -279,9 +279,9 @@ class PlanningAgent(ContextMixin, WorkspaceMixin, BaseAgent):
self.context.add(return_value[1])
return_value = return_value[0]
result = ActionSuccessResult(return_value)
result = ActionSuccessResult(outputs=return_value)
except AgentException as e:
result = ActionErrorResult(e.message, e)
result = ActionErrorResult(reason=e.message, error=e)
result_tlength = count_string_tokens(str(result), self.llm.name)
memory_tlength = count_string_tokens(

View File

@@ -67,7 +67,10 @@ def open_file(file_path: Path, agent: Agent) -> tuple[str, FileContextItem]:
file_path = relative_file_path or file_path
file = FileContextItem(file_path, agent.workspace.root)
file = FileContextItem(
file_path_in_workspace=file_path,
workspace_path=agent.workspace.root,
)
if file in agent_context:
raise DuplicateOperationError(f"The file {file_path} is already open")
@@ -114,7 +117,10 @@ def open_folder(path: Path, agent: Agent) -> tuple[str, FolderContextItem]:
path = relative_path or path
folder = FolderContextItem(path, agent.workspace.root)
folder = FolderContextItem(
path_in_workspace=path,
workspace_path=agent.workspace.root,
)
if folder in agent_context:
raise DuplicateOperationError(f"The folder {path} is already open")

View File

@@ -1,14 +1,13 @@
"""A module that contains the AIConfig class object that contains the configuration"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from pydantic import BaseModel, Field
import yaml
@dataclass
class AIConfig:
class AIConfig(BaseModel):
"""
A class object that contains the configuration information for the AI
@@ -21,7 +20,7 @@ class AIConfig:
ai_name: str = ""
ai_role: str = ""
ai_goals: list[str] = field(default_factory=list[str])
ai_goals: list[str] = Field(default_factory=list[str])
api_budget: float = 0.0
@staticmethod
@@ -53,7 +52,12 @@ class AIConfig:
]
api_budget = config_params.get("api_budget", 0.0)
return AIConfig(ai_name, ai_role, ai_goals, api_budget)
return AIConfig(
ai_name=ai_name,
ai_role=ai_role,
ai_goals=ai_goals,
api_budget=api_budget
)
def save(self, ai_settings_file: str | Path) -> None:
"""
@@ -66,11 +70,5 @@ class AIConfig:
None
"""
config = {
"ai_name": self.ai_name,
"ai_role": self.ai_role,
"ai_goals": self.ai_goals,
"api_budget": self.api_budget,
}
with open(ai_settings_file, "w", encoding="utf-8") as file:
yaml.dump(config, file, allow_unicode=True)
yaml.dump(self.dict(), file, allow_unicode=True)

View File

@@ -1,9 +1,9 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
import yaml
from pydantic import BaseModel
from autogpt.logs.helpers import request_user_double_check
from autogpt.utils import validate_yaml_file
@@ -11,8 +11,7 @@ from autogpt.utils import validate_yaml_file
logger = logging.getLogger(__name__)
@dataclass
class AIDirectives:
class AIDirectives(BaseModel):
"""An object that contains the basic directives for the AI prompt.
Attributes:

View File

@@ -1,12 +1,9 @@
from __future__ import annotations
import json
from copy import deepcopy
from dataclasses import dataclass, field
from math import ceil, floor
from typing import TYPE_CHECKING, Literal, Optional, Type, TypedDict, TypeVar, overload
if TYPE_CHECKING:
from autogpt.llm.providers.openai import OpenAIFunctionCall
from pydantic import BaseModel, Field
from typing import Any, Literal, Optional, Type, TypedDict, TypeVar, overload
MessageRole = Literal["system", "user", "assistant", "function"]
MessageType = Literal["ai_response", "action_result"]
@@ -31,20 +28,30 @@ class FunctionCallDict(TypedDict):
arguments: str
@dataclass
class Message:
class Message(BaseModel):
"""OpenAI Message object containing a role and the message content"""
role: MessageRole
content: str
type: MessageType | None = None
type: Optional[MessageType]
def __init__(
self,
role: MessageRole,
content: str,
type: Optional[MessageType] = None
):
super().__init__(
role=role,
content=content,
type=type,
)
def raw(self) -> MessageDict:
return {"role": self.role, "content": self.content}
@dataclass
class ModelInfo:
class ModelInfo(BaseModel):
"""Struct for model information.
Would be lovely to eventually get this directly from APIs, but needs to be scraped from
@@ -56,26 +63,22 @@ class ModelInfo:
prompt_token_cost: float
@dataclass
class CompletionModelInfo(ModelInfo):
"""Struct for generic completion model information."""
completion_token_cost: float
@dataclass
class ChatModelInfo(CompletionModelInfo):
"""Struct for chat model information."""
supports_functions: bool = False
@dataclass
class TextModelInfo(CompletionModelInfo):
"""Struct for text completion model information."""
@dataclass
class EmbeddingModelInfo(ModelInfo):
"""Struct for embedding model information."""
@@ -86,12 +89,11 @@ class EmbeddingModelInfo(ModelInfo):
TChatSequence = TypeVar("TChatSequence", bound="ChatSequence")
@dataclass
class ChatSequence:
class ChatSequence(BaseModel):
"""Utility container for a chat sequence"""
model: ChatModelInfo
messages: list[Message] = field(default_factory=list[Message])
messages: list[Message] = Field(default_factory=list[Message])
@overload
def __getitem__(self, key: int) -> Message:
@@ -103,7 +105,7 @@ class ChatSequence:
def __getitem__(self: TChatSequence, key: int | slice) -> Message | TChatSequence:
if isinstance(key, slice):
copy = deepcopy(self)
copy = self.copy(deep=True)
copy.messages = self.messages[key]
return copy
return self.messages[key]
@@ -141,7 +143,7 @@ class ChatSequence:
) -> TChatSequence:
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
if not model_name in OPEN_AI_CHAT_MODELS:
if model_name not in OPEN_AI_CHAT_MODELS:
raise ValueError(f"Unknown chat model '{model_name}'")
return cls(
@@ -175,23 +177,43 @@ Length: {self.token_length} tokens; {len(self.messages)} messages
"""
@dataclass
class LLMResponse:
class LLMResponse(BaseModel):
"""Standard response struct for a response from an LLM model."""
model_info: ModelInfo
@dataclass
class EmbeddingModelResponse(LLMResponse):
"""Standard response struct for a response from an embedding model."""
embedding: list[float] = field(default_factory=list)
embedding: list[float] = Field(default_factory=list)
@dataclass
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from a chat LLM."""
content: Optional[str]
function_call: Optional[OpenAIFunctionCall]
function_call: Optional[LLMFunctionCall]
class LLMFunctionCall(BaseModel):
"""Represents a function call as generated by an OpenAI model
Attributes:
name: the name of the function that the LLM wants to call
arguments: a stringified JSON object (unverified) containing `arg: value` pairs
"""
name: str
arguments: dict[str, Any] = {}
@staticmethod
def parse(raw: FunctionCallDict):
return LLMFunctionCall(
name=raw["name"],
arguments=json.loads(raw["arguments"]),
)
# Complete model initialization; necessary because of order of definition
ChatModelResponse.update_forward_refs()

View File

@@ -4,7 +4,7 @@ import functools
import logging
import time
from dataclasses import dataclass
from typing import Callable, List, Optional
from typing import Callable, List, Optional, TypeVar
from unittest.mock import patch
import openai
@@ -118,8 +118,10 @@ OPEN_AI_MODELS: dict[str, ChatModelInfo | EmbeddingModelInfo | TextModelInfo] =
**OPEN_AI_EMBEDDING_MODELS,
}
T = TypeVar("T", bound=Callable)
def meter_api(func: Callable):
def meter_api(func: T) -> T:
"""Adds ApiManager metering to functions which make OpenAI API calls"""
from autogpt.llm.api_manager import ApiManager
@@ -145,6 +147,7 @@ def meter_api(func: Callable):
update_usage_with_response(openai_obj)
return openai_obj
@functools.wraps(func)
def metered_func(*args, **kwargs):
with patch.object(
engine_api_resource.util,
@@ -179,7 +182,7 @@ def retry_api(
)
backoff_msg = "Waiting {backoff} seconds..."
def _wrapper(func: Callable):
def _wrapper(func: T) -> T:
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
@@ -286,19 +289,6 @@ def create_embedding(
)
@dataclass
class OpenAIFunctionCall:
"""Represents a function call as generated by an OpenAI model
Attributes:
name: the name of the function that the LLM wants to call
arguments: a stringified JSON object (unverified) containing `arg: value` pairs
"""
name: str
arguments: str
@dataclass
class OpenAIFunctionSpec:
"""Represents a "function" in OpenAI, which is mapped to a Command in Auto-GPT"""

View File

@@ -1,27 +1,30 @@
from __future__ import annotations
from typing import List, Literal, Optional
import logging
from typing import Optional
from colorama import Fore
from autogpt.config import Config
from ..api_manager import ApiManager
from ..base import (
ChatModelResponse,
ChatSequence,
FunctionCallDict,
LLMFunctionCall,
Message,
ResponseMessageDict,
)
from ..providers import openai as iopenai
from ..providers.openai import (
OPEN_AI_CHAT_MODELS,
OpenAIFunctionCall,
OpenAIFunctionSpec,
count_openai_functions_tokens,
)
from .token_counter import *
from .token_counter import count_message_tokens, count_string_tokens
logger = logging.getLogger(__name__)
def call_ai_function(
@@ -96,7 +99,7 @@ def create_text_completion(
def create_chat_completion(
prompt: ChatSequence,
config: Config,
functions: Optional[List[OpenAIFunctionSpec]] = None,
functions: Optional[list[OpenAIFunctionSpec]] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
@@ -182,9 +185,5 @@ def create_chat_completion(
return ChatModelResponse(
model_info=OPEN_AI_CHAT_MODELS[model],
content=content,
function_call=OpenAIFunctionCall(
name=function_call["name"], arguments=function_call["arguments"]
)
if function_call
else None,
function_call=LLMFunctionCall.parse(function_call) if function_call else None,
)

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import copy
import json
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, Optional
if TYPE_CHECKING:
@@ -23,7 +22,6 @@ from autogpt.logs import PROMPT_SUMMARY_FILE_NAME, SUMMARY_FILE_NAME, LogCycleHa
logger = logging.getLogger(__name__)
@dataclass
class MessageHistory(ChatSequence):
max_summary_tlength: int = 500
agent: Optional[BaseAgent | Agent] = None

View File

@@ -1,12 +1,12 @@
from __future__ import annotations
import dataclasses
import json
import logging
from typing import Literal
import ftfy
import numpy as np
from pydantic import BaseModel
from autogpt.config import Config
from autogpt.llm import Message
@@ -20,8 +20,8 @@ logger = logging.getLogger(__name__)
MemoryDocType = Literal["webpage", "text_file", "code_file", "agent_history"]
@dataclasses.dataclass
class MemoryItem:
# FIXME: implement validators instead of allowing arbitrary types
class MemoryItem(BaseModel, arbitrary_types_allowed=True):
"""Memory object containing raw content as well as embeddings"""
raw_content: str
@@ -94,12 +94,12 @@ class MemoryItem:
metadata["source_type"] = source_type
return MemoryItem(
text,
summary,
chunks,
chunk_summaries,
e_summary,
e_chunks,
raw_content=text,
summary=summary,
chunks=chunks,
chunk_summaries=chunk_summaries,
e_summary=e_summary,
e_chunks=e_chunks,
metadata=metadata,
)
@@ -198,8 +198,7 @@ Metadata: {json.dumps(self.metadata, indent=2)}
)
@dataclasses.dataclass
class MemoryItemRelevance:
class MemoryItemRelevance(BaseModel):
"""
Class that encapsulates memory relevance search functionality and data.
Instances contain a MemoryItem and its relevance scores for a given query.

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterator, Literal, Optional
from pydantic import BaseModel
from autogpt.prompts.utils import format_numbered_list, indent
@dataclass
class Action:
class Action(BaseModel):
name: str
args: dict[str, Any]
reasoning: str
@@ -16,8 +16,7 @@ class Action:
return f"{self.name}({', '.join([f'{a}={repr(v)}' for a, v in self.args.items()])})"
@dataclass
class ActionSuccessResult:
class ActionSuccessResult(BaseModel):
outputs: Any
status: Literal["success"] = "success"
@@ -27,8 +26,8 @@ class ActionSuccessResult:
return f"```\n{self.outputs}\n```" if multiline else str(self.outputs)
@dataclass
class ActionErrorResult:
# FIXME: implement validators instead of allowing arbitrary types
class ActionErrorResult(BaseModel, arbitrary_types_allowed=True):
reason: str
error: Optional[Exception] = None
status: Literal["error"] = "error"
@@ -37,8 +36,7 @@ class ActionErrorResult:
return f"Action failed: '{self.reason}'"
@dataclass
class ActionInterruptedByHuman:
class ActionInterruptedByHuman(BaseModel):
feedback: str
status: Literal["interrupted_by_human"] = "interrupted_by_human"
@@ -49,61 +47,63 @@ class ActionInterruptedByHuman:
ActionResult = ActionSuccessResult | ActionErrorResult | ActionInterruptedByHuman
class ActionHistory:
class Episode(BaseModel):
action: Action
result: ActionResult | None
def __str__(self) -> str:
executed_action = f"Executed `{self.action.format_call()}`"
action_result = f": {self.result}" if self.result else "."
return executed_action + action_result
class EpisodicActionHistory(BaseModel):
"""Utility container for an action history"""
@dataclass
class CycleRecord:
action: Action
result: ActionResult | None
def __str__(self) -> str:
executed_action = f"Executed `{self.action.format_call()}`"
action_result = f": {self.result}" if self.result else "."
return executed_action + action_result
cursor: int
cycles: list[CycleRecord]
episodes: list[Episode]
def __init__(self, cycles: list[CycleRecord] = []):
self.cycles = cycles
self.cursor = len(self.cycles)
def __init__(self, episodes: list[Episode] = []):
super().__init__(
episodes=episodes,
cursor=len(episodes),
)
@property
def current_record(self) -> CycleRecord | None:
def current_episode(self) -> Episode | None:
if self.cursor == len(self):
return None
return self[self.cursor]
def __getitem__(self, key: int) -> CycleRecord:
return self.cycles[key]
def __getitem__(self, key: int) -> Episode:
return self.episodes[key]
def __iter__(self) -> Iterator[CycleRecord]:
return iter(self.cycles)
def __iter__(self) -> Iterator[Episode]:
return iter(self.episodes)
def __len__(self) -> int:
return len(self.cycles)
return len(self.episodes)
def __bool__(self) -> bool:
return len(self.cycles) > 0
return len(self.episodes) > 0
def register_action(self, action: Action) -> None:
if not self.current_record:
self.cycles.append(self.CycleRecord(action, None))
assert self.current_record
elif self.current_record.action:
if not self.current_episode:
self.episodes.append(Episode(action=action, result=None))
assert self.current_episode
elif self.current_episode.action:
raise ValueError("Action for current cycle already set")
def register_result(self, result: ActionResult) -> None:
if not self.current_record:
if not self.current_episode:
raise RuntimeError("Cannot register result for cycle without action")
elif self.current_record.result:
elif self.current_episode.result:
raise ValueError("Result for current cycle already set")
self.current_record.result = result
self.cursor = len(self.cycles)
self.current_episode.result = result
self.cursor = len(self.episodes)
def rewind(self, number_of_cycles: int = 0) -> None:
def rewind(self, number_of_episodes: int = 0) -> None:
"""Resets the history to an earlier state.
Params:
@@ -111,22 +111,22 @@ class ActionHistory:
When set to 0, it will only reset the current cycle.
"""
# Remove partial record of current cycle
if self.current_record:
if self.current_record.action and not self.current_record.result:
self.cycles.pop(self.cursor)
if self.current_episode:
if self.current_episode.action and not self.current_episode.result:
self.episodes.pop(self.cursor)
# Rewind the specified number of cycles
if number_of_cycles > 0:
self.cycles = self.cycles[:-number_of_cycles]
self.cursor = len(self.cycles)
if number_of_episodes > 0:
self.episodes = self.episodes[:-number_of_episodes]
self.cursor = len(self.episodes)
def fmt_list(self) -> str:
return format_numbered_list(self.cycles)
return format_numbered_list(self.episodes)
def fmt_paragraph(self) -> str:
steps: list[str] = []
for i, c in enumerate(self.cycles, 1):
for i, c in enumerate(self.episodes, 1):
step = f"### Step {i}: Executed `{c.action.format_call()}`\n"
step += f'- **Reasoning:** "{c.action.reasoning}"\n'
step += (

View File

@@ -1,9 +1,10 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
from autogpt.commands.file_operations_utils import read_textual_file
logger = logging.getLogger(__name__)
@@ -37,8 +38,7 @@ class ContextItem(ABC):
)
@dataclass
class FileContextItem(ContextItem):
class FileContextItem(BaseModel, ContextItem):
file_path_in_workspace: Path
workspace_path: Path
@@ -59,8 +59,7 @@ class FileContextItem(ContextItem):
return read_textual_file(self.file_path, logger)
@dataclass
class FolderContextItem(ContextItem):
class FolderContextItem(BaseModel, ContextItem):
path_in_workspace: Path
workspace_path: Path
@@ -87,8 +86,7 @@ class FolderContextItem(ContextItem):
return "\n".join(items)
@dataclass
class StaticContextItem(ContextItem):
description: str
source: Optional[str]
content: str
class StaticContextItem(BaseModel, ContextItem):
item_description: str = Field(alias="description")
item_source: Optional[str] = Field(alias="source")
item_content: str = Field(alias="content")

View File

@@ -31,7 +31,7 @@ def agent(config: Config):
def test_message_history_batch_summary(mocker, agent: Agent, config: Config):
history = MessageHistory(agent.llm, agent=agent)
history = MessageHistory(model=agent.llm, agent=agent)
model = config.fast_llm
message_tlength = 0
message_count = 0