chore: initial commit

Co-authored-by: Lifei Zhou <lifei@squareup.com>
Co-authored-by: Mic Neale <micn@tbd.email>
Co-authored-by: Lily Delalande <ldelalande@squareup.com>
Co-authored-by: Bradley Axen <baxen@squareup.com>
Co-authored-by: Andy Lane <alane@squareup.com>
Co-authored-by: Elena Zherdeva <ezherdeva@squareup.com>
Co-authored-by: Zaki Ali <zaki@squareup.com>
Co-authored-by: Salman Mohammed <smohammed@squareup.com>
This commit is contained in:
Luke Alvoeiro
2024-08-23 16:39:04 -07:00
commit dd126afa6c
68 changed files with 4498 additions and 0 deletions

140
.gitignore vendored Normal file
View File

@@ -0,0 +1,140 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# Ignore mkdocs site files generated locally for testing/validation, but generated
# at buildtime in production
site/
docs/docs/notebooks*
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Hermit
.hermit
# VSCode
.vscode
# Autogenerated docs files
docs/docs/reference
## goose session files
.goose

6
.ruff.toml Normal file
View File

@@ -0,0 +1,6 @@
lint.select = ["E", "W", "F", "N", "ANN"]
lint.ignore = ["ANN101"]
exclude = [
"docs",
]
line-length = 120

166
ARCHITECTURE.md Normal file
View File

@@ -0,0 +1,166 @@
# Architecture
## The System
Goose extends the capabilities of high-performing LLMs through a small collection of tools.
This lets you instruct goose, currently via a CLI interface, to automatically solve problems
on your behalf. It attempts to not just tell you how you can do something, but to actually do it for you.
The primary mode of goose (the "developer" toolkit) has access to tools to
- maintain a plan
- run shell commands
- read, create, and edit files
Together these can solve all kinds of problems, and we emphasize performance on tasks like
fully automating adhoc scripts and tasks, interacting with existing code bases, and teaching how
to use new technology.
---
Here are some of the key design decisions about how we drive performance on these tasks with goose,
that you should be able to observe by using it.
- Encouraging it to write and maintain a plan, to allow it to accomplish longer sequences of automation
- Using tool usage as a generalizable and increasingly tuned approach to adding new capabilities (including plugins)
- Relying on reflection at every possible part of the stack
- Showing it clear output of each tool use
- Surfacing all possible errors to the model to give it a chance to correct
- Surfacing the plan to document what has been accomplished
> [!TIP]
> In addition, there are some implementation choices that we've found very performance driving. The share
> a theme of working well by default without constraining the model.
>
> - Encouraging the model to use `ripgrep` via the shell performs very well for navigating filesystems. It mostly
> just works, but enables the model to get clever with regexes or even additional shell operations as needed.
> - Using a replace operation for editing files requires fewer tokens to be generated and avoids laziness on large files,
> but we allow fall back to whole file overwrites to let it more coherently handle major refactors.
## Implementation
The core execution logic for generation and tool calling is handled by [exchange][exchange].
It hooks python functions into the model tool use loop, while defining very careful error handling
so any failures in tools are surfaced to the model.
Once we've created an *exchange* object, running the process is effectively just calling
`exchange.reply()`.
*The key is setting up an exchange with the capabilities we need.*
Goose builds that exchange:
- allows users to configure a profile to customize capabilities
- provides a pluggable system for adding tools and prompts
- sets up the tools to interact with state
We expect that goose will have multiple UXs over time, and be run in different
environments. The UX is expected to be able to load a `Profile` (e.g. in the CLI
we read profiles out of a config) and to provide a `Notifier` (e.g. in the CLI we put
notifications on stdout).
Goose then constructs the exchange for the UX, the UX only interacts with that exchange.
```
def build_exchange(profile: Profile, notifier: Notifier) -> Exchange:
...
```
But to setup a configurable system, Goose uses `Toolkit`s:
```
(Profile, Notifier) -> [Toolkits] -> Exchange
```
## Profile
A profile specifies some basic configuration in Goose, such as which models it should use, as well
as which toolkits it should include.
```yaml
processor: openai:gpt-4o
accelerator: openai:gpt-4o-mini
moderator: passive
toolkits:
- assistant
- calendar
- contacts
- name: scheduling
requires:
assistant: assistant
calendar: calendar
contacts: contacts
```
## Notifier
The notifier is a concrete implementation of the Notifier base class provided by each UX. It
needs to support two methods
```python
class Notifier:
def log(self, RichRenderable):
...
def status(self, str):
...
```
Log is meant to record something concrete that happened, such as a tool being called, and status is intended
for transient displays of the current status. For example, while a shell command is running, it might use
`.log` to record the command that started, and then update the status to `"shell command running"`. Log is durable
while Status is ephemeral.
## Toolkits
Toolkits are a collection of tools, along with the state and prompting they require.
Toolkits are what gives Goose its capabilities.
Tools need a way to report what's happening back to the user, which we treat similarly
to logging. To make that possible, toolkits get a reference to the interface described above.
```python
class ScheduleToolkit(Toolkit):
def __init__(self, notifier: Notifier, requires: Requirements, **kwargs):
super().__init__(notifier, requires, **kwargs) # handles the interface, exchangeview
# for a class that has requirements, you can get them like this
self.calendar = requires.get("calendar")
self.assistant = requires.get("assistant")
self.contacts = requires.get("contacts")
self.appointments_state = []
def prompt(self) -> str:
return "Try out the example tool."
@tool
def example(self):
self.interface.log(f"An example tool was called, current state is {self.state}")
```
### Advanced
**Dependencies**: Toolkits can depend on each other, to make it easier to get plugins to extend
or modify existing capabilities. In the config above, you can see this used for the scheduling toolkit.
You can refer to those requirements in code through:
```python
@tool
def example_dependency(self):
appointments = self.dependencies["calendar"].appointments
...
```
**ExchangeView**: It can also be useful for tools to have a read-only copy of the history
of the loop so far. So for advanced use cases, toolkits also have access to an
`ExchangeView` object.
```python
@tool
def example_history(self):
last_message = self.exchange_view.processor.messages[-1]
...
```
[exchange]: https://github.com/squareup/exchange

0
CHANGELOG.md Normal file
View File

122
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,122 @@
# Contributing
We welcome Pull Requests for general contributions. If you have a larger new feature or any questions on how
to develop a fix, we recommend you open an issue before starting.
## Prerequisites
*goose* uses [uv][uv] for dependency management, and formats with [ruff][ruff].
We provide a shortcut to standard commands using [just][just] in our `justfile`.
## Developing
Now that you have a local environment, you can make edits and run our tests.
```
uv run pytest tests -m "not integration"
```
or, as a shortcut,
```
just test
```
## Evaluations
Given that so much of *goose* involves interactions with LLMs, our unit tests only go so far to
confirming things work as intended.
We're currently developing a suite of evalutions, to make it easier to make improvements to *goose* more confidently.
In the meantime, we typically incubate any new additions that change the behavior of the *goose*
through **opt-in** plugins - `Toolkit`s, `Moderator`s, and `Provider`s. We welcome contributions of plugins
that add new capabilities to *goose*. We recommend sending in several examples of the new capabilities
in action with your pull request.
Additions to the [developer toolkit][developer] change the core performance, and so will need to be measured carefully.
## Build a Toolkit
To add a toolkit, start out with a plugin as mentioned above. In your code (which doesn't necessarily need to be
in the goose package thanks to [plugin metadata][plugin]!), create a class that derives from Toolkit.
```python
import os
import platform
from goose.toolkit.base import Toolkit, tool
class Demo(Toolkit):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Provide any additional tools as needed!
# The docstring of the tool provides instructions to the LLM, so they are important to tune
# you do not have to provide any tools, but any function decorated with @tool will be available
@tool
def authenticate(self, user: str):
"""Output an authentication code for this user
Args:
user (str): The username to authenticate for
"""
# notifier supports any rich renderable https://rich.readthedocs.io/en/stable/introduction.html#quick-start
self.notifier.log(f"[bold red]auth: {str(hash(user))}[/]")
# Provide any system instructions for the model
# This can be generated dynamically, and is run at startup time
def system(self) -> str:
print("new")
return f"""**You must preceed your first message by using the authenticate tool for the current user**
```
platform: {platform.system()}
cwd: {os.getcwd()}
user: {os.environ.get('USER')}
```
"""
```
To make the toolkit available, add it as a plugin. For example in a pyproject.toml
```
[project.entry-points."goose.toolkit"]
developer = "goose.toolkit.developer:Developer"
github = "goose.toolkit.github:Github"
# Add a line like this - the key becomes the name used in profiles
demo = "goose.toolkit.demo:Demo"
```
And then to setup a profile that uses it, add something to ~/.config/goose/profiles.yaml
```yaml
default:
provider: openai
processor: gpt-4o
accelerator: gpt-4o-mini
moderator: passive
toolkits:
- name: developer
requires: {}
demo:
provider: openai
processor: gpt-4o
accelerator: gpt-4o-mini
moderator: passive
toolkits:
- developer
- demo
```
And now you can run goose with this new profile to use the new toolkit!
```sh
goose session start --profile demo
```
[developer]: src/goose/toolkit/developer.py
[uv]: https://docs.astral.sh/uv/
[ruff]: https://docs.astral.sh/ruff/
[just]: https://github.com/casey/just
[plugin]: https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

167
README.md Normal file
View File

