feat(agent/llm/openai): Include compatibility tool call extraction in LLM response parse-fix loop

This commit is contained in:
Reinier van der Leer
2024-01-19 19:23:17 +01:00
parent 8c65f3c748
commit fc37ffdfcf

View File

@@ -377,29 +377,17 @@ class OpenAIProvider(
**completion_kwargs,
)
_response_msg = _response.choices[0].message
if (
tool_calls_compat_mode
and _response_msg.content
and not _response_msg.tool_calls
):
tool_calls = list(
_tool_calls_compat_extract_calls(_response_msg.content)
)
elif _response_msg.tool_calls:
tool_calls = [
AssistantToolCall(**tc.dict()) for tc in _response_msg.tool_calls
]
else:
tool_calls = None
assistant_message = AssistantChatMessage(
content=_response_msg.content,
tool_calls=tool_calls,
_assistant_msg = _response.choices[0].message
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=(
[AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
if _assistant_msg.tool_calls
else None
),
)
response = ChatModelResponse(
response=assistant_message,
response=assistant_msg,
model_info=OPEN_AI_CHAT_MODELS[model_name],
prompt_tokens_used=(
_response.usage.prompt_tokens if _response.usage else 0
@@ -418,11 +406,21 @@ class OpenAIProvider(
# LLM fix its mistake(s).
try:
attempts += 1
response.parsed_result = completion_parser(assistant_message)
if (
tool_calls_compat_mode
and assistant_msg.content
and not assistant_msg.tool_calls
):
assistant_msg.tool_calls = list(
_tool_calls_compat_extract_calls(assistant_msg.content)
)
response.parsed_result = completion_parser(assistant_msg)
break
except Exception as e:
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
self._logger.debug(f"Parsing failed on response: '''{_response_msg}'''")
self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''")
if attempts < self._configuration.fix_failed_parse_tries:
model_prompt.append(
ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}")
@@ -722,7 +720,7 @@ def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCal
else:
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
raise ValueError("Could not find tool calls block in response")
raise ValueError("Could not find tool_calls block in response")
tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
for t in tool_calls: