feat(agent/core): Add Anthropic Claude 3 support (#7085)

- feat(agent/core): Add `AnthropicProvider`
  - Add `ANTHROPIC_API_KEY` to .env.template and docs

  Notable differences in logic compared to `OpenAIProvider`:
  - Merges subsequent user messages in `AnthropicProvider._get_chat_completion_args`
  - Merges and extracts all system messages into `system` parameter in `AnthropicProvider._get_chat_completion_args`
  - Supports prefill; merges prefill content (if any) into generated response

- Prompt changes to improve compatibility with `AnthropicProvider`
  Anthropic has a slightly different API compared to OpenAI, and has much stricter input validation. E.g. Anthropic only supports a single `system` prompt, where OpenAI allows multiple `system` messages. Anthropic also forbids sequences of multiple `user` or `assistant` messages and requires that messages alternate between roles.
  - Move response format instruction from separate message into main system prompt
  - Fix clock message format
  - Add pre-fill to `OneShot` generated prompt

- refactor(agent/core): Tweak `model_providers.schema`
  - Simplify `ModelProviderUsage`
     - Remove attribute `total_tokens` as it is always equal to `prompt_tokens + completion_tokens`
     - Modify signature of `update_usage(..)`; no longer requires a full `ModelResponse` object as input
  - Improve `ModelProviderBudget`
     - Change type of attribute `usage` to `defaultdict[str, ModelProviderUsage]` -> allow per-model usage tracking
     - Modify signature of `update_usage_and_cost(..)`; no longer requires a full `ModelResponse` object as input
     - Allow `ModelProviderBudget` zero-argument instantiation
  - Fix type of `AssistantChatMessage.role` to match `ChatMessage.role` (str -> `ChatMessage.Role`)
  - Add shared attributes and constructor to `ModelProvider` base class
  - Add `max_output_tokens` parameter to `create_chat_completion` interface
  - Add pre-filling as a global feature
    - Add `prefill_response` field to `ChatPrompt` model
    - Add `prefill_response` parameter to `create_chat_completion` interface
  - Add `ChatModelProvider.get_available_models()` and remove `ApiManager`
  - Remove unused `OpenAIChatParser` typedef in openai.py
  - Remove redundant `budget` attribute definition on `OpenAISettings`
  - Remove unnecessary `usage` in `OpenAIProvider` > `default_settings` > `budget`

- feat(agent): Allow use of any available LLM provider through `MultiProvider`
  - Add `MultiProvider` (`model_providers.multi`)
  - Replace all references to / uses of `OpenAIProvider` with `MultiProvider`
  - Change type of `Config.smart_llm` and `Config.fast_llm` from `str` to `ModelName`

- feat(agent/core): Validate function call arguments in `create_chat_completion`
    - Add `validate_call` method to `CompletionModelFunction` in `model_providers.schema`
    - Add `validate_tool_calls` utility function in `model_providers.utils`
    - Add tool call validation step to `create_chat_completion` in `OpenAIProvider` and `AnthropicProvider`
    - Remove (now redundant) command argument validation logic in agent.py and models/command.py

- refactor(agent): Rename `get_openai_command_specs` to `function_specs_from_commands`
This commit is contained in:
Reinier van der Leer
2024-05-04 20:33:25 +02:00
committed by GitHub
parent 78d83bb3ce
commit 39c46ef6be
24 changed files with 923 additions and 149 deletions

View File

@@ -2,8 +2,11 @@
### AutoGPT - GENERAL SETTINGS ### AutoGPT - GENERAL SETTINGS
################################################################################ ################################################################################
## OPENAI_API_KEY - OpenAI API Key (Example: my-openai-api-key) ## OPENAI_API_KEY - OpenAI API Key (Example: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
OPENAI_API_KEY=your-openai-api-key # OPENAI_API_KEY=
## ANTHROPIC_API_KEY - Anthropic API Key (Example: sk-ant-api03-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
# ANTHROPIC_API_KEY=
## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry. ## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry.
## This helps us to spot and solve problems earlier & faster. (Default: DISABLED) ## This helps us to spot and solve problems earlier & faster. (Default: DISABLED)

View File

@@ -5,8 +5,7 @@ from pathlib import Path
from autogpt.agent_manager.agent_manager import AgentManager from autogpt.agent_manager.agent_manager import AgentManager
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptStrategy from autogpt.app.main import _configure_llm_provider, run_interaction_loop
from autogpt.app.main import _configure_openai_provider, run_interaction_loop
from autogpt.config import AIProfile, ConfigBuilder from autogpt.config import AIProfile, ConfigBuilder
from autogpt.file_storage import FileStorageBackendName, get_storage from autogpt.file_storage import FileStorageBackendName, get_storage
from autogpt.logs.config import configure_logging from autogpt.logs.config import configure_logging
@@ -38,10 +37,6 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
ai_goals=[task], ai_goals=[task],
) )
agent_prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(
deep=True
)
agent_prompt_config.use_functions_api = config.openai_functions
agent_settings = AgentSettings( agent_settings = AgentSettings(
name=Agent.default_settings.name, name=Agent.default_settings.name,
agent_id=AgentManager.generate_id("AutoGPT-benchmark"), agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
@@ -53,7 +48,6 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
allow_fs_access=not config.restrict_to_workspace, allow_fs_access=not config.restrict_to_workspace,
use_functions_api=config.openai_functions, use_functions_api=config.openai_functions,
), ),
prompt_config=agent_prompt_config,
history=Agent.default_settings.history.copy(deep=True), history=Agent.default_settings.history.copy(deep=True),
) )
@@ -66,7 +60,7 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
agent = Agent( agent = Agent(
settings=agent_settings, settings=agent_settings,
llm_provider=_configure_openai_provider(config), llm_provider=_configure_llm_provider(config),
file_storage=file_storage, file_storage=file_storage,
legacy_config=config, legacy_config=config,
) )

View File

