chore: use primitives instead of typing imports and fixes completion … (#149)

Signed-off-by: Adrian Cole <adrian.cole@elastic.co>
This commit is contained in:
Adrian Cole
2024-10-16 09:41:37 +11:00
committed by GitHub
parent e687b0b3bc
commit c247c8eb30
53 changed files with 235 additions and 257 deletions

View File

@@ -1,5 +1,4 @@
from copy import deepcopy
from typing import List
from attrs import define, field
@@ -31,7 +30,7 @@ class CheckpointData:
total_token_count: int = field(default=0)
# in order list of individual checkpoints in the exchange
checkpoints: List[Checkpoint] = field(factory=list)
checkpoints: list[Checkpoint] = field(factory=list)
# the offset to apply to the message index when calculating the last message index
# this is useful because messages on the exchange behave like a queue, where you can only

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Optional
from attrs import define, asdict
@@ -7,11 +7,11 @@ CONTENT_TYPES = {}
class Content:
def __init_subclass__(cls, **kwargs: Dict[str, Any]) -> None:
def __init_subclass__(cls, **kwargs: dict[str, any]) -> None:
super().__init_subclass__(**kwargs)
CONTENT_TYPES[cls.__name__] = cls
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, any]:
data = asdict(self, recurse=True)
data["type"] = self.__class__.__name__
return data
@@ -26,7 +26,7 @@ class Text(Content):
class ToolUse(Content):
id: str
name: str
parameters: Any
parameters: any
is_error: bool = False
error_message: Optional[str] = None

View File

@@ -1,8 +1,7 @@
import json
import traceback
from copy import deepcopy
from typing import Any, Dict, List, Mapping, Tuple
from typing import Mapping
from attrs import define, evolve, field, Factory
from tiktoken import get_encoding
@@ -41,8 +40,8 @@ class Exchange:
model: str
system: str
moderator: Moderator = field(default=ContextTruncate())
tools: Tuple[Tool] = field(factory=tuple, converter=tuple)
messages: List[Message] = field(factory=list)
tools: tuple[Tool, ...] = field(factory=tuple, converter=tuple)
messages: list[Message] = field(factory=list)
checkpoint_data: CheckpointData = field(factory=CheckpointData)
generation_args: dict = field(default=Factory(dict))
@@ -50,7 +49,7 @@ class Exchange:
def _toolmap(self) -> Mapping[str, Tool]:
return {tool.name: tool for tool in self.tools}
def replace(self, **kwargs: Dict[str, Any]) -> "Exchange":
def replace(self, **kwargs: dict[str, any]) -> "Exchange":
"""Make a copy of the exchange, replacing any passed arguments"""
# TODO: ensure that the checkpoint data is updated correctly. aka,
# if we replace the messages, we need to update the checkpoint data
@@ -264,7 +263,7 @@ class Exchange:
# we've removed all the checkpoints, so we need to reset the message index offset
self.checkpoint_data.message_index_offset = 0
def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
def pop_last_checkpoint(self) -> tuple[Checkpoint, list[Message]]:
"""
Reverts the exchange back to the last checkpoint, removing associated messages
"""
@@ -275,7 +274,7 @@ class Exchange:
messages.append(self.messages.pop())
return removed_checkpoint, messages
def pop_first_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
def pop_first_checkpoint(self) -> tuple[Checkpoint, list[Message]]:
"""
Pop the first checkpoint from the exchange, removing associated messages
"""
@@ -332,5 +331,6 @@ class Exchange:
# this to be a required method of the provider instead.
return len(self.messages) > 0 and self.messages[-1].role == "user"
def get_token_usage(self) -> Dict[str, Usage]:
@staticmethod
def get_token_usage() -> dict[str, Usage]:
return _token_usage_collector.get_token_usage_group_by_model()

View File

@@ -1,8 +1,5 @@
from typing import List
class InvalidChoiceError(Exception):
def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None:
def __init__(self, attribute_name: str, attribute_value: str, available_values: list[str]) -> None:
self.attribute_name = attribute_name
self.attribute_value = attribute_value
self.available_values = available_values

View File

