refactor(agent): Refactor & improve create_chat_completion (#7082)

* refactor(agent/core): Rearrange and split up `OpenAIProvider.create_chat_completion`
   - Rearrange to reduce complexity, improve separation/abstraction of concerns, and allow multiple points of failure during parsing
   - Move conversion from `ChatMessage` to `openai.types.ChatCompletionMessageParam` to `_get_chat_completion_args`
   - Move token usage and cost tracking boilerplate code to `_create_chat_completion`
   - Move tool call conversion/parsing to `_parse_assistant_tool_calls` (new)

* fix(agent/core): Handle decoding of function call arguments in `create_chat_completion`
   - Amend `model_providers.schema`: change type of `arguments` from `str` to `dict[str, Any]` on `AssistantFunctionCall` and `AssistantFunctionCallDict`
   - Implement robust and transparent parsing in `OpenAIProvider._parse_assistant_tool_calls`
   - Remove now unnecessary `json_loads` calls throughout codebase

* feat(agent/utils): Improve conditions and errors in `json_loads`
   - Include all decoding errors when raising a ValueError on decode failure
   - Use errors returned by `return_errors` instead of an error buffer
   - Fix check for decode failure
This commit is contained in:
Reinier van der Leer
2024-04-16 10:38:49 +02:00
committed by GitHub
parent d7f00a996f
commit 7082e63b11
8 changed files with 211 additions and 116 deletions

View File

@@ -15,7 +15,6 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction, CompletionModelFunction,
) )
from autogpt.core.utils.json_schema import JSONSchema from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -203,9 +202,7 @@ class AgentProfileGenerator(PromptStrategy):
f"LLM did not call {self._create_agent_function.name} function; " f"LLM did not call {self._create_agent_function.name} function; "
"agent profile creation failed" "agent profile creation failed"
) )
arguments: object = json_loads( arguments: object = response_content.tool_calls[0].function.arguments
response_content.tool_calls[0].function.arguments
)
ai_profile = AIProfile( ai_profile = AIProfile(
ai_name=arguments.get("name"), ai_name=arguments.get("name"),
ai_role=arguments.get("description"), ai_role=arguments.get("description"),

View File

@@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction, CompletionModelFunction,
) )
from autogpt.core.utils.json_schema import JSONSchema from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads from autogpt.core.utils.json_utils import extract_dict_from_json
from autogpt.prompts.utils import format_numbered_list, indent from autogpt.prompts.utils import format_numbered_list, indent
@@ -436,7 +436,7 @@ def extract_command(
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply") raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
assistant_reply_json["command"] = { assistant_reply_json["command"] = {
"name": assistant_reply.tool_calls[0].function.name, "name": assistant_reply.tool_calls[0].function.name,
"args": json_loads(assistant_reply.tool_calls[0].function.arguments), "args": assistant_reply.tool_calls[0].function.arguments,
} }
try: try:
if not isinstance(assistant_reply_json, dict): if not isinstance(assistant_reply_json, dict):

View File

@@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction, CompletionModelFunction,
) )
from autogpt.core.utils.json_schema import JSONSchema from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -195,9 +194,7 @@ class InitialPlan(PromptStrategy):
f"LLM did not call {self._create_plan_function.name} function; " f"LLM did not call {self._create_plan_function.name} function; "
"plan creation failed" "plan creation failed"
) )
parsed_response: object = json_loads( parsed_response: object = response_content.tool_calls[0].function.arguments
response_content.tool_calls[0].function.arguments
)
parsed_response["task_list"] = [ parsed_response["task_list"] = [
Task.parse_obj(task) for task in parsed_response["task_list"] Task.parse_obj(task) for task in parsed_response["task_list"]
] ]

View File

@@ -9,7 +9,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction, CompletionModelFunction,
) )
from autogpt.core.utils.json_schema import JSONSchema from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -141,9 +140,7 @@ class NameAndGoals(PromptStrategy):
f"LLM did not call {self._create_agent_function} function; " f"LLM did not call {self._create_agent_function} function; "
"agent profile creation failed" "agent profile creation failed"
) )
parsed_response = json_loads( parsed_response = response_content.tool_calls[0].function.arguments
response_content.tool_calls[0].function.arguments
)
except KeyError: except KeyError:
logger.debug(f"Failed to parse this response content: {response_content}") logger.debug(f"Failed to parse this response content: {response_content}")
raise raise