@@ -19,7 +19,6 @@ from autogpt.components.event_history import EventHistoryComponent
from autogpt.core.configuration import Configurable from autogpt.core.configuration import Configurable
from autogpt.core.prompting import ChatPrompt from autogpt.core.prompting import ChatPrompt
from autogpt.core.resource.model_providers import ( from autogpt.core.resource.model_providers import (
AssistantChatMessage,
AssistantFunctionCall, AssistantFunctionCall,
ChatMessage, ChatMessage,
ChatModelProvider, ChatModelProvider,
@@ -27,7 +26,7 @@ from autogpt.core.resource.model_providers import (
) )
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
from autogpt.file_storage.base import FileStorage from autogpt.file_storage.base import FileStorage
from autogpt.llm.providers.openai import get_openai_command_specs from autogpt.llm.providers.openai import function_specs_from_commands
from autogpt.logs.log_cycle import ( from autogpt.logs.log_cycle import (
CURRENT_CONTEXT_FILE_NAME, CURRENT_CONTEXT_FILE_NAME,
NEXT_ACTION_FILE_NAME, NEXT_ACTION_FILE_NAME,
@@ -46,7 +45,6 @@ from autogpt.utils.exceptions import (
AgentException, AgentException,
AgentTerminated, AgentTerminated,
CommandExecutionError, CommandExecutionError,
InvalidArgumentError,
UnknownCommandError, UnknownCommandError,
) )
@@ -104,7 +102,11 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
self.ai_profile = settings.ai_profile self.ai_profile = settings.ai_profile
self.directives = settings.directives self.directives = settings.directives
prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(deep=True) prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(deep=True)
prompt_config.use_functions_api = settings.config.use_functions_api prompt_config.use_functions_api = (
settings.config.use_functions_api
# Anthropic currently doesn't support tools + prefilling :(
and self.llm.provider_name != "anthropic"
)
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger) self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
self.commands: list[Command] = [] self.commands: list[Command] = []
@@ -172,7 +174,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
task=self.state.task, task=self.state.task,
ai_profile=self.state.ai_profile, ai_profile=self.state.ai_profile,
ai_directives=directives, ai_directives=directives,
commands=get_openai_command_specs(self.commands), commands=function_specs_from_commands(self.commands),
include_os_info=self.legacy_config.execute_local_commands, include_os_info=self.legacy_config.execute_local_commands,
) )
@@ -202,12 +204,9 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
] = await self.llm_provider.create_chat_completion( ] = await self.llm_provider.create_chat_completion(
prompt.messages, prompt.messages,
model_name=self.llm.name, model_name=self.llm.name,
completion_parser=self.parse_and_validate_response, completion_parser=self.prompt_strategy.parse_response_content,
functions=( functions=prompt.functions,
get_openai_command_specs(self.commands) prefill_response=prompt.prefill_response,
if self.config.use_functions_api
else []
),
) )
result = response.parsed_result result = response.parsed_result
@@ -223,28 +222,6 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
return result return result
def parse_and_validate_response(
self, llm_response: AssistantChatMessage
) -> OneShotAgentActionProposal:
parsed_response = self.prompt_strategy.parse_response_content(llm_response)
# Validate command arguments
command_name = parsed_response.use_tool.name
command = self._get_command(command_name)
if arg_errors := command.validate_args(parsed_response.use_tool.arguments)[1]:
fmt_errors = [
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
if f.path
else f.message
for f in arg_errors
]
raise InvalidArgumentError(
f"The set of arguments supplied for {command_name} is invalid:\n"
+ "\n".join(fmt_errors)
)
return parsed_response
async def execute( async def execute(
self, self,
proposal: OneShotAgentActionProposal, proposal: OneShotAgentActionProposal,

View File

@@ -39,11 +39,12 @@ from autogpt.core.configuration import (
SystemSettings, SystemSettings,
UserConfigurable, UserConfigurable,
) )
from autogpt.core.resource.model_providers import AssistantFunctionCall from autogpt.core.resource.model_providers import (
from autogpt.core.resource.model_providers.openai import ( CHAT_MODELS,
OPEN_AI_CHAT_MODELS, AssistantFunctionCall,
OpenAIModelName, ModelName,
) )
from autogpt.core.resource.model_providers.openai import OpenAIModelName
from autogpt.models.utils import ModelWithSummary from autogpt.models.utils import ModelWithSummary
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
@@ -56,8 +57,8 @@ P = ParamSpec("P")
class BaseAgentConfiguration(SystemConfiguration): class BaseAgentConfiguration(SystemConfiguration):
allow_fs_access: bool = UserConfigurable(default=False) allow_fs_access: bool = UserConfigurable(default=False)
fast_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k) fast_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
smart_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT4) smart_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT4)
use_functions_api: bool = UserConfigurable(default=False) use_functions_api: bool = UserConfigurable(default=False)
default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT
@@ -174,7 +175,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
llm_name = ( llm_name = (
self.config.smart_llm if self.config.big_brain else self.config.fast_llm self.config.smart_llm if self.config.big_brain else self.config.fast_llm
) )
return OPEN_AI_CHAT_MODELS[llm_name] return CHAT_MODELS[llm_name]
@property @property
def send_token_limit(self) -> int: def send_token_limit(self) -> int:

View File

@@ -122,7 +122,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
1. System prompt 1. System prompt
3. `cycle_instruction` 3. `cycle_instruction`
""" """
system_prompt = self.build_system_prompt( system_prompt, response_prefill = self.build_system_prompt(
ai_profile=ai_profile, ai_profile=ai_profile,
ai_directives=ai_directives, ai_directives=ai_directives,
commands=commands, commands=commands,
@@ -131,24 +131,34 @@ class OneShotAgentPromptStrategy(PromptStrategy):
final_instruction_msg = ChatMessage.user(self.config.choose_action_instruction) final_instruction_msg = ChatMessage.user(self.config.choose_action_instruction)
prompt = ChatPrompt( return ChatPrompt(
messages=[ messages=[
ChatMessage.system(system_prompt), ChatMessage.system(system_prompt),
ChatMessage.user(f'"""{task}"""'), ChatMessage.user(f'"""{task}"""'),
*messages, *messages,
final_instruction_msg, final_instruction_msg,
], ],
prefill_response=response_prefill,
functions=commands if self.config.use_functions_api else [],
) )
return prompt
def build_system_prompt( def build_system_prompt(
self, self,
ai_profile: AIProfile, ai_profile: AIProfile,
ai_directives: AIDirectives, ai_directives: AIDirectives,
commands: list[CompletionModelFunction], commands: list[CompletionModelFunction],
include_os_info: bool, include_os_info: bool,
) -> str: ) -> tuple[str, str]:
"""
Builds the system prompt.
Returns:
str: The system prompt body
str: The desired start for the LLM's response; used to steer the output
"""
response_fmt_instruction, response_prefill = self.response_format_instruction(
self.config.use_functions_api
)
system_prompt_parts = ( system_prompt_parts = (
self._generate_intro_prompt(ai_profile) self._generate_intro_prompt(ai_profile)
+ (self._generate_os_info() if include_os_info else []) + (self._generate_os_info() if include_os_info else [])
@@ -169,16 +179,16 @@ class OneShotAgentPromptStrategy(PromptStrategy):
" in the next message. Your job is to complete the task while following" " in the next message. Your job is to complete the task while following"
" your directives as given above, and terminate when your task is done." " your directives as given above, and terminate when your task is done."
] ]
+ [ + ["## RESPONSE FORMAT\n" + response_fmt_instruction]
"## RESPONSE FORMAT\n"
+ self.response_format_instruction(self.config.use_functions_api)
]
) )
# Join non-empty parts together into paragraph format # Join non-empty parts together into paragraph format
return "\n\n".join(filter(None, system_prompt_parts)).strip("\n") return (
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
response_prefill,
)
def response_format_instruction(self, use_functions_api: bool) -> str: def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
response_schema = self.response_schema.copy(deep=True) response_schema = self.response_schema.copy(deep=True)
if ( if (
use_functions_api use_functions_api
@@ -193,11 +203,15 @@ class OneShotAgentPromptStrategy(PromptStrategy):
"\n", "\n",
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME), response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
) )
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
return ( return (
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n" (
f"{response_format}" f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
+ ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "") f"{response_format}"
+ ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "")
),
response_prefill,
) )
def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]: def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:

