mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-15 11:14:27 +01:00
AutoGPT: Convert dataclasses to Pydantic models
This commit is contained in:
@@ -147,7 +147,7 @@ class Agent(ContextMixin, WorkspaceMixin, WatchdogMixin, BaseAgent):
|
||||
result: ActionResult
|
||||
|
||||
if command_name == "human_feedback":
|
||||
result = ActionInterruptedByHuman(user_input)
|
||||
result = ActionInterruptedByHuman(feedback=user_input)
|
||||
self.message_history.add(
|
||||
"user",
|
||||
"I interrupted the execution of the command you proposed "
|
||||
@@ -185,9 +185,9 @@ class Agent(ContextMixin, WorkspaceMixin, WatchdogMixin, BaseAgent):
|
||||
)
|
||||
self.context.add(context_item)
|
||||
|
||||
result = ActionSuccessResult(return_value)
|
||||
result = ActionSuccessResult(outputs=return_value)
|
||||
except AgentException as e:
|
||||
result = ActionErrorResult(e.message, e)
|
||||
result = ActionErrorResult(reason=e.message, error=e)
|
||||
|
||||
result_tlength = count_string_tokens(str(result), self.llm.name)
|
||||
history_tlength = count_string_tokens(
|
||||
|
||||
@@ -16,7 +16,7 @@ from autogpt.llm.base import ChatSequence, Message
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS, get_openai_command_specs
|
||||
from autogpt.llm.utils import count_message_tokens, create_chat_completion
|
||||
from autogpt.memory.message_history import MessageHistory
|
||||
from autogpt.models.agent_actions import ActionHistory, ActionResult
|
||||
from autogpt.models.agent_actions import EpisodicActionHistory, ActionResult
|
||||
from autogpt.prompts.generator import PromptGenerator
|
||||
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
|
||||
|
||||
@@ -90,10 +90,10 @@ class BaseAgent(metaclass=ABCMeta):
|
||||
defaults to 75% of `llm.max_tokens`.
|
||||
"""
|
||||
|
||||
self.event_history = ActionHistory()
|
||||
self.event_history = EpisodicActionHistory()
|
||||
|
||||
self.message_history = MessageHistory(
|
||||
self.llm,
|
||||
model=self.llm,
|
||||
max_summary_tlength=summary_max_tlength or self.send_token_limit // 6,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from contextlib import ExitStack
|
||||
|
||||
from autogpt.models.agent_actions import ActionHistory
|
||||
from autogpt.models.agent_actions import EpisodicActionHistory
|
||||
|
||||
from ..base import BaseAgent
|
||||
|
||||
@@ -16,7 +16,7 @@ class WatchdogMixin:
|
||||
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
|
||||
"""
|
||||
|
||||
event_history: ActionHistory
|
||||
event_history: EpisodicActionHistory
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# Initialize other bases first, because we need the event_history from BaseAgent
|
||||
@@ -38,7 +38,7 @@ class WatchdogMixin:
|
||||
and self.config.fast_llm != self.config.smart_llm
|
||||
):
|
||||
# Detect repetitive commands
|
||||
previous_cycle = self.event_history.cycles[self.event_history.cursor - 1]
|
||||
previous_cycle = self.event_history.episodes[self.event_history.cursor - 1]
|
||||
if (
|
||||
command_name == previous_cycle.action.name
|
||||
and command_args == previous_cycle.action.args
|
||||
|
||||
@@ -23,7 +23,7 @@ from autogpt.logs.log_cycle import (
|
||||
)
|
||||
from autogpt.models.agent_actions import (
|
||||
ActionErrorResult,
|
||||
ActionHistory,
|
||||
EpisodicActionHistory,
|
||||
ActionInterruptedByHuman,
|
||||
ActionResult,
|
||||
ActionSuccessResult,
|
||||
@@ -69,7 +69,7 @@ class PlanningAgent(ContextMixin, WorkspaceMixin, BaseAgent):
|
||||
self.log_cycle_handler = LogCycleHandler()
|
||||
"""LogCycleHandler for structured debug logging."""
|
||||
|
||||
self.action_history = ActionHistory()
|
||||
self.action_history = EpisodicActionHistory()
|
||||
|
||||
self.plan: list[str] = []
|
||||
"""List of steps that the Agent plans to take"""
|
||||
@@ -229,7 +229,7 @@ class PlanningAgent(ContextMixin, WorkspaceMixin, BaseAgent):
|
||||
self.ai_config.ai_name,
|
||||
self.created_at,
|
||||
self.cycle_count,
|
||||
self.action_history.cycles,
|
||||
self.action_history.episodes,
|
||||
"action_history.json",
|
||||
)
|
||||
self.log_cycle_handler.log_cycle(
|
||||
@@ -250,7 +250,7 @@ class PlanningAgent(ContextMixin, WorkspaceMixin, BaseAgent):
|
||||
result: ActionResult
|
||||
|
||||
if command_name == "human_feedback":
|
||||
result = ActionInterruptedByHuman(user_input)
|
||||
result = ActionInterruptedByHuman(feedback=user_input)
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.ai_config.ai_name,
|
||||
self.created_at,
|
||||
@@ -279,9 +279,9 @@ class PlanningAgent(ContextMixin, WorkspaceMixin, BaseAgent):
|
||||
self.context.add(return_value[1])
|
||||
return_value = return_value[0]
|
||||
|
||||
result = ActionSuccessResult(return_value)
|
||||
result = ActionSuccessResult(outputs=return_value)
|
||||
except AgentException as e:
|
||||
result = ActionErrorResult(e.message, e)
|
||||
result = ActionErrorResult(reason=e.message, error=e)
|
||||
|
||||
result_tlength = count_string_tokens(str(result), self.llm.name)
|
||||
memory_tlength = count_string_tokens(
|
||||
|
||||
@@ -67,7 +67,10 @@ def open_file(file_path: Path, agent: Agent) -> tuple[str, FileContextItem]:
|
||||
|
||||
file_path = relative_file_path or file_path
|
||||
|
||||
file = FileContextItem(file_path, agent.workspace.root)
|
||||
file = FileContextItem(
|
||||
file_path_in_workspace=file_path,
|
||||
workspace_path=agent.workspace.root,
|
||||
)
|
||||
if file in agent_context:
|
||||
raise DuplicateOperationError(f"The file {file_path} is already open")
|
||||
|
||||
@@ -114,7 +117,10 @@ def open_folder(path: Path, agent: Agent) -> tuple[str, FolderContextItem]:
|
||||
|
||||
path = relative_path or path
|
||||
|
||||
folder = FolderContextItem(path, agent.workspace.root)
|
||||
folder = FolderContextItem(
|
||||
path_in_workspace=path,
|
||||
workspace_path=agent.workspace.root,
|
||||
)
|
||||
if folder in agent_context:
|
||||
raise DuplicateOperationError(f"The folder {path} is already open")
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
"""A module that contains the AIConfig class object that contains the configuration"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
class AIConfig(BaseModel):
|
||||
"""
|
||||
A class object that contains the configuration information for the AI
|
||||
|
||||
@@ -21,7 +20,7 @@ class AIConfig:
|
||||
|
||||
ai_name: str = ""
|
||||
ai_role: str = ""
|
||||
ai_goals: list[str] = field(default_factory=list[str])
|
||||
ai_goals: list[str] = Field(default_factory=list[str])
|
||||
api_budget: float = 0.0
|
||||
|
||||
@staticmethod
|
||||
@@ -53,7 +52,12 @@ class AIConfig:
|
||||
]
|
||||
api_budget = config_params.get("api_budget", 0.0)
|
||||
|
||||
return AIConfig(ai_name, ai_role, ai_goals, api_budget)
|
||||
return AIConfig(
|
||||
ai_name=ai_name,
|
||||
ai_role=ai_role,
|
||||
ai_goals=ai_goals,
|
||||
api_budget=api_budget
|
||||
)
|
||||
|
||||
def save(self, ai_settings_file: str | Path) -> None:
|
||||
"""
|
||||
@@ -66,11 +70,5 @@ class AIConfig:
|
||||
None
|
||||
"""
|
||||
|
||||
config = {
|
||||
"ai_name": self.ai_name,
|
||||
"ai_role": self.ai_role,
|
||||
"ai_goals": self.ai_goals,
|
||||
"api_budget": self.api_budget,
|
||||
}
|
||||
with open(ai_settings_file, "w", encoding="utf-8") as file:
|
||||
yaml.dump(config, file, allow_unicode=True)
|
||||
yaml.dump(self.dict(), file, allow_unicode=True)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt.logs.helpers import request_user_double_check
|
||||
from autogpt.utils import validate_yaml_file
|
||||
@@ -11,8 +11,7 @@ from autogpt.utils import validate_yaml_file
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIDirectives:
|
||||
class AIDirectives(BaseModel):
|
||||
"""An object that contains the basic directives for the AI prompt.
|
||||
|
||||
Attributes:
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from math import ceil, floor
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Type, TypedDict, TypeVar, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.llm.providers.openai import OpenAIFunctionCall
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Literal, Optional, Type, TypedDict, TypeVar, overload
|
||||
|
||||
MessageRole = Literal["system", "user", "assistant", "function"]
|
||||
MessageType = Literal["ai_response", "action_result"]
|
||||
@@ -31,20 +28,30 @@ class FunctionCallDict(TypedDict):
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
class Message(BaseModel):
|
||||
"""OpenAI Message object containing a role and the message content"""
|
||||
|
||||
role: MessageRole
|
||||
content: str
|
||||
type: MessageType | None = None
|
||||
type: Optional[MessageType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
type: Optional[MessageType] = None
|
||||
):
|
||||
super().__init__(
|
||||
role=role,
|
||||
content=content,
|
||||
type=type,
|
||||
)
|
||||
|
||||
def raw(self) -> MessageDict:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
class ModelInfo(BaseModel):
|
||||
"""Struct for model information.
|
||||
|
||||
Would be lovely to eventually get this directly from APIs, but needs to be scraped from
|
||||
@@ -56,26 +63,22 @@ class ModelInfo:
|
||||
prompt_token_cost: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionModelInfo(ModelInfo):
|
||||
"""Struct for generic completion model information."""
|
||||
|
||||
completion_token_cost: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatModelInfo(CompletionModelInfo):
|
||||
"""Struct for chat model information."""
|
||||
|
||||
supports_functions: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextModelInfo(CompletionModelInfo):
|
||||
"""Struct for text completion model information."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelInfo(ModelInfo):
|
||||
"""Struct for embedding model information."""
|
||||
|
||||
@@ -86,12 +89,11 @@ class EmbeddingModelInfo(ModelInfo):
|
||||
TChatSequence = TypeVar("TChatSequence", bound="ChatSequence")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSequence:
|
||||
class ChatSequence(BaseModel):
|
||||
"""Utility container for a chat sequence"""
|
||||
|
||||
model: ChatModelInfo
|
||||
messages: list[Message] = field(default_factory=list[Message])
|
||||
messages: list[Message] = Field(default_factory=list[Message])
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> Message:
|
||||
@@ -103,7 +105,7 @@ class ChatSequence:
|
||||
|
||||
def __getitem__(self: TChatSequence, key: int | slice) -> Message | TChatSequence:
|
||||
if isinstance(key, slice):
|
||||
copy = deepcopy(self)
|
||||
copy = self.copy(deep=True)
|
||||
copy.messages = self.messages[key]
|
||||
return copy
|
||||
return self.messages[key]
|
||||
@@ -141,7 +143,7 @@ class ChatSequence:
|
||||
) -> TChatSequence:
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
|
||||
|
||||
if not model_name in OPEN_AI_CHAT_MODELS:
|
||||
if model_name not in OPEN_AI_CHAT_MODELS:
|
||||
raise ValueError(f"Unknown chat model '{model_name}'")
|
||||
|
||||
return cls(
|
||||
@@ -175,23 +177,43 @@ Length: {self.token_length} tokens; {len(self.messages)} messages
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
class LLMResponse(BaseModel):
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
|
||||
model_info: ModelInfo
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an embedding model."""
|
||||
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
embedding: list[float] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from a chat LLM."""
|
||||
|
||||
content: Optional[str]
|
||||
function_call: Optional[OpenAIFunctionCall]
|
||||
function_call: Optional[LLMFunctionCall]
|
||||
|
||||
|
||||
class LLMFunctionCall(BaseModel):
|
||||
"""Represents a function call as generated by an OpenAI model
|
||||
|
||||
Attributes:
|
||||
name: the name of the function that the LLM wants to call
|
||||
arguments: a stringified JSON object (unverified) containing `arg: value` pairs
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def parse(raw: FunctionCallDict):
|
||||
return LLMFunctionCall(
|
||||
name=raw["name"],
|
||||
arguments=json.loads(raw["arguments"]),
|
||||
)
|
||||
|
||||
|
||||
# Complete model initialization; necessary because of order of definition
|
||||
ChatModelResponse.update_forward_refs()
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, TypeVar
|
||||
from unittest.mock import patch
|
||||
|
||||
import openai
|
||||
@@ -118,8 +118,10 @@ OPEN_AI_MODELS: dict[str, ChatModelInfo | EmbeddingModelInfo | TextModelInfo] =
|
||||
**OPEN_AI_EMBEDDING_MODELS,
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
def meter_api(func: Callable):
|
||||
|
||||
def meter_api(func: T) -> T:
|
||||
"""Adds ApiManager metering to functions which make OpenAI API calls"""
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
|
||||
@@ -145,6 +147,7 @@ def meter_api(func: Callable):
|
||||
update_usage_with_response(openai_obj)
|
||||
return openai_obj
|
||||
|
||||
@functools.wraps(func)
|
||||
def metered_func(*args, **kwargs):
|
||||
with patch.object(
|
||||
engine_api_resource.util,
|
||||
@@ -179,7 +182,7 @@ def retry_api(
|
||||
)
|
||||
backoff_msg = "Waiting {backoff} seconds..."
|
||||
|
||||
def _wrapper(func: Callable):
|
||||
def _wrapper(func: T) -> T:
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
@@ -286,19 +289,6 @@ def create_embedding(
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIFunctionCall:
|
||||
"""Represents a function call as generated by an OpenAI model
|
||||
|
||||
Attributes:
|
||||
name: the name of the function that the LLM wants to call
|
||||
arguments: a stringified JSON object (unverified) containing `arg: value` pairs
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIFunctionSpec:
|
||||
"""Represents a "function" in OpenAI, which is mapped to a Command in Auto-GPT"""
|
||||
|
||||
@@ -1,27 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Literal, Optional
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from autogpt.config import Config
|
||||
|
||||
from ..api_manager import ApiManager
|
||||
from ..base import (
|
||||
ChatModelResponse,
|
||||
ChatSequence,
|
||||
FunctionCallDict,
|
||||
LLMFunctionCall,
|
||||
Message,
|
||||
ResponseMessageDict,
|
||||
)
|
||||
from ..providers import openai as iopenai
|
||||
from ..providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAIFunctionCall,
|
||||
OpenAIFunctionSpec,
|
||||
count_openai_functions_tokens,
|
||||
)
|
||||
from .token_counter import *
|
||||
|
||||
from .token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def call_ai_function(
|
||||
@@ -96,7 +99,7 @@ def create_text_completion(
|
||||
def create_chat_completion(
|
||||
prompt: ChatSequence,
|
||||
config: Config,
|
||||
functions: Optional[List[OpenAIFunctionSpec]] = None,
|
||||
functions: Optional[list[OpenAIFunctionSpec]] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
@@ -182,9 +185,5 @@ def create_chat_completion(
|
||||
return ChatModelResponse(
|
||||
model_info=OPEN_AI_CHAT_MODELS[model],
|
||||
content=content,
|
||||
function_call=OpenAIFunctionCall(
|
||||
name=function_call["name"], arguments=function_call["arguments"]
|
||||
)
|
||||
if function_call
|
||||
else None,
|
||||
function_call=LLMFunctionCall.parse(function_call) if function_call else None,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Iterator, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -23,7 +22,6 @@ from autogpt.logs import PROMPT_SUMMARY_FILE_NAME, SUMMARY_FILE_NAME, LogCycleHa
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageHistory(ChatSequence):
|
||||
max_summary_tlength: int = 500
|
||||
agent: Optional[BaseAgent | Agent] = None
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import ftfy
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.llm import Message
|
||||
@@ -20,8 +20,8 @@ logger = logging.getLogger(__name__)
|
||||
MemoryDocType = Literal["webpage", "text_file", "code_file", "agent_history"]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryItem:
|
||||
# FIXME: implement validators instead of allowing arbitrary types
|
||||
class MemoryItem(BaseModel, arbitrary_types_allowed=True):
|
||||
"""Memory object containing raw content as well as embeddings"""
|
||||
|
||||
raw_content: str
|
||||
@@ -94,12 +94,12 @@ class MemoryItem:
|
||||
metadata["source_type"] = source_type
|
||||
|
||||
return MemoryItem(
|
||||
text,
|
||||
summary,
|
||||
chunks,
|
||||
chunk_summaries,
|
||||
e_summary,
|
||||
e_chunks,
|
||||
raw_content=text,
|
||||
summary=summary,
|
||||
chunks=chunks,
|
||||
chunk_summaries=chunk_summaries,
|
||||
e_summary=e_summary,
|
||||
e_chunks=e_chunks,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@@ -198,8 +198,7 @@ Metadata: {json.dumps(self.metadata, indent=2)}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryItemRelevance:
|
||||
class MemoryItemRelevance(BaseModel):
|
||||
"""
|
||||
Class that encapsulates memory relevance search functionality and data.
|
||||
Instances contain a MemoryItem and its relevance scores for a given query.
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterator, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt.prompts.utils import format_numbered_list, indent
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
class Action(BaseModel):
|
||||
name: str
|
||||
args: dict[str, Any]
|
||||
reasoning: str
|
||||
@@ -16,8 +16,7 @@ class Action:
|
||||
return f"{self.name}({', '.join([f'{a}={repr(v)}' for a, v in self.args.items()])})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionSuccessResult:
|
||||
class ActionSuccessResult(BaseModel):
|
||||
outputs: Any
|
||||
status: Literal["success"] = "success"
|
||||
|
||||
@@ -27,8 +26,8 @@ class ActionSuccessResult:
|
||||
return f"```\n{self.outputs}\n```" if multiline else str(self.outputs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionErrorResult:
|
||||
# FIXME: implement validators instead of allowing arbitrary types
|
||||
class ActionErrorResult(BaseModel, arbitrary_types_allowed=True):
|
||||
reason: str
|
||||
error: Optional[Exception] = None
|
||||
status: Literal["error"] = "error"
|
||||
@@ -37,8 +36,7 @@ class ActionErrorResult:
|
||||
return f"Action failed: '{self.reason}'"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionInterruptedByHuman:
|
||||
class ActionInterruptedByHuman(BaseModel):
|
||||
feedback: str
|
||||
status: Literal["interrupted_by_human"] = "interrupted_by_human"
|
||||
|
||||
@@ -49,61 +47,63 @@ class ActionInterruptedByHuman:
|
||||
ActionResult = ActionSuccessResult | ActionErrorResult | ActionInterruptedByHuman
|
||||
|
||||
|
||||
class ActionHistory:
|
||||
class Episode(BaseModel):
|
||||
action: Action
|
||||
result: ActionResult | None
|
||||
|
||||
def __str__(self) -> str:
|
||||
executed_action = f"Executed `{self.action.format_call()}`"
|
||||
action_result = f": {self.result}" if self.result else "."
|
||||
return executed_action + action_result
|
||||
|
||||
|
||||
class EpisodicActionHistory(BaseModel):
|
||||
"""Utility container for an action history"""
|
||||
|
||||
@dataclass
|
||||
class CycleRecord:
|
||||
action: Action
|
||||
result: ActionResult | None
|
||||
|
||||
def __str__(self) -> str:
|
||||
executed_action = f"Executed `{self.action.format_call()}`"
|
||||
action_result = f": {self.result}" if self.result else "."
|
||||
return executed_action + action_result
|
||||
|
||||
cursor: int
|
||||
cycles: list[CycleRecord]
|
||||
episodes: list[Episode]
|
||||
|
||||
def __init__(self, cycles: list[CycleRecord] = []):
|
||||
self.cycles = cycles
|
||||
self.cursor = len(self.cycles)
|
||||
def __init__(self, episodes: list[Episode] = []):
|
||||
super().__init__(
|
||||
episodes=episodes,
|
||||
cursor=len(episodes),
|
||||
)
|
||||
|
||||
@property
|
||||
def current_record(self) -> CycleRecord | None:
|
||||
def current_episode(self) -> Episode | None:
|
||||
if self.cursor == len(self):
|
||||
return None
|
||||
return self[self.cursor]
|
||||
|
||||
def __getitem__(self, key: int) -> CycleRecord:
|
||||
return self.cycles[key]
|
||||
def __getitem__(self, key: int) -> Episode:
|
||||
return self.episodes[key]
|
||||
|
||||
def __iter__(self) -> Iterator[CycleRecord]:
|
||||
return iter(self.cycles)
|
||||
def __iter__(self) -> Iterator[Episode]:
|
||||
return iter(self.episodes)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.cycles)
|
||||
return len(self.episodes)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.cycles) > 0
|
||||
return len(self.episodes) > 0
|
||||
|
||||
def register_action(self, action: Action) -> None:
|
||||
if not self.current_record:
|
||||
self.cycles.append(self.CycleRecord(action, None))
|
||||
assert self.current_record
|
||||
elif self.current_record.action:
|
||||
if not self.current_episode:
|
||||
self.episodes.append(Episode(action=action, result=None))
|
||||
assert self.current_episode
|
||||
elif self.current_episode.action:
|
||||
raise ValueError("Action for current cycle already set")
|
||||
|
||||
def register_result(self, result: ActionResult) -> None:
|
||||
if not self.current_record:
|
||||
if not self.current_episode:
|
||||
raise RuntimeError("Cannot register result for cycle without action")
|
||||
elif self.current_record.result:
|
||||
elif self.current_episode.result:
|
||||
raise ValueError("Result for current cycle already set")
|
||||
|
||||
self.current_record.result = result
|
||||
self.cursor = len(self.cycles)
|
||||
self.current_episode.result = result
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
def rewind(self, number_of_cycles: int = 0) -> None:
|
||||
def rewind(self, number_of_episodes: int = 0) -> None:
|
||||
"""Resets the history to an earlier state.
|
||||
|
||||
Params:
|
||||
@@ -111,22 +111,22 @@ class ActionHistory:
|
||||
When set to 0, it will only reset the current cycle.
|
||||
"""
|
||||
# Remove partial record of current cycle
|
||||
if self.current_record:
|
||||
if self.current_record.action and not self.current_record.result:
|
||||
self.cycles.pop(self.cursor)
|
||||
if self.current_episode:
|
||||
if self.current_episode.action and not self.current_episode.result:
|
||||
self.episodes.pop(self.cursor)
|
||||
|
||||
# Rewind the specified number of cycles
|
||||
if number_of_cycles > 0:
|
||||
self.cycles = self.cycles[:-number_of_cycles]
|
||||
self.cursor = len(self.cycles)
|
||||
if number_of_episodes > 0:
|
||||
self.episodes = self.episodes[:-number_of_episodes]
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
def fmt_list(self) -> str:
|
||||
return format_numbered_list(self.cycles)
|
||||
return format_numbered_list(self.episodes)
|
||||
|
||||
def fmt_paragraph(self) -> str:
|
||||
steps: list[str] = []
|
||||
|
||||
for i, c in enumerate(self.cycles, 1):
|
||||
for i, c in enumerate(self.episodes, 1):
|
||||
step = f"### Step {i}: Executed `{c.action.format_call()}`\n"
|
||||
step += f'- **Reasoning:** "{c.action.reasoning}"\n'
|
||||
step += (
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt.commands.file_operations_utils import read_textual_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,8 +38,7 @@ class ContextItem(ABC):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileContextItem(ContextItem):
|
||||
class FileContextItem(BaseModel, ContextItem):
|
||||
file_path_in_workspace: Path
|
||||
workspace_path: Path
|
||||
|
||||
@@ -59,8 +59,7 @@ class FileContextItem(ContextItem):
|
||||
return read_textual_file(self.file_path, logger)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FolderContextItem(ContextItem):
|
||||
class FolderContextItem(BaseModel, ContextItem):
|
||||
path_in_workspace: Path
|
||||
workspace_path: Path
|
||||
|
||||
@@ -87,8 +86,7 @@ class FolderContextItem(ContextItem):
|
||||
return "\n".join(items)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticContextItem(ContextItem):
|
||||
description: str
|
||||
source: Optional[str]
|
||||
content: str
|
||||
class StaticContextItem(BaseModel, ContextItem):
|
||||
item_description: str = Field(alias="description")
|
||||
item_source: Optional[str] = Field(alias="source")
|
||||
item_content: str = Field(alias="content")
|
||||
|
||||
@@ -31,7 +31,7 @@ def agent(config: Config):
|
||||
|
||||
|
||||
def test_message_history_batch_summary(mocker, agent: Agent, config: Config):
|
||||
history = MessageHistory(agent.llm, agent=agent)
|
||||
history = MessageHistory(model=agent.llm, agent=agent)
|
||||
model = config.fast_llm
|
||||
message_tlength = 0
|
||||
message_count = 0
|
||||
|
||||
Reference in New Issue
Block a user