diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index b76835a6..2005c468 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -1,19 +1,15 @@ from functools import cache -from io import StringIO from pathlib import Path -from typing import Callable, Dict, Mapping, Tuple +from typing import Callable, Dict, Mapping, Optional, Tuple from rich import print from rich.panel import Panel -from rich.prompt import Confirm -from rich.text import Text from ruamel.yaml import YAML from exchange.providers.ollama import OLLAMA_MODEL from goose.profile import Profile from goose.utils import load_plugins -from goose.utils.diff import pretty_diff GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser() PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml") @@ -41,15 +37,18 @@ def write_config(profiles: Dict[str, Profile]) -> None: yaml.dump(converted, f) -def ensure_config(name: str) -> Profile: +def ensure_config(name: Optional[str]) -> Tuple[str, Profile]: """Ensure that the config exists and has the default section""" # TODO we should copy a templated default config in to better document # but this is complicated a bit by autodetecting the provider - + default_profile_name = "default" + name = name or default_profile_name + default_profiles_dict = default_profiles() provider, processor, accelerator = default_model_configuration() - profile = default_profiles()[name](provider, processor, accelerator) + default_profile = default_profiles_dict.get(name, default_profiles_dict[default_profile_name])( + provider, processor, accelerator + ) - profiles = {} if not PROFILES_CONFIG_PATH.exists(): print( Panel( @@ -58,49 +57,16 @@ def ensure_config(name: str) -> Profile: + "You can add your own profile in this file to further configure goose!" ) ) - default = profile - profiles = {name: default} - write_config(profiles) - return profile + write_config({name: default_profile}) + return (name, default_profile) profiles = read_config() - if name not in profiles: - print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]")) - profiles.update({name: profile}) - write_config(profiles) - elif name in profiles: - # if the profile stored differs from the default one, we should prompt the user to see if they want - # to update it! we need to recursively compare the two profiles, as object comparison will always return false - is_profile_eq = profile.to_dict() == profiles[name].to_dict() - if not is_profile_eq: - yaml = YAML() - before = StringIO() - after = StringIO() - yaml.dump(profiles[name].to_dict(), before) - yaml.dump(profile.to_dict(), after) - before.seek(0) - after.seek(0) - - print( - Panel( - Text( - f"Your profile uses one of the default options - '{name}'" - + " - but it differs from the latest version:\n\n", - ) - + pretty_diff(before.read(), after.read()) - ) - ) - should_update = Confirm.ask( - "Do you want to update your profile to use the latest?", - default=False, - ) - if should_update: - profiles[name] = profile - write_config(profiles) - else: - profile = profiles[name] - - return profile + if name in profiles: + return (name, profiles[name]) + print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]")) + profiles.update({name: default_profile}) + write_config(profiles) + return (name, default_profile) def read_config() -> Dict[str, Profile]: @@ -118,7 +84,6 @@ def default_model_configuration() -> Tuple[str, str, str]: for provider, cls in providers.items(): try: cls.from_env() - print(Panel(f"[green]Detected an available provider: [/]{provider}")) break except Exception: pass diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index b7cdeaa0..5674eb48 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -12,7 +12,7 @@ from rich.panel import Panel from rich.status import Status from goose.build import build_exchange -from goose.cli.config import default_profiles, ensure_config, read_config, session_path, LOG_PATH +from goose.cli.config import ensure_config, session_path, LOG_PATH from goose._logger import get_logger, setup_logging from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.notifier import Notifier @@ -46,15 +46,9 @@ def load_provider() -> str: def load_profile(name: Optional[str]) -> Profile: - if name is None: - name = "default" - - # If the name is one of the default values, we ensure a valid configuration - if name in default_profiles(): - return ensure_config(name) - - # Otherwise this is a custom config and we return it from the config file - return read_config()[name] + (profile_name, profile) = ensure_config(name) + print(Panel(f"[green]Using profile[/]: {profile_name}, {{{profile.profile_info()}}}")) + return profile class SessionNotifier(Notifier): diff --git a/src/goose/profile.py b/src/goose/profile.py index ec0a12a5..b1fb2ad8 100644 --- a/src/goose/profile.py +++ b/src/goose/profile.py @@ -39,6 +39,10 @@ class Profile: def to_dict(self) -> Dict[str, Any]: return asdict(self) + def profile_info(self) -> str: + tookit_names = [toolkit.name for toolkit in self.toolkits] + return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}" + def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile: """Get the default profile""" diff --git a/src/goose/utils/diff.py b/src/goose/utils/diff.py deleted file mode 100644 index e3583be0..00000000 --- a/src/goose/utils/diff.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import List - -from rich.text import Text - - -def diff(a: str, b: str) -> List[str]: - """Returns a string containing the unified diff of two strings.""" - - import difflib - - a_lines = a.splitlines() - b_lines = b.splitlines() - - # Create a Differ object - d = difflib.Differ() - - # Generate the diff - diff = list(d.compare(a_lines, b_lines)) - - return diff - - -def pretty_diff(a: str, b: str) -> Text: - """Returns a pretty-printed diff of two strings.""" - - diff_lines = diff(a, b) - result = Text() - for line in diff_lines: - if line.startswith("+"): - result.append(line, style="green") - elif line.startswith("-"): - result.append(line, style="red") - elif line.startswith("?"): - result.append(line, style="yellow") - else: - result.append(line, style="dim grey") - result.append("\n") - - return result diff --git a/tests/cli/test_config.py b/tests/cli/test_config.py index b857f8b9..0694034d 100644 --- a/tests/cli/test_config.py +++ b/tests/cli/test_config.py @@ -28,53 +28,66 @@ def test_read_write_config(mock_profile_config_path, profile_factory): assert read_config() == profiles -def test_ensure_config_create_profiles_file_with_default_profile( +def test_ensure_config_create_profiles_file_with_default_profile_with_name_default( mock_profile_config_path, mock_default_model_configuration ): assert not mock_profile_config_path.exists() - ensure_config(name="default") + (profile_name, profile) = ensure_config(name=None) + + expected_profile = default_profile(*mock_default_model_configuration()) + + assert profile_name == "default" + assert profile == expected_profile assert mock_profile_config_path.exists() - - assert read_config() == {"default": default_profile(*mock_default_model_configuration())} + assert read_config() == {"default": expected_profile} -def test_ensure_config_add_default_profile(mock_profile_config_path, profile_factory, mock_default_model_configuration): +def test_ensure_config_create_profiles_file_with_default_profile_with_profile_name( + mock_profile_config_path, mock_default_model_configuration +): + assert not mock_profile_config_path.exists() + + (profile_name, profile) = ensure_config(name="my_profile") + + expected_profile = default_profile(*mock_default_model_configuration()) + + assert profile_name == "my_profile" + assert profile == expected_profile + assert mock_profile_config_path.exists() + assert read_config() == {"my_profile": expected_profile} + + +def test_ensure_config_add_default_profile_when_profile_not_exist( + mock_profile_config_path, profile_factory, mock_default_model_configuration +): existing_profile = profile_factory({"provider": "providerA"}) write_config({"profile1": existing_profile}) - ensure_config(name="default") + (profile_name, new_profile) = ensure_config(name="my_new_profile") + expected_profile = default_profile(*mock_default_model_configuration()) + assert profile_name == "my_new_profile" + assert new_profile == expected_profile assert read_config() == { "profile1": existing_profile, - "default": default_profile(*mock_default_model_configuration()), + "my_new_profile": expected_profile, } -@patch("goose.cli.config.Confirm.ask", return_value=True) -def test_ensure_config_overwrite_default_profile( - mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration +def test_ensure_config_get_existing_profile_not_exist( + mock_profile_config_path, profile_factory, mock_default_model_configuration ): existing_profile = profile_factory({"provider": "providerA"}) - profile_name = "default" - write_config({profile_name: existing_profile}) + write_config({"profile1": existing_profile}) - expected_default_profile = default_profile(*mock_default_model_configuration()) - assert ensure_config(name="default") == expected_default_profile - assert read_config() == {"default": expected_default_profile} + (profile_name, profile) = ensure_config(name="profile1") - -@patch("goose.cli.config.Confirm.ask", return_value=False) -def test_ensure_config_keep_original_default_profile( - mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration -): - existing_profile = profile_factory({"provider": "providerA"}) - profile_name = "default" - write_config({profile_name: existing_profile}) - - assert ensure_config(name="default") == existing_profile - - assert read_config() == {"default": existing_profile} + assert profile_name == "profile1" + assert profile == existing_profile + assert read_config() == { + "profile1": existing_profile, + } def test_session_path(mock_sessions_path): diff --git a/tests/test_profile.py b/tests/test_profile.py new file mode 100644 index 00000000..3c022f74 --- /dev/null +++ b/tests/test_profile.py @@ -0,0 +1,12 @@ +from goose.profile import ToolkitSpec + + +def test_profile_info(profile_factory): + profile = profile_factory( + { + "provider": "provider", + "processor": "processor", + "toolkits": [ToolkitSpec("developer"), ToolkitSpec("github")], + } + ) + assert profile.profile_info() == "provider:provider, processor:processor toolkits: developer, github"