Implement watchdog feature for dynamic switching between smart & fast LLMs

This commit is contained in:
Reinier van der Leer
2023-08-31 01:39:51 +02:00
parent a4ef53c55c
commit e4370652e9
7 changed files with 96 additions and 14 deletions

View File

@@ -66,11 +66,11 @@ OPENAI_API_KEY=your-openai-api-key
### LLM MODELS
################################################################################
## SMART_LLM - Smart language model (Default: gpt-4)
# SMART_LLM=gpt-4
## SMART_LLM - Smart language model (Default: gpt-4-0314)
# SMART_LLM=gpt-4-0314
## FAST_LLM - Fast language model (Default: gpt-3.5-turbo)
# FAST_LLM=gpt-3.5-turbo
## FAST_LLM - Fast language model (Default: gpt-3.5-turbo-16k)
# FAST_LLM=gpt-3.5-turbo-16k
## EMBEDDING_MODEL - Model to use for creating embeddings
# EMBEDDING_MODEL=text-embedding-ada-002

View File

@@ -35,6 +35,7 @@ from autogpt.models.context_item import ContextItem
from .base import BaseAgent
from .features.context import ContextMixin
from .features.watchdog import WatchdogMixin
from .features.workspace import WorkspaceMixin
from .utils.exceptions import (
AgentException,
@@ -46,7 +47,7 @@ from .utils.exceptions import (
logger = logging.getLogger(__name__)
class Agent(ContextMixin, WorkspaceMixin, BaseAgent):
class Agent(ContextMixin, WorkspaceMixin, WatchdogMixin, BaseAgent):
"""Agent class for interacting with Auto-GPT."""
def __init__(

View File

@@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
if TYPE_CHECKING:
from autogpt.config import AIConfig, Config
from autogpt.llm.base import ChatModelInfo, ChatModelResponse
from autogpt.models.command_registry import CommandRegistry
from autogpt.agents.utils.exceptions import InvalidAgentResponseError
from autogpt.config.ai_directives import AIDirectives
from autogpt.llm.base import ChatModelResponse, ChatSequence, Message
from autogpt.llm.base import ChatSequence, Message
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS, get_openai_command_specs
from autogpt.llm.utils import count_message_tokens, create_chat_completion
from autogpt.memory.message_history import MessageHistory
@@ -83,10 +84,6 @@ class BaseAgent(metaclass=ABCMeta):
self.cycle_count = 0
"""The number of cycles that the agent has run since its initialization."""
llm_name = self.config.smart_llm if self.big_brain else self.config.fast_llm
self.llm = OPEN_AI_CHAT_MODELS[llm_name]
"""The LLM that the agent uses to think."""
self.send_token_limit = send_token_limit or self.llm.max_tokens * 3 // 4
"""
The token limit for prompt construction. Should leave room for the completion;
@@ -111,6 +108,12 @@ class BaseAgent(metaclass=ABCMeta):
"""
return self.prompt_generator.construct_system_prompt(self)
@property
def llm(self) -> ChatModelInfo:
"""The LLM that the agent uses to think."""
llm_name = self.config.smart_llm if self.big_brain else self.config.fast_llm
return OPEN_AI_CHAT_MODELS[llm_name]
def think(
self,
instruction: Optional[str] = None,

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
import logging
from contextlib import ExitStack
from autogpt.models.agent_actions import ActionHistory
from ..base import BaseAgent
logger = logging.getLogger(__name__)
class WatchdogMixin:
"""
Mixin that adds a watchdog feature to an agent class. Whenever the agent starts
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
"""
event_history: ActionHistory
def __init__(self, **kwargs) -> None:
# Initialize other bases first, because we need the event_history from BaseAgent
super(WatchdogMixin, self).__init__(**kwargs)
if not isinstance(self, BaseAgent):
raise NotImplementedError(
f"{__class__.__name__} can only be applied to BaseAgent derivatives"
)
def think(self, *args, **kwargs) -> BaseAgent.ThoughtProcessOutput:
command_name, command_args, thoughts = super(WatchdogMixin, self).think(
*args, **kwargs
)
if not self.big_brain and len(self.event_history) > 1:
# Detect repetitive commands
previous_cycle = self.event_history.cycles[self.event_history.cursor - 1]
if (
command_name == previous_cycle.action.name
and command_args == previous_cycle.action.args
):
logger.info(
f"Repetitive command detected ({command_name}), re-thinking with SMART_LLM..."
)
with ExitStack() as stack:
@stack.callback
def restore_state() -> None:
# Executed after exiting the ExitStack context
self.big_brain = False
# Remove partial record of current cycle
self.event_history.rewind()
# Switch to SMART_LLM and re-think
self.big_brain = True
return self.think(*args, **kwargs)
return command_name, command_args, thoughts

View File

@@ -55,7 +55,7 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
workspace_path: Optional[Path] = None
file_logger_path: Optional[Path] = None
# Model configuration
fast_llm: str = "gpt-3.5-turbo"
fast_llm: str = "gpt-3.5-turbo-16k"
smart_llm: str = "gpt-4-0314"
temperature: float = 0
openai_functions: bool = False

View File

@@ -103,6 +103,23 @@ class ActionHistory:
self.current_record.result = result
self.cursor = len(self.cycles)
def rewind(self, number_of_cycles: int = 0) -> None:
"""Resets the history to an earlier state.
Params:
number_of_cycles (int): The number of cycles to rewind. Default is 0.
When set to 0, it will only reset the current cycle.
"""
# Remove partial record of current cycle
if self.current_record:
if self.current_record.action and not self.current_record.result:
self.cycles.pop(self.cursor)
# Rewind the specified number of cycles
if number_of_cycles > 0:
self.cycles = self.cycles[:-number_of_cycles]
self.cursor = len(self.cycles)
def fmt_list(self) -> str:
return format_numbered_list(self.cycles)

View File

@@ -45,11 +45,13 @@ class Command:
def __call__(self, *args, agent: BaseAgent, **kwargs) -> Any:
if callable(self.enabled) and not self.enabled(agent.config):
if self.disabled_reason:
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
return f"Command '{self.name}' is disabled"
raise RuntimeError(
f"Command '{self.name}' is disabled: {self.disabled_reason}"
)
raise RuntimeError(f"Command '{self.name}' is disabled")
if callable(self.available) and not self.available(agent):
return f"Command '{self.name}' is not available"
raise RuntimeError(f"Command '{self.name}' is not available")
return self.method(*args, **kwargs, agent=agent)