@@ -1,7 +1,7 @@
import inspect
import time
from pathlib import Path
from typing import Any, Dict, List, Literal, Type
from typing import Literal
from attrs import define, field
from jinja2 import Environment, FileSystemLoader
@@ -12,7 +12,7 @@ from exchange.utils import create_object_id
Role = Literal["user", "assistant"]
def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: ANN401
def validate_role_and_content(instance: "Message", *_: any) -> None: # noqa: ANN401
if instance.role == "user":
if not (instance.text or instance.tool_result):
raise ValueError("User message must include a Text or ToolResult")
@@ -25,7 +25,7 @@ def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: AN
raise ValueError("Assistant message does not support ToolResult")
def content_converter(contents: List[Dict[str, Any]]) -> List[Content]:
def content_converter(contents: list[dict[str, any]]) -> list[Content]:
return [(CONTENT_TYPES[c.pop("type")](**c) if c.__class__ not in CONTENT_TYPES.values() else c) for c in contents]
@@ -48,9 +48,9 @@ class Message:
role: Role = field(default="user")
id: str = field(factory=lambda: str(create_object_id(prefix="msg")))
created: int = field(factory=lambda: int(time.time()))
content: List[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter)
content: list[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, any]:
return {
"role": self.role,
"id": self.id,
@@ -68,7 +68,7 @@ class Message:
return "\n".join(result)
@property
def tool_use(self) -> List[ToolUse]:
def tool_use(self) -> list[ToolUse]:
"""All tool use content of this message."""
result = []
for content in self.content:
@@ -77,7 +77,7 @@ class Message:
return result
@property
def tool_result(self) -> List[ToolResult]:
def tool_result(self) -> list[ToolResult]:
"""All tool result content of this message."""
result = []
for content in self.content:
@@ -87,10 +87,10 @@ class Message:
@classmethod
def load(
cls: Type["Message"],
cls: type["Message"],
filename: str,
role: Role = "user",
**kwargs: Dict[str, Any],
**kwargs: dict[str, any],
) -> "Message":
"""Load the message from filename relative to where the load is called.
@@ -113,9 +113,9 @@ class Message:
return cls(role=role, content=[Text(text=rendered_content)])
@classmethod
def user(cls: Type["Message"], text: str) -> "Message":
def user(cls: type["Message"], text: str) -> "Message":
return cls(role="user", content=[Text(text)])
@classmethod
def assistant(cls: Type["Message"], text: str) -> "Message":
def assistant(cls: type["Message"], text: str) -> "Message":
return cls(role="assistant", content=[Text(text)])

View File

@@ -1,5 +1,4 @@
from functools import cache
from typing import Type
from exchange.invalid_choice_error import InvalidChoiceError
from exchange.moderators.base import Moderator
@@ -10,7 +9,7 @@ from exchange.moderators.summarizer import ContextSummarizer # noqa
@cache
def get_moderator(name: str) -> Type[Moderator]:
def get_moderator(name: str) -> type[Moderator]:
moderators = load_plugins(group="exchange.moderator")
if name not in moderators:
raise InvalidChoiceError("moderator", name, moderators.keys())

View File

@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Type
class Moderator(ABC):
@abstractmethod
def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
pass

View File

@@ -1,7 +1,6 @@
from typing import Type
from exchange.moderators.base import Moderator
class PassiveModerator(Moderator):
def rewrite(self, _: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
def rewrite(self, _: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
pass

View File

@@ -1,12 +1,10 @@
from typing import Type
from exchange import Message
from exchange.checkpoint import CheckpointData
from exchange.moderators import ContextTruncate, PassiveModerator
class ContextSummarizer(ContextTruncate):
def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
"""Summarize the context history up to the last few messages in the exchange"""
self._update_system_prompt_token_count(exchange)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from exchange.checkpoint import CheckpointData
from exchange.message import Message
@@ -62,7 +62,7 @@ class ContextTruncate(Moderator):
exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count
exchange.checkpoint_data.total_token_count += self.system_prompt_token_count
def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]:
def _get_messages_to_remove(self, exchange: Exchange) -> list[Message]:
# this keeps all the messages/checkpoints
throwaway_exchange = exchange.replace(
moderator=PassiveModerator(),

View File

@@ -1,5 +1,4 @@
from functools import cache
from typing import Type
from exchange.invalid_choice_error import InvalidChoiceError
from exchange.providers.anthropic import AnthropicProvider # noqa
@@ -15,7 +14,7 @@ from exchange.utils import load_plugins
@cache
def get_provider(name: str) -> Type[Provider]:
def get_provider(name: str) -> type[Provider]:
providers = load_plugins(group="exchange.provider")
if name not in providers:
raise InvalidChoiceError("provider", name, providers.keys())

View File

@@ -1,5 +1,4 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
@@ -29,7 +28,7 @@ class AnthropicProvider(Provider):
self.client = client
@classmethod
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
def from_env(cls: type["AnthropicProvider"]) -> "AnthropicProvider":
cls.check_env_vars()
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
key = os.environ.get("ANTHROPIC_API_KEY")
@@ -45,7 +44,7 @@ class AnthropicProvider(Provider):
return cls(client)
@staticmethod
def get_usage(data: Dict) -> Usage: # noqa: ANN401
def get_usage(data: dict) -> Usage: # noqa: ANN401
usage = data.get("usage")
input_tokens = usage.get("input_tokens")
output_tokens = usage.get("output_tokens")
@@ -61,7 +60,7 @@ class AnthropicProvider(Provider):
)
@staticmethod
def anthropic_response_to_message(response: Dict) -> Message:
def anthropic_response_to_message(response: dict) -> Message:
content_blocks = response.get("content", [])
content = []
for block in content_blocks:
@@ -78,7 +77,7 @@ class AnthropicProvider(Provider):
return Message(role="assistant", content=content)
@staticmethod
def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]:
def tools_to_anthropic_spec(tools: tuple[Tool, ...]) -> list[dict[str, any]]:
return [
{
"name": tool.name,
@@ -89,7 +88,7 @@ class AnthropicProvider(Provider):
]
@staticmethod
def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]:
def messages_to_anthropic_spec(messages: list[Message]) -> list[dict[str, any]]:
messages_spec = []
# if messages is empty - just make a default
for message in messages:
@@ -127,10 +126,12 @@ class AnthropicProvider(Provider):
self,
model: str,
system: str,
messages: List[Message],
tools: List[Tool] = [],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: list[Tool] = None,
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
if tools is None:
tools = []
tools_set = set()
unique_tools = []
for tool in tools:

View File

@@ -1,5 +1,3 @@
from typing import Type
import httpx
import os
@@ -21,7 +19,7 @@ class AzureProvider(OpenAiProvider):
super().__init__(client)
@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
def from_env(cls: type["AzureProvider"]) -> "AzureProvider":
cls.check_env_vars()
url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME")
deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")

View File

@@ -1,7 +1,7 @@
import os
from abc import ABC, abstractmethod
from attrs import define, field
from typing import List, Optional, Tuple, Type
from typing import Optional
from exchange.message import Message
from exchange.tool import Tool
@@ -19,11 +19,11 @@ class Provider(ABC):
REQUIRED_ENV_VARS: list[str] = []
@classmethod
def from_env(cls: Type["Provider"]) -> "Provider":
def from_env(cls: type["Provider"]) -> "Provider":
return cls()
@classmethod
def check_env_vars(cls: Type["Provider"], instructions_url: Optional[str] = None) -> None:
def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None:
for env_var in cls.REQUIRED_ENV_VARS:
if env_var not in os.environ:
raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url)
@@ -33,9 +33,10 @@ class Provider(ABC):
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: tuple[Tool, ...],
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
"""Generate the next message using the specified model"""
pass

View File

@@ -4,7 +4,7 @@ import json
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Optional
from urllib.parse import quote, urlparse
import httpx
@@ -36,7 +36,7 @@ class AwsClient(httpx.Client):
aws_access_key: str,
aws_secret_key: str,
aws_session_token: Optional[str] = None,
**kwargs: Dict[str, Any],
**kwargs: dict[str, any],
) -> None:
self.region = aws_region
self.host = f"https://{SERVICE}.{aws_region}.amazonaws.com/"
@@ -45,7 +45,7 @@ class AwsClient(httpx.Client):
self.session_token = aws_session_token
super().__init__(base_url=self.host, timeout=600, **kwargs)
def post(self, path: str, json: Dict, **kwargs: Dict[str, Any]) -> httpx.Response:
def post(self, path: str, json: dict, **kwargs: dict[str, any]) -> httpx.Response:
signed_headers = self.sign_and_get_headers(
method="POST",
url=path,
@@ -60,7 +60,7 @@ class AwsClient(httpx.Client):
url: str,
payload: dict,
service: str,
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Sign the request and generate the necessary headers for AWS authentication.
@@ -72,10 +72,10 @@ class AwsClient(httpx.Client):
region (str): The AWS region.
access_key (str): The AWS access key.
secret_key (str): The AWS secret key.
session_token (Optional[str]): The AWS session token, if any.
session_token (optional[str]): The AWS session token, if any.
Returns:
Dict[str, str]: The headers required for the request.
dict[str, str]: The headers required for the request.
"""
def sign(key: bytes, msg: str) -> bytes:
@@ -160,7 +160,7 @@ class BedrockProvider(Provider):
self.client = client
@classmethod
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
def from_env(cls: type["BedrockProvider"]) -> "BedrockProvider":
cls.check_env_vars()
aws_region = os.environ.get("AWS_REGION", "us-east-1")
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
@@ -179,22 +179,22 @@ class BedrockProvider(Provider):
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: tuple[Tool, ...],
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
"""
Generate a completion response from the Bedrock gateway.
Args:
model (str): The model identifier.
system (str): The system prompt or configuration.
messages (List[Message]): A list of messages to be processed by the model.
tools (Tuple[Tool]): A tuple of tools to be used in the completion process.
messages (list[Message]): A list of messages to be processed by the model.
tools (tuple[Tool]): A tuple of tools to be used in the completion process.
**kwargs: Additional keyword arguments for inference configuration.
Returns:
Tuple[Message, Usage]: A tuple containing the response message and usage data.
tuple[Message, Usage]: A tuple containing the response message and usage data.
"""
inference_config = dict(
@@ -231,7 +231,7 @@ class BedrockProvider(Provider):
return self.response_to_message(response_message), usage
@retry_procedure
def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401
def _post(self, payload: any, path: str) -> dict: # noqa: ANN401
response = self.client.post(path, json=payload)
return raise_for_status(response).json()
@@ -311,7 +311,7 @@ class BedrockProvider(Provider):
raise Exception("Invalid response")
@staticmethod
def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]:
def tools_to_bedrock_spec(tools: tuple[Tool, ...]) -> Optional[dict]:
if len(tools) == 0:
return None # API requires a non-empty tool config or None
tools_added = set()

View File

@@ -1,5 +1,3 @@
from typing import Any, Dict, List, Tuple, Type
import httpx
import os
@@ -43,7 +41,7 @@ class DatabricksProvider(Provider):
self.client = client
@classmethod
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
def from_env(cls: type["DatabricksProvider"]) -> "DatabricksProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("DATABRICKS_HOST")
key = os.environ.get("DATABRICKS_TOKEN")
@@ -73,10 +71,10 @@ class DatabricksProvider(Provider):
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: tuple[Tool, ...],
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
payload = dict(
messages=[
{"role": "system", "content": system},

View File

@@ -1,5 +1,4 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
@@ -30,7 +29,7 @@ class GoogleProvider(Provider):
self.client = client
@classmethod
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
def from_env(cls: type["GoogleProvider"]) -> "GoogleProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
key = os.environ.get("GOOGLE_API_KEY")
@@ -45,7 +44,7 @@ class GoogleProvider(Provider):
return cls(client)
@staticmethod
def get_usage(data: Dict) -> Usage: # noqa: ANN401
def get_usage(data: dict) -> Usage: # noqa: ANN401
usage = data.get("usageMetadata")
input_tokens = usage.get("promptTokenCount")
output_tokens = usage.get("candidatesTokenCount")
@@ -61,7 +60,7 @@ class GoogleProvider(Provider):
)
@staticmethod
def google_response_to_message(response: Dict) -> Message:
def google_response_to_message(response: dict) -> Message:
candidates = response.get("candidates", [])
if candidates:
# Only use first candidate for now
@@ -85,12 +84,12 @@ class GoogleProvider(Provider):
return Message(role="assistant", content=[])
@staticmethod
def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]:
def tools_to_google_spec(tools: tuple[Tool, ...]) -> dict[str, list[dict[str, any]]]:
if not tools:
return {}
converted_tools = []
for tool in tools:
converted_tool: Dict[str, Any] = {
converted_tool: dict[str, any] = {
"name": tool.name,
"description": tool.description or "",
}
@@ -100,7 +99,7 @@ class GoogleProvider(Provider):
return {"functionDeclarations": converted_tools}
@staticmethod
def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]:
def messages_to_google_spec(messages: list[Message]) -> list[dict[str, any]]:
messages_spec = []
for message in messages:
role = "user" if message.role == "user" else "model"
@@ -136,10 +135,10 @@ class GoogleProvider(Provider):
self,
model: str,
system: str,
messages: List[Message],
tools: List[Tool] = [],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: list[Tool] = None,
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
tools_set = set()
unique_tools = []
for tool in tools:

View File

@@ -1,5 +1,4 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
@@ -37,7 +36,7 @@ class GroqProvider(Provider):
self.client = client
@classmethod
def from_env(cls: Type["GroqProvider"]) -> "GroqProvider":
def from_env(cls: type["GroqProvider"]) -> "GroqProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("GROQ_HOST", GROQ_HOST)
key = os.environ.get("GROQ_API_KEY")
@@ -69,10 +68,10 @@ class GroqProvider(Provider):
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: tuple[Tool, ...],
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
system_message = [{"role": "system", "content": system}]
payload = dict(
messages=system_message + messages_to_openai_spec(messages),

View File

@@ -1,5 +1,4 @@
import os
from typing import Type
import httpx
@@ -31,7 +30,7 @@ ollama:
super().__init__(client)
@classmethod
def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider":
def from_env(cls: type["OllamaProvider"]) -> "OllamaProvider":
ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST)
timeout = httpx.Timeout(60 * 10)

View File

@@ -1,5 +1,4 @@
import os
from typing import Any, Dict, List, Tuple, Type
import httpx
@@ -37,7 +36,7 @@ class OpenAiProvider(Provider):
self.client = client
@classmethod
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
def from_env(cls: type["OpenAiProvider"]) -> "OpenAiProvider":
cls.check_env_vars(cls.instructions_url)
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
key = os.environ.get("OPENAI_API_KEY")
@@ -69,10 +68,10 @@ class OpenAiProvider(Provider):
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
messages: list[Message],
tools: tuple[Tool, ...],
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}]
payload = dict(
messages=system_message + messages_to_openai_spec(messages),

View File

@@ -1,7 +1,7 @@
import base64
import json
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Optional
import httpx
from exchange.content import Text, ToolResult, ToolUse
@@ -10,10 +10,10 @@ from exchange.tool import Tool
from tenacity import retry_if_exception
def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable:
def retry_if_status(codes: Optional[list[int]] = None, above: Optional[int] = None) -> callable:
codes = codes or []
def predicate(exc: Exception) -> bool:
def predicate(exc: BaseException) -> bool:
if isinstance(exc, httpx.HTTPStatusError):
if exc.response.status_code in codes:
return True
@@ -42,7 +42,7 @@ def encode_image(image_path: str) -> str:
return base64.b64encode(image_file.read()).decode("utf-8")
def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]:
def messages_to_openai_spec(messages: list[Message]) -> list[dict[str, any]]:
messages_spec = []
for message in messages:
converted = {"role": message.role}
@@ -106,7 +106,7 @@ def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]:
return messages_spec
def tools_to_openai_spec(tools: Tuple[Tool]) -> Dict[str, Any]:
def tools_to_openai_spec(tools: tuple[Tool, ...]) -> dict[str, any]:
tools_names = set()
result = []
for tool in tools:

View File

@@ -1,5 +1,4 @@
from collections import defaultdict
from typing import Dict
from exchange.providers.base import Usage
@@ -11,7 +10,7 @@ class _TokenUsageCollector:
def collect(self, model: str, usage: Usage) -> None:
self.usage_data.append((model, usage))
def get_token_usage_group_by_model(self) -> Dict[str, Usage]:
def get_token_usage_group_by_model(self) -> dict[str, Usage]:
usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0))
for model, usage in self.usage_data:
usage_by_model = usage_group_by_model[model]

View File

@@ -1,5 +1,4 @@
import inspect
from typing import Any, Callable, Type
from attrs import define
@@ -13,17 +12,17 @@ class Tool:
Attributes:
name (str): The name of the tool
description (str): A description of what the tool does
parameters dict[str, Any]: A json schema of the function signature
parameters dict[str, any]: A json schema of the function signature
function (Callable): The python function that powers the tool
"""
name: str
description: str
parameters: dict[str, Any]
function: Callable
parameters: dict[str, any]
function: callable
@classmethod
def from_function(cls: Type["Tool"], func: Any) -> "Tool": # noqa: ANN401
def from_function(cls: type["Tool"], func: any) -> "Tool": # noqa: ANN401
"""Create a tool instance from a function and its docstring
The function must have a docstring - we require it to load the description