@@ -0,0 +1,167 @@
<h1 align="center">
goose
</h1>
<p align="center"><strong>goose</strong> <em>is a programming agent that runs on your machine.</em></p>
<p align="center">
<a href="https://opensource.org/licenses/Apache-2.0"><img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg"></a>
</p>
<p align="center">
<a href="#usage">Usage</a>
<a href="#installation">Installation</a>
<a href="#tips">Tips</a>
</p>
`goose` assists in solving a wide range of programming and operational tasks. It is a live virtual developer you can interact with, guide, and learn from.
To solve problems, goose breaks down instructions into sequences of tasks and carries them out using tools. Its ability to connect its changes with real outcomes (e.g. errors) and course correct is its most powerful and exciting feature. goose is free open source software and is built to be extensible and customizable.
## Usage
You interact with goose in conversational sessions - something like a natural language driven code interpreter.
The default toolkit lets it take actions through shell commands and file edits.
You can interrupt Goose at any time to help redirect its efforts.
From your terminal, navigate to the directory you'd like to start from and run:
```sh
goose session start
```
You will see a prompt `G`:
```
G type your instructions here exactly as you would tell a developer.
```
From here you can talk directly with goose - send along your instructions. If you are looking to exit, use `CTRL+D`,
although goose should help you figure that out if you forget.
When you exit a session, it will save the history and you can resume it later on:
``` sh
goose session resume
```
## Tips
Here are some collected tips we have for working efficiently with Goose
- **goose can and will edit files**. Use a git strategy to avoid losing anything - such as staging your
personal edits and leaving goose edits unstaged until reviewed. Or consider using indivdual commits which can be reverted.
- You can interrupt goose with `CTRL+C` to correct it or give it more info.
- goose works best when solving concrete problems - experiment with how far you need to break that problem
down to get goose to solve it. Be specific! E.g. it will likely fail to `"create a banking app"`,
but probably does a good job if prompted with `"create a Fastapi app with an endpoint for deposit and withdrawal
and with account balances stored in mysql keyed by id"`
- If goose doesn't have enough context to start with, it might go down the wrong direction. Tell it
to read files that you are refering to or search for objects in code. Even better, ask it to summarize
them for you, which will help it set up its own next steps.
- Refer to any objects in files with something that is easy to search for, such as `"the MyExample class"
- goose *loves* to know how to run tests to get a feedback loop going, just like you do. If you tell it how you test things locally and quickly, it can make use of that when working on your project
- You can use goose for tasks that would require scripting at times, even looking at your screen and correcting designs/helping you fix bugs, try asking it to help you in a way you would ask a person.
- goose will make mistakes, and go in the wrong direction from times, feel free to correct it, or start again.
- You can tell goose to run things for you continuously (and it will iterate, try, retry) but you can also tell it to check with you before doing things (and then later on tell it to go off on its own and do its best to solve).
- Goose can run anywhere, doesn't have to be in a repo, just ask it!
## Installation
To install goose, we recommend `pipx`
First make sure you've [installed pipx][pipx] - for example
``` sh
brew install pipx
pipx ensurepath
```
Then you can install goose with
``` sh
pipx install goose
```
### Config
Goose will try to detect what LLM it can work with and place a config in `~/.config/goose/profiles.yaml` automatically.
#### Toolkits
Goose can be extended with toolkits, and out of the box there are some available:
* `screen`: for letting goose take a look at your screen to help debug or work on designs (gives goose eyes)
* `github`: for awareness and suggestions on how to use github
* `repo_context`: for summarizing and understanding a repository you are working in.
To configure for example the screen toolkit, edit `~/.config/goose/profiles.yaml`:
```yaml
provider: openai
processor: gpt-4o
accelerator: gpt-4o-mini
moderator: passive
toolkits:
- name: developer
requires: {}
- name: screen
requires: {}
```
#### Advanced LLM config
goose works on top of LLMs (you bring your own LLM). If you need to customize goose, one way is via editing: `~/.config/goose/profiles.yaml`.
It will look by default something like:
```yaml
default:
provider: openai
processor: gpt-4o
accelerator: gpt-4o-mini
moderator: truncate
toolkits:
- name: developer
requires: {}
```
*Note: This requires the environment variable `OPENAI_API_KEY` to be set to your OpenAI API key. goose uses at least 2 LLMs: one for acceleration for fast operating, and processing for writing code and executing commands.*
You can tell it to use another provider for example for Anthropic:
```yaml
default:
provider: anthropic
processor: claude-3-5-sonnet-20240620
accelerator: claude-3-5-sonnet-20240620
...
```
*Note: This will then use the claude-sonnet model, you will need to set the `ANTHROPIC_API_KEY` environment variable to your anthropic API key.*
For Databricks hosted models:
```yaml
default:
provider: databricks
processor: databricks-meta-llama-3-1-70b-instruct
accelerator: databricks-meta-llama-3-1-70b-instruct
moderator: passive
toolkits:
- name: developer
requires: {}
```
This requires `DATABRICKS_HOST` and `DATABRICKS_TOKEN` to be set accordingly
(goose can be extended to support any LLM or combination of LLMs).
## Open Source
Yes, goose is open source and always will be. goose is released under the ASL2.0 license meaning you can use it however you like.
See LICENSE.md for more details.
[pipx]: https://github.com/pypa/pipx?tab=readme-ov-file#install-pipx

19
justfile Normal file
View File

@@ -0,0 +1,19 @@
# This is the default recipe when no arguments are provided
[private]
default:
@just --list --unsorted
test *FLAGS:
uv run pytest tests -m "not integration" {{FLAGS}}
integration *FLAGS:
uv run pytest tests -m integration {{FLAGS}}
format:
ruff check . --fix
ruff format .
coverage *FLAGS:
uv run coverage run -m pytest tests -m "not integration" {{FLAGS}}
uv run coverage report
uv run coverage lcov -o lcov.info

45
pyproject.toml Normal file
View File

@@ -0,0 +1,45 @@
[project]
name = "goose"
description = "a programming agent that runs on your machine"
version = "0.8.0"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"attrs>=23.2.0",
"rich>=13.7.1",
"ruamel-yaml>=0.18.6",
"exchange>=0.7.6",
"click>=8.1.7",
"prompt-toolkit>=3.0.47",
]
author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }]
packages = [{ include = "goose", from = "src" }]
[project.entry-points."goose.toolkit"]
developer = "goose.toolkit.developer:Developer"
github = "goose.toolkit.github:Github"
screen = "goose.toolkit.screen:Screen"
repo_context = "goose.toolkit.repo_context.repo_context:RepoContext"
[project.entry-points."goose.profile"]
default = "goose.profile:default_profile"
[project.entry-points."goose.command"]
file = "goose.command.file:FileCommand"
[project.entry-points."goose.cli"]
goose = "goose.cli.main:goose_cli"
[project.scripts]
goose = "goose.cli.main:cli"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.uv]
dev-dependencies = [
"pytest>=8.3.2",
"codecov>=2.1.13",
]

0
src/goose/__init__.py Normal file
View File

65
src/goose/build.py Normal file
View File

@@ -0,0 +1,65 @@
from itertools import chain
from exchange import Exchange, Message
from exchange.moderators import get_moderator
from exchange.providers import get_provider
from goose.notifier import Notifier
from goose.profile import Profile
from goose.toolkit import get_toolkit
from goose.toolkit.base import Requirements
from goose.view import ExchangeView
def build_exchange(profile: Profile, notifier: Notifier) -> Exchange:
"""Build an exchange configured through the profile
This will setup any toolkits and use that to build the exchange's collection
of tools.
Args:
profile (Profile): The profile specifying how to setup this exchange
notifier (Notifier): A notifier instance used by tools to send info
"""
provider = get_provider(profile.provider).from_env()
# Support instantating toolkits in *two* passes for now, no further nesting
concrete_toolkits = {}
# First instantiate all toolkits that are sub dependencies
for spec in profile.toolkits:
for required in spec.requires.values():
concrete_toolkits[required] = get_toolkit(required)(notifier=notifier, requires=Requirements(required))
# Now that we have the dependencies available, we can instantiate everything else
toolkits = []
for spec in profile.toolkits:
if spec.name in concrete_toolkits:
toolkits.append(concrete_toolkits[spec.name])
continue
requires = Requirements(
spec.name,
{key: concrete_toolkits[val] for key, val in spec.requires.items()},
)
toolkit = get_toolkit(spec.name)(notifier=notifier, requires=requires)
toolkits.append(toolkit)
# From the toolkits, we derive the exchange prompt and tools
system = "\n\n".join([Message.load("system.jinja").text] + [toolkit.system() for toolkit in toolkits])
tools = tuple(chain(*(toolkit.tools() for toolkit in toolkits)))
exchange = Exchange(
provider=provider,
system=system,
tools=tools,
moderator=get_moderator(profile.moderator)(),
model=profile.processor,
)
# This is a bit awkward, but we have to set this after the fact because building
# the exchange requires having the toolkits
for toolkit in toolkits:
toolkit.exchange_view = ExchangeView(profile.processor, profile.accelerator, exchange)
return exchange

View File

141
src/goose/cli/config.py Normal file
View File

@@ -0,0 +1,141 @@
from functools import cache
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Mapping, Tuple
from rich import print
from rich.panel import Panel
from rich.prompt import Confirm
from rich.text import Text
from ruamel.yaml import YAML
from goose.profile import Profile
from goose.utils import load_plugins
from goose.utils.diff import pretty_diff
GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser()
PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml")
SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions")
SESSION_FILE_SUFFIX = ".jsonl"
@cache
def default_profiles() -> Mapping[str, Callable]:
return load_plugins(group="goose.profile")
def session_path(name: str) -> Path:
SESSIONS_PATH.mkdir(parents=True, exist_ok=True)
return SESSIONS_PATH.joinpath(f"{name}{SESSION_FILE_SUFFIX}")
def write_config(profiles: Dict[str, Profile]) -> None:
"""Overwrite the config with the passed profiles"""
PROFILES_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
converted = {name: profile.to_dict() for name, profile in profiles.items()}
yaml = YAML()
with PROFILES_CONFIG_PATH.open("w") as f:
yaml.dump(converted, f)
def ensure_config(name: str) -> Profile:
"""Ensure that the config exists and has the default section"""
# TODO we should copy a templated default config in to better document
# but this is complicated a bit by autodetecting the provider
provider, processor, accelerator = default_model_configuration()
profile = default_profiles()[name](provider, processor, accelerator)
profiles = {}
if not PROFILES_CONFIG_PATH.exists():
print(
Panel(
f"[yellow]No configuration present, we will create a profile '{name}'"
+ f" at: [/]{str(PROFILES_CONFIG_PATH)}\n"
+ "You can add your own profile in this file to further configure goose!"
)
)
default = profile
profiles = {name: default}
write_config(profiles)
return profile
profiles = read_config()
if name not in profiles:
print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]"))
profiles.update({name: profile})
write_config(profiles)
elif name in profiles:
# if the profile stored differs from the default one, we should prompt the user to see if they want
# to update it! we need to recursively compare the two profiles, as object comparison will always return false
is_profile_eq = profile.to_dict() == profiles[name].to_dict()
if not is_profile_eq:
yaml = YAML()
before = StringIO()
after = StringIO()
yaml.dump(profiles[name].to_dict(), before)
yaml.dump(profile.to_dict(), after)
before.seek(0)
after.seek(0)
print(
Panel(
Text(
f"Your profile uses one of the default options - '{name}'"
+ " - but it differs from the latest version:\n\n",
)
+ pretty_diff(before.read(), after.read())
)
)
should_update = Confirm.ask(
"Do you want to update your profile to use the latest?",
default=False,
)
if should_update:
profiles[name] = profile
write_config(profiles)
else:
profile = profiles[name]
return profile
def read_config() -> Dict[str, Profile]:
"""Return config from the configuration file and validates its contents"""
yaml = YAML()
with PROFILES_CONFIG_PATH.open("r") as f:
data = yaml.load(f)
return {name: Profile(**profile) for name, profile in data.items()}
def default_model_configuration() -> Tuple[str, str, str]:
providers = load_plugins(group="exchange.provider")
for provider, cls in providers.items():
try:
cls.from_env()
print(Panel(f"[green]Detected an available provider: [/]{provider}"))
break
except Exception:
pass
else:
raise ValueError(
"Could not detect an available provider,"
+ " make sure to plugin a provider or set an env var such as OPENAI_API_KEY"
)
recommended = {
"openai": ("gpt-4o", "gpt-4o-mini"),
"anthropic": (
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20240620",
),
"databricks": (
# TODO when function calling is first rec should be: "databricks-meta-llama-3-1-405b-instruct"
"databricks-meta-llama-3-1-70b-instruct",
"databricks-meta-llama-3-1-70b-instruct",
),
}
processor, accelerator = recommended.get(provider, ("gpt-4o", "gpt-4o-mini"))
return provider, processor, accelerator

125
src/goose/cli/main.py Normal file
View File

@@ -0,0 +1,125 @@
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional
import click
from rich import print
from ruamel.yaml import YAML
from goose.cli.config import SESSIONS_PATH
from goose.cli.session import Session
from goose.utils import load_plugins
from goose.utils.session_file import list_sorted_session_files
@click.group()
def goose_cli() -> None:
pass
@goose_cli.command()
def version() -> None:
"""Lists the version of goose and any plugins"""
from importlib.metadata import entry_points, version
print(f"[green]Goose[/green]: [bold][cyan]{version('goose')}[/cyan][/bold]")
print("[green]Plugins[/green]:")
filtered_groups = {}
modules = set()
if sys.version_info.minor >= 12:
for ep in entry_points():
group = getattr(ep, "group", None)
if group and (group.startswith("exchange.") or group.startswith("goose.")):
filtered_groups.setdefault(group, []).append(ep)
for eps in filtered_groups.values():
for ep in eps:
module_name = ep.module.split(".")[0]
modules.add(module_name)
else:
eps = entry_points()
for group, entries in eps.items():
if group.startswith("exchange.") or group.startswith("goose."):
for entry in entries:
module_name = entry.value.split(".")[0]
modules.add(module_name)
modules.remove("goose")
for module in sorted(list(modules)):
# TODO: figure out how to get this to work for goose plugins block
# as the module name is set to block.goose.cli
# module_name = 'goose-plugins-block'
try:
module_version = version(module)
print(f" Module: [green]{module}[/green], Version: [bold][cyan]{module_version}[/cyan][/bold]")
except Exception as e:
print(f" [red]Could not retrieve version for {module}: {e}[/red]")
@goose_cli.group()
def session() -> None:
"""Start or manage sessions"""
pass
@session.command(name="start")
@click.option("--profile")
@click.option("--plan", type=click.Path(exists=True))
def session_start(profile: str, plan: Optional[str] = None) -> None:
"""Start a new goose session"""
if plan:
yaml = YAML()
with open(plan, "r") as f:
_plan = yaml.load(f)
else:
_plan = None
session = Session(profile=profile, plan=_plan)
session.run()
@session.command(name="resume")
@click.argument("name", required=False)
@click.option("--profile")
def session_resume(name: str, profile: str) -> None:
"""Resume an existing goose session"""
if name is None:
session_files = get_session_files()
if session_files:
name = list(session_files.keys())[0]
print(f"Resuming most recent session: {name} from {session_files[name]}")
else:
print("No sessions found.")
return
session = Session(name=name, profile=profile)
session.run()
@session.command(name="list")
def session_list() -> None:
session_files = get_session_files().items()
for session_name, session_file in session_files:
print(f"{datetime.fromtimestamp(session_file.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')} {session_name}")
@session.command(name="clear")
@click.option("--keep", default=3, help="Keep this many entries, default 3")
def session_clear(keep: int) -> None:
for i, (_, session_file) in enumerate(get_session_files().items()):
if i >= keep:
session_file.unlink()
def get_session_files() -> Dict[str, Path]:
return list_sorted_session_files(SESSIONS_PATH)
# merging goose cli with additional cli plugins.
def cli() -> None:
clis = load_plugins("goose.cli")
cli_list = list(clis.values()) or []
click.CommandCollection(sources=cli_list)()
if __name__ == "__main__":
cli()

View File

View File

@@ -0,0 +1,46 @@
import re
from typing import List
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
from prompt_toolkit.document import Document
from goose.command.base import Command
class GoosePromptCompleter(Completer):
def __init__(self, commands: List[Command]) -> None:
self.commands = commands
def get_command_completions(self, document: Document) -> List[Completion]:
all_completions = []
for command_name, command_instance in self.commands.items():
pattern = rf"(?<!\S)\/{command_name}:([\S]*)$"
text = document.text_before_cursor
match = re.search(pattern=pattern, string=text)
if not match or text.endswith(" "):
continue
query = match.group(1)
completions = command_instance.get_completions(query)
all_completions.extend(completions)
return all_completions
def get_command_name_completions(self, document: Document) -> List[Completion]:
pattern = r"(?<!\S)\/([\S]*)$"
text = document.text_before_cursor
match = re.search(pattern=pattern, string=text)
if not match or text.endswith(" "):
return []
query = match.group(1)
completions = []
for command_name in self.commands:
if command_name.startswith(query):
completions.append(Completion(command_name, start_position=-len(query), display=command_name))
return completions
def get_completions(self, document: Document, _: CompleteEvent) -> List[Completion]:
command_completions = self.get_command_completions(document)
command_name_completions = self.get_command_name_completions(document)
return command_name_completions + command_completions

View File

@@ -0,0 +1,66 @@
from prompt_toolkit import PromptSession
from prompt_toolkit.application.current import get_app
from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent
from prompt_toolkit.keys import Keys
from prompt_toolkit.styles import Style
from goose.cli.prompt.completer import GoosePromptCompleter
from goose.cli.prompt.lexer import PromptLexer
from goose.command import get_commands
def create_prompt() -> PromptSession:
# Define custom style
style = Style.from_dict(
{
"parameter": "bold",
"command": "ansiblue bold",
"text": "default",
}
)
bindings = KeyBindings()
# Bind the "Option + Enter" key to insert a newline
@bindings.add(Keys.Escape, Keys.ControlM)
def _(event: KeyPressEvent) -> None:
buffer = event.app.current_buffer
buffer.insert_text("\n")
# Bind the "Enter" key to accept the completion if the completion menu is open
# otherwise just submit the input
@bindings.add(Keys.Enter)
def _(event: KeyPressEvent) -> None:
buffer = event.current_buffer
app = get_app()
if app.layout.has_focus(buffer):
# Check if the completion menu is open
if buffer.complete_state:
# accept completion
buffer.complete_state = None
else:
buffer.validate_and_handle()
@bindings.add(Keys.ControlY)
def _(event: KeyPressEvent) -> None:
buffer = event.app.current_buffer
app = get_app()
if app.layout.has_focus(buffer):
# Check if the completion menu is open
if buffer.complete_state:
# accept completion
buffer.complete_state = None
# instantiate the commands available in the prompt
commands = dict()
command_plugins = get_commands()
for command, command_cls in command_plugins.items():
commands[command] = command_cls()
return PromptSession(
completer=GoosePromptCompleter(commands=commands),
lexer=PromptLexer(list(commands.keys())),
style=style,
key_bindings=bindings,
)

View File

@@ -0,0 +1,34 @@
from typing import Optional
from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.validation import DummyValidator
from goose.cli.prompt.create import create_prompt
from goose.cli.prompt.prompt_validator import PromptValidator
from goose.cli.prompt.user_input import PromptAction, UserInput
class GoosePromptSession:
def __init__(self, prompt_session: PromptSession) -> None:
self.prompt_session = prompt_session
@staticmethod
def create_prompt_session() -> "GoosePromptSession":
return GoosePromptSession(create_prompt())
def get_user_input(self) -> "UserInput":
try:
text = FormattedText([("#00AEAE", "G ")]) # Define the prompt style and text.
message = self.prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False)
if message.strip() in ("exit", ":q"):
return UserInput(PromptAction.EXIT)
return UserInput(PromptAction.CONTINUE, message)
except (EOFError, KeyboardInterrupt):
return UserInput(PromptAction.EXIT)
def get_save_session_name(self) -> Optional[str]:
return self.prompt_session.prompt(
"Enter a name to save this session under. A name will be generated for you if empty: ",
validator=DummyValidator(),
)

View File

@@ -0,0 +1,53 @@
import re
from typing import Callable, List, Tuple
from prompt_toolkit.document import Document
from prompt_toolkit.lexers import Lexer
def completion_for_command(target_string: str) -> re.Pattern[str]:
escaped_string = re.escape(target_string)
vals = [f"(?:{escaped_string[:i]}(?=$))" for i in range(len(escaped_string), 0, -1)]
return re.compile(rf'(?<!\S)\/({"|".join(vals)})(?:\s^|$)')
def command_itself(target_string: str) -> re.Pattern[str]:
escaped_string = re.escape(target_string)
return re.compile(rf"(?<!\S)(\/{escaped_string})")
def value_for_command(command_string: str) -> re.Pattern[str]:
escaped_string = re.escape(command_string)
return re.compile(rf"(?<=(?<!\S)\/{escaped_string})([^\s]*)")
class PromptLexer(Lexer):
def __init__(self, command_names: List[str]) -> None:
self.patterns = []
for command_name in command_names:
full_command = command_name + ":"
self.patterns.append((completion_for_command(full_command), "class:command"))
self.patterns.append((value_for_command(full_command), "class:parameter"))
self.patterns.append((command_itself(full_command), "class:command"))
def lex_document(self, document: Document) -> Callable[[int], list]:
def get_line_tokens(line_number: int) -> Tuple[str, str]:
line = document.lines[line_number]
tokens = []
i = 0
while i < len(line):
match = None
for pattern, token in self.patterns:
match = pattern.match(line, i)
if match:
tokens.append((token, match.group()))
i = match.end()
break
if not match:
tokens.append(("class:text", line[i]))
i += 1
return tokens
return get_line_tokens

View File

@@ -0,0 +1,10 @@
from prompt_toolkit.document import Document
from prompt_toolkit.validation import ValidationError, Validator
class PromptValidator(Validator):
def validate(self, document: Document) -> None:
text = document.text
if text is not None and not text.strip():
message = "Enter your prompt to goose. If you would like to exit, use CTRL+D, or type 'exit' or ':q'"
raise ValidationError(message=message, cursor_position=0)

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class PromptAction(Enum):
CONTINUE = 1
EXIT = 2
@dataclass
class UserInput:
action: PromptAction
text: Optional[str] = None
def to_exit(self) -> bool:
return self.action == PromptAction.EXIT
def to_continue(self) -> bool:
return self.action == PromptAction.CONTINUE

237
src/goose/cli/session.py Normal file
View File

@@ -0,0 +1,237 @@
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional
from exchange import Message, ToolResult, ToolUse
from prompt_toolkit.shortcuts import confirm
from rich import print
from rich.console import RenderableType
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.status import Status
from goose.build import build_exchange
from goose.cli.config import (
default_profiles,
ensure_config,
read_config,
session_path,
)
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.notifier import Notifier
from goose.profile import Profile
from goose.utils import droid, load_plugins
from goose.utils.session_file import read_from_file, write_to_file
def load_provider() -> str:
# We try to infer a provider, by going in order of what will auth
providers = load_plugins(group="exchange.provider")
for provider, cls in providers.items():
try:
cls.from_env()
print(Panel(f"[green]Detected an available provider: [/]{provider}"))
return provider
except Exception:
pass
else:
# TODO link to auth docs
print(
Panel(
"[red]Could not authenticate any providers[/]\n"
+ "Returning a default pointing to openai, but you will need to set an API token env variable."
)
)
return "openai"
def load_profile(name: Optional[str]) -> Profile:
if name is None:
name = "default"
# If the name is one of the default values, we ensure a valid configuration
if name in default_profiles():
return ensure_config(name)
# Otherwise this is a custom config and we return it from the config file
return read_config()[name]
class SessionNotifier(Notifier):
def __init__(self, status_indicator: Status) -> None:
self.status_indicator = status_indicator
def log(self, content: RenderableType) -> None:
print(content)
def status(self, status: str) -> None:
self.status_indicator.update(status)
class Session:
"""A session handler for managing interactions between a user and the Goose exchange
This class encapsulates the entire user interaction cycle, from input prompt to response handling,
including interruptions and error management.
"""
def __init__(
self,
name: Optional[str] = None,
profile: Optional[str] = None,
plan: Optional[dict] = None,
**kwargs: Dict[str, Any],
) -> None:
self.name = name
self.status_indicator = Status("", spinner="dots")
notifier = SessionNotifier(self.status_indicator)
self.exchange = build_exchange(profile=load_profile(profile), notifier=notifier)
if name is not None and self.session_file_path.exists():
messages = self.load_session()
if messages and messages[-1].role == "user":
messages.pop()
self.exchange.messages.extend(messages)
if len(self.exchange.messages) == 0 and plan:
self.setup_plan(plan=plan)
self.prompt_session = GoosePromptSession.create_prompt_session()
def setup_plan(self, plan: dict) -> None:
if len(self.exchange.messages):
raise ValueError("The plan can only be set on an empty session.")
self.exchange.messages.append(Message.user(plan["kickoff_message"]))
tasks = []
if "tasks" in plan:
tasks = [dict(description=task, status="planned") for task in plan["tasks"]]
plan_tool_use = ToolUse(id="initialplan", name="update_plan", parameters=dict(tasks=tasks))
self.exchange.add_tool_use(plan_tool_use)
def process_first_message(self) -> Optional[Message]:
# Get a first input unless it has been specified, such as by a plan
if len(self.exchange.messages) == 0 or self.exchange.messages[-1].role == "assistant":
user_input = self.prompt_session.get_user_input()
if user_input.to_exit():
return None
return Message.user(text=user_input.text)
return self.exchange.messages.pop()
def run(self) -> None:
"""
Runs the main loop to handle user inputs and responses.
Continues until an empty string is returned from the prompt.
"""
message = self.process_first_message()
while message: # Loop until no input (empty string).
with Live(self.status_indicator, refresh_per_second=8, transient=True):
try:
self.exchange.add(message)
self.reply() # Process the user message.
except KeyboardInterrupt:
self.interrupt_reply()
except Exception:
print(traceback.format_exc())
if self.exchange.messages:
self.exchange.messages.pop()
print(
"\n[red]The error above was an exception we were not able to handle.\n\n[/]"
+ "These errors are often related to connection or authentication\n"
+ "We've removed your most recent input"
+ " - [yellow]depending on the error you may be able to continue[/]"
)
print() # Print a newline for separation.
user_input = self.prompt_session.get_user_input()
message = Message.user(text=user_input.text) if user_input.to_continue() else None
self.save_session()
def reply(self) -> None:
"""Reply to the last user message, calling tools as needed
Args:
text (str): The text input from the user.
"""
self.status_indicator.update("responding")
response = self.exchange.generate()
if response.text:
print(Markdown(response.text))
while response.tool_use:
content = []
for tool_use in response.tool_use:
tool_result = self.exchange.call_function(tool_use)
content.append(tool_result)
self.exchange.add(Message(role="user", content=content))
self.status_indicator.update("responding")
response = self.exchange.generate()
if response.text:
print(Markdown(response.text))
def interrupt_reply(self) -> None:
"""Recover from an interruption at an arbitrary state"""
# Default recovery message if no user message is pending.
recovery = "We interrupted before the next processing started."
if self.exchange.messages and self.exchange.messages[-1].role == "user":
# If the last message is from the user, remove it.
self.exchange.messages.pop()
recovery = "We interrupted before the model replied and removed the last message."
if (
self.exchange.messages
and self.exchange.messages[-1].role == "assistant"
and self.exchange.messages[-1].tool_use
):
content = []
# Append tool results as errors if interrupted.
for tool_use in self.exchange.messages[-1].tool_use:
content.append(
ToolResult(
tool_use_id=tool_use.id,
output="Interrupted by the user to make a correction",
is_error=True,
)
)
self.exchange.add(Message(role="user", content=content))
recovery = f"We interrupted the existing call to {tool_use.name}. How would you like to proceed?"
self.exchange.add(Message.assistant(recovery))
# Print the recovery message with markup for visibility.
print(f"[yellow]{recovery}[/]")
@property
def session_file_path(self) -> Path:
return session_path(self.name)
def save_session(self) -> None:
"""Save the current session to a file in JSON format."""
if self.name is None:
self.generate_session_name()
try:
if self.session_file_path.exists():
if not confirm(f"Session {self.name} exists in {self.session_file_path}, overwrite?"):
self.generate_session_name()
write_to_file(self.session_file_path, self.exchange.messages)
except PermissionError as e:
raise RuntimeError(f"Failed to save session due to permissions: {e}")
except (IOError, OSError) as e:
raise RuntimeError(f"Failed to save session due to I/O error: {e}")
def load_session(self) -> List[Message]:
"""Load a session from a JSON file."""
return read_from_file(self.session_file_path)
def generate_session_name(self) -> None:
user_entered_session_name = self.prompt_session.get_save_session_name()
self.name = user_entered_session_name if user_entered_session_name else droid()
print(f"Saving to [bold cyan]{self.session_file_path}[/bold cyan]")
if __name__ == "__main__":
session = Session()

View File

@@ -0,0 +1,15 @@
from functools import cache
from typing import Dict
from goose.command.base import Command
from goose.utils import load_plugins
@cache
def get_command(name: str) -> type[Command]:
return load_plugins(group="goose.command")[name]
@cache
def get_commands() -> Dict[str, type[Command]]:
return load_plugins(group="goose.command")

16
src/goose/command/base.py Normal file
View File

@@ -0,0 +1,16 @@
from abc import ABC
from typing import List, Optional
from prompt_toolkit.completion import Completion
class Command(ABC):
"""A command that can be executed by the CLI."""
def get_completions(self, query: str) -> List[Completion]:
"""Get completions for the command."""
return []
def execute(self, query: str) -> Optional[str]:
"""Execute's the command and replaces it with the output."""
return ""

61
src/goose/command/file.py Normal file
View File

@@ -0,0 +1,61 @@
import os
from typing import List
from prompt_toolkit.completion import Completion
from goose.command.base import Command
class FileCommand(Command):
def get_completions(self, query: str) -> List[Completion]:
if query.startswith("/"):
directory = os.path.dirname(query)
search_term = os.path.basename(query)
else:
directory = os.path.join(os.getcwd(), os.path.dirname(query))
search_term = os.path.basename(query)
# if query is a file, don't show completions
if os.path.isfile(directory):
return []
# Get the list of files in the directory
options = []
try:
for file_name in os.listdir(directory):
if file_name.startswith(search_term):
full_path = os.path.join(directory, file_name)
if os.path.isdir(full_path):
options.append(
dict(
display_text="" + file_name,
insert_text=file_name,
is_dir=True,
)
)
else:
options.append(
dict(
display_text="" + file_name,
insert_text=file_name + " ",
is_dir=False,
)
)
except FileNotFoundError:
return []
completions = []
options.sort(key=lambda x: (not x["is_dir"], x["insert_text"]), reverse=False)
for option in options:
completions.append(
Completion(
option["insert_text"],
start_position=-len(search_term),
display=option["display_text"],
)
)
return completions
def execute(self, query: str) -> str | None:
# GOOSE-TODO: return the query
pass

28
src/goose/notifier.py Normal file
View File

@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
from rich.console import RenderableType
class Notifier(ABC):
"""The interface for a notifier
This is expected to be implemented concretely by the each UX
"""
@abstractmethod
def log(self, content: RenderableType) -> None:
"""Append content to the main display
Args:
content (str): The content to render
"""
pass
@abstractmethod
def status(self, status: str) -> None:
"""Log a status to ephemeral display
Args:
status (str): The status to display
"""
pass

54
src/goose/profile.py Normal file
View File

@@ -0,0 +1,54 @@
from typing import Any, Dict, List, Mapping, Type
from attrs import asdict, define, field
from goose.utils import ensure_list
@define
class ToolkitSpec:
"""Configuration for a Toolkit"""
name: str
requires: Mapping[str, str] = field(factory=dict)
@define
class Profile:
"""The configuration for a run of goose"""
provider: str
processor: str
accelerator: str
moderator: str
toolkits: List[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec))
@toolkits.validator
def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[ToolkitSpec]) -> None:
# checks that the list of toolkits in the profile have their requirements
installed_toolkits = set([toolkit.name for toolkit in toolkits])
for toolkit in toolkits:
toolkit_name = toolkit.name
toolkit_requirements = toolkit.requires
for _, req in toolkit_requirements.items():
if req not in installed_toolkits:
msg = f"Toolkit {toolkit_name} requires {req} but it is not present"
raise ValueError(msg)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile:
"""Get the default profile"""
# TODO consider if the providers should have recommended models
return Profile(
provider=provider,
processor=processor,
accelerator=accelerator,
moderator="truncate",
toolkits=[ToolkitSpec("developer")],
)

1
src/goose/system.jinja Normal file
View File

@@ -0,0 +1 @@
You are an AI assistant named Goose. You solve problems using your tools.

View File

@@ -0,0 +1,9 @@
from functools import cache
from goose.toolkit.base import Toolkit
from goose.utils import load_plugins
@cache
def get_toolkit(name: str) -> type[Toolkit]:
return load_plugins(group="goose.toolkit")[name]

65
src/goose/toolkit/base.py Normal file
View File

@@ -0,0 +1,65 @@
import inspect
from abc import ABC
from typing import Callable, Mapping, Optional, Tuple, TypeVar
from attrs import define, field
from exchange import Tool
from goose.notifier import Notifier
# Create a type variable that can represent any function signature
F = TypeVar("F", bound=Callable)
def tool(func: F) -> F:
func._is_tool = True
return func
@define
class Requirements:
"""A collection of requirements for advanced toolkits
Requirements are an advanced use case, most toolkits will not need to
use these. They allow one toolkit to interact with another's state.
"""
_toolkit: str
_requirements: Mapping[str, "Toolkit"] = field(factory=dict)
def get(self, requirement: str) -> "Toolkit":
"""Get a requirement by name."""
if requirement not in self._requirements:
raise RuntimeError(
f"The toolkit '{self._toolkit}' requested a requirement '{requirement}' but none was passed!\n"
+ f" Make sure to include `requires: {{{requirement}: ...}}` in your profile config\n"
+ f" See the documentation for {self._toolkit} for more details"
)
return self._requirements[requirement]
class Toolkit(ABC):
"""A collection of tools with corresponding prompting
This class defines the interface that all toolkit implementations must follow,
providing a system prompt and a collection of tools. Both are allowed to be
empty if they are not required for the toolkit.
"""
def __init__(self, notifier: Notifier, requires: Optional[Requirements] = None) -> None:
self.notifier = notifier
# This needs to be updated after the fact via build_exchange
self.exchange_view = None
def system(self) -> str:
"""Get the addition to the system prompt for this toolkit."""
return ""
def tools(self) -> Tuple[Tool, ...]:
"""Get the tools for this toolkit
This default method looks for functions on the toolkit annotated
with @tool.
"""
candidates = inspect.getmembers(self, predicate=inspect.ismethod)
return (Tool.from_function(candidate) for _, candidate in candidates if getattr(candidate, "_is_tool", None))

View File

@@ -0,0 +1,201 @@
from pathlib import Path
from subprocess import CompletedProcess, run
from typing import List
from exchange import Message
from rich import box
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Confirm, PromptType
from rich.table import Table
from rich.text import Text
from goose.toolkit.base import Toolkit, tool
from goose.toolkit.utils import get_language
def keep_unsafe_command_prompt(command: str) -> PromptType:
command_text = Text(command, style="bold red")
message = (
Text("\nWe flagged the command: ")
+ command_text
+ Text(" as potentially unsafe, do you want to proceed? (yes/no)")
)
return Confirm.ask(message, default=True)
class Developer(Toolkit):
"""The developer toolkit provides a set of general purpose development capabilities
The tools include plan management, a general purpose shell execution tool, and file operations.
We also include some default shell strategies in the prompt, such as using ripgrep
"""
def system(self) -> str:
"""Retrieve system configuration details for developer"""
return Message.load("prompts/developer.jinja").text
@tool
def update_plan(self, tasks: List[dict]) -> List[dict]:
"""
Update the plan by overwriting all current tasks
This can be used to update the status of a task. This update will be
shown to the user directly, you do not need to reiterate it
Args:
tasks (List(dict)): The list of tasks, where each task is a dictionary
with a key for the task "description" and the task "status". The status
MUST be one of "planned", "complete", "failed", "in-progress".
"""
# Validate the status of each task to ensure it is one of the accepted values.
for task in tasks:
if task["status"] not in {"planned", "complete", "failed", "in-progress"}:
raise ValueError(f"Invalid task status: {task['status']}")
# Create a table with columns for the index, description, and status of each task.
table = Table(expand=True)
table.add_column("#", justify="right", style="magenta")
table.add_column("Task", justify="left")
table.add_column("Status", justify="left")
# Mapping of statuses to emojis for better visual representation in the table.
emoji = {"planned": "", "complete": "", "failed": "", "in-progress": "🕓"}
for i, entry in enumerate(tasks):
table.add_row(str(i), entry["description"], emoji[entry["status"]])
# Log the table to display it directly to the user
# `.log` method is used here to log the command execution in the application's UX
self.notifier.log(table)
# Return the tasks unchanged as the function's primary purpose is to update and display the task status.
return tasks
@tool
def patch_file(self, path: str, before: str, after: str) -> str:
"""Patch the file at the specified by replacing before with after
Before **must** be present exactly once in the file, so that it can safely
be replaced with after.
Args:
path (str): The path to the file, in the format "path/to/file.txt"
before (str): The content that will be replaced
after (str): The content it will be replaced with
"""
self.notifier.status(f"editing {path}")
_path = Path(path)
language = get_language(path)
content = _path.read_text()
if content.count(before) > 1:
raise ValueError("The before content is present multiple times in the file, be more specific.")
if content.count(before) < 1:
raise ValueError("The before content was not found in file, be careful that you recreate it exactly.")
content = content.replace(before, after)
_path.write_text(content)
output = f"""
```{language}
{before}
```
->
```{language}
{after}
```
"""
self.notifier.log(Panel.fit(Markdown(output), title=path))
return "Succesfully replaced before with after."
@tool
def read_file(self, path: str) -> str:
"""Read the content of the file at path
Args:
path (str): The destination file path, in the format "path/to/file.txt"
"""
language = get_language(path)
content = Path(path).expanduser().read_text()
self.notifier.log(Panel.fit(Markdown(f"```\ncat {path}\n```"), box=box.MINIMAL))
return f"```{language}\n{content}\n```"
@tool
def shell(self, command: str) -> str:
"""
Execute a command on the shell (in OSX)
This will return the output and error concatenated into a single string, as
you would see from running on the command line. There will also be an indication
of if the command succeeded or failed.
Args:
command (str): The shell command to run. It can support multiline statements
if you need to run more than one at a time
"""
self.notifier.status("running shell command")
# Log the command being executed in a visually structured format (Markdown).
# The `.log` method is used here to log the command execution in the application's UX
# this method is dynamically attached to functions in the Goose framework to handle user-visible
# logging and integrates with the overall UI logging system
self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell"))
safety_rails_exchange = self.exchange_view.processor.replace(
system=Message.load("prompts/safety_rails.jinja").text
)
# remove the previous message which was a tool_use Assistant message
safety_rails_exchange.messages.pop()
safety_rails_exchange.add(Message.assistant(f"Here is the command I'd like to run: `{command}`"))
safety_rails_exchange.add(Message.user("Please provide the danger rating of that command"))
rating = safety_rails_exchange.reply().text
try:
rating = int(rating)
except ValueError:
rating = 5 # if we can't interpret we default to unsafe
if int(rating) > 3:
if not keep_unsafe_command_prompt(command):
raise RuntimeError(
f"The command {command} was rejected as dangerous by the user."
+ " Do not proceed further, instead ask for instructions."
)
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
if result.returncode == 0:
output = "Command succeeded"
else:
output = f"Command failed with returncode {result.returncode}"
return "\n".join([output, result.stdout, result.stderr])
@tool
def write_file(self, path: str, content: str) -> str:
"""
Write a file at the specified path with the provided content. This will create any directories if they do not exist.
The content will fully overwrite the existing file.
Args:
path (str): The destination file path, in the format "path/to/file.txt"
content (str): The raw file content.
""" # noqa: E501
self.notifier.status("writing file")
# Get the programming language for syntax highlighting in logs
language = get_language(path)
md = f"```{language}\n{content}\n```"
# Log the content that will be written to the file
# .log` method is used here to log the command execution in the application's UX
# this method is dynamically attached to functions in the Goose framework
self.notifier.log(Panel.fit(Markdown(md), title=path))
# Prepare the path and create any necessary parent directories
_path = Path(path)
_path.parent.mkdir(parents=True, exist_ok=True)
# Write the content to the file
_path.write_text(content)
# Return a success message
return f"Succesfully wrote to {path}"

