mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-09 09:14:19 +01:00
Implement watchdog feature for dynamic switching between smart & fast LLMs
This commit is contained in:
@@ -66,11 +66,11 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
### LLM MODELS
|
||||
################################################################################
|
||||
|
||||
## SMART_LLM - Smart language model (Default: gpt-4)
|
||||
# SMART_LLM=gpt-4
|
||||
## SMART_LLM - Smart language model (Default: gpt-4-0314)
|
||||
# SMART_LLM=gpt-4-0314
|
||||
|
||||
## FAST_LLM - Fast language model (Default: gpt-3.5-turbo)
|
||||
# FAST_LLM=gpt-3.5-turbo
|
||||
## FAST_LLM - Fast language model (Default: gpt-3.5-turbo-16k)
|
||||
# FAST_LLM=gpt-3.5-turbo-16k
|
||||
|
||||
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||
# EMBEDDING_MODEL=text-embedding-ada-002
|
||||
|
||||
@@ -35,6 +35,7 @@ from autogpt.models.context_item import ContextItem
|
||||
|
||||
from .base import BaseAgent
|
||||
from .features.context import ContextMixin
|
||||
from .features.watchdog import WatchdogMixin
|
||||
from .features.workspace import WorkspaceMixin
|
||||
from .utils.exceptions import (
|
||||
AgentException,
|
||||
@@ -46,7 +47,7 @@ from .utils.exceptions import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Agent(ContextMixin, WorkspaceMixin, BaseAgent):
|
||||
class Agent(ContextMixin, WorkspaceMixin, WatchdogMixin, BaseAgent):
|
||||
"""Agent class for interacting with Auto-GPT."""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import AIConfig, Config
|
||||
from autogpt.llm.base import ChatModelInfo, ChatModelResponse
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
from autogpt.agents.utils.exceptions import InvalidAgentResponseError
|
||||
from autogpt.config.ai_directives import AIDirectives
|
||||
from autogpt.llm.base import ChatModelResponse, ChatSequence, Message
|
||||
from autogpt.llm.base import ChatSequence, Message
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS, get_openai_command_specs
|
||||
from autogpt.llm.utils import count_message_tokens, create_chat_completion
|
||||
from autogpt.memory.message_history import MessageHistory
|
||||
@@ -83,10 +84,6 @@ class BaseAgent(metaclass=ABCMeta):
|
||||
self.cycle_count = 0
|
||||
"""The number of cycles that the agent has run since its initialization."""
|
||||
|
||||
llm_name = self.config.smart_llm if self.big_brain else self.config.fast_llm
|
||||
self.llm = OPEN_AI_CHAT_MODELS[llm_name]
|
||||
"""The LLM that the agent uses to think."""
|
||||
|
||||
self.send_token_limit = send_token_limit or self.llm.max_tokens * 3 // 4
|
||||
"""
|
||||
The token limit for prompt construction. Should leave room for the completion;
|
||||
@@ -111,6 +108,12 @@ class BaseAgent(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.prompt_generator.construct_system_prompt(self)
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatModelInfo:
|
||||
"""The LLM that the agent uses to think."""
|
||||
llm_name = self.config.smart_llm if self.big_brain else self.config.fast_llm
|
||||
return OPEN_AI_CHAT_MODELS[llm_name]
|
||||
|
||||
def think(
|
||||
self,
|
||||
instruction: Optional[str] = None,
|
||||
|
||||
59
autogpt/agents/features/watchdog.py
Normal file
59
autogpt/agents/features/watchdog.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import ExitStack
|
||||
|
||||
from autogpt.models.agent_actions import ActionHistory
|
||||
|
||||
from ..base import BaseAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatchdogMixin:
|
||||
"""
|
||||
Mixin that adds a watchdog feature to an agent class. Whenever the agent starts
|
||||
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
|
||||
"""
|
||||
|
||||
event_history: ActionHistory
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# Initialize other bases first, because we need the event_history from BaseAgent
|
||||
super(WatchdogMixin, self).__init__(**kwargs)
|
||||
|
||||
if not isinstance(self, BaseAgent):
|
||||
raise NotImplementedError(
|
||||
f"{__class__.__name__} can only be applied to BaseAgent derivatives"
|
||||
)
|
||||
|
||||
def think(self, *args, **kwargs) -> BaseAgent.ThoughtProcessOutput:
|
||||
command_name, command_args, thoughts = super(WatchdogMixin, self).think(
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
if not self.big_brain and len(self.event_history) > 1:
|
||||
# Detect repetitive commands
|
||||
previous_cycle = self.event_history.cycles[self.event_history.cursor - 1]
|
||||
if (
|
||||
command_name == previous_cycle.action.name
|
||||
and command_args == previous_cycle.action.args
|
||||
):
|
||||
logger.info(
|
||||
f"Repetitive command detected ({command_name}), re-thinking with SMART_LLM..."
|
||||
)
|
||||
with ExitStack() as stack:
|
||||
|
||||
@stack.callback
|
||||
def restore_state() -> None:
|
||||
# Executed after exiting the ExitStack context
|
||||
self.big_brain = False
|
||||
|
||||
# Remove partial record of current cycle
|
||||
self.event_history.rewind()
|
||||
|
||||
# Switch to SMART_LLM and re-think
|
||||
self.big_brain = True
|
||||
return self.think(*args, **kwargs)
|
||||
|
||||
return command_name, command_args, thoughts
|
||||
@@ -55,7 +55,7 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
workspace_path: Optional[Path] = None
|
||||
file_logger_path: Optional[Path] = None
|
||||
# Model configuration
|
||||
fast_llm: str = "gpt-3.5-turbo"
|
||||
fast_llm: str = "gpt-3.5-turbo-16k"
|
||||
smart_llm: str = "gpt-4-0314"
|
||||
temperature: float = 0
|
||||
openai_functions: bool = False
|
||||
|
||||
@@ -103,6 +103,23 @@ class ActionHistory:
|
||||
self.current_record.result = result
|
||||
self.cursor = len(self.cycles)
|
||||
|
||||
def rewind(self, number_of_cycles: int = 0) -> None:
|
||||
"""Resets the history to an earlier state.
|
||||
|
||||
Params:
|
||||
number_of_cycles (int): The number of cycles to rewind. Default is 0.
|
||||
When set to 0, it will only reset the current cycle.
|
||||
"""
|
||||
# Remove partial record of current cycle
|
||||
if self.current_record:
|
||||
if self.current_record.action and not self.current_record.result:
|
||||
self.cycles.pop(self.cursor)
|
||||
|
||||
# Rewind the specified number of cycles
|
||||
if number_of_cycles > 0:
|
||||
self.cycles = self.cycles[:-number_of_cycles]
|
||||
self.cursor = len(self.cycles)
|
||||
|
||||
def fmt_list(self) -> str:
|
||||
return format_numbered_list(self.cycles)
|
||||
|
||||
|
||||
@@ -45,11 +45,13 @@ class Command:
|
||||
def __call__(self, *args, agent: BaseAgent, **kwargs) -> Any:
|
||||
if callable(self.enabled) and not self.enabled(agent.config):
|
||||
if self.disabled_reason:
|
||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
return f"Command '{self.name}' is disabled"
|
||||
raise RuntimeError(
|
||||
f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
)
|
||||
raise RuntimeError(f"Command '{self.name}' is disabled")
|
||||
|
||||
if callable(self.available) and not self.available(agent):
|
||||
return f"Command '{self.name}' is not available"
|
||||
raise RuntimeError(f"Command '{self.name}' is not available")
|
||||
|
||||
return self.method(*args, **kwargs, agent=agent)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user