View File

@@ -1,7 +1,7 @@
import inspect
import uuid
from importlib.metadata import entry_points
from typing import Any, Callable, Dict, List, Type, get_args, get_origin
from typing import get_args, get_origin
from griffe import (
Docstring,
@@ -20,7 +20,7 @@ def compact(content: str) -> str:
return " ".join(content.split())
def parse_docstring(func: Callable) -> tuple[str, List[Dict]]:
def parse_docstring(func: callable) -> tuple[str, list[dict]]:
"""Get description and parameters from function docstring"""
function_args = list(inspect.signature(func).parameters.keys())
text = str(func.__doc__)
@@ -71,7 +71,7 @@ def parse_docstring(func: Callable) -> tuple[str, List[Dict]]:
def _check_section_is_present(
parsed_docstring: List[DocstringSection], section_type: Type[DocstringSectionText]
parsed_docstring: list[DocstringSection], section_type: type[DocstringSectionText]
) -> bool:
for section in parsed_docstring:
if isinstance(section, section_type):
@@ -79,7 +79,7 @@ def _check_section_is_present(
return False
def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401
def json_schema(func: any) -> dict[str, any]: # noqa: ANN401
"""Get the json schema for a function"""
signature = inspect.signature(func)
parameters = signature.parameters
@@ -107,16 +107,16 @@ def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401
return schema
def _map_type_to_schema(py_type: Type) -> Dict[str, Any]: # noqa: ANN401
def _map_type_to_schema(py_type: type) -> dict[str, any]: # noqa: ANN401
origin = get_origin(py_type)
args = get_args(py_type)
if origin is list or origin is tuple:
return {"type": "array", "items": _map_type_to_schema(args[0] if args else Any)}
return {"type": "array", "items": _map_type_to_schema(args[0] if args else any)}
elif origin is dict:
return {
"type": "object",
"additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else Any),
"additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else any),
}
elif py_type is int:
return {"type": "integer"}

View File

@@ -1,7 +1,6 @@
import json
import os
import re
from typing import Type, Tuple
import pytest
import yaml
@@ -189,14 +188,14 @@ def scrub_response_headers(response):
return response
def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
def complete(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)
def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
def tools(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant. Expect to need to read a file using read_file."
messages = [Message.user("What are the contents of this file? test.txt")]
@@ -205,7 +204,7 @@ def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message,
)
def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
def vision(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [

View File

@@ -14,10 +14,10 @@ AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")
@pytest.mark.parametrize(
"env_var_name",
[
("AZURE_CHAT_COMPLETIONS_HOST_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"),
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"),
("AZURE_CHAT_COMPLETIONS_KEY"),
"AZURE_CHAT_COMPLETIONS_HOST_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION",
"AZURE_CHAT_COMPLETIONS_KEY",
],
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):

View File

@@ -15,9 +15,9 @@ logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"env_var_name",
[
("AWS_ACCESS_KEY_ID"),
("AWS_SECRET_ACCESS_KEY"),
("AWS_SESSION_TOKEN"),
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
],
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):

View File

@@ -10,8 +10,8 @@ from exchange.providers.databricks import DatabricksProvider
@pytest.mark.parametrize(
"env_var_name",
[
("DATABRICKS_HOST"),
("DATABRICKS_TOKEN"),
"DATABRICKS_HOST",
"DATABRICKS_TOKEN",
],
)
def test_from_env_throw_error_when_missing_env_var(env_var_name):

View File

@@ -107,9 +107,9 @@ def test_messages_to_openai_spec() -> None:
Message(role="user", content=[Text("How are you?")]),
Message(
role="assistant",
content=[ToolUse(id=1, name="tool1", parameters={"param1": "value1"})],
content=[ToolUse(id="1", name="tool1", parameters={"param1": "value1"})],
),
Message(role="user", content=[ToolResult(tool_use_id=1, output="Result")]),
Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]),
]
spec = messages_to_openai_spec(messages)
@@ -121,7 +121,7 @@ def test_messages_to_openai_spec() -> None:
"role": "assistant",
"tool_calls": [
{
"id": 1,
"id": "1",
"type": "function",
"function": {
"name": "tool1",
@@ -133,7 +133,7 @@ def test_messages_to_openai_spec() -> None:
{
"role": "tool",
"content": "Result",
"tool_call_id": 1,
"tool_call_id": "1",
},
]
@@ -216,7 +216,7 @@ def test_openai_response_to_message_valid_tooluse() -> None:
expect = asdict(
Message(
role="assistant",
content=[ToolUse(id=1, name="example_fn", parameters={"param": "value"})],
content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})],
)
)
actual.pop("id")

View File

@@ -1,5 +1,3 @@
from typing import List, Tuple
import pytest
from exchange.checkpoint import Checkpoint, CheckpointData
@@ -29,12 +27,12 @@ def no_overlapping_checkpoints(exchange: Exchange) -> bool:
return True
def checkpoint_to_index_pairs(checkpoints: List[Checkpoint]) -> List[Tuple[int, int]]:
def checkpoint_to_index_pairs(checkpoints: list[Checkpoint]) -> list[tuple[int, int]]:
return [(checkpoint.start_index, checkpoint.end_index) for checkpoint in checkpoints]
class MockProvider(Provider):
def __init__(self, sequence: List[Message], usage_dicts: List[dict]):
def __init__(self, sequence: list[Message], usage_dicts: list[dict]):
# We'll use init to provide a preplanned reply sequence
self.sequence = sequence
self.call_count = 0
@@ -56,11 +54,18 @@ class MockProvider(Provider):
total_tokens=total_tokens,
)
def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message:
def complete(
self,
model: str,
system: str,
messages: list[Message],
tools: tuple[Tool, ...],
**kwargs: dict[str, any],
) -> tuple[Message, Usage]:
output = self.sequence[self.call_count]
usage = self.get_usage(self.usage_dicts[self.call_count])
self.call_count += 1
return (output, usage)
return output, usage
def test_reply_with_unsupported_tool():
@@ -116,7 +121,7 @@ def test_invalid_tool_parameters():
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(dummy_tool)],
tools=(Tool.from_function(dummy_tool),),
moderator=PassiveModerator(),
)
@@ -154,7 +159,7 @@ def test_max_tool_use_when_limit_reached():
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(dummy_tool)],
tools=(Tool.from_function(dummy_tool),),
moderator=PassiveModerator(),
)
@@ -195,7 +200,7 @@ def test_tool_output_too_long_character_error():
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(long_output_tool_char)],
tools=(Tool.from_function(long_output_tool_char),),
moderator=PassiveModerator(),
)
@@ -236,7 +241,7 @@ def test_tool_output_too_long_token_error():
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(long_output_tool_token)],
tools=(Tool.from_function(long_output_tool_token),),
moderator=PassiveModerator(),
)
@@ -301,7 +306,7 @@ def resumed_exchange() -> Exchange:
ex = Exchange(
provider=provider,
messages=messages,
tools=[],
tools=(),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
checkpoint_data=CheckpointData(),
@@ -399,7 +404,7 @@ def test_pop_first_message_no_messages():
provider=MockProvider(sequence=[], usage_dicts=[]),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(dummy_tool)],
tools=(Tool.from_function(dummy_tool),),
moderator=PassiveModerator(),
)
@@ -741,7 +746,7 @@ def test_rewind_with_tool_usage():
),
model="gpt-4o-2024-05-13",
system="You are a helpful assistant.",
tools=[Tool.from_function(dummy_tool)],
tools=(Tool.from_function(dummy_tool),),
moderator=PassiveModerator(),
)
ex.add(Message(role="user", content=[Text(text="test")]))

View File

@@ -9,7 +9,7 @@ from exchange.tool import Tool
class MockProvider(Provider):
def complete(self, model, system, messages, tools=None):
def complete(self, model, system, messages, tools, **kwargs):
return Message(role="assistant", content=[Text(text="This is a mock response.")]), Usage.from_dict(
{"total_tokens": 35}
)

View File

@@ -3,11 +3,11 @@ from exchange import Exchange, Message
from exchange.content import ToolResult, ToolUse
from exchange.moderators.passive import PassiveModerator
from exchange.moderators.summarizer import ContextSummarizer
from exchange.providers import Usage
from exchange.providers import Usage, Provider
class MockProvider:
def complete(self, model, system, messages, tools):
class MockProvider(Provider):
def complete(self, model, system, messages, tools, **kwargs):
assistant_message_text = "Summarized content here."
output_tokens = len(assistant_message_text)
total_input_tokens = sum(len(msg.text) for msg in messages)
@@ -138,14 +138,14 @@ MESSAGE_SEQUENCE = [
]
class AnotherMockProvider:
class AnotherMockProvider(Provider):
def __init__(self):
self.sequence = MESSAGE_SEQUENCE
self.current_index = 1
self.summarize_next = False
self.summarized_count = 0
def complete(self, model, system, messages, tools):
def complete(self, model, system, messages, tools, **kwargs):
system_prompt_tokens = 100
input_token_count = system_prompt_tokens

View File

@@ -73,7 +73,7 @@ class TruncateLinearProvider(Provider):
self.summarize_next = False
self.summarized_count = 0
def complete(self, model, system, messages, tools):
def complete(self, model, system, messages, tools, **kwargs):
input_token_count = SYSTEM_PROMPT_TOKENS
message = self.sequence[self.current_index]

View File

@@ -1,6 +1,6 @@
from functools import cache
from pathlib import Path
from typing import Callable, Dict, Mapping, Optional, Tuple
from typing import Mapping, Optional
from rich import print
from rich.panel import Panel
@@ -20,7 +20,7 @@ RECOMMENDED_DEFAULT_PROVIDER = "openai"
@cache
def default_profiles() -> Mapping[str, Callable]:
def default_profiles() -> Mapping[str, callable]:
return load_plugins(group="goose.profile")
@@ -29,7 +29,7 @@ def session_path(name: str) -> Path:
return SESSIONS_PATH.joinpath(f"{name}{SESSION_FILE_SUFFIX}")
def write_config(profiles: Dict[str, Profile]) -> None:
def write_config(profiles: dict[str, Profile]) -> None:
"""Overwrite the config with the passed profiles"""
PROFILES_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
converted = {name: profile.to_dict() for name, profile in profiles.items()}
@@ -38,7 +38,7 @@ def write_config(profiles: Dict[str, Profile]) -> None:
yaml.dump(converted, f)
def ensure_config(name: Optional[str]) -> Tuple[str, Profile]:
def ensure_config(name: Optional[str]) -> tuple[str, Profile]:
"""Ensure that the config exists and has the default section"""
# TODO we should copy a templated default config in to better document
# but this is complicated a bit by autodetecting the provider
@@ -70,7 +70,7 @@ def ensure_config(name: Optional[str]) -> Tuple[str, Profile]:
return (name, default_profile)
def read_config() -> Dict[str, Profile]:
def read_config() -> dict[str, Profile]:
"""Return config from the configuration file and validates its contents"""
yaml = YAML()
@@ -80,7 +80,7 @@ def read_config() -> Dict[str, Profile]:
return {name: Profile(**profile) for name, profile in data.items()}
def default_model_configuration() -> Tuple[str, str, str]:
def default_model_configuration() -> tuple[str, str, str]:
providers = load_plugins(group="exchange.provider")
for provider, cls in providers.items():
try:

View File

@@ -1,5 +1,4 @@
import re
from typing import List
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
from prompt_toolkit.document import Document
@@ -8,10 +7,10 @@ from goose.command.base import Command
class GoosePromptCompleter(Completer):
def __init__(self, commands: List[Command]) -> None:
def __init__(self, commands: list[Command]) -> None:
self.commands = commands
def get_command_completions(self, document: Document) -> List[Completion]:
def get_command_completions(self, document: Document) -> list[Completion]:
all_completions = []
for command_name, command_instance in self.commands.items():
pattern = rf"(?<!\S)\/{command_name}:([\S]*)$"
@@ -25,7 +24,7 @@ class GoosePromptCompleter(Completer):
all_completions.extend(completions)
return all_completions
def get_command_name_completions(self, document: Document) -> List[Completion]:
def get_command_name_completions(self, document: Document) -> list[Completion]:
pattern = r"(?<!\S)\/([\S]*)$"
text = document.text_before_cursor
match = re.search(pattern=pattern, string=text)
@@ -40,7 +39,7 @@ class GoosePromptCompleter(Completer):
completions.append(Completion(command_name, start_position=-len(query), display=command_name))
return completions
def get_completions(self, document: Document, _: CompleteEvent) -> List[Completion]:
def get_completions(self, document: Document, _: CompleteEvent) -> list[Completion]:
command_completions = self.get_command_completions(document)
command_name_completions = self.get_command_name_completions(document)
return command_name_completions + command_completions