View File

@@ -0,0 +1,11 @@
from exchange import Message
from goose.toolkit.base import Toolkit
class Github(Toolkit):
"""Provides an additional prompt on how to interact with Github"""
def system(self) -> str:
"""Retrieve detailed configuration and procedural guidelines for GitHub operations"""
return Message.load("prompts/github.jinja").text

View File

@@ -0,0 +1,62 @@
To start solving problems, you will always create a plan, and then immediately
execute the next step of the plan. Do not wait for confirmation on your plan,
always proceed to the next step. After each step, update the plan based
on the output you recieve from any actions you take, including marking
finished tasks as complete!
To accomplish this, you will call your tools such as update_plan, shell, or write.
Always use the update_plan tool before taking any new actions, to show the client
an up to date plan. The plan will be displayed automatically any time you update it.
The plan should consist of as few entries as possible, and translate from the user
request into concrete tasks that you will use to get it done. These should reflect
the actions you will need to take, such as writing files or executing shell commands.
For example, here's a plan to unstage all edited files in a git repo
{"description": "Use the git status command to find edited files", "status": "pending"}
{"description": "For each file with changes, call git restore on the file", "status": "pending"}
After running git status, you would update to
{"description": "Use the git status command to find edited files", "status": "complete"}
{"description": "For each file with changes, call git restore on the file", "status": "pending"}
Here's another plan example to get the sum of the "payment" column in data.csv
{"description": "Install pandas", "status": "pending"}
{"description": "Write a python script in the file sum.py that loads the csv in pandas and sums the column", "status": "pending"}
{"description": "Run the python script with bash", "status": "pending"}
If you were to encounter an error along the way, you can update the plan to specify a new approach.
Always call update_plan before any other tool calls! You should specify the whole plan upfront as pending,
and then update status at each step. **Do not describe the plan in your text response, only communicate
the plan through the tool**
If you need to manipulate files, always prefer the write file tool. To edit a file
that already exists, first check the content with the shell and then overwrite it
Some of the files that you work with will be long. When you want to edit a long file,
prefer to use the patch tool. Make sure that you always read the file using the read tool
before you call patch.
The patch and write tools can accomplish the same operations, but patch is a more complex tool.
The patch tool is worth it when the file is large enough that it would be tedious to fully rewrite.
You are an expert with ripgrep - `rg`. When you need to locate content in the code base, use
`rg` exclusively. It will respect ignored files for efficiency.
To locate files by name, use
```bash
rg --files | rg example.py
```
To locate content inside files, use
```bash
rg 'class Example'
```

