AutoGPT/v2: First pass with small fixes

* Typing fixes & improvements

* Improved console output formatting

* Added support for all OpenAI GPT-3.5-turbo and GPT-4 model versions

* Added token counting functions to ModelProviders
This commit is contained in:
Reinier van der Leer
2023-09-17 16:40:56 +02:00
parent f4d319cee4
commit 11920b8fe5
18 changed files with 449 additions and 149 deletions

View File

@@ -48,7 +48,7 @@ class SimpleAbilityRegistry(AbilityRegistry, Configurable):
self._memory = memory
self._workspace = workspace
self._model_providers = model_providers
self._abilities = []
self._abilities: list[Ability] = []
for (
ability_name,
ability_configuration,

View File

@@ -2,7 +2,6 @@
from autogpt.core.planning.schema import (
LanguageModelClassification,
LanguageModelPrompt,
LanguageModelResponse,
Task,
TaskStatus,
TaskType,

View File

@@ -1,4 +1,5 @@
import enum
from typing import Optional
from pydantic import BaseModel, Field
@@ -6,7 +7,6 @@ from autogpt.core.ability.schema import AbilityResult
from autogpt.core.resource.model_providers.schema import (
LanguageModelFunction,
LanguageModelMessage,
LanguageModelProviderModelResponse,
)
@@ -19,8 +19,8 @@ class LanguageModelClassification(str, enum.Enum):
"""
FAST_MODEL: str = "fast_model"
SMART_MODEL: str = "smart_model"
FAST_MODEL = "fast_model"
SMART_MODEL = "smart_model"
class LanguageModelPrompt(BaseModel):
@@ -28,34 +28,33 @@ class LanguageModelPrompt(BaseModel):
functions: list[LanguageModelFunction] = Field(default_factory=list)
def __str__(self):
return "\n\n".join([f"{m.role.value}: {m.content}" for m in self.messages])
class LanguageModelResponse(LanguageModelProviderModelResponse):
"""Standard response struct for a response from a language model."""
return "\n\n".join(
f"{m.role.value.upper()}: {m.content}"
for m in self.messages
)
class TaskType(str, enum.Enum):
RESEARCH: str = "research"
WRITE: str = "write"
EDIT: str = "edit"
CODE: str = "code"
DESIGN: str = "design"
TEST: str = "test"
PLAN: str = "plan"
RESEARCH = "research"
WRITE = "write"
EDIT = "edit"
CODE = "code"
DESIGN = "design"
TEST = "test"
PLAN = "plan"
class TaskStatus(str, enum.Enum):
BACKLOG: str = "backlog"
READY: str = "ready"
IN_PROGRESS: str = "in_progress"
DONE: str = "done"
BACKLOG = "backlog"
READY = "ready"
IN_PROGRESS = "in_progress"
DONE = "done"
class TaskContext(BaseModel):
cycle_count: int = 0
status: TaskStatus = TaskStatus.BACKLOG
parent: "Task" = None
parent: Optional["Task"] = None
prior_actions: list[AbilityResult] = Field(default_factory=list)
memories: list = Field(default_factory=list)
user_input: list[str] = Field(default_factory=list)

View File

@@ -14,14 +14,15 @@ from autogpt.core.planning import strategies
from autogpt.core.planning.base import PromptStrategy
from autogpt.core.planning.schema import (
LanguageModelClassification,
LanguageModelResponse,
Task,
)
from autogpt.core.resource.model_providers import (
LanguageModelProvider,
LanguageModelResponse,
ModelProviderName,
OpenAIModelName,
)
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
from autogpt.core.workspace import Workspace
@@ -153,7 +154,7 @@ class SimplePlanner(Configurable):
template_kwargs.update(kwargs)
prompt = prompt_strategy.build_prompt(**template_kwargs)
self._logger.debug(f"Using prompt:\n{prompt}\n\n")
self._logger.debug(f"Using prompt:\n{dump_prompt(prompt)}\n")
response = await provider.create_language_completion(
model_prompt=prompt.messages,
functions=prompt.functions,

View File

@@ -1,3 +1,5 @@
import logging
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.planning.base import PromptStrategy
from autogpt.core.planning.schema import (
@@ -13,6 +15,8 @@ from autogpt.core.resource.model_providers import (
MessageRole,
)
logger = logging.getLogger(__name__)
class InitialPlanConfiguration(SystemConfiguration):
model_classification: LanguageModelClassification = UserConfigurable()
@@ -98,7 +102,7 @@ class InitialPlan(PromptStrategy):
},
}
default_configuration = InitialPlanConfiguration(
default_configuration: InitialPlanConfiguration = InitialPlanConfiguration(
model_classification=LanguageModelClassification.SMART_MODEL,
system_prompt_template=DEFAULT_SYSTEM_PROMPT_TEMPLATE,
system_info=DEFAULT_SYSTEM_INFO,
@@ -183,8 +187,12 @@ class InitialPlan(PromptStrategy):
The parsed response.
"""
parsed_response = json_loads(response_content["function_call"]["arguments"])
parsed_response["task_list"] = [
Task.parse_obj(task) for task in parsed_response["task_list"]
]
try:
parsed_response = json_loads(response_content["function_call"]["arguments"])
parsed_response["task_list"] = [
Task.parse_obj(task) for task in parsed_response["task_list"]
]
except KeyError:
logger.debug(f"Failed to parse this response content: {response_content}")
raise
return parsed_response

View File

@@ -1,3 +1,5 @@
import logging
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.planning.base import PromptStrategy
from autogpt.core.planning.schema import (
@@ -11,6 +13,7 @@ from autogpt.core.resource.model_providers import (
MessageRole,
)
logger = logging.getLogger(__name__)
class NameAndGoalsConfiguration(SystemConfiguration):
model_classification: LanguageModelClassification = UserConfigurable()
@@ -21,12 +24,16 @@ class NameAndGoalsConfiguration(SystemConfiguration):
class NameAndGoals(PromptStrategy):
DEFAULT_SYSTEM_PROMPT = (
"Your job is to respond to a user-defined task by invoking the `create_agent` function "
"to generate an autonomous agent to complete the task. You should supply a role-based "
"name for the agent, an informative description for what the agent does, and 1 to 5 "
"goals that are optimally aligned with the successful completion of its assigned task.\n\n"
"Your job is to respond to a user-defined task, given in triple quotes, by "
"invoking the `create_agent` function to generate an autonomous agent to "
"complete the task. "
"You should supply a role-based name for the agent, "
"an informative description for what the agent does, and "
"1 to 5 goals that are optimally aligned with the successful completion of "
"its assigned task.\n"
"\n"
"Example Input:\n"
"Help me with marketing my business\n\n"
'"""Help me with marketing my business"""\n\n'
"Example Function Call:\n"
"create_agent(name='CMOGPT', "
"description='A professional digital marketer AI that assists Solopreneurs in "
@@ -43,7 +50,7 @@ class NameAndGoals(PromptStrategy):
"remains on track.'])"
)
DEFAULT_USER_PROMPT_TEMPLATE = "'{user_objective}'"
DEFAULT_USER_PROMPT_TEMPLATE = '"""{user_objective}"""'
DEFAULT_CREATE_AGENT_FUNCTION = {
"name": "create_agent",
@@ -77,7 +84,7 @@ class NameAndGoals(PromptStrategy):
},
}
default_configuration = NameAndGoalsConfiguration(
default_configuration: NameAndGoalsConfiguration = NameAndGoalsConfiguration(
model_classification=LanguageModelClassification.SMART_MODEL,
system_prompt=DEFAULT_SYSTEM_PROMPT,
user_prompt_template=DEFAULT_USER_PROMPT_TEMPLATE,
@@ -135,5 +142,9 @@ class NameAndGoals(PromptStrategy):
The parsed response.
"""
parsed_response = json_loads(response_content["function_call"]["arguments"])
try:
parsed_response = json_loads(response_content["function_call"]["arguments"])
except KeyError:
logger.debug(f"Failed to parse this response content: {response_content}")
raise
return parsed_response

View File

@@ -1,3 +1,5 @@
import logging
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.planning.base import PromptStrategy
from autogpt.core.planning.schema import (
@@ -12,6 +14,8 @@ from autogpt.core.resource.model_providers import (
MessageRole,
)
logger = logging.getLogger(__name__)
class NextAbilityConfiguration(SystemConfiguration):
model_classification: LanguageModelClassification = UserConfigurable()
@@ -61,7 +65,7 @@ class NextAbility(PromptStrategy):
},
}
default_configuration = NextAbilityConfiguration(
default_configuration: NextAbilityConfiguration = NextAbilityConfiguration(
model_classification=LanguageModelClassification.SMART_MODEL,
system_prompt_template=DEFAULT_SYSTEM_PROMPT_TEMPLATE,
system_info=DEFAULT_SYSTEM_INFO,
@@ -171,13 +175,17 @@ class NextAbility(PromptStrategy):
The parsed response.
"""
function_name = response_content["function_call"]["name"]
function_arguments = json_loads(response_content["function_call"]["arguments"])
parsed_response = {
"motivation": function_arguments.pop("motivation"),
"self_criticism": function_arguments.pop("self_criticism"),
"reasoning": function_arguments.pop("reasoning"),
"next_ability": function_name,
"ability_arguments": function_arguments,
}
try:
function_name = response_content["function_call"]["name"]
function_arguments = json_loads(response_content["function_call"]["arguments"])
parsed_response = {
"motivation": function_arguments.pop("motivation"),
"self_criticism": function_arguments.pop("self_criticism"),
"reasoning": function_arguments.pop("reasoning"),
"next_ability": function_name,
"ability_arguments": function_arguments,
}
except KeyError:
logger.debug(f"Failed to parse this response content: {response_content}")
raise
return parsed_response

View File

@@ -832,6 +832,103 @@ files = [
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
]
[[package]]
name = "regex"
version = "2023.8.8"
description = "Alternative regular expression module, to replace re."
optional = false
python-versions = ">=3.6"
files = [
{file = "regex-2023.8.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:88900f521c645f784260a8d346e12a1590f79e96403971241e64c3a265c8ecdb"},
{file = "regex-2023.8.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3611576aff55918af2697410ff0293d6071b7e00f4b09e005d614686ac4cd57c"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8a0ccc8f2698f120e9e5742f4b38dc944c38744d4bdfc427616f3a163dd9de5"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c662a4cbdd6280ee56f841f14620787215a171c4e2d1744c9528bed8f5816c96"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cf0633e4a1b667bfe0bb10b5e53fe0d5f34a6243ea2530eb342491f1adf4f739"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551ad543fa19e94943c5b2cebc54c73353ffff08228ee5f3376bd27b3d5b9800"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54de2619f5ea58474f2ac211ceea6b615af2d7e4306220d4f3fe690c91988a61"},
{file = "regex-2023.8.8-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5ec4b3f0aebbbe2fc0134ee30a791af522a92ad9f164858805a77442d7d18570"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ae646c35cb9f820491760ac62c25b6d6b496757fda2d51be429e0e7b67ae0ab"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ca339088839582d01654e6f83a637a4b8194d0960477b9769d2ff2cfa0fa36d2"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:d9b6627408021452dcd0d2cdf8da0534e19d93d070bfa8b6b4176f99711e7f90"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:bd3366aceedf274f765a3a4bc95d6cd97b130d1dda524d8f25225d14123c01db"},
{file = "regex-2023.8.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7aed90a72fc3654fba9bc4b7f851571dcc368120432ad68b226bd593f3f6c0b7"},
{file = "regex-2023.8.8-cp310-cp310-win32.whl", hash = "sha256:80b80b889cb767cc47f31d2b2f3dec2db8126fbcd0cff31b3925b4dc6609dcdb"},
{file = "regex-2023.8.8-cp310-cp310-win_amd64.whl", hash = "sha256:b82edc98d107cbc7357da7a5a695901b47d6eb0420e587256ba3ad24b80b7d0b"},
{file = "regex-2023.8.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1e7d84d64c84ad97bf06f3c8cb5e48941f135ace28f450d86af6b6512f1c9a71"},
{file = "regex-2023.8.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce0f9fbe7d295f9922c0424a3637b88c6c472b75eafeaff6f910494a1fa719ef"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06c57e14ac723b04458df5956cfb7e2d9caa6e9d353c0b4c7d5d54fcb1325c46"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7a9aaa5a1267125eef22cef3b63484c3241aaec6f48949b366d26c7250e0357"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b7408511fca48a82a119d78a77c2f5eb1b22fe88b0d2450ed0756d194fe7a9a"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14dc6f2d88192a67d708341f3085df6a4f5a0c7b03dec08d763ca2cd86e9f559"},
{file = "regex-2023.8.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48c640b99213643d141550326f34f0502fedb1798adb3c9eb79650b1ecb2f177"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0085da0f6c6393428bf0d9c08d8b1874d805bb55e17cb1dfa5ddb7cfb11140bf"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:964b16dcc10c79a4a2be9f1273fcc2684a9eedb3906439720598029a797b46e6"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7ce606c14bb195b0e5108544b540e2c5faed6843367e4ab3deb5c6aa5e681208"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:40f029d73b10fac448c73d6eb33d57b34607f40116e9f6e9f0d32e9229b147d7"},
{file = "regex-2023.8.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3b8e6ea6be6d64104d8e9afc34c151926f8182f84e7ac290a93925c0db004bfd"},
{file = "regex-2023.8.8-cp311-cp311-win32.whl", hash = "sha256:942f8b1f3b223638b02df7df79140646c03938d488fbfb771824f3d05fc083a8"},
{file = "regex-2023.8.8-cp311-cp311-win_amd64.whl", hash = "sha256:51d8ea2a3a1a8fe4f67de21b8b93757005213e8ac3917567872f2865185fa7fb"},
{file = "regex-2023.8.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e951d1a8e9963ea51efd7f150450803e3b95db5939f994ad3d5edac2b6f6e2b4"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:704f63b774218207b8ccc6c47fcef5340741e5d839d11d606f70af93ee78e4d4"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22283c769a7b01c8ac355d5be0715bf6929b6267619505e289f792b01304d898"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91129ff1bb0619bc1f4ad19485718cc623a2dc433dff95baadbf89405c7f6b57"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de35342190deb7b866ad6ba5cbcccb2d22c0487ee0cbb251efef0843d705f0d4"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b993b6f524d1e274a5062488a43e3f9f8764ee9745ccd8e8193df743dbe5ee61"},
{file = "regex-2023.8.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3026cbcf11d79095a32d9a13bbc572a458727bd5b1ca332df4a79faecd45281c"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:293352710172239bf579c90a9864d0df57340b6fd21272345222fb6371bf82b3"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d909b5a3fff619dc7e48b6b1bedc2f30ec43033ba7af32f936c10839e81b9217"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:3d370ff652323c5307d9c8e4c62efd1956fb08051b0e9210212bc51168b4ff56"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:b076da1ed19dc37788f6a934c60adf97bd02c7eea461b73730513921a85d4235"},
{file = "regex-2023.8.8-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e9941a4ada58f6218694f382e43fdd256e97615db9da135e77359da257a7168b"},
{file = "regex-2023.8.8-cp36-cp36m-win32.whl", hash = "sha256:a8c65c17aed7e15a0c824cdc63a6b104dfc530f6fa8cb6ac51c437af52b481c7"},
{file = "regex-2023.8.8-cp36-cp36m-win_amd64.whl", hash = "sha256:aadf28046e77a72f30dcc1ab185639e8de7f4104b8cb5c6dfa5d8ed860e57236"},
{file = "regex-2023.8.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:423adfa872b4908843ac3e7a30f957f5d5282944b81ca0a3b8a7ccbbfaa06103"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ae594c66f4a7e1ea67232a0846649a7c94c188d6c071ac0210c3e86a5f92109"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e51c80c168074faa793685656c38eb7a06cbad7774c8cbc3ea05552d615393d8"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:09b7f4c66aa9d1522b06e31a54f15581c37286237208df1345108fcf4e050c18"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e73e5243af12d9cd6a9d6a45a43570dbe2e5b1cdfc862f5ae2b031e44dd95a8"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:941460db8fe3bd613db52f05259c9336f5a47ccae7d7def44cc277184030a116"},
{file = "regex-2023.8.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f0ccf3e01afeb412a1a9993049cb160d0352dba635bbca7762b2dc722aa5742a"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2e9216e0d2cdce7dbc9be48cb3eacb962740a09b011a116fd7af8c832ab116ca"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:5cd9cd7170459b9223c5e592ac036e0704bee765706445c353d96f2890e816c8"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:4873ef92e03a4309b3ccd8281454801b291b689f6ad45ef8c3658b6fa761d7ac"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:239c3c2a339d3b3ddd51c2daef10874410917cd2b998f043c13e2084cb191684"},
{file = "regex-2023.8.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1005c60ed7037be0d9dea1f9c53cc42f836188227366370867222bda4c3c6bd7"},
{file = "regex-2023.8.8-cp37-cp37m-win32.whl", hash = "sha256:e6bd1e9b95bc5614a7a9c9c44fde9539cba1c823b43a9f7bc11266446dd568e3"},
{file = "regex-2023.8.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9a96edd79661e93327cfeac4edec72a4046e14550a1d22aa0dd2e3ca52aec921"},
{file = "regex-2023.8.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f2181c20ef18747d5f4a7ea513e09ea03bdd50884a11ce46066bb90fe4213675"},
{file = "regex-2023.8.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a2ad5add903eb7cdde2b7c64aaca405f3957ab34f16594d2b78d53b8b1a6a7d6"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9233ac249b354c54146e392e8a451e465dd2d967fc773690811d3a8c240ac601"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:920974009fb37b20d32afcdf0227a2e707eb83fe418713f7a8b7de038b870d0b"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd2b6c5dfe0929b6c23dde9624483380b170b6e34ed79054ad131b20203a1a63"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96979d753b1dc3b2169003e1854dc67bfc86edf93c01e84757927f810b8c3c93"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ae54a338191e1356253e7883d9d19f8679b6143703086245fb14d1f20196be9"},
{file = "regex-2023.8.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2162ae2eb8b079622176a81b65d486ba50b888271302190870b8cc488587d280"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c884d1a59e69e03b93cf0dfee8794c63d7de0ee8f7ffb76e5f75be8131b6400a"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:cf9273e96f3ee2ac89ffcb17627a78f78e7516b08f94dc435844ae72576a276e"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:83215147121e15d5f3a45d99abeed9cf1fe16869d5c233b08c56cdf75f43a504"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f7454aa427b8ab9101f3787eb178057c5250478e39b99540cfc2b889c7d0586"},
{file = "regex-2023.8.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f0640913d2c1044d97e30d7c41728195fc37e54d190c5385eacb52115127b882"},
{file = "regex-2023.8.8-cp38-cp38-win32.whl", hash = "sha256:0c59122ceccb905a941fb23b087b8eafc5290bf983ebcb14d2301febcbe199c7"},
{file = "regex-2023.8.8-cp38-cp38-win_amd64.whl", hash = "sha256:c12f6f67495ea05c3d542d119d270007090bad5b843f642d418eb601ec0fa7be"},
{file = "regex-2023.8.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:82cd0a69cd28f6cc3789cc6adeb1027f79526b1ab50b1f6062bbc3a0ccb2dbc3"},
{file = "regex-2023.8.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bb34d1605f96a245fc39790a117ac1bac8de84ab7691637b26ab2c5efb8f228c"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:987b9ac04d0b38ef4f89fbc035e84a7efad9cdd5f1e29024f9289182c8d99e09"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dd6082f4e2aec9b6a0927202c85bc1b09dcab113f97265127c1dc20e2e32495"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7eb95fe8222932c10d4436e7a6f7c99991e3fdd9f36c949eff16a69246dee2dc"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7098c524ba9f20717a56a8d551d2ed491ea89cbf37e540759ed3b776a4f8d6eb"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b694430b3f00eb02c594ff5a16db30e054c1b9589a043fe9174584c6efa8033"},
{file = "regex-2023.8.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b2aeab3895d778155054abea5238d0eb9a72e9242bd4b43f42fd911ef9a13470"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:988631b9d78b546e284478c2ec15c8a85960e262e247b35ca5eaf7ee22f6050a"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:67ecd894e56a0c6108ec5ab1d8fa8418ec0cff45844a855966b875d1039a2e34"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:14898830f0a0eb67cae2bbbc787c1a7d6e34ecc06fbd39d3af5fe29a4468e2c9"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:f2200e00b62568cfd920127782c61bc1c546062a879cdc741cfcc6976668dfcf"},
{file = "regex-2023.8.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9691a549c19c22d26a4f3b948071e93517bdf86e41b81d8c6ac8a964bb71e5a6"},
{file = "regex-2023.8.8-cp39-cp39-win32.whl", hash = "sha256:6ab2ed84bf0137927846b37e882745a827458689eb969028af8032b1b3dac78e"},
{file = "regex-2023.8.8-cp39-cp39-win_amd64.whl", hash = "sha256:5543c055d8ec7801901e1193a51570643d6a6ab8751b1f7dd9af71af467538bb"},
{file = "regex-2023.8.8.tar.gz", hash = "sha256:fcbdc5f2b0f1cd0f6a56cdb46fe41d2cce1e644e3b68832f3eeebc5fb0f7712e"},
]
[[package]]
name = "requests"
version = "2.31.0"
@@ -881,6 +978,51 @@ anyio = ">=3.4.0,<5"
[package.extras]
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"]
[[package]]
name = "tiktoken"
version = "0.5.1"
description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
optional = false
python-versions = ">=3.8"
files = [
{file = "tiktoken-0.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2b0bae3fd56de1c0a5874fb6577667a3c75bf231a6cef599338820210c16e40a"},
{file = "tiktoken-0.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e529578d017045e2f0ed12d2e00e7e99f780f477234da4aae799ec4afca89f37"},
{file = "tiktoken-0.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edd2ffbb789712d83fee19ab009949f998a35c51ad9f9beb39109357416344ff"},
{file = "tiktoken-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4c73d47bdc1a3f1f66ffa019af0386c48effdc6e8797e5e76875f6388ff72e9"},
{file = "tiktoken-0.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:46b8554b9f351561b1989157c6bb54462056f3d44e43aa4e671367c5d62535fc"},
{file = "tiktoken-0.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:92ed3bbf71a175a6a4e5fbfcdb2c422bdd72d9b20407e00f435cf22a68b4ea9b"},
{file = "tiktoken-0.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:714efb2f4a082635d9f5afe0bf7e62989b72b65ac52f004eb7ac939f506c03a4"},
{file = "tiktoken-0.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a10488d1d1a5f9c9d2b2052fdb4cf807bba545818cb1ef724a7f5d44d9f7c3d4"},
{file = "tiktoken-0.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8079ac065572fe0e7c696dbd63e1fdc12ce4cdca9933935d038689d4732451df"},
{file = "tiktoken-0.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ef730db4097f5b13df8d960f7fdda2744fe21d203ea2bb80c120bb58661b155"},
{file = "tiktoken-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426e7def5f3f23645dada816be119fa61e587dfb4755de250e136b47a045c365"},
{file = "tiktoken-0.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:323cec0031358bc09aa965c2c5c1f9f59baf76e5b17e62dcc06d1bb9bc3a3c7c"},
{file = "tiktoken-0.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5abd9436f02e2c8eda5cce2ff8015ce91f33e782a7423de2a1859f772928f714"},
{file = "tiktoken-0.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:1fe99953b63aabc0c9536fbc91c3c9000d78e4755edc28cc2e10825372046a2d"},
{file = "tiktoken-0.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dcdc630461927718b317e6f8be7707bd0fc768cee1fdc78ddaa1e93f4dc6b2b1"},
{file = "tiktoken-0.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1f2b3b253e22322b7f53a111e1f6d7ecfa199b4f08f3efdeb0480f4033b5cdc6"},
{file = "tiktoken-0.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43ce0199f315776dec3ea7bf86f35df86d24b6fcde1babd3e53c38f17352442f"},
{file = "tiktoken-0.5.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a84657c083d458593c0235926b5c993eec0b586a2508d6a2020556e5347c2f0d"},
{file = "tiktoken-0.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c008375c0f3d97c36e81725308699116cd5804fdac0f9b7afc732056329d2790"},
{file = "tiktoken-0.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:779c4dea5edd1d3178734d144d32231e0b814976bec1ec09636d1003ffe4725f"},
{file = "tiktoken-0.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:b5dcfcf9bfb798e86fbce76d40a1d5d9e3f92131aecfa3d1e5c9ea1a20f1ef1a"},
{file = "tiktoken-0.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9b180a22db0bbcc447f691ffc3cf7a580e9e0587d87379e35e58b826ebf5bc7b"},
{file = "tiktoken-0.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b756a65d98b7cf760617a6b68762a23ab8b6ef79922be5afdb00f5e8a9f4e76"},
{file = "tiktoken-0.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba9873c253ca1f670e662192a0afcb72b41e0ba3e730f16c665099e12f4dac2d"},
{file = "tiktoken-0.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74c90d2be0b4c1a2b3f7dde95cd976757817d4df080d6af0ee8d461568c2e2ad"},
{file = "tiktoken-0.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:709a5220891f2b56caad8327fab86281787704931ed484d9548f65598dea9ce4"},
{file = "tiktoken-0.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5d5a187ff9c786fae6aadd49f47f019ff19e99071dc5b0fe91bfecc94d37c686"},
{file = "tiktoken-0.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:e21840043dbe2e280e99ad41951c00eff8ee3b63daf57cd4c1508a3fd8583ea2"},
{file = "tiktoken-0.5.1.tar.gz", hash = "sha256:27e773564232004f4f810fd1f85236673ec3a56ed7f1206fc9ed8670ebedb97a"},
]
[package.dependencies]
regex = ">=2022.1.18"
requests = ">=2.26.0"
[package.extras]
blobfile = ["blobfile (>=2)"]
[[package]]
name = "tomli"
version = "2.0.1"
@@ -1044,4 +1186,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "f0352811def6ea015f9ea62162354da6ae075102b6dc19893ca3b93c8f1179f8"
content-hash = "50ad53581d2716ee6927df7200b2522acfaad35aadc76909bdab4073c49a824e"

View File

@@ -28,6 +28,7 @@ inflection = "^0.5.1"
openai = "^0.28.0"
pydantic = "^1.10.12"
pyyaml = "^6.0.0"
tiktoken = "^0.5.1"
[build-system]
requires = ["poetry-core"]

View File

@@ -6,39 +6,48 @@ from autogpt.core.resource.model_providers.openai import (
)
from autogpt.core.resource.model_providers.schema import (
Embedding,
EmbeddingModelInfo,
EmbeddingModelProvider,
EmbeddingModelProviderModelInfo,
EmbeddingModelProviderModelResponse,
EmbeddingModelResponse,
LanguageModelFunction,
LanguageModelInfo,
LanguageModelMessage,
LanguageModelProvider,
LanguageModelProviderModelInfo,
LanguageModelProviderModelResponse,
LanguageModelResponse,
MessageRole,
ModelInfo,
ModelProvider,
ModelProviderBudget,
ModelProviderCredentials,
ModelProviderModelInfo,
ModelProviderModelResponse,
ModelProviderName,
ModelProviderService,
ModelProviderSettings,
ModelProviderUsage,
ModelResponse,
)
__all__ = [
"ModelProvider",
"ModelProviderName",
"ModelProviderSettings",
"Embedding",
"EmbeddingModelInfo",
"EmbeddingModelProvider",
"EmbeddingModelProviderModelResponse",
"LanguageModelProvider",
"LanguageModelProviderModelResponse",
"EmbeddingModelResponse",
"LanguageModelFunction",
"LanguageModelInfo",
"LanguageModelMessage",
"LanguageModelProvider",
"LanguageModelResponse",
"MessageRole",
"OpenAIModelName",
"ModelInfo",
"ModelProvider",
"ModelProviderBudget",
"ModelProviderCredentials",
"ModelProviderName",
"ModelProviderService",
"ModelProviderSettings",
"ModelProviderUsage",
"ModelResponse",
"OPEN_AI_MODELS",
"OpenAIModelName",
"OpenAIProvider",
"OpenAISettings",
]

View File

@@ -6,6 +6,7 @@ import time
from typing import Callable, ParamSpec, TypeVar
import openai
import tiktoken
from openai.error import APIError, RateLimitError
from autogpt.core.configuration import (
@@ -16,13 +17,13 @@ from autogpt.core.configuration import (
from autogpt.core.resource.model_providers.schema import (
Embedding,
EmbeddingModelProvider,
EmbeddingModelProviderModelInfo,
EmbeddingModelProviderModelResponse,
EmbeddingModelInfo,
EmbeddingModelResponse,
LanguageModelFunction,
LanguageModelMessage,
LanguageModelProvider,
LanguageModelProviderModelInfo,
LanguageModelProviderModelResponse,
LanguageModelInfo,
LanguageModelResponse,
ModelProviderBudget,
ModelProviderCredentials,
ModelProviderName,
@@ -37,19 +38,31 @@ OpenAIChatParser = Callable[[str], dict]
class OpenAIModelName(str, enum.Enum):
ADA = "text-embedding-ada-002"
GPT3 = "gpt-3.5-turbo-0613"
GPT3_16K = "gpt-3.5-turbo-16k-0613"
GPT4 = "gpt-4-0613"
GPT4_32K = "gpt-4-32k-0613"
GPT3_v1 = "gpt-3.5-turbo-0301"
GPT3_v2 = "gpt-3.5-turbo-0613"
GPT3_v2_16k = "gpt-3.5-turbo-16k-0613"
GPT3_ROLLING = "gpt-3.5-turbo"
GPT3_ROLLING_16k = "gpt-3.5-turbo-16k"
GPT3 = GPT3_ROLLING
GPT3_16k = GPT3_ROLLING_16k
GPT4_v1 = "gpt-4-0314"
GPT4_v1_32k = "gpt-4-32k-0314"
GPT4_v2 = "gpt-4-0613"
GPT4_v2_32k = "gpt-4-32k-0613"
GPT4_ROLLING = "gpt-4"
GPT4_ROLLING_32k = "gpt-4-32k"
GPT4 = GPT4_ROLLING
GPT4_32k = GPT4_ROLLING_32k
OPEN_AI_EMBEDDING_MODELS = {
OpenAIModelName.ADA: EmbeddingModelProviderModelInfo(
OpenAIModelName.ADA: EmbeddingModelInfo(
name=OpenAIModelName.ADA,
service=ModelProviderService.EMBEDDING,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0004,
completion_token_cost=0.0,
prompt_token_cost=0.0001/1000,
max_tokens=8191,
embedding_dimensions=1536,
),
@@ -57,39 +70,54 @@ OPEN_AI_EMBEDDING_MODELS = {
OPEN_AI_LANGUAGE_MODELS = {
OpenAIModelName.GPT3: LanguageModelProviderModelInfo(
name=OpenAIModelName.GPT3,
service=ModelProviderService.LANGUAGE,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0015,
completion_token_cost=0.002,
max_tokens=4096,
),
OpenAIModelName.GPT3_16K: LanguageModelProviderModelInfo(
name=OpenAIModelName.GPT3,
service=ModelProviderService.LANGUAGE,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.003,
completion_token_cost=0.002,
max_tokens=16384,
),
OpenAIModelName.GPT4: LanguageModelProviderModelInfo(
name=OpenAIModelName.GPT4,
service=ModelProviderService.LANGUAGE,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.03,
completion_token_cost=0.06,
max_tokens=8192,
),
OpenAIModelName.GPT4_32K: LanguageModelProviderModelInfo(
name=OpenAIModelName.GPT4_32K,
service=ModelProviderService.LANGUAGE,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.06,
completion_token_cost=0.12,
max_tokens=32768,
),
info.name: info
for info in [
LanguageModelInfo(
name=OpenAIModelName.GPT3,
service=ModelProviderService.LANGUAGE,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.0015/1000,
completion_token_cost=0.002/1000,
max_tokens=4096,
),
LanguageModelInfo(
name=OpenAIModelName.GPT3_16k,
service=ModelProviderService.LANGUAGE,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.003/1000,
completion_token_cost=0.004/1000,
max_tokens=16384,
),
LanguageModelInfo(
name=OpenAIModelName.GPT4,
service=ModelProviderService.LANGUAGE,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.03/1000,
completion_token_cost=0.06/1000,
max_tokens=8192,
),
LanguageModelInfo(
name=OpenAIModelName.GPT4_32k,
service=ModelProviderService.LANGUAGE,
provider_name=ModelProviderName.OPENAI,
prompt_token_cost=0.06/1000,
completion_token_cost=0.12/1000,
max_tokens=32768,
),
]
}
# Copy entries for equivalent models
chat_model_mapping = {
OpenAIModelName.GPT3: [OpenAIModelName.GPT3_v1, OpenAIModelName.GPT3_v2],
OpenAIModelName.GPT3_16k: [OpenAIModelName.GPT3_v2_16k],
OpenAIModelName.GPT4: [OpenAIModelName.GPT4_v1, OpenAIModelName.GPT4_v2],
OpenAIModelName.GPT4_32k: [OpenAIModelName.GPT4_v1_32k, OpenAIModelName.GPT4_v2_32k],
}
for base, copies in chat_model_mapping.items():
for copy in copies:
copy_info = LanguageModelInfo(**OPEN_AI_LANGUAGE_MODELS[base].__dict__)
copy_info.name = copy
OPEN_AI_LANGUAGE_MODELS[copy] = copy_info
OPEN_AI_MODELS = {
@@ -113,11 +141,7 @@ class OpenAISettings(ModelProviderSettings):
budget: OpenAIModelProviderBudget
class OpenAIProvider(
Configurable,
LanguageModelProvider,
EmbeddingModelProvider,
):
class OpenAIProvider(Configurable, LanguageModelProvider, EmbeddingModelProvider):
default_settings = OpenAISettings(
name="openai_provider",
description="Provides access to OpenAI's API.",
@@ -139,6 +163,8 @@ class OpenAIProvider(
),
)
logger = logging.getLogger("model_providers.OpenAIProvider")
def __init__(
self,
settings: OpenAISettings,
@@ -166,14 +192,60 @@ class OpenAIProvider(
"""Get the remaining budget."""
return self._budget.remaining_budget
def count_tokens(self, text: str, model_name: OpenAIModelName) -> int:
encoding = tiktoken.encoding_for_model(model_name)
return len(encoding.encode(text))
def count_message_tokens(
self,
messages: LanguageModelMessage | list[LanguageModelMessage],
model_name: OpenAIModelName,
) -> int:
if isinstance(messages, LanguageModelMessage):
messages = [messages]
if model_name.startswith("gpt-3.5-turbo"):
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_name = -1 # if there's a name, the role is omitted
encoding_model = "gpt-3.5-turbo"
elif model_name.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
encoding_model = "gpt-4"
else:
raise NotImplementedError(
f"count_message_tokens() is not implemented for model {model_name}.\n"
" See https://github.com/openai/openai-python/blob/main/chatml.md for"
" information on how messages are converted to tokens."
)
try:
encoding = tiktoken.encoding_for_model(encoding_model)
except KeyError:
self.logger.warn(
f"Model {model_name} not found. Defaulting to cl100k_base encoding."
)
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.dict().items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
async def create_language_completion(
self,
model_prompt: list[LanguageModelMessage],
functions: list[LanguageModelFunction],
model_name: OpenAIModelName,
completion_parser: Callable[[dict], dict],
functions: list[LanguageModelFunction] = [],
**kwargs,
) -> LanguageModelProviderModelResponse:
) -> LanguageModelResponse:
"""Create a completion using the OpenAI API."""
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
response = await self._create_completion(
@@ -189,7 +261,7 @@ class OpenAIProvider(
parsed_response = completion_parser(
response.choices[0].message.to_dict_recursive()
)
response = LanguageModelProviderModelResponse(
response = LanguageModelResponse(
content=parsed_response, **response_args
)
self._budget.update_usage_and_cost(response)
@@ -201,7 +273,7 @@ class OpenAIProvider(
model_name: OpenAIModelName,
embedding_parser: Callable[[Embedding], Embedding],
**kwargs,
) -> EmbeddingModelProviderModelResponse:
) -> EmbeddingModelResponse:
"""Create an embedding using the OpenAI API."""
embedding_kwargs = self._get_embedding_kwargs(model_name, **kwargs)
response = await self._create_embedding(text=text, **embedding_kwargs)
@@ -211,7 +283,7 @@ class OpenAIProvider(
"prompt_tokens_used": response.usage.prompt_tokens,
"completion_tokens_used": response.usage.completion_tokens,
}
response = EmbeddingModelProviderModelResponse(
response = EmbeddingModelResponse(
**response_args,
embedding=embedding_parser(response.embeddings[0]),
)

View File

@@ -18,13 +18,13 @@ from autogpt.core.resource.schema import (
class ModelProviderService(str, enum.Enum):
"""A ModelService describes what kind of service the model provides."""
EMBEDDING: str = "embedding"
LANGUAGE: str = "language"
TEXT: str = "text"
EMBEDDING = "embedding"
LANGUAGE = "language"
TEXT = "text"
class ModelProviderName(str, enum.Enum):
OPENAI: str = "openai"
OPENAI = "openai"
class MessageRole(str, enum.Enum):
@@ -42,7 +42,7 @@ class LanguageModelFunction(BaseModel):
json_schema: dict
class ModelProviderModelInfo(BaseModel):
class ModelInfo(BaseModel):
"""Struct for model information.
Would be lovely to eventually get this directly from APIs, but needs to be
@@ -57,12 +57,12 @@ class ModelProviderModelInfo(BaseModel):
completion_token_cost: float = 0.0
class ModelProviderModelResponse(BaseModel):
class ModelResponse(BaseModel):
"""Standard response struct for a response from a model."""
prompt_tokens_used: int
completion_tokens_used: int
model_info: ModelProviderModelInfo
model_info: ModelInfo
class ModelProviderCredentials(ProviderCredentials):
@@ -101,7 +101,7 @@ class ModelProviderUsage(ProviderUsage):
def update_usage(
self,
model_response: ModelProviderModelResponse,
model_response: ModelResponse,
) -> None:
self.completion_tokens += model_response.completion_tokens_used
self.prompt_tokens += model_response.prompt_tokens_used
@@ -118,7 +118,7 @@ class ModelProviderBudget(ProviderBudget):
def update_usage_and_cost(
self,
model_response: ModelProviderModelResponse,
model_response: ModelResponse,
) -> None:
"""Update the usage and cost of the provider."""
model_info = model_response.model_info
@@ -126,7 +126,7 @@ class ModelProviderBudget(ProviderBudget):
incremental_cost = (
model_response.completion_tokens_used * model_info.completion_token_cost
+ model_response.prompt_tokens_used * model_info.prompt_token_cost
) / 1000.0
)
self.total_cost += incremental_cost
self.remaining_budget -= incremental_cost
@@ -142,6 +142,10 @@ class ModelProvider(abc.ABC):
defaults: ClassVar[ModelProviderSettings]
@abc.abstractmethod
def count_tokens(self, text: str, model_name: str) -> int:
...
@abc.abstractmethod
def get_token_limit(self, model_name: str) -> int:
...
@@ -156,14 +160,15 @@ class ModelProvider(abc.ABC):
####################
class EmbeddingModelProviderModelInfo(ModelProviderModelInfo):
class EmbeddingModelInfo(ModelInfo):
"""Struct for embedding model information."""
llm_service: ModelProviderService = ModelProviderService.EMBEDDING
llm_service = ModelProviderService.EMBEDDING
max_tokens: int
embedding_dimensions: int
class EmbeddingModelProviderModelResponse(ModelProviderModelResponse):
class EmbeddingModelResponse(ModelResponse):
"""Standard response struct for a response from an embedding model."""
embedding: Embedding = Field(default_factory=list)
@@ -184,7 +189,7 @@ class EmbeddingModelProvider(ModelProvider):
model_name: str,
embedding_parser: Callable[[Embedding], Embedding],
**kwargs,
) -> EmbeddingModelProviderModelResponse:
) -> EmbeddingModelResponse:
...
@@ -193,20 +198,28 @@ class EmbeddingModelProvider(ModelProvider):
###################
class LanguageModelProviderModelInfo(ModelProviderModelInfo):
class LanguageModelInfo(ModelInfo):
"""Struct for language model information."""
llm_service: ModelProviderService = ModelProviderService.LANGUAGE
llm_service = ModelProviderService.LANGUAGE
max_tokens: int
class LanguageModelProviderModelResponse(ModelProviderModelResponse):
class LanguageModelResponse(ModelResponse):
"""Standard response struct for a response from a language model."""
content: dict = None
class LanguageModelProvider(ModelProvider):
@abc.abstractmethod
def count_message_tokens(
self,
messages: LanguageModelMessage | list[LanguageModelMessage],
model_name: str,
) -> int:
...
@abc.abstractmethod
async def create_language_completion(
self,
@@ -215,5 +228,5 @@ class LanguageModelProvider(ModelProvider):
model_name: str,
completion_parser: Callable[[dict], dict],
**kwargs,
) -> LanguageModelProviderModelResponse:
) -> LanguageModelResponse:
...

View File

@@ -44,20 +44,20 @@ async def run_auto_gpt(user_configuration: dict):
agent_settings,
client_logger,
)
print(parse_agent_name_and_goals(name_and_goals))
print("\n" + parse_agent_name_and_goals(name_and_goals))
# Finally, update the agent settings with the name and goals.
agent_settings.update_agent_name_and_goals(name_and_goals)
# Step 3. Provision the agent.
agent_workspace = SimpleAgent.provision_agent(agent_settings, client_logger)
print("agent is provisioned")
client_logger.info("Agent is provisioned")
# launch agent interaction loop
agent = SimpleAgent.from_workspace(
agent_workspace,
client_logger,
)
print("agent is loaded")
client_logger.info("Agent is loaded")
plan = await agent.build_initial_plan()
print(parse_agent_plan(plan))

View File

@@ -0,0 +1,22 @@
import logging
from .config import configure_root_logger, FancyConsoleFormatter, BelowLevelFilter
from .helpers import dump_prompt
def get_client_logger():
# Configure logging before we do anything else.
# Application logs need a place to live.
client_logger = logging.getLogger("autogpt_client_application")
client_logger.setLevel(logging.DEBUG)
return client_logger
__all__ = [
"configure_root_logger",
"get_client_logger",
"FancyConsoleFormatter",
"BelowLevelFilter",
"dump_prompt",
]

View File

@@ -24,16 +24,7 @@ def configure_root_logger():
logging.basicConfig(level=logging.DEBUG, handlers=[stdout, stderr])
# Disable debug logging from OpenAI library
openai_logger.setLevel(logging.INFO)
def get_client_logger():
# Configure logging before we do anything else.
# Application logs need a place to live.
client_logger = logging.getLogger("autogpt_client_application")
client_logger.setLevel(logging.DEBUG)
return client_logger
openai_logger.setLevel(logging.WARNING)
class FancyConsoleFormatter(logging.Formatter):

View File

@@ -0,0 +1,21 @@
from math import ceil, floor
from autogpt.core.planning import LanguageModelPrompt
SEPARATOR_LENGTH = 42
def dump_prompt(prompt: LanguageModelPrompt) -> str:
def separator(text: str):
half_sep_len = (SEPARATOR_LENGTH - 2 - len(text)) / 2
return f"{floor(half_sep_len)*'-'} {text.upper()} {ceil(half_sep_len)*'-'}"
formatted_messages = "\n".join(
[f"{separator(m.role)}\n{m.content}" for m in prompt.messages]
)
return f"""
============== {prompt.__class__.__name__} ==============
Length: {len(prompt.messages)} messages
{formatted_messages}
==========================================
"""

View File

@@ -8,15 +8,15 @@ def parse_agent_name_and_goals(name_and_goals: dict) -> str:
def parse_agent_plan(plan: dict) -> str:
parsed_response = f"Agent Plan:\n"
parsed_response = "Agent Plan:\n"
for i, task in enumerate(plan["task_list"]):
parsed_response += f"{i+1}. {task['objective']}\n"
parsed_response += f"Task type: {task['type']} "
parsed_response += f"Priority: {task['priority']}\n"
parsed_response += f"Ready Criteria:\n"
parsed_response += "Ready Criteria:\n"
for j, criteria in enumerate(task["ready_criteria"]):
parsed_response += f" {j+1}. {criteria}\n"
parsed_response += f"Acceptance Criteria:\n"
parsed_response += "Acceptance Criteria:\n"
for j, criteria in enumerate(task["acceptance_criteria"]):
parsed_response += f" {j+1}. {criteria}\n"
parsed_response += "\n"

View File

@@ -21,7 +21,10 @@ class AutoGptFormatter(FancyConsoleFormatter):
# Determine color for title
title = getattr(record, "title", "")
title_color = getattr(record, "title_color", "") or self.LEVEL_COLOR_MAP.get(record.levelno, "")
title_color = (
getattr(record, "title_color", "")
or self.LEVEL_COLOR_MAP.get(record.levelno, "")
)
if title and title_color:
title = f"{title_color + Style.BRIGHT}{title}{Style.RESET_ALL}"
# Make sure record.title is set, and padded with a space if not empty