mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-19 06:54:22 +01:00
refactor(agent/openai): Upgrade OpenAI library to v1
- Update `openai` dependency from ^v0.27.10 to ^v1.7.2 - Update poetry.lock - Update code for changed endpoints and new output types of OpenAI library - Replace uses of `AssistantChatMessageDict` by `AssistantChatMessage` - Update `PromptStrategy`, `BaseAgent`, and all of their subclasses accordingly - Update `OpenAIProvider`, `OpenAICredentials`, azure.yaml.template, .env.template and test_config.py to work with new separate `AzureOpenAI` client - Remove `_OpenAIRetryHandler` and implement retry mechanism with `tenacity` - Rewrite pytest fixture `cached_openai_client` (renamed from `patched_api_requestor`) for OpenAI v1 library
This commit is contained in:
@@ -76,9 +76,11 @@ OPENAI_API_KEY=your-openai-api-key
|
|||||||
## USE_AZURE - Use Azure OpenAI or not (Default: False)
|
## USE_AZURE - Use Azure OpenAI or not (Default: False)
|
||||||
# USE_AZURE=False
|
# USE_AZURE=False
|
||||||
|
|
||||||
## AZURE_CONFIG_FILE - The path to the azure.yaml file, relative to the AutoGPT root directory. (Default: azure.yaml)
|
## AZURE_CONFIG_FILE - The path to the azure.yaml file, relative to the folder containing this file. (Default: azure.yaml)
|
||||||
# AZURE_CONFIG_FILE=azure.yaml
|
# AZURE_CONFIG_FILE=azure.yaml
|
||||||
|
|
||||||
|
# AZURE_OPENAI_AD_TOKEN=
|
||||||
|
# AZURE_OPENAI_ENDPOINT=
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
### LLM MODELS
|
### LLM MODELS
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from autogpt.core.prompting import (
|
|||||||
)
|
)
|
||||||
from autogpt.core.prompting.utils import json_loads
|
from autogpt.core.prompting.utils import json_loads
|
||||||
from autogpt.core.resource.model_providers.schema import (
|
from autogpt.core.resource.model_providers.schema import (
|
||||||
AssistantChatMessageDict,
|
AssistantChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatModelProvider,
|
ChatModelProvider,
|
||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
@@ -186,7 +186,7 @@ class AgentProfileGenerator(PromptStrategy):
|
|||||||
|
|
||||||
def parse_response_content(
|
def parse_response_content(
|
||||||
self,
|
self,
|
||||||
response_content: AssistantChatMessageDict,
|
response_content: AssistantChatMessage,
|
||||||
) -> tuple[AIProfile, AIDirectives]:
|
) -> tuple[AIProfile, AIDirectives]:
|
||||||
"""Parse the actual text response from the objective model.
|
"""Parse the actual text response from the objective model.
|
||||||
|
|
||||||
@@ -198,16 +198,21 @@ class AgentProfileGenerator(PromptStrategy):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
arguments = json_loads(
|
if not response_content.tool_calls:
|
||||||
response_content["tool_calls"][0]["function"]["arguments"]
|
raise ValueError(
|
||||||
|
f"LLM did not call {self._create_agent_function.name} function; "
|
||||||
|
"agent profile creation failed"
|
||||||
|
)
|
||||||
|
arguments: object = json_loads(
|
||||||
|
response_content.tool_calls[0].function.arguments
|
||||||
)
|
)
|
||||||
ai_profile = AIProfile(
|
ai_profile = AIProfile(
|
||||||
ai_name=arguments.get("name"),
|
ai_name=arguments.get("name"),
|
||||||
ai_role=arguments.get("description"),
|
ai_role=arguments.get("description"),
|
||||||
)
|
)
|
||||||
ai_directives = AIDirectives(
|
ai_directives = AIDirectives(
|
||||||
best_practices=arguments["directives"].get("best_practices"),
|
best_practices=arguments.get("directives", {}).get("best_practices"),
|
||||||
constraints=arguments["directives"].get("constraints"),
|
constraints=arguments.get("directives", {}).get("constraints"),
|
||||||
resources=[],
|
resources=[],
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from pydantic import Field
|
|||||||
from autogpt.core.configuration import Configurable
|
from autogpt.core.configuration import Configurable
|
||||||
from autogpt.core.prompting import ChatPrompt
|
from autogpt.core.prompting import ChatPrompt
|
||||||
from autogpt.core.resource.model_providers import (
|
from autogpt.core.resource.model_providers import (
|
||||||
AssistantChatMessageDict,
|
AssistantChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatModelProvider,
|
ChatModelProvider,
|
||||||
)
|
)
|
||||||
@@ -172,14 +172,12 @@ class Agent(
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def parse_and_process_response(
|
def parse_and_process_response(
|
||||||
self, llm_response: AssistantChatMessageDict, *args, **kwargs
|
self, llm_response: AssistantChatMessage, *args, **kwargs
|
||||||
) -> Agent.ThoughtProcessOutput:
|
) -> Agent.ThoughtProcessOutput:
|
||||||
for plugin in self.config.plugins:
|
for plugin in self.config.plugins:
|
||||||
if not plugin.can_handle_post_planning():
|
if not plugin.can_handle_post_planning():
|
||||||
continue
|
continue
|
||||||
llm_response["content"] = plugin.post_planning(
|
llm_response.content = plugin.post_planning(llm_response.content or "")
|
||||||
llm_response.get("content", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
(
|
||||||
command_name,
|
command_name,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
from autogpt.core.prompting.base import PromptStrategy
|
from autogpt.core.prompting.base import PromptStrategy
|
||||||
from autogpt.core.resource.model_providers.schema import (
|
from autogpt.core.resource.model_providers.schema import (
|
||||||
AssistantChatMessageDict,
|
AssistantChatMessage,
|
||||||
ChatModelInfo,
|
ChatModelInfo,
|
||||||
ChatModelProvider,
|
ChatModelProvider,
|
||||||
ChatModelResponse,
|
ChatModelResponse,
|
||||||
@@ -410,7 +410,7 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_and_process_response(
|
def parse_and_process_response(
|
||||||
self,
|
self,
|
||||||
llm_response: AssistantChatMessageDict,
|
llm_response: AssistantChatMessage,
|
||||||
prompt: ChatPrompt,
|
prompt: ChatPrompt,
|
||||||
scratchpad: PromptScratchpad,
|
scratchpad: PromptScratchpad,
|
||||||
) -> ThoughtProcessOutput:
|
) -> ThoughtProcessOutput:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from autogpt.core.prompting import (
|
|||||||
PromptStrategy,
|
PromptStrategy,
|
||||||
)
|
)
|
||||||
from autogpt.core.resource.model_providers.schema import (
|
from autogpt.core.resource.model_providers.schema import (
|
||||||
AssistantChatMessageDict,
|
AssistantChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
@@ -386,12 +386,12 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
|
|
||||||
def parse_response_content(
|
def parse_response_content(
|
||||||
self,
|
self,
|
||||||
response: AssistantChatMessageDict,
|
response: AssistantChatMessage,
|
||||||
) -> Agent.ThoughtProcessOutput:
|
) -> Agent.ThoughtProcessOutput:
|
||||||
if "content" not in response:
|
if not response.content:
|
||||||
raise InvalidAgentResponseError("Assistant response has no text content")
|
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||||
|
|
||||||
assistant_reply_dict = extract_dict_from_response(response["content"])
|
assistant_reply_dict = extract_dict_from_response(response.content)
|
||||||
|
|
||||||
_, errors = self.response_schema.validate_object(
|
_, errors = self.response_schema.validate_object(
|
||||||
object=assistant_reply_dict,
|
object=assistant_reply_dict,
|
||||||
@@ -417,14 +417,14 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
|||||||
|
|
||||||
def extract_command(
|
def extract_command(
|
||||||
assistant_reply_json: dict,
|
assistant_reply_json: dict,
|
||||||
assistant_reply: AssistantChatMessageDict,
|
assistant_reply: AssistantChatMessage,
|
||||||
use_openai_functions_api: bool,
|
use_openai_functions_api: bool,
|
||||||
) -> tuple[str, dict[str, str]]:
|
) -> tuple[str, dict[str, str]]:
|
||||||
"""Parse the response and return the command name and arguments
|
"""Parse the response and return the command name and arguments
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
assistant_reply_json (dict): The response object from the AI
|
assistant_reply_json (dict): The response object from the AI
|
||||||
assistant_reply (ChatModelResponse): The model response from the AI
|
assistant_reply (AssistantChatMessage): The model response from the AI
|
||||||
config (Config): The config object
|
config (Config): The config object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -436,13 +436,11 @@ def extract_command(
|
|||||||
Exception: If any other error occurs
|
Exception: If any other error occurs
|
||||||
"""
|
"""
|
||||||
if use_openai_functions_api:
|
if use_openai_functions_api:
|
||||||
if not assistant_reply.get("tool_calls"):
|
if not assistant_reply.tool_calls:
|
||||||
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
||||||
assistant_reply_json["command"] = {
|
assistant_reply_json["command"] = {
|
||||||
"name": assistant_reply["tool_calls"][0]["function"]["name"],
|
"name": assistant_reply.tool_calls[0].function.name,
|
||||||
"args": json.loads(
|
"args": json.loads(assistant_reply.tool_calls[0].function.arguments),
|
||||||
assistant_reply["tool_calls"][0]["function"]["arguments"]
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
if not isinstance(assistant_reply_json, dict):
|
if not isinstance(assistant_reply_json, dict):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
@@ -34,6 +33,7 @@ from autogpt.commands.system import finish
|
|||||||
from autogpt.commands.user_interaction import ask_user
|
from autogpt.commands.user_interaction import ask_user
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||||
|
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||||
from autogpt.file_workspace import (
|
from autogpt.file_workspace import (
|
||||||
FileWorkspace,
|
FileWorkspace,
|
||||||
FileWorkspaceBackendName,
|
FileWorkspaceBackendName,
|
||||||
@@ -414,8 +414,8 @@ class AgentProtocolServer:
|
|||||||
"""
|
"""
|
||||||
Configures the LLM provider with headers to link outgoing requests to the task.
|
Configures the LLM provider with headers to link outgoing requests to the task.
|
||||||
"""
|
"""
|
||||||
task_llm_provider = copy.deepcopy(self.llm_provider)
|
task_llm_provider_config = self.llm_provider._configuration.copy(deep=True)
|
||||||
_extra_request_headers = task_llm_provider._configuration.extra_request_headers
|
_extra_request_headers = task_llm_provider_config.extra_request_headers
|
||||||
|
|
||||||
_extra_request_headers["AP-TaskID"] = task.task_id
|
_extra_request_headers["AP-TaskID"] = task.task_id
|
||||||
if step_id:
|
if step_id:
|
||||||
@@ -423,7 +423,15 @@ class AgentProtocolServer:
|
|||||||
if task.additional_input and (user_id := task.additional_input.get("user_id")):
|
if task.additional_input and (user_id := task.additional_input.get("user_id")):
|
||||||
_extra_request_headers["AutoGPT-UserID"] = user_id
|
_extra_request_headers["AutoGPT-UserID"] = user_id
|
||||||
|
|
||||||
return task_llm_provider
|
if isinstance(self.llm_provider, OpenAIProvider):
|
||||||
|
settings = self.llm_provider._settings.copy()
|
||||||
|
settings.configuration = task_llm_provider_config
|
||||||
|
return OpenAIProvider(
|
||||||
|
settings=settings,
|
||||||
|
logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.llm_provider
|
||||||
|
|
||||||
|
|
||||||
def task_agent_id(task_id: str | int) -> str:
|
def task_agent_id(task_id: str | int) -> str:
|
||||||
|
|||||||
@@ -190,9 +190,9 @@ def check_model(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
||||||
api_manager = ApiManager()
|
api_manager = ApiManager()
|
||||||
models = api_manager.get_models(**api_credentials.get_api_access_kwargs(model_name))
|
models = api_manager.get_models(api_credentials)
|
||||||
|
|
||||||
if any(model_name in m["id"] for m in models):
|
if any(model_name == m.id for m in models):
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import uuid
|
|||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import openai
|
|
||||||
import requests
|
import requests
|
||||||
|
from openai import OpenAI
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from autogpt.agents.agent import Agent
|
from autogpt.agents.agent import Agent
|
||||||
@@ -142,17 +142,18 @@ def generate_image_with_dalle(
|
|||||||
)
|
)
|
||||||
size = closest
|
size = closest
|
||||||
|
|
||||||
response = openai.Image.create(
|
response = OpenAI(
|
||||||
|
api_key=agent.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||||
|
).images.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
n=1,
|
n=1,
|
||||||
size=f"{size}x{size}",
|
size=f"{size}x{size}",
|
||||||
response_format="b64_json",
|
response_format="b64_json",
|
||||||
api_key=agent.legacy_config.openai_credentials.api_key.get_secret_value(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Image Generated for prompt:{prompt}")
|
logger.info(f"Image Generated for prompt:{prompt}")
|
||||||
|
|
||||||
image_data = b64decode(response["data"][0]["b64_json"])
|
image_data = b64decode(response.data[0].b64_json)
|
||||||
|
|
||||||
with open(output_file, mode="wb") as png:
|
with open(output_file, mode="wb") as png:
|
||||||
png.write(image_data)
|
png.write(image_data)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from autogpt.core.prompting import PromptStrategy
|
|||||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||||
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
||||||
from autogpt.core.resource.model_providers import (
|
from autogpt.core.resource.model_providers import (
|
||||||
AssistantChatMessageDict,
|
AssistantChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
@@ -178,7 +178,7 @@ class InitialPlan(PromptStrategy):
|
|||||||
|
|
||||||
def parse_response_content(
|
def parse_response_content(
|
||||||
self,
|
self,
|
||||||
response_content: AssistantChatMessageDict,
|
response_content: AssistantChatMessage,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Parse the actual text response from the objective model.
|
"""Parse the actual text response from the objective model.
|
||||||
|
|
||||||
@@ -189,8 +189,13 @@ class InitialPlan(PromptStrategy):
|
|||||||
The parsed response.
|
The parsed response.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
parsed_response = json_loads(
|
if not response_content.tool_calls:
|
||||||
response_content["tool_calls"][0]["function"]["arguments"]
|
raise ValueError(
|
||||||
|
f"LLM did not call {self._create_plan_function.name} function; "
|
||||||
|
"plan creation failed"
|
||||||
|
)
|
||||||
|
parsed_response: object = json_loads(
|
||||||
|
response_content.tool_calls[0].function.arguments
|
||||||
)
|
)
|
||||||
parsed_response["task_list"] = [
|
parsed_response["task_list"] = [
|
||||||
Task.parse_obj(task) for task in parsed_response["task_list"]
|
Task.parse_obj(task) for task in parsed_response["task_list"]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from autogpt.core.prompting import PromptStrategy
|
|||||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||||
from autogpt.core.prompting.utils import json_loads
|
from autogpt.core.prompting.utils import json_loads
|
||||||
from autogpt.core.resource.model_providers import (
|
from autogpt.core.resource.model_providers import (
|
||||||
AssistantChatMessageDict,
|
AssistantChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
@@ -124,7 +124,7 @@ class NameAndGoals(PromptStrategy):
|
|||||||
|
|
||||||
def parse_response_content(
|
def parse_response_content(
|
||||||
self,
|
self,
|
||||||
response_content: AssistantChatMessageDict,
|
response_content: AssistantChatMessage,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Parse the actual text response from the objective model.
|
"""Parse the actual text response from the objective model.
|
||||||
|
|
||||||
@@ -136,8 +136,13 @@ class NameAndGoals(PromptStrategy):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if not response_content.tool_calls:
|
||||||
|
raise ValueError(
|
||||||
|
f"LLM did not call {self._create_agent_function} function; "
|
||||||
|
"agent profile creation failed"
|
||||||
|
)
|
||||||
parsed_response = json_loads(
|
parsed_response = json_loads(
|
||||||
response_content["tool_calls"][0]["function"]["arguments"]
|
response_content.tool_calls[0].function.arguments
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.debug(f"Failed to parse this response content: {response_content}")
|
logger.debug(f"Failed to parse this response content: {response_content}")
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from autogpt.core.prompting import PromptStrategy
|
|||||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||||
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
||||||
from autogpt.core.resource.model_providers import (
|
from autogpt.core.resource.model_providers import (
|
||||||
AssistantChatMessageDict,
|
AssistantChatMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
CompletionModelFunction,
|
CompletionModelFunction,
|
||||||
)
|
)
|
||||||
@@ -171,7 +171,7 @@ class NextAbility(PromptStrategy):
|
|||||||
|
|
||||||
def parse_response_content(
|
def parse_response_content(
|
||||||
self,
|
self,
|
||||||
response_content: AssistantChatMessageDict,
|
response_content: AssistantChatMessage,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Parse the actual text response from the objective model.
|
"""Parse the actual text response from the objective model.
|
||||||
|
|
||||||
@@ -183,9 +183,12 @@ class NextAbility(PromptStrategy):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
function_name = response_content["tool_calls"][0]["function"]["name"]
|
if not response_content.tool_calls:
|
||||||
|
raise ValueError("LLM did not call any function")
|
||||||
|
|
||||||
|
function_name = response_content.tool_calls[0].function.name
|
||||||
function_arguments = json_loads(
|
function_arguments = json_loads(
|
||||||
response_content["tool_calls"][0]["function"]["arguments"]
|
response_content.tool_calls[0].function.arguments
|
||||||
)
|
)
|
||||||
parsed_response = {
|
parsed_response = {
|
||||||
"motivation": function_arguments.pop("motivation"),
|
"motivation": function_arguments.pop("motivation"),
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import abc
|
import abc
|
||||||
|
|
||||||
from autogpt.core.configuration import SystemConfiguration
|
from autogpt.core.configuration import SystemConfiguration
|
||||||
from autogpt.core.resource.model_providers import AssistantChatMessageDict
|
from autogpt.core.resource.model_providers import AssistantChatMessage
|
||||||
|
|
||||||
from .schema import ChatPrompt, LanguageModelClassification
|
from .schema import ChatPrompt, LanguageModelClassification
|
||||||
|
|
||||||
@@ -19,5 +19,5 @@ class PromptStrategy(abc.ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def parse_response_content(self, response_content: AssistantChatMessageDict):
|
def parse_response_content(self, response_content: AssistantChatMessage):
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -1,21 +1,22 @@
|
|||||||
import enum
|
import enum
|
||||||
import functools
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, ParamSpec, TypeVar
|
from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
||||||
|
|
||||||
import openai
|
import tenacity
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import yaml
|
import yaml
|
||||||
from openai.error import APIError, RateLimitError
|
from openai._exceptions import APIStatusError, RateLimitError
|
||||||
|
from openai.types import CreateEmbeddingResponse
|
||||||
|
from openai.types.chat import ChatCompletion
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from autogpt.core.configuration import Configurable, UserConfigurable
|
from autogpt.core.configuration import Configurable, UserConfigurable
|
||||||
from autogpt.core.resource.model_providers.schema import (
|
from autogpt.core.resource.model_providers.schema import (
|
||||||
AssistantChatMessageDict,
|
AssistantChatMessage,
|
||||||
|
AssistantToolCall,
|
||||||
AssistantToolCallDict,
|
AssistantToolCallDict,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatModelInfo,
|
ChatModelInfo,
|
||||||
@@ -166,7 +167,6 @@ OPEN_AI_MODELS = {
|
|||||||
|
|
||||||
class OpenAIConfiguration(ModelProviderConfiguration):
|
class OpenAIConfiguration(ModelProviderConfiguration):
|
||||||
fix_failed_parse_tries: int = UserConfigurable(3)
|
fix_failed_parse_tries: int = UserConfigurable(3)
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAICredentials(ModelProviderCredentials):
|
class OpenAICredentials(ModelProviderCredentials):
|
||||||
@@ -187,32 +187,45 @@ class OpenAICredentials(ModelProviderCredentials):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION")
|
api_version: str = UserConfigurable("", from_env="OPENAI_API_VERSION")
|
||||||
|
azure_endpoint: Optional[SecretStr] = None
|
||||||
azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
|
azure_model_to_deploy_id_map: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
def get_api_access_kwargs(self, model: str = "") -> dict[str, str]:
|
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||||
credentials = {k: v for k, v in self.unmasked().items() if type(v) is str}
|
kwargs = {
|
||||||
|
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||||
|
for k, v in {
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"base_url": self.api_base,
|
||||||
|
"organization": self.organization,
|
||||||
|
}.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
if self.api_type == "azure":
|
||||||
|
kwargs["api_version"] = self.api_version
|
||||||
|
kwargs["azure_endpoint"] = self.azure_endpoint
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def get_model_access_kwargs(self, model: str) -> dict[str, str]:
|
||||||
|
kwargs = {"model": model}
|
||||||
if self.api_type == "azure" and model:
|
if self.api_type == "azure" and model:
|
||||||
azure_credentials = self._get_azure_access_kwargs(model)
|
azure_kwargs = self._get_azure_access_kwargs(model)
|
||||||
credentials.update(azure_credentials)
|
kwargs.update(azure_kwargs)
|
||||||
return credentials
|
return kwargs
|
||||||
|
|
||||||
def load_azure_config(self, config_file: Path) -> None:
|
def load_azure_config(self, config_file: Path) -> None:
|
||||||
with open(config_file) as file:
|
with open(config_file) as file:
|
||||||
config_params = yaml.load(file, Loader=yaml.FullLoader) or {}
|
config_params = yaml.load(file, Loader=yaml.FullLoader) or {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert (
|
|
||||||
azure_api_base := config_params.get("azure_api_base", "")
|
|
||||||
) != "", "Azure API base URL not set"
|
|
||||||
assert config_params.get(
|
assert config_params.get(
|
||||||
"azure_model_map", {}
|
"azure_model_map", {}
|
||||||
), "Azure model->deployment_id map is empty"
|
), "Azure model->deployment_id map is empty"
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
raise ValueError(*e.args)
|
raise ValueError(*e.args)
|
||||||
|
|
||||||
self.api_base = SecretStr(azure_api_base)
|
|
||||||
self.api_type = config_params.get("azure_api_type", "azure")
|
self.api_type = config_params.get("azure_api_type", "azure")
|
||||||
self.api_version = config_params.get("azure_api_version", "")
|
self.api_version = config_params.get("azure_api_version", "")
|
||||||
|
self.azure_endpoint = config_params.get("azure_endpoint")
|
||||||
self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
|
self.azure_model_to_deploy_id_map = config_params.get("azure_model_map")
|
||||||
|
|
||||||
def _get_azure_access_kwargs(self, model: str) -> dict[str, str]:
|
def _get_azure_access_kwargs(self, model: str) -> dict[str, str]:
|
||||||
@@ -225,10 +238,7 @@ class OpenAICredentials(ModelProviderCredentials):
|
|||||||
raise ValueError(f"No Azure deployment ID configured for model '{model}'")
|
raise ValueError(f"No Azure deployment ID configured for model '{model}'")
|
||||||
deployment_id = self.azure_model_to_deploy_id_map[model]
|
deployment_id = self.azure_model_to_deploy_id_map[model]
|
||||||
|
|
||||||
if model in OPEN_AI_EMBEDDING_MODELS:
|
return {"model": deployment_id}
|
||||||
return {"engine": deployment_id}
|
|
||||||
else:
|
|
||||||
return {"deployment_id": deployment_id}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModelProviderBudget(ModelProviderBudget):
|
class OpenAIModelProviderBudget(ModelProviderBudget):
|
||||||
@@ -273,21 +283,26 @@ class OpenAIProvider(
|
|||||||
settings: OpenAISettings,
|
settings: OpenAISettings,
|
||||||
logger: logging.Logger,
|
logger: logging.Logger,
|
||||||
):
|
):
|
||||||
|
self._settings = settings
|
||||||
|
|
||||||
assert settings.credentials, "Cannot create OpenAIProvider without credentials"
|
assert settings.credentials, "Cannot create OpenAIProvider without credentials"
|
||||||
self._configuration = settings.configuration
|
self._configuration = settings.configuration
|
||||||
self._credentials = settings.credentials
|
self._credentials = settings.credentials
|
||||||
self._budget = settings.budget
|
self._budget = settings.budget
|
||||||
|
|
||||||
|
if self._credentials.api_type == "azure":
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
|
||||||
|
# API key and org (if configured) are passed, the rest of the required
|
||||||
|
# credentials is loaded from the environment by the AzureOpenAI client.
|
||||||
|
self._client = AsyncAzureOpenAI(**self._credentials.get_api_access_kwargs())
|
||||||
|
else:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
|
||||||
|
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
|
||||||
retry_handler = _OpenAIRetryHandler(
|
|
||||||
logger=self._logger,
|
|
||||||
num_retries=self._configuration.retries_per_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._create_chat_completion = retry_handler(_create_chat_completion)
|
|
||||||
self._create_embedding = retry_handler(_create_embedding)
|
|
||||||
|
|
||||||
def get_token_limit(self, model_name: str) -> int:
|
def get_token_limit(self, model_name: str) -> int:
|
||||||
"""Get the token limit for a given model."""
|
"""Get the token limit for a given model."""
|
||||||
return OPEN_AI_MODELS[model_name].max_tokens
|
return OPEN_AI_MODELS[model_name].max_tokens
|
||||||
@@ -333,7 +348,7 @@ class OpenAIProvider(
|
|||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(encoding_model)
|
encoding = tiktoken.encoding_for_model(encoding_model)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
cls._logger.warning(
|
logging.getLogger(__class__.__name__).warning(
|
||||||
f"Model {model_name} not found. Defaulting to cl100k_base encoding."
|
f"Model {model_name} not found. Defaulting to cl100k_base encoding."
|
||||||
)
|
)
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
@@ -352,7 +367,7 @@ class OpenAIProvider(
|
|||||||
self,
|
self,
|
||||||
model_prompt: list[ChatMessage],
|
model_prompt: list[ChatMessage],
|
||||||
model_name: OpenAIModelName,
|
model_name: OpenAIModelName,
|
||||||
completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
|
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||||
functions: Optional[list[CompletionModelFunction]] = None,
|
functions: Optional[list[CompletionModelFunction]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ChatModelResponse[_T]:
|
) -> ChatModelResponse[_T]:
|
||||||
@@ -370,23 +385,33 @@ class OpenAIProvider(
|
|||||||
messages=model_prompt,
|
messages=model_prompt,
|
||||||
**completion_kwargs,
|
**completion_kwargs,
|
||||||
)
|
)
|
||||||
response_args = {
|
|
||||||
"model_info": OPEN_AI_CHAT_MODELS[model_name],
|
|
||||||
"prompt_tokens_used": response.usage.prompt_tokens,
|
|
||||||
"completion_tokens_used": response.usage.completion_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
response_message = response.choices[0].message.to_dict_recursive()
|
response_message = response.choices[0].message
|
||||||
if tool_calls_compat_mode:
|
if (
|
||||||
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
|
tool_calls_compat_mode
|
||||||
response_message["content"]
|
and response_message.content
|
||||||
|
and not response_message.tool_calls
|
||||||
|
):
|
||||||
|
tool_calls = list(
|
||||||
|
_tool_calls_compat_extract_calls(response_message.content)
|
||||||
)
|
)
|
||||||
|
elif response_message.tool_calls:
|
||||||
|
tool_calls = [
|
||||||
|
AssistantToolCall(**tc.dict()) for tc in response_message.tool_calls
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tool_calls = None
|
||||||
|
|
||||||
|
assistant_message = AssistantChatMessage(
|
||||||
|
content=response_message.content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
# If parsing the response fails, append the error to the prompt, and let the
|
# If parsing the response fails, append the error to the prompt, and let the
|
||||||
# LLM fix its mistake(s).
|
# LLM fix its mistake(s).
|
||||||
try:
|
try:
|
||||||
attempts += 1
|
attempts += 1
|
||||||
parsed_response = completion_parser(response_message)
|
parsed_response = completion_parser(assistant_message)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||||
@@ -401,9 +426,13 @@ class OpenAIProvider(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
response = ChatModelResponse(
|
response = ChatModelResponse(
|
||||||
response=response_message,
|
response=assistant_message,
|
||||||
parsed_result=parsed_response,
|
parsed_result=parsed_response,
|
||||||
**response_args,
|
model_info=OPEN_AI_CHAT_MODELS[model_name],
|
||||||
|
prompt_tokens_used=response.usage.prompt_tokens if response.usage else 0,
|
||||||
|
completion_tokens_used=(
|
||||||
|
response.usage.completion_tokens if response.usage else 0
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self._budget.update_usage_and_cost(response)
|
self._budget.update_usage_and_cost(response)
|
||||||
return response
|
return response
|
||||||
@@ -419,14 +448,11 @@ class OpenAIProvider(
|
|||||||
embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs)
|
embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs)
|
||||||
response = await self._create_embedding(text=text, **embedding_kwargs)
|
response = await self._create_embedding(text=text, **embedding_kwargs)
|
||||||
|
|
||||||
response_args = {
|
|
||||||
"model_info": OPEN_AI_EMBEDDING_MODELS[model_name],
|
|
||||||
"prompt_tokens_used": response.usage.prompt_tokens,
|
|
||||||
"completion_tokens_used": response.usage.completion_tokens,
|
|
||||||
}
|
|
||||||
response = EmbeddingModelResponse(
|
response = EmbeddingModelResponse(
|
||||||
**response_args,
|
embedding=embedding_parser(response.data[0].embedding),
|
||||||
embedding=embedding_parser(response.embeddings[0]),
|
model_info=OPEN_AI_EMBEDDING_MODELS[model_name],
|
||||||
|
prompt_tokens_used=response.usage.prompt_tokens,
|
||||||
|
completion_tokens_used=0,
|
||||||
)
|
)
|
||||||
self._budget.update_usage_and_cost(response)
|
self._budget.update_usage_and_cost(response)
|
||||||
return response
|
return response
|
||||||
@@ -447,34 +473,29 @@ class OpenAIProvider(
|
|||||||
The kwargs for the chat API call.
|
The kwargs for the chat API call.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
completion_kwargs = {
|
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
||||||
"model": model_name,
|
|
||||||
**kwargs,
|
|
||||||
**self._credentials.get_api_access_kwargs(model_name),
|
|
||||||
}
|
|
||||||
|
|
||||||
if functions:
|
if functions:
|
||||||
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
|
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
|
||||||
completion_kwargs["tools"] = [
|
kwargs["tools"] = [
|
||||||
{"type": "function", "function": f.schema} for f in functions
|
{"type": "function", "function": f.schema} for f in functions
|
||||||
]
|
]
|
||||||
if len(functions) == 1:
|
if len(functions) == 1:
|
||||||
# force the model to call the only specified function
|
# force the model to call the only specified function
|
||||||
completion_kwargs["tool_choice"] = {
|
kwargs["tool_choice"] = {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {"name": functions[0].name},
|
"function": {"name": functions[0].name},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Provide compatibility with older models
|
# Provide compatibility with older models
|
||||||
_functions_compat_fix_kwargs(functions, completion_kwargs)
|
_functions_compat_fix_kwargs(functions, kwargs)
|
||||||
|
|
||||||
if extra_headers := self._configuration.extra_request_headers:
|
if extra_headers := self._configuration.extra_request_headers:
|
||||||
if completion_kwargs.get("headers"):
|
kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update(
|
||||||
completion_kwargs["headers"].update(extra_headers)
|
extra_headers.copy()
|
||||||
else:
|
)
|
||||||
completion_kwargs["headers"] = extra_headers.copy()
|
|
||||||
|
|
||||||
return completion_kwargs
|
return kwargs
|
||||||
|
|
||||||
def _get_embedding_kwargs(
|
def _get_embedding_kwargs(
|
||||||
self,
|
self,
|
||||||
@@ -491,122 +512,84 @@ class OpenAIProvider(
|
|||||||
The kwargs for the embedding API call.
|
The kwargs for the embedding API call.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
embedding_kwargs = {
|
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
||||||
"model": model_name,
|
|
||||||
**kwargs,
|
|
||||||
**self._credentials.unmasked(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if extra_headers := self._configuration.extra_request_headers:
|
if extra_headers := self._configuration.extra_request_headers:
|
||||||
if embedding_kwargs.get("headers"):
|
kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update(
|
||||||
embedding_kwargs["headers"].update(extra_headers)
|
extra_headers.copy()
|
||||||
else:
|
)
|
||||||
embedding_kwargs["headers"] = extra_headers.copy()
|
|
||||||
|
|
||||||
return embedding_kwargs
|
return kwargs
|
||||||
|
|
||||||
|
def _create_chat_completion(
|
||||||
|
self, messages: list[ChatMessage], *_, **kwargs
|
||||||
|
) -> Coroutine[None, None, ChatCompletion]:
|
||||||
|
"""Create a chat completion using the OpenAI API with retry handling."""
|
||||||
|
|
||||||
|
@self._retry_api_request
|
||||||
|
async def _create_chat_completion_with_retry(
|
||||||
|
messages: list[ChatMessage], *_, **kwargs
|
||||||
|
) -> ChatCompletion:
|
||||||
|
raw_messages = [
|
||||||
|
message.dict(include={"role", "content", "tool_calls", "name"})
|
||||||
|
for message in messages
|
||||||
|
]
|
||||||
|
return await self._client.chat.completions.create(
|
||||||
|
messages=raw_messages, # type: ignore
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _create_chat_completion_with_retry(messages, *_, **kwargs)
|
||||||
|
|
||||||
|
def _create_embedding(
|
||||||
|
self, text: str, *_, **kwargs
|
||||||
|
) -> Coroutine[None, None, CreateEmbeddingResponse]:
|
||||||
|
"""Create an embedding using the OpenAI API with retry handling."""
|
||||||
|
|
||||||
|
@self._retry_api_request
|
||||||
|
async def _create_embedding_with_retry(
|
||||||
|
text: str, *_, **kwargs
|
||||||
|
) -> CreateEmbeddingResponse:
|
||||||
|
return await self._client.embeddings.create(
|
||||||
|
input=[text],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _create_embedding_with_retry(text, *_, **kwargs)
|
||||||
|
|
||||||
|
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
|
_log_retry_debug_message = tenacity.after_log(self._logger, logging.DEBUG)
|
||||||
|
|
||||||
|
def _log_on_fail(retry_state: tenacity.RetryCallState) -> None:
|
||||||
|
_log_retry_debug_message(retry_state)
|
||||||
|
|
||||||
|
if (
|
||||||
|
retry_state.attempt_number == 0
|
||||||
|
and retry_state.outcome
|
||||||
|
and isinstance(retry_state.outcome.exception(), RateLimitError)
|
||||||
|
):
|
||||||
|
self._logger.warning(
|
||||||
|
"Please double check that you have setup a PAID OpenAI API Account."
|
||||||
|
" You can read more here: "
|
||||||
|
"https://docs.agpt.co/setup/#getting-an-openai-api-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
return tenacity.retry(
|
||||||
|
retry=(
|
||||||
|
tenacity.retry_if_exception_type(RateLimitError)
|
||||||
|
| tenacity.retry_if_exception(
|
||||||
|
lambda e: isinstance(e, APIStatusError) and e.status_code == 502
|
||||||
|
)
|
||||||
|
),
|
||||||
|
wait=tenacity.wait_exponential(),
|
||||||
|
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||||
|
after=_log_on_fail,
|
||||||
|
)(func)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "OpenAIProvider()"
|
return "OpenAIProvider()"
|
||||||
|
|
||||||
|
|
||||||
async def _create_embedding(text: str, *_, **kwargs) -> openai.Embedding:
|
|
||||||
"""Embed text using the OpenAI API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text str: The text to embed.
|
|
||||||
model str: The name of the model to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The embedding.
|
|
||||||
"""
|
|
||||||
return await openai.Embedding.acreate(
|
|
||||||
input=[text],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_chat_completion(
|
|
||||||
messages: list[ChatMessage], *_, **kwargs
|
|
||||||
) -> openai.Completion:
|
|
||||||
"""Create a chat completion using the OpenAI API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The prompt to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The completion.
|
|
||||||
"""
|
|
||||||
raw_messages = [
|
|
||||||
message.dict(include={"role", "content", "tool_calls", "name"})
|
|
||||||
for message in messages
|
|
||||||
]
|
|
||||||
return await openai.ChatCompletion.acreate(
|
|
||||||
messages=raw_messages,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _OpenAIRetryHandler:
|
|
||||||
"""Retry Handler for OpenAI API call.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_retries int: Number of retries. Defaults to 10.
|
|
||||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
|
||||||
warn_user bool: Whether to warn the user. Defaults to True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_retry_limit_msg = "Error: Reached rate limit, passing..."
|
|
||||||
_api_key_error_msg = (
|
|
||||||
"Please double check that you have setup a PAID OpenAI API Account. You can "
|
|
||||||
"read more here: https://docs.agpt.co/setup/#getting-an-openai-api-key"
|
|
||||||
)
|
|
||||||
_backoff_msg = "Error: API Bad gateway. Waiting {backoff} seconds..."
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
logger: logging.Logger,
|
|
||||||
num_retries: int = 10,
|
|
||||||
backoff_base: float = 2.0,
|
|
||||||
warn_user: bool = True,
|
|
||||||
):
|
|
||||||
self._logger = logger
|
|
||||||
self._num_retries = num_retries
|
|
||||||
self._backoff_base = backoff_base
|
|
||||||
self._warn_user = warn_user
|
|
||||||
|
|
||||||
def _log_rate_limit_error(self) -> None:
|
|
||||||
self._logger.debug(self._retry_limit_msg)
|
|
||||||
if self._warn_user:
|
|
||||||
self._logger.warning(self._api_key_error_msg)
|
|
||||||
self._warn_user = False
|
|
||||||
|
|
||||||
def _backoff(self, attempt: int) -> None:
|
|
||||||
backoff = self._backoff_base ** (attempt + 2)
|
|
||||||
self._logger.debug(self._backoff_msg.format(backoff=backoff))
|
|
||||||
time.sleep(backoff)
|
|
||||||
|
|
||||||
def __call__(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def _wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
|
||||||
num_attempts = self._num_retries + 1 # +1 for the first attempt
|
|
||||||
for attempt in range(1, num_attempts + 1):
|
|
||||||
try:
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
except RateLimitError:
|
|
||||||
if attempt == num_attempts:
|
|
||||||
raise
|
|
||||||
self._log_rate_limit_error()
|
|
||||||
|
|
||||||
except APIError as e:
|
|
||||||
if (e.http_status != 502) or (attempt == num_attempts):
|
|
||||||
raise
|
|
||||||
|
|
||||||
self._backoff(attempt)
|
|
||||||
|
|
||||||
return _wrapped
|
|
||||||
|
|
||||||
|
|
||||||
def format_function_specs_as_typescript_ns(
|
def format_function_specs_as_typescript_ns(
|
||||||
functions: list[CompletionModelFunction],
|
functions: list[CompletionModelFunction],
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -730,7 +713,7 @@ def _functions_compat_fix_kwargs(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDict]:
|
def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -747,4 +730,4 @@ def _tool_calls_compat_extract_calls(response: str) -> list[AssistantToolCallDic
|
|||||||
for t in tool_calls:
|
for t in tool_calls:
|
||||||
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
|
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
|
||||||
|
|
||||||
return tool_calls
|
yield AssistantToolCall.parse_obj(t)
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class AssistantToolCallDict(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class AssistantChatMessage(ChatMessage):
|
class AssistantChatMessage(ChatMessage):
|
||||||
role: Literal["assistant"]
|
role: Literal["assistant"] = "assistant"
|
||||||
content: Optional[str]
|
content: Optional[str]
|
||||||
tool_calls: Optional[list[AssistantToolCall]]
|
tool_calls: Optional[list[AssistantToolCall]]
|
||||||
|
|
||||||
@@ -320,7 +320,7 @@ _T = TypeVar("_T")
|
|||||||
class ChatModelResponse(ModelResponse, Generic[_T]):
|
class ChatModelResponse(ModelResponse, Generic[_T]):
|
||||||
"""Standard response struct for a response from a language model."""
|
"""Standard response struct for a response from a language model."""
|
||||||
|
|
||||||
response: AssistantChatMessageDict
|
response: AssistantChatMessage
|
||||||
parsed_result: _T = None
|
parsed_result: _T = None
|
||||||
|
|
||||||
|
|
||||||
@@ -338,7 +338,7 @@ class ChatModelProvider(ModelProvider):
|
|||||||
self,
|
self,
|
||||||
model_prompt: list[ChatMessage],
|
model_prompt: list[ChatMessage],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
|
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||||
functions: Optional[list[CompletionModelFunction]] = None,
|
functions: Optional[list[CompletionModelFunction]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ChatModelResponse[_T]:
|
) -> ChatModelResponse[_T]:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from openai.util import logger as openai_logger
|
from openai._base_client import log as openai_logger
|
||||||
|
|
||||||
SIMPLE_LOG_FORMAT = "%(asctime)s %(levelname)s %(message)s"
|
SIMPLE_LOG_FORMAT = "%(asctime)s %(levelname)s %(message)s"
|
||||||
DEBUG_LOG_FORMAT = (
|
DEBUG_LOG_FORMAT = (
|
||||||
|
|||||||
@@ -3,10 +3,13 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import openai
|
from openai import OpenAI
|
||||||
from openai import Model
|
from openai.types import Model
|
||||||
|
|
||||||
from autogpt.core.resource.model_providers.openai import OPEN_AI_MODELS
|
from autogpt.core.resource.model_providers.openai import (
|
||||||
|
OPEN_AI_MODELS,
|
||||||
|
OpenAICredentials,
|
||||||
|
)
|
||||||
from autogpt.core.resource.model_providers.schema import ChatModelInfo
|
from autogpt.core.resource.model_providers.schema import ChatModelInfo
|
||||||
from autogpt.singleton import Singleton
|
from autogpt.singleton import Singleton
|
||||||
|
|
||||||
@@ -96,16 +99,17 @@ class ApiManager(metaclass=Singleton):
|
|||||||
"""
|
"""
|
||||||
return self.total_budget
|
return self.total_budget
|
||||||
|
|
||||||
def get_models(self, **openai_credentials) -> List[Model]:
|
def get_models(self, openai_credentials: OpenAICredentials) -> List[Model]:
|
||||||
"""
|
"""
|
||||||
Get list of available GPT models.
|
Get list of available GPT models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: List of available GPT models.
|
list[Model]: List of available GPT models.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.models is None:
|
if self.models is None:
|
||||||
all_models = openai.Model.list(**openai_credentials)["data"]
|
all_models = (
|
||||||
self.models = [model for model in all_models if "gpt" in model["id"]]
|
OpenAI(**openai_credentials.get_api_access_kwargs()).models.list().data
|
||||||
|
)
|
||||||
|
self.models = [model for model in all_models if "gpt" in model.id]
|
||||||
|
|
||||||
return self.models
|
return self.models
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||||
from openai.util import logger as openai_logger
|
from openai._base_client import log as openai_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
@@ -184,7 +184,7 @@ def configure_logging(
|
|||||||
json_logger.propagate = False
|
json_logger.propagate = False
|
||||||
|
|
||||||
# Disable debug logging from OpenAI library
|
# Disable debug logging from OpenAI library
|
||||||
openai_logger.setLevel(logging.INFO)
|
openai_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
def configure_chat_plugins(config: Config) -> None:
|
def configure_chat_plugins(config: Config) -> None:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
azure_api_type: azure
|
azure_api_type: azure
|
||||||
azure_api_base: your-base-url-for-azure
|
|
||||||
azure_api_version: api-version-for-azure
|
azure_api_version: api-version-for-azure
|
||||||
|
azure_endpoint: your-azure-openai-endpoint
|
||||||
azure_model_map:
|
azure_model_map:
|
||||||
gpt-3.5-turbo: gpt35-deployment-id-for-azure
|
gpt-3.5-turbo: gpt35-deployment-id-for-azure
|
||||||
gpt-4: gpt4-deployment-id-for-azure
|
gpt-4: gpt4-deployment-id-for-azure
|
||||||
|
|||||||
3270
autogpts/autogpt/poetry.lock
generated
3270
autogpts/autogpt/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -43,7 +43,7 @@ hypercorn = "^0.14.4"
|
|||||||
inflection = "*"
|
inflection = "*"
|
||||||
jsonschema = "*"
|
jsonschema = "*"
|
||||||
numpy = "*"
|
numpy = "*"
|
||||||
openai = "^0.27.10"
|
openai = "^1.7.2"
|
||||||
orjson = "^3.8.10"
|
orjson = "^3.8.10"
|
||||||
Pillow = "*"
|
Pillow = "*"
|
||||||
pinecone-client = "^2.2.1"
|
pinecone-client = "^2.2.1"
|
||||||
@@ -60,6 +60,7 @@ redis = "*"
|
|||||||
requests = "*"
|
requests = "*"
|
||||||
selenium = "^4.11.2"
|
selenium = "^4.11.2"
|
||||||
spacy = "^3.0.0"
|
spacy = "^3.0.0"
|
||||||
|
tenacity = "^8.2.2"
|
||||||
tiktoken = "^0.5.0"
|
tiktoken = "^0.5.0"
|
||||||
webdriver-manager = "*"
|
webdriver-manager = "*"
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def temp_plugins_config_file():
|
|||||||
yield config_file
|
yield config_file
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture(scope="function")
|
||||||
def config(
|
def config(
|
||||||
temp_plugins_config_file: Path,
|
temp_plugins_config_file: Path,
|
||||||
tmp_project_root: Path,
|
tmp_project_root: Path,
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ def test_json_memory_load_index(config: Config, memory_item: MemoryItem):
|
|||||||
|
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.requires_openai_api_key
|
@pytest.mark.requires_openai_api_key
|
||||||
def test_json_memory_get_relevant(config: Config, patched_api_requestor: None) -> None:
|
def test_json_memory_get_relevant(config: Config, cached_openai_client: None) -> None:
|
||||||
index = JSONFileMemory(config)
|
index = JSONFileMemory(config)
|
||||||
mem1 = MemoryItem.from_text_file("Sample text", "sample.txt", config)
|
mem1 = MemoryItem.from_text_file("Sample text", "sample.txt", config)
|
||||||
mem2 = MemoryItem.from_text_file(
|
mem2 = MemoryItem.from_text_file(
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ def image_size(request):
|
|||||||
|
|
||||||
@pytest.mark.requires_openai_api_key
|
@pytest.mark.requires_openai_api_key
|
||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
def test_dalle(agent: Agent, workspace, image_size, patched_api_requestor):
|
def test_dalle(agent: Agent, workspace, image_size, cached_openai_client):
|
||||||
"""Test DALL-E image generation."""
|
"""Test DALL-E image generation."""
|
||||||
generate_and_validate(
|
generate_and_validate(
|
||||||
agent,
|
agent,
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from autogpt.commands.web_selenium import BrowsingError, read_webpage
|
|||||||
@pytest.mark.vcr
|
@pytest.mark.vcr
|
||||||
@pytest.mark.requires_openai_api_key
|
@pytest.mark.requires_openai_api_key
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_browse_website_nonexistent_url(
|
async def test_browse_website_nonexistent_url(agent: Agent, cached_openai_client: None):
|
||||||
agent: Agent, patched_api_requestor: None
|
|
||||||
):
|
|
||||||
url = "https://auto-gpt-thinks-this-website-does-not-exist.com"
|
url = "https://auto-gpt-thinks-this-website-does-not-exist.com"
|
||||||
question = "How to execute a barrel roll"
|
question = "How to execute a barrel roll"
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
@@ -77,13 +75,3 @@ class TestApiManager:
|
|||||||
assert api_manager.get_total_prompt_tokens() == prompt_tokens
|
assert api_manager.get_total_prompt_tokens() == prompt_tokens
|
||||||
assert api_manager.get_total_completion_tokens() == 0
|
assert api_manager.get_total_completion_tokens() == 0
|
||||||
assert api_manager.get_total_cost() == (prompt_tokens * 0.0004) / 1000
|
assert api_manager.get_total_cost() == (prompt_tokens * 0.0004) / 1000
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def test_get_models():
|
|
||||||
"""Test if getting models works correctly."""
|
|
||||||
with patch("openai.Model.list") as mock_list_models:
|
|
||||||
mock_list_models.return_value = {"data": [{"id": "gpt-3.5-turbo"}]}
|
|
||||||
result = api_manager.get_models()
|
|
||||||
|
|
||||||
assert result[0]["id"] == "gpt-3.5-turbo"
|
|
||||||
assert api_manager.models[0]["id"] == "gpt-3.5-turbo"
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from unittest import mock
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.pagination import SyncPage
|
||||||
|
from openai.types import Model
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config
|
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config
|
||||||
@@ -80,8 +82,10 @@ def test_set_smart_llm(config: Config) -> None:
|
|||||||
config.smart_llm = smart_llm
|
config.smart_llm = smart_llm
|
||||||
|
|
||||||
|
|
||||||
@patch("openai.Model.list")
|
@patch("openai.resources.models.Models.list")
|
||||||
def test_smart_and_fast_llms_set_to_gpt4(mock_list_models: Any, config: Config) -> None:
|
def test_fallback_to_gpt3_if_gpt4_not_available(
|
||||||
|
mock_list_models: Any, config: Config
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test if models update to gpt-3.5-turbo if gpt-4 is not available.
|
Test if models update to gpt-3.5-turbo if gpt-4 is not available.
|
||||||
"""
|
"""
|
||||||
@@ -91,7 +95,10 @@ def test_smart_and_fast_llms_set_to_gpt4(mock_list_models: Any, config: Config)
|
|||||||
config.fast_llm = "gpt-4"
|
config.fast_llm = "gpt-4"
|
||||||
config.smart_llm = "gpt-4"
|
config.smart_llm = "gpt-4"
|
||||||
|
|
||||||
mock_list_models.return_value = {"data": [{"id": "gpt-3.5-turbo"}]}
|
mock_list_models.return_value = SyncPage(
|
||||||
|
data=[Model(id=GPT_3_MODEL, created=0, object="model", owned_by="AutoGPT")],
|
||||||
|
object="Models", # no idea what this should be, but irrelevant
|
||||||
|
)
|
||||||
|
|
||||||
apply_overrides_to_config(
|
apply_overrides_to_config(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -123,74 +130,80 @@ def test_missing_azure_config(config: Config) -> None:
|
|||||||
assert config.openai_credentials.azure_model_to_deploy_id_map is None
|
assert config.openai_credentials.azure_model_to_deploy_id_map is None
|
||||||
|
|
||||||
|
|
||||||
def test_azure_config(config: Config) -> None:
|
@pytest.fixture
|
||||||
|
def config_with_azure(config: Config):
|
||||||
config_file = config.app_data_dir / "azure_config.yaml"
|
config_file = config.app_data_dir / "azure_config.yaml"
|
||||||
config_file.write_text(
|
config_file.write_text(
|
||||||
f"""
|
f"""
|
||||||
azure_api_type: azure
|
azure_api_type: azure
|
||||||
azure_api_base: https://dummy.openai.azure.com
|
|
||||||
azure_api_version: 2023-06-01-preview
|
azure_api_version: 2023-06-01-preview
|
||||||
|
azure_endpoint: https://dummy.openai.azure.com
|
||||||
azure_model_map:
|
azure_model_map:
|
||||||
{config.fast_llm}: FAST-LLM_ID
|
{config.fast_llm}: FAST-LLM_ID
|
||||||
{config.smart_llm}: SMART-LLM_ID
|
{config.smart_llm}: SMART-LLM_ID
|
||||||
{config.embedding_model}: embedding-deployment-id-for-azure
|
{config.embedding_model}: embedding-deployment-id-for-azure
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ["USE_AZURE"] = "True"
|
os.environ["USE_AZURE"] = "True"
|
||||||
os.environ["AZURE_CONFIG_FILE"] = str(config_file)
|
os.environ["AZURE_CONFIG_FILE"] = str(config_file)
|
||||||
config = ConfigBuilder.build_config_from_env(project_root=config.project_root)
|
config_with_azure = ConfigBuilder.build_config_from_env(
|
||||||
|
project_root=config.project_root
|
||||||
assert (credentials := config.openai_credentials) is not None
|
|
||||||
assert credentials.api_type == "azure"
|
|
||||||
assert credentials.api_base == SecretStr("https://dummy.openai.azure.com")
|
|
||||||
assert credentials.api_version == "2023-06-01-preview"
|
|
||||||
assert credentials.azure_model_to_deploy_id_map == {
|
|
||||||
config.fast_llm: "FAST-LLM_ID",
|
|
||||||
config.smart_llm: "SMART-LLM_ID",
|
|
||||||
config.embedding_model: "embedding-deployment-id-for-azure",
|
|
||||||
}
|
|
||||||
|
|
||||||
fast_llm = config.fast_llm
|
|
||||||
smart_llm = config.smart_llm
|
|
||||||
assert (
|
|
||||||
credentials.get_api_access_kwargs(config.fast_llm)["deployment_id"]
|
|
||||||
== "FAST-LLM_ID"
|
|
||||||
)
|
)
|
||||||
assert (
|
yield config_with_azure
|
||||||
credentials.get_api_access_kwargs(config.smart_llm)["deployment_id"]
|
|
||||||
== "SMART-LLM_ID"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Emulate --gpt4only
|
|
||||||
config.fast_llm = smart_llm
|
|
||||||
assert (
|
|
||||||
credentials.get_api_access_kwargs(config.fast_llm)["deployment_id"]
|
|
||||||
== "SMART-LLM_ID"
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
credentials.get_api_access_kwargs(config.smart_llm)["deployment_id"]
|
|
||||||
== "SMART-LLM_ID"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Emulate --gpt3only
|
|
||||||
config.fast_llm = config.smart_llm = fast_llm
|
|
||||||
assert (
|
|
||||||
credentials.get_api_access_kwargs(config.fast_llm)["deployment_id"]
|
|
||||||
== "FAST-LLM_ID"
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
credentials.get_api_access_kwargs(config.smart_llm)["deployment_id"]
|
|
||||||
== "FAST-LLM_ID"
|
|
||||||
)
|
|
||||||
|
|
||||||
del os.environ["USE_AZURE"]
|
del os.environ["USE_AZURE"]
|
||||||
del os.environ["AZURE_CONFIG_FILE"]
|
del os.environ["AZURE_CONFIG_FILE"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_config(config_with_azure: Config) -> None:
|
||||||
|
assert (credentials := config_with_azure.openai_credentials) is not None
|
||||||
|
assert credentials.api_type == "azure"
|
||||||
|
assert credentials.api_version == "2023-06-01-preview"
|
||||||
|
assert credentials.azure_endpoint == SecretStr("https://dummy.openai.azure.com")
|
||||||
|
assert credentials.azure_model_to_deploy_id_map == {
|
||||||
|
config_with_azure.fast_llm: "FAST-LLM_ID",
|
||||||
|
config_with_azure.smart_llm: "SMART-LLM_ID",
|
||||||
|
config_with_azure.embedding_model: "embedding-deployment-id-for-azure",
|
||||||
|
}
|
||||||
|
|
||||||
|
fast_llm = config_with_azure.fast_llm
|
||||||
|
smart_llm = config_with_azure.smart_llm
|
||||||
|
assert (
|
||||||
|
credentials.get_model_access_kwargs(config_with_azure.fast_llm)["model"]
|
||||||
|
== "FAST-LLM_ID"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
credentials.get_model_access_kwargs(config_with_azure.smart_llm)["model"]
|
||||||
|
== "SMART-LLM_ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emulate --gpt4only
|
||||||
|
config_with_azure.fast_llm = smart_llm
|
||||||
|
assert (
|
||||||
|
credentials.get_model_access_kwargs(config_with_azure.fast_llm)["model"]
|
||||||
|
== "SMART-LLM_ID"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
credentials.get_model_access_kwargs(config_with_azure.smart_llm)["model"]
|
||||||
|
== "SMART-LLM_ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emulate --gpt3only
|
||||||
|
config_with_azure.fast_llm = config_with_azure.smart_llm = fast_llm
|
||||||
|
assert (
|
||||||
|
credentials.get_model_access_kwargs(config_with_azure.fast_llm)["model"]
|
||||||
|
== "FAST-LLM_ID"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
credentials.get_model_access_kwargs(config_with_azure.smart_llm)["model"]
|
||||||
|
== "FAST-LLM_ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_config_gpt4only(config: Config) -> None:
|
def test_create_config_gpt4only(config: Config) -> None:
|
||||||
with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models:
|
with mock.patch("autogpt.llm.api_manager.ApiManager.get_models") as mock_get_models:
|
||||||
mock_get_models.return_value = [{"id": GPT_4_MODEL}]
|
mock_get_models.return_value = [
|
||||||
|
Model(id=GPT_4_MODEL, created=0, object="model", owned_by="AutoGPT")
|
||||||
|
]
|
||||||
apply_overrides_to_config(
|
apply_overrides_to_config(
|
||||||
config=config,
|
config=config,
|
||||||
gpt4only=True,
|
gpt4only=True,
|
||||||
|
|||||||
@@ -2,8 +2,11 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
|
||||||
import openai.api_requestor
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai import OpenAI
|
||||||
|
from openai._models import FinalRequestOptions
|
||||||
|
from openai._types import Omit
|
||||||
|
from openai._utils import is_given
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from .vcr_filter import (
|
from .vcr_filter import (
|
||||||
@@ -52,30 +55,26 @@ def vcr_cassette_dir(request):
|
|||||||
return os.path.join("tests/vcr_cassettes", test_name)
|
return os.path.join("tests/vcr_cassettes", test_name)
|
||||||
|
|
||||||
|
|
||||||
def patch_api_base(requestor: openai.api_requestor.APIRequestor):
|
|
||||||
new_api_base = f"{PROXY}/v1"
|
|
||||||
requestor.api_base = new_api_base
|
|
||||||
return requestor
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patched_api_requestor(mocker: MockerFixture):
|
def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
||||||
init_requestor = openai.api_requestor.APIRequestor.__init__
|
client = OpenAI()
|
||||||
prepare_request = openai.api_requestor.APIRequestor._prepare_request_raw
|
_prepare_options = client._prepare_options
|
||||||
|
|
||||||
def patched_init_requestor(requestor, *args, **kwargs):
|
def _patched_prepare_options(self, options: FinalRequestOptions):
|
||||||
init_requestor(requestor, *args, **kwargs)
|
_prepare_options(options)
|
||||||
patch_api_base(requestor)
|
|
||||||
|
|
||||||
def patched_prepare_request(self, *args, **kwargs):
|
headers: dict[str, str | Omit] = (
|
||||||
url, headers, data = prepare_request(self, *args, **kwargs)
|
{**options.headers} if is_given(options.headers) else {}
|
||||||
|
)
|
||||||
|
options.headers = headers
|
||||||
|
data: dict = options.json_data
|
||||||
|
|
||||||
if PROXY:
|
if PROXY:
|
||||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
|
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit())
|
||||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
|
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit())
|
||||||
|
|
||||||
logging.getLogger("patched_api_requestor").debug(
|
logging.getLogger("cached_openai_client").debug(
|
||||||
f"Outgoing API request: {headers}\n{data.decode() if data else None}"
|
f"Outgoing API request: {headers}\n{data if data else None}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add hash header for cheap & fast matching on cassette playback
|
# Add hash header for cheap & fast matching on cassette playback
|
||||||
@@ -83,16 +82,12 @@ def patched_api_requestor(mocker: MockerFixture):
|
|||||||
freeze_request_body(data), usedforsecurity=False
|
freeze_request_body(data), usedforsecurity=False
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
return url, headers, data
|
|
||||||
|
|
||||||
if PROXY:
|
if PROXY:
|
||||||
mocker.patch.object(
|
client.base_url = f"{PROXY}/v1"
|
||||||
openai.api_requestor.APIRequestor,
|
|
||||||
"__init__",
|
|
||||||
new=patched_init_requestor,
|
|
||||||
)
|
|
||||||
mocker.patch.object(
|
mocker.patch.object(
|
||||||
openai.api_requestor.APIRequestor,
|
client,
|
||||||
"_prepare_request_raw",
|
"_prepare_options",
|
||||||
new=patched_prepare_request,
|
new=_patched_prepare_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|||||||
@@ -44,14 +44,9 @@ def replace_message_content(content: str, replacements: List[Dict[str, str]]) ->
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def freeze_request_body(json_body: str | bytes) -> bytes:
|
def freeze_request_body(body: dict) -> bytes:
|
||||||
"""Remove any dynamic items from the request body"""
|
"""Remove any dynamic items from the request body"""
|
||||||
|
|
||||||
try:
|
|
||||||
body = json.loads(json_body)
|
|
||||||
except ValueError:
|
|
||||||
return json_body if type(json_body) is bytes else json_body.encode()
|
|
||||||
|
|
||||||
if "messages" not in body:
|
if "messages" not in body:
|
||||||
return json.dumps(body, sort_keys=True).encode()
|
return json.dumps(body, sort_keys=True).encode()
|
||||||
|
|
||||||
@@ -74,9 +69,11 @@ def freeze_request(request: Request) -> Request:
|
|||||||
|
|
||||||
with contextlib.suppress(ValueError):
|
with contextlib.suppress(ValueError):
|
||||||
request.body = freeze_request_body(
|
request.body = freeze_request_body(
|
||||||
request.body.getvalue()
|
json.loads(
|
||||||
if isinstance(request.body, BytesIO)
|
request.body.getvalue()
|
||||||
else request.body
|
if isinstance(request.body, BytesIO)
|
||||||
|
else request.body
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
|||||||
Reference in New Issue
Block a user