feat(agent/core): Add max_output_tokens parameter to create_chat_completion interface

This commit is contained in:
Reinier van der Leer
2024-04-18 21:48:22 +02:00
parent 35ebb10378
commit 7bb7c30842
3 changed files with 8 additions and 2 deletions

View File

@@ -416,12 +416,17 @@ class OpenAIProvider(
model_name: OpenAIModelName,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API."""
openai_messages, completion_kwargs = self._get_chat_completion_args(
model_prompt, model_name, functions, **kwargs
model_prompt=model_prompt,
model_name=model_name,
functions=functions,
max_tokens=max_output_tokens,
**kwargs,
)
tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)

View File

@@ -357,6 +357,7 @@ class ChatModelProvider(ModelProvider):
model_name: str,
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ChatModelResponse[_T]:
...

View File

@@ -160,7 +160,7 @@ async def _process_text(
model_prompt=prompt.messages,
model_name=model,
temperature=0.5,
max_tokens=max_result_tokens,
max_output_tokens=max_result_tokens,
completion_parser=lambda s: (
extract_list_from_json(s.content) if output_type is not str else None
),