View File

@@ -34,7 +34,6 @@ from autogpt.agent_manager import AgentManager
from autogpt.app.utils import is_port_free from autogpt.app.utils import is_port_free
from autogpt.config import Config from autogpt.config import Config
from autogpt.core.resource.model_providers import ChatModelProvider, ModelProviderBudget from autogpt.core.resource.model_providers import ChatModelProvider, ModelProviderBudget
from autogpt.core.resource.model_providers.openai import OpenAIProvider
from autogpt.file_storage import FileStorage from autogpt.file_storage import FileStorage
from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult
from autogpt.utils.exceptions import AgentFinished from autogpt.utils.exceptions import AgentFinished
@@ -464,20 +463,18 @@ class AgentProtocolServer:
if task.additional_input and (user_id := task.additional_input.get("user_id")): if task.additional_input and (user_id := task.additional_input.get("user_id")):
_extra_request_headers["AutoGPT-UserID"] = user_id _extra_request_headers["AutoGPT-UserID"] = user_id
task_llm_provider = None settings = self.llm_provider._settings.copy()
if isinstance(self.llm_provider, OpenAIProvider): settings.budget = task_llm_budget
settings = self.llm_provider._settings.copy() settings.configuration = task_llm_provider_config
settings.budget = task_llm_budget task_llm_provider = self.llm_provider.__class__(
settings.configuration = task_llm_provider_config # type: ignore settings=settings,
task_llm_provider = OpenAIProvider( logger=logger.getChild(
settings=settings, f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}"
logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"), ),
) )
self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore
if task_llm_provider and task_llm_provider._budget: return task_llm_provider
self._task_budgets[task.task_id] = task_llm_provider._budget
return task_llm_provider or self.llm_provider
def task_agent_id(task_id: str | int) -> str: def task_agent_id(task_id: str | int) -> str:

View File

@@ -10,7 +10,7 @@ from colorama import Back, Fore, Style
from autogpt.config import Config from autogpt.config import Config
from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL
from autogpt.core.resource.model_providers.openai import OpenAIModelName, OpenAIProvider from autogpt.core.resource.model_providers import ModelName, MultiProvider
from autogpt.logs.helpers import request_user_double_check from autogpt.logs.helpers import request_user_double_check
from autogpt.memory.vector import get_supported_memory_backends from autogpt.memory.vector import get_supported_memory_backends
from autogpt.utils import utils from autogpt.utils import utils
@@ -150,11 +150,11 @@ async def apply_overrides_to_config(
async def check_model( async def check_model(
model_name: OpenAIModelName, model_type: Literal["smart_llm", "fast_llm"] model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"]
) -> OpenAIModelName: ) -> ModelName:
"""Check if model is available for use. If not, return gpt-3.5-turbo.""" """Check if model is available for use. If not, return gpt-3.5-turbo."""
openai = OpenAIProvider() multi_provider = MultiProvider()
models = await openai.get_available_models() models = await multi_provider.get_available_models()
if any(model_name == m.name for m in models): if any(model_name == m.name for m in models):
return model_name return model_name

View File

@@ -35,7 +35,7 @@ from autogpt.config import (
ConfigBuilder, ConfigBuilder,
assert_config_has_openai_api_key, assert_config_has_openai_api_key,
) )
from autogpt.core.resource.model_providers.openai import OpenAIProvider from autogpt.core.resource.model_providers import MultiProvider
from autogpt.core.runner.client_lib.utils import coroutine from autogpt.core.runner.client_lib.utils import coroutine
from autogpt.file_storage import FileStorageBackendName, get_storage from autogpt.file_storage import FileStorageBackendName, get_storage
from autogpt.logs.config import configure_logging from autogpt.logs.config import configure_logging
@@ -123,7 +123,7 @@ async def run_auto_gpt(
skip_news=skip_news, skip_news=skip_news,
) )
llm_provider = _configure_openai_provider(config) llm_provider = _configure_llm_provider(config)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -399,7 +399,7 @@ async def run_auto_gpt_server(
allow_downloads=allow_downloads, allow_downloads=allow_downloads,
) )
llm_provider = _configure_openai_provider(config) llm_provider = _configure_llm_provider(config)
# Set up & start server # Set up & start server
database = AgentDB( database = AgentDB(
@@ -421,24 +421,12 @@ async def run_auto_gpt_server(
) )
def _configure_openai_provider(config: Config) -> OpenAIProvider: def _configure_llm_provider(config: Config) -> MultiProvider:
"""Create a configured OpenAIProvider object. multi_provider = MultiProvider()
for model in [config.smart_llm, config.fast_llm]:
Args: # Ensure model providers for configured LLMs are available
config: The program's configuration. multi_provider.get_model_provider(model)
return multi_provider
Returns:
A configured OpenAIProvider object.
"""
if config.openai_credentials is None:
raise RuntimeError("OpenAI key is not configured")
openai_settings = OpenAIProvider.default_settings.copy(deep=True)
openai_settings.credentials = config.openai_credentials
return OpenAIProvider(
settings=openai_settings,
logger=logging.getLogger("OpenAIProvider"),
)
def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float: def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float:

View File

@@ -31,7 +31,9 @@ class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
def get_messages(self) -> Iterator[ChatMessage]: def get_messages(self) -> Iterator[ChatMessage]:
# Clock # Clock
yield ChatMessage.system(f"The current time and date is {time.strftime('%c')}") yield ChatMessage.system(
f"## Clock\nThe current time and date is {time.strftime('%c')}"
)
def get_commands(self) -> Iterator[Command]: def get_commands(self) -> Iterator[Command]:
yield self.finish yield self.finish

View File

@@ -17,8 +17,8 @@ from autogpt.core.configuration.schema import (
SystemSettings, SystemSettings,
UserConfigurable, UserConfigurable,
) )
from autogpt.core.resource.model_providers import CHAT_MODELS, ModelName
from autogpt.core.resource.model_providers.openai import ( from autogpt.core.resource.model_providers.openai import (
OPEN_AI_CHAT_MODELS,
OpenAICredentials, OpenAICredentials,
OpenAIModelName, OpenAIModelName,
) )
@@ -74,11 +74,11 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
) )
# Model configuration # Model configuration
fast_llm: OpenAIModelName = UserConfigurable( fast_llm: ModelName = UserConfigurable(
default=OpenAIModelName.GPT3, default=OpenAIModelName.GPT3,
from_env="FAST_LLM", from_env="FAST_LLM",
) )
smart_llm: OpenAIModelName = UserConfigurable( smart_llm: ModelName = UserConfigurable(
default=OpenAIModelName.GPT4_TURBO, default=OpenAIModelName.GPT4_TURBO,
from_env="SMART_LLM", from_env="SMART_LLM",
) )
@@ -206,8 +206,8 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
def validate_openai_functions(cls, v: bool, values: dict[str, Any]): def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
if v: if v:
smart_llm = values["smart_llm"] smart_llm = values["smart_llm"]
assert OPEN_AI_CHAT_MODELS[smart_llm].has_function_call_api, ( assert CHAT_MODELS[smart_llm].has_function_call_api, (
f"Model {smart_llm} does not support OpenAI Functions. " f"Model {smart_llm} does not support tool calling. "
"Please disable OPENAI_FUNCTIONS or choose a suitable model." "Please disable OPENAI_FUNCTIONS or choose a suitable model."
) )
return v return v

