mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-29 03:44:28 +01:00
Merge branch 'master' into release-v0.4.4
This commit is contained in:
5
.github/workflows/ci.yml
vendored
5
.github/workflows/ci.yml
vendored
@@ -153,7 +153,8 @@ jobs:
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
pytest -n auto --cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
pytest -vv --cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
tests/unit tests/integration tests/challenges
|
||||
python tests/challenges/utils/build_current_score.py
|
||||
env:
|
||||
@@ -251,7 +252,7 @@ jobs:
|
||||
gh api repos/$REPO/issues/$PR_NUMBER/comments -X POST -F body="You changed AutoGPT's behaviour. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
|
||||
- name: Upload logs as artifact
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
|
||||
10
.github/workflows/docker-ci.yml
vendored
10
.github/workflows/docker-ci.yml
vendored
@@ -73,16 +73,13 @@ jobs:
|
||||
run: .github/workflows/scripts/docker-ci-summary.sh >> $GITHUB_STEP_SUMMARY
|
||||
continue-on-error: true
|
||||
|
||||
# Docker setup needs fixing before this is going to work: #1843
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
needs: build
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
@@ -102,14 +99,15 @@ jobs:
|
||||
- id: test
|
||||
name: Run tests
|
||||
env:
|
||||
PLAIN_OUTPUT: True
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
set +e
|
||||
test_output=$(
|
||||
docker run --env CI --env OPENAI_API_KEY --entrypoint python ${{ env.IMAGE_NAME }} -m \
|
||||
pytest -n auto --cov=autogpt --cov-branch --cov-report term-missing \
|
||||
pytest -v --cov=autogpt --cov-branch --cov-report term-missing \
|
||||
--numprocesses=4 --durations=10 \
|
||||
tests/unit tests/integration 2>&1
|
||||
)
|
||||
test_failure=$?
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -12,7 +12,7 @@ last_run_ai_settings.yaml
|
||||
auto-gpt.json
|
||||
log.txt
|
||||
log-ingestion.txt
|
||||
logs
|
||||
/logs
|
||||
*.log
|
||||
*.mp3
|
||||
mem.sqlite3
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from autogpt.agent.agent import Agent
|
||||
from autogpt.agent.agent_manager import AgentManager
|
||||
|
||||
__all__ = ["Agent", "AgentManager"]
|
||||
__all__ = ["Agent"]
|
||||
|
||||
@@ -12,13 +12,15 @@ from autogpt.json_utils.utilities import extract_json_from_response, validate_js
|
||||
from autogpt.llm.chat import chat_with_ai
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
|
||||
from autogpt.llm.utils import count_string_tokens
|
||||
from autogpt.log_cycle.log_cycle import (
|
||||
from autogpt.logs import (
|
||||
FULL_MESSAGE_HISTORY_FILE_NAME,
|
||||
NEXT_ACTION_FILE_NAME,
|
||||
USER_INPUT_FILE_NAME,
|
||||
LogCycleHandler,
|
||||
logger,
|
||||
print_assistant_thoughts,
|
||||
remove_ansi_escape,
|
||||
)
|
||||
from autogpt.logs import logger, print_assistant_thoughts, remove_ansi_escape
|
||||
from autogpt.memory.message_history import MessageHistory
|
||||
from autogpt.memory.vector import VectorMemory
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
@@ -70,7 +72,7 @@ class Agent:
|
||||
):
|
||||
self.ai_name = ai_name
|
||||
self.memory = memory
|
||||
self.history = MessageHistory(self)
|
||||
self.history = MessageHistory.for_model(config.smart_llm, agent=self)
|
||||
self.next_action_count = next_action_count
|
||||
self.command_registry = command_registry
|
||||
self.config = config
|
||||
@@ -167,8 +169,6 @@ class Agent:
|
||||
if self.config.speak_mode:
|
||||
say_text(f"I want to execute {command_name}", self.config)
|
||||
|
||||
arguments = self._resolve_pathlike_command_args(arguments)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error: \n", str(e))
|
||||
self.log_cycle_handler.log_cycle(
|
||||
@@ -307,14 +307,3 @@ class Agent:
|
||||
logger.typewriter_log(
|
||||
"SYSTEM: ", Fore.YELLOW, "Unable to execute command"
|
||||
)
|
||||
|
||||
def _resolve_pathlike_command_args(self, command_args):
|
||||
if "directory" in command_args and command_args["directory"] in {"", "/"}:
|
||||
command_args["directory"] = str(self.workspace.root)
|
||||
else:
|
||||
for pathlike in ["filename", "directory", "clone_path"]:
|
||||
if pathlike in command_args:
|
||||
command_args[pathlike] = str(
|
||||
self.workspace.get_path(command_args[pathlike])
|
||||
)
|
||||
return command_args
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
"""Agent manager for managing GPT agents"""
|
||||
from __future__ import annotations
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.llm.base import ChatSequence
|
||||
from autogpt.llm.chat import Message, create_chat_completion
|
||||
from autogpt.singleton import Singleton
|
||||
|
||||
|
||||
class AgentManager(metaclass=Singleton):
|
||||
"""Agent manager for managing GPT agents"""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.next_key = 0
|
||||
self.agents: dict[
|
||||
int, tuple[str, list[Message], str]
|
||||
] = {} # key, (task, full_message_history, model)
|
||||
self.config = config
|
||||
|
||||
# Create new GPT agent
|
||||
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
|
||||
|
||||
def create_agent(
|
||||
self, task: str, creation_prompt: str, model: str
|
||||
) -> tuple[int, str]:
|
||||
"""Create a new agent and return its key
|
||||
|
||||
Args:
|
||||
task: The task to perform
|
||||
creation_prompt: Prompt passed to the LLM at creation
|
||||
model: The model to use to run this agent
|
||||
|
||||
Returns:
|
||||
The key of the new agent
|
||||
"""
|
||||
messages = ChatSequence.for_model(model, [Message("user", creation_prompt)])
|
||||
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_pre_instruction():
|
||||
continue
|
||||
if plugin_messages := plugin.pre_instruction(messages.raw()):
|
||||
messages.extend([Message(**raw_msg) for raw_msg in plugin_messages])
|
||||
# Start GPT instance
|
||||
agent_reply = create_chat_completion(
|
||||
prompt=messages, config=self.config
|
||||
).content
|
||||
|
||||
messages.add("assistant", agent_reply)
|
||||
|
||||
plugins_reply = ""
|
||||
for i, plugin in enumerate(self.config.plugins):
|
||||
if not plugin.can_handle_on_instruction():
|
||||
continue
|
||||
if plugin_result := plugin.on_instruction([m.raw() for m in messages]):
|
||||
sep = "\n" if i else ""
|
||||
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
|
||||
|
||||
if plugins_reply and plugins_reply != "":
|
||||
messages.add("assistant", plugins_reply)
|
||||
key = self.next_key
|
||||
# This is done instead of len(agents) to make keys unique even if agents
|
||||
# are deleted
|
||||
self.next_key += 1
|
||||
|
||||
self.agents[key] = (task, list(messages), model)
|
||||
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_instruction():
|
||||
continue
|
||||
agent_reply = plugin.post_instruction(agent_reply)
|
||||
|
||||
return key, agent_reply
|
||||
|
||||
def message_agent(self, key: str | int, message: str) -> str:
|
||||
"""Send a message to an agent and return its response
|
||||
|
||||
Args:
|
||||
key: The key of the agent to message
|
||||
message: The message to send to the agent
|
||||
|
||||
Returns:
|
||||
The agent's response
|
||||
"""
|
||||
task, messages, model = self.agents[int(key)]
|
||||
|
||||
# Add user message to message history before sending to agent
|
||||
messages = ChatSequence.for_model(model, messages)
|
||||
messages.add("user", message)
|
||||
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_pre_instruction():
|
||||
continue
|
||||
if plugin_messages := plugin.pre_instruction([m.raw() for m in messages]):
|
||||
messages.extend([Message(**raw_msg) for raw_msg in plugin_messages])
|
||||
|
||||
# Start GPT instance
|
||||
agent_reply = create_chat_completion(
|
||||
prompt=messages, config=self.config
|
||||
).content
|
||||
|
||||
messages.add("assistant", agent_reply)
|
||||
|
||||
plugins_reply = agent_reply
|
||||
for i, plugin in enumerate(self.config.plugins):
|
||||
if not plugin.can_handle_on_instruction():
|
||||
continue
|
||||
if plugin_result := plugin.on_instruction([m.raw() for m in messages]):
|
||||
sep = "\n" if i else ""
|
||||
plugins_reply = f"{plugins_reply}{sep}{plugin_result}"
|
||||
# Update full message history
|
||||
if plugins_reply and plugins_reply != "":
|
||||
messages.add("assistant", plugins_reply)
|
||||
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_instruction():
|
||||
continue
|
||||
agent_reply = plugin.post_instruction(agent_reply)
|
||||
|
||||
return agent_reply
|
||||
|
||||
def list_agents(self) -> list[tuple[str | int, str]]:
|
||||
"""Return a list of all agents
|
||||
|
||||
Returns:
|
||||
A list of tuples of the form (key, task)
|
||||
"""
|
||||
|
||||
# Return a list of agent keys and their tasks
|
||||
return [(key, task) for key, (task, _, _) in self.agents.items()]
|
||||
|
||||
def delete_agent(self, key: str | int) -> bool:
|
||||
"""Delete an agent from the agent manager
|
||||
|
||||
Args:
|
||||
key: The key of the agent to delete
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
del self.agents[int(key)]
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
64
autogpt/commands/decorators.py
Normal file
64
autogpt/commands/decorators.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from autogpt.agent.agent import Agent
|
||||
from autogpt.logs import logger
|
||||
|
||||
|
||||
def sanitize_path_arg(arg_name: str):
|
||||
def decorator(func: Callable):
|
||||
# Get position of path parameter, in case it is passed as a positional argument
|
||||
try:
|
||||
arg_index = list(func.__annotations__.keys()).index(arg_name)
|
||||
except ValueError:
|
||||
raise TypeError(
|
||||
f"Sanitized parameter '{arg_name}' absent or not annotated on function '{func.__name__}'"
|
||||
)
|
||||
|
||||
# Get position of agent parameter, in case it is passed as a positional argument
|
||||
try:
|
||||
agent_arg_index = list(func.__annotations__.keys()).index("agent")
|
||||
except ValueError:
|
||||
raise TypeError(
|
||||
f"Parameter 'agent' absent or not annotated on function '{func.__name__}'"
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
logger.debug(f"Sanitizing arg '{arg_name}' on function '{func.__name__}'")
|
||||
logger.debug(f"Function annotations: {func.__annotations__}")
|
||||
|
||||
# Get Agent from the called function's arguments
|
||||
agent = kwargs.get(
|
||||
"agent", len(args) > agent_arg_index and args[agent_arg_index]
|
||||
)
|
||||
logger.debug(f"Args: {args}")
|
||||
logger.debug(f"KWArgs: {kwargs}")
|
||||
logger.debug(f"Agent argument lifted from function call: {agent}")
|
||||
if not isinstance(agent, Agent):
|
||||
raise RuntimeError("Could not get Agent from decorated command's args")
|
||||
|
||||
# Sanitize the specified path argument, if one is given
|
||||
given_path: str | Path | None = kwargs.get(
|
||||
arg_name, len(args) > arg_index and args[arg_index] or None
|
||||
)
|
||||
if given_path:
|
||||
if given_path in {"", "/"}:
|
||||
sanitized_path = str(agent.workspace.root)
|
||||
else:
|
||||
sanitized_path = str(agent.workspace.get_path(given_path))
|
||||
|
||||
if arg_name in kwargs:
|
||||
kwargs[arg_name] = sanitized_path
|
||||
else:
|
||||
# args is an immutable tuple; must be converted to a list to update
|
||||
arg_list = list(args)
|
||||
arg_list[arg_index] = sanitized_path
|
||||
args = tuple(arg_list)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -12,6 +12,8 @@ from autogpt.command_decorator import command
|
||||
from autogpt.config import Config
|
||||
from autogpt.logs import logger
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
|
||||
ALLOWLIST_CONTROL = "allowlist"
|
||||
DENYLIST_CONTROL = "denylist"
|
||||
|
||||
@@ -43,14 +45,14 @@ def execute_python_code(code: str, name: str, agent: Agent) -> str:
|
||||
Returns:
|
||||
str: The STDOUT captured from the code when it ran
|
||||
"""
|
||||
ai_name = agent.ai_name
|
||||
ai_name = agent.ai_config.ai_name
|
||||
code_dir = agent.workspace.get_path(Path(ai_name, "executed_code"))
|
||||
os.makedirs(code_dir, exist_ok=True)
|
||||
|
||||
if not name.endswith(".py"):
|
||||
name = name + ".py"
|
||||
|
||||
# The `name` arg is not covered by Agent._resolve_pathlike_command_args(),
|
||||
# The `name` arg is not covered by @sanitize_path_arg,
|
||||
# so sanitization must be done here to prevent path traversal.
|
||||
file_path = agent.workspace.get_path(code_dir / name)
|
||||
if not file_path.is_relative_to(code_dir):
|
||||
@@ -76,6 +78,7 @@ def execute_python_code(code: str, name: str, agent: Agent) -> str:
|
||||
},
|
||||
},
|
||||
)
|
||||
@sanitize_path_arg("filename")
|
||||
def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
"""Execute a Python file in a Docker container and return the output
|
||||
|
||||
@@ -100,6 +103,9 @@ def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
)
|
||||
|
||||
if we_are_running_in_a_docker_container():
|
||||
logger.debug(
|
||||
f"Auto-GPT is running in a Docker container; executing {file_path} directly..."
|
||||
)
|
||||
result = subprocess.run(
|
||||
["python", str(file_path)],
|
||||
capture_output=True,
|
||||
@@ -111,6 +117,7 @@ def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
else:
|
||||
return f"Error: {result.stderr}"
|
||||
|
||||
logger.debug("Auto-GPT is not running in a Docker container")
|
||||
try:
|
||||
client = docker.from_env()
|
||||
# You can replace this with the desired Python image/version
|
||||
@@ -119,10 +126,10 @@ def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
image_name = "python:3-alpine"
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
logger.warn(f"Image '{image_name}' found locally")
|
||||
logger.debug(f"Image '{image_name}' found locally")
|
||||
except ImageNotFound:
|
||||
logger.info(
|
||||
f"Image '{image_name}' not found locally, pulling from Docker Hub"
|
||||
f"Image '{image_name}' not found locally, pulling from Docker Hub..."
|
||||
)
|
||||
# Use the low-level API to stream the pull response
|
||||
low_level_client = docker.APIClient()
|
||||
@@ -135,6 +142,7 @@ def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
elif status:
|
||||
logger.info(status)
|
||||
|
||||
logger.debug(f"Running {file_path} in a {image_name} container...")
|
||||
container: DockerContainer = client.containers.run(
|
||||
image_name,
|
||||
["python", str(file_path.relative_to(agent.workspace.root))],
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
"""File operations for AutoGPT"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import os
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
from confection import Config
|
||||
|
||||
from autogpt.agent.agent import Agent
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.commands.file_operations_utils import read_textual_file
|
||||
from autogpt.config import Config
|
||||
from autogpt.logs import logger
|
||||
from autogpt.memory.vector import MemoryItem, VectorMemory
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
from .file_operations_utils import read_textual_file
|
||||
|
||||
Operation = Literal["write", "append", "delete"]
|
||||
|
||||
|
||||
@@ -74,21 +75,26 @@ def file_operations_state(log_path: str) -> dict[str, str]:
|
||||
return state
|
||||
|
||||
|
||||
@sanitize_path_arg("filename")
|
||||
def is_duplicate_operation(
|
||||
operation: Operation, filename: str, config: Config, checksum: str | None = None
|
||||
operation: Operation, filename: str, agent: Agent, checksum: str | None = None
|
||||
) -> bool:
|
||||
"""Check if the operation has already been performed
|
||||
|
||||
Args:
|
||||
operation: The operation to check for
|
||||
filename: The name of the file to check for
|
||||
config: The agent config
|
||||
agent: The agent
|
||||
checksum: The checksum of the contents to be written
|
||||
|
||||
Returns:
|
||||
True if the operation has already been performed on the file
|
||||
"""
|
||||
state = file_operations_state(config.file_logger_path)
|
||||
# Make the filename into a relative path if possible
|
||||
with contextlib.suppress(ValueError):
|
||||
filename = str(Path(filename).relative_to(agent.workspace.root))
|
||||
|
||||
state = file_operations_state(agent.config.file_logger_path)
|
||||
if operation == "delete" and filename not in state:
|
||||
return True
|
||||
if operation == "write" and state.get(filename) == checksum:
|
||||
@@ -96,8 +102,9 @@ def is_duplicate_operation(
|
||||
return False
|
||||
|
||||
|
||||
@sanitize_path_arg("filename")
|
||||
def log_operation(
|
||||
operation: str, filename: str, agent: Agent, checksum: str | None = None
|
||||
operation: Operation, filename: str, agent: Agent, checksum: str | None = None
|
||||
) -> None:
|
||||
"""Log the file operation to the file_logger.txt
|
||||
|
||||
@@ -106,6 +113,10 @@ def log_operation(
|
||||
filename: The name of the file the operation was performed on
|
||||
checksum: The checksum of the contents to be written
|
||||
"""
|
||||
# Make the filename into a relative path if possible
|
||||
with contextlib.suppress(ValueError):
|
||||
filename = str(Path(filename).relative_to(agent.workspace.root))
|
||||
|
||||
log_entry = f"{operation}: {filename}"
|
||||
if checksum is not None:
|
||||
log_entry += f" #{checksum}"
|
||||
@@ -126,6 +137,7 @@ def log_operation(
|
||||
}
|
||||
},
|
||||
)
|
||||
@sanitize_path_arg("filename")
|
||||
def read_file(filename: str, agent: Agent) -> str:
|
||||
"""Read a file and return the contents
|
||||
|
||||
@@ -191,6 +203,7 @@ def ingest_file(
|
||||
},
|
||||
aliases=["write_file", "create_file"],
|
||||
)
|
||||
@sanitize_path_arg("filename")
|
||||
def write_to_file(filename: str, text: str, agent: Agent) -> str:
|
||||
"""Write text to a file
|
||||
|
||||
@@ -202,7 +215,7 @@ def write_to_file(filename: str, text: str, agent: Agent) -> str:
|
||||
str: A message indicating success or failure
|
||||
"""
|
||||
checksum = text_checksum(text)
|
||||
if is_duplicate_operation("write", filename, agent.config, checksum):
|
||||
if is_duplicate_operation("write", filename, agent, checksum):
|
||||
return "Error: File has already been updated."
|
||||
try:
|
||||
directory = os.path.dirname(filename)
|
||||
@@ -231,6 +244,7 @@ def write_to_file(filename: str, text: str, agent: Agent) -> str:
|
||||
},
|
||||
},
|
||||
)
|
||||
@sanitize_path_arg("filename")
|
||||
def append_to_file(
|
||||
filename: str, text: str, agent: Agent, should_log: bool = True
|
||||
) -> str:
|
||||
@@ -271,6 +285,7 @@ def append_to_file(
|
||||
}
|
||||
},
|
||||
)
|
||||
@sanitize_path_arg("filename")
|
||||
def delete_file(filename: str, agent: Agent) -> str:
|
||||
"""Delete a file
|
||||
|
||||
@@ -280,7 +295,7 @@ def delete_file(filename: str, agent: Agent) -> str:
|
||||
Returns:
|
||||
str: A message indicating success or failure
|
||||
"""
|
||||
if is_duplicate_operation("delete", filename, agent.config):
|
||||
if is_duplicate_operation("delete", filename, agent):
|
||||
return "Error: File has already been deleted."
|
||||
try:
|
||||
os.remove(filename)
|
||||
@@ -301,6 +316,7 @@ def delete_file(filename: str, agent: Agent) -> str:
|
||||
}
|
||||
},
|
||||
)
|
||||
@sanitize_path_arg("directory")
|
||||
def list_files(directory: str, agent: Agent) -> list[str]:
|
||||
"""lists files in a directory recursively
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ from autogpt.agent.agent import Agent
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.url_utils.validators import validate_url
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
|
||||
|
||||
@command(
|
||||
"clone_repository",
|
||||
@@ -22,9 +24,10 @@ from autogpt.url_utils.validators import validate_url
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
lambda config: config.github_username and config.github_api_key,
|
||||
lambda config: bool(config.github_username and config.github_api_key),
|
||||
"Configure github_username and github_api_key.",
|
||||
)
|
||||
@sanitize_path_arg("clone_path")
|
||||
@validate_url
|
||||
def clone_repository(url: str, clone_path: str, agent: Agent) -> str:
|
||||
"""Clone a GitHub repository locally.
|
||||
|
||||
@@ -24,7 +24,7 @@ from autogpt.logs import logger
|
||||
"required": True,
|
||||
},
|
||||
},
|
||||
lambda config: config.image_provider,
|
||||
lambda config: bool(config.image_provider),
|
||||
"Requires a image provider to be set.",
|
||||
)
|
||||
def generate_image(prompt: str, agent: Agent, size: int = 256) -> str:
|
||||
|
||||
@@ -4,87 +4,145 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import yaml
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from colorama import Fore
|
||||
from pydantic import Field, validator
|
||||
|
||||
from autogpt.core.configuration.schema import Configurable, SystemSettings
|
||||
from autogpt.plugins.plugins_config import PluginsConfig
|
||||
|
||||
AZURE_CONFIG_FILE = os.path.join(os.path.dirname(__file__), "../..", "azure.yaml")
|
||||
PLUGINS_CONFIG_FILE = os.path.join(
|
||||
os.path.dirname(__file__), "../..", "plugins_config.yaml"
|
||||
)
|
||||
GPT_4_MODEL = "gpt-4"
|
||||
GPT_3_MODEL = "gpt-3.5-turbo"
|
||||
|
||||
|
||||
class Config(SystemSettings):
|
||||
fast_llm: str
|
||||
smart_llm: str
|
||||
continuous_mode: bool
|
||||
skip_news: bool
|
||||
class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
name: str = "Auto-GPT configuration"
|
||||
description: str = "Default configuration for the Auto-GPT application."
|
||||
########################
|
||||
# Application Settings #
|
||||
########################
|
||||
skip_news: bool = False
|
||||
skip_reprompt: bool = False
|
||||
authorise_key: str = "y"
|
||||
exit_key: str = "n"
|
||||
debug_mode: bool = False
|
||||
plain_output: bool = False
|
||||
chat_messages_enabled: bool = True
|
||||
# TTS configuration
|
||||
speak_mode: bool = False
|
||||
text_to_speech_provider: str = "gtts"
|
||||
streamelements_voice: str = "Brian"
|
||||
elevenlabs_voice_id: Optional[str] = None
|
||||
|
||||
##########################
|
||||
# Agent Control Settings #
|
||||
##########################
|
||||
# Paths
|
||||
ai_settings_file: str = "ai_settings.yaml"
|
||||
prompt_settings_file: str = "prompt_settings.yaml"
|
||||
workspace_path: Optional[str] = None
|
||||
file_logger_path: Optional[str] = None
|
||||
debug_mode: bool
|
||||
plugins_dir: str
|
||||
plugins_config: PluginsConfig
|
||||
continuous_limit: int
|
||||
speak_mode: bool
|
||||
skip_reprompt: bool
|
||||
allow_downloads: bool
|
||||
exit_key: str
|
||||
plain_output: bool
|
||||
disabled_command_categories: list[str]
|
||||
shell_command_control: str
|
||||
shell_denylist: list[str]
|
||||
shell_allowlist: list[str]
|
||||
ai_settings_file: str
|
||||
prompt_settings_file: str
|
||||
embedding_model: str
|
||||
browse_spacy_language_model: str
|
||||
# Model configuration
|
||||
fast_llm: str = "gpt-3.5-turbo"
|
||||
smart_llm: str = "gpt-4"
|
||||
temperature: float = 0
|
||||
openai_functions: bool = False
|
||||
embedding_model: str = "text-embedding-ada-002"
|
||||
browse_spacy_language_model: str = "en_core_web_sm"
|
||||
# Run loop configuration
|
||||
continuous_mode: bool = False
|
||||
continuous_limit: int = 0
|
||||
|
||||
##########
|
||||
# Memory #
|
||||
##########
|
||||
memory_backend: str = "json_file"
|
||||
memory_index: str = "auto-gpt-memory"
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_password: str = ""
|
||||
wipe_redis_on_start: bool = True
|
||||
|
||||
############
|
||||
# Commands #
|
||||
############
|
||||
# General
|
||||
disabled_command_categories: list[str] = Field(default_factory=list)
|
||||
# File ops
|
||||
restrict_to_workspace: bool = True
|
||||
allow_downloads: bool = False
|
||||
# Shell commands
|
||||
shell_command_control: str = "denylist"
|
||||
execute_local_commands: bool = False
|
||||
shell_denylist: list[str] = Field(default_factory=lambda: ["sudo", "su"])
|
||||
shell_allowlist: list[str] = Field(default_factory=list)
|
||||
# Text to image
|
||||
image_provider: Optional[str] = None
|
||||
huggingface_image_model: str = "CompVis/stable-diffusion-v1-4"
|
||||
sd_webui_url: Optional[str] = "http://localhost:7860"
|
||||
image_size: int = 256
|
||||
# Audio to text
|
||||
audio_to_text_provider: str = "huggingface"
|
||||
huggingface_audio_to_text_model: Optional[str] = None
|
||||
# Web browsing
|
||||
selenium_web_browser: str = "chrome"
|
||||
selenium_headless: bool = True
|
||||
user_agent: str = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"
|
||||
|
||||
###################
|
||||
# Plugin Settings #
|
||||
###################
|
||||
plugins_dir: str = "plugins"
|
||||
plugins_config_file: str = PLUGINS_CONFIG_FILE
|
||||
plugins_config: PluginsConfig = Field(
|
||||
default_factory=lambda: PluginsConfig(plugins={})
|
||||
)
|
||||
plugins: list[AutoGPTPluginTemplate] = Field(default_factory=list, exclude=True)
|
||||
plugins_allowlist: list[str] = Field(default_factory=list)
|
||||
plugins_denylist: list[str] = Field(default_factory=list)
|
||||
plugins_openai: list[str] = Field(default_factory=list)
|
||||
|
||||
###############
|
||||
# Credentials #
|
||||
###############
|
||||
# OpenAI
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_organization: Optional[str] = None
|
||||
temperature: float
|
||||
use_azure: bool
|
||||
azure_config_file: Optional[str] = None
|
||||
azure_model_to_deployment_id_map: Optional[Dict[str, str]] = None
|
||||
execute_local_commands: bool
|
||||
restrict_to_workspace: bool
|
||||
openai_api_type: Optional[str] = None
|
||||
openai_api_base: Optional[str] = None
|
||||
openai_api_version: Optional[str] = None
|
||||
openai_functions: bool
|
||||
openai_organization: Optional[str] = None
|
||||
use_azure: bool = False
|
||||
azure_config_file: Optional[str] = AZURE_CONFIG_FILE
|
||||
azure_model_to_deployment_id_map: Optional[Dict[str, str]] = None
|
||||
# Elevenlabs
|
||||
elevenlabs_api_key: Optional[str] = None
|
||||
streamelements_voice: str
|
||||
text_to_speech_provider: str
|
||||
# Github
|
||||
github_api_key: Optional[str] = None
|
||||
github_username: Optional[str] = None
|
||||
# Google
|
||||
google_api_key: Optional[str] = None
|
||||
google_custom_search_engine_id: Optional[str] = None
|
||||
image_provider: Optional[str] = None
|
||||
image_size: int
|
||||
# Huggingface
|
||||
huggingface_api_token: Optional[str] = None
|
||||
huggingface_image_model: str
|
||||
audio_to_text_provider: str
|
||||
huggingface_audio_to_text_model: Optional[str] = None
|
||||
sd_webui_url: Optional[str] = None
|
||||
# Stable Diffusion
|
||||
sd_webui_auth: Optional[str] = None
|
||||
selenium_web_browser: str
|
||||
selenium_headless: bool
|
||||
user_agent: str
|
||||
memory_backend: str
|
||||
memory_index: str
|
||||
redis_host: str
|
||||
redis_port: int
|
||||
redis_password: str
|
||||
wipe_redis_on_start: bool
|
||||
plugins_allowlist: list[str]
|
||||
plugins_denylist: list[str]
|
||||
plugins_openai: list[str]
|
||||
plugins_config_file: str
|
||||
chat_messages_enabled: bool
|
||||
elevenlabs_voice_id: Optional[str] = None
|
||||
plugins: list[str]
|
||||
authorise_key: str
|
||||
|
||||
@validator("plugins", each_item=True)
|
||||
def validate_plugins(cls, p: AutoGPTPluginTemplate | Any):
|
||||
assert issubclass(
|
||||
p.__class__, AutoGPTPluginTemplate
|
||||
), f"{p} does not subclass AutoGPTPluginTemplate"
|
||||
assert (
|
||||
p.__class__.__name__ != "AutoGPTPluginTemplate"
|
||||
), f"Plugins must subclass AutoGPTPluginTemplate; {p} is a template instance"
|
||||
return p
|
||||
|
||||
def get_openai_credentials(self, model: str) -> dict[str, str]:
|
||||
credentials = {
|
||||
@@ -149,73 +207,7 @@ class Config(SystemSettings):
|
||||
|
||||
|
||||
class ConfigBuilder(Configurable[Config]):
|
||||
default_plugins_config_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "..", "plugins_config.yaml"
|
||||
)
|
||||
|
||||
elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
if os.getenv("USE_MAC_OS_TTS"):
|
||||
default_tts_provider = "macos"
|
||||
elif elevenlabs_api_key:
|
||||
default_tts_provider = "elevenlabs"
|
||||
elif os.getenv("USE_BRIAN_TTS"):
|
||||
default_tts_provider = "streamelements"
|
||||
else:
|
||||
default_tts_provider = "gtts"
|
||||
|
||||
default_settings = Config(
|
||||
name="Default Server Config",
|
||||
description="This is a default server configuration",
|
||||
smart_llm="gpt-4",
|
||||
fast_llm="gpt-3.5-turbo",
|
||||
continuous_mode=False,
|
||||
continuous_limit=0,
|
||||
skip_news=False,
|
||||
debug_mode=False,
|
||||
plugins_dir="plugins",
|
||||
plugins_config=PluginsConfig(plugins={}),
|
||||
speak_mode=False,
|
||||
skip_reprompt=False,
|
||||
allow_downloads=False,
|
||||
exit_key="n",
|
||||
plain_output=False,
|
||||
disabled_command_categories=[],
|
||||
shell_command_control="denylist",
|
||||
shell_denylist=["sudo", "su"],
|
||||
shell_allowlist=[],
|
||||
ai_settings_file="ai_settings.yaml",
|
||||
prompt_settings_file="prompt_settings.yaml",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
browse_spacy_language_model="en_core_web_sm",
|
||||
temperature=0,
|
||||
use_azure=False,
|
||||
azure_config_file=AZURE_CONFIG_FILE,
|
||||
execute_local_commands=False,
|
||||
restrict_to_workspace=True,
|
||||
openai_functions=False,
|
||||
streamelements_voice="Brian",
|
||||
text_to_speech_provider=default_tts_provider,
|
||||
image_size=256,
|
||||
huggingface_image_model="CompVis/stable-diffusion-v1-4",
|
||||
audio_to_text_provider="huggingface",
|
||||
sd_webui_url="http://localhost:7860",
|
||||
selenium_web_browser="chrome",
|
||||
selenium_headless=True,
|
||||
user_agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||
memory_backend="json_file",
|
||||
memory_index="auto-gpt-memory",
|
||||
redis_host="localhost",
|
||||
redis_port=6379,
|
||||
wipe_redis_on_start=True,
|
||||
plugins_allowlist=[],
|
||||
plugins_denylist=[],
|
||||
plugins_openai=[],
|
||||
plugins_config_file=default_plugins_config_file,
|
||||
chat_messages_enabled=True,
|
||||
plugins=[],
|
||||
authorise_key="y",
|
||||
redis_password="",
|
||||
)
|
||||
default_settings = Config()
|
||||
|
||||
@classmethod
|
||||
def build_config_from_env(cls) -> Config:
|
||||
@@ -285,6 +277,16 @@ class ConfigBuilder(Configurable[Config]):
|
||||
config_dict["elevenlabs_voice_id"] = os.getenv(
|
||||
"ELEVENLABS_VOICE_ID", os.getenv("ELEVENLABS_VOICE_1_ID")
|
||||
)
|
||||
elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
if os.getenv("USE_MAC_OS_TTS"):
|
||||
default_tts_provider = "macos"
|
||||
elif elevenlabs_api_key:
|
||||
default_tts_provider = "elevenlabs"
|
||||
elif os.getenv("USE_BRIAN_TTS"):
|
||||
default_tts_provider = "streamelements"
|
||||
else:
|
||||
default_tts_provider = "gtts"
|
||||
config_dict["text_to_speech_provider"] = default_tts_provider
|
||||
|
||||
config_dict["plugins_allowlist"] = _safe_split(os.getenv("ALLOWLISTED_PLUGINS"))
|
||||
config_dict["plugins_denylist"] = _safe_split(os.getenv("DENYLISTED_PLUGINS"))
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from math import ceil, floor
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Type, TypedDict, TypeVar, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.llm.providers.openai import OpenAIFunctionCall
|
||||
|
||||
MessageRole = Literal["system", "user", "assistant"]
|
||||
MessageRole = Literal["system", "user", "assistant", "function"]
|
||||
MessageType = Literal["ai_response", "action_result"]
|
||||
|
||||
TText = list[int]
|
||||
@@ -19,6 +20,17 @@ class MessageDict(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
class ResponseMessageDict(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
content: Optional[str]
|
||||
function_call: Optional[FunctionCallDict]
|
||||
|
||||
|
||||
class FunctionCallDict(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""OpenAI Message object containing a role and the message content"""
|
||||
@@ -68,15 +80,31 @@ class EmbeddingModelInfo(ModelInfo):
|
||||
embedding_dimensions: int
|
||||
|
||||
|
||||
# Can be replaced by Self in Python 3.11
|
||||
TChatSequence = TypeVar("TChatSequence", bound="ChatSequence")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSequence:
|
||||
"""Utility container for a chat sequence"""
|
||||
|
||||
model: ChatModelInfo
|
||||
messages: list[Message] = field(default_factory=list)
|
||||
messages: list[Message] = field(default_factory=list[Message])
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.messages[i]
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> Message:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: TChatSequence, key: slice) -> TChatSequence:
|
||||
...
|
||||
|
||||
def __getitem__(self: TChatSequence, key: int | slice) -> Message | TChatSequence:
|
||||
if isinstance(key, slice):
|
||||
copy = deepcopy(self)
|
||||
copy.messages = self.messages[key]
|
||||
return copy
|
||||
return self.messages[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.messages)
|
||||
@@ -84,6 +112,14 @@ class ChatSequence:
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
|
||||
def add(
|
||||
self,
|
||||
message_role: MessageRole,
|
||||
content: str,
|
||||
type: MessageType | None = None,
|
||||
) -> None:
|
||||
self.append(Message(message_role, content, type))
|
||||
|
||||
def append(self, message: Message):
|
||||
return self.messages.append(message)
|
||||
|
||||
@@ -95,21 +131,23 @@ class ChatSequence:
|
||||
self.messages.insert(index, message)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, messages: list[Message] | ChatSequence = []):
|
||||
def for_model(
|
||||
cls: Type[TChatSequence],
|
||||
model_name: str,
|
||||
messages: list[Message] | ChatSequence = [],
|
||||
**kwargs,
|
||||
) -> TChatSequence:
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
|
||||
|
||||
if not model_name in OPEN_AI_CHAT_MODELS:
|
||||
raise ValueError(f"Unknown chat model '{model_name}'")
|
||||
|
||||
return ChatSequence(
|
||||
model=OPEN_AI_CHAT_MODELS[model_name], messages=list(messages)
|
||||
return cls(
|
||||
model=OPEN_AI_CHAT_MODELS[model_name], messages=list(messages), **kwargs
|
||||
)
|
||||
|
||||
def add(self, message_role: MessageRole, content: str):
|
||||
self.messages.append(Message(message_role, content))
|
||||
|
||||
@property
|
||||
def token_length(self):
|
||||
def token_length(self) -> int:
|
||||
from autogpt.llm.utils import count_message_tokens
|
||||
|
||||
return count_message_tokens(self.messages, self.model.name)
|
||||
@@ -128,7 +166,7 @@ class ChatSequence:
|
||||
[f"{separator(m.role)}\n{m.content}" for m in self.messages]
|
||||
)
|
||||
return f"""
|
||||
============== ChatSequence ==============
|
||||
============== {__class__.__name__} ==============
|
||||
Length: {self.token_length} tokens; {len(self.messages)} messages
|
||||
{formatted_messages}
|
||||
==========================================
|
||||
@@ -140,24 +178,18 @@ class LLMResponse:
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
|
||||
model_info: ModelInfo
|
||||
prompt_tokens_used: int = 0
|
||||
completion_tokens_used: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an embedding model."""
|
||||
|
||||
embedding: List[float] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.completion_tokens_used:
|
||||
raise ValueError("Embeddings should not have completion tokens used.")
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatModelResponse(LLMResponse):
|
||||
"""Standard response struct for a response from an LLM model."""
|
||||
"""Standard response struct for a response from a chat LLM."""
|
||||
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[OpenAIFunctionCall] = None
|
||||
content: Optional[str]
|
||||
function_call: Optional[OpenAIFunctionCall]
|
||||
|
||||
@@ -3,17 +3,18 @@ from __future__ import annotations
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt.llm.providers.openai import get_openai_command_specs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agent.agent import Agent
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.llm.base import ChatSequence, Message
|
||||
from autogpt.llm.providers.openai import (
|
||||
count_openai_functions_tokens,
|
||||
get_openai_command_specs,
|
||||
)
|
||||
from autogpt.llm.utils import count_message_tokens, create_chat_completion
|
||||
from autogpt.log_cycle.log_cycle import CURRENT_CONTEXT_FILE_NAME
|
||||
from autogpt.logs import logger
|
||||
from autogpt.logs import CURRENT_CONTEXT_FILE_NAME, logger
|
||||
|
||||
|
||||
# TODO: Change debug from hardcode to argument
|
||||
@@ -73,33 +74,28 @@ def chat_with_ai(
|
||||
],
|
||||
)
|
||||
|
||||
# Add messages from the full message history until we reach the token limit
|
||||
next_message_to_add_index = len(agent.history) - 1
|
||||
insertion_index = len(message_sequence)
|
||||
# Count the currently used tokens
|
||||
current_tokens_used = message_sequence.token_length
|
||||
insertion_index = len(message_sequence)
|
||||
|
||||
# while current_tokens_used > 2500:
|
||||
# # remove memories until we are under 2500 tokens
|
||||
# relevant_memory = relevant_memory[:-1]
|
||||
# (
|
||||
# next_message_to_add_index,
|
||||
# current_tokens_used,
|
||||
# insertion_index,
|
||||
# current_context,
|
||||
# ) = generate_context(
|
||||
# prompt, relevant_memory, agent.history, model
|
||||
# )
|
||||
# Account for tokens used by OpenAI functions
|
||||
openai_functions = None
|
||||
if agent.config.openai_functions:
|
||||
openai_functions = get_openai_command_specs(agent.command_registry)
|
||||
functions_tlength = count_openai_functions_tokens(openai_functions, model)
|
||||
current_tokens_used += functions_tlength
|
||||
logger.debug(f"OpenAI Functions take up {functions_tlength} tokens in API call")
|
||||
|
||||
# Account for user input (appended later)
|
||||
user_input_msg = Message("user", triggering_prompt)
|
||||
current_tokens_used += count_message_tokens([user_input_msg], model)
|
||||
current_tokens_used += count_message_tokens(user_input_msg, model)
|
||||
|
||||
current_tokens_used += 500 # Reserve space for new_summary_message
|
||||
current_tokens_used += agent.history.max_summary_tlength # Reserve space
|
||||
current_tokens_used += 500 # Reserve space for the openai functions TODO improve
|
||||
|
||||
# Add Messages until the token limit is reached or there are no more messages to add.
|
||||
for cycle in reversed(list(agent.history.per_cycle(agent.config))):
|
||||
# Add historical Messages until the token limit is reached
|
||||
# or there are no more messages to add.
|
||||
for cycle in reversed(list(agent.history.per_cycle())):
|
||||
messages_to_add = [msg for msg in cycle if msg is not None]
|
||||
tokens_to_add = count_message_tokens(messages_to_add, model)
|
||||
if current_tokens_used + tokens_to_add > send_token_limit:
|
||||
@@ -115,9 +111,9 @@ def chat_with_ai(
|
||||
new_summary_message, trimmed_messages = agent.history.trim_messages(
|
||||
current_message_chain=list(message_sequence), config=agent.config
|
||||
)
|
||||
tokens_to_add = count_message_tokens([new_summary_message], model)
|
||||
tokens_to_add = count_message_tokens(new_summary_message, model)
|
||||
message_sequence.insert(insertion_index, new_summary_message)
|
||||
current_tokens_used += tokens_to_add - 500
|
||||
current_tokens_used += tokens_to_add - agent.history.max_summary_tlength
|
||||
|
||||
# FIXME: uncomment when memory is back in use
|
||||
# memory_store = get_memory(config)
|
||||
@@ -143,7 +139,7 @@ def chat_with_ai(
|
||||
)
|
||||
logger.debug(budget_message)
|
||||
message_sequence.add("system", budget_message)
|
||||
current_tokens_used += count_message_tokens([message_sequence[-1]], model)
|
||||
current_tokens_used += count_message_tokens(message_sequence[-1], model)
|
||||
|
||||
# Append user input, the length of this is accounted for above
|
||||
message_sequence.append(user_input_msg)
|
||||
@@ -157,14 +153,14 @@ def chat_with_ai(
|
||||
)
|
||||
if not plugin_response or plugin_response == "":
|
||||
continue
|
||||
tokens_to_add = count_message_tokens(
|
||||
[Message("system", plugin_response)], model
|
||||
)
|
||||
tokens_to_add = count_message_tokens(Message("system", plugin_response), model)
|
||||
if current_tokens_used + tokens_to_add > send_token_limit:
|
||||
logger.debug(f"Plugin response too long, skipping: {plugin_response}")
|
||||
logger.debug(f"Plugins remaining at stop: {plugin_count - i}")
|
||||
break
|
||||
message_sequence.add("system", plugin_response)
|
||||
current_tokens_used += tokens_to_add
|
||||
|
||||
# Calculate remaining tokens
|
||||
tokens_remaining = token_limit - current_tokens_used
|
||||
# assert tokens_remaining >= 0, "Tokens remaining is negative.
|
||||
@@ -196,7 +192,7 @@ def chat_with_ai(
|
||||
assistant_reply = create_chat_completion(
|
||||
prompt=message_sequence,
|
||||
config=agent.config,
|
||||
functions=get_openai_command_specs(agent),
|
||||
functions=openai_functions,
|
||||
max_tokens=tokens_remaining,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import openai
|
||||
@@ -12,9 +12,6 @@ from colorama import Fore, Style
|
||||
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
|
||||
from openai.openai_object import OpenAIObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agent.agent import Agent
|
||||
|
||||
from autogpt.llm.base import (
|
||||
ChatModelInfo,
|
||||
EmbeddingModelInfo,
|
||||
@@ -23,6 +20,7 @@ from autogpt.llm.base import (
|
||||
TText,
|
||||
)
|
||||
from autogpt.logs import logger
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
OPEN_AI_CHAT_MODELS = {
|
||||
info.name: info
|
||||
@@ -114,7 +112,7 @@ OPEN_AI_MODELS: dict[str, ChatModelInfo | EmbeddingModelInfo | TextModelInfo] =
|
||||
}
|
||||
|
||||
|
||||
def meter_api(func):
|
||||
def meter_api(func: Callable):
|
||||
"""Adds ApiManager metering to functions which make OpenAI API calls"""
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
|
||||
@@ -152,7 +150,7 @@ def meter_api(func):
|
||||
|
||||
|
||||
def retry_api(
|
||||
num_retries: int = 10,
|
||||
max_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True,
|
||||
):
|
||||
@@ -164,43 +162,49 @@ def retry_api(
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
error_messages = {
|
||||
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded, passing...{Fore.RESET}",
|
||||
RateLimitError: f"{Fore.RED}Error: Reached rate limit, passing...{Fore.RESET}",
|
||||
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded{Fore.RESET}",
|
||||
RateLimitError: f"{Fore.RED}Error: Reached rate limit{Fore.RESET}",
|
||||
}
|
||||
api_key_error_msg = (
|
||||
f"Please double check that you have setup a "
|
||||
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
|
||||
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
|
||||
)
|
||||
backoff_msg = (
|
||||
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
|
||||
)
|
||||
backoff_msg = f"{Fore.RED}Waiting {{backoff}} seconds...{Fore.RESET}"
|
||||
|
||||
def _wrapper(func):
|
||||
def _wrapper(func: Callable):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
max_attempts = max_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except (RateLimitError, ServiceUnavailableError) as e:
|
||||
if attempt == num_attempts:
|
||||
if attempt >= max_attempts or (
|
||||
# User's API quota exceeded
|
||||
isinstance(e, RateLimitError)
|
||||
and (err := getattr(e, "error", {}))
|
||||
and err.get("code") == "insufficient_quota"
|
||||
):
|
||||
raise
|
||||
|
||||
error_msg = error_messages[type(e)]
|
||||
logger.debug(error_msg)
|
||||
logger.warn(error_msg)
|
||||
if not user_warned:
|
||||
logger.double_check(api_key_error_msg)
|
||||
logger.debug(f"Status: {e.http_status}")
|
||||
logger.debug(f"Response body: {e.json_body}")
|
||||
logger.debug(f"Response headers: {e.headers}")
|
||||
user_warned = True
|
||||
|
||||
except (APIError, Timeout) as e:
|
||||
if (e.http_status not in [429, 502]) or (attempt == num_attempts):
|
||||
if (e.http_status not in [429, 502]) or (attempt == max_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
logger.debug(backoff_msg.format(backoff=backoff))
|
||||
logger.warn(backoff_msg.format(backoff=backoff))
|
||||
time.sleep(backoff)
|
||||
|
||||
return _wrapped
|
||||
@@ -301,13 +305,13 @@ class OpenAIFunctionSpec:
|
||||
@dataclass
|
||||
class ParameterSpec:
|
||||
name: str
|
||||
type: str
|
||||
type: str # TODO: add enum support
|
||||
description: Optional[str]
|
||||
required: bool = False
|
||||
|
||||
@property
|
||||
def __dict__(self):
|
||||
"""Output an OpenAI-consumable function specification"""
|
||||
def schema(self) -> dict[str, str | dict | list]:
|
||||
"""Returns an OpenAI-consumable function specification"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
@@ -326,14 +330,44 @@ class OpenAIFunctionSpec:
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def prompt_format(self) -> str:
|
||||
"""Returns the function formatted similarly to the way OpenAI does it internally:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
|
||||
def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]:
|
||||
Example:
|
||||
```ts
|
||||
// Get the current weather in a given location
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
unit?: "celsius" | "fahrenheit",
|
||||
}) => any;
|
||||
```
|
||||
"""
|
||||
|
||||
def param_signature(p_spec: OpenAIFunctionSpec.ParameterSpec) -> str:
|
||||
# TODO: enum type support
|
||||
return (
|
||||
f"// {p_spec.description}\n" if p_spec.description else ""
|
||||
) + f"{p_spec.name}{'' if p_spec.required else '?'}: {p_spec.type},"
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
f"// {self.description}",
|
||||
f"type {self.name} = (_ :{{",
|
||||
*[param_signature(p) for p in self.parameters.values()],
|
||||
"}) => any;",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_openai_command_specs(
|
||||
command_registry: CommandRegistry,
|
||||
) -> list[OpenAIFunctionSpec]:
|
||||
"""Get OpenAI-consumable function specs for the agent's available commands.
|
||||
see https://platform.openai.com/docs/guides/gpt/function-calling
|
||||
"""
|
||||
if not agent.config.openai_functions:
|
||||
return []
|
||||
|
||||
return [
|
||||
OpenAIFunctionSpec(
|
||||
name=command.name,
|
||||
@@ -348,5 +382,48 @@ def get_openai_command_specs(agent: Agent) -> list[OpenAIFunctionSpec]:
|
||||
for param in command.parameters
|
||||
},
|
||||
)
|
||||
for command in agent.command_registry.commands.values()
|
||||
for command in command_registry.commands.values()
|
||||
]
|
||||
|
||||
|
||||
def count_openai_functions_tokens(
|
||||
functions: list[OpenAIFunctionSpec], for_model: str
|
||||
) -> int:
|
||||
"""Returns the number of tokens taken up by a set of function definitions
|
||||
|
||||
Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
"""
|
||||
from autogpt.llm.utils import count_string_tokens
|
||||
|
||||
return count_string_tokens(
|
||||
f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}",
|
||||
for_model,
|
||||
)
|
||||
|
||||
|
||||
def format_function_specs_as_typescript_ns(functions: list[OpenAIFunctionSpec]) -> str:
|
||||
"""Returns a function signature block in the format used by OpenAI internally:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
|
||||
For use with `count_string_tokens` to determine token usage of provided functions.
|
||||
|
||||
Example:
|
||||
```ts
|
||||
namespace functions {
|
||||
|
||||
// Get the current weather in a given location
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
unit?: "celsius" | "fahrenheit",
|
||||
}) => any;
|
||||
|
||||
} // namespace functions
|
||||
```
|
||||
"""
|
||||
|
||||
return (
|
||||
"namespace functions {\n\n"
|
||||
+ "\n\n".join(f.prompt_format for f in functions)
|
||||
+ "\n\n} // namespace functions"
|
||||
)
|
||||
|
||||
@@ -7,12 +7,19 @@ from colorama import Fore
|
||||
from autogpt.config import Config
|
||||
|
||||
from ..api_manager import ApiManager
|
||||
from ..base import ChatModelResponse, ChatSequence, Message
|
||||
from ..base import (
|
||||
ChatModelResponse,
|
||||
ChatSequence,
|
||||
FunctionCallDict,
|
||||
Message,
|
||||
ResponseMessageDict,
|
||||
)
|
||||
from ..providers import openai as iopenai
|
||||
from ..providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAIFunctionCall,
|
||||
OpenAIFunctionSpec,
|
||||
count_openai_functions_tokens,
|
||||
)
|
||||
from .token_counter import *
|
||||
|
||||
@@ -111,7 +118,13 @@ def create_chat_completion(
|
||||
if temperature is None:
|
||||
temperature = config.temperature
|
||||
if max_tokens is None:
|
||||
max_tokens = OPEN_AI_CHAT_MODELS[model].max_tokens - prompt.token_length
|
||||
prompt_tlength = prompt.token_length
|
||||
max_tokens = OPEN_AI_CHAT_MODELS[model].max_tokens - prompt_tlength
|
||||
logger.debug(f"Prompt length: {prompt_tlength} tokens")
|
||||
if functions:
|
||||
functions_tlength = count_openai_functions_tokens(functions, model)
|
||||
max_tokens -= functions_tlength
|
||||
logger.debug(f"Functions take up {functions_tlength} tokens in API call")
|
||||
|
||||
logger.debug(
|
||||
f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}"
|
||||
@@ -138,9 +151,8 @@ def create_chat_completion(
|
||||
|
||||
if functions:
|
||||
chat_completion_kwargs["functions"] = [
|
||||
function.__dict__ for function in functions
|
||||
function.schema for function in functions
|
||||
]
|
||||
logger.debug(f"Function dicts: {chat_completion_kwargs['functions']}")
|
||||
|
||||
response = iopenai.create_chat_completion(
|
||||
messages=prompt.raw(),
|
||||
@@ -152,19 +164,24 @@ def create_chat_completion(
|
||||
logger.error(response.error)
|
||||
raise RuntimeError(response.error)
|
||||
|
||||
first_message = response.choices[0].message
|
||||
first_message: ResponseMessageDict = response.choices[0].message
|
||||
content: str | None = first_message.get("content")
|
||||
function_call: OpenAIFunctionCall | None = first_message.get("function_call")
|
||||
function_call: FunctionCallDict | None = first_message.get("function_call")
|
||||
|
||||
for plugin in config.plugins:
|
||||
if not plugin.can_handle_on_response():
|
||||
continue
|
||||
# TODO: function call support in plugin.on_response()
|
||||
content = plugin.on_response(content)
|
||||
|
||||
return ChatModelResponse(
|
||||
model_info=OPEN_AI_CHAT_MODELS[model],
|
||||
content=content,
|
||||
function_call=function_call,
|
||||
function_call=OpenAIFunctionCall(
|
||||
name=function_call["name"], arguments=function_call["arguments"]
|
||||
)
|
||||
if function_call
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Functions for counting the number of tokens in a message or string."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import List, overload
|
||||
|
||||
import tiktoken
|
||||
|
||||
@@ -9,8 +9,18 @@ from autogpt.llm.base import Message
|
||||
from autogpt.logs import logger
|
||||
|
||||
|
||||
@overload
|
||||
def count_message_tokens(messages: Message, model: str = "gpt-3.5-turbo") -> int:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def count_message_tokens(messages: List[Message], model: str = "gpt-3.5-turbo") -> int:
|
||||
...
|
||||
|
||||
|
||||
def count_message_tokens(
|
||||
messages: List[Message], model: str = "gpt-3.5-turbo-0301"
|
||||
messages: Message | List[Message], model: str = "gpt-3.5-turbo"
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of tokens used by a list of messages.
|
||||
@@ -24,6 +34,9 @@ def count_message_tokens(
|
||||
Returns:
|
||||
int: The number of tokens used by the list of messages.
|
||||
"""
|
||||
if isinstance(messages, Message):
|
||||
messages = [messages]
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
tokens_per_message = (
|
||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
class JsonFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename, mode="a", encoding=None, delay=False):
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record):
|
||||
json_data = json.loads(self.format(record))
|
||||
with open(self.baseFilename, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
return record.msg
|
||||
15
autogpt/logs/__init__.py
Normal file
15
autogpt/logs/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .formatters import AutoGptFormatter, JsonFormatter, remove_color_codes
|
||||
from .handlers import ConsoleHandler, JsonFileHandler, TypingConsoleHandler
|
||||
from .log_cycle import (
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
FULL_MESSAGE_HISTORY_FILE_NAME,
|
||||
NEXT_ACTION_FILE_NAME,
|
||||
PROMPT_SUMMARY_FILE_NAME,
|
||||
PROMPT_SUPERVISOR_FEEDBACK_FILE_NAME,
|
||||
SUMMARY_FILE_NAME,
|
||||
SUPERVISOR_FEEDBACK_FILE_NAME,
|
||||
USER_INPUT_FILE_NAME,
|
||||
LogCycleHandler,
|
||||
)
|
||||
from .logger import Logger, logger
|
||||
from .utils import print_assistant_thoughts, remove_ansi_escape
|
||||
41
autogpt/logs/formatters.py
Normal file
41
autogpt/logs/formatters.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from colorama import Style
|
||||
|
||||
|
||||
class AutoGptFormatter(logging.Formatter):
|
||||
"""
|
||||
Allows to handle custom placeholders 'title_color' and 'message_no_color'.
|
||||
To use this formatter, make sure to pass 'color', 'title' as log extras.
|
||||
"""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if hasattr(record, "color"):
|
||||
record.title_color = (
|
||||
getattr(record, "color")
|
||||
+ getattr(record, "title", "")
|
||||
+ " "
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
record.title_color = getattr(record, "title", "")
|
||||
|
||||
# Add this line to set 'title' to an empty string if it doesn't exist
|
||||
record.title = getattr(record, "title", "")
|
||||
|
||||
if hasattr(record, "msg"):
|
||||
record.message_no_color = remove_color_codes(getattr(record, "msg"))
|
||||
else:
|
||||
record.message_no_color = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||||
return ansi_escape.sub("", s)
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord):
|
||||
return record.msg
|
||||
47
autogpt/logs/handlers.py
Normal file
47
autogpt/logs/handlers.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
|
||||
class ConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
msg = self.format(record)
|
||||
try:
|
||||
print(msg)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class TypingConsoleHandler(logging.StreamHandler):
|
||||
"""Output stream to console using simulated typing"""
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
min_typing_speed = 0.05
|
||||
max_typing_speed = 0.01
|
||||
|
||||
msg = self.format(record)
|
||||
try:
|
||||
words = msg.split()
|
||||
for i, word in enumerate(words):
|
||||
print(word, end="", flush=True)
|
||||
if i < len(words) - 1:
|
||||
print(" ", end="", flush=True)
|
||||
typing_speed = random.uniform(min_typing_speed, max_typing_speed)
|
||||
time.sleep(typing_speed)
|
||||
# type faster after each word
|
||||
min_typing_speed = min_typing_speed * 0.95
|
||||
max_typing_speed = max_typing_speed * 0.95
|
||||
print()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class JsonFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename: str, mode="a", encoding=None, delay=False):
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
json_data = json.loads(self.format(record))
|
||||
with open(self.baseFilename, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from autogpt.logs import logger
|
||||
from .logger import logger
|
||||
|
||||
DEFAULT_PREFIX = "agent"
|
||||
FULL_MESSAGE_HISTORY_FILE_NAME = "full_message_history.json"
|
||||
@@ -42,7 +42,7 @@ class LogCycleHandler:
|
||||
|
||||
return outer_folder_path
|
||||
|
||||
def get_agent_short_name(self, ai_name):
|
||||
def get_agent_short_name(self, ai_name: str) -> str:
|
||||
return ai_name[:15].rstrip() if ai_name else DEFAULT_PREFIX
|
||||
|
||||
def create_inner_directory(self, outer_folder_path: str, cycle_count: int) -> str:
|
||||
@@ -3,20 +3,18 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from logging import LogRecord
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from colorama import Fore, Style
|
||||
from colorama import Fore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
from autogpt.log_cycle.json_handler import JsonFileHandler, JsonFormatter
|
||||
from autogpt.singleton import Singleton
|
||||
|
||||
from .formatters import AutoGptFormatter, JsonFormatter
|
||||
from .handlers import ConsoleHandler, JsonFileHandler, TypingConsoleHandler
|
||||
|
||||
|
||||
class Logger(metaclass=Singleton):
|
||||
"""
|
||||
@@ -100,8 +98,13 @@ class Logger(metaclass=Singleton):
|
||||
self.typing_logger.addHandler(self.console_handler)
|
||||
|
||||
def typewriter_log(
|
||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||
):
|
||||
self,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
content: str = "",
|
||||
speak_text: bool = False,
|
||||
level: int = logging.INFO,
|
||||
) -> None:
|
||||
from autogpt.speech import say_text
|
||||
|
||||
if speak_text and self.config and self.config.speak_mode:
|
||||
@@ -122,29 +125,29 @@ class Logger(metaclass=Singleton):
|
||||
|
||||
def debug(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
message: str,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
) -> None:
|
||||
self._log(title, title_color, message, logging.DEBUG)
|
||||
|
||||
def info(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
message: str,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
) -> None:
|
||||
self._log(title, title_color, message, logging.INFO)
|
||||
|
||||
def warn(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
message: str,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
) -> None:
|
||||
self._log(title, title_color, message, logging.WARN)
|
||||
|
||||
def error(self, title, message=""):
|
||||
def error(self, title: str, message: str = "") -> None:
|
||||
self._log(title, Fore.RED, message, logging.ERROR)
|
||||
|
||||
def _log(
|
||||
@@ -152,8 +155,8 @@ class Logger(metaclass=Singleton):
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
message: str = "",
|
||||
level=logging.INFO,
|
||||
):
|
||||
level: int = logging.INFO,
|
||||
) -> None:
|
||||
if message:
|
||||
if isinstance(message, list):
|
||||
message = " ".join(message)
|
||||
@@ -161,11 +164,11 @@ class Logger(metaclass=Singleton):
|
||||
level, message, extra={"title": str(title), "color": str(title_color)}
|
||||
)
|
||||
|
||||
def set_level(self, level):
|
||||
def set_level(self, level: logging._Level) -> None:
|
||||
self.logger.setLevel(level)
|
||||
self.typing_logger.setLevel(level)
|
||||
|
||||
def double_check(self, additionalText=None):
|
||||
def double_check(self, additionalText: Optional[str] = None) -> None:
|
||||
if not additionalText:
|
||||
additionalText = (
|
||||
"Please ensure you've setup and configured everything"
|
||||
@@ -191,131 +194,10 @@ class Logger(metaclass=Singleton):
|
||||
self.json_logger.debug(data)
|
||||
self.json_logger.removeHandler(json_data_handler)
|
||||
|
||||
def get_log_directory(self):
|
||||
def get_log_directory(self) -> str:
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
log_dir = os.path.join(this_files_dir_path, "../../logs")
|
||||
return os.path.abspath(log_dir)
|
||||
|
||||
|
||||
"""
|
||||
Output stream to console using simulated typing
|
||||
"""
|
||||
|
||||
|
||||
class TypingConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record):
|
||||
min_typing_speed = 0.05
|
||||
max_typing_speed = 0.01
|
||||
|
||||
msg = self.format(record)
|
||||
try:
|
||||
words = msg.split()
|
||||
for i, word in enumerate(words):
|
||||
print(word, end="", flush=True)
|
||||
if i < len(words) - 1:
|
||||
print(" ", end="", flush=True)
|
||||
typing_speed = random.uniform(min_typing_speed, max_typing_speed)
|
||||
time.sleep(typing_speed)
|
||||
# type faster after each word
|
||||
min_typing_speed = min_typing_speed * 0.95
|
||||
max_typing_speed = max_typing_speed * 0.95
|
||||
print()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class ConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record) -> None:
|
||||
msg = self.format(record)
|
||||
try:
|
||||
print(msg)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class AutoGptFormatter(logging.Formatter):
|
||||
"""
|
||||
Allows to handle custom placeholders 'title_color' and 'message_no_color'.
|
||||
To use this formatter, make sure to pass 'color', 'title' as log extras.
|
||||
"""
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
if hasattr(record, "color"):
|
||||
record.title_color = (
|
||||
getattr(record, "color")
|
||||
+ getattr(record, "title", "")
|
||||
+ " "
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
record.title_color = getattr(record, "title", "")
|
||||
|
||||
# Add this line to set 'title' to an empty string if it doesn't exist
|
||||
record.title = getattr(record, "title", "")
|
||||
|
||||
if hasattr(record, "msg"):
|
||||
record.message_no_color = remove_color_codes(getattr(record, "msg"))
|
||||
else:
|
||||
record.message_no_color = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||||
return ansi_escape.sub("", s)
|
||||
|
||||
|
||||
def remove_ansi_escape(s: str) -> str:
|
||||
return s.replace("\x1B", "")
|
||||
|
||||
|
||||
logger = Logger()
|
||||
|
||||
|
||||
def print_assistant_thoughts(
|
||||
ai_name: object,
|
||||
assistant_reply_json_valid: object,
|
||||
config: Config,
|
||||
) -> None:
|
||||
from autogpt.speech import say_text
|
||||
|
||||
assistant_thoughts_reasoning = None
|
||||
assistant_thoughts_plan = None
|
||||
assistant_thoughts_speak = None
|
||||
assistant_thoughts_criticism = None
|
||||
|
||||
assistant_thoughts = assistant_reply_json_valid.get("thoughts", {})
|
||||
assistant_thoughts_text = remove_ansi_escape(assistant_thoughts.get("text", ""))
|
||||
if assistant_thoughts:
|
||||
assistant_thoughts_reasoning = remove_ansi_escape(
|
||||
assistant_thoughts.get("reasoning")
|
||||
)
|
||||
assistant_thoughts_plan = remove_ansi_escape(assistant_thoughts.get("plan"))
|
||||
assistant_thoughts_criticism = remove_ansi_escape(
|
||||
assistant_thoughts.get("criticism")
|
||||
)
|
||||
assistant_thoughts_speak = remove_ansi_escape(assistant_thoughts.get("speak"))
|
||||
logger.typewriter_log(
|
||||
f"{ai_name.upper()} THOUGHTS:", Fore.YELLOW, f"{assistant_thoughts_text}"
|
||||
)
|
||||
logger.typewriter_log("REASONING:", Fore.YELLOW, f"{assistant_thoughts_reasoning}")
|
||||
if assistant_thoughts_plan:
|
||||
logger.typewriter_log("PLAN:", Fore.YELLOW, "")
|
||||
# If it's a list, join it into a string
|
||||
if isinstance(assistant_thoughts_plan, list):
|
||||
assistant_thoughts_plan = "\n".join(assistant_thoughts_plan)
|
||||
elif isinstance(assistant_thoughts_plan, dict):
|
||||
assistant_thoughts_plan = str(assistant_thoughts_plan)
|
||||
|
||||
# Split the input_string using the newline character and dashes
|
||||
lines = assistant_thoughts_plan.split("\n")
|
||||
for line in lines:
|
||||
line = line.lstrip("- ")
|
||||
logger.typewriter_log("- ", Fore.GREEN, line.strip())
|
||||
logger.typewriter_log("CRITICISM:", Fore.YELLOW, f"{assistant_thoughts_criticism}")
|
||||
# Speak the assistant's thoughts
|
||||
if assistant_thoughts_speak:
|
||||
if config.speak_mode:
|
||||
say_text(assistant_thoughts_speak, config)
|
||||
else:
|
||||
logger.typewriter_log("SPEAK:", Fore.YELLOW, f"{assistant_thoughts_speak}")
|
||||
65
autogpt/logs/utils.py
Normal file
65
autogpt/logs/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
from .logger import logger
|
||||
|
||||
|
||||
def print_assistant_thoughts(
|
||||
ai_name: str,
|
||||
assistant_reply_json_valid: dict,
|
||||
config: Config,
|
||||
) -> None:
|
||||
from autogpt.speech import say_text
|
||||
|
||||
assistant_thoughts_reasoning = None
|
||||
assistant_thoughts_plan = None
|
||||
assistant_thoughts_speak = None
|
||||
assistant_thoughts_criticism = None
|
||||
|
||||
assistant_thoughts = assistant_reply_json_valid.get("thoughts", {})
|
||||
assistant_thoughts_text = remove_ansi_escape(assistant_thoughts.get("text", ""))
|
||||
if assistant_thoughts:
|
||||
assistant_thoughts_reasoning = remove_ansi_escape(
|
||||
assistant_thoughts.get("reasoning", "")
|
||||
)
|
||||
assistant_thoughts_plan = remove_ansi_escape(assistant_thoughts.get("plan", ""))
|
||||
assistant_thoughts_criticism = remove_ansi_escape(
|
||||
assistant_thoughts.get("criticism", "")
|
||||
)
|
||||
assistant_thoughts_speak = remove_ansi_escape(
|
||||
assistant_thoughts.get("speak", "")
|
||||
)
|
||||
logger.typewriter_log(
|
||||
f"{ai_name.upper()} THOUGHTS:", Fore.YELLOW, assistant_thoughts_text
|
||||
)
|
||||
logger.typewriter_log("REASONING:", Fore.YELLOW, str(assistant_thoughts_reasoning))
|
||||
if assistant_thoughts_plan:
|
||||
logger.typewriter_log("PLAN:", Fore.YELLOW, "")
|
||||
# If it's a list, join it into a string
|
||||
if isinstance(assistant_thoughts_plan, list):
|
||||
assistant_thoughts_plan = "\n".join(assistant_thoughts_plan)
|
||||
elif isinstance(assistant_thoughts_plan, dict):
|
||||
assistant_thoughts_plan = str(assistant_thoughts_plan)
|
||||
|
||||
# Split the input_string using the newline character and dashes
|
||||
lines = assistant_thoughts_plan.split("\n")
|
||||
for line in lines:
|
||||
line = line.lstrip("- ")
|
||||
logger.typewriter_log("- ", Fore.GREEN, line.strip())
|
||||
logger.typewriter_log("CRITICISM:", Fore.YELLOW, f"{assistant_thoughts_criticism}")
|
||||
# Speak the assistant's thoughts
|
||||
if assistant_thoughts_speak:
|
||||
if config.speak_mode:
|
||||
say_text(assistant_thoughts_speak, config)
|
||||
else:
|
||||
logger.typewriter_log("SPEAK:", Fore.YELLOW, f"{assistant_thoughts_speak}")
|
||||
|
||||
|
||||
def remove_ansi_escape(s: str) -> str:
|
||||
return s.replace("\x1B", "")
|
||||
@@ -2,49 +2,45 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agent import Agent
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.json_utils.utilities import extract_json_from_response
|
||||
from autogpt.llm.base import ChatSequence, Message, MessageRole, MessageType
|
||||
from autogpt.llm.base import ChatSequence, Message
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
|
||||
from autogpt.llm.utils import count_string_tokens, create_chat_completion
|
||||
from autogpt.log_cycle.log_cycle import PROMPT_SUMMARY_FILE_NAME, SUMMARY_FILE_NAME
|
||||
from autogpt.logs import logger
|
||||
from autogpt.llm.utils import (
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
create_chat_completion,
|
||||
)
|
||||
from autogpt.logs import PROMPT_SUMMARY_FILE_NAME, SUMMARY_FILE_NAME, logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageHistory:
|
||||
agent: Agent
|
||||
|
||||
messages: list[Message] = field(default_factory=list)
|
||||
class MessageHistory(ChatSequence):
|
||||
max_summary_tlength: int = 500
|
||||
agent: Optional[Agent] = None
|
||||
summary: str = "I was created"
|
||||
|
||||
last_trimmed_index: int = 0
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.messages[i]
|
||||
SUMMARIZATION_PROMPT = '''Your task is to create a concise running summary of actions and information results in the provided text, focusing on key and potentially important information to remember.
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.messages)
|
||||
You will receive the current summary and your latest actions. Combine them, adding relevant key information from the latest development in 1st person past tense and keeping the summary concise.
|
||||
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
Summary So Far:
|
||||
"""
|
||||
{summary}
|
||||
"""
|
||||
|
||||
def add(
|
||||
self,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
type: MessageType | None = None,
|
||||
):
|
||||
return self.append(Message(role, content, type))
|
||||
|
||||
def append(self, message: Message):
|
||||
return self.messages.append(message)
|
||||
Latest Development:
|
||||
"""
|
||||
{new_events}
|
||||
"""
|
||||
'''
|
||||
|
||||
def trim_messages(
|
||||
self, current_message_chain: list[Message], config: Config
|
||||
@@ -84,7 +80,7 @@ class MessageHistory:
|
||||
|
||||
return new_summary_message, new_messages_not_in_chain
|
||||
|
||||
def per_cycle(self, config: Config, messages: list[Message] | None = None):
|
||||
def per_cycle(self, messages: list[Message] | None = None):
|
||||
"""
|
||||
Yields:
|
||||
Message: a message containing user input
|
||||
@@ -119,26 +115,33 @@ class MessageHistory:
|
||||
)
|
||||
|
||||
def update_running_summary(
|
||||
self, new_events: list[Message], config: Config
|
||||
self,
|
||||
new_events: list[Message],
|
||||
config: Config,
|
||||
max_summary_length: Optional[int] = None,
|
||||
) -> Message:
|
||||
"""
|
||||
This function takes a list of dictionaries representing new events and combines them with the current summary,
|
||||
focusing on key and potentially important information to remember. The updated summary is returned in a message
|
||||
formatted in the 1st person past tense.
|
||||
This function takes a list of Message objects and updates the running summary
|
||||
to include the events they describe. The updated summary is returned
|
||||
in a Message formatted in the 1st person past tense.
|
||||
|
||||
Args:
|
||||
new_events (List[Dict]): A list of dictionaries containing the latest events to be added to the summary.
|
||||
new_events: A list of Messages containing the latest events to be added to the summary.
|
||||
|
||||
Returns:
|
||||
str: A message containing the updated summary of actions, formatted in the 1st person past tense.
|
||||
Message: a Message containing the updated running summary.
|
||||
|
||||
Example:
|
||||
```py
|
||||
new_events = [{"event": "entered the kitchen."}, {"event": "found a scrawled note with the number 7"}]
|
||||
update_running_summary(new_events)
|
||||
# Returns: "This reminds you of these events from your past: \nI entered the kitchen and found a scrawled note saying 7."
|
||||
```
|
||||
"""
|
||||
if not new_events:
|
||||
return self.summary_message()
|
||||
if not max_summary_length:
|
||||
max_summary_length = self.max_summary_tlength
|
||||
|
||||
# Create a copy of the new_events list to prevent modifying the original list
|
||||
new_events = copy.deepcopy(new_events)
|
||||
@@ -166,29 +169,29 @@ class MessageHistory:
|
||||
elif event.role == "user":
|
||||
new_events.remove(event)
|
||||
|
||||
# Summarize events and current summary in batch to a new running summary
|
||||
summ_model = OPEN_AI_CHAT_MODELS[config.fast_llm]
|
||||
|
||||
# Assume an upper bound length for the summary prompt template, i.e. Your task is to create a concise running summary...., in summarize_batch func
|
||||
# TODO make this default dynamic
|
||||
prompt_template_length = 100
|
||||
max_tokens = OPEN_AI_CHAT_MODELS.get(config.fast_llm).max_tokens
|
||||
summary_tlength = count_string_tokens(str(self.summary), config.fast_llm)
|
||||
# Determine token lengths for use in batching
|
||||
prompt_template_length = len(
|
||||
MessageHistory.SUMMARIZATION_PROMPT.format(summary="", new_events="")
|
||||
)
|
||||
max_input_tokens = summ_model.max_tokens - max_summary_length
|
||||
summary_tlength = count_string_tokens(self.summary, summ_model.name)
|
||||
batch = []
|
||||
batch_tlength = 0
|
||||
|
||||
# TODO Can put a cap on length of total new events and drop some previous events to save API cost, but need to think thru more how to do it without losing the context
|
||||
# TODO: Put a cap on length of total new events and drop some previous events to
|
||||
# save API cost. Need to think thru more how to do it without losing the context.
|
||||
for event in new_events:
|
||||
event_tlength = count_string_tokens(str(event), config.fast_llm)
|
||||
event_tlength = count_message_tokens(event, summ_model.name)
|
||||
|
||||
if (
|
||||
batch_tlength + event_tlength
|
||||
> max_tokens - prompt_template_length - summary_tlength
|
||||
> max_input_tokens - prompt_template_length - summary_tlength
|
||||
):
|
||||
# The batch is full. Summarize it and start a new one.
|
||||
self.summarize_batch(batch, config)
|
||||
summary_tlength = count_string_tokens(
|
||||
str(self.summary), config.fast_llm
|
||||
)
|
||||
self.summarize_batch(batch, config, max_summary_length)
|
||||
summary_tlength = count_string_tokens(self.summary, summ_model.name)
|
||||
batch = [event]
|
||||
batch_tlength = event_tlength
|
||||
else:
|
||||
@@ -197,41 +200,36 @@ class MessageHistory:
|
||||
|
||||
if batch:
|
||||
# There's an unprocessed batch. Summarize it.
|
||||
self.summarize_batch(batch, config)
|
||||
self.summarize_batch(batch, config, max_summary_length)
|
||||
|
||||
return self.summary_message()
|
||||
|
||||
def summarize_batch(self, new_events_batch, config):
|
||||
prompt = f'''Your task is to create a concise running summary of actions and information results in the provided text, focusing on key and potentially important information to remember.
|
||||
|
||||
You will receive the current summary and your latest actions. Combine them, adding relevant key information from the latest development in 1st person past tense and keeping the summary concise.
|
||||
|
||||
Summary So Far:
|
||||
"""
|
||||
{self.summary}
|
||||
"""
|
||||
|
||||
Latest Development:
|
||||
"""
|
||||
{new_events_batch or "Nothing new happened."}
|
||||
"""
|
||||
'''
|
||||
def summarize_batch(
|
||||
self, new_events_batch: list[Message], config: Config, max_output_length: int
|
||||
):
|
||||
prompt = MessageHistory.SUMMARIZATION_PROMPT.format(
|
||||
summary=self.summary, new_events=new_events_batch
|
||||
)
|
||||
|
||||
prompt = ChatSequence.for_model(config.fast_llm, [Message("user", prompt)])
|
||||
self.agent.log_cycle_handler.log_cycle(
|
||||
self.agent.ai_name,
|
||||
self.agent.created_at,
|
||||
self.agent.cycle_count,
|
||||
prompt.raw(),
|
||||
PROMPT_SUMMARY_FILE_NAME,
|
||||
)
|
||||
if self.agent:
|
||||
self.agent.log_cycle_handler.log_cycle(
|
||||
self.agent.ai_config.ai_name,
|
||||
self.agent.created_at,
|
||||
self.agent.cycle_count,
|
||||
prompt.raw(),
|
||||
PROMPT_SUMMARY_FILE_NAME,
|
||||
)
|
||||
|
||||
self.summary = create_chat_completion(prompt, config).content
|
||||
self.summary = create_chat_completion(
|
||||
prompt, config, max_tokens=max_output_length
|
||||
).content
|
||||
|
||||
self.agent.log_cycle_handler.log_cycle(
|
||||
self.agent.ai_name,
|
||||
self.agent.created_at,
|
||||
self.agent.cycle_count,
|
||||
self.summary,
|
||||
SUMMARY_FILE_NAME,
|
||||
)
|
||||
if self.agent:
|
||||
self.agent.log_cycle_handler.log_cycle(
|
||||
self.agent.ai_config.ai_name,
|
||||
self.agent.created_at,
|
||||
self.agent.cycle_count,
|
||||
self.summary,
|
||||
SUMMARY_FILE_NAME,
|
||||
)
|
||||
|
||||
@@ -6,13 +6,12 @@ import numpy as np
|
||||
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.logs import logger
|
||||
from autogpt.singleton import AbstractSingleton
|
||||
|
||||
from .. import MemoryItem, MemoryItemRelevance
|
||||
from ..utils import Embedding, get_embedding
|
||||
|
||||
|
||||
class VectorMemoryProvider(MutableSet[MemoryItem], AbstractSingleton):
|
||||
class VectorMemoryProvider(MutableSet[MemoryItem]):
|
||||
@abc.abstractmethod
|
||||
def __init__(self, config: Config):
|
||||
pass
|
||||
|
||||
@@ -15,8 +15,12 @@ class CommandRegistry:
|
||||
directory.
|
||||
"""
|
||||
|
||||
commands: dict[str, Command] = {}
|
||||
commands_aliases: dict[str, Command] = {}
|
||||
commands: dict[str, Command]
|
||||
commands_aliases: dict[str, Command]
|
||||
|
||||
def __init__(self):
|
||||
self.commands = {}
|
||||
self.commands_aliases = {}
|
||||
|
||||
def __contains__(self, command_name: str):
|
||||
return command_name in self.commands or command_name in self.commands_aliases
|
||||
|
||||
@@ -7,7 +7,6 @@ constraints: [
|
||||
resources: [
|
||||
'Internet access for searches and information gathering.',
|
||||
'Long Term memory management.',
|
||||
'GPT-3.5 powered Agents for delegation of simple tasks.',
|
||||
'File output.'
|
||||
]
|
||||
performance_evaluations: [
|
||||
|
||||
Submodule tests/Auto-GPT-test-cassettes updated: 4485d191a4...d584872257
@@ -6,7 +6,7 @@ from typing import Any, Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.log_cycle.log_cycle import LogCycleHandler
|
||||
from autogpt.logs import LogCycleHandler
|
||||
from autogpt.workspace import Workspace
|
||||
from benchmarks import run_task
|
||||
from tests.challenges.schema import Task
|
||||
|
||||
@@ -10,6 +10,7 @@ from autogpt.agent.agent import Agent
|
||||
from autogpt.config import AIConfig, Config, ConfigBuilder
|
||||
from autogpt.config.ai_config import AIConfig
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.logs import logger
|
||||
from autogpt.memory.vector import get_memory
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
|
||||
@@ -52,6 +53,9 @@ def config(
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = "sk-dummy"
|
||||
|
||||
# HACK: this is necessary to ensure PLAIN_OUTPUT takes effect
|
||||
logger.config = config
|
||||
|
||||
config.plugins_dir = "tests/unit/data/test_plugins"
|
||||
config.plugins_config_file = temp_plugins_config_file
|
||||
|
||||
|
||||
@@ -8,12 +8,6 @@ from autogpt.memory.vector import JSONFileMemory, MemoryItem
|
||||
from autogpt.workspace import Workspace
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_sut_singleton():
|
||||
if JSONFileMemory in JSONFileMemory._instances:
|
||||
del JSONFileMemory._instances[JSONFileMemory]
|
||||
|
||||
|
||||
def test_json_memory_init_without_backing_file(config: Config, workspace: Workspace):
|
||||
index_file = workspace.root / f"{config.memory_index}.json"
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import tempfile
|
||||
|
||||
@@ -88,13 +89,9 @@ def test_execute_python_file_invalid(agent: Agent):
|
||||
|
||||
|
||||
def test_execute_python_file_not_found(agent: Agent):
|
||||
assert all(
|
||||
s in sut.execute_python_file("notexist.py", agent).lower()
|
||||
for s in [
|
||||
"python: can't open file 'notexist.py'",
|
||||
"[errno 2] no such file or directory",
|
||||
]
|
||||
)
|
||||
result = sut.execute_python_file("notexist.py", agent).lower()
|
||||
assert re.match(r"python: can't open file '([A-Z]:)?[/\\\-\w]*notexist.py'", result)
|
||||
assert "[errno 2] no such file or directory" in result
|
||||
|
||||
|
||||
def test_execute_shell(random_string: str, agent: Agent):
|
||||
|
||||
@@ -12,6 +12,6 @@ def test_browse_website(agent: Agent, patched_api_requestor: MockerFixture):
|
||||
question = "How to execute a barrel roll"
|
||||
|
||||
response = browse_website(url, question, agent)
|
||||
assert "Error" in response
|
||||
assert "error" in response.lower()
|
||||
# Sanity check that the response is not too long
|
||||
assert len(response) < 200
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from autogpt.agent.agent_manager import AgentManager
|
||||
from autogpt.llm import ChatModelResponse
|
||||
from autogpt.llm.chat import create_chat_completion
|
||||
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_manager(config):
|
||||
# Hack, real gross. Singletons are not good times.
|
||||
yield AgentManager(config)
|
||||
del AgentManager._instances[AgentManager]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task():
|
||||
return "translate English to French"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt():
|
||||
return "Translate the following English text to French: 'Hello, how are you?'"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model():
|
||||
return "gpt-3.5-turbo"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_create_chat_completion(mocker, config):
|
||||
mock_create_chat_completion = mocker.patch(
|
||||
"autogpt.agent.agent_manager.create_chat_completion",
|
||||
wraps=create_chat_completion,
|
||||
)
|
||||
mock_create_chat_completion.return_value = ChatModelResponse(
|
||||
model_info=OPEN_AI_CHAT_MODELS[config.fast_llm],
|
||||
content="irrelevant",
|
||||
function_call={},
|
||||
)
|
||||
return mock_create_chat_completion
|
||||
|
||||
|
||||
def test_create_agent(agent_manager: AgentManager, task, prompt, model):
|
||||
key, agent_reply = agent_manager.create_agent(task, prompt, model)
|
||||
assert isinstance(key, int)
|
||||
assert isinstance(agent_reply, str)
|
||||
assert key in agent_manager.agents
|
||||
|
||||
|
||||
def test_message_agent(agent_manager: AgentManager, task, prompt, model):
|
||||
key, _ = agent_manager.create_agent(task, prompt, model)
|
||||
user_message = "Please translate 'Good morning' to French."
|
||||
agent_reply = agent_manager.message_agent(key, user_message)
|
||||
assert isinstance(agent_reply, str)
|
||||
|
||||
|
||||
def test_list_agents(agent_manager: AgentManager, task, prompt, model):
|
||||
key, _ = agent_manager.create_agent(task, prompt, model)
|
||||
agents_list = agent_manager.list_agents()
|
||||
assert isinstance(agents_list, list)
|
||||
assert (key, task) in agents_list
|
||||
|
||||
|
||||
def test_delete_agent(agent_manager: AgentManager, task, prompt, model):
|
||||
key, _ = agent_manager.create_agent(task, prompt, model)
|
||||
success = agent_manager.delete_agent(key)
|
||||
assert success
|
||||
assert key not in agent_manager.agents
|
||||
@@ -44,8 +44,13 @@ def mock_MemoryItem_from_text(
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_file_path(workspace: Workspace):
|
||||
return workspace.get_path("test_file.txt")
|
||||
def test_file_name():
|
||||
return Path("test_file.txt")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_file_path(test_file_name: Path, workspace: Workspace):
|
||||
return workspace.get_path(test_file_name)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -130,42 +135,34 @@ def test_is_duplicate_operation(agent: Agent, mocker: MockerFixture):
|
||||
# Test cases with write operations
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"write", "path/to/file1.txt", agent.config, "checksum1"
|
||||
"write", "path/to/file1.txt", agent, "checksum1"
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"write", "path/to/file1.txt", agent.config, "checksum2"
|
||||
"write", "path/to/file1.txt", agent, "checksum2"
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"write", "path/to/file3.txt", agent.config, "checksum3"
|
||||
"write", "path/to/file3.txt", agent, "checksum3"
|
||||
)
|
||||
is False
|
||||
)
|
||||
# Test cases with append operations
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"append", "path/to/file1.txt", agent.config, "checksum1"
|
||||
"append", "path/to/file1.txt", agent, "checksum1"
|
||||
)
|
||||
is False
|
||||
)
|
||||
# Test cases with delete operations
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"delete", "path/to/file1.txt", config=agent.config
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"delete", "path/to/file3.txt", config=agent.config
|
||||
)
|
||||
is True
|
||||
file_ops.is_duplicate_operation("delete", "path/to/file1.txt", agent) is False
|
||||
)
|
||||
assert file_ops.is_duplicate_operation("delete", "path/to/file3.txt", agent) is True
|
||||
|
||||
|
||||
# Test logging a file operation
|
||||
@@ -206,7 +203,15 @@ def test_read_file_not_found(agent: Agent):
|
||||
assert "Error:" in content and filename in content and "no such file" in content
|
||||
|
||||
|
||||
def test_write_to_file(test_file_path: Path, agent: Agent):
|
||||
def test_write_to_file_relative_path(test_file_name: Path, agent: Agent):
|
||||
new_content = "This is new content.\n"
|
||||
file_ops.write_to_file(str(test_file_name), new_content, agent=agent)
|
||||
with open(agent.workspace.get_path(test_file_name), "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert content == new_content
|
||||
|
||||
|
||||
def test_write_to_file_absolute_path(test_file_path: Path, agent: Agent):
|
||||
new_content = "This is new content.\n"
|
||||
file_ops.write_to_file(str(test_file_path), new_content, agent=agent)
|
||||
with open(test_file_path, "r", encoding="utf-8") as f:
|
||||
@@ -214,24 +219,24 @@ def test_write_to_file(test_file_path: Path, agent: Agent):
|
||||
assert content == new_content
|
||||
|
||||
|
||||
def test_write_file_logs_checksum(test_file_path: Path, agent: Agent):
|
||||
def test_write_file_logs_checksum(test_file_name: Path, agent: Agent):
|
||||
new_content = "This is new content.\n"
|
||||
new_checksum = file_ops.text_checksum(new_content)
|
||||
file_ops.write_to_file(str(test_file_path), new_content, agent=agent)
|
||||
file_ops.write_to_file(str(test_file_name), new_content, agent=agent)
|
||||
with open(agent.config.file_logger_path, "r", encoding="utf-8") as f:
|
||||
log_entry = f.read()
|
||||
assert log_entry == f"write: {test_file_path} #{new_checksum}\n"
|
||||
assert log_entry == f"write: {test_file_name} #{new_checksum}\n"
|
||||
|
||||
|
||||
def test_write_file_fails_if_content_exists(test_file_path: Path, agent: Agent):
|
||||
def test_write_file_fails_if_content_exists(test_file_name: Path, agent: Agent):
|
||||
new_content = "This is new content.\n"
|
||||
file_ops.log_operation(
|
||||
"write",
|
||||
str(test_file_path),
|
||||
str(test_file_name),
|
||||
agent=agent,
|
||||
checksum=file_ops.text_checksum(new_content),
|
||||
)
|
||||
result = file_ops.write_to_file(str(test_file_path), new_content, agent=agent)
|
||||
result = file_ops.write_to_file(str(test_file_name), new_content, agent=agent)
|
||||
assert result == "Error: File has already been updated."
|
||||
|
||||
|
||||
@@ -258,11 +263,11 @@ def test_append_to_file(test_nested_file: Path, agent: Agent):
|
||||
|
||||
|
||||
def test_append_to_file_uses_checksum_from_appended_file(
|
||||
test_file_path: Path, agent: Agent
|
||||
test_file_name: Path, agent: Agent
|
||||
):
|
||||
append_text = "This is appended text.\n"
|
||||
file_ops.append_to_file(test_file_path, append_text, agent=agent)
|
||||
file_ops.append_to_file(test_file_path, append_text, agent=agent)
|
||||
file_ops.append_to_file(test_file_name, append_text, agent=agent)
|
||||
file_ops.append_to_file(test_file_name, append_text, agent=agent)
|
||||
with open(agent.config.file_logger_path, "r", encoding="utf-8") as f:
|
||||
log_contents = f.read()
|
||||
|
||||
@@ -272,8 +277,8 @@ def test_append_to_file_uses_checksum_from_appended_file(
|
||||
digest.update(append_text.encode("utf-8"))
|
||||
checksum2 = digest.hexdigest()
|
||||
assert log_contents == (
|
||||
f"append: {test_file_path} #{checksum1}\n"
|
||||
f"append: {test_file_path} #{checksum2}\n"
|
||||
f"append: {test_file_name} #{checksum1}\n"
|
||||
f"append: {test_file_name} #{checksum2}\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -288,7 +293,7 @@ def test_delete_missing_file(agent: Agent):
|
||||
# confuse the log
|
||||
file_ops.log_operation("write", filename, agent=agent, checksum="fake")
|
||||
try:
|
||||
os.remove(filename)
|
||||
os.remove(agent.workspace.get_path(filename))
|
||||
except FileNotFoundError as err:
|
||||
assert str(err) in file_ops.delete_file(filename, agent=agent)
|
||||
return
|
||||
|
||||
@@ -38,8 +38,8 @@ def agent(config: Config):
|
||||
return agent
|
||||
|
||||
|
||||
def test_message_history_batch_summary(mocker, agent, config):
|
||||
history = MessageHistory(agent)
|
||||
def test_message_history_batch_summary(mocker, agent: Agent, config: Config):
|
||||
history = MessageHistory.for_model(agent.config.smart_llm, agent=agent)
|
||||
model = config.fast_llm
|
||||
message_tlength = 0
|
||||
message_count = 0
|
||||
@@ -48,7 +48,7 @@ def test_message_history_batch_summary(mocker, agent, config):
|
||||
mock_summary_response = ChatModelResponse(
|
||||
model_info=OPEN_AI_CHAT_MODELS[model],
|
||||
content="I executed browse_website command for each of the websites returned from Google search, but none of them have any job openings.",
|
||||
function_call={},
|
||||
function_call=None,
|
||||
)
|
||||
mock_summary = mocker.patch(
|
||||
"autogpt.memory.message_history.create_chat_completion",
|
||||
@@ -105,7 +105,7 @@ def test_message_history_batch_summary(mocker, agent, config):
|
||||
result = (
|
||||
"Command browse_website returned: Answer gathered from website: The text in job"
|
||||
+ str(i)
|
||||
+ " does not provide information on specific job requirements or a job URL.]",
|
||||
+ " does not provide information on specific job requirements or a job URL.]"
|
||||
)
|
||||
msg = Message("system", result, "action_result")
|
||||
history.append(msg)
|
||||
@@ -117,7 +117,7 @@ def test_message_history_batch_summary(mocker, agent, config):
|
||||
history.append(user_input_msg)
|
||||
|
||||
# only take the last cycle of the message history, trim the rest of previous messages, and generate a summary for them
|
||||
for cycle in reversed(list(history.per_cycle(config))):
|
||||
for cycle in reversed(list(history.per_cycle())):
|
||||
messages_to_add = [msg for msg in cycle if msg is not None]
|
||||
message_sequence.insert(insertion_index, *messages_to_add)
|
||||
break
|
||||
@@ -134,7 +134,7 @@ def test_message_history_batch_summary(mocker, agent, config):
|
||||
)
|
||||
|
||||
expected_call_count = math.ceil(
|
||||
message_tlength / (OPEN_AI_CHAT_MODELS.get(config.fast_llm).max_tokens)
|
||||
message_tlength / (OPEN_AI_CHAT_MODELS[config.fast_llm].max_tokens)
|
||||
)
|
||||
# Expecting 2 batches because of over max token
|
||||
assert mock_summary.call_count == expected_call_count # 2 at the time of writing
|
||||
|
||||
@@ -20,7 +20,7 @@ def error_factory(error_instance, error_count, retry_count, warn_user=True):
|
||||
self.count = 0
|
||||
|
||||
@openai.retry_api(
|
||||
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
|
||||
max_retries=retry_count, backoff_base=0.001, warn_user=warn_user
|
||||
)
|
||||
def __call__(self):
|
||||
self.count += 1
|
||||
@@ -69,16 +69,11 @@ def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure
|
||||
|
||||
if error_count and retry_count:
|
||||
if type(error) == RateLimitError:
|
||||
assert "Reached rate limit, passing..." in output.out
|
||||
assert "Reached rate limit" in output.out
|
||||
assert "Please double check" in output.out
|
||||
if type(error) == ServiceUnavailableError:
|
||||
assert (
|
||||
"The OpenAI API engine is currently overloaded, passing..."
|
||||
in output.out
|
||||
)
|
||||
assert "The OpenAI API engine is currently overloaded" in output.out
|
||||
assert "Please double check" in output.out
|
||||
if type(error) == APIError:
|
||||
assert "API Bad gateway" in output.out
|
||||
else:
|
||||
assert output.out == ""
|
||||
|
||||
@@ -96,7 +91,7 @@ def test_retry_open_api_rate_limit_no_warn(capsys):
|
||||
|
||||
output = capsys.readouterr()
|
||||
|
||||
assert "Reached rate limit, passing..." in output.out
|
||||
assert "Reached rate limit" in output.out
|
||||
assert "Please double check" not in output.out
|
||||
|
||||
|
||||
@@ -115,7 +110,7 @@ def test_retry_open_api_service_unavairable_no_warn(capsys):
|
||||
|
||||
output = capsys.readouterr()
|
||||
|
||||
assert "The OpenAI API engine is currently overloaded, passing..." in output.out
|
||||
assert "The OpenAI API engine is currently overloaded" in output.out
|
||||
assert "Please double check" not in output.out
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import os
|
||||
from hashlib import sha256
|
||||
|
||||
import openai.api_requestor
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .vcr_filter import PROXY, before_record_request, before_record_response
|
||||
from .vcr_filter import (
|
||||
PROXY,
|
||||
before_record_request,
|
||||
before_record_response,
|
||||
freeze_request_body,
|
||||
)
|
||||
|
||||
DEFAULT_RECORD_MODE = "new_episodes"
|
||||
BASE_VCR_CONFIG = {
|
||||
@@ -12,10 +18,13 @@ BASE_VCR_CONFIG = {
|
||||
"before_record_response": before_record_response,
|
||||
"filter_headers": [
|
||||
"Authorization",
|
||||
"AGENT-MODE",
|
||||
"AGENT-TYPE",
|
||||
"OpenAI-Organization",
|
||||
"X-OpenAI-Client-User-Agent",
|
||||
"User-Agent",
|
||||
],
|
||||
"match_on": ["method", "body"],
|
||||
"match_on": ["method", "headers"],
|
||||
}
|
||||
|
||||
|
||||
@@ -41,7 +50,7 @@ def vcr_cassette_dir(request):
|
||||
return os.path.join("tests/Auto-GPT-test-cassettes", test_name)
|
||||
|
||||
|
||||
def patch_api_base(requestor):
|
||||
def patch_api_base(requestor: openai.api_requestor.APIRequestor):
|
||||
new_api_base = f"{PROXY}/v1"
|
||||
requestor.api_base = new_api_base
|
||||
return requestor
|
||||
@@ -49,23 +58,35 @@ def patch_api_base(requestor):
|
||||
|
||||
@pytest.fixture
|
||||
def patched_api_requestor(mocker: MockerFixture):
|
||||
original_init = openai.api_requestor.APIRequestor.__init__
|
||||
original_validate_headers = openai.api_requestor.APIRequestor._validate_headers
|
||||
init_requestor = openai.api_requestor.APIRequestor.__init__
|
||||
prepare_request = openai.api_requestor.APIRequestor._prepare_request_raw
|
||||
|
||||
def patched_init(requestor, *args, **kwargs):
|
||||
original_init(requestor, *args, **kwargs)
|
||||
def patched_init_requestor(requestor, *args, **kwargs):
|
||||
init_requestor(requestor, *args, **kwargs)
|
||||
patch_api_base(requestor)
|
||||
|
||||
def patched_validate_headers(self, supplied_headers):
|
||||
headers = original_validate_headers(self, supplied_headers)
|
||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
|
||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
|
||||
return headers
|
||||
def patched_prepare_request(self, *args, **kwargs):
|
||||
url, headers, data = prepare_request(self, *args, **kwargs)
|
||||
|
||||
if PROXY:
|
||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
|
||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
|
||||
|
||||
# Add hash header for cheap & fast matching on cassette playback
|
||||
headers["X-Content-Hash"] = sha256(
|
||||
freeze_request_body(data), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
return url, headers, data
|
||||
|
||||
if PROXY:
|
||||
mocker.patch("openai.api_requestor.APIRequestor.__init__", new=patched_init)
|
||||
mocker.patch.object(
|
||||
openai.api_requestor.APIRequestor,
|
||||
"_validate_headers",
|
||||
new=patched_validate_headers,
|
||||
"__init__",
|
||||
new=patched_init_requestor,
|
||||
)
|
||||
mocker.patch.object(
|
||||
openai.api_requestor.APIRequestor,
|
||||
"_prepare_request_raw",
|
||||
new=patched_prepare_request,
|
||||
)
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def replace_timestamp_in_request(request):
|
||||
# Check if the request body contains a JSON object
|
||||
|
||||
try:
|
||||
if not request or not request.body:
|
||||
return request
|
||||
body = json.loads(request.body)
|
||||
except ValueError:
|
||||
return request
|
||||
|
||||
if "messages" not in body:
|
||||
return request
|
||||
|
||||
for message in body["messages"]:
|
||||
if "content" in message and "role" in message and message["role"] == "system":
|
||||
timestamp_regex = re.compile(r"\w{3} \w{3} \d{2} \d{2}:\d{2}:\d{2} \d{4}")
|
||||
message["content"] = timestamp_regex.sub(
|
||||
"Tue Jan 01 00:00:00 2000", message["content"]
|
||||
)
|
||||
|
||||
request.body = json.dumps(body)
|
||||
return request
|
||||
|
||||
|
||||
def before_record_response(response):
|
||||
if "Transfer-Encoding" in response["headers"]:
|
||||
del response["headers"]["Transfer-Encoding"]
|
||||
return response
|
||||
|
||||
|
||||
def before_record_request(request):
|
||||
filtered_request = filter_hostnames(request)
|
||||
filtered_request_without_dynamic_data = replace_timestamp_in_request(
|
||||
filtered_request
|
||||
)
|
||||
return filtered_request_without_dynamic_data
|
||||
|
||||
|
||||
def filter_hostnames(request):
|
||||
allowed_hostnames = [
|
||||
"api.openai.com",
|
||||
"localhost:50337",
|
||||
] # List of hostnames you want to allow
|
||||
|
||||
if any(hostname in request.url for hostname in allowed_hostnames):
|
||||
return request
|
||||
else:
|
||||
return None
|
||||
@@ -1,8 +1,12 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from vcr.request import Request
|
||||
|
||||
PROXY = os.environ.get("PROXY")
|
||||
|
||||
REPLACEMENTS: List[Dict[str, str]] = [
|
||||
@@ -39,19 +43,20 @@ def replace_message_content(content: str, replacements: List[Dict[str, str]]) ->
|
||||
return content
|
||||
|
||||
|
||||
def replace_timestamp_in_request(request: Any) -> Any:
|
||||
def freeze_request_body(json_body: str | bytes) -> bytes:
|
||||
"""Remove any dynamic items from the request body"""
|
||||
|
||||
try:
|
||||
if not request or not request.body:
|
||||
return request
|
||||
body = json.loads(request.body)
|
||||
body = json.loads(json_body)
|
||||
except ValueError:
|
||||
return request
|
||||
return json_body if type(json_body) == bytes else json_body.encode()
|
||||
|
||||
if "messages" not in body:
|
||||
return request
|
||||
body[
|
||||
"max_tokens"
|
||||
] = 0 # this field is inconsistent between requests and not used at the moment.
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
if "max_tokens" in body:
|
||||
del body["max_tokens"]
|
||||
|
||||
for message in body["messages"]:
|
||||
if "content" in message and "role" in message:
|
||||
if message["role"] == "system":
|
||||
@@ -59,7 +64,20 @@ def replace_timestamp_in_request(request: Any) -> Any:
|
||||
message["content"], REPLACEMENTS
|
||||
)
|
||||
|
||||
request.body = json.dumps(body)
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
|
||||
def freeze_request(request: Request) -> Request:
|
||||
if not request or not request.body:
|
||||
return request
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
request.body = freeze_request_body(
|
||||
request.body.getvalue()
|
||||
if isinstance(request.body, BytesIO)
|
||||
else request.body
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
@@ -69,20 +87,23 @@ def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return response
|
||||
|
||||
|
||||
def before_record_request(request: Any) -> Any:
|
||||
def before_record_request(request: Request) -> Request | None:
|
||||
request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL)
|
||||
|
||||
filtered_request = filter_hostnames(request)
|
||||
filtered_request_without_dynamic_data = replace_timestamp_in_request(
|
||||
filtered_request
|
||||
)
|
||||
if not filtered_request:
|
||||
return None
|
||||
|
||||
filtered_request_without_dynamic_data = freeze_request(filtered_request)
|
||||
return filtered_request_without_dynamic_data
|
||||
|
||||
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
|
||||
def replace_request_hostname(request: Any, original_url: str, new_hostname: str) -> Any:
|
||||
def replace_request_hostname(
|
||||
request: Request, original_url: str, new_hostname: str
|
||||
) -> Request:
|
||||
parsed_url = urlparse(request.uri)
|
||||
|
||||
if parsed_url.hostname in original_url:
|
||||
@@ -94,7 +115,7 @@ def replace_request_hostname(request: Any, original_url: str, new_hostname: str)
|
||||
return request
|
||||
|
||||
|
||||
def filter_hostnames(request: Any) -> Any:
|
||||
def filter_hostnames(request: Request) -> Request | None:
|
||||
# Add your implementation here for filtering hostnames
|
||||
if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES):
|
||||
return request
|
||||
|
||||
Reference in New Issue
Block a user