fix: removed the diff when default profile changes. Printed the current profile info instead (#92)

This commit is contained in:
Lifei Zhou
2024-09-25 19:07:01 -07:00
committed by GitHub
parent 5c52138f38
commit d56c0d68cd
6 changed files with 76 additions and 127 deletions

View File

@@ -1,19 +1,15 @@
from functools import cache
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Mapping, Tuple
from typing import Callable, Dict, Mapping, Optional, Tuple
from rich import print
from rich.panel import Panel
from rich.prompt import Confirm
from rich.text import Text
from ruamel.yaml import YAML
from exchange.providers.ollama import OLLAMA_MODEL
from goose.profile import Profile
from goose.utils import load_plugins
from goose.utils.diff import pretty_diff
GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser()
PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml")
@@ -41,15 +37,18 @@ def write_config(profiles: Dict[str, Profile]) -> None:
yaml.dump(converted, f)
def ensure_config(name: str) -> Profile:
def ensure_config(name: Optional[str]) -> Tuple[str, Profile]:
"""Ensure that the config exists and has the default section"""
# TODO we should copy a templated default config in to better document
# but this is complicated a bit by autodetecting the provider
default_profile_name = "default"
name = name or default_profile_name
default_profiles_dict = default_profiles()
provider, processor, accelerator = default_model_configuration()
profile = default_profiles()[name](provider, processor, accelerator)
default_profile = default_profiles_dict.get(name, default_profiles_dict[default_profile_name])(
provider, processor, accelerator
)
profiles = {}
if not PROFILES_CONFIG_PATH.exists():
print(
Panel(
@@ -58,49 +57,16 @@ def ensure_config(name: str) -> Profile:
+ "You can add your own profile in this file to further configure goose!"
)
)
default = profile
profiles = {name: default}
write_config(profiles)
return profile
write_config({name: default_profile})
return (name, default_profile)
profiles = read_config()
if name not in profiles:
print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]"))
profiles.update({name: profile})
write_config(profiles)
elif name in profiles:
# if the profile stored differs from the default one, we should prompt the user to see if they want
# to update it! we need to recursively compare the two profiles, as object comparison will always return false
is_profile_eq = profile.to_dict() == profiles[name].to_dict()
if not is_profile_eq:
yaml = YAML()
before = StringIO()
after = StringIO()
yaml.dump(profiles[name].to_dict(), before)
yaml.dump(profile.to_dict(), after)
before.seek(0)
after.seek(0)
print(
Panel(
Text(
f"Your profile uses one of the default options - '{name}'"
+ " - but it differs from the latest version:\n\n",
)
+ pretty_diff(before.read(), after.read())
)
)
should_update = Confirm.ask(
"Do you want to update your profile to use the latest?",
default=False,
)
if should_update:
profiles[name] = profile
write_config(profiles)
else:
profile = profiles[name]
return profile
if name in profiles:
return (name, profiles[name])
print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]"))
profiles.update({name: default_profile})
write_config(profiles)
return (name, default_profile)
def read_config() -> Dict[str, Profile]:
@@ -118,7 +84,6 @@ def default_model_configuration() -> Tuple[str, str, str]:
for provider, cls in providers.items():
try:
cls.from_env()
print(Panel(f"[green]Detected an available provider: [/]{provider}"))
break
except Exception:
pass

View File

@@ -12,7 +12,7 @@ from rich.panel import Panel
from rich.status import Status
from goose.build import build_exchange
from goose.cli.config import default_profiles, ensure_config, read_config, session_path, LOG_PATH
from goose.cli.config import ensure_config, session_path, LOG_PATH
from goose._logger import get_logger, setup_logging
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.notifier import Notifier
@@ -46,15 +46,9 @@ def load_provider() -> str:
def load_profile(name: Optional[str]) -> Profile:
if name is None:
name = "default"
# If the name is one of the default values, we ensure a valid configuration
if name in default_profiles():
return ensure_config(name)
# Otherwise this is a custom config and we return it from the config file
return read_config()[name]
(profile_name, profile) = ensure_config(name)
print(Panel(f"[green]Using profile[/]: {profile_name}, {{{profile.profile_info()}}}"))
return profile
class SessionNotifier(Notifier):

View File

@@ -39,6 +39,10 @@ class Profile:
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def profile_info(self) -> str:
tookit_names = [toolkit.name for toolkit in self.toolkits]
return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}"
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile:
"""Get the default profile"""

