mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-17 05:54:26 +01:00
feat(agent/core): Add Anthropic Claude 3 support (#7085)
- feat(agent/core): Add `AnthropicProvider`
- Add `ANTHROPIC_API_KEY` to .env.template and docs
Notable differences in logic compared to `OpenAIProvider`:
- Merges subsequent user messages in `AnthropicProvider._get_chat_completion_args`
- Merges and extracts all system messages into `system` parameter in `AnthropicProvider._get_chat_completion_args`
- Supports prefill; merges prefill content (if any) into generated response
- Prompt changes to improve compatibility with `AnthropicProvider`
Anthropic has a slightly different API compared to OpenAI, and has much stricter input validation. E.g. Anthropic only supports a single `system` prompt, where OpenAI allows multiple `system` messages. Anthropic also forbids sequences of multiple `user` or `assistant` messages and requires that messages alternate between roles.
- Move response format instruction from separate message into main system prompt
- Fix clock message format
- Add pre-fill to `OneShot` generated prompt
- refactor(agent/core): Tweak `model_providers.schema`
- Simplify `ModelProviderUsage`
- Remove attribute `total_tokens` as it is always equal to `prompt_tokens + completion_tokens`
- Modify signature of `update_usage(..)`; no longer requires a full `ModelResponse` object as input
- Improve `ModelProviderBudget`
- Change type of attribute `usage` to `defaultdict[str, ModelProviderUsage]` -> allow per-model usage tracking
- Modify signature of `update_usage_and_cost(..)`; no longer requires a full `ModelResponse` object as input
- Allow `ModelProviderBudget` zero-argument instantiation
- Fix type of `AssistantChatMessage.role` to match `ChatMessage.role` (str -> `ChatMessage.Role`)
- Add shared attributes and constructor to `ModelProvider` base class
- Add `max_output_tokens` parameter to `create_chat_completion` interface
- Add pre-filling as a global feature
- Add `prefill_response` field to `ChatPrompt` model
- Add `prefill_response` parameter to `create_chat_completion` interface
- Add `ChatModelProvider.get_available_models()` and remove `ApiManager`
- Remove unused `OpenAIChatParser` typedef in openai.py
- Remove redundant `budget` attribute definition on `OpenAISettings`
- Remove unnecessary `usage` in `OpenAIProvider` > `default_settings` > `budget`
- feat(agent): Allow use of any available LLM provider through `MultiProvider`
- Add `MultiProvider` (`model_providers.multi`)
- Replace all references to / uses of `OpenAIProvider` with `MultiProvider`
- Change type of `Config.smart_llm` and `Config.fast_llm` from `str` to `ModelName`
- feat(agent/core): Validate function call arguments in `create_chat_completion`
- Add `validate_call` method to `CompletionModelFunction` in `model_providers.schema`
- Add `validate_tool_calls` utility function in `model_providers.utils`
- Add tool call validation step to `create_chat_completion` in `OpenAIProvider` and `AnthropicProvider`
- Remove (now redundant) command argument validation logic in agent.py and models/command.py
- refactor(agent): Rename `get_openai_command_specs` to `function_specs_from_commands`
This commit is contained in:
committed by
GitHub
parent
78d83bb3ce
commit
39c46ef6be
@@ -2,8 +2,11 @@
|
||||
### AutoGPT - GENERAL SETTINGS
|
||||
################################################################################
|
||||
|
||||
## OPENAI_API_KEY - OpenAI API Key (Example: my-openai-api-key)
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
## OPENAI_API_KEY - OpenAI API Key (Example: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
|
||||
# OPENAI_API_KEY=
|
||||
|
||||
## ANTHROPIC_API_KEY - Anthropic API Key (Example: sk-ant-api03-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
|
||||
# ANTHROPIC_API_KEY=
|
||||
|
||||
## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry.
|
||||
## This helps us to spot and solve problems earlier & faster. (Default: DISABLED)
|
||||
|
||||
@@ -5,8 +5,7 @@ from pathlib import Path
|
||||
|
||||
from autogpt.agent_manager.agent_manager import AgentManager
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptStrategy
|
||||
from autogpt.app.main import _configure_openai_provider, run_interaction_loop
|
||||
from autogpt.app.main import _configure_llm_provider, run_interaction_loop
|
||||
from autogpt.config import AIProfile, ConfigBuilder
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
from autogpt.logs.config import configure_logging
|
||||
@@ -38,10 +37,6 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
ai_goals=[task],
|
||||
)
|
||||
|
||||
agent_prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(
|
||||
deep=True
|
||||
)
|
||||
agent_prompt_config.use_functions_api = config.openai_functions
|
||||
agent_settings = AgentSettings(
|
||||
name=Agent.default_settings.name,
|
||||
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
|
||||
@@ -53,7 +48,6 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
allow_fs_access=not config.restrict_to_workspace,
|
||||
use_functions_api=config.openai_functions,
|
||||
),
|
||||
prompt_config=agent_prompt_config,
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
@@ -66,7 +60,7 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
|
||||
agent = Agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=_configure_openai_provider(config),
|
||||
llm_provider=_configure_llm_provider(config),
|
||||
file_storage=file_storage,
|
||||
legacy_config=config,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ from autogpt.components.event_history import EventHistoryComponent
|
||||
from autogpt.core.configuration import Configurable
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
@@ -27,7 +26,7 @@ from autogpt.core.resource.model_providers import (
|
||||
)
|
||||
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.llm.providers.openai import get_openai_command_specs
|
||||
from autogpt.llm.providers.openai import function_specs_from_commands
|
||||
from autogpt.logs.log_cycle import (
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
NEXT_ACTION_FILE_NAME,
|
||||
@@ -46,7 +45,6 @@ from autogpt.utils.exceptions import (
|
||||
AgentException,
|
||||
AgentTerminated,
|
||||
CommandExecutionError,
|
||||
InvalidArgumentError,
|
||||
UnknownCommandError,
|
||||
)
|
||||
|
||||
@@ -104,7 +102,11 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
self.ai_profile = settings.ai_profile
|
||||
self.directives = settings.directives
|
||||
prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(deep=True)
|
||||
prompt_config.use_functions_api = settings.config.use_functions_api
|
||||
prompt_config.use_functions_api = (
|
||||
settings.config.use_functions_api
|
||||
# Anthropic currently doesn't support tools + prefilling :(
|
||||
and self.llm.provider_name != "anthropic"
|
||||
)
|
||||
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
|
||||
self.commands: list[Command] = []
|
||||
|
||||
@@ -172,7 +174,7 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
task=self.state.task,
|
||||
ai_profile=self.state.ai_profile,
|
||||
ai_directives=directives,
|
||||
commands=get_openai_command_specs(self.commands),
|
||||
commands=function_specs_from_commands(self.commands),
|
||||
include_os_info=self.legacy_config.execute_local_commands,
|
||||
)
|
||||
|
||||
@@ -202,12 +204,9 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
] = await self.llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
model_name=self.llm.name,
|
||||
completion_parser=self.parse_and_validate_response,
|
||||
functions=(
|
||||
get_openai_command_specs(self.commands)
|
||||
if self.config.use_functions_api
|
||||
else []
|
||||
),
|
||||
completion_parser=self.prompt_strategy.parse_response_content,
|
||||
functions=prompt.functions,
|
||||
prefill_response=prompt.prefill_response,
|
||||
)
|
||||
result = response.parsed_result
|
||||
|
||||
@@ -223,28 +222,6 @@ class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
|
||||
return result
|
||||
|
||||
def parse_and_validate_response(
|
||||
self, llm_response: AssistantChatMessage
|
||||
) -> OneShotAgentActionProposal:
|
||||
parsed_response = self.prompt_strategy.parse_response_content(llm_response)
|
||||
|
||||
# Validate command arguments
|
||||
command_name = parsed_response.use_tool.name
|
||||
command = self._get_command(command_name)
|
||||
if arg_errors := command.validate_args(parsed_response.use_tool.arguments)[1]:
|
||||
fmt_errors = [
|
||||
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
|
||||
if f.path
|
||||
else f.message
|
||||
for f in arg_errors
|
||||
]
|
||||
raise InvalidArgumentError(
|
||||
f"The set of arguments supplied for {command_name} is invalid:\n"
|
||||
+ "\n".join(fmt_errors)
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
proposal: OneShotAgentActionProposal,
|
||||
|
||||
@@ -39,11 +39,12 @@ from autogpt.core.configuration import (
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from autogpt.core.resource.model_providers import AssistantFunctionCall
|
||||
from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAIModelName,
|
||||
from autogpt.core.resource.model_providers import (
|
||||
CHAT_MODELS,
|
||||
AssistantFunctionCall,
|
||||
ModelName,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIModelName
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
|
||||
|
||||
@@ -56,8 +57,8 @@ P = ParamSpec("P")
|
||||
class BaseAgentConfiguration(SystemConfiguration):
|
||||
allow_fs_access: bool = UserConfigurable(default=False)
|
||||
|
||||
fast_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
|
||||
smart_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT4)
|
||||
fast_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
|
||||
smart_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT4)
|
||||
use_functions_api: bool = UserConfigurable(default=False)
|
||||
|
||||
default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT
|
||||
@@ -174,7 +175,7 @@ class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
llm_name = (
|
||||
self.config.smart_llm if self.config.big_brain else self.config.fast_llm
|
||||
)
|
||||
return OPEN_AI_CHAT_MODELS[llm_name]
|
||||
return CHAT_MODELS[llm_name]
|
||||
|
||||
@property
|
||||
def send_token_limit(self) -> int:
|
||||
|
||||
@@ -122,7 +122,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
1. System prompt
|
||||
3. `cycle_instruction`
|
||||
"""
|
||||
system_prompt = self.build_system_prompt(
|
||||
system_prompt, response_prefill = self.build_system_prompt(
|
||||
ai_profile=ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=commands,
|
||||
@@ -131,24 +131,34 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
|
||||
final_instruction_msg = ChatMessage.user(self.config.choose_action_instruction)
|
||||
|
||||
prompt = ChatPrompt(
|
||||
return ChatPrompt(
|
||||
messages=[
|
||||
ChatMessage.system(system_prompt),
|
||||
ChatMessage.user(f'"""{task}"""'),
|
||||
*messages,
|
||||
final_instruction_msg,
|
||||
],
|
||||
prefill_response=response_prefill,
|
||||
functions=commands if self.config.use_functions_api else [],
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
) -> str:
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds the system prompt.
|
||||
|
||||
Returns:
|
||||
str: The system prompt body
|
||||
str: The desired start for the LLM's response; used to steer the output
|
||||
"""
|
||||
response_fmt_instruction, response_prefill = self.response_format_instruction(
|
||||
self.config.use_functions_api
|
||||
)
|
||||
system_prompt_parts = (
|
||||
self._generate_intro_prompt(ai_profile)
|
||||
+ (self._generate_os_info() if include_os_info else [])
|
||||
@@ -169,16 +179,16 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
" in the next message. Your job is to complete the task while following"
|
||||
" your directives as given above, and terminate when your task is done."
|
||||
]
|
||||
+ [
|
||||
"## RESPONSE FORMAT\n"
|
||||
+ self.response_format_instruction(self.config.use_functions_api)
|
||||
]
|
||||
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
|
||||
)
|
||||
|
||||
# Join non-empty parts together into paragraph format
|
||||
return "\n\n".join(filter(None, system_prompt_parts)).strip("\n")
|
||||
return (
|
||||
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
|
||||
response_prefill,
|
||||
)
|
||||
|
||||
def response_format_instruction(self, use_functions_api: bool) -> str:
|
||||
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
|
||||
response_schema = self.response_schema.copy(deep=True)
|
||||
if (
|
||||
use_functions_api
|
||||
@@ -193,11 +203,15 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
"\n",
|
||||
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
|
||||
)
|
||||
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
|
||||
|
||||
return (
|
||||
(
|
||||
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
|
||||
f"{response_format}"
|
||||
+ ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "")
|
||||
),
|
||||
response_prefill,
|
||||
)
|
||||
|
||||
def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:
|
||||
|
||||
@@ -34,7 +34,6 @@ from autogpt.agent_manager import AgentManager
|
||||
from autogpt.app.utils import is_port_free
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider, ModelProviderBudget
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||
from autogpt.file_storage import FileStorage
|
||||
from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult
|
||||
from autogpt.utils.exceptions import AgentFinished
|
||||
@@ -464,20 +463,18 @@ class AgentProtocolServer:
|
||||
if task.additional_input and (user_id := task.additional_input.get("user_id")):
|
||||
_extra_request_headers["AutoGPT-UserID"] = user_id
|
||||
|
||||
task_llm_provider = None
|
||||
if isinstance(self.llm_provider, OpenAIProvider):
|
||||
settings = self.llm_provider._settings.copy()
|
||||
settings.budget = task_llm_budget
|
||||
settings.configuration = task_llm_provider_config # type: ignore
|
||||
task_llm_provider = OpenAIProvider(
|
||||
settings.configuration = task_llm_provider_config
|
||||
task_llm_provider = self.llm_provider.__class__(
|
||||
settings=settings,
|
||||
logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"),
|
||||
logger=logger.getChild(
|
||||
f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}"
|
||||
),
|
||||
)
|
||||
self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore
|
||||
|
||||
if task_llm_provider and task_llm_provider._budget:
|
||||
self._task_budgets[task.task_id] = task_llm_provider._budget
|
||||
|
||||
return task_llm_provider or self.llm_provider
|
||||
return task_llm_provider
|
||||
|
||||
|
||||
def task_agent_id(task_id: str | int) -> str:
|
||||
|
||||
@@ -10,7 +10,7 @@ from colorama import Back, Fore, Style
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIModelName, OpenAIProvider
|
||||
from autogpt.core.resource.model_providers import ModelName, MultiProvider
|
||||
from autogpt.logs.helpers import request_user_double_check
|
||||
from autogpt.memory.vector import get_supported_memory_backends
|
||||
from autogpt.utils import utils
|
||||
@@ -150,11 +150,11 @@ async def apply_overrides_to_config(
|
||||
|
||||
|
||||
async def check_model(
|
||||
model_name: OpenAIModelName, model_type: Literal["smart_llm", "fast_llm"]
|
||||
) -> OpenAIModelName:
|
||||
model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"]
|
||||
) -> ModelName:
|
||||
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
||||
openai = OpenAIProvider()
|
||||
models = await openai.get_available_models()
|
||||
multi_provider = MultiProvider()
|
||||
models = await multi_provider.get_available_models()
|
||||
|
||||
if any(model_name == m.name for m in models):
|
||||
return model_name
|
||||
|
||||
@@ -35,7 +35,7 @@ from autogpt.config import (
|
||||
ConfigBuilder,
|
||||
assert_config_has_openai_api_key,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||
from autogpt.core.resource.model_providers import MultiProvider
|
||||
from autogpt.core.runner.client_lib.utils import coroutine
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
from autogpt.logs.config import configure_logging
|
||||
@@ -123,7 +123,7 @@ async def run_auto_gpt(
|
||||
skip_news=skip_news,
|
||||
)
|
||||
|
||||
llm_provider = _configure_openai_provider(config)
|
||||
llm_provider = _configure_llm_provider(config)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -399,7 +399,7 @@ async def run_auto_gpt_server(
|
||||
allow_downloads=allow_downloads,
|
||||
)
|
||||
|
||||
llm_provider = _configure_openai_provider(config)
|
||||
llm_provider = _configure_llm_provider(config)
|
||||
|
||||
# Set up & start server
|
||||
database = AgentDB(
|
||||
@@ -421,24 +421,12 @@ async def run_auto_gpt_server(
|
||||
)
|
||||
|
||||
|
||||
def _configure_openai_provider(config: Config) -> OpenAIProvider:
|
||||
"""Create a configured OpenAIProvider object.
|
||||
|
||||
Args:
|
||||
config: The program's configuration.
|
||||
|
||||
Returns:
|
||||
A configured OpenAIProvider object.
|
||||
"""
|
||||
if config.openai_credentials is None:
|
||||
raise RuntimeError("OpenAI key is not configured")
|
||||
|
||||
openai_settings = OpenAIProvider.default_settings.copy(deep=True)
|
||||
openai_settings.credentials = config.openai_credentials
|
||||
return OpenAIProvider(
|
||||
settings=openai_settings,
|
||||
logger=logging.getLogger("OpenAIProvider"),
|
||||
)
|
||||
def _configure_llm_provider(config: Config) -> MultiProvider:
|
||||
multi_provider = MultiProvider()
|
||||
for model in [config.smart_llm, config.fast_llm]:
|
||||
# Ensure model providers for configured LLMs are available
|
||||
multi_provider.get_model_provider(model)
|
||||
return multi_provider
|
||||
|
||||
|
||||
def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float:
|
||||
|
||||
@@ -31,7 +31,9 @@ class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
# Clock
|
||||
yield ChatMessage.system(f"The current time and date is {time.strftime('%c')}")
|
||||
yield ChatMessage.system(
|
||||
f"## Clock\nThe current time and date is {time.strftime('%c')}"
|
||||
)
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.finish
|
||||
|
||||
@@ -17,8 +17,8 @@ from autogpt.core.configuration.schema import (
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from autogpt.core.resource.model_providers import CHAT_MODELS, ModelName
|
||||
from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAICredentials,
|
||||
OpenAIModelName,
|
||||
)
|
||||
@@ -74,11 +74,11 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
)
|
||||
|
||||
# Model configuration
|
||||
fast_llm: OpenAIModelName = UserConfigurable(
|
||||
fast_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT3,
|
||||
from_env="FAST_LLM",
|
||||
)
|
||||
smart_llm: OpenAIModelName = UserConfigurable(
|
||||
smart_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT4_TURBO,
|
||||
from_env="SMART_LLM",
|
||||
)
|
||||
@@ -206,8 +206,8 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
smart_llm = values["smart_llm"]
|
||||
assert OPEN_AI_CHAT_MODELS[smart_llm].has_function_call_api, (
|
||||
f"Model {smart_llm} does not support OpenAI Functions. "
|
||||
assert CHAT_MODELS[smart_llm].has_function_call_api, (
|
||||
f"Model {smart_llm} does not support tool calling. "
|
||||
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
|
||||
)
|
||||
return v
|
||||
|
||||
@@ -24,6 +24,7 @@ class LanguageModelClassification(str, enum.Enum):
|
||||
class ChatPrompt(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
functions: list[CompletionModelFunction] = Field(default_factory=list)
|
||||
prefill_response: str = ""
|
||||
|
||||
def raw(self) -> list[ChatMessageDict]:
|
||||
return [m.dict() for m in self.messages]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .multi import CHAT_MODELS, ModelName, MultiProvider
|
||||
from .openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OPEN_AI_EMBEDDING_MODELS,
|
||||
@@ -42,11 +43,13 @@ __all__ = [
|
||||
"ChatModelProvider",
|
||||
"ChatModelResponse",
|
||||
"CompletionModelFunction",
|
||||
"CHAT_MODELS",
|
||||
"Embedding",
|
||||
"EmbeddingModelInfo",
|
||||
"EmbeddingModelProvider",
|
||||
"EmbeddingModelResponse",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
"ModelProvider",
|
||||
"ModelProviderBudget",
|
||||
"ModelProviderCredentials",
|
||||
@@ -56,6 +59,7 @@ __all__ = [
|
||||
"ModelProviderUsage",
|
||||
"ModelResponse",
|
||||
"ModelTokenizer",
|
||||
"MultiProvider",
|
||||
"OPEN_AI_MODELS",
|
||||
"OPEN_AI_CHAT_MODELS",
|
||||
"OPEN_AI_EMBEDDING_MODELS",
|
||||
|
||||
@@ -0,0 +1,495 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
import tiktoken
|
||||
from anthropic import APIConnectionError, APIStatusError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration import Configurable, UserConfigurable
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
ToolResultMessage,
|
||||
)
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types.beta.tools import MessageCreateParams
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage as Message
|
||||
from anthropic.types.beta.tools import ToolsBetaMessageParam as MessageParam
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class AnthropicModelName(str, enum.Enum):
|
||||
CLAUDE3_OPUS_v1 = "claude-3-opus-20240229"
|
||||
CLAUDE3_SONNET_v1 = "claude-3-sonnet-20240229"
|
||||
CLAUDE3_HAIKU_v1 = "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
ANTHROPIC_CHAT_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_OPUS_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=15 / 1e6,
|
||||
completion_token_cost=75 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_SONNET_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=3 / 1e6,
|
||||
completion_token_cost=15 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_HAIKU_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=0.25 / 1e6,
|
||||
completion_token_cost=1.25 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class AnthropicConfiguration(ModelProviderConfiguration):
|
||||
fix_failed_parse_tries: int = UserConfigurable(3)
|
||||
|
||||
|
||||
class AnthropicCredentials(ModelProviderCredentials):
|
||||
"""Credentials for Anthropic."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY")
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="ANTHROPIC_API_BASE_URL"
|
||||
)
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
return {
|
||||
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
|
||||
class AnthropicSettings(ModelProviderSettings):
|
||||
configuration: AnthropicConfiguration
|
||||
credentials: Optional[AnthropicCredentials]
|
||||
budget: ModelProviderBudget
|
||||
|
||||
|
||||
class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
default_settings = AnthropicSettings(
|
||||
name="anthropic_provider",
|
||||
description="Provides access to Anthropic's API.",
|
||||
configuration=AnthropicConfiguration(
|
||||
retries_per_request=7,
|
||||
),
|
||||
credentials=None,
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: AnthropicSettings
|
||||
_configuration: AnthropicConfiguration
|
||||
_credentials: AnthropicCredentials
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[AnthropicSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
if not settings.credentials:
|
||||
settings.credentials = AnthropicCredentials.from_env()
|
||||
|
||||
super(AnthropicProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
self._client = AsyncAnthropic(**self._credentials.get_api_access_kwargs())
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
return list(ANTHROPIC_CHAT_MODELS.values())
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: AnthropicModelName) -> ModelTokenizer:
|
||||
# HACK: No official tokenizer is available for Claude 3
|
||||
return tiktoken.encoding_for_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: AnthropicModelName) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Anthropic API."""
|
||||
anthropic_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
prompt_messages=model_prompt,
|
||||
model=model_name,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
total_cost = 0.0
|
||||
attempts = 0
|
||||
while True:
|
||||
completion_kwargs["messages"] = anthropic_messages.copy()
|
||||
if prefill_response:
|
||||
completion_kwargs["messages"].append(
|
||||
{"role": "assistant", "content": prefill_response}
|
||||
)
|
||||
|
||||
(
|
||||
_assistant_msg,
|
||||
cost,
|
||||
t_input,
|
||||
t_output,
|
||||
) = await self._create_chat_completion(completion_kwargs)
|
||||
total_cost += cost
|
||||
self._logger.debug(
|
||||
f"Completion usage: {t_input} input, {t_output} output "
|
||||
f"- ${round(cost, 5)}"
|
||||
)
|
||||
|
||||
# Merge prefill into generated response
|
||||
if prefill_response:
|
||||
first_text_block = next(
|
||||
b for b in _assistant_msg.content if b.type == "text"
|
||||
)
|
||||
first_text_block.text = prefill_response + first_text_block.text
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content="\n\n".join(
|
||||
b.text for b in _assistant_msg.content if b.type == "text"
|
||||
),
|
||||
tool_calls=self._parse_assistant_tool_calls(_assistant_msg),
|
||||
)
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
attempts += 1
|
||||
tool_call_errors = []
|
||||
try:
|
||||
# Validate tool calls
|
||||
if assistant_msg.tool_calls and functions:
|
||||
tool_call_errors = validate_tool_calls(
|
||||
assistant_msg.tool_calls, functions
|
||||
)
|
||||
if tool_call_errors:
|
||||
raise ValueError(
|
||||
"Invalid tool use(s):\n"
|
||||
+ "\n".join(str(e) for e in tool_call_errors)
|
||||
)
|
||||
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||
)
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
anthropic_messages.append(
|
||||
_assistant_msg.dict(include={"role", "content"})
|
||||
)
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
# tool_result is required if last assistant message
|
||||
# had tool_use block(s)
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tc.id,
|
||||
"is_error": True,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Not executed because parsing "
|
||||
"of your last message failed"
|
||||
if not tool_call_errors
|
||||
else str(e)
|
||||
if (
|
||||
e := next(
|
||||
(
|
||||
tce
|
||||
for tce in tool_call_errors
|
||||
if tce.name
|
||||
== tc.function.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
)
|
||||
else "Not executed because validation "
|
||||
"of tool input failed",
|
||||
}
|
||||
],
|
||||
}
|
||||
for tc in assistant_msg.tool_calls or []
|
||||
),
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"ERROR PARSING YOUR RESPONSE:\n\n"
|
||||
f"{e.__class__.__name__}: {e}"
|
||||
),
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if attempts > 1:
|
||||
self._logger.debug(
|
||||
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||
)
|
||||
|
||||
return ChatModelResponse(
|
||||
response=assistant_msg,
|
||||
parsed_result=parsed_result,
|
||||
model_info=ANTHROPIC_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=t_input,
|
||||
completion_tokens_used=t_output,
|
||||
)
|
||||
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
model: AnthropicModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[MessageParam], MessageCreateParams]:
|
||||
"""Prepare arguments for message completion API call.
|
||||
|
||||
Args:
|
||||
prompt_messages: List of ChatMessages.
|
||||
model: The model to use.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
list[MessageParam]: Prompt messages for the Anthropic call
|
||||
dict[str, Any]: Any other kwargs for the Anthropic call
|
||||
"""
|
||||
kwargs["model"] = model
|
||||
|
||||
if functions:
|
||||
kwargs["tools"] = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
name: param.to_dict()
|
||||
for name, param in f.parameters.items()
|
||||
},
|
||||
"required": [
|
||||
name
|
||||
for name, param in f.parameters.items()
|
||||
if param.required
|
||||
],
|
||||
},
|
||||
}
|
||||
for f in functions
|
||||
]
|
||||
|
||||
kwargs["max_tokens"] = max_output_tokens or 4096
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
||||
kwargs["extra_headers"].update(extra_headers.copy())
|
||||
|
||||
system_messages = [
|
||||
m for m in prompt_messages if m.role == ChatMessage.Role.SYSTEM
|
||||
]
|
||||
if (_n := len(system_messages)) > 1:
|
||||
self._logger.warning(
|
||||
f"Prompt has {_n} system messages; Anthropic supports only 1. "
|
||||
"They will be merged, and removed from the rest of the prompt."
|
||||
)
|
||||
kwargs["system"] = "\n\n".join(sm.content for sm in system_messages)
|
||||
|
||||
messages: list[MessageParam] = []
|
||||
for message in prompt_messages:
|
||||
if message.role == ChatMessage.Role.SYSTEM:
|
||||
continue
|
||||
elif message.role == ChatMessage.Role.USER:
|
||||
# Merge subsequent user messages
|
||||
if messages and (prev_msg := messages[-1])["role"] == "user":
|
||||
if isinstance(prev_msg["content"], str):
|
||||
prev_msg["content"] += f"\n\n{message.content}"
|
||||
else:
|
||||
assert isinstance(prev_msg["content"], list)
|
||||
prev_msg["content"].append(
|
||||
{"type": "text", "text": message.content}
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": message.content})
|
||||
# TODO: add support for image blocks
|
||||
elif message.role == ChatMessage.Role.ASSISTANT:
|
||||
if isinstance(message, AssistantChatMessage) and message.tool_calls:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
*(
|
||||
[{"type": "text", "text": message.content}]
|
||||
if message.content
|
||||
else []
|
||||
),
|
||||
*(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"input": tc.function.arguments,
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
),
|
||||
],
|
||||
}
|
||||
)
|
||||
elif message.content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
elif isinstance(message, ToolResultMessage):
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": [{"type": "text", "text": message.content}],
|
||||
"is_error": message.is_error,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return messages, kwargs # type: ignore
|
||||
|
||||
async def _create_chat_completion(
|
||||
self, completion_kwargs: MessageCreateParams
|
||||
) -> tuple[Message, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the Anthropic API with retry handling.
|
||||
|
||||
Params:
|
||||
completion_kwargs: Keyword arguments for an Anthropic Messages API call
|
||||
|
||||
Returns:
|
||||
Message: The message completion object
|
||||
float: The cost ($) of this completion
|
||||
int: Number of input tokens used
|
||||
int: Number of output tokens used
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
completion_kwargs: MessageCreateParams,
|
||||
) -> Message:
|
||||
return await self._client.beta.tools.messages.create(
|
||||
**completion_kwargs # type: ignore
|
||||
)
|
||||
|
||||
response = await _create_chat_completion_with_retry(completion_kwargs)
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=ANTHROPIC_CHAT_MODELS[completion_kwargs["model"]],
|
||||
input_tokens_used=response.usage.input_tokens,
|
||||
output_tokens_used=response.usage.output_tokens,
|
||||
)
|
||||
return response, cost, response.usage.input_tokens, response.usage.output_tokens
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: Message
|
||||
) -> list[AssistantToolCall]:
|
||||
return [
|
||||
AssistantToolCall(
|
||||
id=c.id,
|
||||
type="function",
|
||||
function=AssistantFunctionCall(name=c.name, arguments=c.input),
|
||||
)
|
||||
for c in assistant_message.content
|
||||
if c.type == "tool_use"
|
||||
]
|
||||
|
||||
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
return tenacity.retry(
|
||||
retry=(
|
||||
tenacity.retry_if_exception_type(APIConnectionError)
|
||||
| tenacity.retry_if_exception(
|
||||
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
|
||||
)
|
||||
),
|
||||
wait=tenacity.wait_exponential(),
|
||||
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||
after=tenacity.after_log(self._logger, logging.DEBUG),
|
||||
)(func)
|
||||
|
||||
def __repr__(self):
|
||||
return "AnthropicProvider()"
|
||||
162
autogpts/autogpt/autogpt/core/resource/model_providers/multi.py
Normal file
162
autogpts/autogpt/autogpt/core/resource/model_providers/multi.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterator, Optional, TypeVar
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from autogpt.core.configuration import Configurable
|
||||
|
||||
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
|
||||
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
ModelName = AnthropicModelName | OpenAIModelName
|
||||
|
||||
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
|
||||
|
||||
|
||||
class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
default_settings = ModelProviderSettings(
|
||||
name="multi_provider",
|
||||
description=(
|
||||
"Provides access to all of the available models, regardless of provider."
|
||||
),
|
||||
configuration=ModelProviderConfiguration(
|
||||
retries_per_request=7,
|
||||
),
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
_provider_instances: dict[ModelProviderName, ChatModelProvider]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
super(MultiProvider, self).__init__(settings=settings, logger=logger)
|
||||
self._budget = self._settings.budget or ModelProviderBudget()
|
||||
|
||||
self._provider_instances = {}
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
models = []
|
||||
for provider in self.get_available_providers():
|
||||
models.extend(await provider.get_available_models())
|
||||
return models
|
||||
|
||||
def get_token_limit(self, model_name: ModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return self.get_model_provider(model_name).get_token_limit(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: ModelName) -> ModelTokenizer:
|
||||
return cls._get_model_provider_class(model_name).get_tokenizer(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: ModelName) -> int:
|
||||
return cls._get_model_provider_class(model_name).count_tokens(
|
||||
text=text, model_name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls, messages: ChatMessage | list[ChatMessage], model_name: ModelName
|
||||
) -> int:
|
||||
return cls._get_model_provider_class(model_name).count_message_tokens(
|
||||
messages=messages, model_name=model_name
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Anthropic API."""
|
||||
return await self.get_model_provider(model_name).create_chat_completion(
|
||||
model_prompt=model_prompt,
|
||||
model_name=model_name,
|
||||
completion_parser=completion_parser,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
prefill_response=prefill_response,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_model_provider(self, model: ModelName) -> ChatModelProvider:
|
||||
model_info = CHAT_MODELS[model]
|
||||
return self._get_provider(model_info.provider_name)
|
||||
|
||||
def get_available_providers(self) -> Iterator[ChatModelProvider]:
|
||||
for provider_name in ModelProviderName:
|
||||
try:
|
||||
yield self._get_provider(provider_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
|
||||
_provider = self._provider_instances.get(provider_name)
|
||||
if not _provider:
|
||||
Provider = self._get_provider_class(provider_name)
|
||||
settings = Provider.default_settings.copy(deep=True)
|
||||
settings.budget = self._budget
|
||||
settings.configuration.extra_request_headers.update(
|
||||
self._settings.configuration.extra_request_headers
|
||||
)
|
||||
if settings.credentials is None:
|
||||
try:
|
||||
Credentials = settings.__fields__["credentials"].type_
|
||||
settings.credentials = Credentials.from_env()
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f"{provider_name} is unavailable: can't load credentials"
|
||||
) from e
|
||||
|
||||
self._provider_instances[provider_name] = _provider = Provider(
|
||||
settings=settings, logger=self._logger
|
||||
)
|
||||
_provider._budget = self._budget # Object binding not preserved by Pydantic
|
||||
return _provider
|
||||
|
||||
@classmethod
|
||||
def _get_model_provider_class(
|
||||
cls, model_name: ModelName
|
||||
) -> type[AnthropicProvider | OpenAIProvider]:
|
||||
return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
|
||||
|
||||
@classmethod
|
||||
def _get_provider_class(
|
||||
cls, provider_name: ModelProviderName
|
||||
) -> type[AnthropicProvider | OpenAIProvider]:
|
||||
try:
|
||||
return {
|
||||
ModelProviderName.ANTHROPIC: AnthropicProvider,
|
||||
ModelProviderName.OPENAI: OpenAIProvider,
|
||||
}[provider_name]
|
||||
except KeyError:
|
||||
raise ValueError(f"{provider_name} is not a known provider") from None
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
@@ -42,6 +42,8 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
@@ -298,6 +300,7 @@ class OpenAIProvider(
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: OpenAISettings
|
||||
_configuration: OpenAIConfiguration
|
||||
_credentials: OpenAICredentials
|
||||
_budget: ModelProviderBudget
|
||||
@@ -312,11 +315,7 @@ class OpenAIProvider(
|
||||
if not settings.credentials:
|
||||
settings.credentials = OpenAICredentials.from_env()
|
||||
|
||||
self._settings = settings
|
||||
|
||||
self._configuration = settings.configuration
|
||||
self._credentials = settings.credentials
|
||||
self._budget = settings.budget
|
||||
super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
if self._credentials.api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
@@ -329,8 +328,6 @@ class OpenAIProvider(
|
||||
|
||||
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
|
||||
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
_models = (await self._client.models.list()).data
|
||||
return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS]
|
||||
@@ -398,6 +395,7 @@ class OpenAIProvider(
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "", # not supported by OpenAI
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the OpenAI API and parse it."""
|
||||
@@ -432,6 +430,10 @@ class OpenAIProvider(
|
||||
)
|
||||
parse_errors += _errors
|
||||
|
||||
# Validate tool calls
|
||||
if not parse_errors and tool_calls and functions:
|
||||
parse_errors += validate_tool_calls(tool_calls, functions)
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=tool_calls or None,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import abc
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
@@ -28,6 +30,9 @@ from autogpt.core.resource.schema import (
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.logs.utils import fmt_kwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jsonschema import ValidationError
|
||||
|
||||
|
||||
class ModelProviderService(str, enum.Enum):
|
||||
"""A ModelService describes what kind of service the model provides."""
|
||||
@@ -39,6 +44,7 @@ class ModelProviderService(str, enum.Enum):
|
||||
|
||||
class ModelProviderName(str, enum.Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -100,6 +106,12 @@ class AssistantChatMessage(ChatMessage):
|
||||
tool_calls: Optional[list[AssistantToolCall]] = None
|
||||
|
||||
|
||||
class ToolResultMessage(ChatMessage):
|
||||
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL
|
||||
is_error: bool = False
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AssistantChatMessageDict(TypedDict, total=False):
|
||||
role: str
|
||||
content: str
|
||||
@@ -146,6 +158,30 @@ class CompletionModelFunction(BaseModel):
|
||||
)
|
||||
return f"{self.name}: {self.description}. Params: ({params})"
|
||||
|
||||
def validate_call(
|
||||
self, function_call: AssistantFunctionCall
|
||||
) -> tuple[bool, list["ValidationError"]]:
|
||||
"""
|
||||
Validates the given function call against the function's parameter specs
|
||||
|
||||
Returns:
|
||||
bool: Whether the given set of arguments is valid for this command
|
||||
list[ValidationError]: Issues with the set of arguments (if any)
|
||||
|
||||
Raises:
|
||||
ValueError: If the function_call doesn't call this function
|
||||
"""
|
||||
if function_call.name != self.name:
|
||||
raise ValueError(
|
||||
f"Can't validate {function_call.name} call using {self.name} spec"
|
||||
)
|
||||
|
||||
params_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={name: spec for name, spec in self.parameters.items()},
|
||||
)
|
||||
return params_schema.validate_object(function_call.arguments)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Struct for model information.
|
||||
@@ -229,7 +265,7 @@ class ModelProviderBudget(ProviderBudget):
|
||||
class ModelProviderSettings(ProviderSettings):
|
||||
resource_type: ResourceType = ResourceType.MODEL
|
||||
configuration: ModelProviderConfiguration
|
||||
credentials: ModelProviderCredentials
|
||||
credentials: Optional[ModelProviderCredentials] = None
|
||||
budget: Optional[ModelProviderBudget] = None
|
||||
|
||||
|
||||
@@ -238,9 +274,28 @@ class ModelProvider(abc.ABC):
|
||||
|
||||
default_settings: ClassVar[ModelProviderSettings]
|
||||
|
||||
_settings: ModelProviderSettings
|
||||
_configuration: ModelProviderConfiguration
|
||||
_credentials: Optional[ModelProviderCredentials] = None
|
||||
_budget: Optional[ModelProviderBudget] = None
|
||||
|
||||
_logger: logging.Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
|
||||
self._settings = settings
|
||||
self._configuration = settings.configuration
|
||||
self._credentials = settings.credentials
|
||||
self._budget = settings.budget
|
||||
|
||||
self._logger = logger or logging.getLogger(self.__module__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
...
|
||||
@@ -358,6 +413,7 @@ class ChatModelProvider(ModelProvider):
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
...
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import Any
|
||||
|
||||
from .schema import AssistantToolCall, CompletionModelFunction
|
||||
|
||||
|
||||
class InvalidFunctionCallError(Exception):
|
||||
def __init__(self, name: str, arguments: dict[str, Any], message: str):
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.arguments = arguments
|
||||
super().__init__(message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Invalid function call for {self.name}: {self.message}"
|
||||
|
||||
|
||||
def validate_tool_calls(
|
||||
tool_calls: list[AssistantToolCall], functions: list[CompletionModelFunction]
|
||||
) -> list[InvalidFunctionCallError]:
|
||||
"""
|
||||
Validates a list of tool calls against a list of functions.
|
||||
|
||||
1. Tries to find a function matching each tool call
|
||||
2. If a matching function is found, validates the tool call's arguments,
|
||||
reporting any resulting errors
|
||||
2. If no matching function is found, an error "Unknown function X" is reported
|
||||
3. A list of all errors encountered during validation is returned
|
||||
|
||||
Params:
|
||||
tool_calls: A list of tool calls to validate.
|
||||
functions: A list of functions to validate against.
|
||||
|
||||
Returns:
|
||||
list[InvalidFunctionCallError]: All errors encountered during validation.
|
||||
"""
|
||||
errors: list[InvalidFunctionCallError] = []
|
||||
for tool_call in tool_calls:
|
||||
function_call = tool_call.function
|
||||
|
||||
if function := next(
|
||||
(f for f in functions if f.name == function_call.name),
|
||||
None,
|
||||
):
|
||||
is_valid, validation_errors = function.validate_call(function_call)
|
||||
if not is_valid:
|
||||
fmt_errors = [
|
||||
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
|
||||
if f.path
|
||||
else f.message
|
||||
for f in validation_errors
|
||||
]
|
||||
errors.append(
|
||||
InvalidFunctionCallError(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
message=(
|
||||
"The set of arguments supplied is invalid:\n"
|
||||
+ "\n".join(fmt_errors)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
errors.append(
|
||||
InvalidFunctionCallError(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
message=f"Unknown function {function_call.name}",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
def get_openai_command_specs(
|
||||
def function_specs_from_commands(
|
||||
commands: Iterable[Command],
|
||||
) -> list[CompletionModelFunction]:
|
||||
"""Get OpenAI-consumable function specs for the agent's available commands.
|
||||
|
||||
@@ -3,8 +3,6 @@ from __future__ import annotations
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
from .command_parameter import CommandParameter
|
||||
from .context_item import ContextItem
|
||||
|
||||
@@ -42,20 +40,6 @@ class Command:
|
||||
def is_async(self) -> bool:
|
||||
return inspect.iscoroutinefunction(self.method)
|
||||
|
||||
def validate_args(self, args: dict[str, Any]):
|
||||
"""
|
||||
Validates the given arguments against the command's parameter specifications
|
||||
|
||||
Returns:
|
||||
bool: Whether the given set of arguments is valid for this command
|
||||
list[ValidationError]: Issues with the set of arguments (if any)
|
||||
"""
|
||||
params_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={p.name: p.spec for p in self.parameters},
|
||||
)
|
||||
return params_schema.validate_object(args)
|
||||
|
||||
def _parameters_match(
|
||||
self, func: Callable, parameters: list[CommandParameter]
|
||||
) -> bool:
|
||||
|
||||
26
autogpts/autogpt/poetry.lock
generated
26
autogpts/autogpt/poetry.lock
generated
@@ -167,6 +167,30 @@ files = [
|
||||
[package.dependencies]
|
||||
frozenlist = ">=1.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.25.1"
|
||||
description = "The official Python library for the anthropic API"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "anthropic-0.25.1-py3-none-any.whl", hash = "sha256:95d0cedc2a4b5beae3a78f9030aea4001caea5f46c6d263cce377c891c594e71"},
|
||||
{file = "anthropic-0.25.1.tar.gz", hash = "sha256:0c01b30b77d041a8d07c532737bae69da58086031217150008e4541f52a64bd9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=3.5.0,<5"
|
||||
distro = ">=1.7.0,<2"
|
||||
httpx = ">=0.23.0,<1"
|
||||
pydantic = ">=1.9.0,<3"
|
||||
sniffio = "*"
|
||||
tokenizers = ">=0.13.0"
|
||||
typing-extensions = ">=4.7,<5"
|
||||
|
||||
[package.extras]
|
||||
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
|
||||
vertex = ["google-auth (>=2,<3)"]
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.2.0"
|
||||
@@ -7234,4 +7258,4 @@ benchmark = ["agbenchmark"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "e6eab5c079d53f075ce701e86a2007e7ebeb635ac067d25f555bfea363bcc630"
|
||||
content-hash = "ad1e3c4706465733d04ddab975af630975bd528efce152c1da01eded53069eca"
|
||||
|
||||
@@ -22,6 +22,7 @@ serve = "autogpt.app.cli:serve"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
anthropic = "^0.25.1"
|
||||
# autogpt-forge = { path = "../forge" }
|
||||
autogpt-forge = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "autogpts/forge"}
|
||||
beautifulsoup4 = "^4.12.2"
|
||||
|
||||
@@ -8,9 +8,9 @@ import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.main import _configure_openai_provider
|
||||
from autogpt.app.main import _configure_llm_provider
|
||||
from autogpt.config import AIProfile, Config, ConfigBuilder
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider, OpenAIProvider
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.file_storage.local import (
|
||||
FileStorage,
|
||||
FileStorageConfiguration,
|
||||
@@ -73,8 +73,8 @@ def setup_logger(config: Config):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(config: Config) -> OpenAIProvider:
|
||||
return _configure_openai_provider(config)
|
||||
def llm_provider(config: Config) -> ChatModelProvider:
|
||||
return _configure_llm_provider(config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -14,7 +14,6 @@ from pydantic import SecretStr
|
||||
|
||||
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config
|
||||
from autogpt.config import Config, ConfigBuilder
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIModelName
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
ChatModelInfo,
|
||||
ModelProviderName,
|
||||
@@ -39,8 +38,8 @@ async def test_fallback_to_gpt3_if_gpt4_not_available(
|
||||
"""
|
||||
Test if models update to gpt-3.5-turbo if gpt-4 is not available.
|
||||
"""
|
||||
config.fast_llm = OpenAIModelName.GPT4_TURBO
|
||||
config.smart_llm = OpenAIModelName.GPT4_TURBO
|
||||
config.fast_llm = GPT_4_MODEL
|
||||
config.smart_llm = GPT_4_MODEL
|
||||
|
||||
mock_list_models.return_value = asyncio.Future()
|
||||
mock_list_models.return_value.set_result(
|
||||
@@ -56,8 +55,8 @@ async def test_fallback_to_gpt3_if_gpt4_not_available(
|
||||
gpt4only=False,
|
||||
)
|
||||
|
||||
assert config.fast_llm == "gpt-3.5-turbo"
|
||||
assert config.smart_llm == "gpt-3.5-turbo"
|
||||
assert config.fast_llm == GPT_3_MODEL
|
||||
assert config.smart_llm == GPT_3_MODEL
|
||||
|
||||
|
||||
def test_missing_azure_config(config: Config) -> None:
|
||||
@@ -148,8 +147,7 @@ def test_azure_config(config_with_azure: Config) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_config_gpt4only(config: Config) -> None:
|
||||
with mock.patch(
|
||||
"autogpt.core.resource.model_providers.openai."
|
||||
"OpenAIProvider.get_available_models"
|
||||
"autogpt.core.resource.model_providers.multi.MultiProvider.get_available_models"
|
||||
) as mock_get_models:
|
||||
mock_get_models.return_value = [
|
||||
ChatModelInfo(
|
||||
@@ -169,8 +167,7 @@ async def test_create_config_gpt4only(config: Config) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_config_gpt3only(config: Config) -> None:
|
||||
with mock.patch(
|
||||
"autogpt.core.resource.model_providers.openai."
|
||||
"OpenAIProvider.get_available_models"
|
||||
"autogpt.core.resource.model_providers.multi.MultiProvider.get_available_models"
|
||||
) as mock_get_models:
|
||||
mock_get_models.return_value = [
|
||||
ChatModelInfo(
|
||||
|
||||
@@ -7,6 +7,7 @@ Configuration is controlled through the `Config` object. You can set configurati
|
||||
- `AI_SETTINGS_FILE`: Location of the AI Settings file relative to the AutoGPT root directory. Default: ai_settings.yaml
|
||||
- `AUDIO_TO_TEXT_PROVIDER`: Audio To Text Provider. Only option currently is `huggingface`. Default: huggingface
|
||||
- `AUTHORISE_COMMAND_KEY`: Key response accepted when authorising commands. Default: y
|
||||
- `ANTHROPIC_API_KEY`: Set this if you want to use Anthropic models with AutoGPT
|
||||
- `AZURE_CONFIG_FILE`: Location of the Azure Config file relative to the AutoGPT root directory. Default: azure.yaml
|
||||
- `BROWSE_CHUNK_MAX_LENGTH`: When browsing website, define the length of chunks to summarize. Default: 3000
|
||||
- `BROWSE_SPACY_LANGUAGE_MODEL`: [spaCy language model](https://spacy.io/usage/models) to use when creating chunks. Default: en_core_web_sm
|
||||
|
||||
Reference in New Issue
Block a user