View File

@@ -1,5 +1,5 @@
import re
from typing import Callable, List, Tuple
from typing import Callable
from prompt_toolkit.document import Document
from prompt_toolkit.lexers import Lexer
@@ -27,7 +27,7 @@ def value_for_command(command_string: str) -> re.Pattern[str]:
class PromptLexer(Lexer):
def __init__(self, command_names: List[str]) -> None:
def __init__(self, command_names: list[str]) -> None:
self.patterns = []
for command_name in command_names:
self.patterns.append((completion_for_command(command_name), "class:command"))
@@ -35,7 +35,7 @@ class PromptLexer(Lexer):
self.patterns.append((command_itself(command_name), "class:command"))
def lex_document(self, document: Document) -> Callable[[int], list]:
def get_line_tokens(line_number: int) -> Tuple[str, str]:
def get_line_tokens(line_number: int) -> tuple[str, str]:
line = document.lines[line_number]
tokens = []

View File

@@ -1,10 +1,8 @@
from typing import Any
from rich.prompt import Prompt
class OverwriteSessionPrompt(Prompt):
def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
def __init__(self, *args: tuple[any], **kwargs: dict[str, any]) -> None:
super().__init__(*args, **kwargs)
self.choices = {
"yes": "Overwrite the existing session",

View File

@@ -1,6 +1,6 @@
import traceback
from pathlib import Path
from typing import Any, Optional
from typing import Optional
from exchange import Message, Text, ToolResult, ToolUse
from rich import print
@@ -62,7 +62,7 @@ class Session:
profile: Optional[str] = None,
plan: Optional[dict] = None,
log_level: Optional[str] = "INFO",
**kwargs: dict[str, Any],
**kwargs: dict[str, any],
) -> None:
if name is None:
self.name = droid()

View File

@@ -1,5 +1,4 @@
from functools import cache
from typing import Dict
from goose.command.base import Command
from goose.utils import load_plugins
@@ -11,5 +10,5 @@ def get_command(name: str) -> type[Command]:
@cache
def get_commands() -> Dict[str, type[Command]]:
def get_commands() -> dict[str, type[Command]]:
return load_plugins(group="goose.command")

View File

@@ -1,5 +1,5 @@
from abc import ABC
from typing import List, Optional
from typing import Optional
from prompt_toolkit.completion import Completion
@@ -7,7 +7,7 @@ from prompt_toolkit.completion import Completion
class Command(ABC):
"""A command that can be executed by the CLI."""
def get_completions(self, query: str) -> List[Completion]:
def get_completions(self, query: str) -> list[Completion]:
"""
Get completions for the command.

View File

@@ -1,5 +1,4 @@
import os
from typing import List
from prompt_toolkit.completion import Completion
@@ -7,7 +6,7 @@ from goose.command.base import Command
class FileCommand(Command):
def get_completions(self, query: str) -> List[Completion]:
def get_completions(self, query: str) -> list[Completion]:
if query.startswith("/"):
directory = os.path.dirname(query)
search_term = os.path.basename(query)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Mapping, Type
from typing import Mapping
from attrs import asdict, define, field
@@ -21,10 +21,10 @@ class Profile:
processor: str
accelerator: str
moderator: str
toolkits: List[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec))
toolkits: list[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec))
@toolkits.validator
def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[ToolkitSpec]) -> None:
def check_toolkit_requirements(self, _: type["ToolkitSpec"], toolkits: list[ToolkitSpec]) -> None:
# checks that the list of toolkits in the profile have their requirements
installed_toolkits = set([toolkit.name for toolkit in toolkits])
@@ -36,7 +36,7 @@ class Profile:
msg = f"Toolkit {toolkit_name} requires {req} but it is not present"
raise ValueError(msg)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, any]:
return asdict(self)
def profile_info(self) -> str:
@@ -44,7 +44,7 @@ class Profile:
return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}"
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile:
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: dict[str, any]) -> Profile:
"""Get the default profile"""
# TODO consider if the providers should have recommended models

View File

@@ -1,6 +1,6 @@
import inspect
from abc import ABC
from typing import Callable, Mapping, Optional, Tuple, TypeVar
from typing import Mapping, Optional, TypeVar
from attrs import define, field
from exchange import Tool
@@ -8,7 +8,7 @@ from exchange import Tool
from goose.notifier import Notifier
# Create a type variable that can represent any function signature
F = TypeVar("F", bound=Callable)
F = TypeVar("F", bound=callable)
def tool(func: F) -> F:
@@ -55,7 +55,7 @@ class Toolkit(ABC):
"""Get the addition to the system prompt for this toolkit."""
return ""
def tools(self) -> Tuple[Tool, ...]:
def tools(self) -> tuple[Tool, ...]:
"""Get the tools for this toolkit
This default method looks for functions on the toolkit annotated

View File

@@ -3,7 +3,6 @@ import re
import subprocess
import time
from pathlib import Path
from typing import Dict, List
from exchange import Message
from goose.toolkit.base import Toolkit, tool
@@ -35,9 +34,9 @@ class Developer(Toolkit):
We also include some default shell strategies in the prompt, such as using ripgrep
"""
def __init__(self, *args: object, **kwargs: Dict[str, object]) -> None:
def __init__(self, *args: object, **kwargs: dict[str, object]) -> None:
super().__init__(*args, **kwargs)
self.timestamps: Dict[str, float] = {}
self.timestamps: dict[str, float] = {}
def system(self) -> str:
"""Retrieve system configuration details for developer"""
@@ -55,7 +54,7 @@ class Developer(Toolkit):
return system_prompt
@tool
def update_plan(self, tasks: List[dict]) -> List[dict]:
def update_plan(self, tasks: list[dict]) -> list[dict]:
"""
Update the plan by overwriting all current tasks
@@ -63,7 +62,7 @@ class Developer(Toolkit):
shown to the user directly, you do not need to reiterate it
Args:
tasks (List(dict)): The list of tasks, where each task is a dictionary
tasks (list(dict)): The list of tasks, where each task is a dictionary
with a key for the task "description" and the task "status". The status
MUST be one of "planned", "complete", "failed", "in-progress".

View File