View File

@@ -1,39 +0,0 @@
from typing import List
from rich.text import Text
def diff(a: str, b: str) -> List[str]:
"""Returns a string containing the unified diff of two strings."""
import difflib
a_lines = a.splitlines()
b_lines = b.splitlines()
# Create a Differ object
d = difflib.Differ()
# Generate the diff
diff = list(d.compare(a_lines, b_lines))
return diff
def pretty_diff(a: str, b: str) -> Text:
"""Returns a pretty-printed diff of two strings."""
diff_lines = diff(a, b)
result = Text()
for line in diff_lines:
if line.startswith("+"):
result.append(line, style="green")
elif line.startswith("-"):
result.append(line, style="red")
elif line.startswith("?"):
result.append(line, style="yellow")
else:
result.append(line, style="dim grey")
result.append("\n")
return result

View File

@@ -28,53 +28,66 @@ def test_read_write_config(mock_profile_config_path, profile_factory):
assert read_config() == profiles
def test_ensure_config_create_profiles_file_with_default_profile(
def test_ensure_config_create_profiles_file_with_default_profile_with_name_default(
mock_profile_config_path, mock_default_model_configuration
):
assert not mock_profile_config_path.exists()
ensure_config(name="default")
(profile_name, profile) = ensure_config(name=None)
expected_profile = default_profile(*mock_default_model_configuration())
assert profile_name == "default"
assert profile == expected_profile
assert mock_profile_config_path.exists()
assert read_config() == {"default": default_profile(*mock_default_model_configuration())}
assert read_config() == {"default": expected_profile}
def test_ensure_config_add_default_profile(mock_profile_config_path, profile_factory, mock_default_model_configuration):
def test_ensure_config_create_profiles_file_with_default_profile_with_profile_name(
mock_profile_config_path, mock_default_model_configuration
):
assert not mock_profile_config_path.exists()
(profile_name, profile) = ensure_config(name="my_profile")
expected_profile = default_profile(*mock_default_model_configuration())
assert profile_name == "my_profile"
assert profile == expected_profile
assert mock_profile_config_path.exists()
assert read_config() == {"my_profile": expected_profile}
def test_ensure_config_add_default_profile_when_profile_not_exist(
mock_profile_config_path, profile_factory, mock_default_model_configuration
):
existing_profile = profile_factory({"provider": "providerA"})
write_config({"profile1": existing_profile})
ensure_config(name="default")
(profile_name, new_profile) = ensure_config(name="my_new_profile")
expected_profile = default_profile(*mock_default_model_configuration())
assert profile_name == "my_new_profile"
assert new_profile == expected_profile
assert read_config() == {
"profile1": existing_profile,
"default": default_profile(*mock_default_model_configuration()),
"my_new_profile": expected_profile,
}
@patch("goose.cli.config.Confirm.ask", return_value=True)
def test_ensure_config_overwrite_default_profile(
mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration
def test_ensure_config_get_existing_profile_not_exist(
mock_profile_config_path, profile_factory, mock_default_model_configuration
):
existing_profile = profile_factory({"provider": "providerA"})
profile_name = "default"
write_config({profile_name: existing_profile})
write_config({"profile1": existing_profile})
expected_default_profile = default_profile(*mock_default_model_configuration())
assert ensure_config(name="default") == expected_default_profile
assert read_config() == {"default": expected_default_profile}
(profile_name, profile) = ensure_config(name="profile1")
@patch("goose.cli.config.Confirm.ask", return_value=False)
def test_ensure_config_keep_original_default_profile(
mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration
):
existing_profile = profile_factory({"provider": "providerA"})
profile_name = "default"
write_config({profile_name: existing_profile})
assert ensure_config(name="default") == existing_profile
assert read_config() == {"default": existing_profile}
assert profile_name == "profile1"
assert profile == existing_profile
assert read_config() == {
"profile1": existing_profile,
}
def test_session_path(mock_sessions_path):

12
tests/test_profile.py Normal file
View File

@@ -0,0 +1,12 @@
from goose.profile import ToolkitSpec
def test_profile_info(profile_factory):
profile = profile_factory(
{
"provider": "provider",
"processor": "processor",
"toolkits": [ToolkitSpec("developer"), ToolkitSpec("github")],
}
)
assert profile.profile_info() == "provider:provider, processor:processor toolkits: developer, github"