View File

@@ -0,0 +1,18 @@
You can interact through github via the `gh` command line generally.
If it fails to auth, prompt the user to run `gh auth login`
Typically when someone requests you to look at a pull request review, they mean to view
not just the top level comments and reviews, but also the comments nested within that review.
To do that, you need to first use the API to get reviews:
```bash
gh api -H "Accept: application/vnd.github+json" /repos/OWNER/REPO/pulls/PULL_NUMBER/reviews
```
And then for each individual review, get all of the comments:
```bash
gh api -H 'Accept: application/vnd.github+json' /repos/OWNER/REPO/pulls/PULL_NUMBER/reviews/ID/comments
```
When you work with a pull request review, use the above approach as well as `gh diff`
to get the full details before answering questions.

View File

@@ -0,0 +1,39 @@
You are a third party auditor, with the goal of ensuring safety for commands
that will run on the user's system. You will be provided a conversation,
along with the next command that is queued to run. Evaluate if that command
is safe.
Here are examples of harmful commands:
1. Commands that delete files that weren't created in the session, like `rm -rf /`
2. Commands that edit dotfiles or generally anything at the top level of the home `~/.` directory
3. Commands that globally modify packages or the environment, such as `brew install`
However, it's safe to allow edits scoped to a project:
1. Commands that create new directories or new files, or scaffold new projects. `mkdir example`
2. Commands that manage the dependencies of the current project, such as virtualenvs, cargo, npm etc. `poetry add requests`
3. Commands that edit dotfiles in a project directory, like a .gitignore
You should rate the potential harm of a command on a scale of 1 to 5, with 1 being totally safe and 5 being very dangerous.
Reply with a single character, the numerical rating.
Here are some examples.
Assistant: Here is the command I'd like to run: `rm -rf ~`
User: Please provide the danger rating of that command
Assistant: 5
Assistant: Here is the command I'd like to run: `brew install ripgrep`
User: Please provide the danger rating of that command
Assistant: 4
Assistant: Here is the command I'd like to run: `rm tests/test_exchange.py`
User: Please provide the danger rating of that command
Assistant: 3
Assistant: Here is the command I'd like to run: `cat demo.py`
User: Please provide the danger rating of that command
Assistant: 1
Assistant: Here is the command I'd like to run: `echo "export PATH=$HOME/.local/bin/:$PATH" >> ~/.zprofile`
User: Please provide the danger rating of that command
Assistant: 5

View File

@@ -0,0 +1,39 @@
Given a dictionary of files and directories in a project repository, please identify which files should be retained
based solely on their relevance to the core processing code of the project. Exclude configuration-related files and other
non-code files, except for necessary Dockerfiles and Markdown files. You do not need to read or open the files or
directories. Just make an educated guess.
**Important:** Return the file and directory names exactly as they appear in the input list. Do not modify, alter, or
assume different names. Any suggested file or directory must match an entry in the input list.
Return ONLY a dictionary of relevant files and potentially relevant directories that need further inspection. NO
MARKDOWN NOTATION.
Example:
Input:
{
'files': [
'LICENSE.md',
'ARCHITECTURE.md',
'mkdocs.yml',
'justfile',
'CHANGELOG.md',
'pyproject.toml',
'README.md',
'CONTRIBUTING.md',
'poetry.lock'
],
'directories': ['bin', 'tests', 'docs', 'mlruns', 'scripts', 'src']
}
Output:
{
'files': [
'ARCHITECTURE.md',
'README.md',
],
'directories': ['tests', 'scripts', 'src']
}

View File

@@ -0,0 +1,106 @@
import os
from functools import cache
from subprocess import CompletedProcess, run
from typing import Dict, Tuple
from exchange import Message
from goose.notifier import Notifier
from goose.toolkit import Toolkit
from goose.toolkit.base import Requirements, tool
from goose.toolkit.repo_context.utils import get_repo_size, goose_picks_files
from goose.toolkit.summarization.utils import load_summary_file_if_exists, summarize_files_concurrent
from goose.utils.ask import clear_exchange, replace_prompt
class RepoContext(Toolkit):
def __init__(self, notifier: Notifier, requires: Requirements) -> None:
super().__init__(notifier=notifier, requires=requires)
self.repo_project_root, self.is_git_repo, self.goose_session_root = self.determine_git_proj()
def determine_git_proj(self) -> Tuple[str, bool, str]:
"""Determines the root as well as where Goose is currently running
If the project is not part of a Github repo, the root of the project will be defined as the current working
directory
Returns:
str: path to the root of the project (if part of a local repository) or the CWD if not
boolean: if Goose is operating within local repository or not
str: path to where the Goose session is running (the CWD)
"""
# FIXME: monorepos
cwd = os.getcwd()
command = "git rev-parse --show-toplevel"
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
if result.returncode == 0:
project_root = result.stdout.strip()
return project_root, True, cwd
else:
self.notifier.log("Not part of a Git repository. Returning current working directory")
return cwd, False, cwd
@property
@cache
def repo_size(self) -> float:
"""Returns the size of the repo in MB (if Goose detects its running in a local repository
This measurement can be used to guess if the local repository is a monorepo
Returns:
float: size of project in MB
"""
# in MB
if self.is_git_repo:
return get_repo_size(self.repo_project_root)
else:
self.notifier.log("Not a git repo. Returning 0.")
return 0.0
@property
def is_mono_repo(self) -> bool:
"""An boolean indicator of whether the local repository is part of a monorepo
Returns:
boolean: True if above 2000 MB; False otherwise
"""
# java: 6394.367112159729
# go: 3729.93 MB
return self.repo_size > 2000
@tool
def summarize_current_project(self) -> Dict[str, str]:
"""Summarizes the current project based on repo root (if git repo) or current project_directory (if not)
Returns:
summary (Dict[str, str]): Keys are file paths and values are the summaries
"""
self.notifier.log("Summarizing the most relevant files in the current project. This may take a while...")
if self.is_mono_repo:
self.notifier.log("This might be a monorepo. Goose performs better on smaller projects. Using CWD.")
# TODO: prompt user to specify a subdirectory
project_directory = self.goose_session_root
else:
project_directory = self.repo_project_root
# before selecting files and summarizing look for summarization file
project_name = project_directory.split("/")[-1]
summary = load_summary_file_if_exists(project_name=project_name)
if summary:
self.notifier.log("Summary file for project exists already -- loading into the context")
return summary
# clear exchange and replace the system prompt with instructions on why and how to select files to summarize
file_select_exchange = clear_exchange(self.exchange_view.accelerator, clear_tools=True)
system = Message.load("prompts/repo_context.jinja").text
file_select_exchange = replace_prompt(exchange=file_select_exchange, prompt=system)
files = goose_picks_files(root=project_directory, exchange=file_select_exchange)
summary = summarize_files_concurrent(
exchange=self.exchange_view.accelerator, file_list=files, project_name=project_directory.split("/")[-1]
)
return summary

View File