@@ -1,7 +1,6 @@
import os
from functools import cache
from subprocess import CompletedProcess, run
from typing import Dict, Tuple
from exchange import Message
@@ -21,7 +20,7 @@ class RepoContext(Toolkit):
self.repo_project_root, self.is_git_repo, self.goose_session_root = self.determine_git_proj()
def determine_git_proj(self) -> Tuple[str, bool, str]:
def determine_git_proj(self) -> tuple[str, bool, str]:
"""Determines the root as well as where Goose is currently running
If the project is not part of a Github repo, the root of the project will be defined as the current working
@@ -72,11 +71,11 @@ class RepoContext(Toolkit):
return self.repo_size > 2000
@tool
def summarize_current_project(self) -> Dict[str, str]:
def summarize_current_project(self) -> dict[str, str]:
"""Summarizes the current project based on repo root (if git repo) or current project_directory (if not)
Returns:
summary (Dict[str, str]): Keys are file paths and values are the summaries
summary (dict[str, str]): Keys are file paths and values are the summaries
"""
self.notifier.log("Summarizing the most relevant files in the current project. This may take a while...")

View File

@@ -2,7 +2,6 @@ import ast
import concurrent.futures
import os
from collections import deque
from typing import Dict, List, Tuple
from exchange import Exchange
@@ -26,7 +25,7 @@ def get_repo_size(repo_path: str) -> int:
return get_directory_size(git_dir) / (1024**2)
def get_files_and_directories(root_dir: str) -> Dict[str, list]:
def get_files_and_directories(root_dir: str) -> dict[str, list]:
"""Gets file names and directory names. Checks that goose has correctly typed the file and directory names and that
the files actually exist (to avoid downstream file read errors).
@@ -61,7 +60,7 @@ def get_files_and_directories(root_dir: str) -> Dict[str, list]:
return {"files": files, "directories": dirs}
def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> List[str]:
def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> list[str]:
"""Lets goose pick files in a BFS manner"""
queue = deque([root])
@@ -80,7 +79,7 @@ def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> Li
return all_files
def process_directory(current_dir: str, exchange: Exchange) -> Tuple[List[str], List[str]]:
def process_directory(current_dir: str, exchange: Exchange) -> tuple[list[str], list[str]]:
"""Allows goose to pick files and subdirectories contained in a given directory (current_dir). Get the list of file
and directory names in the current folder, then ask Goose to pick which ones to keep.

View File

@@ -1,5 +1,5 @@
import os
from typing import List, Optional
from typing import Optional
from goose.toolkit import Toolkit
from goose.toolkit.base import tool
@@ -11,7 +11,7 @@ class SummarizeProject(Toolkit):
def get_project_summary(
self,
project_dir_path: Optional[str] = os.getcwd(),
extensions: Optional[List[str]] = None,
extensions: Optional[list[str]] = None,
summary_instructions_prompt: Optional[str] = None,
) -> dict:
"""Generates or retrieves a project summary based on specified file extensions.
@@ -19,7 +19,7 @@ class SummarizeProject(Toolkit):
Args:
project_dir_path (Optional[Path]): Path to the project directory. Defaults to the current working directory
if None
extensions (Optional[List[str]]): Specific file extensions to summarize.
extensions (Optional[list[str]]): Specific file extensions to summarize.
summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g.
"Summarize the file in two sentences.". The default instruction is "Please summarize this file."

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional
from goose.toolkit import Toolkit
from goose.toolkit.base import tool
@@ -10,7 +10,7 @@ class SummarizeRepo(Toolkit):
def summarize_repo(
self,
repo_url: str,
specified_extensions: Optional[List[str]] = None,
specified_extensions: Optional[list[str]] = None,
summary_instructions_prompt: Optional[str] = None,
) -> dict:
"""
@@ -19,7 +19,7 @@ class SummarizeRepo(Toolkit):
Args:
repo_url (str): The URL of the repository to summarize.
specified_extensions (Optional[List[str]]): List of file extensions to summarize, e.g., ["tf", "md"]. If
specified_extensions (Optional[list[str]]): list of file extensions to summarize, e.g., ["tf", "md"]. If
this list is empty, then all files in the repo are summarized
summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g.
"Summarize the file in two sentences.". The default instruction is "Please summarize this file."

View File

@@ -2,7 +2,7 @@ import json
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Optional
from exchange import Exchange
from exchange.providers.utils import InitialMessageTooLargeError
@@ -15,7 +15,7 @@ CLONED_REPOS_FOLDER = ".goose/cloned_repos"
# TODO: move git stuff
def run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]:
def run_git_command(command: list[str]) -> subprocess.CompletedProcess[str]:
result = subprocess.run(["git"] + command, capture_output=True, text=True, check=False)
if result.returncode != 0:
@@ -28,7 +28,7 @@ def clone_repo(repo_url: str, target_directory: str) -> None:
run_git_command(["clone", repo_url, target_directory])
def load_summary_file_if_exists(project_name: str) -> Optional[Dict]:
def load_summary_file_if_exists(project_name: str) -> Optional[dict]:
"""Checks if a summary file exists at '.goose/summaries/projectname-summary.json. Returns contents of the file if
it exists, otherwise returns None
@@ -36,7 +36,7 @@ def load_summary_file_if_exists(project_name: str) -> Optional[Dict]:
project_name (str): name of the project or repo
Returns:
Optional[Dict]: File contents, else None
Optional[dict]: File contents, else None
"""
summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json"
if Path(summary_file_path).exists():
@@ -44,7 +44,7 @@ def load_summary_file_if_exists(project_name: str) -> Optional[Dict]:
return json.load(f)
def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> Tuple[str, str]:
def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> tuple[str, str]:
"""Summarizes a single file
Args:
@@ -74,15 +74,15 @@ def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = No
def summarize_repo(
repo_url: str,
exchange: Exchange,
extensions: List[str],
extensions: list[str],
summary_instructions_prompt: Optional[str] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Clones (if needed) and summarizes a repo
Args:
repo_url (str): Repository url
exchange (Exchange): Exchange for summarizing the repo.
extensions (List[str]): List of file-types to summarize.
extensions (list[str]): list of file-types to summarize.
summary_instructions_prompt (Optional[str]): Optional parameter to customize summarization results. Defaults to
"Please summarize this file"
"""
@@ -110,15 +110,15 @@ def summarize_repo(
def summarize_directory(
directory: str, exchange: Exchange, extensions: List[str], summary_instructions_prompt: Optional[str] = None
) -> Dict[str, str]:
directory: str, exchange: Exchange, extensions: list[str], summary_instructions_prompt: Optional[str] = None
) -> dict[str, str]:
"""Summarize files in a given directory based on extensions. Will also recursively find files in subdirectories and
summarize them.
Args:
directory (str): path to the top-level directory to summarize
exchange (Exchange): Exchange to use to summarize
extensions (List[str]): List of file-type extensions to summarize (and ignore all other extensions).
extensions (list[str]): list of file-type extensions to summarize (and ignore all other extensions).
summary_instructions_prompt (Optional[str]): Optional instructions to give to the exchange regarding summarization.
Returns:
@@ -158,19 +158,19 @@ def summarize_directory(
def summarize_files_concurrent(
exchange: Exchange, file_list: List[str], project_name: str, summary_instructions_prompt: Optional[str] = None
) -> Dict[str, str]:
exchange: Exchange, file_list: list[str], project_name: str, summary_instructions_prompt: Optional[str] = None
) -> dict[str, str]:
"""Takes in a list of files and summarizes them. Exchange does not keep history of the summarized files.
Args:
exchange (Exchange): Underlying exchange
file_list (List[str]): List of paths to files to summarize
file_list (list[str]): list of paths to files to summarize
project_name (str): Used to save the summary of the files to .goose/summaries/<project_name>-summary.json
summary_instructions_prompt (Optional[str]): Summary instructions for the LLM. Defaults to "Please summarize
this file."
Returns:
file_summaries (Dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange
file_summaries (dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange
"""
summary_file = load_summary_file_if_exists(project_name)
if summary_file:

View File

@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional, Dict
from typing import Optional
from pygments.lexers import get_lexer_for_filename
from pygments.util import ClassNotFound
@@ -67,7 +67,7 @@ def find_last_task_group_index(input_str: str) -> int:
return last_group_start_index
def parse_plan(input_plan_str: str) -> Dict:
def parse_plan(input_plan_str: str) -> dict:
last_group_start_index = find_last_task_group_index(input_plan_str)
if last_group_start_index == -1:
return {"kickoff_message": input_plan_str, "tasks": []}

View File

@@ -1,7 +1,7 @@
import random
import string
from importlib.metadata import entry_points
from typing import Any, Callable, Dict, List, Type, TypeVar
from typing import TypeVar, Callable
T = TypeVar("T")
@@ -31,10 +31,10 @@ def load_plugins(group: str) -> dict:
return plugins
def ensure(cls: Type[T]) -> Callable[[Any], T]:
def ensure(cls: type[T]) -> Callable[[any], T]:
"""Convert dictionary to a class instance"""
def converter(val: Any) -> T: # noqa: ANN401
def converter(val: any) -> T: # noqa: ANN401
if isinstance(val, cls):
return val
elif isinstance(val, dict):
@@ -47,10 +47,10 @@ def ensure(cls: Type[T]) -> Callable[[Any], T]:
return converter
def ensure_list(cls: Type[T]) -> Callable[[List[Dict[str, Any]]], Type[T]]:
def ensure_list(cls: type[T]) -> Callable[[list[dict[str, any]]], type[T]]:
"""Convert a list of dictionaries to class instances"""
def converter(val: List[Dict[str, Any]]) -> List[T]:
def converter(val: list[dict[str, any]]) -> list[T]:
output = []
for entry in val:
output.append(ensure(cls)(entry))

View File

@@ -2,7 +2,7 @@ import glob
import os
from collections import Counter
from pathlib import Path
from typing import Dict, List, Optional
from typing import Optional
def create_extensions_list(project_root: str, max_n: int) -> list:
@@ -11,7 +11,7 @@ def create_extensions_list(project_root: str, max_n: int) -> list:
project_root (str): Root of the project to analyze
max_n (int): The number of file extensions to return
Returns:
extensions (List[str]): A list of the top N file extensions
extensions (list[str]): A list of the top N file extensions
"""
if max_n == 0:
raise (ValueError("Number of file extensions must be greater than 0"))
@@ -31,14 +31,14 @@ def create_extensions_list(project_root: str, max_n: int) -> list:
return extensions
def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]:
def create_language_weighting(files_in_directory: list[str]) -> dict[str, float]:
"""Calculate language weighting by file size to match GitHub's methodology.
Args:
files_in_directory (List[str]): Paths to files in the project directory
files_in_directory (list[str]): Paths to files in the project directory
Returns:
Dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values
dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values
"""
# Initialize counters for sizes
@@ -59,7 +59,7 @@ def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]
return dict(sorted(language_percentages.items(), key=lambda item: item[1], reverse=True))
def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> List[str]:
def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> list[str]:
"""List all files in a directory with a given extension. Set extension to '' to return all files.
Args:
@@ -67,7 +67,7 @@ def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> L
extension (Optional[str]): extension to lookup. Defaults to '' which will return all files.
Returns:
files (List[str]): List of file paths
files (list[str]): list of file paths
"""
# add a leading '.' to extension if needed
if extension and not extension.startswith("."):
@@ -77,15 +77,15 @@ def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> L
return files
def create_file_list(dir_path: str, extensions: List[str]) -> List[str]:
def create_file_list(dir_path: str, extensions: list[str]) -> list[str]:
"""Creates a list of files with certain extensions
Args:
dir_path (str): Directory to list files of. Will include files recursively in sub-directories.
extensions (List[str]): List of file extensions to select for. If empty list, return all files
extensions (list[str]): list of file extensions to select for. If empty list, return all files
Returns:
final_file_list (List[str]): List of file paths with specified extensions.
final_file_list (list[str]): list of file paths with specified extensions.
"""
# if extensions is empty list, return all files
if not extensions:

View File

@@ -2,7 +2,7 @@ import json
import os
import tempfile
from pathlib import Path
from typing import Dict, Iterator, List
from typing import Iterator
from exchange import Message
@@ -17,12 +17,12 @@ def is_empty_session(path: Path) -> bool:
return path.is_file() and path.stat().st_size == 0
def write_to_file(file_path: Path, messages: List[Message]) -> None:
def write_to_file(file_path: Path, messages: list[Message]) -> None:
with open(file_path, "w") as f:
_write_messages_to_file(f, messages)
def read_or_create_file(file_path: Path) -> List[Message]:
def read_or_create_file(file_path: Path) -> list[Message]:
if file_path.exists():
return read_from_file(file_path)
with open(file_path, "w"):
@@ -30,7 +30,7 @@ def read_or_create_file(file_path: Path) -> List[Message]:
return []
def read_from_file(file_path: Path) -> List[Message]:
def read_from_file(file_path: Path) -> list[Message]:
try:
with open(file_path, "r") as f:
messages = [json.loads(m) for m in list(f) if m.strip()]
@@ -40,7 +40,7 @@ def read_from_file(file_path: Path) -> List[Message]:
return [Message(**m) for m in messages]
def list_sorted_session_files(session_files_directory: Path) -> Dict[str, Path]:
def list_sorted_session_files(session_files_directory: Path) -> dict[str, Path]:
logs = list_session_files(session_files_directory)
return {log.stem: log for log in sorted(logs, key=lambda x: x.stat().st_mtime, reverse=True)}
@@ -55,7 +55,7 @@ def session_file_exists(session_files_directory: Path) -> bool:
return any(list_session_files(session_files_directory))
def save_latest_session(file_path: Path, messages: List[Message]) -> None:
def save_latest_session(file_path: Path, messages: list[Message]) -> None:
with tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
_write_messages_to_file(temp_file, messages)
temp_file_path = temp_file.name
@@ -63,7 +63,7 @@ def save_latest_session(file_path: Path, messages: List[Message]) -> None:
os.replace(temp_file_path, file_path)
def _write_messages_to_file(file: any, messages: List[Message]) -> None:
def _write_messages_to_file(file: any, messages: list[Message]) -> None:
for m in messages:
json.dump(m.to_dict(), file)
file.write("\n")