From fc37ffdfcf399942d6f9e83fe3fb0406cabdc0bc Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 19 Jan 2024 19:23:17 +0100 Subject: [PATCH] feat(agent/llm/openai): Include compatibility tool call extraction in LLM response parse-fix loop --- .../core/resource/model_providers/openai.py | 46 +++++++++---------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index 50628694..af26ba96 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -377,29 +377,17 @@ class OpenAIProvider( **completion_kwargs, ) - _response_msg = _response.choices[0].message - if ( - tool_calls_compat_mode - and _response_msg.content - and not _response_msg.tool_calls - ): - tool_calls = list( - _tool_calls_compat_extract_calls(_response_msg.content) - ) - elif _response_msg.tool_calls: - tool_calls = [ - AssistantToolCall(**tc.dict()) for tc in _response_msg.tool_calls - ] - else: - tool_calls = None - - assistant_message = AssistantChatMessage( - content=_response_msg.content, - tool_calls=tool_calls, + _assistant_msg = _response.choices[0].message + assistant_msg = AssistantChatMessage( + content=_assistant_msg.content, + tool_calls=( + [AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls] + if _assistant_msg.tool_calls + else None + ), ) - response = ChatModelResponse( - response=assistant_message, + response=assistant_msg, model_info=OPEN_AI_CHAT_MODELS[model_name], prompt_tokens_used=( _response.usage.prompt_tokens if _response.usage else 0 @@ -418,11 +406,21 @@ class OpenAIProvider( # LLM fix its mistake(s). try: attempts += 1 - response.parsed_result = completion_parser(assistant_message) + + if ( + tool_calls_compat_mode + and assistant_msg.content + and not assistant_msg.tool_calls + ): + assistant_msg.tool_calls = list( + _tool_calls_compat_extract_calls(assistant_msg.content) + ) + + response.parsed_result = completion_parser(assistant_msg) break except Exception as e: self._logger.warning(f"Parsing attempt #{attempts} failed: {e}") - self._logger.debug(f"Parsing failed on response: '''{_response_msg}'''") + self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''") if attempts < self._configuration.fix_failed_parse_tries: model_prompt.append( ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}") @@ -722,7 +720,7 @@ def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCal else: block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL) if not block: - raise ValueError("Could not find tool calls block in response") + raise ValueError("Could not find tool_calls block in response") tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1)) for t in tool_calls: