mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 15:34:27 +01:00
fix: removed the diff when default profile changes. Printed the current profile info instead (#92)
This commit is contained in:
@@ -1,19 +1,15 @@
|
||||
from functools import cache
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Mapping, Tuple
|
||||
from typing import Callable, Dict, Mapping, Optional, Tuple
|
||||
|
||||
from rich import print
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from rich.text import Text
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from exchange.providers.ollama import OLLAMA_MODEL
|
||||
|
||||
from goose.profile import Profile
|
||||
from goose.utils import load_plugins
|
||||
from goose.utils.diff import pretty_diff
|
||||
|
||||
GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser()
|
||||
PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml")
|
||||
@@ -41,15 +37,18 @@ def write_config(profiles: Dict[str, Profile]) -> None:
|
||||
yaml.dump(converted, f)
|
||||
|
||||
|
||||
def ensure_config(name: str) -> Profile:
|
||||
def ensure_config(name: Optional[str]) -> Tuple[str, Profile]:
|
||||
"""Ensure that the config exists and has the default section"""
|
||||
# TODO we should copy a templated default config in to better document
|
||||
# but this is complicated a bit by autodetecting the provider
|
||||
|
||||
default_profile_name = "default"
|
||||
name = name or default_profile_name
|
||||
default_profiles_dict = default_profiles()
|
||||
provider, processor, accelerator = default_model_configuration()
|
||||
profile = default_profiles()[name](provider, processor, accelerator)
|
||||
default_profile = default_profiles_dict.get(name, default_profiles_dict[default_profile_name])(
|
||||
provider, processor, accelerator
|
||||
)
|
||||
|
||||
profiles = {}
|
||||
if not PROFILES_CONFIG_PATH.exists():
|
||||
print(
|
||||
Panel(
|
||||
@@ -58,49 +57,16 @@ def ensure_config(name: str) -> Profile:
|
||||
+ "You can add your own profile in this file to further configure goose!"
|
||||
)
|
||||
)
|
||||
default = profile
|
||||
profiles = {name: default}
|
||||
write_config(profiles)
|
||||
return profile
|
||||
write_config({name: default_profile})
|
||||
return (name, default_profile)
|
||||
|
||||
profiles = read_config()
|
||||
if name not in profiles:
|
||||
print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]"))
|
||||
profiles.update({name: profile})
|
||||
write_config(profiles)
|
||||
elif name in profiles:
|
||||
# if the profile stored differs from the default one, we should prompt the user to see if they want
|
||||
# to update it! we need to recursively compare the two profiles, as object comparison will always return false
|
||||
is_profile_eq = profile.to_dict() == profiles[name].to_dict()
|
||||
if not is_profile_eq:
|
||||
yaml = YAML()
|
||||
before = StringIO()
|
||||
after = StringIO()
|
||||
yaml.dump(profiles[name].to_dict(), before)
|
||||
yaml.dump(profile.to_dict(), after)
|
||||
before.seek(0)
|
||||
after.seek(0)
|
||||
|
||||
print(
|
||||
Panel(
|
||||
Text(
|
||||
f"Your profile uses one of the default options - '{name}'"
|
||||
+ " - but it differs from the latest version:\n\n",
|
||||
)
|
||||
+ pretty_diff(before.read(), after.read())
|
||||
)
|
||||
)
|
||||
should_update = Confirm.ask(
|
||||
"Do you want to update your profile to use the latest?",
|
||||
default=False,
|
||||
)
|
||||
if should_update:
|
||||
profiles[name] = profile
|
||||
write_config(profiles)
|
||||
else:
|
||||
profile = profiles[name]
|
||||
|
||||
return profile
|
||||
if name in profiles:
|
||||
return (name, profiles[name])
|
||||
print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]"))
|
||||
profiles.update({name: default_profile})
|
||||
write_config(profiles)
|
||||
return (name, default_profile)
|
||||
|
||||
|
||||
def read_config() -> Dict[str, Profile]:
|
||||
@@ -118,7 +84,6 @@ def default_model_configuration() -> Tuple[str, str, str]:
|
||||
for provider, cls in providers.items():
|
||||
try:
|
||||
cls.from_env()
|
||||
print(Panel(f"[green]Detected an available provider: [/]{provider}"))
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -12,7 +12,7 @@ from rich.panel import Panel
|
||||
from rich.status import Status
|
||||
|
||||
from goose.build import build_exchange
|
||||
from goose.cli.config import default_profiles, ensure_config, read_config, session_path, LOG_PATH
|
||||
from goose.cli.config import ensure_config, session_path, LOG_PATH
|
||||
from goose._logger import get_logger, setup_logging
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.notifier import Notifier
|
||||
@@ -46,15 +46,9 @@ def load_provider() -> str:
|
||||
|
||||
|
||||
def load_profile(name: Optional[str]) -> Profile:
|
||||
if name is None:
|
||||
name = "default"
|
||||
|
||||
# If the name is one of the default values, we ensure a valid configuration
|
||||
if name in default_profiles():
|
||||
return ensure_config(name)
|
||||
|
||||
# Otherwise this is a custom config and we return it from the config file
|
||||
return read_config()[name]
|
||||
(profile_name, profile) = ensure_config(name)
|
||||
print(Panel(f"[green]Using profile[/]: {profile_name}, {{{profile.profile_info()}}}"))
|
||||
return profile
|
||||
|
||||
|
||||
class SessionNotifier(Notifier):
|
||||
|
||||
@@ -39,6 +39,10 @@ class Profile:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
def profile_info(self) -> str:
|
||||
tookit_names = [toolkit.name for toolkit in self.toolkits]
|
||||
return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}"
|
||||
|
||||
|
||||
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile:
|
||||
"""Get the default profile"""
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
def diff(a: str, b: str) -> List[str]:
|
||||
"""Returns a string containing the unified diff of two strings."""
|
||||
|
||||
import difflib
|
||||
|
||||
a_lines = a.splitlines()
|
||||
b_lines = b.splitlines()
|
||||
|
||||
# Create a Differ object
|
||||
d = difflib.Differ()
|
||||
|
||||
# Generate the diff
|
||||
diff = list(d.compare(a_lines, b_lines))
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def pretty_diff(a: str, b: str) -> Text:
|
||||
"""Returns a pretty-printed diff of two strings."""
|
||||
|
||||
diff_lines = diff(a, b)
|
||||
result = Text()
|
||||
for line in diff_lines:
|
||||
if line.startswith("+"):
|
||||
result.append(line, style="green")
|
||||
elif line.startswith("-"):
|
||||
result.append(line, style="red")
|
||||
elif line.startswith("?"):
|
||||
result.append(line, style="yellow")
|
||||
else:
|
||||
result.append(line, style="dim grey")
|
||||
result.append("\n")
|
||||
|
||||
return result
|
||||
@@ -28,53 +28,66 @@ def test_read_write_config(mock_profile_config_path, profile_factory):
|
||||
assert read_config() == profiles
|
||||
|
||||
|
||||
def test_ensure_config_create_profiles_file_with_default_profile(
|
||||
def test_ensure_config_create_profiles_file_with_default_profile_with_name_default(
|
||||
mock_profile_config_path, mock_default_model_configuration
|
||||
):
|
||||
assert not mock_profile_config_path.exists()
|
||||
|
||||
ensure_config(name="default")
|
||||
(profile_name, profile) = ensure_config(name=None)
|
||||
|
||||
expected_profile = default_profile(*mock_default_model_configuration())
|
||||
|
||||
assert profile_name == "default"
|
||||
assert profile == expected_profile
|
||||
assert mock_profile_config_path.exists()
|
||||
|
||||
assert read_config() == {"default": default_profile(*mock_default_model_configuration())}
|
||||
assert read_config() == {"default": expected_profile}
|
||||
|
||||
|
||||
def test_ensure_config_add_default_profile(mock_profile_config_path, profile_factory, mock_default_model_configuration):
|
||||
def test_ensure_config_create_profiles_file_with_default_profile_with_profile_name(
|
||||
mock_profile_config_path, mock_default_model_configuration
|
||||
):
|
||||
assert not mock_profile_config_path.exists()
|
||||
|
||||
(profile_name, profile) = ensure_config(name="my_profile")
|
||||
|
||||
expected_profile = default_profile(*mock_default_model_configuration())
|
||||
|
||||
assert profile_name == "my_profile"
|
||||
assert profile == expected_profile
|
||||
assert mock_profile_config_path.exists()
|
||||
assert read_config() == {"my_profile": expected_profile}
|
||||
|
||||
|
||||
def test_ensure_config_add_default_profile_when_profile_not_exist(
|
||||
mock_profile_config_path, profile_factory, mock_default_model_configuration
|
||||
):
|
||||
existing_profile = profile_factory({"provider": "providerA"})
|
||||
write_config({"profile1": existing_profile})
|
||||
|
||||
ensure_config(name="default")
|
||||
(profile_name, new_profile) = ensure_config(name="my_new_profile")
|
||||
|
||||
expected_profile = default_profile(*mock_default_model_configuration())
|
||||
assert profile_name == "my_new_profile"
|
||||
assert new_profile == expected_profile
|
||||
assert read_config() == {
|
||||
"profile1": existing_profile,
|
||||
"default": default_profile(*mock_default_model_configuration()),
|
||||
"my_new_profile": expected_profile,
|
||||
}
|
||||
|
||||
|
||||
@patch("goose.cli.config.Confirm.ask", return_value=True)
|
||||
def test_ensure_config_overwrite_default_profile(
|
||||
mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration
|
||||
def test_ensure_config_get_existing_profile_not_exist(
|
||||
mock_profile_config_path, profile_factory, mock_default_model_configuration
|
||||
):
|
||||
existing_profile = profile_factory({"provider": "providerA"})
|
||||
profile_name = "default"
|
||||
write_config({profile_name: existing_profile})
|
||||
write_config({"profile1": existing_profile})
|
||||
|
||||
expected_default_profile = default_profile(*mock_default_model_configuration())
|
||||
assert ensure_config(name="default") == expected_default_profile
|
||||
assert read_config() == {"default": expected_default_profile}
|
||||
(profile_name, profile) = ensure_config(name="profile1")
|
||||
|
||||
|
||||
@patch("goose.cli.config.Confirm.ask", return_value=False)
|
||||
def test_ensure_config_keep_original_default_profile(
|
||||
mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration
|
||||
):
|
||||
existing_profile = profile_factory({"provider": "providerA"})
|
||||
profile_name = "default"
|
||||
write_config({profile_name: existing_profile})
|
||||
|
||||
assert ensure_config(name="default") == existing_profile
|
||||
|
||||
assert read_config() == {"default": existing_profile}
|
||||
assert profile_name == "profile1"
|
||||
assert profile == existing_profile
|
||||
assert read_config() == {
|
||||
"profile1": existing_profile,
|
||||
}
|
||||
|
||||
|
||||
def test_session_path(mock_sessions_path):
|
||||
|
||||
12
tests/test_profile.py
Normal file
12
tests/test_profile.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from goose.profile import ToolkitSpec
|
||||
|
||||
|
||||
def test_profile_info(profile_factory):
|
||||
profile = profile_factory(
|
||||
{
|
||||
"provider": "provider",
|
||||
"processor": "processor",
|
||||
"toolkits": [ToolkitSpec("developer"), ToolkitSpec("github")],
|
||||
}
|
||||
)
|
||||
assert profile.profile_info() == "provider:provider, processor:processor toolkits: developer, github"
|
||||
Reference in New Issue
Block a user