mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-01 21:44:26 +01:00
chore: use primitives instead of typing imports and fixes completion … (#149)
Signed-off-by: Adrian Cole <adrian.cole@elastic.co>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from copy import deepcopy
|
||||
from typing import List
|
||||
from attrs import define, field
|
||||
|
||||
|
||||
@@ -31,7 +30,7 @@ class CheckpointData:
|
||||
total_token_count: int = field(default=0)
|
||||
|
||||
# in order list of individual checkpoints in the exchange
|
||||
checkpoints: List[Checkpoint] = field(factory=list)
|
||||
checkpoints: list[Checkpoint] = field(factory=list)
|
||||
|
||||
# the offset to apply to the message index when calculating the last message index
|
||||
# this is useful because messages on the exchange behave like a queue, where you can only
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from attrs import define, asdict
|
||||
|
||||
@@ -7,11 +7,11 @@ CONTENT_TYPES = {}
|
||||
|
||||
|
||||
class Content:
|
||||
def __init_subclass__(cls, **kwargs: Dict[str, Any]) -> None:
|
||||
def __init_subclass__(cls, **kwargs: dict[str, any]) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
CONTENT_TYPES[cls.__name__] = cls
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, any]:
|
||||
data = asdict(self, recurse=True)
|
||||
data["type"] = self.__class__.__name__
|
||||
return data
|
||||
@@ -26,7 +26,7 @@ class Text(Content):
|
||||
class ToolUse(Content):
|
||||
id: str
|
||||
name: str
|
||||
parameters: Any
|
||||
parameters: any
|
||||
is_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Mapping, Tuple
|
||||
|
||||
from typing import Mapping
|
||||
from attrs import define, evolve, field, Factory
|
||||
from tiktoken import get_encoding
|
||||
|
||||
@@ -41,8 +40,8 @@ class Exchange:
|
||||
model: str
|
||||
system: str
|
||||
moderator: Moderator = field(default=ContextTruncate())
|
||||
tools: Tuple[Tool] = field(factory=tuple, converter=tuple)
|
||||
messages: List[Message] = field(factory=list)
|
||||
tools: tuple[Tool, ...] = field(factory=tuple, converter=tuple)
|
||||
messages: list[Message] = field(factory=list)
|
||||
checkpoint_data: CheckpointData = field(factory=CheckpointData)
|
||||
generation_args: dict = field(default=Factory(dict))
|
||||
|
||||
@@ -50,7 +49,7 @@ class Exchange:
|
||||
def _toolmap(self) -> Mapping[str, Tool]:
|
||||
return {tool.name: tool for tool in self.tools}
|
||||
|
||||
def replace(self, **kwargs: Dict[str, Any]) -> "Exchange":
|
||||
def replace(self, **kwargs: dict[str, any]) -> "Exchange":
|
||||
"""Make a copy of the exchange, replacing any passed arguments"""
|
||||
# TODO: ensure that the checkpoint data is updated correctly. aka,
|
||||
# if we replace the messages, we need to update the checkpoint data
|
||||
@@ -264,7 +263,7 @@ class Exchange:
|
||||
# we've removed all the checkpoints, so we need to reset the message index offset
|
||||
self.checkpoint_data.message_index_offset = 0
|
||||
|
||||
def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
|
||||
def pop_last_checkpoint(self) -> tuple[Checkpoint, list[Message]]:
|
||||
"""
|
||||
Reverts the exchange back to the last checkpoint, removing associated messages
|
||||
"""
|
||||
@@ -275,7 +274,7 @@ class Exchange:
|
||||
messages.append(self.messages.pop())
|
||||
return removed_checkpoint, messages
|
||||
|
||||
def pop_first_checkpoint(self) -> Tuple[Checkpoint, List[Message]]:
|
||||
def pop_first_checkpoint(self) -> tuple[Checkpoint, list[Message]]:
|
||||
"""
|
||||
Pop the first checkpoint from the exchange, removing associated messages
|
||||
"""
|
||||
@@ -332,5 +331,6 @@ class Exchange:
|
||||
# this to be a required method of the provider instead.
|
||||
return len(self.messages) > 0 and self.messages[-1].role == "user"
|
||||
|
||||
def get_token_usage(self) -> Dict[str, Usage]:
|
||||
@staticmethod
|
||||
def get_token_usage() -> dict[str, Usage]:
|
||||
return _token_usage_collector.get_token_usage_group_by_model()
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
class InvalidChoiceError(Exception):
|
||||
def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None:
|
||||
def __init__(self, attribute_name: str, attribute_value: str, available_values: list[str]) -> None:
|
||||
self.attribute_name = attribute_name
|
||||
self.attribute_value = attribute_value
|
||||
self.available_values = available_values
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import inspect
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Type
|
||||
from typing import Literal
|
||||
|
||||
from attrs import define, field
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
@@ -12,7 +12,7 @@ from exchange.utils import create_object_id
|
||||
Role = Literal["user", "assistant"]
|
||||
|
||||
|
||||
def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: ANN401
|
||||
def validate_role_and_content(instance: "Message", *_: any) -> None: # noqa: ANN401
|
||||
if instance.role == "user":
|
||||
if not (instance.text or instance.tool_result):
|
||||
raise ValueError("User message must include a Text or ToolResult")
|
||||
@@ -25,7 +25,7 @@ def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: AN
|
||||
raise ValueError("Assistant message does not support ToolResult")
|
||||
|
||||
|
||||
def content_converter(contents: List[Dict[str, Any]]) -> List[Content]:
|
||||
def content_converter(contents: list[dict[str, any]]) -> list[Content]:
|
||||
return [(CONTENT_TYPES[c.pop("type")](**c) if c.__class__ not in CONTENT_TYPES.values() else c) for c in contents]
|
||||
|
||||
|
||||
@@ -48,9 +48,9 @@ class Message:
|
||||
role: Role = field(default="user")
|
||||
id: str = field(factory=lambda: str(create_object_id(prefix="msg")))
|
||||
created: int = field(factory=lambda: int(time.time()))
|
||||
content: List[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter)
|
||||
content: list[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, any]:
|
||||
return {
|
||||
"role": self.role,
|
||||
"id": self.id,
|
||||
@@ -68,7 +68,7 @@ class Message:
|
||||
return "\n".join(result)
|
||||
|
||||
@property
|
||||
def tool_use(self) -> List[ToolUse]:
|
||||
def tool_use(self) -> list[ToolUse]:
|
||||
"""All tool use content of this message."""
|
||||
result = []
|
||||
for content in self.content:
|
||||
@@ -77,7 +77,7 @@ class Message:
|
||||
return result
|
||||
|
||||
@property
|
||||
def tool_result(self) -> List[ToolResult]:
|
||||
def tool_result(self) -> list[ToolResult]:
|
||||
"""All tool result content of this message."""
|
||||
result = []
|
||||
for content in self.content:
|
||||
@@ -87,10 +87,10 @@ class Message:
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls: Type["Message"],
|
||||
cls: type["Message"],
|
||||
filename: str,
|
||||
role: Role = "user",
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: dict[str, any],
|
||||
) -> "Message":
|
||||
"""Load the message from filename relative to where the load is called.
|
||||
|
||||
@@ -113,9 +113,9 @@ class Message:
|
||||
return cls(role=role, content=[Text(text=rendered_content)])
|
||||
|
||||
@classmethod
|
||||
def user(cls: Type["Message"], text: str) -> "Message":
|
||||
def user(cls: type["Message"], text: str) -> "Message":
|
||||
return cls(role="user", content=[Text(text)])
|
||||
|
||||
@classmethod
|
||||
def assistant(cls: Type["Message"], text: str) -> "Message":
|
||||
def assistant(cls: type["Message"], text: str) -> "Message":
|
||||
return cls(role="assistant", content=[Text(text)])
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from functools import cache
|
||||
from typing import Type
|
||||
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from exchange.moderators.base import Moderator
|
||||
@@ -10,7 +9,7 @@ from exchange.moderators.summarizer import ContextSummarizer # noqa
|
||||
|
||||
|
||||
@cache
|
||||
def get_moderator(name: str) -> Type[Moderator]:
|
||||
def get_moderator(name: str) -> type[Moderator]:
|
||||
moderators = load_plugins(group="exchange.moderator")
|
||||
if name not in moderators:
|
||||
raise InvalidChoiceError("moderator", name, moderators.keys())
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type
|
||||
|
||||
|
||||
class Moderator(ABC):
|
||||
@abstractmethod
|
||||
def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
|
||||
def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Type
|
||||
from exchange.moderators.base import Moderator
|
||||
|
||||
|
||||
class PassiveModerator(Moderator):
|
||||
def rewrite(self, _: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
|
||||
def rewrite(self, _: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
|
||||
pass
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from typing import Type
|
||||
|
||||
from exchange import Message
|
||||
from exchange.checkpoint import CheckpointData
|
||||
from exchange.moderators import ContextTruncate, PassiveModerator
|
||||
|
||||
|
||||
class ContextSummarizer(ContextTruncate):
|
||||
def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821
|
||||
def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821
|
||||
"""Summarize the context history up to the last few messages in the exchange"""
|
||||
|
||||
self._update_system_prompt_token_count(exchange)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from exchange.checkpoint import CheckpointData
|
||||
from exchange.message import Message
|
||||
@@ -62,7 +62,7 @@ class ContextTruncate(Moderator):
|
||||
exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count
|
||||
exchange.checkpoint_data.total_token_count += self.system_prompt_token_count
|
||||
|
||||
def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]:
|
||||
def _get_messages_to_remove(self, exchange: Exchange) -> list[Message]:
|
||||
# this keeps all the messages/checkpoints
|
||||
throwaway_exchange = exchange.replace(
|
||||
moderator=PassiveModerator(),
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from functools import cache
|
||||
from typing import Type
|
||||
|
||||
from exchange.invalid_choice_error import InvalidChoiceError
|
||||
from exchange.providers.anthropic import AnthropicProvider # noqa
|
||||
@@ -15,7 +14,7 @@ from exchange.utils import load_plugins
|
||||
|
||||
|
||||
@cache
|
||||
def get_provider(name: str) -> Type[Provider]:
|
||||
def get_provider(name: str) -> type[Provider]:
|
||||
providers = load_plugins(group="exchange.provider")
|
||||
if name not in providers:
|
||||
raise InvalidChoiceError("provider", name, providers.keys())
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -29,7 +28,7 @@ class AnthropicProvider(Provider):
|
||||
self.client = client
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider":
|
||||
def from_env(cls: type["AnthropicProvider"]) -> "AnthropicProvider":
|
||||
cls.check_env_vars()
|
||||
url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST)
|
||||
key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
@@ -45,7 +44,7 @@ class AnthropicProvider(Provider):
|
||||
return cls(client)
|
||||
|
||||
@staticmethod
|
||||
def get_usage(data: Dict) -> Usage: # noqa: ANN401
|
||||
def get_usage(data: dict) -> Usage: # noqa: ANN401
|
||||
usage = data.get("usage")
|
||||
input_tokens = usage.get("input_tokens")
|
||||
output_tokens = usage.get("output_tokens")
|
||||
@@ -61,7 +60,7 @@ class AnthropicProvider(Provider):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def anthropic_response_to_message(response: Dict) -> Message:
|
||||
def anthropic_response_to_message(response: dict) -> Message:
|
||||
content_blocks = response.get("content", [])
|
||||
content = []
|
||||
for block in content_blocks:
|
||||
@@ -78,7 +77,7 @@ class AnthropicProvider(Provider):
|
||||
return Message(role="assistant", content=content)
|
||||
|
||||
@staticmethod
|
||||
def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]:
|
||||
def tools_to_anthropic_spec(tools: tuple[Tool, ...]) -> list[dict[str, any]]:
|
||||
return [
|
||||
{
|
||||
"name": tool.name,
|
||||
@@ -89,7 +88,7 @@ class AnthropicProvider(Provider):
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]:
|
||||
def messages_to_anthropic_spec(messages: list[Message]) -> list[dict[str, any]]:
|
||||
messages_spec = []
|
||||
# if messages is empty - just make a default
|
||||
for message in messages:
|
||||
@@ -127,10 +126,12 @@ class AnthropicProvider(Provider):
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
messages: List[Message],
|
||||
tools: List[Tool] = [],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Message, Usage]:
|
||||
messages: list[Message],
|
||||
tools: list[Tool] = None,
|
||||
**kwargs: dict[str, any],
|
||||
) -> tuple[Message, Usage]:
|
||||
if tools is None:
|
||||
tools = []
|
||||
tools_set = set()
|
||||
unique_tools = []
|
||||
for tool in tools:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Type
|
||||
|
||||
import httpx
|
||||
import os
|
||||
|
||||
@@ -21,7 +19,7 @@ class AzureProvider(OpenAiProvider):
|
||||
super().__init__(client)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
|
||||
def from_env(cls: type["AzureProvider"]) -> "AzureProvider":
|
||||
cls.check_env_vars()
|
||||
url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME")
|
||||
deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from attrs import define, field
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import Optional
|
||||
|
||||
from exchange.message import Message
|
||||
from exchange.tool import Tool
|
||||
@@ -19,11 +19,11 @@ class Provider(ABC):
|
||||
REQUIRED_ENV_VARS: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["Provider"]) -> "Provider":
|
||||
def from_env(cls: type["Provider"]) -> "Provider":
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def check_env_vars(cls: Type["Provider"], instructions_url: Optional[str] = None) -> None:
|
||||
def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None:
|
||||
for env_var in cls.REQUIRED_ENV_VARS:
|
||||
if env_var not in os.environ:
|
||||
raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url)
|
||||
@@ -33,9 +33,10 @@ class Provider(ABC):
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
messages: List[Message],
|
||||
tools: Tuple[Tool],
|
||||
) -> Tuple[Message, Usage]:
|
||||
messages: list[Message],
|
||||
tools: tuple[Tool, ...],
|
||||
**kwargs: dict[str, any],
|
||||
) -> tuple[Message, Usage]:
|
||||
"""Generate the next message using the specified model"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Optional
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
import httpx
|
||||
@@ -36,7 +36,7 @@ class AwsClient(httpx.Client):
|
||||
aws_access_key: str,
|
||||
aws_secret_key: str,
|
||||
aws_session_token: Optional[str] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: dict[str, any],
|
||||
) -> None:
|
||||
self.region = aws_region
|
||||
self.host = f"https://{SERVICE}.{aws_region}.amazonaws.com/"
|
||||
@@ -45,7 +45,7 @@ class AwsClient(httpx.Client):
|
||||
self.session_token = aws_session_token
|
||||
super().__init__(base_url=self.host, timeout=600, **kwargs)
|
||||
|
||||
def post(self, path: str, json: Dict, **kwargs: Dict[str, Any]) -> httpx.Response:
|
||||
def post(self, path: str, json: dict, **kwargs: dict[str, any]) -> httpx.Response:
|
||||
signed_headers = self.sign_and_get_headers(
|
||||
method="POST",
|
||||
url=path,
|
||||
@@ -60,7 +60,7 @@ class AwsClient(httpx.Client):
|
||||
url: str,
|
||||
payload: dict,
|
||||
service: str,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Sign the request and generate the necessary headers for AWS authentication.
|
||||
|
||||
@@ -72,10 +72,10 @@ class AwsClient(httpx.Client):
|
||||
region (str): The AWS region.
|
||||
access_key (str): The AWS access key.
|
||||
secret_key (str): The AWS secret key.
|
||||
session_token (Optional[str]): The AWS session token, if any.
|
||||
session_token (optional[str]): The AWS session token, if any.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: The headers required for the request.
|
||||
dict[str, str]: The headers required for the request.
|
||||
"""
|
||||
|
||||
def sign(key: bytes, msg: str) -> bytes:
|
||||
@@ -160,7 +160,7 @@ class BedrockProvider(Provider):
|
||||
self.client = client
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider":
|
||||
def from_env(cls: type["BedrockProvider"]) -> "BedrockProvider":
|
||||
cls.check_env_vars()
|
||||
aws_region = os.environ.get("AWS_REGION", "us-east-1")
|
||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
|
||||
@@ -179,22 +179,22 @@ class BedrockProvider(Provider):
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
messages: List[Message],
|
||||
tools: Tuple[Tool],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Message, Usage]:
|
||||
messages: list[Message],
|
||||
tools: tuple[Tool, ...],
|
||||
**kwargs: dict[str, any],
|
||||
) -> tuple[Message, Usage]:
|
||||
"""
|
||||
Generate a completion response from the Bedrock gateway.
|
||||
|
||||
Args:
|
||||
model (str): The model identifier.
|
||||
system (str): The system prompt or configuration.
|
||||
messages (List[Message]): A list of messages to be processed by the model.
|
||||
tools (Tuple[Tool]): A tuple of tools to be used in the completion process.
|
||||
messages (list[Message]): A list of messages to be processed by the model.
|
||||
tools (tuple[Tool]): A tuple of tools to be used in the completion process.
|
||||
**kwargs: Additional keyword arguments for inference configuration.
|
||||
|
||||
Returns:
|
||||
Tuple[Message, Usage]: A tuple containing the response message and usage data.
|
||||
tuple[Message, Usage]: A tuple containing the response message and usage data.
|
||||
"""
|
||||
|
||||
inference_config = dict(
|
||||
@@ -231,7 +231,7 @@ class BedrockProvider(Provider):
|
||||
return self.response_to_message(response_message), usage
|
||||
|
||||
@retry_procedure
|
||||
def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401
|
||||
def _post(self, payload: any, path: str) -> dict: # noqa: ANN401
|
||||
response = self.client.post(path, json=payload)
|
||||
return raise_for_status(response).json()
|
||||
|
||||
@@ -311,7 +311,7 @@ class BedrockProvider(Provider):
|
||||
raise Exception("Invalid response")
|
||||
|
||||
@staticmethod
|
||||
def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]:
|
||||
def tools_to_bedrock_spec(tools: tuple[Tool, ...]) -> Optional[dict]:
|
||||
if len(tools) == 0:
|
||||
return None # API requires a non-empty tool config or None
|
||||
tools_added = set()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import httpx
|
||||
import os
|
||||
|
||||
@@ -43,7 +41,7 @@ class DatabricksProvider(Provider):
|
||||
self.client = client
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider":
|
||||
def from_env(cls: type["DatabricksProvider"]) -> "DatabricksProvider":
|
||||
cls.check_env_vars(cls.instructions_url)
|
||||
url = os.environ.get("DATABRICKS_HOST")
|
||||
key = os.environ.get("DATABRICKS_TOKEN")
|
||||
@@ -73,10 +71,10 @@ class DatabricksProvider(Provider):
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
messages: List[Message],
|
||||
tools: Tuple[Tool],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Message, Usage]:
|
||||
messages: list[Message],
|
||||
tools: tuple[Tool, ...],
|
||||
**kwargs: dict[str, any],
|
||||
) -> tuple[Message, Usage]:
|
||||
payload = dict(
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -30,7 +29,7 @@ class GoogleProvider(Provider):
|
||||
self.client = client
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider":
|
||||
def from_env(cls: type["GoogleProvider"]) -> "GoogleProvider":
|
||||
cls.check_env_vars(cls.instructions_url)
|
||||
url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST)
|
||||
key = os.environ.get("GOOGLE_API_KEY")
|
||||
@@ -45,7 +44,7 @@ class GoogleProvider(Provider):
|
||||
return cls(client)
|
||||
|
||||
@staticmethod
|
||||
def get_usage(data: Dict) -> Usage: # noqa: ANN401
|
||||
def get_usage(data: dict) -> Usage: # noqa: ANN401
|
||||
usage = data.get("usageMetadata")
|
||||
input_tokens = usage.get("promptTokenCount")
|
||||
output_tokens = usage.get("candidatesTokenCount")
|
||||
@@ -61,7 +60,7 @@ class GoogleProvider(Provider):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def google_response_to_message(response: Dict) -> Message:
|
||||
def google_response_to_message(response: dict) -> Message:
|
||||
candidates = response.get("candidates", [])
|
||||
if candidates:
|
||||
# Only use first candidate for now
|
||||
@@ -85,12 +84,12 @@ class GoogleProvider(Provider):
|
||||
return Message(role="assistant", content=[])
|
||||
|
||||
@staticmethod
|
||||
def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]:
|
||||
def tools_to_google_spec(tools: tuple[Tool, ...]) -> dict[str, list[dict[str, any]]]:
|
||||
if not tools:
|
||||
return {}
|
||||
converted_tools = []
|
||||
for tool in tools:
|
||||
converted_tool: Dict[str, Any] = {
|
||||
converted_tool: dict[str, any] = {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
}
|
||||
@@ -100,7 +99,7 @@ class GoogleProvider(Provider):
|
||||
return {"functionDeclarations": converted_tools}
|
||||
|
||||
@staticmethod
|
||||
def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]:
|
||||
def messages_to_google_spec(messages: list[Message]) -> list[dict[str, any]]:
|
||||
messages_spec = []
|
||||
for message in messages:
|
||||
role = "user" if message.role == "user" else "model"
|
||||
@@ -136,10 +135,10 @@ class GoogleProvider(Provider):
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
messages: List[Message],
|
||||
tools: List[Tool] = [],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Message, Usage]:
|
||||
messages: list[Message],
|
||||
tools: list[Tool] = None,
|
||||
**kwargs: dict[str, any],
|
||||
) -> tuple[Message, Usage]:
|
||||
tools_set = set()
|
||||
unique_tools = []
|
||||
for tool in tools:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -37,7 +36,7 @@ class GroqProvider(Provider):
|
||||
self.client = client
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["GroqProvider"]) -> "GroqProvider":
|
||||
def from_env(cls: type["GroqProvider"]) -> "GroqProvider":
|
||||
cls.check_env_vars(cls.instructions_url)
|
||||
url = os.environ.get("GROQ_HOST", GROQ_HOST)
|
||||
key = os.environ.get("GROQ_API_KEY")
|
||||
@@ -69,10 +68,10 @@ class GroqProvider(Provider):
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
messages: List[Message],
|
||||
tools: Tuple[Tool],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Message, Usage]:
|
||||
messages: list[Message],
|
||||
tools: tuple[Tool, ...],
|
||||
**kwargs: dict[str, any],
|
||||
) -> tuple[Message, Usage]:
|
||||
system_message = [{"role": "system", "content": system}]
|
||||
payload = dict(
|
||||
messages=system_message + messages_to_openai_spec(messages),
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Type
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -31,7 +30,7 @@ ollama:
|
||||
super().__init__(client)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider":
|
||||
def from_env(cls: type["OllamaProvider"]) -> "OllamaProvider":
|
||||
ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST)
|
||||
timeout = httpx.Timeout(60 * 10)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -37,7 +36,7 @@ class OpenAiProvider(Provider):
|
||||
self.client = client
|
||||
|
||||
@classmethod
|
||||
def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider":
|
||||
def from_env(cls: type["OpenAiProvider"]) -> "OpenAiProvider":
|
||||
cls.check_env_vars(cls.instructions_url)
|
||||
url = os.environ.get("OPENAI_HOST", OPENAI_HOST)
|
||||
key = os.environ.get("OPENAI_API_KEY")
|
||||
@@ -69,10 +68,10 @@ class OpenAiProvider(Provider):
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
messages: List[Message],
|
||||
tools: Tuple[Tool],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Message, Usage]:
|
||||
messages: list[Message],
|
||||
tools: tuple[Tool, ...],
|
||||
**kwargs: dict[str, any],
|
||||
) -> tuple[Message, Usage]:
|
||||
system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}]
|
||||
payload = dict(
|
||||
messages=system_message + messages_to_openai_spec(messages),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from exchange.content import Text, ToolResult, ToolUse
|
||||
@@ -10,10 +10,10 @@ from exchange.tool import Tool
|
||||
from tenacity import retry_if_exception
|
||||
|
||||
|
||||
def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable:
|
||||
def retry_if_status(codes: Optional[list[int]] = None, above: Optional[int] = None) -> callable:
|
||||
codes = codes or []
|
||||
|
||||
def predicate(exc: Exception) -> bool:
|
||||
def predicate(exc: BaseException) -> bool:
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
if exc.response.status_code in codes:
|
||||
return True
|
||||
@@ -42,7 +42,7 @@ def encode_image(image_path: str) -> str:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]:
|
||||
def messages_to_openai_spec(messages: list[Message]) -> list[dict[str, any]]:
|
||||
messages_spec = []
|
||||
for message in messages:
|
||||
converted = {"role": message.role}
|
||||
@@ -106,7 +106,7 @@ def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]:
|
||||
return messages_spec
|
||||
|
||||
|
||||
def tools_to_openai_spec(tools: Tuple[Tool]) -> Dict[str, Any]:
|
||||
def tools_to_openai_spec(tools: tuple[Tool, ...]) -> dict[str, any]:
|
||||
tools_names = set()
|
||||
result = []
|
||||
for tool in tools:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
from exchange.providers.base import Usage
|
||||
|
||||
@@ -11,7 +10,7 @@ class _TokenUsageCollector:
|
||||
def collect(self, model: str, usage: Usage) -> None:
|
||||
self.usage_data.append((model, usage))
|
||||
|
||||
def get_token_usage_group_by_model(self) -> Dict[str, Usage]:
|
||||
def get_token_usage_group_by_model(self) -> dict[str, Usage]:
|
||||
usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0))
|
||||
for model, usage in self.usage_data:
|
||||
usage_by_model = usage_group_by_model[model]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
from attrs import define
|
||||
|
||||
@@ -13,17 +12,17 @@ class Tool:
|
||||
Attributes:
|
||||
name (str): The name of the tool
|
||||
description (str): A description of what the tool does
|
||||
parameters dict[str, Any]: A json schema of the function signature
|
||||
parameters dict[str, any]: A json schema of the function signature
|
||||
function (Callable): The python function that powers the tool
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict[str, Any]
|
||||
function: Callable
|
||||
parameters: dict[str, any]
|
||||
function: callable
|
||||
|
||||
@classmethod
|
||||
def from_function(cls: Type["Tool"], func: Any) -> "Tool": # noqa: ANN401
|
||||
def from_function(cls: type["Tool"], func: any) -> "Tool": # noqa: ANN401
|
||||
"""Create a tool instance from a function and its docstring
|
||||
|
||||
The function must have a docstring - we require it to load the description
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import inspect
|
||||
import uuid
|
||||
from importlib.metadata import entry_points
|
||||
from typing import Any, Callable, Dict, List, Type, get_args, get_origin
|
||||
from typing import get_args, get_origin
|
||||
|
||||
from griffe import (
|
||||
Docstring,
|
||||
@@ -20,7 +20,7 @@ def compact(content: str) -> str:
|
||||
return " ".join(content.split())
|
||||
|
||||
|
||||
def parse_docstring(func: Callable) -> tuple[str, List[Dict]]:
|
||||
def parse_docstring(func: callable) -> tuple[str, list[dict]]:
|
||||
"""Get description and parameters from function docstring"""
|
||||
function_args = list(inspect.signature(func).parameters.keys())
|
||||
text = str(func.__doc__)
|
||||
@@ -71,7 +71,7 @@ def parse_docstring(func: Callable) -> tuple[str, List[Dict]]:
|
||||
|
||||
|
||||
def _check_section_is_present(
|
||||
parsed_docstring: List[DocstringSection], section_type: Type[DocstringSectionText]
|
||||
parsed_docstring: list[DocstringSection], section_type: type[DocstringSectionText]
|
||||
) -> bool:
|
||||
for section in parsed_docstring:
|
||||
if isinstance(section, section_type):
|
||||
@@ -79,7 +79,7 @@ def _check_section_is_present(
|
||||
return False
|
||||
|
||||
|
||||
def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401
|
||||
def json_schema(func: any) -> dict[str, any]: # noqa: ANN401
|
||||
"""Get the json schema for a function"""
|
||||
signature = inspect.signature(func)
|
||||
parameters = signature.parameters
|
||||
@@ -107,16 +107,16 @@ def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401
|
||||
return schema
|
||||
|
||||
|
||||
def _map_type_to_schema(py_type: Type) -> Dict[str, Any]: # noqa: ANN401
|
||||
def _map_type_to_schema(py_type: type) -> dict[str, any]: # noqa: ANN401
|
||||
origin = get_origin(py_type)
|
||||
args = get_args(py_type)
|
||||
|
||||
if origin is list or origin is tuple:
|
||||
return {"type": "array", "items": _map_type_to_schema(args[0] if args else Any)}
|
||||
return {"type": "array", "items": _map_type_to_schema(args[0] if args else any)}
|
||||
elif origin is dict:
|
||||
return {
|
||||
"type": "object",
|
||||
"additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else Any),
|
||||
"additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else any),
|
||||
}
|
||||
elif py_type is int:
|
||||
return {"type": "integer"}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Type, Tuple
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -189,14 +188,14 @@ def scrub_response_headers(response):
|
||||
return response
|
||||
|
||||
|
||||
def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
|
||||
def complete(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]:
|
||||
provider = provider_cls.from_env()
|
||||
system = "You are a helpful assistant."
|
||||
messages = [Message.user("Hello")]
|
||||
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)
|
||||
|
||||
|
||||
def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
|
||||
def tools(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]:
|
||||
provider = provider_cls.from_env()
|
||||
system = "You are a helpful assistant. Expect to need to read a file using read_file."
|
||||
messages = [Message.user("What are the contents of this file? test.txt")]
|
||||
@@ -205,7 +204,7 @@ def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message,
|
||||
)
|
||||
|
||||
|
||||
def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
|
||||
def vision(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]:
|
||||
provider = provider_cls.from_env()
|
||||
system = "You are a helpful assistant."
|
||||
messages = [
|
||||
|
||||
@@ -14,10 +14,10 @@ AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")
|
||||
@pytest.mark.parametrize(
|
||||
"env_var_name",
|
||||
[
|
||||
("AZURE_CHAT_COMPLETIONS_HOST_NAME"),
|
||||
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"),
|
||||
("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"),
|
||||
("AZURE_CHAT_COMPLETIONS_KEY"),
|
||||
"AZURE_CHAT_COMPLETIONS_HOST_NAME",
|
||||
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME",
|
||||
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION",
|
||||
"AZURE_CHAT_COMPLETIONS_KEY",
|
||||
],
|
||||
)
|
||||
def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
|
||||
@@ -15,9 +15,9 @@ logger = logging.getLogger(__name__)
|
||||
@pytest.mark.parametrize(
|
||||
"env_var_name",
|
||||
[
|
||||
("AWS_ACCESS_KEY_ID"),
|
||||
("AWS_SECRET_ACCESS_KEY"),
|
||||
("AWS_SESSION_TOKEN"),
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_SESSION_TOKEN",
|
||||
],
|
||||
)
|
||||
def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
|
||||
@@ -10,8 +10,8 @@ from exchange.providers.databricks import DatabricksProvider
|
||||
@pytest.mark.parametrize(
|
||||
"env_var_name",
|
||||
[
|
||||
("DATABRICKS_HOST"),
|
||||
("DATABRICKS_TOKEN"),
|
||||
"DATABRICKS_HOST",
|
||||
"DATABRICKS_TOKEN",
|
||||
],
|
||||
)
|
||||
def test_from_env_throw_error_when_missing_env_var(env_var_name):
|
||||
|
||||
@@ -107,9 +107,9 @@ def test_messages_to_openai_spec() -> None:
|
||||
Message(role="user", content=[Text("How are you?")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id=1, name="tool1", parameters={"param1": "value1"})],
|
||||
content=[ToolUse(id="1", name="tool1", parameters={"param1": "value1"})],
|
||||
),
|
||||
Message(role="user", content=[ToolResult(tool_use_id=1, output="Result")]),
|
||||
Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]),
|
||||
]
|
||||
|
||||
spec = messages_to_openai_spec(messages)
|
||||
@@ -121,7 +121,7 @@ def test_messages_to_openai_spec() -> None:
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": 1,
|
||||
"id": "1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool1",
|
||||
@@ -133,7 +133,7 @@ def test_messages_to_openai_spec() -> None:
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "Result",
|
||||
"tool_call_id": 1,
|
||||
"tool_call_id": "1",
|
||||
},
|
||||
]
|
||||
|
||||
@@ -216,7 +216,7 @@ def test_openai_response_to_message_valid_tooluse() -> None:
|
||||
expect = asdict(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=[ToolUse(id=1, name="example_fn", parameters={"param": "value"})],
|
||||
content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})],
|
||||
)
|
||||
)
|
||||
actual.pop("id")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from exchange.checkpoint import Checkpoint, CheckpointData
|
||||
@@ -29,12 +27,12 @@ def no_overlapping_checkpoints(exchange: Exchange) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def checkpoint_to_index_pairs(checkpoints: List[Checkpoint]) -> List[Tuple[int, int]]:
|
||||
def checkpoint_to_index_pairs(checkpoints: list[Checkpoint]) -> list[tuple[int, int]]:
|
||||
return [(checkpoint.start_index, checkpoint.end_index) for checkpoint in checkpoints]
|
||||
|
||||
|
||||
class MockProvider(Provider):
|
||||
def __init__(self, sequence: List[Message], usage_dicts: List[dict]):
|
||||
def __init__(self, sequence: list[Message], usage_dicts: list[dict]):
|
||||
# We'll use init to provide a preplanned reply sequence
|
||||
self.sequence = sequence
|
||||
self.call_count = 0
|
||||
@@ -56,11 +54,18 @@ class MockProvider(Provider):
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message:
|
||||
def complete(
|
||||
self,
|
||||
model: str,
|
||||
system: str,
|
||||
messages: list[Message],
|
||||
tools: tuple[Tool, ...],
|
||||
**kwargs: dict[str, any],
|
||||
) -> tuple[Message, Usage]:
|
||||
output = self.sequence[self.call_count]
|
||||
usage = self.get_usage(self.usage_dicts[self.call_count])
|
||||
self.call_count += 1
|
||||
return (output, usage)
|
||||
return output, usage
|
||||
|
||||
|
||||
def test_reply_with_unsupported_tool():
|
||||
@@ -116,7 +121,7 @@ def test_invalid_tool_parameters():
|
||||
),
|
||||
model="gpt-4o-2024-05-13",
|
||||
system="You are a helpful assistant.",
|
||||
tools=[Tool.from_function(dummy_tool)],
|
||||
tools=(Tool.from_function(dummy_tool),),
|
||||
moderator=PassiveModerator(),
|
||||
)
|
||||
|
||||
@@ -154,7 +159,7 @@ def test_max_tool_use_when_limit_reached():
|
||||
),
|
||||
model="gpt-4o-2024-05-13",
|
||||
system="You are a helpful assistant.",
|
||||
tools=[Tool.from_function(dummy_tool)],
|
||||
tools=(Tool.from_function(dummy_tool),),
|
||||
moderator=PassiveModerator(),
|
||||
)
|
||||
|
||||
@@ -195,7 +200,7 @@ def test_tool_output_too_long_character_error():
|
||||
),
|
||||
model="gpt-4o-2024-05-13",
|
||||
system="You are a helpful assistant.",
|
||||
tools=[Tool.from_function(long_output_tool_char)],
|
||||
tools=(Tool.from_function(long_output_tool_char),),
|
||||
moderator=PassiveModerator(),
|
||||
)
|
||||
|
||||
@@ -236,7 +241,7 @@ def test_tool_output_too_long_token_error():
|
||||
),
|
||||
model="gpt-4o-2024-05-13",
|
||||
system="You are a helpful assistant.",
|
||||
tools=[Tool.from_function(long_output_tool_token)],
|
||||
tools=(Tool.from_function(long_output_tool_token),),
|
||||
moderator=PassiveModerator(),
|
||||
)
|
||||
|
||||
@@ -301,7 +306,7 @@ def resumed_exchange() -> Exchange:
|
||||
ex = Exchange(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
tools=[],
|
||||
tools=(),
|
||||
model="gpt-4o-2024-05-13",
|
||||
system="You are a helpful assistant.",
|
||||
checkpoint_data=CheckpointData(),
|
||||
@@ -399,7 +404,7 @@ def test_pop_first_message_no_messages():
|
||||
provider=MockProvider(sequence=[], usage_dicts=[]),
|
||||
model="gpt-4o-2024-05-13",
|
||||
system="You are a helpful assistant.",
|
||||
tools=[Tool.from_function(dummy_tool)],
|
||||
tools=(Tool.from_function(dummy_tool),),
|
||||
moderator=PassiveModerator(),
|
||||
)
|
||||
|
||||
@@ -741,7 +746,7 @@ def test_rewind_with_tool_usage():
|
||||
),
|
||||
model="gpt-4o-2024-05-13",
|
||||
system="You are a helpful assistant.",
|
||||
tools=[Tool.from_function(dummy_tool)],
|
||||
tools=(Tool.from_function(dummy_tool),),
|
||||
moderator=PassiveModerator(),
|
||||
)
|
||||
ex.add(Message(role="user", content=[Text(text="test")]))
|
||||
|
||||
@@ -9,7 +9,7 @@ from exchange.tool import Tool
|
||||
|
||||
|
||||
class MockProvider(Provider):
|
||||
def complete(self, model, system, messages, tools=None):
|
||||
def complete(self, model, system, messages, tools, **kwargs):
|
||||
return Message(role="assistant", content=[Text(text="This is a mock response.")]), Usage.from_dict(
|
||||
{"total_tokens": 35}
|
||||
)
|
||||
|
||||
@@ -3,11 +3,11 @@ from exchange import Exchange, Message
|
||||
from exchange.content import ToolResult, ToolUse
|
||||
from exchange.moderators.passive import PassiveModerator
|
||||
from exchange.moderators.summarizer import ContextSummarizer
|
||||
from exchange.providers import Usage
|
||||
from exchange.providers import Usage, Provider
|
||||
|
||||
|
||||
class MockProvider:
|
||||
def complete(self, model, system, messages, tools):
|
||||
class MockProvider(Provider):
|
||||
def complete(self, model, system, messages, tools, **kwargs):
|
||||
assistant_message_text = "Summarized content here."
|
||||
output_tokens = len(assistant_message_text)
|
||||
total_input_tokens = sum(len(msg.text) for msg in messages)
|
||||
@@ -138,14 +138,14 @@ MESSAGE_SEQUENCE = [
|
||||
]
|
||||
|
||||
|
||||
class AnotherMockProvider:
|
||||
class AnotherMockProvider(Provider):
|
||||
def __init__(self):
|
||||
self.sequence = MESSAGE_SEQUENCE
|
||||
self.current_index = 1
|
||||
self.summarize_next = False
|
||||
self.summarized_count = 0
|
||||
|
||||
def complete(self, model, system, messages, tools):
|
||||
def complete(self, model, system, messages, tools, **kwargs):
|
||||
system_prompt_tokens = 100
|
||||
input_token_count = system_prompt_tokens
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ class TruncateLinearProvider(Provider):
|
||||
self.summarize_next = False
|
||||
self.summarized_count = 0
|
||||
|
||||
def complete(self, model, system, messages, tools):
|
||||
def complete(self, model, system, messages, tools, **kwargs):
|
||||
input_token_count = SYSTEM_PROMPT_TOKENS
|
||||
|
||||
message = self.sequence[self.current_index]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Mapping, Optional, Tuple
|
||||
from typing import Mapping, Optional
|
||||
|
||||
from rich import print
|
||||
from rich.panel import Panel
|
||||
@@ -20,7 +20,7 @@ RECOMMENDED_DEFAULT_PROVIDER = "openai"
|
||||
|
||||
|
||||
@cache
|
||||
def default_profiles() -> Mapping[str, Callable]:
|
||||
def default_profiles() -> Mapping[str, callable]:
|
||||
return load_plugins(group="goose.profile")
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def session_path(name: str) -> Path:
|
||||
return SESSIONS_PATH.joinpath(f"{name}{SESSION_FILE_SUFFIX}")
|
||||
|
||||
|
||||
def write_config(profiles: Dict[str, Profile]) -> None:
|
||||
def write_config(profiles: dict[str, Profile]) -> None:
|
||||
"""Overwrite the config with the passed profiles"""
|
||||
PROFILES_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
converted = {name: profile.to_dict() for name, profile in profiles.items()}
|
||||
@@ -38,7 +38,7 @@ def write_config(profiles: Dict[str, Profile]) -> None:
|
||||
yaml.dump(converted, f)
|
||||
|
||||
|
||||
def ensure_config(name: Optional[str]) -> Tuple[str, Profile]:
|
||||
def ensure_config(name: Optional[str]) -> tuple[str, Profile]:
|
||||
"""Ensure that the config exists and has the default section"""
|
||||
# TODO we should copy a templated default config in to better document
|
||||
# but this is complicated a bit by autodetecting the provider
|
||||
@@ -70,7 +70,7 @@ def ensure_config(name: Optional[str]) -> Tuple[str, Profile]:
|
||||
return (name, default_profile)
|
||||
|
||||
|
||||
def read_config() -> Dict[str, Profile]:
|
||||
def read_config() -> dict[str, Profile]:
|
||||
"""Return config from the configuration file and validates its contents"""
|
||||
|
||||
yaml = YAML()
|
||||
@@ -80,7 +80,7 @@ def read_config() -> Dict[str, Profile]:
|
||||
return {name: Profile(**profile) for name, profile in data.items()}
|
||||
|
||||
|
||||
def default_model_configuration() -> Tuple[str, str, str]:
|
||||
def default_model_configuration() -> tuple[str, str, str]:
|
||||
providers = load_plugins(group="exchange.provider")
|
||||
for provider, cls in providers.items():
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
|
||||
from prompt_toolkit.document import Document
|
||||
@@ -8,10 +7,10 @@ from goose.command.base import Command
|
||||
|
||||
|
||||
class GoosePromptCompleter(Completer):
|
||||
def __init__(self, commands: List[Command]) -> None:
|
||||
def __init__(self, commands: list[Command]) -> None:
|
||||
self.commands = commands
|
||||
|
||||
def get_command_completions(self, document: Document) -> List[Completion]:
|
||||
def get_command_completions(self, document: Document) -> list[Completion]:
|
||||
all_completions = []
|
||||
for command_name, command_instance in self.commands.items():
|
||||
pattern = rf"(?<!\S)\/{command_name}:([\S]*)$"
|
||||
@@ -25,7 +24,7 @@ class GoosePromptCompleter(Completer):
|
||||
all_completions.extend(completions)
|
||||
return all_completions
|
||||
|
||||
def get_command_name_completions(self, document: Document) -> List[Completion]:
|
||||
def get_command_name_completions(self, document: Document) -> list[Completion]:
|
||||
pattern = r"(?<!\S)\/([\S]*)$"
|
||||
text = document.text_before_cursor
|
||||
match = re.search(pattern=pattern, string=text)
|
||||
@@ -40,7 +39,7 @@ class GoosePromptCompleter(Completer):
|
||||
completions.append(Completion(command_name, start_position=-len(query), display=command_name))
|
||||
return completions
|
||||
|
||||
def get_completions(self, document: Document, _: CompleteEvent) -> List[Completion]:
|
||||
def get_completions(self, document: Document, _: CompleteEvent) -> list[Completion]:
|
||||
command_completions = self.get_command_completions(document)
|
||||
command_name_completions = self.get_command_name_completions(document)
|
||||
return command_name_completions + command_completions
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Callable
|
||||
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.lexers import Lexer
|
||||
@@ -27,7 +27,7 @@ def value_for_command(command_string: str) -> re.Pattern[str]:
|
||||
|
||||
|
||||
class PromptLexer(Lexer):
|
||||
def __init__(self, command_names: List[str]) -> None:
|
||||
def __init__(self, command_names: list[str]) -> None:
|
||||
self.patterns = []
|
||||
for command_name in command_names:
|
||||
self.patterns.append((completion_for_command(command_name), "class:command"))
|
||||
@@ -35,7 +35,7 @@ class PromptLexer(Lexer):
|
||||
self.patterns.append((command_itself(command_name), "class:command"))
|
||||
|
||||
def lex_document(self, document: Document) -> Callable[[int], list]:
|
||||
def get_line_tokens(line_number: int) -> Tuple[str, str]:
|
||||
def get_line_tokens(line_number: int) -> tuple[str, str]:
|
||||
line = document.lines[line_number]
|
||||
tokens = []
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from rich.prompt import Prompt
|
||||
|
||||
|
||||
class OverwriteSessionPrompt(Prompt):
|
||||
def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
|
||||
def __init__(self, *args: tuple[any], **kwargs: dict[str, any]) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.choices = {
|
||||
"yes": "Overwrite the existing session",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from exchange import Message, Text, ToolResult, ToolUse
|
||||
from rich import print
|
||||
@@ -62,7 +62,7 @@ class Session:
|
||||
profile: Optional[str] = None,
|
||||
plan: Optional[dict] = None,
|
||||
log_level: Optional[str] = "INFO",
|
||||
**kwargs: dict[str, Any],
|
||||
**kwargs: dict[str, any],
|
||||
) -> None:
|
||||
if name is None:
|
||||
self.name = droid()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from functools import cache
|
||||
from typing import Dict
|
||||
|
||||
from goose.command.base import Command
|
||||
from goose.utils import load_plugins
|
||||
@@ -11,5 +10,5 @@ def get_command(name: str) -> type[Command]:
|
||||
|
||||
|
||||
@cache
|
||||
def get_commands() -> Dict[str, type[Command]]:
|
||||
def get_commands() -> dict[str, type[Command]]:
|
||||
return load_plugins(group="goose.command")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from prompt_toolkit.completion import Completion
|
||||
|
||||
@@ -7,7 +7,7 @@ from prompt_toolkit.completion import Completion
|
||||
class Command(ABC):
|
||||
"""A command that can be executed by the CLI."""
|
||||
|
||||
def get_completions(self, query: str) -> List[Completion]:
|
||||
def get_completions(self, query: str) -> list[Completion]:
|
||||
"""
|
||||
Get completions for the command.
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from prompt_toolkit.completion import Completion
|
||||
|
||||
@@ -7,7 +6,7 @@ from goose.command.base import Command
|
||||
|
||||
|
||||
class FileCommand(Command):
|
||||
def get_completions(self, query: str) -> List[Completion]:
|
||||
def get_completions(self, query: str) -> list[Completion]:
|
||||
if query.startswith("/"):
|
||||
directory = os.path.dirname(query)
|
||||
search_term = os.path.basename(query)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Mapping, Type
|
||||
from typing import Mapping
|
||||
|
||||
from attrs import asdict, define, field
|
||||
|
||||
@@ -21,10 +21,10 @@ class Profile:
|
||||
processor: str
|
||||
accelerator: str
|
||||
moderator: str
|
||||
toolkits: List[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec))
|
||||
toolkits: list[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec))
|
||||
|
||||
@toolkits.validator
|
||||
def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[ToolkitSpec]) -> None:
|
||||
def check_toolkit_requirements(self, _: type["ToolkitSpec"], toolkits: list[ToolkitSpec]) -> None:
|
||||
# checks that the list of toolkits in the profile have their requirements
|
||||
installed_toolkits = set([toolkit.name for toolkit in toolkits])
|
||||
|
||||
@@ -36,7 +36,7 @@ class Profile:
|
||||
msg = f"Toolkit {toolkit_name} requires {req} but it is not present"
|
||||
raise ValueError(msg)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, any]:
|
||||
return asdict(self)
|
||||
|
||||
def profile_info(self) -> str:
|
||||
@@ -44,7 +44,7 @@ class Profile:
|
||||
return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}"
|
||||
|
||||
|
||||
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile:
|
||||
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: dict[str, any]) -> Profile:
|
||||
"""Get the default profile"""
|
||||
|
||||
# TODO consider if the providers should have recommended models
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Callable, Mapping, Optional, Tuple, TypeVar
|
||||
from typing import Mapping, Optional, TypeVar
|
||||
|
||||
from attrs import define, field
|
||||
from exchange import Tool
|
||||
@@ -8,7 +8,7 @@ from exchange import Tool
|
||||
from goose.notifier import Notifier
|
||||
|
||||
# Create a type variable that can represent any function signature
|
||||
F = TypeVar("F", bound=Callable)
|
||||
F = TypeVar("F", bound=callable)
|
||||
|
||||
|
||||
def tool(func: F) -> F:
|
||||
@@ -55,7 +55,7 @@ class Toolkit(ABC):
|
||||
"""Get the addition to the system prompt for this toolkit."""
|
||||
return ""
|
||||
|
||||
def tools(self) -> Tuple[Tool, ...]:
|
||||
def tools(self) -> tuple[Tool, ...]:
|
||||
"""Get the tools for this toolkit
|
||||
|
||||
This default method looks for functions on the toolkit annotated
|
||||
|
||||
@@ -3,7 +3,6 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from exchange import Message
|
||||
from goose.toolkit.base import Toolkit, tool
|
||||
@@ -35,9 +34,9 @@ class Developer(Toolkit):
|
||||
We also include some default shell strategies in the prompt, such as using ripgrep
|
||||
"""
|
||||
|
||||
def __init__(self, *args: object, **kwargs: Dict[str, object]) -> None:
|
||||
def __init__(self, *args: object, **kwargs: dict[str, object]) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.timestamps: Dict[str, float] = {}
|
||||
self.timestamps: dict[str, float] = {}
|
||||
|
||||
def system(self) -> str:
|
||||
"""Retrieve system configuration details for developer"""
|
||||
@@ -55,7 +54,7 @@ class Developer(Toolkit):
|
||||
return system_prompt
|
||||
|
||||
@tool
|
||||
def update_plan(self, tasks: List[dict]) -> List[dict]:
|
||||
def update_plan(self, tasks: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Update the plan by overwriting all current tasks
|
||||
|
||||
@@ -63,7 +62,7 @@ class Developer(Toolkit):
|
||||
shown to the user directly, you do not need to reiterate it
|
||||
|
||||
Args:
|
||||
tasks (List(dict)): The list of tasks, where each task is a dictionary
|
||||
tasks (list(dict)): The list of tasks, where each task is a dictionary
|
||||
with a key for the task "description" and the task "status". The status
|
||||
MUST be one of "planned", "complete", "failed", "in-progress".
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
from functools import cache
|
||||
from subprocess import CompletedProcess, run
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from exchange import Message
|
||||
|
||||
@@ -21,7 +20,7 @@ class RepoContext(Toolkit):
|
||||
|
||||
self.repo_project_root, self.is_git_repo, self.goose_session_root = self.determine_git_proj()
|
||||
|
||||
def determine_git_proj(self) -> Tuple[str, bool, str]:
|
||||
def determine_git_proj(self) -> tuple[str, bool, str]:
|
||||
"""Determines the root as well as where Goose is currently running
|
||||
|
||||
If the project is not part of a Github repo, the root of the project will be defined as the current working
|
||||
@@ -72,11 +71,11 @@ class RepoContext(Toolkit):
|
||||
return self.repo_size > 2000
|
||||
|
||||
@tool
|
||||
def summarize_current_project(self) -> Dict[str, str]:
|
||||
def summarize_current_project(self) -> dict[str, str]:
|
||||
"""Summarizes the current project based on repo root (if git repo) or current project_directory (if not)
|
||||
|
||||
Returns:
|
||||
summary (Dict[str, str]): Keys are file paths and values are the summaries
|
||||
summary (dict[str, str]): Keys are file paths and values are the summaries
|
||||
"""
|
||||
|
||||
self.notifier.log("Summarizing the most relevant files in the current project. This may take a while...")
|
||||
|
||||
@@ -2,7 +2,6 @@ import ast
|
||||
import concurrent.futures
|
||||
import os
|
||||
from collections import deque
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from exchange import Exchange
|
||||
|
||||
@@ -26,7 +25,7 @@ def get_repo_size(repo_path: str) -> int:
|
||||
return get_directory_size(git_dir) / (1024**2)
|
||||
|
||||
|
||||
def get_files_and_directories(root_dir: str) -> Dict[str, list]:
|
||||
def get_files_and_directories(root_dir: str) -> dict[str, list]:
|
||||
"""Gets file names and directory names. Checks that goose has correctly typed the file and directory names and that
|
||||
the files actually exist (to avoid downstream file read errors).
|
||||
|
||||
@@ -61,7 +60,7 @@ def get_files_and_directories(root_dir: str) -> Dict[str, list]:
|
||||
return {"files": files, "directories": dirs}
|
||||
|
||||
|
||||
def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> List[str]:
|
||||
def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> list[str]:
|
||||
"""Lets goose pick files in a BFS manner"""
|
||||
queue = deque([root])
|
||||
|
||||
@@ -80,7 +79,7 @@ def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> Li
|
||||
return all_files
|
||||
|
||||
|
||||
def process_directory(current_dir: str, exchange: Exchange) -> Tuple[List[str], List[str]]:
|
||||
def process_directory(current_dir: str, exchange: Exchange) -> tuple[list[str], list[str]]:
|
||||
"""Allows goose to pick files and subdirectories contained in a given directory (current_dir). Get the list of file
|
||||
and directory names in the current folder, then ask Goose to pick which ones to keep.
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from goose.toolkit import Toolkit
|
||||
from goose.toolkit.base import tool
|
||||
@@ -11,7 +11,7 @@ class SummarizeProject(Toolkit):
|
||||
def get_project_summary(
|
||||
self,
|
||||
project_dir_path: Optional[str] = os.getcwd(),
|
||||
extensions: Optional[List[str]] = None,
|
||||
extensions: Optional[list[str]] = None,
|
||||
summary_instructions_prompt: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Generates or retrieves a project summary based on specified file extensions.
|
||||
@@ -19,7 +19,7 @@ class SummarizeProject(Toolkit):
|
||||
Args:
|
||||
project_dir_path (Optional[Path]): Path to the project directory. Defaults to the current working directory
|
||||
if None
|
||||
extensions (Optional[List[str]]): Specific file extensions to summarize.
|
||||
extensions (Optional[list[str]]): Specific file extensions to summarize.
|
||||
summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g.
|
||||
"Summarize the file in two sentences.". The default instruction is "Please summarize this file."
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from goose.toolkit import Toolkit
|
||||
from goose.toolkit.base import tool
|
||||
@@ -10,7 +10,7 @@ class SummarizeRepo(Toolkit):
|
||||
def summarize_repo(
|
||||
self,
|
||||
repo_url: str,
|
||||
specified_extensions: Optional[List[str]] = None,
|
||||
specified_extensions: Optional[list[str]] = None,
|
||||
summary_instructions_prompt: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -19,7 +19,7 @@ class SummarizeRepo(Toolkit):
|
||||
|
||||
Args:
|
||||
repo_url (str): The URL of the repository to summarize.
|
||||
specified_extensions (Optional[List[str]]): List of file extensions to summarize, e.g., ["tf", "md"]. If
|
||||
specified_extensions (Optional[list[str]]): list of file extensions to summarize, e.g., ["tf", "md"]. If
|
||||
this list is empty, then all files in the repo are summarized
|
||||
summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g.
|
||||
"Summarize the file in two sentences.". The default instruction is "Please summarize this file."
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from exchange import Exchange
|
||||
from exchange.providers.utils import InitialMessageTooLargeError
|
||||
@@ -15,7 +15,7 @@ CLONED_REPOS_FOLDER = ".goose/cloned_repos"
|
||||
|
||||
|
||||
# TODO: move git stuff
|
||||
def run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]:
|
||||
def run_git_command(command: list[str]) -> subprocess.CompletedProcess[str]:
|
||||
result = subprocess.run(["git"] + command, capture_output=True, text=True, check=False)
|
||||
|
||||
if result.returncode != 0:
|
||||
@@ -28,7 +28,7 @@ def clone_repo(repo_url: str, target_directory: str) -> None:
|
||||
run_git_command(["clone", repo_url, target_directory])
|
||||
|
||||
|
||||
def load_summary_file_if_exists(project_name: str) -> Optional[Dict]:
|
||||
def load_summary_file_if_exists(project_name: str) -> Optional[dict]:
|
||||
"""Checks if a summary file exists at '.goose/summaries/projectname-summary.json. Returns contents of the file if
|
||||
it exists, otherwise returns None
|
||||
|
||||
@@ -36,7 +36,7 @@ def load_summary_file_if_exists(project_name: str) -> Optional[Dict]:
|
||||
project_name (str): name of the project or repo
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: File contents, else None
|
||||
Optional[dict]: File contents, else None
|
||||
"""
|
||||
summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json"
|
||||
if Path(summary_file_path).exists():
|
||||
@@ -44,7 +44,7 @@ def load_summary_file_if_exists(project_name: str) -> Optional[Dict]:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> Tuple[str, str]:
|
||||
def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> tuple[str, str]:
|
||||
"""Summarizes a single file
|
||||
|
||||
Args:
|
||||
@@ -74,15 +74,15 @@ def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = No
|
||||
def summarize_repo(
|
||||
repo_url: str,
|
||||
exchange: Exchange,
|
||||
extensions: List[str],
|
||||
extensions: list[str],
|
||||
summary_instructions_prompt: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Clones (if needed) and summarizes a repo
|
||||
|
||||
Args:
|
||||
repo_url (str): Repository url
|
||||
exchange (Exchange): Exchange for summarizing the repo.
|
||||
extensions (List[str]): List of file-types to summarize.
|
||||
extensions (list[str]): list of file-types to summarize.
|
||||
summary_instructions_prompt (Optional[str]): Optional parameter to customize summarization results. Defaults to
|
||||
"Please summarize this file"
|
||||
"""
|
||||
@@ -110,15 +110,15 @@ def summarize_repo(
|
||||
|
||||
|
||||
def summarize_directory(
|
||||
directory: str, exchange: Exchange, extensions: List[str], summary_instructions_prompt: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
directory: str, exchange: Exchange, extensions: list[str], summary_instructions_prompt: Optional[str] = None
|
||||
) -> dict[str, str]:
|
||||
"""Summarize files in a given directory based on extensions. Will also recursively find files in subdirectories and
|
||||
summarize them.
|
||||
|
||||
Args:
|
||||
directory (str): path to the top-level directory to summarize
|
||||
exchange (Exchange): Exchange to use to summarize
|
||||
extensions (List[str]): List of file-type extensions to summarize (and ignore all other extensions).
|
||||
extensions (list[str]): list of file-type extensions to summarize (and ignore all other extensions).
|
||||
summary_instructions_prompt (Optional[str]): Optional instructions to give to the exchange regarding summarization.
|
||||
|
||||
Returns:
|
||||
@@ -158,19 +158,19 @@ def summarize_directory(
|
||||
|
||||
|
||||
def summarize_files_concurrent(
|
||||
exchange: Exchange, file_list: List[str], project_name: str, summary_instructions_prompt: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
exchange: Exchange, file_list: list[str], project_name: str, summary_instructions_prompt: Optional[str] = None
|
||||
) -> dict[str, str]:
|
||||
"""Takes in a list of files and summarizes them. Exchange does not keep history of the summarized files.
|
||||
|
||||
Args:
|
||||
exchange (Exchange): Underlying exchange
|
||||
file_list (List[str]): List of paths to files to summarize
|
||||
file_list (list[str]): list of paths to files to summarize
|
||||
project_name (str): Used to save the summary of the files to .goose/summaries/<project_name>-summary.json
|
||||
summary_instructions_prompt (Optional[str]): Summary instructions for the LLM. Defaults to "Please summarize
|
||||
this file."
|
||||
|
||||
Returns:
|
||||
file_summaries (Dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange
|
||||
file_summaries (dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange
|
||||
"""
|
||||
summary_file = load_summary_file_if_exists(project_name)
|
||||
if summary_file:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from typing import Optional
|
||||
|
||||
from pygments.lexers import get_lexer_for_filename
|
||||
from pygments.util import ClassNotFound
|
||||
@@ -67,7 +67,7 @@ def find_last_task_group_index(input_str: str) -> int:
|
||||
return last_group_start_index
|
||||
|
||||
|
||||
def parse_plan(input_plan_str: str) -> Dict:
|
||||
def parse_plan(input_plan_str: str) -> dict:
|
||||
last_group_start_index = find_last_task_group_index(input_plan_str)
|
||||
if last_group_start_index == -1:
|
||||
return {"kickoff_message": input_plan_str, "tasks": []}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
import string
|
||||
from importlib.metadata import entry_points
|
||||
from typing import Any, Callable, Dict, List, Type, TypeVar
|
||||
from typing import TypeVar, Callable
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -31,10 +31,10 @@ def load_plugins(group: str) -> dict:
|
||||
return plugins
|
||||
|
||||
|
||||
def ensure(cls: Type[T]) -> Callable[[Any], T]:
|
||||
def ensure(cls: type[T]) -> Callable[[any], T]:
|
||||
"""Convert dictionary to a class instance"""
|
||||
|
||||
def converter(val: Any) -> T: # noqa: ANN401
|
||||
def converter(val: any) -> T: # noqa: ANN401
|
||||
if isinstance(val, cls):
|
||||
return val
|
||||
elif isinstance(val, dict):
|
||||
@@ -47,10 +47,10 @@ def ensure(cls: Type[T]) -> Callable[[Any], T]:
|
||||
return converter
|
||||
|
||||
|
||||
def ensure_list(cls: Type[T]) -> Callable[[List[Dict[str, Any]]], Type[T]]:
|
||||
def ensure_list(cls: type[T]) -> Callable[[list[dict[str, any]]], type[T]]:
|
||||
"""Convert a list of dictionaries to class instances"""
|
||||
|
||||
def converter(val: List[Dict[str, Any]]) -> List[T]:
|
||||
def converter(val: list[dict[str, any]]) -> list[T]:
|
||||
output = []
|
||||
for entry in val:
|
||||
output.append(ensure(cls)(entry))
|
||||
|
||||
@@ -2,7 +2,7 @@ import glob
|
||||
import os
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def create_extensions_list(project_root: str, max_n: int) -> list:
|
||||
@@ -11,7 +11,7 @@ def create_extensions_list(project_root: str, max_n: int) -> list:
|
||||
project_root (str): Root of the project to analyze
|
||||
max_n (int): The number of file extensions to return
|
||||
Returns:
|
||||
extensions (List[str]): A list of the top N file extensions
|
||||
extensions (list[str]): A list of the top N file extensions
|
||||
"""
|
||||
if max_n == 0:
|
||||
raise (ValueError("Number of file extensions must be greater than 0"))
|
||||
@@ -31,14 +31,14 @@ def create_extensions_list(project_root: str, max_n: int) -> list:
|
||||
return extensions
|
||||
|
||||
|
||||
def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]:
|
||||
def create_language_weighting(files_in_directory: list[str]) -> dict[str, float]:
|
||||
"""Calculate language weighting by file size to match GitHub's methodology.
|
||||
|
||||
Args:
|
||||
files_in_directory (List[str]): Paths to files in the project directory
|
||||
files_in_directory (list[str]): Paths to files in the project directory
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values
|
||||
dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values
|
||||
"""
|
||||
|
||||
# Initialize counters for sizes
|
||||
@@ -59,7 +59,7 @@ def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]
|
||||
return dict(sorted(language_percentages.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
|
||||
def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> List[str]:
|
||||
def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> list[str]:
|
||||
"""List all files in a directory with a given extension. Set extension to '' to return all files.
|
||||
|
||||
Args:
|
||||
@@ -67,7 +67,7 @@ def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> L
|
||||
extension (Optional[str]): extension to lookup. Defaults to '' which will return all files.
|
||||
|
||||
Returns:
|
||||
files (List[str]): List of file paths
|
||||
files (list[str]): list of file paths
|
||||
"""
|
||||
# add a leading '.' to extension if needed
|
||||
if extension and not extension.startswith("."):
|
||||
@@ -77,15 +77,15 @@ def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> L
|
||||
return files
|
||||
|
||||
|
||||
def create_file_list(dir_path: str, extensions: List[str]) -> List[str]:
|
||||
def create_file_list(dir_path: str, extensions: list[str]) -> list[str]:
|
||||
"""Creates a list of files with certain extensions
|
||||
|
||||
Args:
|
||||
dir_path (str): Directory to list files of. Will include files recursively in sub-directories.
|
||||
extensions (List[str]): List of file extensions to select for. If empty list, return all files
|
||||
extensions (list[str]): list of file extensions to select for. If empty list, return all files
|
||||
|
||||
Returns:
|
||||
final_file_list (List[str]): List of file paths with specified extensions.
|
||||
final_file_list (list[str]): list of file paths with specified extensions.
|
||||
"""
|
||||
# if extensions is empty list, return all files
|
||||
if not extensions:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List
|
||||
from typing import Iterator
|
||||
|
||||
from exchange import Message
|
||||
|
||||
@@ -17,12 +17,12 @@ def is_empty_session(path: Path) -> bool:
|
||||
return path.is_file() and path.stat().st_size == 0
|
||||
|
||||
|
||||
def write_to_file(file_path: Path, messages: List[Message]) -> None:
|
||||
def write_to_file(file_path: Path, messages: list[Message]) -> None:
|
||||
with open(file_path, "w") as f:
|
||||
_write_messages_to_file(f, messages)
|
||||
|
||||
|
||||
def read_or_create_file(file_path: Path) -> List[Message]:
|
||||
def read_or_create_file(file_path: Path) -> list[Message]:
|
||||
if file_path.exists():
|
||||
return read_from_file(file_path)
|
||||
with open(file_path, "w"):
|
||||
@@ -30,7 +30,7 @@ def read_or_create_file(file_path: Path) -> List[Message]:
|
||||
return []
|
||||
|
||||
|
||||
def read_from_file(file_path: Path) -> List[Message]:
|
||||
def read_from_file(file_path: Path) -> list[Message]:
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
messages = [json.loads(m) for m in list(f) if m.strip()]
|
||||
@@ -40,7 +40,7 @@ def read_from_file(file_path: Path) -> List[Message]:
|
||||
return [Message(**m) for m in messages]
|
||||
|
||||
|
||||
def list_sorted_session_files(session_files_directory: Path) -> Dict[str, Path]:
|
||||
def list_sorted_session_files(session_files_directory: Path) -> dict[str, Path]:
|
||||
logs = list_session_files(session_files_directory)
|
||||
return {log.stem: log for log in sorted(logs, key=lambda x: x.stat().st_mtime, reverse=True)}
|
||||
|
||||
@@ -55,7 +55,7 @@ def session_file_exists(session_files_directory: Path) -> bool:
|
||||
return any(list_session_files(session_files_directory))
|
||||
|
||||
|
||||
def save_latest_session(file_path: Path, messages: List[Message]) -> None:
|
||||
def save_latest_session(file_path: Path, messages: list[Message]) -> None:
|
||||
with tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
|
||||
_write_messages_to_file(temp_file, messages)
|
||||
temp_file_path = temp_file.name
|
||||
@@ -63,7 +63,7 @@ def save_latest_session(file_path: Path, messages: List[Message]) -> None:
|
||||
os.replace(temp_file_path, file_path)
|
||||
|
||||
|
||||
def _write_messages_to_file(file: any, messages: List[Message]) -> None:
|
||||
def _write_messages_to_file(file: any, messages: list[Message]) -> None:
|
||||
for m in messages:
|
||||
json.dump(m.to_dict(), file)
|
||||
file.write("\n")
|
||||
|
||||
Reference in New Issue
Block a user