View File

@@ -24,6 +24,7 @@ class LanguageModelClassification(str, enum.Enum):
class ChatPrompt(BaseModel): class ChatPrompt(BaseModel):
messages: list[ChatMessage] messages: list[ChatMessage]
functions: list[CompletionModelFunction] = Field(default_factory=list) functions: list[CompletionModelFunction] = Field(default_factory=list)
prefill_response: str = ""
def raw(self) -> list[ChatMessageDict]: def raw(self) -> list[ChatMessageDict]:
return [m.dict() for m in self.messages] return [m.dict() for m in self.messages]

View File

@@ -1,3 +1,4 @@
from .multi import CHAT_MODELS, ModelName, MultiProvider
from .openai import ( from .openai import (
OPEN_AI_CHAT_MODELS, OPEN_AI_CHAT_MODELS,
OPEN_AI_EMBEDDING_MODELS, OPEN_AI_EMBEDDING_MODELS,
@@ -42,11 +43,13 @@ __all__ = [
"ChatModelProvider", "ChatModelProvider",
"ChatModelResponse", "ChatModelResponse",
"CompletionModelFunction", "CompletionModelFunction",
"CHAT_MODELS",
"Embedding", "Embedding",
"EmbeddingModelInfo", "EmbeddingModelInfo",
"EmbeddingModelProvider", "EmbeddingModelProvider",
"EmbeddingModelResponse", "EmbeddingModelResponse",
"ModelInfo", "ModelInfo",
"ModelName",
"ModelProvider", "ModelProvider",
"ModelProviderBudget", "ModelProviderBudget",
"ModelProviderCredentials", "ModelProviderCredentials",
@@ -56,6 +59,7 @@ __all__ = [
"ModelProviderUsage", "ModelProviderUsage",
"ModelResponse", "ModelResponse",
"ModelTokenizer", "ModelTokenizer",
"MultiProvider",
"OPEN_AI_MODELS", "OPEN_AI_MODELS",
"OPEN_AI_CHAT_MODELS", "OPEN_AI_CHAT_MODELS",
"OPEN_AI_EMBEDDING_MODELS", "OPEN_AI_EMBEDDING_MODELS",

View File

@@ -0,0 +1,495 @@
from __future__ import annotations
import enum
import logging
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
import sentry_sdk
import tenacity
import tiktoken
from anthropic import APIConnectionError, APIStatusError
from pydantic import SecretStr
from autogpt.core.configuration import Configurable, UserConfigurable
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessage,
AssistantFunctionCall,
AssistantToolCall,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
CompletionModelFunction,
ModelProviderBudget,
ModelProviderConfiguration,
ModelProviderCredentials,
ModelProviderName,
ModelProviderSettings,
ModelTokenizer,
ToolResultMessage,
)
from .utils import validate_tool_calls
if TYPE_CHECKING:
from anthropic.types.beta.tools import MessageCreateParams
from anthropic.types.beta.tools import ToolsBetaMessage as Message
from anthropic.types.beta.tools import ToolsBetaMessageParam as MessageParam
_T = TypeVar("_T")
_P = ParamSpec("_P")
class AnthropicModelName(str, enum.Enum):
CLAUDE3_OPUS_v1 = "claude-3-opus-20240229"
CLAUDE3_SONNET_v1 = "claude-3-sonnet-20240229"
CLAUDE3_HAIKU_v1 = "claude-3-haiku-20240307"
ANTHROPIC_CHAT_MODELS = {
info.name: info
for info in [
ChatModelInfo(
name=AnthropicModelName.CLAUDE3_OPUS_v1,
provider_name=ModelProviderName.ANTHROPIC,
prompt_token_cost=15 / 1e6,
completion_token_cost=75 / 1e6,
max_tokens=200000,
has_function_call_api=True,
),
ChatModelInfo(
name=AnthropicModelName.CLAUDE3_SONNET_v1,
provider_name=ModelProviderName.ANTHROPIC,
prompt_token_cost=3 / 1e6,
completion_token_cost=15 / 1e6,
max_tokens=200000,
has_function_call_api=True,
),
ChatModelInfo(
name=AnthropicModelName.CLAUDE3_HAIKU_v1,
provider_name=ModelProviderName.ANTHROPIC,
prompt_token_cost=0.25 / 1e6,
completion_token_cost=1.25 / 1e6,
max_tokens=200000,
has_function_call_api=True,
),
]
}
class AnthropicConfiguration(ModelProviderConfiguration):
fix_failed_parse_tries: int = UserConfigurable(3)
class AnthropicCredentials(ModelProviderCredentials):
"""Credentials for Anthropic."""
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY")
api_base: Optional[SecretStr] = UserConfigurable(
default=None, from_env="ANTHROPIC_API_BASE_URL"
)
def get_api_access_kwargs(self) -> dict[str, str]:
return {
k: (v.get_secret_value() if type(v) is SecretStr else v)
for k, v in {
"api_key": self.api_key,
"base_url": self.api_base,
}.items()
if v is not None
}
class AnthropicSettings(ModelProviderSettings):
configuration: AnthropicConfiguration
credentials: Optional[AnthropicCredentials]
budget: ModelProviderBudget
class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
default_settings = AnthropicSettings(
name="anthropic_provider",
description="Provides access to Anthropic's API.",
configuration=AnthropicConfiguration(
retries_per_request=7,
),
credentials=None,
budget=ModelProviderBudget(),
)
_settings: AnthropicSettings
_configuration: AnthropicConfiguration
_credentials: AnthropicCredentials
_budget: ModelProviderBudget
def __init__(
self,
settings: Optional[AnthropicSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not settings:
settings = self.default_settings.copy(deep=True)
if not settings.credentials:
settings.credentials = AnthropicCredentials.from_env()
super(AnthropicProvider, self).__init__(settings=settings, logger=logger)
from anthropic import AsyncAnthropic
self._client = AsyncAnthropic(**self._credentials.get_api_access_kwargs())
async def get_available_models(self) -> list[ChatModelInfo]:
return list(ANTHROPIC_CHAT_MODELS.values())
def get_token_limit(self, model_name: str) -> int:
"""Get the token limit for a given model."""
return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
@classmethod
def get_tokenizer(cls, model_name: AnthropicModelName) -> ModelTokenizer:
# HACK: No official tokenizer is available for Claude 3
return tiktoken.encoding_for_model(model_name)
@classmethod
def count_tokens(cls, text: str, model_name: AnthropicModelName) -> int:
return 0 # HACK: No official tokenizer is available for Claude 3
@classmethod
def count_message_tokens(
cls,
messages: ChatMessage | list[ChatMessage],
model_name: AnthropicModelName,
) -> int:
return 0 # HACK: No official tokenizer is available for Claude 3
async def create_chat_completion(
self,
model_prompt: list[ChatMessage],
model_name: AnthropicModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the Anthropic API."""
anthropic_messages, completion_kwargs = self._get_chat_completion_args(
prompt_messages=model_prompt,
model=model_name,
functions=functions,
max_output_tokens=max_output_tokens,
**kwargs,
)
total_cost = 0.0
attempts = 0
while True:
completion_kwargs["messages"] = anthropic_messages.copy()
if prefill_response:
completion_kwargs["messages"].append(
{"role": "assistant", "content": prefill_response}
)
(
_assistant_msg,
cost,
t_input,
t_output,
) = await self._create_chat_completion(completion_kwargs)
total_cost += cost
self._logger.debug(
f"Completion usage: {t_input} input, {t_output} output "
f"- ${round(cost, 5)}"
)
# Merge prefill into generated response
if prefill_response:
first_text_block = next(
b for b in _assistant_msg.content if b.type == "text"
)
first_text_block.text = prefill_response + first_text_block.text
assistant_msg = AssistantChatMessage(
content="\n\n".join(
b.text for b in _assistant_msg.content if b.type == "text"
),
tool_calls=self._parse_assistant_tool_calls(_assistant_msg),
)
# If parsing the response fails, append the error to the prompt, and let the
# LLM fix its mistake(s).
attempts += 1
tool_call_errors = []
try:
# Validate tool calls
if assistant_msg.tool_calls and functions:
tool_call_errors = validate_tool_calls(
assistant_msg.tool_calls, functions
)
if tool_call_errors:
raise ValueError(
"Invalid tool use(s):\n"
+ "\n".join(str(e) for e in tool_call_errors)
)
parsed_result = completion_parser(assistant_msg)
break
except Exception as e:
self._logger.debug(
f"Parsing failed on response: '''{_assistant_msg}'''"
)
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
sentry_sdk.capture_exception(
error=e,
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
)
if attempts < self._configuration.fix_failed_parse_tries:
anthropic_messages.append(
_assistant_msg.dict(include={"role", "content"})
)
anthropic_messages.append(
{
"role": "user",
"content": [
*(
# tool_result is required if last assistant message
# had tool_use block(s)
{
"type": "tool_result",
"tool_use_id": tc.id,
"is_error": True,
"content": [
{
"type": "text",
"text": "Not executed because parsing "
"of your last message failed"
if not tool_call_errors
else str(e)
if (
e := next(
(
tce
for tce in tool_call_errors
if tce.name
== tc.function.name
),
None,
)
)
else "Not executed because validation "
"of tool input failed",
}
],
}
for tc in assistant_msg.tool_calls or []
),
{
"type": "text",
"text": (
"ERROR PARSING YOUR RESPONSE:\n\n"
f"{e.__class__.__name__}: {e}"
),
},
],
}
)
else:
raise
if attempts > 1:
self._logger.debug(
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
)
return ChatModelResponse(
response=assistant_msg,
parsed_result=parsed_result,
model_info=ANTHROPIC_CHAT_MODELS[model_name],
prompt_tokens_used=t_input,
completion_tokens_used=t_output,
)
def _get_chat_completion_args(
self,
prompt_messages: list[ChatMessage],
model: AnthropicModelName,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> tuple[list[MessageParam], MessageCreateParams]:
"""Prepare arguments for message completion API call.
Args:
prompt_messages: List of ChatMessages.
model: The model to use.
functions: Optional list of functions available to the LLM.
kwargs: Additional keyword arguments.
Returns:
list[MessageParam]: Prompt messages for the Anthropic call
dict[str, Any]: Any other kwargs for the Anthropic call
"""
kwargs["model"] = model
if functions:
kwargs["tools"] = [
{
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": {
name: param.to_dict()
for name, param in f.parameters.items()
},
"required": [
name
for name, param in f.parameters.items()
if param.required
],
},
}
for f in functions
]
kwargs["max_tokens"] = max_output_tokens or 4096
if extra_headers := self._configuration.extra_request_headers:
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
kwargs["extra_headers"].update(extra_headers.copy())
system_messages = [
m for m in prompt_messages if m.role == ChatMessage.Role.SYSTEM
]
if (_n := len(system_messages)) > 1:
self._logger.warning(
f"Prompt has {_n} system messages; Anthropic supports only 1. "
"They will be merged, and removed from the rest of the prompt."
)
kwargs["system"] = "\n\n".join(sm.content for sm in system_messages)
messages: list[MessageParam] = []
for message in prompt_messages:
if message.role == ChatMessage.Role.SYSTEM:
continue
elif message.role == ChatMessage.Role.USER:
# Merge subsequent user messages
if messages and (prev_msg := messages[-1])["role"] == "user":
if isinstance(prev_msg["content"], str):
prev_msg["content"] += f"\n\n{message.content}"
else:
assert isinstance(prev_msg["content"], list)
prev_msg["content"].append(
{"type": "text", "text": message.content}
)
else:
messages.append({"role": "user", "content": message.content})
# TODO: add support for image blocks
elif message.role == ChatMessage.Role.ASSISTANT:
if isinstance(message, AssistantChatMessage) and message.tool_calls:
messages.append(
{
"role": "assistant",
"content": [
*(
[{"type": "text", "text": message.content}]
if message.content
else []
),
*(
{
"type": "tool_use",
"id": tc.id,
"name": tc.function.name,
"input": tc.function.arguments,
}
for tc in message.tool_calls
),
],
}
)
elif message.content:
messages.append(
{
"role": "assistant",
"content": message.content,
}
)
elif isinstance(message, ToolResultMessage):
messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": [{"type": "text", "text": message.content}],
"is_error": message.is_error,
}
],
}
)
return messages, kwargs # type: ignore
async def _create_chat_completion(
self, completion_kwargs: MessageCreateParams
) -> tuple[Message, float, int, int]:
"""
Create a chat completion using the Anthropic API with retry handling.
Params:
completion_kwargs: Keyword arguments for an Anthropic Messages API call
Returns:
Message: The message completion object
float: The cost ($) of this completion
int: Number of input tokens used
int: Number of output tokens used
"""
@self._retry_api_request
async def _create_chat_completion_with_retry(
completion_kwargs: MessageCreateParams,
) -> Message:
return await self._client.beta.tools.messages.create(
**completion_kwargs # type: ignore
)
response = await _create_chat_completion_with_retry(completion_kwargs)
cost = self._budget.update_usage_and_cost(
model_info=ANTHROPIC_CHAT_MODELS[completion_kwargs["model"]],
input_tokens_used=response.usage.input_tokens,
output_tokens_used=response.usage.output_tokens,
)
return response, cost, response.usage.input_tokens, response.usage.output_tokens
def _parse_assistant_tool_calls(
self, assistant_message: Message
) -> list[AssistantToolCall]:
return [
AssistantToolCall(
id=c.id,
type="function",
function=AssistantFunctionCall(name=c.name, arguments=c.input),
)
for c in assistant_message.content
if c.type == "tool_use"
]
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
return tenacity.retry(
retry=(
tenacity.retry_if_exception_type(APIConnectionError)
| tenacity.retry_if_exception(
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
)
),
wait=tenacity.wait_exponential(),
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
after=tenacity.after_log(self._logger, logging.DEBUG),
)(func)
def __repr__(self):
return "AnthropicProvider()"

View File

@@ -0,0 +1,162 @@
from __future__ import annotations
import logging
from typing import Callable, Iterator, Optional, TypeVar
from pydantic import ValidationError
from autogpt.core.configuration import Configurable
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
from .schema import (
AssistantChatMessage,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
CompletionModelFunction,
ModelProviderBudget,
ModelProviderConfiguration,
ModelProviderName,
ModelProviderSettings,
ModelTokenizer,
)
_T = TypeVar("_T")
ModelName = AnthropicModelName | OpenAIModelName
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
default_settings = ModelProviderSettings(
name="multi_provider",
description=(
"Provides access to all of the available models, regardless of provider."
),
configuration=ModelProviderConfiguration(
retries_per_request=7,
),
budget=ModelProviderBudget(),
)
_budget: ModelProviderBudget
_provider_instances: dict[ModelProviderName, ChatModelProvider]
def __init__(
self,
settings: Optional[ModelProviderSettings] = None,
logger: Optional[logging.Logger] = None,
):
super(MultiProvider, self).__init__(settings=settings, logger=logger)
self._budget = self._settings.budget or ModelProviderBudget()
self._provider_instances = {}
async def get_available_models(self) -> list[ChatModelInfo]:
models = []
for provider in self.get_available_providers():
models.extend(await provider.get_available_models())
return models
def get_token_limit(self, model_name: ModelName) -> int:
"""Get the token limit for a given model."""
return self.get_model_provider(model_name).get_token_limit(model_name)
@classmethod
def get_tokenizer(cls, model_name: ModelName) -> ModelTokenizer:
return cls._get_model_provider_class(model_name).get_tokenizer(model_name)
@classmethod
def count_tokens(cls, text: str, model_name: ModelName) -> int:
return cls._get_model_provider_class(model_name).count_tokens(
text=text, model_name=model_name
)
@classmethod
def count_message_tokens(
cls, messages: ChatMessage | list[ChatMessage], model_name: ModelName
) -> int:
return cls._get_model_provider_class(model_name).count_message_tokens(
messages=messages, model_name=model_name
)
async def create_chat_completion(
self,
model_prompt: list[ChatMessage],
model_name: ModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
prefill_response: str = "",
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the Anthropic API."""
return await self.get_model_provider(model_name).create_chat_completion(
model_prompt=model_prompt,
model_name=model_name,
completion_parser=completion_parser,
functions=functions,
max_output_tokens=max_output_tokens,
prefill_response=prefill_response,
**kwargs,
)
def get_model_provider(self, model: ModelName) -> ChatModelProvider:
model_info = CHAT_MODELS[model]
return self._get_provider(model_info.provider_name)
def get_available_providers(self) -> Iterator[ChatModelProvider]:
for provider_name in ModelProviderName:
try:
yield self._get_provider(provider_name)
except Exception:
pass
def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
_provider = self._provider_instances.get(provider_name)
if not _provider:
Provider = self._get_provider_class(provider_name)
settings = Provider.default_settings.copy(deep=True)
settings.budget = self._budget
settings.configuration.extra_request_headers.update(
self._settings.configuration.extra_request_headers
)
if settings.credentials is None:
try:
Credentials = settings.__fields__["credentials"].type_
settings.credentials = Credentials.from_env()
except ValidationError as e:
raise ValueError(
f"{provider_name} is unavailable: can't load credentials"
) from e
self._provider_instances[provider_name] = _provider = Provider(
settings=settings, logger=self._logger
)
_provider._budget = self._budget # Object binding not preserved by Pydantic
return _provider
@classmethod
def _get_model_provider_class(
cls, model_name: ModelName
) -> type[AnthropicProvider | OpenAIProvider]:
return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
@classmethod
def _get_provider_class(
cls, provider_name: ModelProviderName
) -> type[AnthropicProvider | OpenAIProvider]:
try:
return {
ModelProviderName.ANTHROPIC: AnthropicProvider,
ModelProviderName.OPENAI: OpenAIProvider,
}[provider_name]
except KeyError:
raise ValueError(f"{provider_name} is not a known provider") from None
def __repr__(self):
return f"{self.__class__.__name__}()"

View File

@@ -42,6 +42,8 @@ from autogpt.core.resource.model_providers.schema import (
from autogpt.core.utils.json_schema import JSONSchema from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads from autogpt.core.utils.json_utils import json_loads
from .utils import validate_tool_calls
_T = TypeVar("_T") _T = TypeVar("_T")
_P = ParamSpec("_P") _P = ParamSpec("_P")
@@ -298,6 +300,7 @@ class OpenAIProvider(
budget=ModelProviderBudget(), budget=ModelProviderBudget(),
) )
_settings: OpenAISettings
_configuration: OpenAIConfiguration _configuration: OpenAIConfiguration
_credentials: OpenAICredentials _credentials: OpenAICredentials
_budget: ModelProviderBudget _budget: ModelProviderBudget
@@ -312,11 +315,7 @@ class OpenAIProvider(
if not settings.credentials: if not settings.credentials:
settings.credentials = OpenAICredentials.from_env() settings.credentials = OpenAICredentials.from_env()
self._settings = settings super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
self._configuration = settings.configuration
self._credentials = settings.credentials
self._budget = settings.budget
if self._credentials.api_type == "azure": if self._credentials.api_type == "azure":
from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
@@ -329,8 +328,6 @@ class OpenAIProvider(
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs()) self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
self._logger = logger or logging.getLogger(__name__)
async def get_available_models(self) -> list[ChatModelInfo]: async def get_available_models(self) -> list[ChatModelInfo]:
_models = (await self._client.models.list()).data _models = (await self._client.models.list()).data
return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS] return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS]
@@ -398,6 +395,7 @@ class OpenAIProvider(
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None, functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
prefill_response: str = "", # not supported by OpenAI
**kwargs, **kwargs,
) -> ChatModelResponse[_T]: ) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API and parse it.""" """Create a completion using the OpenAI API and parse it."""
@@ -432,6 +430,10 @@ class OpenAIProvider(
) )
parse_errors += _errors parse_errors += _errors
# Validate tool calls
if not parse_errors and tool_calls and functions:
parse_errors += validate_tool_calls(tool_calls, functions)
assistant_msg = AssistantChatMessage( assistant_msg = AssistantChatMessage(
content=_assistant_msg.content, content=_assistant_msg.content,
tool_calls=tool_calls or None, tool_calls=tool_calls or None,

View File

@@ -1,8 +1,10 @@
import abc import abc
import enum import enum
import logging
import math import math
from collections import defaultdict from collections import defaultdict
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
ClassVar, ClassVar,
@@ -28,6 +30,9 @@ from autogpt.core.resource.schema import (
from autogpt.core.utils.json_schema import JSONSchema from autogpt.core.utils.json_schema import JSONSchema
from autogpt.logs.utils import fmt_kwargs from autogpt.logs.utils import fmt_kwargs
if TYPE_CHECKING:
from jsonschema import ValidationError
class ModelProviderService(str, enum.Enum): class ModelProviderService(str, enum.Enum):
"""A ModelService describes what kind of service the model provides.""" """A ModelService describes what kind of service the model provides."""
@@ -39,6 +44,7 @@ class ModelProviderService(str, enum.Enum):
class ModelProviderName(str, enum.Enum): class ModelProviderName(str, enum.Enum):
OPENAI = "openai" OPENAI = "openai"
ANTHROPIC = "anthropic"
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@@ -100,6 +106,12 @@ class AssistantChatMessage(ChatMessage):
tool_calls: Optional[list[AssistantToolCall]] = None tool_calls: Optional[list[AssistantToolCall]] = None
class ToolResultMessage(ChatMessage):
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL
is_error: bool = False
tool_call_id: str
class AssistantChatMessageDict(TypedDict, total=False): class AssistantChatMessageDict(TypedDict, total=False):
role: str role: str
content: str content: str
@@ -146,6 +158,30 @@ class CompletionModelFunction(BaseModel):
) )
return f"{self.name}: {self.description}. Params: ({params})" return f"{self.name}: {self.description}. Params: ({params})"
def validate_call(
self, function_call: AssistantFunctionCall
) -> tuple[bool, list["ValidationError"]]:
"""
Validates the given function call against the function's parameter specs
Returns:
bool: Whether the given set of arguments is valid for this command
list[ValidationError]: Issues with the set of arguments (if any)
Raises:
ValueError: If the function_call doesn't call this function
"""
if function_call.name != self.name:
raise ValueError(
f"Can't validate {function_call.name} call using {self.name} spec"
)
params_schema = JSONSchema(
type=JSONSchema.Type.OBJECT,
properties={name: spec for name, spec in self.parameters.items()},
)
return params_schema.validate_object(function_call.arguments)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
"""Struct for model information. """Struct for model information.
@@ -229,7 +265,7 @@ class ModelProviderBudget(ProviderBudget):
class ModelProviderSettings(ProviderSettings): class ModelProviderSettings(ProviderSettings):
resource_type: ResourceType = ResourceType.MODEL resource_type: ResourceType = ResourceType.MODEL
configuration: ModelProviderConfiguration configuration: ModelProviderConfiguration
credentials: ModelProviderCredentials credentials: Optional[ModelProviderCredentials] = None
budget: Optional[ModelProviderBudget] = None budget: Optional[ModelProviderBudget] = None
@@ -238,9 +274,28 @@ class ModelProvider(abc.ABC):
default_settings: ClassVar[ModelProviderSettings] default_settings: ClassVar[ModelProviderSettings]
_settings: ModelProviderSettings
_configuration: ModelProviderConfiguration _configuration: ModelProviderConfiguration
_credentials: Optional[ModelProviderCredentials] = None
_budget: Optional[ModelProviderBudget] = None _budget: Optional[ModelProviderBudget] = None
_logger: logging.Logger
def __init__(
self,
settings: Optional[ModelProviderSettings] = None,
logger: Optional[logging.Logger] = None,
):
if not settings:
settings = self.default_settings.copy(deep=True)
self._settings = settings
self._configuration = settings.configuration
self._credentials = settings.credentials
self._budget = settings.budget
self._logger = logger or logging.getLogger(self.__module__)
@abc.abstractmethod @abc.abstractmethod
def count_tokens(self, text: str, model_name: str) -> int: def count_tokens(self, text: str, model_name: str) -> int:
... ...
@@ -358,6 +413,7 @@ class ChatModelProvider(ModelProvider):
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None, functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
prefill_response: str = "",
**kwargs, **kwargs,
) -> ChatModelResponse[_T]: ) -> ChatModelResponse[_T]:
... ...

View File

@@ -0,0 +1,71 @@
from typing import Any
from .schema import AssistantToolCall, CompletionModelFunction
class InvalidFunctionCallError(Exception):
def __init__(self, name: str, arguments: dict[str, Any], message: str):
self.message = message
self.name = name
self.arguments = arguments
super().__init__(message)
def __str__(self) -> str:
return f"Invalid function call for {self.name}: {self.message}"
def validate_tool_calls(
tool_calls: list[AssistantToolCall], functions: list[CompletionModelFunction]
) -> list[InvalidFunctionCallError]:
"""
Validates a list of tool calls against a list of functions.
1. Tries to find a function matching each tool call
2. If a matching function is found, validates the tool call's arguments,
reporting any resulting errors
2. If no matching function is found, an error "Unknown function X" is reported
3. A list of all errors encountered during validation is returned
Params:
tool_calls: A list of tool calls to validate.
functions: A list of functions to validate against.
Returns:
list[InvalidFunctionCallError]: All errors encountered during validation.
"""
errors: list[InvalidFunctionCallError] = []
for tool_call in tool_calls:
function_call = tool_call.function
if function := next(
(f for f in functions if f.name == function_call.name),
None,
):
is_valid, validation_errors = function.validate_call(function_call)
if not is_valid:
fmt_errors = [
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
if f.path
else f.message
for f in validation_errors
]
errors.append(
InvalidFunctionCallError(
name=function_call.name,
arguments=function_call.arguments,
message=(
"The set of arguments supplied is invalid:\n"
+ "\n".join(fmt_errors)
),
)
)
else:
errors.append(
InvalidFunctionCallError(
name=function_call.name,
arguments=function_call.arguments,
message=f"Unknown function {function_call.name}",
)
)
return errors

View File

@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Callable) T = TypeVar("T", bound=Callable)
def get_openai_command_specs( def function_specs_from_commands(
commands: Iterable[Command], commands: Iterable[Command],
) -> list[CompletionModelFunction]: ) -> list[CompletionModelFunction]:
"""Get OpenAI-consumable function specs for the agent's available commands. """Get OpenAI-consumable function specs for the agent's available commands.

View File

@@ -3,8 +3,6 @@ from __future__ import annotations
import inspect import inspect
from typing import Any, Callable from typing import Any, Callable
from autogpt.core.utils.json_schema import JSONSchema
from .command_parameter import CommandParameter from .command_parameter import CommandParameter
from .context_item import ContextItem from .context_item import ContextItem
@@ -42,20 +40,6 @@ class Command:
def is_async(self) -> bool: def is_async(self) -> bool:
return inspect.iscoroutinefunction(self.method) return inspect.iscoroutinefunction(self.method)
def validate_args(self, args: dict[str, Any]):
"""
Validates the given arguments against the command's parameter specifications
Returns:
bool: Whether the given set of arguments is valid for this command
list[ValidationError]: Issues with the set of arguments (if any)
"""
params_schema = JSONSchema(
type=JSONSchema.Type.OBJECT,
properties={p.name: p.spec for p in self.parameters},
)
return params_schema.validate_object(args)
def _parameters_match( def _parameters_match(
self, func: Callable, parameters: list[CommandParameter] self, func: Callable, parameters: list[CommandParameter]
) -> bool: ) -> bool:

View File

@@ -167,6 +167,30 @@ files = [
[package.dependencies] [package.dependencies]
frozenlist = ">=1.1.0" frozenlist = ">=1.1.0"
[[package]]
name = "anthropic"
version = "0.25.1"
description = "The official Python library for the anthropic API"
optional = false
python-versions = ">=3.7"
files = [
{file = "anthropic-0.25.1-py3-none-any.whl", hash = "sha256:95d0cedc2a4b5beae3a78f9030aea4001caea5f46c6d263cce377c891c594e71"},
{file = "anthropic-0.25.1.tar.gz", hash = "sha256:0c01b30b77d041a8d07c532737bae69da58086031217150008e4541f52a64bd9"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
tokenizers = ">=0.13.0"
typing-extensions = ">=4.7,<5"
[package.extras]
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
vertex = ["google-auth (>=2,<3)"]
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "4.2.0" version = "4.2.0"
@@ -7234,4 +7258,4 @@ benchmark = ["agbenchmark"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "e6eab5c079d53f075ce701e86a2007e7ebeb635ac067d25f555bfea363bcc630" content-hash = "ad1e3c4706465733d04ddab975af630975bd528efce152c1da01eded53069eca"

View File

@@ -22,6 +22,7 @@ serve = "autogpt.app.cli:serve"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.10" python = "^3.10"
anthropic = "^0.25.1"
# autogpt-forge = { path = "../forge" } # autogpt-forge = { path = "../forge" }
autogpt-forge = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "autogpts/forge"} autogpt-forge = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "autogpts/forge"}
beautifulsoup4 = "^4.12.2" beautifulsoup4 = "^4.12.2"

View File

@@ -8,9 +8,9 @@ import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
from autogpt.app.main import _configure_openai_provider from autogpt.app.main import _configure_llm_provider
from autogpt.config import AIProfile, Config, ConfigBuilder from autogpt.config import AIProfile, Config, ConfigBuilder
from autogpt.core.resource.model_providers import ChatModelProvider, OpenAIProvider from autogpt.core.resource.model_providers import ChatModelProvider
from autogpt.file_storage.local import ( from autogpt.file_storage.local import (
FileStorage, FileStorage,
FileStorageConfiguration, FileStorageConfiguration,
@@ -73,8 +73,8 @@ def setup_logger(config: Config):
@pytest.fixture @pytest.fixture
def llm_provider(config: Config) -> OpenAIProvider: def llm_provider(config: Config) -> ChatModelProvider:
return _configure_openai_provider(config) return _configure_llm_provider(config)
@pytest.fixture @pytest.fixture

View File

@@ -14,7 +14,6 @@ from pydantic import SecretStr
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config
from autogpt.config import Config, ConfigBuilder from autogpt.config import Config, ConfigBuilder
from autogpt.core.resource.model_providers.openai import OpenAIModelName
from autogpt.core.resource.model_providers.schema import ( from autogpt.core.resource.model_providers.schema import (
ChatModelInfo, ChatModelInfo,
ModelProviderName, ModelProviderName,
@@ -39,8 +38,8 @@ async def test_fallback_to_gpt3_if_gpt4_not_available(
""" """
Test if models update to gpt-3.5-turbo if gpt-4 is not available. Test if models update to gpt-3.5-turbo if gpt-4 is not available.
""" """
config.fast_llm = OpenAIModelName.GPT4_TURBO config.fast_llm = GPT_4_MODEL
config.smart_llm = OpenAIModelName.GPT4_TURBO config.smart_llm = GPT_4_MODEL
mock_list_models.return_value = asyncio.Future() mock_list_models.return_value = asyncio.Future()
mock_list_models.return_value.set_result( mock_list_models.return_value.set_result(
@@ -56,8 +55,8 @@ async def test_fallback_to_gpt3_if_gpt4_not_available(
gpt4only=False, gpt4only=False,
) )
assert config.fast_llm == "gpt-3.5-turbo" assert config.fast_llm == GPT_3_MODEL
assert config.smart_llm == "gpt-3.5-turbo" assert config.smart_llm == GPT_3_MODEL
def test_missing_azure_config(config: Config) -> None: def test_missing_azure_config(config: Config) -> None:
@@ -148,8 +147,7 @@ def test_azure_config(config_with_azure: Config) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_config_gpt4only(config: Config) -> None: async def test_create_config_gpt4only(config: Config) -> None:
with mock.patch( with mock.patch(
"autogpt.core.resource.model_providers.openai." "autogpt.core.resource.model_providers.multi.MultiProvider.get_available_models"
"OpenAIProvider.get_available_models"
) as mock_get_models: ) as mock_get_models:
mock_get_models.return_value = [ mock_get_models.return_value = [
ChatModelInfo( ChatModelInfo(
@@ -169,8 +167,7 @@ async def test_create_config_gpt4only(config: Config) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_config_gpt3only(config: Config) -> None: async def test_create_config_gpt3only(config: Config) -> None:
with mock.patch( with mock.patch(
"autogpt.core.resource.model_providers.openai." "autogpt.core.resource.model_providers.multi.MultiProvider.get_available_models"
"OpenAIProvider.get_available_models"
) as mock_get_models: ) as mock_get_models:
mock_get_models.return_value = [ mock_get_models.return_value = [
ChatModelInfo( ChatModelInfo(

View File

@@ -7,6 +7,7 @@ Configuration is controlled through the `Config` object. You can set configurati
- `AI_SETTINGS_FILE`: Location of the AI Settings file relative to the AutoGPT root directory. Default: ai_settings.yaml - `AI_SETTINGS_FILE`: Location of the AI Settings file relative to the AutoGPT root directory. Default: ai_settings.yaml
- `AUDIO_TO_TEXT_PROVIDER`: Audio To Text Provider. Only option currently is `huggingface`. Default: huggingface - `AUDIO_TO_TEXT_PROVIDER`: Audio To Text Provider. Only option currently is `huggingface`. Default: huggingface
- `AUTHORISE_COMMAND_KEY`: Key response accepted when authorising commands. Default: y - `AUTHORISE_COMMAND_KEY`: Key response accepted when authorising commands. Default: y
- `ANTHROPIC_API_KEY`: Set this if you want to use Anthropic models with AutoGPT
- `AZURE_CONFIG_FILE`: Location of the Azure Config file relative to the AutoGPT root directory. Default: azure.yaml - `AZURE_CONFIG_FILE`: Location of the Azure Config file relative to the AutoGPT root directory. Default: azure.yaml
- `BROWSE_CHUNK_MAX_LENGTH`: When browsing website, define the length of chunks to summarize. Default: 3000 - `BROWSE_CHUNK_MAX_LENGTH`: When browsing website, define the length of chunks to summarize. Default: 3000
- `BROWSE_SPACY_LANGUAGE_MODEL`: [spaCy language model](https://spacy.io/usage/models) to use when creating chunks. Default: en_core_web_sm - `BROWSE_SPACY_LANGUAGE_MODEL`: [spaCy language model](https://spacy.io/usage/models) to use when creating chunks. Default: en_core_web_sm