@@ -0,0 +1,104 @@
import ast
import concurrent.futures
import os
from collections import deque
from typing import Dict, List, Tuple
from exchange import Exchange
from goose.utils.ask import ask_an_ai
def get_directory_size(directory: str) -> int:
total_size = 0
for dirpath, _, filenames in os.walk(directory):
for f in filenames:
fp = os.path.join(dirpath, f)
# Skip if it is a symbolic link
if not os.path.islink(fp):
total_size += os.path.getsize(fp)
return total_size
def get_repo_size(repo_path: str) -> int:
"""Returns repo size in MB"""
git_dir = os.path.join(repo_path, ".git")
return get_directory_size(git_dir) / (1024**2)
def get_files_and_directories(root_dir: str) -> Dict[str, list]:
"""Gets file names and directory names. Checks that goose has correctly typed the file and directory names and that
the files actually exist (to avoid downstream file read errors).
Args:
root_dir (str): Path to the directory to examine for files and sub-directories
Returns:
dict: A list of files and directories in the form {'files': [], 'directories: []}. Paths
are all relative (i.e. ['src'] not ['goose/src'])
"""
files = []
dirs = []
# check dir exists
try:
os.listdir(root_dir)
except FileNotFoundError:
# FIXME: fuzzy match might work here to recover directories 'lost' to goose mistyping
# hallucination: Goose mistyped the path (e.g. `metrichandler` vs `metricshandler`)
return {"files": files, "directories": dirs}
for entry in os.listdir(root_dir):
if entry.startswith(".") or entry.startswith("~"):
continue # Skip hidden files and directories
full_path = os.path.join(root_dir, entry)
if os.path.isdir(full_path):
dirs.append(entry)
elif os.path.isfile(full_path):
files.append(entry)
return {"files": files, "directories": dirs}
def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> List[str]:
"""Lets goose pick files in a BFS manner"""
queue = deque([root])
all_files = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
while queue:
current_batch = [queue.popleft() for _ in range(min(max_workers, len(queue)))]
futures = {executor.submit(process_directory, dir, exchange): dir for dir in current_batch}
for future in concurrent.futures.as_completed(futures):
files, next_dirs = future.result()
all_files.extend(files)
queue.extend(next_dirs)
return all_files
def process_directory(current_dir: str, exchange: Exchange) -> Tuple[List[str], List[str]]:
"""Allows goose to pick files and subdirectories contained in a given directory (current_dir). Get the list of file
and directory names in the current folder, then ask Goose to pick which ones to keep.
"""
files_and_dirs = get_files_and_directories(current_dir)
ai_response = ask_an_ai(str(files_and_dirs), exchange)
# FIXME: goose response validation
try:
as_dict = ast.literal_eval(ai_response.text)
except Exception:
# can happen if goose returns anything but {result: dict} (e.g. ```json\n {results: dict} \n```)
return [], []
if not isinstance(as_dict, dict):
# can happen if goose returns something like `{'files': ['x.py'] 'directories': ['dir1']}` (missing comma)
return [], []
files = [f"{current_dir}/{file}" for file in as_dict.get("files", [])]
next_dirs = [f"{current_dir}/{next_dir}" for next_dir in as_dict.get("directories", [])]
return files, next_dirs

View File

@@ -0,0 +1,28 @@
import subprocess
import uuid
from goose.toolkit.base import Toolkit, tool
class Screen(Toolkit):
"""Provides an instructions on when and how to work with screenshots"""
@tool
def take_screenshot(self) -> str:
"""
Take a screenshot to assist the user in debugging or designing an app. They may need help with moving a screen element, or interacting in some way where you could do with seeing the screen.
Return:
(str) a path to the screenshot file, in the format of image: followed by the path to the file.
""" # noqa: E501
# Generate a random tmp filename for screenshot
filename = f"/tmp/goose_screenshot_{uuid.uuid4().hex}.png"
subprocess.run(["screencapture", "-x", filename])
return f"image:{filename}"
# Provide any system instructions for the model
# This can be generated dynamically, and is run at startup time
def system(self) -> str:
return """**When the user wants you to help debug, or work on a visual design by looking at their screen, IDE or browser, call the take_screenshot and send the output from the user.**""" # noqa: E501

View File

@@ -0,0 +1,3 @@
from .summarize_repo import SummarizeRepo # noqa
from .summarize_project import SummarizeProject # noqa
from .summarize_file import SummarizeFile # noqa

View File

@@ -0,0 +1,28 @@
from typing import Optional
from goose.toolkit import Toolkit
from goose.toolkit.base import tool
from goose.toolkit.summarization.utils import summarize_file
class SummarizeFile(Toolkit):
@tool
def summarize_file(self, filepath: str, prompt: Optional[str] = None) -> str:
"""
Tool to summarize a specific file
Args:
filepath (str): Path to the file to summarize
prompt (str): Optional prompt giving the model instructions on how to summarize the file.
Under the hood this defaults to "Please summarize this file"
Returns:
summary (Optional[str]): Summary of the file contents
"""
exchange = self.exchange_view.accelerator
_, summary = summarize_file(filepath=filepath, exchange=exchange, prompt=prompt)
return summary

View File

@@ -0,0 +1,37 @@
import os
from typing import List, Optional
from goose.toolkit import Toolkit
from goose.toolkit.base import tool
from goose.toolkit.summarization.utils import summarize_directory
class SummarizeProject(Toolkit):
@tool
def get_project_summary(
self,
project_dir_path: Optional[str] = os.getcwd(),
extensions: Optional[List[str]] = None,
summary_instructions_prompt: Optional[str] = None,
) -> dict:
"""Generates or retrieves a project summary based on specified file extensions.
Args:
project_dir_path (Optional[Path]): Path to the project directory. Defaults to the current working directory
if None
extensions (Optional[List[str]]): Specific file extensions to summarize.
summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g.
"Summarize the file in two sentences.". The default instruction is "Please summarize this file."
Returns:
summary (dict): Project summary.
""" # noqa: E501
summary = summarize_directory(
project_dir_path,
exchange=self.exchange_view.accelerator,
extensions=extensions,
summary_instructions_prompt=summary_instructions_prompt,
)
return summary

View File

@@ -0,0 +1,37 @@
from typing import List, Optional
from goose.toolkit import Toolkit
from goose.toolkit.base import tool
from goose.toolkit.summarization.utils import summarize_repo
class SummarizeRepo(Toolkit):
@tool
def summarize_repo(
self,
repo_url: str,
specified_extensions: Optional[List[str]] = None,
summary_instructions_prompt: Optional[str] = None,
) -> dict:
"""
Retrieves a summary of a repository. Clones the repository if not already cloned and summarizes based on the
specified file extensions. If no extensions are specified, it summarizes the top `max_extensions` extensions.
Args:
repo_url (str): The URL of the repository to summarize.
specified_extensions (Optional[List[str]]): List of file extensions to summarize, e.g., ["tf", "md"]. If
this list is empty, then all files in the repo are summarized
summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g.
"Summarize the file in two sentences.". The default instruction is "Please summarize this file."
Returns:
summary (dict): A summary of the repository where keys are the file extensions and values are their
summaries.
""" # noqa: E501
return summarize_repo(
repo_url=repo_url,
exchange=self.exchange_view.accelerator,
extensions=specified_extensions,
summary_instructions_prompt=summary_instructions_prompt,
)

View File

@@ -0,0 +1,199 @@
import json
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from exchange import Exchange
from exchange.providers.utils import InitialMessageTooLargeError
from goose.utils.ask import ask_an_ai
from goose.utils.file_utils import create_file_list
SUMMARIES_FOLDER = ".goose/summaries"
CLONED_REPOS_FOLDER = ".goose/cloned_repos"
# TODO: move git stuff
def run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]:
result = subprocess.run(["git"] + command, capture_output=True, text=True, check=False)
if result.returncode != 0:
raise Exception(f"Git command failed with message: {result.stderr.strip()}")
return result
def clone_repo(repo_url: str, target_directory: str) -> None:
run_git_command(["clone", repo_url, target_directory])
def load_summary_file_if_exists(project_name: str) -> Optional[Dict]:
"""Checks if a summary file exists at '.goose/summaries/projectname-summary.json. Returns contents of the file if
it exists, otherwise returns None
Args:
project_name (str): name of the project or repo
Returns:
Optional[Dict]: File contents, else None
"""
summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json"
if Path(summary_file_path).exists():
with open(summary_file_path, "r") as f:
return json.load(f)
def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> Tuple[str, str]:
"""Summarizes a single file
Args:
filepath (str): Path to the file to summarize.
exchange (Exchange): Exchange object to use for summarization.
prompt (Optional[str]): Defaults to "Please summarize this file."
"""
try:
with open(filepath, "r") as f:
file_text = f.read()
except Exception as e:
return filepath, f"Error reading file {filepath}: {str(e)}"
if not file_text:
return filepath, "Empty file"
try:
reply = ask_an_ai(
input=file_text, exchange=exchange, prompt=prompt if prompt else "Please summarize this file."
)
except InitialMessageTooLargeError:
return filepath, "File too large"
return filepath, reply.text
def summarize_repo(
repo_url: str,
exchange: Exchange,
extensions: List[str],
summary_instructions_prompt: Optional[str] = None,
) -> Dict[str, str]:
"""Clones (if needed) and summarizes a repo
Args:
repo_url (str): Repository url
exchange (Exchange): Exchange for summarizing the repo.
extensions (List[str]): List of file-types to summarize.
summary_instructions_prompt (Optional[str]): Optional parameter to customize summarization results. Defaults to
"Please summarize this file"
"""
# set up the paths for the repository and the summary file
repo_name = repo_url.split("/")[-1]
repo_dir = f"{CLONED_REPOS_FOLDER}/{repo_name}" # e.g. '.goose/cloned_repos/<project-name>'
if Path(repo_dir).exists():
# TODO: re-add ability to log
return summarize_directory(
directory=repo_dir,
exchange=exchange,
extensions=extensions,
summary_instructions_prompt=summary_instructions_prompt,
)
clone_repo(repo_url, target_directory=repo_dir)
return summarize_directory(
directory=repo_dir,
exchange=exchange,
extensions=extensions,
summary_instructions_prompt=summary_instructions_prompt,
)
def summarize_directory(
directory: str, exchange: Exchange, extensions: List[str], summary_instructions_prompt: Optional[str] = None
) -> Dict[str, str]:
"""Summarize files in a given directory based on extensions. Will also recursively find files in subdirectories and
summarize them.
Args:
directory (str): path to the top-level directory to summarize
exchange (Exchange): Exchange to use to summarize
extensions (List[str]): List of file-type extensions to summarize (and ignore all other extensions).
summary_instructions_prompt (Optional[str]): Optional instructions to give to the exchange regarding summarization.
Returns:
file_summaries (dict): Keys are file names and values are summaries.
""" # noqa: E501
# TODO: make sure that '.goose/summaries' is
# in the root of the current not relative to current dir or in cloned repo root
project_name = directory.split("/")[-1]
summary_file = load_summary_file_if_exists(project_name)
if summary_file:
return summary_file
summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json"
# create the .goose/summaries folder if not already created
Path(SUMMARIES_FOLDER).mkdir(exist_ok=True, parents=True)
# select a subset of files to summarize based on file extension
files_to_summarize = create_file_list(directory, extensions=extensions)
file_summaries = summarize_files_concurrent(
exchange=exchange,
file_list=files_to_summarize,
project_name=project_name,
summary_instructions_prompt=summary_instructions_prompt,
)
summary_file_contents = {"extensions": extensions, "summaries": file_summaries}
# Write the summaries into a json
with open(summary_file_path, "w") as f:
json.dump(summary_file_contents, f, indent=2)
return file_summaries
def summarize_files_concurrent(
exchange: Exchange, file_list: List[str], project_name: str, summary_instructions_prompt: Optional[str] = None
) -> Dict[str, str]:
"""Takes in a list of files and summarizes them. Exchange does not keep history of the summarized files.
Args:
exchange (Exchange): Underlying exchange
file_list (List[str]): List of paths to files to summarize
project_name (str): Used to save the summary of the files to .goose/summaries/<project_name>-summary.json
summary_instructions_prompt (Optional[str]): Summary instructions for the LLM. Defaults to "Please summarize
this file."
Returns:
file_summaries (Dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange
"""
summary_file = load_summary_file_if_exists(project_name)
if summary_file:
return summary_file
file_summaries = {}
# compile the individual file summaries into a single summary dict
# TODO: add progress bar as this step can take quite some time and it's nice to see something is happening
with ThreadPoolExecutor() as executor:
future_to_file = {
executor.submit(summarize_file, file, exchange, summary_instructions_prompt): file for file in file_list
}
for future in as_completed(future_to_file):
file_name, file_summary = future.result()
file_summaries[file_name] = file_summary
# create summaries folder if it doesn't exist
Path(SUMMARIES_FOLDER).mkdir(exist_ok=True, parents=True)
summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json"
# Write the summaries into a json
with open(summary_file_path, "w") as f:
json.dump(file_summaries, f, indent=2)
return file_summaries

View File

@@ -0,0 +1,21 @@
from pathlib import Path
from pygments.lexers import get_lexer_for_filename
from pygments.util import ClassNotFound
def get_language(filename: Path) -> str:
"""
Determine the programming language of a file based on its filename extension.
Args:
filename (str): The name of the file for which to determine the programming language.
Returns:
str: The name of the programming language if recognized, otherwise an empty string.
"""
try:
lexer = get_lexer_for_filename(filename)
return lexer.name
except ClassNotFound:
return ""

View File

@@ -0,0 +1,70 @@
import random
import string
from importlib.metadata import entry_points
from typing import Any, Callable, Dict, List, Type, TypeVar
T = TypeVar("T")
def load_plugins(group: str) -> dict:
"""
Load plugins based on a specified entry point group.
This function iterates through all entry points registered under a specified group
Args:
group (str): The entry point group to load plugins from. This should match the group specified
in the package setup where plugins are defined.
Returns:
dict: A dictionary where each key is the entry point name, and the value is the loaded plugin object.
Raises:
Exception: Propagates exceptions raised by entry point loading, which might occur if a plugin
is not found or if there are issues with the plugin's code.
"""
plugins = {}
# Access all entry points for the specified group and load each.
for entrypoint in entry_points(group=group):
plugin = entrypoint.load() # Load the plugin.
plugins[entrypoint.name] = plugin # Store the loaded plugin in the dictionary.
return plugins
def ensure(cls: Type[T]) -> Callable[[Any], T]:
"""Convert dictionary to a class instance"""
def converter(val: Any) -> T: # noqa: ANN401
if isinstance(val, cls):
return val
elif isinstance(val, dict):
return cls(**val)
elif isinstance(val, list):
return cls(*val)
else:
return cls(val)
return converter
def ensure_list(cls: Type[T]) -> Callable[[List[Dict[str, Any]]], Type[T]]:
"""Convert a list of dictionaries to class instances"""
def converter(val: List[Dict[str, Any]]) -> List[T]:
output = []
for entry in val:
output.append(ensure(cls)(entry))
return output
return converter
def droid() -> str:
return "".join(
[
random.choice(string.ascii_lowercase),
random.choice(string.digits),
random.choice(string.ascii_lowercase),
random.choice(string.digits),
]
)

82
src/goose/utils/ask.py Normal file
View File

