feat(agent/core): Add max_output_tokens parameter to create_chat_completion interface

This commit is contained in:
Reinier van der Leer
2024-04-18 21:48:22 +02:00
parent 35ebb10378
commit 7bb7c30842
3 changed files with 8 additions and 2 deletions

View File

@@ -416,12 +416,17 @@ class OpenAIProvider(
model_name: OpenAIModelName, model_name: OpenAIModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None, functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ChatModelResponse[_T]: ) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API.""" """Create a completion using the OpenAI API."""
openai_messages, completion_kwargs = self._get_chat_completion_args( openai_messages, completion_kwargs = self._get_chat_completion_args(
model_prompt, model_name, functions, **kwargs model_prompt=model_prompt,
model_name=model_name,
functions=functions,
max_tokens=max_output_tokens,
**kwargs,
) )
tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs) tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)

View File

@@ -357,6 +357,7 @@ class ChatModelProvider(ModelProvider):
model_name: str, model_name: str,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None, completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None, functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ChatModelResponse[_T]: ) -> ChatModelResponse[_T]:
... ...

View File

@@ -160,7 +160,7 @@ async def _process_text(
model_prompt=prompt.messages, model_prompt=prompt.messages,
model_name=model, model_name=model,
temperature=0.5, temperature=0.5,
max_tokens=max_result_tokens, max_output_tokens=max_result_tokens,
completion_parser=lambda s: ( completion_parser=lambda s: (
extract_list_from_json(s.content) if output_type is not str else None extract_list_from_json(s.content) if output_type is not str else None
), ),