mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-02-23 15:14:44 +01:00
feat: Add support for new models and features from OpenAI's November 6 update (#6147)
* feat: Add support for new models and features from OpenAI's November 6 update - Updated the `OpenAIModelName` enum to include new GPT-3.5 Turbo and GPT-4 models - Added support for the `GPT3_v3` and `GPT4_v3` models in the `OPEN_AI_CHAT_MODELS` dictionary - Modified the `OpenAIProvider` class to handle the new models and features - Updated the schema definitions in the `schema.py` module to include `AssistantToolCall` and `AssistantToolCallDict` models - Updated the `AssistantChatMessage` and `AssistantChatMessageDict` models to include the `tool_calls` field - Refactored the code in various modules to handle the new tool calls and function arguments Added support for the new models and features introduced with OpenAI's latest update. This commit allows the system to utilize the `GPT3_v3` and `GPT4_v3` models and includes all necessary modifications to the codebase to handle the new models and associated features. * Fix validation error in LLM response handling * fix: Fix profile generator in-prompt example for functions compatibility mode - Updated the in-prompt example in the profile generator to be compatible with functions compatibility mode. - Modified the example call section to correctly reflect the structure of function calls.
This commit is contained in:
committed by
GitHub
parent
578087ec96
commit
345ff6f88d
@@ -36,9 +36,10 @@ class AgentProfileGeneratorConfiguration(SystemConfiguration):
|
||||
"\n"
|
||||
"Example Input:\n"
|
||||
'"""Help me with marketing my business"""\n\n'
|
||||
"Example Function Call:\n"
|
||||
"Example Call:\n"
|
||||
"```\n"
|
||||
"{"
|
||||
"[" # tool_calls
|
||||
'{"type": "function", "function": {'
|
||||
'"name": "create_agent",'
|
||||
' "arguments": {'
|
||||
'"name": "CMOGPT",'
|
||||
@@ -65,7 +66,9 @@ class AgentProfileGeneratorConfiguration(SystemConfiguration):
|
||||
"]" # constraints
|
||||
"}" # directives
|
||||
"}" # arguments
|
||||
"}\n"
|
||||
"}" # function
|
||||
"}" # tool call
|
||||
"]\n" # tool_calls
|
||||
"```"
|
||||
)
|
||||
)
|
||||
@@ -172,7 +175,9 @@ class AgentProfileGenerator(PromptStrategy):
|
||||
|
||||
"""
|
||||
try:
|
||||
arguments = json_loads(response_content["function_call"]["arguments"])
|
||||
arguments = json_loads(
|
||||
response_content["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
ai_profile = AIProfile(
|
||||
ai_name=arguments.get("name"),
|
||||
ai_role=arguments.get("description"),
|
||||
|
||||
@@ -316,7 +316,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
)
|
||||
|
||||
return (
|
||||
f"Respond strictly with a JSON object{' containing your thoughts, and a function_call specifying the next command to use' if use_functions_api else ''}. "
|
||||
f"Respond strictly with a JSON object{' containing your thoughts, and a tool_call specifying the next command to use' if use_functions_api else ''}. "
|
||||
"The JSON object should be compatible with the TypeScript type `Response` from the following:\n"
|
||||
f"{response_format}"
|
||||
)
|
||||
@@ -431,11 +431,13 @@ def extract_command(
|
||||
Exception: If any other error occurs
|
||||
"""
|
||||
if use_openai_functions_api:
|
||||
if "function_call" not in assistant_reply:
|
||||
raise InvalidAgentResponseError("No 'function_call' in assistant reply")
|
||||
if not assistant_reply.get("tool_calls"):
|
||||
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
||||
assistant_reply_json["command"] = {
|
||||
"name": assistant_reply["function_call"]["name"],
|
||||
"args": json.loads(assistant_reply["function_call"]["arguments"]),
|
||||
"name": assistant_reply["tool_calls"][0]["function"]["name"],
|
||||
"args": json.loads(
|
||||
assistant_reply["tool_calls"][0]["function"]["arguments"]
|
||||
),
|
||||
}
|
||||
try:
|
||||
if not isinstance(assistant_reply_json, dict):
|
||||
|
||||
@@ -55,7 +55,7 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
prompt_settings_file: Path = project_root / PROMPT_SETTINGS_FILE
|
||||
# Model configuration
|
||||
fast_llm: str = "gpt-3.5-turbo-16k"
|
||||
smart_llm: str = "gpt-4-0314"
|
||||
smart_llm: str = "gpt-4"
|
||||
temperature: float = 0
|
||||
openai_functions: bool = False
|
||||
embedding_model: str = "text-embedding-ada-002"
|
||||
|
||||
@@ -169,7 +169,9 @@ class InitialPlan(PromptStrategy):
|
||||
The parsed response.
|
||||
"""
|
||||
try:
|
||||
parsed_response = json_loads(response_content["function_call"]["arguments"])
|
||||
parsed_response = json_loads(
|
||||
response_content["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
parsed_response["task_list"] = [
|
||||
Task.parse_obj(task) for task in parsed_response["task_list"]
|
||||
]
|
||||
|
||||
@@ -133,7 +133,9 @@ class NameAndGoals(PromptStrategy):
|
||||
|
||||
"""
|
||||
try:
|
||||
parsed_response = json_loads(response_content["function_call"]["arguments"])
|
||||
parsed_response = json_loads(
|
||||
response_content["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
except KeyError:
|
||||
logger.debug(f"Failed to parse this response content: {response_content}")
|
||||
raise
|
||||
|
||||
@@ -170,9 +170,9 @@ class NextAbility(PromptStrategy):
|
||||
|
||||
"""
|
||||
try:
|
||||
function_name = response_content["function_call"]["name"]
|
||||
function_name = response_content["tool_calls"][0]["function"]["name"]
|
||||
function_arguments = json_loads(
|
||||
response_content["function_call"]["arguments"]
|
||||
response_content["tool_calls"][0]["function"]["arguments"]
|
||||
)
|
||||
parsed_response = {
|
||||
"motivation": function_arguments.pop("motivation"),
|
||||
|
||||
@@ -16,7 +16,7 @@ from autogpt.core.configuration import (
|
||||
)
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantFunctionCallDict,
|
||||
AssistantToolCallDict,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
@@ -49,6 +49,7 @@ class OpenAIModelName(str, enum.Enum):
|
||||
GPT3_v1 = "gpt-3.5-turbo-0301"
|
||||
GPT3_v2 = "gpt-3.5-turbo-0613"
|
||||
GPT3_v2_16k = "gpt-3.5-turbo-16k-0613"
|
||||
GPT3_v3 = "gpt-3.5-turbo-1106"
|
||||
GPT3_ROLLING = "gpt-3.5-turbo"
|
||||
GPT3_ROLLING_16k = "gpt-3.5-turbo-16k"
|
||||
GPT3 = GPT3_ROLLING
|
||||
@@ -58,8 +59,10 @@ class OpenAIModelName(str, enum.Enum):
|
||||
GPT4_v1_32k = "gpt-4-32k-0314"
|
||||
GPT4_v2 = "gpt-4-0613"
|
||||
GPT4_v2_32k = "gpt-4-32k-0613"
|
||||
GPT4_v3 = "gpt-4-1106-preview"
|
||||
GPT4_ROLLING = "gpt-4"
|
||||
GPT4_ROLLING_32k = "gpt-4-32k"
|
||||
GPT4_VISION = "gpt-4-vision-preview"
|
||||
GPT4 = GPT4_ROLLING
|
||||
GPT4_32k = GPT4_ROLLING_32k
|
||||
|
||||
@@ -97,6 +100,15 @@ OPEN_AI_CHAT_MODELS = {
|
||||
max_tokens=16384,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT3_v3,
|
||||
service=ModelProviderService.CHAT,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.001 / 1000,
|
||||
completion_token_cost=0.002 / 1000,
|
||||
max_tokens=16384,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4,
|
||||
service=ModelProviderService.CHAT,
|
||||
@@ -115,6 +127,15 @@ OPEN_AI_CHAT_MODELS = {
|
||||
max_tokens=32768,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4_v3,
|
||||
service=ModelProviderService.CHAT,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.01 / 1000,
|
||||
completion_token_cost=0.03 / 1000,
|
||||
max_tokens=128000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
]
|
||||
}
|
||||
# Copy entries for models with equivalent specs
|
||||
@@ -271,7 +292,7 @@ class OpenAIProvider(
|
||||
"""Create a completion using the OpenAI API."""
|
||||
|
||||
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
|
||||
functions_compat_mode = functions and "functions" not in completion_kwargs
|
||||
tool_calls_compat_mode = functions and "tools" not in completion_kwargs
|
||||
if "messages" in completion_kwargs:
|
||||
model_prompt += completion_kwargs["messages"]
|
||||
del completion_kwargs["messages"]
|
||||
@@ -287,8 +308,8 @@ class OpenAIProvider(
|
||||
}
|
||||
|
||||
response_message = response.choices[0].message.to_dict_recursive()
|
||||
if functions_compat_mode:
|
||||
response_message["function_call"] = _functions_compat_extract_call(
|
||||
if tool_calls_compat_mode:
|
||||
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
|
||||
response_message["content"]
|
||||
)
|
||||
response = ChatModelResponse(
|
||||
@@ -346,10 +367,15 @@ class OpenAIProvider(
|
||||
|
||||
if functions:
|
||||
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
|
||||
completion_kwargs["functions"] = [f.schema for f in functions]
|
||||
completion_kwargs["tools"] = [
|
||||
{"type": "function", "function": f.schema} for f in functions
|
||||
]
|
||||
if len(functions) == 1:
|
||||
# force the model to call the only specified function
|
||||
completion_kwargs["function_call"] = {"name": functions[0].name}
|
||||
completion_kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
"function": {"name": functions[0].name},
|
||||
}
|
||||
else:
|
||||
# Provide compatibility with older models
|
||||
_functions_compat_fix_kwargs(functions, completion_kwargs)
|
||||
@@ -411,7 +437,7 @@ async def _create_chat_completion(
|
||||
The completion.
|
||||
"""
|
||||
raw_messages = [
|
||||
message.dict(include={"role", "content", "function_call", "name"})
|
||||
message.dict(include={"role", "content", "tool_calls", "name"})
|
||||
for message in messages
|
||||
]
|
||||
return await openai.ChatCompletion.acreate(
|
||||
@@ -573,14 +599,27 @@ def _functions_compat_fix_kwargs(
|
||||
),
|
||||
},
|
||||
)
|
||||
tool_calls_schema = JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
items=JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
"type": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
enum=["function"],
|
||||
),
|
||||
"function": function_call_schema,
|
||||
},
|
||||
),
|
||||
)
|
||||
completion_kwargs["messages"] = [
|
||||
ChatMessage.system(
|
||||
"# function_call instructions\n\n"
|
||||
"Specify a '```function_call' block in your response,"
|
||||
" enclosing a function call in the form of a valid JSON object"
|
||||
" that adheres to the following schema:\n\n"
|
||||
f"{function_call_schema.to_dict()}\n\n"
|
||||
"Put the function_call block at the end of your response"
|
||||
"# tool usage instructions\n\n"
|
||||
"Specify a '```tool_calls' block in your response,"
|
||||
" with a valid JSON object that adheres to the following schema:\n\n"
|
||||
f"{tool_calls_schema.to_dict()}\n\n"
|
||||
"Specify any tools that you need to use through this JSON object.\n\n"
|
||||
"Put the tool_calls block at the end of your response"
|
||||
" and include its fences if it is not the only content.\n\n"
|
||||
"## functions\n\n"
|
||||
"For the function call itself, use one of the following"
|
||||
@@ -589,19 +628,21 @@ def _functions_compat_fix_kwargs(
|
||||
]
|
||||
|
||||
|
||||
def _functions_compat_extract_call(response: str) -> AssistantFunctionCallDict:
|
||||
def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDict]:
|
||||
import json
|
||||
import re
|
||||
|
||||
logging.debug(f"Trying to extract function call from response:\n{response}")
|
||||
logging.debug(f"Trying to extract tool calls from response:\n{response}")
|
||||
|
||||
if response[0] == "{":
|
||||
function_call = json.loads(response)
|
||||
if response[0] == "[":
|
||||
tool_calls: list[AssistantToolCallDict] = json.loads(response)
|
||||
else:
|
||||
block = re.search(r"```(?:function_call)?\n(.*)\n```\s*$", response, re.DOTALL)
|
||||
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
|
||||
if not block:
|
||||
raise ValueError("Could not find function call block in response")
|
||||
function_call = json.loads(block.group(1))
|
||||
raise ValueError("Could not find tool calls block in response")
|
||||
tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
|
||||
|
||||
function_call["arguments"] = str(function_call["arguments"]) # HACK
|
||||
return function_call
|
||||
for t in tool_calls:
|
||||
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
|
||||
|
||||
return tool_calls
|
||||
|
||||
@@ -77,16 +77,28 @@ class AssistantFunctionCallDict(TypedDict):
|
||||
arguments: str
|
||||
|
||||
|
||||
class AssistantToolCall(BaseModel):
|
||||
# id: str
|
||||
type: Literal["function"]
|
||||
function: AssistantFunctionCall
|
||||
|
||||
|
||||
class AssistantToolCallDict(TypedDict):
|
||||
# id: str
|
||||
type: Literal["function"]
|
||||
function: AssistantFunctionCallDict
|
||||
|
||||
|
||||
class AssistantChatMessage(ChatMessage):
|
||||
role: Literal["assistant"]
|
||||
content: Optional[str]
|
||||
function_call: Optional[AssistantFunctionCall]
|
||||
tool_calls: Optional[list[AssistantToolCall]]
|
||||
|
||||
|
||||
class AssistantChatMessageDict(TypedDict, total=False):
|
||||
role: str
|
||||
content: str
|
||||
function_call: AssistantFunctionCallDict
|
||||
tool_calls: list[AssistantToolCallDict]
|
||||
|
||||
|
||||
class CompletionModelFunction(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user