@@ -0,0 +1,82 @@
from exchange import Exchange, Message
def ask_an_ai(input: str, exchange: Exchange, prompt: str = "", no_history: bool = True) -> Message:
"""Sends a separate message to an LLM using a separate Exchange than the one underlying the Goose session.
Can be used to summarize a file, or submit any other request that you'd like to an AI. The Exchange can have a
history/prior context, or be wiped clean (by setting no_history to True).
Parameters:
input (str): The user's input string to be processed by the AI. Must be a non-empty string. Example: text from
a file.
exchange (Exchange): An object representing the AI exchange system which manages the state and flow of the
conversation.
prompt (str, optional): An optional new prompt to replace the current one in the exchange system. Defaults to
None. Example: "Please summarize this file."
no_history (bool, optional): A flag to determine if the conversation history should be cleared before
processing the new input. True clears the context, False retains it. Defaults to True.
Returns:
reply (str): The AI's reply as a string.
Raises:
TypeError: If the `input` is not a non-empty string.
Exception: If there is an issue within the exchange system, including errors from the provider or model.
Example:
# Create an instance of an Exchange system
exchange_system = Exchange(provider=OpenAIProvider.from_env(), model="gpt-4")
# Simulate asking the AI a question
response = ask_an_ai("What is the weather today?", exchange_system)
print(response) # Outputs the AI's response to the question.
"""
if no_history:
exchange = clear_exchange(exchange)
if prompt:
exchange = replace_prompt(exchange, prompt)
if not input:
raise TypeError("`input` must be a string of finite length")
msg = Message.user(input)
exchange.add(msg)
reply = exchange.reply()
return reply
def clear_exchange(exchange: Exchange, clear_tools: bool = False) -> Exchange:
"""Clears the exchange object
Args:
exchange (Exchange): Exchange object to be overwritten. Messages and checkpoints are replaced with empty lists.
clear_tools (bool): Boolean to indicate whether tools should be dropped from the exchange.
Returns:
new_exchange (Exchange)
"""
if clear_tools:
new_exchange = exchange.replace(messages=[], checkpoints=[], tools=())
else:
new_exchange = exchange.replace(messages=[], checkpoints=[])
return new_exchange
def replace_prompt(exchange: Exchange, prompt: str) -> Exchange:
"""Replaces the system prompt
Args:
exchange (Exchange): Exchange object to be overwritten. Messages and checkpoints are replaced with empty lists.
prompt (str): The system prompt.
Returns:
new_exchange (Exchange)
"""
new_exchange = exchange.replace(system=prompt)
return new_exchange

39
src/goose/utils/diff.py Normal file
View File

@@ -0,0 +1,39 @@
from typing import List
from rich.text import Text
def diff(a: str, b: str) -> List[str]:
"""Returns a string containing the unified diff of two strings."""
import difflib
a_lines = a.splitlines()
b_lines = b.splitlines()
# Create a Differ object
d = difflib.Differ()
# Generate the diff
diff = list(d.compare(a_lines, b_lines))
return diff
def pretty_diff(a: str, b: str) -> Text:
"""Returns a pretty-printed diff of two strings."""
diff_lines = diff(a, b)
result = Text()
for line in diff_lines:
if line.startswith("+"):
result.append(line, style="green")
elif line.startswith("-"):
result.append(line, style="red")
elif line.startswith("?"):
result.append(line, style="yellow")
else:
result.append(line, style="dim grey")
result.append("\n")
return result

View File

@@ -0,0 +1,103 @@
import glob
import os
from collections import Counter
from pathlib import Path
from typing import Dict, List, Optional
def create_extensions_list(project_root: str, max_n: int) -> list:
"""Get the top N file extensions in the current project
Args:
project_root (str): Root of the project to analyze
max_n (int): The number of file extensions to return
Returns:
extensions (List[str]): A list of the top N file extensions
"""
if max_n == 0:
raise (ValueError("Number of file extensions must be greater than 0"))
files = create_file_list(project_root, [])
counter = Counter()
for file in files:
file_path = Path(file)
if file_path.suffix: # omit ''
counter[file_path.suffix] += 1
top_n = counter.most_common(max_n)
extensions = [ext for ext, _ in top_n]
return extensions
def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]:
"""Calculate language weighting by file size to match GitHub's methodology.
Args:
files_in_directory (List[str]): Paths to files in the project directory
Returns:
Dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values
"""
# Initialize counters for sizes
size_by_language = Counter()
# Calculate size for files by language
for file_path in files_in_directory:
path = Path(file_path)
if path.suffix:
size_by_language[path.suffix] += os.path.getsize(file_path)
# Calculate total size and language percentages
total_size = sum(size_by_language.values())
language_percentages = {
lang: (size / total_size * 100) if total_size else 0 for lang, size in size_by_language.items()
}
return dict(sorted(language_percentages.items(), key=lambda item: item[1], reverse=True))
def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> List[str]:
"""List all files in a directory with a given extension. Set extension to '' to return all files.
Args:
dir_path (str): The path to the directory
extension (Optional[str]): extension to lookup. Defaults to '' which will return all files.
Returns:
files (List[str]): List of file paths
"""
# add a leading '.' to extension if needed
if extension and not extension.startswith("."):
extension = f".{extension}"
files = glob.glob(f"{dir_path}/**/*{extension}", recursive=True)
return files
def create_file_list(dir_path: str, extensions: List[str]) -> List[str]:
"""Creates a list of files with certain extensions
Args:
dir_path (str): Directory to list files of. Will include files recursively in sub-directories.
extensions (List[str]): List of file extensions to select for. If empty list, return all files
Returns:
final_file_list (List[str]): List of file paths with specified extensions.
"""
# if extensions is empty list, return all files
if not extensions:
return glob.glob(f"{dir_path}/**/*", recursive=True)
# prune out files that do not end with any of the extensions in extensions
final_file_list = []
for ext in extensions:
if ext and not ext.startswith("."):
ext = f".{ext}"
files = glob.glob(f"{dir_path}/**/*{ext}", recursive=True)
final_file_list += files
return final_file_list

View File

@@ -0,0 +1,39 @@
import json
from pathlib import Path
from typing import Dict, Iterator, List
from exchange import Message
from goose.cli.config import SESSION_FILE_SUFFIX
def write_to_file(file_path: Path, messages: List[Message]) -> None:
with open(file_path, "w") as f:
for m in messages:
json.dump(m.to_dict(), f)
f.write("\n")
def read_from_file(file_path: Path) -> List[Message]:
try:
with open(file_path, "r") as f:
messages = [json.loads(m) for m in list(f) if m.strip()]
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to load session due to JSON decode Error: {e}")
return [Message(**m) for m in messages]
def list_sorted_session_files(session_files_directory: Path) -> Dict[str, Path]:
logs = list_session_files(session_files_directory)
return {log.stem: log for log in sorted(logs, key=lambda x: x.stat().st_mtime, reverse=True)}
def list_session_files(session_files_directory: Path) -> Iterator[Path]:
return session_files_directory.glob(f"*{SESSION_FILE_SUFFIX}")
def session_file_exists(session_files_directory: Path) -> bool:
if not session_files_directory.exists():
return False
return any(list_session_files(session_files_directory))

26
src/goose/view.py Normal file
View File

@@ -0,0 +1,26 @@
from attrs import define
from exchange import Exchange
@define
class ExchangeView:
"""A read-only view of the underlying Exchange
Attributes:
processor: A copy of the exchange configured for high capabilities
accelerator: A copy of the exchange configured for high speed
"""
_processor: str
_accelerator: str
_exchange: Exchange
@property
def processor(self) -> Exchange:
return self._exchange.replace(model=self._processor)
@property
def accelerator(self) -> Exchange:
return self._exchange.replace(model=self._accelerator)

2
tests/.ruff.toml Normal file
View File

@@ -0,0 +1,2 @@
lint.select = ["E", "W", "F", "N"]
line-length = 120

View File

@@ -0,0 +1,47 @@
from unittest.mock import patch
import pytest
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput
@pytest.fixture
def mock_prompt_session():
with patch("prompt_toolkit.PromptSession") as mock_prompt_session:
yield mock_prompt_session
def test_get_save_session_name(mock_prompt_session):
mock_prompt_session.prompt.return_value = "my_session"
goose_prompt_session = GoosePromptSession(mock_prompt_session)
assert goose_prompt_session.get_save_session_name() == "my_session"
def test_get_user_input_to_continue(mock_prompt_session):
mock_prompt_session.prompt.return_value = "input_value"
goose_prompt_session = GoosePromptSession(mock_prompt_session)
user_input = goose_prompt_session.get_user_input()
assert user_input == UserInput(PromptAction.CONTINUE, "input_value")
@pytest.mark.parametrize("exit_input", ["exit", ":q"])
def test_get_user_input_to_exit(exit_input, mock_prompt_session):
mock_prompt_session.prompt.return_value = exit_input
goose_prompt_session = GoosePromptSession(mock_prompt_session)
user_input = goose_prompt_session.get_user_input()
assert user_input == UserInput(PromptAction.EXIT)
@pytest.mark.parametrize("error", [EOFError, KeyboardInterrupt])
def test_get_user_input_to_exit_when_error_occurs(error, mock_prompt_session):
mock_prompt_session.prompt.side_effect = error
goose_prompt_session = GoosePromptSession(mock_prompt_session)
user_input = goose_prompt_session.get_user_input()
assert user_input == UserInput(PromptAction.EXIT)

View File

@@ -0,0 +1,253 @@
from goose.cli.prompt.lexer import (
PromptLexer,
command_itself,
completion_for_command,
value_for_command,
)
from prompt_toolkit.document import Document
# Helper function to create a Document and lexer instance
def create_lexer_and_document(commands, text):
lexer = PromptLexer(commands)
document = Document(text)
return lexer, document
# Test cases
def test_lex_document_command():
lexer, document = create_lexer_and_document(["file"], "/file:example.txt")
tokens = lexer.lex_document(document)
expected_tokens = [("class:command", "/file:"), ("class:parameter", "example.txt")]
assert tokens(0) == expected_tokens
def test_lex_document_partial_command():
lexer, document = create_lexer_and_document(["file"], "/fi")
tokens = lexer.lex_document(document)
expected_tokens = [("class:command", "/fi")]
assert tokens(0) == expected_tokens
def test_lex_document_with_text():
lexer, document = create_lexer_and_document(["file"], "Some text /file:example.txt")
tokens = lexer.lex_document(document)
expected_tokens = [
("class:text", "S"),
("class:text", "o"),
("class:text", "m"),
("class:text", "e"),
("class:text", " "),
("class:text", "t"),
("class:text", "e"),
("class:text", "x"),
("class:text", "t"),
("class:text", " "),
("class:command", "/file:"),
("class:parameter", "example.txt"),
]
assert tokens(0) == expected_tokens
def test_lex_document_with_command_in_middle():
lexer, document = create_lexer_and_document(["file"], "Some text /file:example.txt more text")
tokens = lexer.lex_document(document)
expected_tokens = [
("class:text", "S"),
("class:text", "o"),
("class:text", "m"),
("class:text", "e"),
("class:text", " "),
("class:text", "t"),
("class:text", "e"),
("class:text", "x"),
("class:text", "t"),
("class:text", " "),
("class:command", "/file:"),
("class:parameter", "example.txt"),
("class:text", " "),
("class:text", "m"),
("class:text", "o"),
("class:text", "r"),
("class:text", "e"),
("class:text", " "),
("class:text", "t"),
("class:text", "e"),
("class:text", "x"),
("class:text", "t"),
]
actual_tokens = list(tokens(0))
assert actual_tokens == expected_tokens
def test_lex_document_multiple_commands():
lexer, document = create_lexer_and_document(
["command", "anothercommand"],
"/command:example1.txt more text /anothercommand:example2.txt",
)
tokens = lexer.lex_document(document)
expected_tokens = [
("class:command", "/command:"),
("class:parameter", "example1.txt"),
("class:text", " "),
("class:text", "m"),
("class:text", "o"),
("class:text", "r"),
("class:text", "e"),
("class:text", " "),
("class:text", "t"),
("class:text", "e"),
("class:text", "x"),
("class:text", "t"),
("class:text", " "),
("class:command", "/anothercommand:"),
("class:parameter", "example2.txt"),
]
actual_tokens = list(tokens(0))
assert actual_tokens == expected_tokens
def test_lex_document_multiple_same_commands():
lexer, document = create_lexer_and_document(
["command"],
"/command:example1.txt more text /command:example2.txt",
)
tokens = lexer.lex_document(document)
expected_tokens = [
("class:command", "/command:"),
("class:parameter", "example1.txt"),
("class:text", " "),
("class:text", "m"),
("class:text", "o"),
("class:text", "r"),
("class:text", "e"),
("class:text", " "),
("class:text", "t"),
("class:text", "e"),
("class:text", "x"),
("class:text", "t"),
("class:text", " "),
("class:command", "/command:"),
("class:parameter", "example2.txt"),
]
actual_tokens = list(tokens(0))
assert actual_tokens == expected_tokens
def test_lex_document_two_half_commands():
lexer, document = create_lexer_and_document(
["command"],
"/comma /com",
)
tokens = lexer.lex_document(document)
expected_tokens = [
("class:text", "/"),
("class:text", "c"),
("class:text", "o"),
("class:text", "m"),
("class:text", "m"),
("class:text", "a"),
("class:text", " "),
("class:command", "/com"),
]
actual_tokens = list(tokens(0))
assert actual_tokens == expected_tokens
def test_lex_document_command_attached_to_pre_string():
lexer, document = create_lexer_and_document(
["command"],
"some/command:example.txt",
)
expected_tokens = [
("class:text", "s"),
("class:text", "o"),
("class:text", "m"),
("class:text", "e"),
("class:text", "/"),
("class:text", "c"),
("class:text", "o"),
("class:text", "m"),
("class:text", "m"),
("class:text", "a"),
("class:text", "n"),
("class:text", "d"),
("class:text", ":"),
("class:text", "e"),
("class:text", "x"),
("class:text", "a"),
("class:text", "m"),
("class:text", "p"),
("class:text", "l"),
("class:text", "e"),
("class:text", "."),
("class:text", "t"),
("class:text", "x"),
("class:text", "t"),
]
tokens = lexer.lex_document(document)
actual_tokens = list(tokens(0))
assert actual_tokens == expected_tokens
def test_lex_document_partial_command_attached_to_pre_string():
lexer, document = create_lexer_and_document(
["command"],
"some/com",
)
tokens = lexer.lex_document(document)
expected_tokens = [
("class:text", "s"),
("class:text", "o"),
("class:text", "m"),
("class:text", "e"),
("class:text", "/"),
("class:text", "c"),
("class:text", "o"),
("class:text", "m"),
]
actual_tokens = list(tokens(0))
assert actual_tokens == expected_tokens
def test_lex_document_no_command():
lexer, document = create_lexer_and_document([], "Some random text")
tokens = lexer.lex_document(document)
expected_tokens = [("class:text", character) for character in "Some random text"]
actual_tokens = list(tokens(0))
assert actual_tokens == expected_tokens
def test_lex_document_ending_char_of_parameter_is_symbol():
lexer, document = create_lexer_and_document(
["command"],
"/command:example.txt/",
)
expected_tokens = [
("class:command", "/command:"),
("class:parameter", "example.txt/"),
]
tokens = lexer.lex_document(document)
actual_tokens = list(tokens(0))
assert actual_tokens == expected_tokens
def test_command_itself():
pattern = command_itself("file:")
matches = pattern.match("/file:example.txt")
assert matches is not None
assert matches.group(1) == "/file:"
def test_value_for_command():
pattern = value_for_command("file:")
matches = pattern.search("/file:example.txt")
assert matches is not None
assert matches.group(1) == "example.txt"
def test_completion_for_command():
pattern = completion_for_command("file:")
matches = pattern.search("/file:")
assert matches is not None
assert matches.group(1) == "file:"

