mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-17 22:14:28 +01:00
refactor(agent): Refactor & improve create_chat_completion (#7082)
* refactor(agent/core): Rearrange and split up `OpenAIProvider.create_chat_completion` - Rearrange to reduce complexity, improve separation/abstraction of concerns, and allow multiple points of failure during parsing - Move conversion from `ChatMessage` to `openai.types.ChatCompletionMessageParam` to `_get_chat_completion_args` - Move token usage and cost tracking boilerplate code to `_create_chat_completion` - Move tool call conversion/parsing to `_parse_assistant_tool_calls` (new) * fix(agent/core): Handle decoding of function call arguments in `create_chat_completion` - Amend `model_providers.schema`: change type of `arguments` from `str` to `dict[str, Any]` on `AssistantFunctionCall` and `AssistantFunctionCallDict` - Implement robust and transparent parsing in `OpenAIProvider._parse_assistant_tool_calls` - Remove now unnecessary `json_loads` calls throughout codebase * feat(agent/utils): Improve conditions and errors in `json_loads` - Include all decoding errors when raising a ValueError on decode failure - Use errors returned by `return_errors` instead of an error buffer - Fix check for decode failure
This commit is contained in:
committed by
GitHub
parent
d7f00a996f
commit
7082e63b11
@@ -15,7 +15,6 @@ from autogpt.core.resource.model_providers.schema import (
|
|||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
from autogpt.core.utils.json_schema import JSONSchema
|
from autogpt.core.utils.json_schema import JSONSchema
|
||||||
from autogpt.core.utils.json_utils import json_loads
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -203,9 +202,7 @@ class AgentProfileGenerator(PromptStrategy):
|
|||||||
f"LLM did not call {self._create_agent_function.name} function; "
|
f"LLM did not call {self._create_agent_function.name} function; "
|
||||||
"agent profile creation failed"
|
"agent profile creation failed"
|
||||||
)
|
)
|
||||||
arguments: object = json_loads(
|
arguments: object = response_content.tool_calls[0].function.arguments
|
||||||
response_content.tool_calls[0].function.arguments
|
|
||||||
)
|
|
||||||
ai_profile = AIProfile(
|
ai_profile = AIProfile(
|
||||||
ai_name=arguments.get("name"),
|
ai_name=arguments.get("name"),
|
||||||
ai_role=arguments.get("description"),
|
ai_role=arguments.get("description"),
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import (
|
|||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
from autogpt.core.utils.json_schema import JSONSchema
|
from autogpt.core.utils.json_schema import JSONSchema
|
||||||
from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads
|
from autogpt.core.utils.json_utils import extract_dict_from_json
|
||||||
from autogpt.prompts.utils import format_numbered_list, indent
|
from autogpt.prompts.utils import format_numbered_list, indent
|
||||||
|
|
||||||
|
|
||||||
@@ -436,7 +436,7 @@ def extract_command(
|
|||||||
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
||||||
assistant_reply_json["command"] = {
|
assistant_reply_json["command"] = {
|
||||||
"name": assistant_reply.tool_calls[0].function.name,
|
"name": assistant_reply.tool_calls[0].function.name,
|
||||||
"args": json_loads(assistant_reply.tool_calls[0].function.arguments),
|
"args": assistant_reply.tool_calls[0].function.arguments,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
if not isinstance(assistant_reply_json, dict):
|
if not isinstance(assistant_reply_json, dict):
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
|
|||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
from autogpt.core.utils.json_schema import JSONSchema
|
from autogpt.core.utils.json_schema import JSONSchema
|
||||||
from autogpt.core.utils.json_utils import json_loads
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -195,9 +194,7 @@ class InitialPlan(PromptStrategy):
|
|||||||
f"LLM did not call {self._create_plan_function.name} function; "
|
f"LLM did not call {self._create_plan_function.name} function; "
|
||||||
"plan creation failed"
|
"plan creation failed"
|
||||||
)
|
)
|
||||||
parsed_response: object = json_loads(
|
parsed_response: object = response_content.tool_calls[0].function.arguments
|
||||||
response_content.tool_calls[0].function.arguments
|
|
||||||
)
|
|
||||||
parsed_response["task_list"] = [
|
parsed_response["task_list"] = [
|
||||||
Task.parse_obj(task) for task in parsed_response["task_list"]
|
Task.parse_obj(task) for task in parsed_response["task_list"]
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from autogpt.core.resource.model_providers import (
|
|||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
from autogpt.core.utils.json_schema import JSONSchema
|
from autogpt.core.utils.json_schema import JSONSchema
|
||||||
from autogpt.core.utils.json_utils import json_loads
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -141,9 +140,7 @@ class NameAndGoals(PromptStrategy):
|
|||||||
f"LLM did not call {self._create_agent_function} function; "
|
f"LLM did not call {self._create_agent_function} function; "
|
||||||
"agent profile creation failed"
|
"agent profile creation failed"
|
||||||
)
|
)
|
||||||
parsed_response = json_loads(
|
parsed_response = response_content.tool_calls[0].function.arguments
|
||||||
response_content.tool_calls[0].function.arguments
|
|
||||||
)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.debug(f"Failed to parse this response content: {response_content}")
|
logger.debug(f"Failed to parse this response content: {response_content}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
|
|||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
from autogpt.core.utils.json_schema import JSONSchema
|
from autogpt.core.utils.json_schema import JSONSchema
|
||||||
from autogpt.core.utils.json_utils import json_loads
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -188,9 +187,7 @@ class NextAbility(PromptStrategy):
|
|||||||
raise ValueError("LLM did not call any function")
|
raise ValueError("LLM did not call any function")
|
||||||
|
|
||||||
function_name = response_content.tool_calls[0].function.name
|
function_name = response_content.tool_calls[0].function.name
|
||||||
function_arguments = json_loads(
|
function_arguments = response_content.tool_calls[0].function.arguments
|
||||||
response_content.tool_calls[0].function.arguments
|
|
||||||
)
|
|
||||||
parsed_response = {
|
parsed_response = {
|
||||||
"motivation": function_arguments.pop("motivation"),
|
"motivation": function_arguments.pop("motivation"),
|
||||||
"self_criticism": function_arguments.pop("self_criticism"),
|
"self_criticism": function_arguments.pop("self_criticism"),
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
||||||
|
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
import tenacity
|
import tenacity
|
||||||
@@ -11,12 +11,17 @@ import tiktoken
|
|||||||
import yaml
|
import yaml
|
||||||
from openai._exceptions import APIStatusError, RateLimitError
|
from openai._exceptions import APIStatusError, RateLimitError
|
||||||
from openai.types import CreateEmbeddingResponse
|
from openai.types import CreateEmbeddingResponse
|
||||||
from openai.types.chat import ChatCompletion
|
from openai.types.chat import (
|
||||||
|
ChatCompletion,
|
||||||
|
ChatCompletionMessage,
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
)
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from autogpt.core.configuration import Configurable, UserConfigurable
|
from autogpt.core.configuration import Configurable, UserConfigurable
|
||||||
from autogpt.core.resource.model_providers.schema import (
|
from autogpt.core.resource.model_providers.schema import (
|
||||||
AssistantChatMessage,
|
AssistantChatMessage,
|
||||||
|
AssistantFunctionCall,
|
||||||
AssistantToolCall,
|
AssistantToolCall,
|
||||||
AssistantToolCallDict,
|
AssistantToolCallDict,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@@ -406,83 +411,90 @@ class OpenAIProvider(
|
|||||||
) -> ChatModelResponse[_T]:
|
) -> ChatModelResponse[_T]:
|
||||||
"""Create a completion using the OpenAI API."""
|
"""Create a completion using the OpenAI API."""
|
||||||
|
|
||||||
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
|
openai_messages, completion_kwargs = self._get_chat_completion_args(
|
||||||
tool_calls_compat_mode = functions and "tools" not in completion_kwargs
|
model_prompt, model_name, functions, **kwargs
|
||||||
if "messages" in completion_kwargs:
|
)
|
||||||
model_prompt += completion_kwargs["messages"]
|
tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)
|
||||||
del completion_kwargs["messages"]
|
|
||||||
|
|
||||||
cost = 0.0
|
total_cost = 0.0
|
||||||
attempts = 0
|
attempts = 0
|
||||||
while True:
|
while True:
|
||||||
_response = await self._create_chat_completion(
|
_response, _cost, t_input, t_output = await self._create_chat_completion(
|
||||||
messages=model_prompt,
|
messages=openai_messages,
|
||||||
**completion_kwargs,
|
**completion_kwargs,
|
||||||
)
|
)
|
||||||
|
total_cost += _cost
|
||||||
_assistant_msg = _response.choices[0].message
|
|
||||||
assistant_msg = AssistantChatMessage(
|
|
||||||
content=_assistant_msg.content,
|
|
||||||
tool_calls=(
|
|
||||||
[AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
|
|
||||||
if _assistant_msg.tool_calls
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
response = ChatModelResponse(
|
|
||||||
response=assistant_msg,
|
|
||||||
model_info=OPEN_AI_CHAT_MODELS[model_name],
|
|
||||||
prompt_tokens_used=(
|
|
||||||
_response.usage.prompt_tokens if _response.usage else 0
|
|
||||||
),
|
|
||||||
completion_tokens_used=(
|
|
||||||
_response.usage.completion_tokens if _response.usage else 0
|
|
||||||
),
|
|
||||||
)
|
|
||||||
cost += self._budget.update_usage_and_cost(response)
|
|
||||||
self._logger.debug(
|
|
||||||
f"Completion usage: {response.prompt_tokens_used} input, "
|
|
||||||
f"{response.completion_tokens_used} output - ${round(cost, 5)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If parsing the response fails, append the error to the prompt, and let the
|
# If parsing the response fails, append the error to the prompt, and let the
|
||||||
# LLM fix its mistake(s).
|
# LLM fix its mistake(s).
|
||||||
try:
|
attempts += 1
|
||||||
attempts += 1
|
parse_errors: list[Exception] = []
|
||||||
|
|
||||||
if (
|
_assistant_msg = _response.choices[0].message
|
||||||
tool_calls_compat_mode
|
|
||||||
and assistant_msg.content
|
tool_calls, _errors = self._parse_assistant_tool_calls(
|
||||||
and not assistant_msg.tool_calls
|
_assistant_msg, tool_calls_compat_mode
|
||||||
):
|
)
|
||||||
assistant_msg.tool_calls = list(
|
parse_errors += _errors
|
||||||
_tool_calls_compat_extract_calls(assistant_msg.content)
|
|
||||||
|
assistant_msg = AssistantChatMessage(
|
||||||
|
content=_assistant_msg.content,
|
||||||
|
tool_calls=tool_calls or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_result: _T = None # type: ignore
|
||||||
|
if not parse_errors:
|
||||||
|
try:
|
||||||
|
parsed_result = completion_parser(assistant_msg)
|
||||||
|
except Exception as e:
|
||||||
|
parse_errors.append(e)
|
||||||
|
|
||||||
|
if not parse_errors:
|
||||||
|
if attempts > 1:
|
||||||
|
self._logger.debug(
|
||||||
|
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
response.parsed_result = completion_parser(assistant_msg)
|
return ChatModelResponse(
|
||||||
break
|
response=AssistantChatMessage(
|
||||||
except Exception as e:
|
content=_assistant_msg.content,
|
||||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
tool_calls=tool_calls or None,
|
||||||
self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''")
|
),
|
||||||
sentry_sdk.capture_exception(
|
parsed_result=parsed_result,
|
||||||
error=e,
|
model_info=OPEN_AI_CHAT_MODELS[model_name],
|
||||||
extras={"assistant_msg": assistant_msg, "i_attempt": attempts},
|
prompt_tokens_used=t_input,
|
||||||
|
completion_tokens_used=t_output,
|
||||||
)
|
)
|
||||||
if attempts < self._configuration.fix_failed_parse_tries:
|
|
||||||
model_prompt.append(assistant_msg)
|
else:
|
||||||
model_prompt.append(
|
self._logger.debug(
|
||||||
ChatMessage.system(
|
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||||
"ERROR PARSING YOUR RESPONSE:\n\n"
|
)
|
||||||
f"{e.__class__.__name__}: {e}"
|
self._logger.warning(
|
||||||
)
|
f"Parsing attempt #{attempts} failed: {parse_errors}"
|
||||||
|
)
|
||||||
|
for e in parse_errors:
|
||||||
|
sentry_sdk.capture_exception(
|
||||||
|
error=e,
|
||||||
|
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if attempts < self._configuration.fix_failed_parse_tries:
|
||||||
|
openai_messages.append(_assistant_msg.dict(exclude_none=True))
|
||||||
|
openai_messages.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"ERROR PARSING YOUR RESPONSE:\n\n"
|
||||||
|
+ "\n\n".join(
|
||||||
|
f"{e.__class__.__name__}: {e}" for e in parse_errors
|
||||||
|
)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
raise
|
raise parse_errors[0]
|
||||||
|
|
||||||
if attempts > 1:
|
|
||||||
self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}")
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def create_embedding(
|
async def create_embedding(
|
||||||
self,
|
self,
|
||||||
@@ -504,21 +516,24 @@ class OpenAIProvider(
|
|||||||
self._budget.update_usage_and_cost(response)
|
self._budget.update_usage_and_cost(response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _get_completion_kwargs(
|
def _get_chat_completion_args(
|
||||||
self,
|
self,
|
||||||
|
model_prompt: list[ChatMessage],
|
||||||
model_name: OpenAIModelName,
|
model_name: OpenAIModelName,
|
||||||
functions: Optional[list[CompletionModelFunction]] = None,
|
functions: Optional[list[CompletionModelFunction]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]:
|
||||||
"""Get kwargs for completion API call.
|
"""Prepare chat completion arguments and keyword arguments for API call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The model to use.
|
model_prompt: List of ChatMessages.
|
||||||
kwargs: Keyword arguments to override the default values.
|
model_name: The model to use.
|
||||||
|
functions: Optional list of functions available to the LLM.
|
||||||
|
kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The kwargs for the chat API call.
|
list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
|
||||||
|
dict[str, Any]: Any other kwargs for the OpenAI call
|
||||||
"""
|
"""
|
||||||
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
||||||
|
|
||||||
@@ -541,7 +556,19 @@ class OpenAIProvider(
|
|||||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
||||||
kwargs["extra_headers"].update(extra_headers.copy())
|
kwargs["extra_headers"].update(extra_headers.copy())
|
||||||
|
|
||||||
return kwargs
|
if "messages" in kwargs:
|
||||||
|
model_prompt += kwargs["messages"]
|
||||||
|
del kwargs["messages"]
|
||||||
|
|
||||||
|
openai_messages: list[ChatCompletionMessageParam] = [
|
||||||
|
message.dict(
|
||||||
|
include={"role", "content", "tool_calls", "name"},
|
||||||
|
exclude_none=True,
|
||||||
|
)
|
||||||
|
for message in model_prompt
|
||||||
|
]
|
||||||
|
|
||||||
|
return openai_messages, kwargs
|
||||||
|
|
||||||
def _get_embedding_kwargs(
|
def _get_embedding_kwargs(
|
||||||
self,
|
self,
|
||||||
@@ -566,28 +593,106 @@ class OpenAIProvider(
|
|||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def _create_chat_completion(
|
async def _create_chat_completion(
|
||||||
self, messages: list[ChatMessage], *_, **kwargs
|
self,
|
||||||
) -> Coroutine[None, None, ChatCompletion]:
|
messages: list[ChatCompletionMessageParam],
|
||||||
"""Create a chat completion using the OpenAI API with retry handling."""
|
model: OpenAIModelName,
|
||||||
|
*_,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[ChatCompletion, float, int, int]:
|
||||||
|
"""
|
||||||
|
Create a chat completion using the OpenAI API with retry handling.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
openai_messages: List of OpenAI-consumable message dict objects
|
||||||
|
model: The model to use for the completion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatCompletion: The chat completion response object
|
||||||
|
float: The cost ($) of this completion
|
||||||
|
int: Number of prompt tokens used
|
||||||
|
int: Number of completion tokens used
|
||||||
|
"""
|
||||||
|
|
||||||
@self._retry_api_request
|
@self._retry_api_request
|
||||||
async def _create_chat_completion_with_retry(
|
async def _create_chat_completion_with_retry(
|
||||||
messages: list[ChatMessage], *_, **kwargs
|
messages: list[ChatCompletionMessageParam], **kwargs
|
||||||
) -> ChatCompletion:
|
) -> ChatCompletion:
|
||||||
raw_messages = [
|
|
||||||
message.dict(
|
|
||||||
include={"role", "content", "tool_calls", "name"},
|
|
||||||
exclude_none=True,
|
|
||||||
)
|
|
||||||
for message in messages
|
|
||||||
]
|
|
||||||
return await self._client.chat.completions.create(
|
return await self._client.chat.completions.create(
|
||||||
messages=raw_messages, # type: ignore
|
messages=messages, # type: ignore
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _create_chat_completion_with_retry(messages, *_, **kwargs)
|
completion = await _create_chat_completion_with_retry(
|
||||||
|
messages, model=model, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if completion.usage:
|
||||||
|
prompt_tokens_used = completion.usage.prompt_tokens
|
||||||
|
completion_tokens_used = completion.usage.completion_tokens
|
||||||
|
else:
|
||||||
|
prompt_tokens_used = completion_tokens_used = 0
|
||||||
|
|
||||||
|
cost = self._budget.update_usage_and_cost(
|
||||||
|
ChatModelResponse(
|
||||||
|
response=AssistantChatMessage(content=None),
|
||||||
|
model_info=OPEN_AI_CHAT_MODELS[model],
|
||||||
|
prompt_tokens_used=prompt_tokens_used,
|
||||||
|
completion_tokens_used=completion_tokens_used,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._logger.debug(
|
||||||
|
f"Completion usage: {prompt_tokens_used} input, "
|
||||||
|
f"{completion_tokens_used} output - ${round(cost, 5)}"
|
||||||
|
)
|
||||||
|
return completion, cost, prompt_tokens_used, completion_tokens_used
|
||||||
|
|
||||||
|
def _parse_assistant_tool_calls(
|
||||||
|
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
|
||||||
|
):
|
||||||
|
tool_calls: list[AssistantToolCall] = []
|
||||||
|
parse_errors: list[Exception] = []
|
||||||
|
|
||||||
|
if assistant_message.tool_calls:
|
||||||
|
for _tc in assistant_message.tool_calls:
|
||||||
|
try:
|
||||||
|
parsed_arguments = json_loads(_tc.function.arguments)
|
||||||
|
except Exception as e:
|
||||||
|
err_message = (
|
||||||
|
f"Decoding arguments for {_tc.function.name} failed: "
|
||||||
|
+ str(e.args[0])
|
||||||
|
)
|
||||||
|
parse_errors.append(
|
||||||
|
type(e)(err_message, *e.args[1:]).with_traceback(
|
||||||
|
e.__traceback__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
AssistantToolCall(
|
||||||
|
id=_tc.id,
|
||||||
|
type=_tc.type,
|
||||||
|
function=AssistantFunctionCall(
|
||||||
|
name=_tc.function.name,
|
||||||
|
arguments=parsed_arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# If parsing of all tool calls succeeds in the end, we ignore any issues
|
||||||
|
if len(tool_calls) == len(assistant_message.tool_calls):
|
||||||
|
parse_errors = []
|
||||||
|
|
||||||
|
elif compat_mode and assistant_message.content:
|
||||||
|
try:
|
||||||
|
tool_calls = list(
|
||||||
|
_tool_calls_compat_extract_calls(assistant_message.content)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
parse_errors.append(e)
|
||||||
|
|
||||||
|
return tool_calls, parse_errors
|
||||||
|
|
||||||
def _create_embedding(
|
def _create_embedding(
|
||||||
self, text: str, *_, **kwargs
|
self, text: str, *_, **kwargs
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import abc
|
|||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Generic,
|
Generic,
|
||||||
@@ -68,12 +69,12 @@ class ChatMessageDict(TypedDict):
|
|||||||
|
|
||||||
class AssistantFunctionCall(BaseModel):
|
class AssistantFunctionCall(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
arguments: str
|
arguments: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class AssistantFunctionCallDict(TypedDict):
|
class AssistantFunctionCallDict(TypedDict):
|
||||||
name: str
|
name: str
|
||||||
arguments: str
|
arguments: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class AssistantToolCall(BaseModel):
|
class AssistantToolCall(BaseModel):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -32,16 +31,18 @@ def json_loads(json_str: str) -> Any:
|
|||||||
if match:
|
if match:
|
||||||
json_str = match.group(1).strip()
|
json_str = match.group(1).strip()
|
||||||
|
|
||||||
error_buffer = io.StringIO()
|
json_result = demjson3.decode(json_str, return_errors=True)
|
||||||
json_result = demjson3.decode(
|
assert json_result is not None # by virtue of return_errors=True
|
||||||
json_str, return_errors=True, write_errors=error_buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
if error_buffer.getvalue():
|
if json_result.errors:
|
||||||
logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}")
|
logger.debug(
|
||||||
|
"JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
|
||||||
|
)
|
||||||
|
|
||||||
if json_result is None:
|
if json_result.object is demjson3.undefined:
|
||||||
raise ValueError(f"Failed to parse JSON string: {json_str}")
|
raise ValueError(
|
||||||
|
f"Failed to parse JSON string: {json_str}", *json_result.errors
|
||||||
|
)
|
||||||
|
|
||||||
return json_result.object
|
return json_result.object
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user