View File

@@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction, CompletionModelFunction,
) )
from autogpt.core.utils.json_schema import JSONSchema from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -188,9 +187,7 @@ class NextAbility(PromptStrategy):
raise ValueError("LLM did not call any function") raise ValueError("LLM did not call any function")
function_name = response_content.tool_calls[0].function.name function_name = response_content.tool_calls[0].function.name
function_arguments = json_loads( function_arguments = response_content.tool_calls[0].function.arguments
response_content.tool_calls[0].function.arguments
)
parsed_response = { parsed_response = {
"motivation": function_arguments.pop("motivation"), "motivation": function_arguments.pop("motivation"),
"self_criticism": function_arguments.pop("self_criticism"), "self_criticism": function_arguments.pop("self_criticism"),

View File

@@ -3,7 +3,7 @@ import logging
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
import sentry_sdk import sentry_sdk
import tenacity import tenacity
@@ -11,12 +11,17 @@ import tiktoken
import yaml import yaml
from openai._exceptions import APIStatusError, RateLimitError from openai._exceptions import APIStatusError, RateLimitError
from openai.types import CreateEmbeddingResponse from openai.types import CreateEmbeddingResponse
from openai.types.chat import ChatCompletion from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from pydantic import SecretStr from pydantic import SecretStr
from autogpt.core.configuration import Configurable, UserConfigurable from autogpt.core.configuration import Configurable, UserConfigurable
from autogpt.core.resource.model_providers.schema import ( from autogpt.core.resource.model_providers.schema import (
AssistantChatMessage, AssistantChatMessage,
AssistantFunctionCall,
AssistantToolCall, AssistantToolCall,
AssistantToolCallDict, AssistantToolCallDict,
ChatMessage, ChatMessage,
@@ -406,83 +411,90 @@ class OpenAIProvider(
) -> ChatModelResponse[_T]: ) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API.""" """Create a completion using the OpenAI API."""
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs) openai_messages, completion_kwargs = self._get_chat_completion_args(
tool_calls_compat_mode = functions and "tools" not in completion_kwargs model_prompt, model_name, functions, **kwargs
if "messages" in completion_kwargs: )
model_prompt += completion_kwargs["messages"] tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)
del completion_kwargs["messages"]
cost = 0.0 total_cost = 0.0
attempts = 0 attempts = 0
while True: while True:
_response = await self._create_chat_completion( _response, _cost, t_input, t_output = await self._create_chat_completion(
messages=model_prompt, messages=openai_messages,
**completion_kwargs, **completion_kwargs,
) )
total_cost += _cost
_assistant_msg = _response.choices[0].message
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=(
[AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
if _assistant_msg.tool_calls
else None
),
)
response = ChatModelResponse(
response=assistant_msg,
model_info=OPEN_AI_CHAT_MODELS[model_name],
prompt_tokens_used=(
_response.usage.prompt_tokens if _response.usage else 0
),
completion_tokens_used=(
_response.usage.completion_tokens if _response.usage else 0
),
)
cost += self._budget.update_usage_and_cost(response)
self._logger.debug(
f"Completion usage: {response.prompt_tokens_used} input, "
f"{response.completion_tokens_used} output - ${round(cost, 5)}"
)
# If parsing the response fails, append the error to the prompt, and let the # If parsing the response fails, append the error to the prompt, and let the
# LLM fix its mistake(s). # LLM fix its mistake(s).
try: attempts += 1
attempts += 1 parse_errors: list[Exception] = []
if ( _assistant_msg = _response.choices[0].message
tool_calls_compat_mode
and assistant_msg.content tool_calls, _errors = self._parse_assistant_tool_calls(
and not assistant_msg.tool_calls _assistant_msg, tool_calls_compat_mode
): )
assistant_msg.tool_calls = list( parse_errors += _errors
_tool_calls_compat_extract_calls(assistant_msg.content)
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=tool_calls or None,
)
parsed_result: _T = None # type: ignore
if not parse_errors:
try:
parsed_result = completion_parser(assistant_msg)
except Exception as e:
parse_errors.append(e)
if not parse_errors:
if attempts > 1:
self._logger.debug(
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
) )
response.parsed_result = completion_parser(assistant_msg) return ChatModelResponse(
break response=AssistantChatMessage(
except Exception as e: content=_assistant_msg.content,
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}") tool_calls=tool_calls or None,
self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''") ),
sentry_sdk.capture_exception( parsed_result=parsed_result,
error=e, model_info=OPEN_AI_CHAT_MODELS[model_name],
extras={"assistant_msg": assistant_msg, "i_attempt": attempts}, prompt_tokens_used=t_input,
completion_tokens_used=t_output,
) )
if attempts < self._configuration.fix_failed_parse_tries:
model_prompt.append(assistant_msg) else:
model_prompt.append( self._logger.debug(
ChatMessage.system( f"Parsing failed on response: '''{_assistant_msg}'''"
"ERROR PARSING YOUR RESPONSE:\n\n" )
f"{e.__class__.__name__}: {e}" self._logger.warning(
) f"Parsing attempt #{attempts} failed: {parse_errors}"
)
for e in parse_errors:
sentry_sdk.capture_exception(
error=e,
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
) )
if attempts < self._configuration.fix_failed_parse_tries:
openai_messages.append(_assistant_msg.dict(exclude_none=True))
openai_messages.append(
{
"role": "system",
"content": (
"ERROR PARSING YOUR RESPONSE:\n\n"
+ "\n\n".join(
f"{e.__class__.__name__}: {e}" for e in parse_errors
)
),
}
)
continue
else: else:
raise raise parse_errors[0]
if attempts > 1:
self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}")
return response
async def create_embedding( async def create_embedding(
self, self,
@@ -504,21 +516,24 @@ class OpenAIProvider(
self._budget.update_usage_and_cost(response) self._budget.update_usage_and_cost(response)
return response return response
def _get_completion_kwargs( def _get_chat_completion_args(
self, self,
model_prompt: list[ChatMessage],
model_name: OpenAIModelName, model_name: OpenAIModelName,
functions: Optional[list[CompletionModelFunction]] = None, functions: Optional[list[CompletionModelFunction]] = None,
**kwargs, **kwargs,
) -> dict: ) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]:
"""Get kwargs for completion API call. """Prepare chat completion arguments and keyword arguments for API call.
Args: Args:
model: The model to use. model_prompt: List of ChatMessages.
kwargs: Keyword arguments to override the default values. model_name: The model to use.
functions: Optional list of functions available to the LLM.
kwargs: Additional keyword arguments.
Returns: Returns:
The kwargs for the chat API call. list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
dict[str, Any]: Any other kwargs for the OpenAI call
""" """
kwargs.update(self._credentials.get_model_access_kwargs(model_name)) kwargs.update(self._credentials.get_model_access_kwargs(model_name))
@@ -541,7 +556,19 @@ class OpenAIProvider(
kwargs["extra_headers"] = kwargs.get("extra_headers", {}) kwargs["extra_headers"] = kwargs.get("extra_headers", {})
kwargs["extra_headers"].update(extra_headers.copy()) kwargs["extra_headers"].update(extra_headers.copy())
return kwargs if "messages" in kwargs:
model_prompt += kwargs["messages"]
del kwargs["messages"]
openai_messages: list[ChatCompletionMessageParam] = [
message.dict(
include={"role", "content", "tool_calls", "name"},
exclude_none=True,
)
for message in model_prompt
]
return openai_messages, kwargs
def _get_embedding_kwargs( def _get_embedding_kwargs(
self, self,
@@ -566,28 +593,106 @@ class OpenAIProvider(
return kwargs return kwargs
def _create_chat_completion( async def _create_chat_completion(
self, messages: list[ChatMessage], *_, **kwargs self,
) -> Coroutine[None, None, ChatCompletion]: messages: list[ChatCompletionMessageParam],
"""Create a chat completion using the OpenAI API with retry handling.""" model: OpenAIModelName,
*_,
**kwargs,
) -> tuple[ChatCompletion, float, int, int]:
"""
Create a chat completion using the OpenAI API with retry handling.
Params:
openai_messages: List of OpenAI-consumable message dict objects
model: The model to use for the completion
Returns:
ChatCompletion: The chat completion response object
float: The cost ($) of this completion
int: Number of prompt tokens used
int: Number of completion tokens used
"""
@self._retry_api_request @self._retry_api_request
async def _create_chat_completion_with_retry( async def _create_chat_completion_with_retry(
messages: list[ChatMessage], *_, **kwargs messages: list[ChatCompletionMessageParam], **kwargs
) -> ChatCompletion: ) -> ChatCompletion:
raw_messages = [
message.dict(
include={"role", "content", "tool_calls", "name"},
exclude_none=True,
)
for message in messages
]
return await self._client.chat.completions.create( return await self._client.chat.completions.create(
messages=raw_messages, # type: ignore messages=messages, # type: ignore
**kwargs, **kwargs,
) )
return _create_chat_completion_with_retry(messages, *_, **kwargs) completion = await _create_chat_completion_with_retry(
messages, model=model, **kwargs
)
if completion.usage:
prompt_tokens_used = completion.usage.prompt_tokens
completion_tokens_used = completion.usage.completion_tokens
else:
prompt_tokens_used = completion_tokens_used = 0
cost = self._budget.update_usage_and_cost(
ChatModelResponse(
response=AssistantChatMessage(content=None),
model_info=OPEN_AI_CHAT_MODELS[model],
prompt_tokens_used=prompt_tokens_used,
completion_tokens_used=completion_tokens_used,
)
)
self._logger.debug(
f"Completion usage: {prompt_tokens_used} input, "
f"{completion_tokens_used} output - ${round(cost, 5)}"
)
return completion, cost, prompt_tokens_used, completion_tokens_used
def _parse_assistant_tool_calls(
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
):
tool_calls: list[AssistantToolCall] = []
parse_errors: list[Exception] = []
if assistant_message.tool_calls:
for _tc in assistant_message.tool_calls:
try:
parsed_arguments = json_loads(_tc.function.arguments)
except Exception as e:
err_message = (
f"Decoding arguments for {_tc.function.name} failed: "
+ str(e.args[0])
)
parse_errors.append(
type(e)(err_message, *e.args[1:]).with_traceback(
e.__traceback__
)
)
continue
tool_calls.append(
AssistantToolCall(
id=_tc.id,
type=_tc.type,
function=AssistantFunctionCall(
name=_tc.function.name,
arguments=parsed_arguments,
),
)
)
# If parsing of all tool calls succeeds in the end, we ignore any issues
if len(tool_calls) == len(assistant_message.tool_calls):
parse_errors = []
elif compat_mode and assistant_message.content:
try:
tool_calls = list(
_tool_calls_compat_extract_calls(assistant_message.content)
)
except Exception as e:
parse_errors.append(e)
return tool_calls, parse_errors
def _create_embedding( def _create_embedding(
self, text: str, *_, **kwargs self, text: str, *_, **kwargs

View File

@@ -2,6 +2,7 @@ import abc
import enum import enum
import math import math
from typing import ( from typing import (
Any,
Callable, Callable,
ClassVar, ClassVar,
Generic, Generic,
@@ -68,12 +69,12 @@ class ChatMessageDict(TypedDict):
class AssistantFunctionCall(BaseModel): class AssistantFunctionCall(BaseModel):
name: str name: str
arguments: str arguments: dict[str, Any]
class AssistantFunctionCallDict(TypedDict): class AssistantFunctionCallDict(TypedDict):
name: str name: str
arguments: str arguments: dict[str, Any]
class AssistantToolCall(BaseModel): class AssistantToolCall(BaseModel):

View File

@@ -1,4 +1,3 @@
import io
import logging import logging
import re import re
from typing import Any from typing import Any
@@ -32,16 +31,18 @@ def json_loads(json_str: str) -> Any:
if match: if match:
json_str = match.group(1).strip() json_str = match.group(1).strip()
error_buffer = io.StringIO() json_result = demjson3.decode(json_str, return_errors=True)
json_result = demjson3.decode( assert json_result is not None # by virtue of return_errors=True
json_str, return_errors=True, write_errors=error_buffer
)
if error_buffer.getvalue(): if json_result.errors:
logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}") logger.debug(
"JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
)
if json_result is None: if json_result.object is demjson3.undefined:
raise ValueError(f"Failed to parse JSON string: {json_str}") raise ValueError(
f"Failed to parse JSON string: {json_str}", *json_result.errors
)
return json_result.object return json_result.object