View File

@@ -0,0 +1,37 @@
from unittest.mock import MagicMock, patch
import pytest
from goose.cli.prompt.prompt_validator import PromptValidator
from prompt_toolkit.validation import ValidationError
@pytest.fixture
def validator():
return PromptValidator()
@patch("prompt_toolkit.document.Document.text")
def test_validate_should_not_raise_error_when_input_is_none(document, validator):
try:
validator.validate(create_mock_document(None))
except Exception as e:
pytest.fail(f"An error was raised: {e}")
@patch("prompt_toolkit.document.Document.text", return_value="user typed something")
def test_validate_should_not_raise_error_when_user_has_input(document, validator):
try:
validator.validate(create_mock_document("user typed something"))
except Exception as e:
pytest.fail(f"An error was raised: {e}")
def test_validate_should_raise_validation_error_when_user_has_empty_input(validator):
with pytest.raises(ValidationError):
validator.validate(create_mock_document(""))
def create_mock_document(text: str) -> MagicMock:
document = MagicMock()
document.text = text
return document

View File

@@ -0,0 +1,15 @@
from goose.cli.prompt.user_input import PromptAction, UserInput
def test_user_input_with_action_continue():
input = UserInput(action=PromptAction.CONTINUE, text="Hello")
assert input.to_continue() is True
assert input.to_exit() is False
assert input.text == "Hello"
def test_user_input_with_action_exit():
input = UserInput(action=PromptAction.EXIT)
assert input.to_continue() is False
assert input.to_exit() is True
assert input.text is None

81
tests/cli/test_config.py Normal file
View File

@@ -0,0 +1,81 @@
from unittest.mock import patch
import pytest
from goose.cli.config import ensure_config, read_config, session_path, write_config
from goose.profile import default_profile
@pytest.fixture
def mock_profile_config_path(tmp_path):
with patch("goose.cli.config.PROFILES_CONFIG_PATH", tmp_path / "profiles.yaml") as mock_path:
yield mock_path
@pytest.fixture
def mock_default_model_configuration():
with patch(
"goose.cli.config.default_model_configuration", return_value=("provider", "processor", "accelerator")
) as mock_default_model_configuration:
yield mock_default_model_configuration
def test_read_write_config(mock_profile_config_path, profile_factory):
profiles = {
"profile1": profile_factory({"provider": "providerA"}),
}
write_config(profiles)
assert read_config() == profiles
def test_ensure_config_create_profiles_file_with_default_profile(
mock_profile_config_path, mock_default_model_configuration
):
assert not mock_profile_config_path.exists()
ensure_config(name="default")
assert mock_profile_config_path.exists()
assert read_config() == {"default": default_profile(*mock_default_model_configuration())}
def test_ensure_config_add_default_profile(mock_profile_config_path, profile_factory, mock_default_model_configuration):
existing_profile = profile_factory({"provider": "providerA"})
write_config({"profile1": existing_profile})
ensure_config(name="default")
assert read_config() == {
"profile1": existing_profile,
"default": default_profile(*mock_default_model_configuration()),
}
@patch("goose.cli.config.Confirm.ask", return_value=True)
def test_ensure_config_overwrite_default_profile(
mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration
):
existing_profile = profile_factory({"provider": "providerA"})
profile_name = "default"
write_config({profile_name: existing_profile})
expected_default_profile = default_profile(*mock_default_model_configuration())
assert ensure_config(name="default") == expected_default_profile
assert read_config() == {"default": expected_default_profile}
@patch("goose.cli.config.Confirm.ask", return_value=False)
def test_ensure_config_keep_original_default_profile(
mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration
):
existing_profile = profile_factory({"provider": "providerA"})
profile_name = "default"
write_config({profile_name: existing_profile})
assert ensure_config(name="default") == existing_profile
assert read_config() == {"default": existing_profile}
def test_session_path(mock_sessions_path):
assert session_path("session1") == mock_sessions_path / "session1.jsonl"

80
tests/cli/test_main.py Normal file
View File

@@ -0,0 +1,80 @@
from datetime import datetime
from time import time
from unittest.mock import MagicMock, patch
import pytest
from click.testing import CliRunner
from exchange import Message
from goose.cli.main import goose_cli
@pytest.fixture
def mock_print():
with patch("goose.cli.main.print") as mock_print:
yield mock_print
@pytest.fixture
def mock_session_files_path(tmp_path):
with patch("goose.cli.main.SESSIONS_PATH", tmp_path) as session_files_path:
yield session_files_path
@pytest.fixture
def mock_session():
with patch("goose.cli.main.Session") as mock_session_class:
mock_session_instance = MagicMock()
mock_session_class.return_value = mock_session_instance
yield mock_session_class, mock_session_instance
def test_session_resume_command_with_session_name(mock_session):
mock_session_class, mock_session_instance = mock_session
runner = CliRunner()
runner.invoke(goose_cli, ["session", "resume", "session1", "--profile", "default"])
mock_session_class.assert_called_once_with(name="session1", profile="default")
mock_session_instance.run.assert_called_once()
def test_session_resume_command_without_session_name_without_session_files(
mock_print, mock_session_files_path, mock_session
):
_, mock_session_instance = mock_session
runner = CliRunner()
runner.invoke(goose_cli, ["session", "resume"])
mock_print.assert_called_with("No sessions found.")
mock_session_instance.run.assert_not_called()
def test_session_resume_command_without_session_name_use_latest_session(
mock_print, mock_session_files_path, mock_session, create_session_file
):
mock_session_class, mock_session_instance = mock_session
for index, session_name in enumerate(["first", "second"]):
create_session_file([Message.user("Hello1")], mock_session_files_path / f"{session_name}.jsonl", time() + index)
runner = CliRunner()
runner.invoke(goose_cli, ["session", "resume", "--profile", "default"])
second_file_path = mock_session_files_path / "second.jsonl"
mock_print.assert_called_once_with(f"Resuming most recent session: second from {second_file_path}")
mock_session_class.assert_called_once_with(name="second", profile="default")
mock_session_instance.run.assert_called_once()
def test_session_list_command(mock_print, mock_session_files_path, create_session_file):
create_session_file([Message.user("Hello")], mock_session_files_path / "abc.jsonl")
runner = CliRunner()
runner.invoke(goose_cli, ["session", "list"])
file_time = datetime.fromtimestamp(mock_session_files_path.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")
mock_print.assert_called_with(f"{file_time} abc")
def test_session_clear_command(mock_session_files_path, create_session_file):
for index, session_name in enumerate(["first", "second"]):
create_session_file([Message.user("Hello1")], mock_session_files_path / f"{session_name}.jsonl", time() + index)
runner = CliRunner()
runner.invoke(goose_cli, ["session", "clear", "--keep", "1"])
session_files = list(mock_session_files_path.glob("*.jsonl"))
assert len(session_files) == 1
assert session_files[0].stem == "second"

134
tests/cli/test_session.py Normal file
View File

@@ -0,0 +1,134 @@
from unittest.mock import MagicMock, patch
import pytest
from exchange import Message
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.cli.session import Session
from prompt_toolkit import PromptSession
SPECIFIED_SESSION_NAME = "mySession"
SESSION_NAME = "test"
@pytest.fixture
def mock_specified_session_name():
with patch.object(PromptSession, "prompt", return_value=SPECIFIED_SESSION_NAME) as specified_session_name:
yield specified_session_name
@pytest.fixture
def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory):
with patch("goose.cli.session.build_exchange", return_value=exchange_factory()), patch(
"goose.cli.session.load_profile", return_value=profile_factory()
), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch(
"goose.cli.session.load_provider", return_value="provider"
):
mock_session_notifier.return_value = MagicMock()
def create_session(session_attributes: dict = {}):
return Session(**session_attributes)
yield create_session
def test_session_does_not_extend_last_user_message_on_init(
create_session_with_mock_configs, mock_sessions_path, create_session_file
):
messages = [Message.user("Hello"), Message.assistant("Hi"), Message.user("Last should be removed")]
create_session_file(messages, mock_sessions_path / f"{SESSION_NAME}.jsonl")
session = create_session_with_mock_configs({"name": SESSION_NAME})
print("Messages after session init:", session.exchange.messages) # Debugging line
assert len(session.exchange.messages) == 2
assert [message.text for message in session.exchange.messages] == ["Hello", "Hi"]
def test_save_session_create_session(mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name):
session = create_session_with_mock_configs()
session.exchange.messages.append(Message.assistant("Hello"))
session.save_session()
session_file = mock_sessions_path / f"{SPECIFIED_SESSION_NAME}.jsonl"
assert session_file.exists()
saved_messages = session.load_session()
assert len(saved_messages) == 1
assert saved_messages[0].text == "Hello"
def test_save_session_resume_session_new_file(
mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name, create_session_file
):
with patch("goose.cli.session.confirm", return_value=False):
existing_messages = [Message.assistant("existing_message")]
existing_session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
create_session_file(existing_messages, existing_session_file)
new_session_file = mock_sessions_path / f"{SPECIFIED_SESSION_NAME}.jsonl"
assert not new_session_file.exists()
session = create_session_with_mock_configs({"name": SESSION_NAME})
session.exchange.messages.append(Message.assistant("new_message"))
session.save_session()
assert new_session_file.exists()
assert existing_session_file.exists()
saved_messages = session.load_session()
assert [message.text for message in saved_messages] == ["existing_message", "new_message"]
def test_save_session_resume_session_existing_session_file(
mock_sessions_path, create_session_with_mock_configs, create_session_file
):
with patch("goose.cli.session.confirm", return_value=True):
existing_messages = [Message.assistant("existing_message")]
existing_session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
create_session_file(existing_messages, existing_session_file)
session = create_session_with_mock_configs({"name": SESSION_NAME})
session.exchange.messages.append(Message.assistant("new_message"))
session.save_session()
saved_messages = session.load_session()
assert [message.text for message in saved_messages] == ["existing_message", "new_message"]
def test_process_first_message_return_message(create_session_with_mock_configs):
session = create_session_with_mock_configs()
with patch.object(
GoosePromptSession, "get_user_input", return_value=UserInput(action=PromptAction.CONTINUE, text="Hello")
):
message = session.process_first_message()
assert message.text == "Hello"
assert len(session.exchange.messages) == 0
def test_process_first_message_to_exit(create_session_with_mock_configs):
session = create_session_with_mock_configs()
with patch.object(GoosePromptSession, "get_user_input", return_value=UserInput(action=PromptAction.EXIT)):
message = session.process_first_message()
assert message is None
def test_process_first_message_return_last_exchange_message(create_session_with_mock_configs):
session = create_session_with_mock_configs()
session.exchange.messages.append(Message.user("Hi"))
message = session.process_first_message()
assert message.text == "Hi"
assert len(session.exchange.messages) == 0
def test_generate_session_name(create_session_with_mock_configs):
session = create_session_with_mock_configs()
with patch.object(GoosePromptSession, "get_save_session_name", return_value=SPECIFIED_SESSION_NAME):
session.generate_session_name()
assert session.name == SPECIFIED_SESSION_NAME

59
tests/conftest.py Normal file
View File

@@ -0,0 +1,59 @@
import json
import os
from time import time
from unittest.mock import Mock, patch
import pytest
from exchange import Exchange
from goose.profile import Profile
@pytest.fixture
def profile_factory():
def _create_profile(attributes={}):
profile_attrs = {
"provider": "mock_provider",
"processor": "mock_processor",
"accelerator": "mock_accelerator",
"moderator": "mock_moderator",
"toolkits": [],
}
profile_attrs.update(attributes)
return Profile(**profile_attrs)
return _create_profile
@pytest.fixture
def exchange_factory():
def _create_exchange(attributes={}):
exchange_attrs = {
"provider": "mock_provider",
"system": "mock_system",
"tools": [],
"moderator": Mock(),
"model": "mock_model",
}
exchange_attrs.update(attributes)
return Exchange(**exchange_attrs)
return _create_exchange
@pytest.fixture
def mock_sessions_path(tmp_path):
with patch("goose.cli.config.SESSIONS_PATH", tmp_path) as mock_path:
yield mock_path
@pytest.fixture
def create_session_file():
def _create_session_file(messages, session_file_path, mtime=time()):
with open(session_file_path, "w") as session_file:
for m in messages:
json.dump(m.to_dict(), session_file)
session_file.write("\n")
session_file.close()
os.utime(session_file_path, (mtime, mtime))
return _create_session_file

50
tests/test_completer.py Normal file
View File

@@ -0,0 +1,50 @@
from unittest.mock import Mock
import pytest
from goose.cli.prompt.completer import GoosePromptCompleter
from goose.command.base import Command
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
# Mock Command class
dummy_command = Mock(spec=Command)
dummy_command.get_completions = Mock(
return_value=[
Completion(text="completion1"),
Completion(text="completion2"),
]
)
commands_list = {"test_command1": dummy_command, "test_command2": dummy_command}
@pytest.fixture
def completer():
return GoosePromptCompleter(commands=commands_list)
def test_get_command_completions(completer):
document = Document(text="/test_command1:input")
completions = list(completer.get_command_completions(document))
assert len(completions) == 2
assert completions[0].text == "completion1"
assert completions[1].text == "completion2"
def test_get_command_name_completions(completer):
document = Document(text="/test")
completions = list(completer.get_command_name_completions(document))
print(completions)
assert len(completions) == 2
assert completions[0].text == "test_command1"
assert completions[1].text == "test_command2"
def test_get_completions(completer):
document = Document(text="/test_command1:input")
completions = list(completer.get_completions(document, None))
print(completions)
assert len(completions) == 2
assert completions[0].text == "completion1"
assert completions[1].text == "completion2"

View File

View File

