From ad0c3ebf07ec384c559202ced43c6ec23bccee67 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Sun, 8 Oct 2023 18:05:08 -0700 Subject: [PATCH] Implement functions API compatibility mode for older OpenAI models --- .../core/resource/model_providers/openai.py | 151 +++++++++++++++++- .../core/resource/model_providers/schema.py | 2 +- 2 files changed, 148 insertions(+), 5 deletions(-) diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index 1cc2147c..37a672ea 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -3,7 +3,7 @@ import functools import logging import math import time -from typing import Callable, ParamSpec, TypeVar +from typing import Callable, Optional, ParamSpec, TypeVar import openai import tiktoken @@ -16,6 +16,7 @@ from autogpt.core.configuration import ( ) from autogpt.core.resource.model_providers.schema import ( AssistantChatMessageDict, + AssistantFunctionCallDict, ChatMessage, ChatModelInfo, ChatModelProvider, @@ -33,6 +34,7 @@ from autogpt.core.resource.model_providers.schema import ( ModelProviderUsage, ModelTokenizer, ) +from autogpt.core.utils.json_schema import JSONSchema _T = TypeVar("_T") _P = ParamSpec("_P") @@ -263,11 +265,17 @@ class OpenAIProvider( model_prompt: list[ChatMessage], model_name: OpenAIModelName, completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None, - functions: list[CompletionModelFunction] = [], + functions: Optional[list[CompletionModelFunction]] = None, **kwargs, ) -> ChatModelResponse[_T]: """Create a completion using the OpenAI API.""" + completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs) + functions_compat_mode = functions and "functions" not in completion_kwargs + if "messages" in completion_kwargs: + model_prompt += completion_kwargs["messages"] + del completion_kwargs["messages"] + response = await self._create_chat_completion( messages=model_prompt, **completion_kwargs, @@ -279,6 +287,10 @@ class OpenAIProvider( } response_message = response.choices[0].message.to_dict_recursive() + if functions_compat_mode: + response_message["function_call"] = _functions_compat_extract_call( + response_message["content"] + ) response = ChatModelResponse( response=response_message, parsed_result=completion_parser(response_message), @@ -313,7 +325,7 @@ class OpenAIProvider( def _get_completion_kwargs( self, model_name: OpenAIModelName, - functions: list[CompletionModelFunction], + functions: Optional[list[CompletionModelFunction]] = None, **kwargs, ) -> dict: """Get kwargs for completion API call. @@ -331,8 +343,13 @@ class OpenAIProvider( **kwargs, **self._credentials.unmasked(), } + if functions: - completion_kwargs["functions"] = [f.schema for f in functions] + if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api: + completion_kwargs["functions"] = [f.schema for f in functions] + else: + # Provide compatibility with older models + _functions_compat_fix_kwargs(functions, completion_kwargs) return completion_kwargs @@ -459,3 +476,129 @@ class _OpenAIRetryHandler: self._backoff(attempt) return _wrapped + + +def format_function_specs_as_typescript_ns( + functions: list[CompletionModelFunction], +) -> str: + """Returns a function signature block in the format used by OpenAI internally: + https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 + + For use with `count_tokens` to determine token usage of provided functions. + + Example: + ```ts + namespace functions { + + // Get the current weather in a given location + type get_current_weather = (_: { + // The city and state, e.g. San Francisco, CA + location: string, + unit?: "celsius" | "fahrenheit", + }) => any; + + } // namespace functions + ``` + """ + + return ( + "namespace functions {\n\n" + + "\n\n".join(format_openai_function_for_prompt(f) for f in functions) + + "\n\n} // namespace functions" + ) + + +def format_openai_function_for_prompt(func: CompletionModelFunction) -> str: + """Returns the function formatted similarly to the way OpenAI does it internally: + https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 + + Example: + ```ts + // Get the current weather in a given location + type get_current_weather = (_: { + // The city and state, e.g. San Francisco, CA + location: string, + unit?: "celsius" | "fahrenheit", + }) => any; + ``` + """ + + def param_signature(name: str, spec: JSONSchema) -> str: + return ( + f"// {spec.description}\n" if spec.description else "" + ) + f"{name}{'' if spec.required else '?'}: {spec.typescript_type}," + + return "\n".join( + [ + f"// {func.description}", + f"type {func.name} = (_ :{{", + *[param_signature(name, p) for name, p in func.parameters.items()], + "}) => any;", + ] + ) + + +def count_openai_functions_tokens( + functions: list[CompletionModelFunction], count_tokens: Callable[[str], int] +) -> int: + """Returns the number of tokens taken up by a set of function definitions + + Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18 + """ + return count_tokens( + f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}" + ) + + +def _functions_compat_fix_kwargs( + functions: list[CompletionModelFunction], + completion_kwargs: dict, +): + function_definitions = format_function_specs_as_typescript_ns(functions) + function_call_schema = JSONSchema( + type=JSONSchema.Type.OBJECT, + properties={ + "name": JSONSchema( + description="The name of the function to call", + enum=[f.name for f in functions], + required=True, + ), + "arguments": JSONSchema( + description="The arguments for the function call", + type=JSONSchema.Type.OBJECT, + required=True, + ), + }, + ) + completion_kwargs["messages"] = [ + ChatMessage.system( + "# function_call instructions\n\n" + "Specify a '```function_call' block in your response," + " enclosing a function call in the form of a valid JSON object" + " that adheres to the following schema:\n\n" + f"{function_call_schema.to_dict()}\n\n" + "Put the function_call block at the end of your response" + " and include its fences if it is not the only content.\n\n" + "## functions\n\n" + "For the function call itself, use one of the following" + f" functions:\n\n{function_definitions}" + ), + ] + + +def _functions_compat_extract_call(response: str) -> AssistantFunctionCallDict: + import json + import re + + logging.debug(f"Trying to extract function call from response:\n{response}") + + if response[0] == "{": + function_call = json.loads(response) + else: + block = re.search(r"```(?:function_call)?\n(.*)\n```\s*$", response, re.DOTALL) + if not block: + raise ValueError("Could not find function call block in response") + function_call = json.loads(block.group(1)) + + function_call["arguments"] = str(function_call["arguments"]) # HACK + return function_call diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py index 4989afd5..14e5618c 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py @@ -333,7 +333,7 @@ class ChatModelProvider(ModelProvider): model_prompt: list[ChatMessage], model_name: str, completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None, - functions: list[CompletionModelFunction] = [], + functions: Optional[list[CompletionModelFunction]] = None, **kwargs, ) -> ChatModelResponse[_T]: ...