@@ -0,0 +1,68 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock, Mock
import pytest
from goose.toolkit.base import Requirements
from goose.toolkit.developer import Developer
@pytest.fixture
def temp_dir():
with TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def developer_toolkit():
toolkit = Developer(notifier=MagicMock(), requires=Requirements(""))
# This mocking ensures that that the safety check is considered a pass in shell calls
toolkit.exchange_view = Mock()
toolkit.exchange_view.processor.replace.return_value = Mock()
toolkit.exchange_view.processor.replace.return_value.messages = []
toolkit.exchange_view.processor.replace.return_value.add = Mock()
toolkit.exchange_view.processor.replace.return_value.reply.return_value.text = "3"
toolkit.exchange_view.processor.replace.return_value.messages = [Mock()]
return toolkit
def test_update_plan(developer_toolkit):
tasks = [
{"description": "Task 1", "status": "planned"},
{"description": "Task 2", "status": "complete"},
{"description": "Task 3", "status": "in-progress"},
]
updated_tasks = developer_toolkit.update_plan(tasks)
assert updated_tasks == tasks
def test_patch_file(temp_dir, developer_toolkit):
test_file = temp_dir / "test.txt"
before_content = "Hello World"
after_content = "Hello Goose"
test_file.write_text(before_content)
developer_toolkit.patch_file(test_file.as_posix(), before_content, after_content)
assert test_file.read_text() == after_content
def test_read_file(temp_dir, developer_toolkit):
test_file = temp_dir / "test.txt"
content = "Hello World"
test_file.write_text(content)
read_content = developer_toolkit.read_file(test_file.as_posix())
assert content in read_content
def test_shell(developer_toolkit):
command = "echo Hello World"
result = developer_toolkit.shell(command)
assert "Hello World" in result
def test_write_file(temp_dir, developer_toolkit):
test_file = temp_dir / "test.txt"
content = "Hello World"
developer_toolkit.write_file(test_file.as_posix(), content)
assert test_file.read_text() == content

136
tests/utils/test_ask.py Normal file
View File

@@ -0,0 +1,136 @@
from unittest.mock import MagicMock, patch
import pytest
from exchange import Exchange
from goose.utils.ask import ask_an_ai, clear_exchange, replace_prompt
# tests for `ask_an_ai`
def test_ask_an_ai_empty_input():
"""Test that function raises TypeError if input is empty."""
exchange = MagicMock(spec=Exchange)
with pytest.raises(TypeError):
ask_an_ai("", exchange)
def test_ask_an_ai_no_history():
"""Test the no_history functionality."""
exchange = MagicMock(spec=Exchange)
with patch("goose.utils.ask.clear_exchange") as mock_clear:
ask_an_ai("Test input", exchange, no_history=True)
mock_clear.assert_called_once_with(exchange)
def test_ask_an_ai_prompt_replacement():
"""Test that the prompt is replaced if provided."""
exchange = MagicMock(spec=Exchange)
prompt = "New prompt"
with patch("goose.utils.ask.replace_prompt") as mock_replace_prompt:
# Configure the mock to return a new mock object with the same spec
modified_exchange = MagicMock(spec=Exchange)
mock_replace_prompt.return_value = modified_exchange
ask_an_ai("Test input", exchange, prompt=prompt, no_history=False)
# Check if replace_prompt was called correctly
mock_replace_prompt.assert_called_once_with(exchange, prompt)
# Assert that the modified exchange was returned correctly
assert mock_replace_prompt.return_value is modified_exchange, "Should return the modified exchange mock"
def test_ask_an_ai_exchange_usage():
"""Test that the exchange adds and processes the message correctly."""
exchange = MagicMock(spec=Exchange)
input_text = "Test input"
message_mock = MagicMock(return_value="Mocked Message")
with patch("goose.utils.ask.Message.user", new=message_mock):
ask_an_ai(input_text, exchange, no_history=False)
# Assert that Message.user was called with the correct input
message_mock.assert_called_once_with(input_text)
# Assert that exchange.add was called with the mocked message
exchange.add.assert_called_once_with("Mocked Message")
exchange.reply.assert_called_once()
def test_ask_an_ai_return_value():
"""Test that the function returns the correct reply."""
exchange = MagicMock(spec=Exchange)
expected_reply = "AI response"
exchange.reply.return_value = expected_reply
result = ask_an_ai("Test input", exchange, no_history=False)
assert result == expected_reply, "Function should return the reply from the exchange."
# tests for `clear_exchange`
def test_clear_exchange_without_tools():
"""Test clearing messages and checkpoints but not tools."""
# Arrange
exchange = MagicMock(spec=Exchange)
# Act
new_exchange = clear_exchange(exchange, clear_tools=False)
# Assert
exchange.replace.assert_called_once_with(messages=[], checkpoints=[])
assert new_exchange == exchange.replace.return_value, "Should return the modified exchange"
def test_clear_exchange_with_tools():
"""Test clearing messages, checkpoints, and tools."""
# Arrange
exchange = MagicMock(spec=Exchange)
# Act
new_exchange = clear_exchange(exchange, clear_tools=True)
# Assert
exchange.replace.assert_called_once_with(messages=[], checkpoints=[], tools=())
assert new_exchange == exchange.replace.return_value, "Should return the modified exchange with tools cleared"
def test_clear_exchange_return_value():
"""Test that the returned value is a new exchange object."""
# Arrange
exchange = MagicMock(spec=Exchange)
new_exchange_mock = MagicMock(spec=Exchange)
exchange.replace.return_value = new_exchange_mock
# Act
new_exchange = clear_exchange(exchange, clear_tools=False)
# Assert
assert new_exchange == new_exchange_mock, "Returned exchange should be the new exchange instance"
# tests for `replace_prompt`
def test_replace_prompt():
"""Test that the system prompt is correctly replaced."""
# Arrange
exchange = MagicMock(spec=Exchange)
prompt = "New system prompt"
# Act
new_exchange = replace_prompt(exchange, prompt)
# Assert
exchange.replace.assert_called_once_with(system=prompt)
assert new_exchange == exchange.replace.return_value, "Should return the modified exchange with the new prompt"
def test_replace_prompt_return_value():
"""Test that the returned value is a new exchange object."""
# Arrange
exchange = MagicMock(spec=Exchange)
expected_new_exchange = MagicMock(spec=Exchange)
exchange.replace.return_value = expected_new_exchange
# Act
new_exchange = replace_prompt(exchange, "Another prompt")
# Assert
assert new_exchange == expected_new_exchange, "Returned exchange should be the new exchange instance"

View File

@@ -0,0 +1,192 @@
from unittest.mock import patch
import pytest
from goose.utils.file_utils import (
create_extensions_list,
create_language_weighting,
) # Adjust the import path as necessary
# tests for `create_extensions_list`
def test_create_extensions_list_valid_input():
"""Test with valid input and multiple file extensions."""
project_root = "/fake/project/root"
max_n = 3
files = [
"/fake/project/root/file1.py",
"/fake/project/root/file2.py",
"/fake/project/root/file3.md",
"/fake/project/root/file4.md",
"/fake/project/root/file5.txt",
"/fake/project/root/file6.py",
"/fake/project/root/file7.md",
]
with patch("goose.utils.file_utils.create_file_list", return_value=files):
extensions = create_extensions_list(project_root, max_n)
assert extensions == [".py", ".md", ".txt"], "Should return the top 3 extensions in the correct order"
def test_create_extensions_list_zero_max_n():
"""Test that a ValueError is raised when max_n is 0."""
project_root = "/fake/project/root"
max_n = 0
with pytest.raises(ValueError, match="Number of file extensions must be greater than 0"):
create_extensions_list(project_root, max_n)
def test_create_extensions_list_no_files():
"""Test with a project root that contains no files."""
project_root = "/fake/project/root"
max_n = 3
with patch("goose.utils.file_utils.create_file_list", return_value=[]):
extensions = create_extensions_list(project_root, max_n)
assert extensions == [], "Should return an empty list when no files are present"
def test_create_extensions_list_fewer_extensions_than_max_n():
"""Test when there are fewer unique extensions than max_n."""
project_root = "/fake/project/root"
max_n = 5
files = [
"/fake/project/root/file1.py",
"/fake/project/root/file2.py",
"/fake/project/root/file3.md",
]
with patch("goose.utils.file_utils.create_file_list", return_value=files):
extensions = create_extensions_list(project_root, max_n)
assert extensions == [".py", ".md"], "Should return all available extensions when fewer than max_n"
def test_create_extensions_list_files_without_extensions():
"""Test that files without extensions are ignored."""
project_root = "/fake/project/root"
max_n = 3
files = [
"/fake/project/root/file1",
"/fake/project/root/file2.py",
"/fake/project/root/file3",
"/fake/project/root/file4.md",
]
with patch("goose.utils.file_utils.create_file_list", return_value=files):
extensions = create_extensions_list(project_root, max_n)
assert extensions == [".py", ".md"], "Should ignore files without extensions"
# tests for `create_language_weighting`
def test_create_language_weighting_normal_case():
"""Test the function with multiple files and different sizes."""
files = [
"/fake/project/file1.py",
"/fake/project/file2.py",
"/fake/project/file3.md",
"/fake/project/file4.txt",
]
sizes = {
"/fake/project/file1.py": 100,
"/fake/project/file2.py": 200,
"/fake/project/file3.md": 50,
"/fake/project/file4.txt": 150,
}
# Mocking os.path.getsize to return different sizes for different files
with patch("os.path.getsize") as mock_getsize:
mock_getsize.side_effect = lambda file: sizes[file]
result = create_language_weighting(files)
total = sum(sizes.values())
expected_result = {
".py": 300 / total * 100, # 300 out of 600 total
".txt": 150 / total * 100, # 150 out of 600 total
".md": 50 / total * 100, # 50 out of 600 total
}
# Check if the result matches the expected output
assert result[".py"] == pytest.approx(expected_result.get(".py"), 0.01)
assert result[".txt"] == pytest.approx(expected_result.get(".txt"), 0.01)
assert result[".md"] == pytest.approx(expected_result.get(".md"), 0.01)
def test_create_language_weighting_no_files():
"""Test the function when no files are provided."""
files = []
result = create_language_weighting(files)
assert result == {}, "Should return an empty dictionary when no files are provided"
def test_create_language_weighting_files_without_extensions():
"""Test the function when files have no extensions."""
files = [
"/fake/project/file1",
"/fake/project/file2",
]
with patch("os.path.getsize", return_value=100):
result = create_language_weighting(files)
assert result == {}, "Should return an empty dictionary when files have no extensions"
def test_create_language_weighting_zero_total_size():
"""Test the function when all files have a size of 0."""
files = [
"/fake/project/file1.py",
"/fake/project/file2.py",
]
with patch("os.path.getsize", return_value=0):
result = create_language_weighting(files)
assert result == {".py": 0}
def test_create_language_weighting_single_file():
"""Test the function with a single file."""
files = [
"/fake/project/file1.py",
]
with patch("os.path.getsize", return_value=100):
result = create_language_weighting(files)
assert result == {".py": 100.0}, "Should return 100% for the single file's extension"
def test_create_language_weighting_mixed_extensions():
"""Test the function with files of mixed extensions and sizes."""
files = [
"/fake/project/file1.py",
"/fake/project/file2.py",
"/fake/project/file3.md",
"/fake/project/file4.txt",
"/fake/project/file5.md",
]
with patch("os.path.getsize") as mock_getsize:
mock_getsize.side_effect = lambda file: {
"/fake/project/file1.py": 100,
"/fake/project/file2.py": 100,
"/fake/project/file3.md": 200,
"/fake/project/file4.txt": 300,
"/fake/project/file5.md": 100,
}[file]
result = create_language_weighting(files)
expected_result = {
".txt": 37.5, # 300 out of 800 total
".md": 37.5, # 300 out of 800 total
".py": 25.0, # 200 out of 800 total
}
assert result[".txt"] == pytest.approx(expected_result.get(".txt"), 0.01)
assert result[".md"] == pytest.approx(expected_result.get(".md"), 0.01)
assert result[".py"] == pytest.approx(expected_result.get(".py"), 0.01)

View File

@@ -0,0 +1,77 @@
from pathlib import Path
import pytest
from exchange import Message
from goose.utils.session_file import list_sorted_session_files, read_from_file, session_file_exists, write_to_file
@pytest.fixture
def file_path(tmp_path):
return tmp_path / "test_file.jsonl"
def test_read_write_to_file(file_path):
messages = [
Message.user("prompt1"),
Message.user("prompt2"),
]
write_to_file(file_path, messages)
assert file_path.exists()
assert read_from_file(file_path) == messages
def test_read_from_file_non_existing_file(tmp_path):
with pytest.raises(FileNotFoundError):
read_from_file(tmp_path / "no_existing.json")
def test_read_from_file_non_jsonl_file(file_path):
file_path.write_text("Hello World")
with pytest.raises(RuntimeError):
read_from_file(file_path)
def test_list_sorted_session_files(tmp_path):
session_files_directory = tmp_path / "session_files_dir"
session_files_directory.mkdir()
file_names = ["file1", "file2", "file3"]
created_session_files = [create_session_file(session_files_directory, file_name) for file_name in file_names]
sorted_files = list_sorted_session_files(session_files_directory)
assert sorted_files == {
"file3": created_session_files[2],
"file2": created_session_files[1],
"file1": created_session_files[0],
}
def test_list_sorted_session_without_session_files(tmp_path):
session_files_directory = tmp_path / "session_files_dir"
sorted_files = list_sorted_session_files(session_files_directory)
assert sorted_files == {}
def test_session_file_exists_return_false_when_directory_does_not_exist(tmp_path):
session_files_directory = tmp_path / "session_files_dir"
assert not session_file_exists(session_files_directory)
def test_session_file_exists_return_false_when_no_session_file_exists(tmp_path):
session_files_directory = tmp_path / "session_files_dir"
session_files_directory.mkdir()
assert not session_file_exists(session_files_directory)
def test_session_file_exists_return_true_when_session_file_exists(tmp_path):
session_files_directory = tmp_path / "session_files_dir"
session_files_directory.mkdir()
create_session_file(session_files_directory, "session1")
assert session_file_exists(session_files_directory)
def create_session_file(file_path, file_name) -> Path:
file = file_path / f"{file_name}.jsonl"
file.touch()
return file

63
tests/utils/test_utils.py Normal file
View File

@@ -0,0 +1,63 @@
import string
import pytest
from goose.utils import droid, ensure, ensure_list, load_plugins
class MockClass:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return self.name == other.name
def test_load_plugins():
plugins = load_plugins("exchange.provider")
assert isinstance(plugins, dict)
assert len(plugins) > 0
def test_ensure_with_class():
mock_class = MockClass("foo")
assert ensure(MockClass)(mock_class) == mock_class
def test_ensure_with_dictionary():
mock_class = ensure(MockClass)({"name": "foo"})
assert mock_class == MockClass("foo")
def test_ensure_with_invalid_dictionary():
with pytest.raises(TypeError):
ensure(MockClass)({"age": "foo"})
def test_ensure_with_list():
mock_class = ensure(MockClass)(["foo"])
assert mock_class == MockClass("foo")
def test_ensure_with_invalid_list():
with pytest.raises(TypeError):
ensure(MockClass)(["foo", "bar"])
def test_ensure_with_value():
mock_class = ensure(MockClass)("foo")
assert mock_class == MockClass("foo")
def test_ensure_list():
obj_list = ensure_list(MockClass)(["foo", "bar"])
assert obj_list == [MockClass("foo"), MockClass("bar")]
def test_droid():
result = droid()
assert isinstance(result, str)
assert len(result) == 4
for character in [result[i] for i in [0, 2]]:
assert character in string.ascii_lowercase, "should be in lower case"
for character in [result[i] for i in [1, 3]]:
assert character in string.digits, "should be a digit"