mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 07:24:24 +01:00
chore: initial commit
Co-authored-by: Lifei Zhou <lifei@squareup.com> Co-authored-by: Mic Neale <micn@tbd.email> Co-authored-by: Lily Delalande <ldelalande@squareup.com> Co-authored-by: Bradley Axen <baxen@squareup.com> Co-authored-by: Andy Lane <alane@squareup.com> Co-authored-by: Elena Zherdeva <ezherdeva@squareup.com> Co-authored-by: Zaki Ali <zaki@squareup.com> Co-authored-by: Salman Mohammed <smohammed@squareup.com>
This commit is contained in:
140
.gitignore
vendored
Normal file
140
.gitignore
vendored
Normal file
@@ -0,0 +1,140 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# Ignore mkdocs site files generated locally for testing/validation, but generated
|
||||
# at buildtime in production
|
||||
site/
|
||||
docs/docs/notebooks*
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Hermit
|
||||
.hermit
|
||||
|
||||
# VSCode
|
||||
.vscode
|
||||
|
||||
# Autogenerated docs files
|
||||
docs/docs/reference
|
||||
|
||||
## goose session files
|
||||
.goose
|
||||
6
.ruff.toml
Normal file
6
.ruff.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
lint.select = ["E", "W", "F", "N", "ANN"]
|
||||
lint.ignore = ["ANN101"]
|
||||
exclude = [
|
||||
"docs",
|
||||
]
|
||||
line-length = 120
|
||||
166
ARCHITECTURE.md
Normal file
166
ARCHITECTURE.md
Normal file
@@ -0,0 +1,166 @@
|
||||
# Architecture
|
||||
|
||||
## The System
|
||||
|
||||
Goose extends the capabilities of high-performing LLMs through a small collection of tools.
|
||||
This lets you instruct goose, currently via a CLI interface, to automatically solve problems
|
||||
on your behalf. It attempts to not just tell you how you can do something, but to actually do it for you.
|
||||
|
||||
The primary mode of goose (the "developer" toolkit) has access to tools to
|
||||
|
||||
- maintain a plan
|
||||
- run shell commands
|
||||
- read, create, and edit files
|
||||
|
||||
Together these can solve all kinds of problems, and we emphasize performance on tasks like
|
||||
fully automating adhoc scripts and tasks, interacting with existing code bases, and teaching how
|
||||
to use new technology.
|
||||
|
||||
---
|
||||
|
||||
Here are some of the key design decisions about how we drive performance on these tasks with goose,
|
||||
that you should be able to observe by using it.
|
||||
|
||||
- Encouraging it to write and maintain a plan, to allow it to accomplish longer sequences of automation
|
||||
- Using tool usage as a generalizable and increasingly tuned approach to adding new capabilities (including plugins)
|
||||
- Relying on reflection at every possible part of the stack
|
||||
- Showing it clear output of each tool use
|
||||
- Surfacing all possible errors to the model to give it a chance to correct
|
||||
- Surfacing the plan to document what has been accomplished
|
||||
|
||||
> [!TIP]
|
||||
> In addition, there are some implementation choices that we've found very performance driving. The share
|
||||
> a theme of working well by default without constraining the model.
|
||||
>
|
||||
> - Encouraging the model to use `ripgrep` via the shell performs very well for navigating filesystems. It mostly
|
||||
> just works, but enables the model to get clever with regexes or even additional shell operations as needed.
|
||||
> - Using a replace operation for editing files requires fewer tokens to be generated and avoids laziness on large files,
|
||||
> but we allow fall back to whole file overwrites to let it more coherently handle major refactors.
|
||||
|
||||
## Implementation
|
||||
|
||||
The core execution logic for generation and tool calling is handled by [exchange][exchange].
|
||||
It hooks python functions into the model tool use loop, while defining very careful error handling
|
||||
so any failures in tools are surfaced to the model.
|
||||
|
||||
Once we've created an *exchange* object, running the process is effectively just calling
|
||||
`exchange.reply()`.
|
||||
|
||||
*The key is setting up an exchange with the capabilities we need.*
|
||||
|
||||
Goose builds that exchange:
|
||||
- allows users to configure a profile to customize capabilities
|
||||
- provides a pluggable system for adding tools and prompts
|
||||
- sets up the tools to interact with state
|
||||
|
||||
We expect that goose will have multiple UXs over time, and be run in different
|
||||
environments. The UX is expected to be able to load a `Profile` (e.g. in the CLI
|
||||
we read profiles out of a config) and to provide a `Notifier` (e.g. in the CLI we put
|
||||
notifications on stdout).
|
||||
|
||||
Goose then constructs the exchange for the UX, the UX only interacts with that exchange.
|
||||
|
||||
```
|
||||
def build_exchange(profile: Profile, notifier: Notifier) -> Exchange:
|
||||
...
|
||||
```
|
||||
|
||||
But to setup a configurable system, Goose uses `Toolkit`s:
|
||||
|
||||
```
|
||||
(Profile, Notifier) -> [Toolkits] -> Exchange
|
||||
```
|
||||
|
||||
## Profile
|
||||
|
||||
A profile specifies some basic configuration in Goose, such as which models it should use, as well
|
||||
as which toolkits it should include.
|
||||
|
||||
```yaml
|
||||
processor: openai:gpt-4o
|
||||
accelerator: openai:gpt-4o-mini
|
||||
moderator: passive
|
||||
toolkits:
|
||||
- assistant
|
||||
- calendar
|
||||
- contacts
|
||||
- name: scheduling
|
||||
requires:
|
||||
assistant: assistant
|
||||
calendar: calendar
|
||||
contacts: contacts
|
||||
```
|
||||
|
||||
## Notifier
|
||||
|
||||
The notifier is a concrete implementation of the Notifier base class provided by each UX. It
|
||||
needs to support two methods
|
||||
|
||||
```python
|
||||
class Notifier:
|
||||
def log(self, RichRenderable):
|
||||
...
|
||||
|
||||
def status(self, str):
|
||||
...
|
||||
```
|
||||
|
||||
Log is meant to record something concrete that happened, such as a tool being called, and status is intended
|
||||
for transient displays of the current status. For example, while a shell command is running, it might use
|
||||
`.log` to record the command that started, and then update the status to `"shell command running"`. Log is durable
|
||||
while Status is ephemeral.
|
||||
|
||||
## Toolkits
|
||||
|
||||
Toolkits are a collection of tools, along with the state and prompting they require.
|
||||
Toolkits are what gives Goose its capabilities.
|
||||
|
||||
Tools need a way to report what's happening back to the user, which we treat similarly
|
||||
to logging. To make that possible, toolkits get a reference to the interface described above.
|
||||
|
||||
```python
|
||||
class ScheduleToolkit(Toolkit):
|
||||
def __init__(self, notifier: Notifier, requires: Requirements, **kwargs):
|
||||
super().__init__(notifier, requires, **kwargs) # handles the interface, exchangeview
|
||||
|
||||
# for a class that has requirements, you can get them like this
|
||||
self.calendar = requires.get("calendar")
|
||||
self.assistant = requires.get("assistant")
|
||||
self.contacts = requires.get("contacts")
|
||||
|
||||
self.appointments_state = []
|
||||
|
||||
def prompt(self) -> str:
|
||||
return "Try out the example tool."
|
||||
|
||||
@tool
|
||||
def example(self):
|
||||
self.interface.log(f"An example tool was called, current state is {self.state}")
|
||||
```
|
||||
|
||||
### Advanced
|
||||
|
||||
**Dependencies**: Toolkits can depend on each other, to make it easier to get plugins to extend
|
||||
or modify existing capabilities. In the config above, you can see this used for the scheduling toolkit.
|
||||
You can refer to those requirements in code through:
|
||||
|
||||
```python
|
||||
@tool
|
||||
def example_dependency(self):
|
||||
appointments = self.dependencies["calendar"].appointments
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
**ExchangeView**: It can also be useful for tools to have a read-only copy of the history
|
||||
of the loop so far. So for advanced use cases, toolkits also have access to an
|
||||
`ExchangeView` object.
|
||||
|
||||
```python
|
||||
@tool
|
||||
def example_history(self):
|
||||
last_message = self.exchange_view.processor.messages[-1]
|
||||
...
|
||||
```
|
||||
|
||||
[exchange]: https://github.com/squareup/exchange
|
||||
0
CHANGELOG.md
Normal file
0
CHANGELOG.md
Normal file
122
CONTRIBUTING.md
Normal file
122
CONTRIBUTING.md
Normal file
@@ -0,0 +1,122 @@
|
||||
# Contributing
|
||||
|
||||
We welcome Pull Requests for general contributions. If you have a larger new feature or any questions on how
|
||||
to develop a fix, we recommend you open an issue before starting.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
*goose* uses [uv][uv] for dependency management, and formats with [ruff][ruff].
|
||||
|
||||
We provide a shortcut to standard commands using [just][just] in our `justfile`.
|
||||
|
||||
## Developing
|
||||
|
||||
Now that you have a local environment, you can make edits and run our tests.
|
||||
|
||||
```
|
||||
uv run pytest tests -m "not integration"
|
||||
```
|
||||
|
||||
or, as a shortcut,
|
||||
|
||||
```
|
||||
just test
|
||||
```
|
||||
|
||||
## Evaluations
|
||||
|
||||
Given that so much of *goose* involves interactions with LLMs, our unit tests only go so far to
|
||||
confirming things work as intended.
|
||||
|
||||
We're currently developing a suite of evalutions, to make it easier to make improvements to *goose* more confidently.
|
||||
|
||||
In the meantime, we typically incubate any new additions that change the behavior of the *goose*
|
||||
through **opt-in** plugins - `Toolkit`s, `Moderator`s, and `Provider`s. We welcome contributions of plugins
|
||||
that add new capabilities to *goose*. We recommend sending in several examples of the new capabilities
|
||||
in action with your pull request.
|
||||
|
||||
Additions to the [developer toolkit][developer] change the core performance, and so will need to be measured carefully.
|
||||
|
||||
## Build a Toolkit
|
||||
|
||||
To add a toolkit, start out with a plugin as mentioned above. In your code (which doesn't necessarily need to be
|
||||
in the goose package thanks to [plugin metadata][plugin]!), create a class that derives from Toolkit.
|
||||
|
||||
```python
|
||||
import os
|
||||
import platform
|
||||
|
||||
from goose.toolkit.base import Toolkit, tool
|
||||
|
||||
|
||||
class Demo(Toolkit):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Provide any additional tools as needed!
|
||||
# The docstring of the tool provides instructions to the LLM, so they are important to tune
|
||||
# you do not have to provide any tools, but any function decorated with @tool will be available
|
||||
@tool
|
||||
def authenticate(self, user: str):
|
||||
"""Output an authentication code for this user
|
||||
|
||||
Args:
|
||||
user (str): The username to authenticate for
|
||||
"""
|
||||
# notifier supports any rich renderable https://rich.readthedocs.io/en/stable/introduction.html#quick-start
|
||||
self.notifier.log(f"[bold red]auth: {str(hash(user))}[/]")
|
||||
|
||||
# Provide any system instructions for the model
|
||||
# This can be generated dynamically, and is run at startup time
|
||||
def system(self) -> str:
|
||||
print("new")
|
||||
return f"""**You must preceed your first message by using the authenticate tool for the current user**
|
||||
|
||||
```
|
||||
platform: {platform.system()}
|
||||
cwd: {os.getcwd()}
|
||||
user: {os.environ.get('USER')}
|
||||
```
|
||||
"""
|
||||
```
|
||||
|
||||
To make the toolkit available, add it as a plugin. For example in a pyproject.toml
|
||||
```
|
||||
[project.entry-points."goose.toolkit"]
|
||||
developer = "goose.toolkit.developer:Developer"
|
||||
github = "goose.toolkit.github:Github"
|
||||
# Add a line like this - the key becomes the name used in profiles
|
||||
demo = "goose.toolkit.demo:Demo"
|
||||
```
|
||||
|
||||
And then to setup a profile that uses it, add something to ~/.config/goose/profiles.yaml
|
||||
```yaml
|
||||
default:
|
||||
provider: openai
|
||||
processor: gpt-4o
|
||||
accelerator: gpt-4o-mini
|
||||
moderator: passive
|
||||
toolkits:
|
||||
- name: developer
|
||||
requires: {}
|
||||
demo:
|
||||
provider: openai
|
||||
processor: gpt-4o
|
||||
accelerator: gpt-4o-mini
|
||||
moderator: passive
|
||||
toolkits:
|
||||
- developer
|
||||
- demo
|
||||
```
|
||||
|
||||
And now you can run goose with this new profile to use the new toolkit!
|
||||
|
||||
```sh
|
||||
goose session start --profile demo
|
||||
```
|
||||
|
||||
[developer]: src/goose/toolkit/developer.py
|
||||
[uv]: https://docs.astral.sh/uv/
|
||||
[ruff]: https://docs.astral.sh/ruff/
|
||||
[just]: https://github.com/casey/just
|
||||
[plugin]: https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
167
README.md
Normal file
167
README.md
Normal file
@@ -0,0 +1,167 @@
|
||||
<h1 align="center">
|
||||
goose
|
||||
</h1>
|
||||
|
||||
<p align="center"><strong>goose</strong> <em>is a programming agent that runs on your machine.</em></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://opensource.org/licenses/Apache-2.0"><img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#usage">Usage</a> •
|
||||
<a href="#installation">Installation</a> •
|
||||
<a href="#tips">Tips</a>
|
||||
</p>
|
||||
|
||||
`goose` assists in solving a wide range of programming and operational tasks. It is a live virtual developer you can interact with, guide, and learn from.
|
||||
|
||||
To solve problems, goose breaks down instructions into sequences of tasks and carries them out using tools. Its ability to connect its changes with real outcomes (e.g. errors) and course correct is its most powerful and exciting feature. goose is free open source software and is built to be extensible and customizable.
|
||||
|
||||
## Usage
|
||||
|
||||
You interact with goose in conversational sessions - something like a natural language driven code interpreter.
|
||||
The default toolkit lets it take actions through shell commands and file edits.
|
||||
You can interrupt Goose at any time to help redirect its efforts.
|
||||
|
||||
From your terminal, navigate to the directory you'd like to start from and run:
|
||||
```sh
|
||||
goose session start
|
||||
```
|
||||
|
||||
You will see a prompt `G❯`:
|
||||
|
||||
```
|
||||
G❯ type your instructions here exactly as you would tell a developer.
|
||||
```
|
||||
|
||||
From here you can talk directly with goose - send along your instructions. If you are looking to exit, use `CTRL+D`,
|
||||
although goose should help you figure that out if you forget.
|
||||
|
||||
When you exit a session, it will save the history and you can resume it later on:
|
||||
|
||||
``` sh
|
||||
goose session resume
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
Here are some collected tips we have for working efficiently with Goose
|
||||
|
||||
- **goose can and will edit files**. Use a git strategy to avoid losing anything - such as staging your
|
||||
personal edits and leaving goose edits unstaged until reviewed. Or consider using indivdual commits which can be reverted.
|
||||
- You can interrupt goose with `CTRL+C` to correct it or give it more info.
|
||||
- goose works best when solving concrete problems - experiment with how far you need to break that problem
|
||||
down to get goose to solve it. Be specific! E.g. it will likely fail to `"create a banking app"`,
|
||||
but probably does a good job if prompted with `"create a Fastapi app with an endpoint for deposit and withdrawal
|
||||
and with account balances stored in mysql keyed by id"`
|
||||
- If goose doesn't have enough context to start with, it might go down the wrong direction. Tell it
|
||||
to read files that you are refering to or search for objects in code. Even better, ask it to summarize
|
||||
them for you, which will help it set up its own next steps.
|
||||
- Refer to any objects in files with something that is easy to search for, such as `"the MyExample class"
|
||||
- goose *loves* to know how to run tests to get a feedback loop going, just like you do. If you tell it how you test things locally and quickly, it can make use of that when working on your project
|
||||
- You can use goose for tasks that would require scripting at times, even looking at your screen and correcting designs/helping you fix bugs, try asking it to help you in a way you would ask a person.
|
||||
- goose will make mistakes, and go in the wrong direction from times, feel free to correct it, or start again.
|
||||
- You can tell goose to run things for you continuously (and it will iterate, try, retry) but you can also tell it to check with you before doing things (and then later on tell it to go off on its own and do its best to solve).
|
||||
- Goose can run anywhere, doesn't have to be in a repo, just ask it!
|
||||
|
||||
## Installation
|
||||
|
||||
To install goose, we recommend `pipx`
|
||||
|
||||
First make sure you've [installed pipx][pipx] - for example
|
||||
|
||||
``` sh
|
||||
brew install pipx
|
||||
pipx ensurepath
|
||||
```
|
||||
|
||||
Then you can install goose with
|
||||
|
||||
``` sh
|
||||
pipx install goose
|
||||
```
|
||||
|
||||
### Config
|
||||
|
||||
Goose will try to detect what LLM it can work with and place a config in `~/.config/goose/profiles.yaml` automatically.
|
||||
|
||||
#### Toolkits
|
||||
|
||||
Goose can be extended with toolkits, and out of the box there are some available:
|
||||
|
||||
* `screen`: for letting goose take a look at your screen to help debug or work on designs (gives goose eyes)
|
||||
* `github`: for awareness and suggestions on how to use github
|
||||
* `repo_context`: for summarizing and understanding a repository you are working in.
|
||||
|
||||
To configure for example the screen toolkit, edit `~/.config/goose/profiles.yaml`:
|
||||
|
||||
```yaml
|
||||
provider: openai
|
||||
processor: gpt-4o
|
||||
accelerator: gpt-4o-mini
|
||||
moderator: passive
|
||||
toolkits:
|
||||
- name: developer
|
||||
requires: {}
|
||||
- name: screen
|
||||
requires: {}
|
||||
```
|
||||
|
||||
|
||||
|
||||
#### Advanced LLM config
|
||||
|
||||
goose works on top of LLMs (you bring your own LLM). If you need to customize goose, one way is via editing: `~/.config/goose/profiles.yaml`.
|
||||
|
||||
It will look by default something like:
|
||||
|
||||
```yaml
|
||||
default:
|
||||
provider: openai
|
||||
processor: gpt-4o
|
||||
accelerator: gpt-4o-mini
|
||||
moderator: truncate
|
||||
toolkits:
|
||||
- name: developer
|
||||
requires: {}
|
||||
```
|
||||
|
||||
*Note: This requires the environment variable `OPENAI_API_KEY` to be set to your OpenAI API key. goose uses at least 2 LLMs: one for acceleration for fast operating, and processing for writing code and executing commands.*
|
||||
|
||||
You can tell it to use another provider for example for Anthropic:
|
||||
|
||||
```yaml
|
||||
default:
|
||||
provider: anthropic
|
||||
processor: claude-3-5-sonnet-20240620
|
||||
accelerator: claude-3-5-sonnet-20240620
|
||||
...
|
||||
```
|
||||
|
||||
*Note: This will then use the claude-sonnet model, you will need to set the `ANTHROPIC_API_KEY` environment variable to your anthropic API key.*
|
||||
|
||||
For Databricks hosted models:
|
||||
|
||||
```yaml
|
||||
default:
|
||||
provider: databricks
|
||||
processor: databricks-meta-llama-3-1-70b-instruct
|
||||
accelerator: databricks-meta-llama-3-1-70b-instruct
|
||||
moderator: passive
|
||||
toolkits:
|
||||
- name: developer
|
||||
requires: {}
|
||||
```
|
||||
|
||||
This requires `DATABRICKS_HOST` and `DATABRICKS_TOKEN` to be set accordingly
|
||||
|
||||
(goose can be extended to support any LLM or combination of LLMs).
|
||||
|
||||
## Open Source
|
||||
|
||||
Yes, goose is open source and always will be. goose is released under the ASL2.0 license meaning you can use it however you like.
|
||||
See LICENSE.md for more details.
|
||||
|
||||
|
||||
[pipx]: https://github.com/pypa/pipx?tab=readme-ov-file#install-pipx
|
||||
19
justfile
Normal file
19
justfile
Normal file
@@ -0,0 +1,19 @@
|
||||
# This is the default recipe when no arguments are provided
|
||||
[private]
|
||||
default:
|
||||
@just --list --unsorted
|
||||
|
||||
test *FLAGS:
|
||||
uv run pytest tests -m "not integration" {{FLAGS}}
|
||||
|
||||
integration *FLAGS:
|
||||
uv run pytest tests -m integration {{FLAGS}}
|
||||
|
||||
format:
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
|
||||
coverage *FLAGS:
|
||||
uv run coverage run -m pytest tests -m "not integration" {{FLAGS}}
|
||||
uv run coverage report
|
||||
uv run coverage lcov -o lcov.info
|
||||
45
pyproject.toml
Normal file
45
pyproject.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
[project]
|
||||
name = "goose"
|
||||
description = "a programming agent that runs on your machine"
|
||||
version = "0.8.0"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"attrs>=23.2.0",
|
||||
"rich>=13.7.1",
|
||||
"ruamel-yaml>=0.18.6",
|
||||
"exchange>=0.7.6",
|
||||
"click>=8.1.7",
|
||||
"prompt-toolkit>=3.0.47",
|
||||
]
|
||||
author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }]
|
||||
packages = [{ include = "goose", from = "src" }]
|
||||
|
||||
[project.entry-points."goose.toolkit"]
|
||||
developer = "goose.toolkit.developer:Developer"
|
||||
github = "goose.toolkit.github:Github"
|
||||
screen = "goose.toolkit.screen:Screen"
|
||||
repo_context = "goose.toolkit.repo_context.repo_context:RepoContext"
|
||||
|
||||
[project.entry-points."goose.profile"]
|
||||
default = "goose.profile:default_profile"
|
||||
|
||||
[project.entry-points."goose.command"]
|
||||
file = "goose.command.file:FileCommand"
|
||||
|
||||
[project.entry-points."goose.cli"]
|
||||
goose = "goose.cli.main:goose_cli"
|
||||
|
||||
[project.scripts]
|
||||
goose = "goose.cli.main:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"pytest>=8.3.2",
|
||||
"codecov>=2.1.13",
|
||||
]
|
||||
|
||||
0
src/goose/__init__.py
Normal file
0
src/goose/__init__.py
Normal file
65
src/goose/build.py
Normal file
65
src/goose/build.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from itertools import chain
|
||||
|
||||
from exchange import Exchange, Message
|
||||
from exchange.moderators import get_moderator
|
||||
from exchange.providers import get_provider
|
||||
|
||||
from goose.notifier import Notifier
|
||||
from goose.profile import Profile
|
||||
from goose.toolkit import get_toolkit
|
||||
from goose.toolkit.base import Requirements
|
||||
from goose.view import ExchangeView
|
||||
|
||||
|
||||
def build_exchange(profile: Profile, notifier: Notifier) -> Exchange:
|
||||
"""Build an exchange configured through the profile
|
||||
|
||||
This will setup any toolkits and use that to build the exchange's collection
|
||||
of tools.
|
||||
|
||||
Args:
|
||||
profile (Profile): The profile specifying how to setup this exchange
|
||||
notifier (Notifier): A notifier instance used by tools to send info
|
||||
"""
|
||||
|
||||
provider = get_provider(profile.provider).from_env()
|
||||
|
||||
# Support instantating toolkits in *two* passes for now, no further nesting
|
||||
concrete_toolkits = {}
|
||||
|
||||
# First instantiate all toolkits that are sub dependencies
|
||||
for spec in profile.toolkits:
|
||||
for required in spec.requires.values():
|
||||
concrete_toolkits[required] = get_toolkit(required)(notifier=notifier, requires=Requirements(required))
|
||||
|
||||
# Now that we have the dependencies available, we can instantiate everything else
|
||||
toolkits = []
|
||||
for spec in profile.toolkits:
|
||||
if spec.name in concrete_toolkits:
|
||||
toolkits.append(concrete_toolkits[spec.name])
|
||||
continue
|
||||
|
||||
requires = Requirements(
|
||||
spec.name,
|
||||
{key: concrete_toolkits[val] for key, val in spec.requires.items()},
|
||||
)
|
||||
toolkit = get_toolkit(spec.name)(notifier=notifier, requires=requires)
|
||||
toolkits.append(toolkit)
|
||||
|
||||
# From the toolkits, we derive the exchange prompt and tools
|
||||
system = "\n\n".join([Message.load("system.jinja").text] + [toolkit.system() for toolkit in toolkits])
|
||||
tools = tuple(chain(*(toolkit.tools() for toolkit in toolkits)))
|
||||
exchange = Exchange(
|
||||
provider=provider,
|
||||
system=system,
|
||||
tools=tools,
|
||||
moderator=get_moderator(profile.moderator)(),
|
||||
model=profile.processor,
|
||||
)
|
||||
|
||||
# This is a bit awkward, but we have to set this after the fact because building
|
||||
# the exchange requires having the toolkits
|
||||
for toolkit in toolkits:
|
||||
toolkit.exchange_view = ExchangeView(profile.processor, profile.accelerator, exchange)
|
||||
|
||||
return exchange
|
||||
0
src/goose/cli/__init__.py
Normal file
0
src/goose/cli/__init__.py
Normal file
141
src/goose/cli/config.py
Normal file
141
src/goose/cli/config.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from functools import cache
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Mapping, Tuple
|
||||
|
||||
from rich import print
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from rich.text import Text
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from goose.profile import Profile
|
||||
from goose.utils import load_plugins
|
||||
from goose.utils.diff import pretty_diff
|
||||
|
||||
GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser()
|
||||
PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml")
|
||||
SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions")
|
||||
SESSION_FILE_SUFFIX = ".jsonl"
|
||||
|
||||
|
||||
@cache
|
||||
def default_profiles() -> Mapping[str, Callable]:
|
||||
return load_plugins(group="goose.profile")
|
||||
|
||||
|
||||
def session_path(name: str) -> Path:
|
||||
SESSIONS_PATH.mkdir(parents=True, exist_ok=True)
|
||||
return SESSIONS_PATH.joinpath(f"{name}{SESSION_FILE_SUFFIX}")
|
||||
|
||||
|
||||
def write_config(profiles: Dict[str, Profile]) -> None:
|
||||
"""Overwrite the config with the passed profiles"""
|
||||
PROFILES_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
converted = {name: profile.to_dict() for name, profile in profiles.items()}
|
||||
yaml = YAML()
|
||||
with PROFILES_CONFIG_PATH.open("w") as f:
|
||||
yaml.dump(converted, f)
|
||||
|
||||
|
||||
def ensure_config(name: str) -> Profile:
|
||||
"""Ensure that the config exists and has the default section"""
|
||||
# TODO we should copy a templated default config in to better document
|
||||
# but this is complicated a bit by autodetecting the provider
|
||||
|
||||
provider, processor, accelerator = default_model_configuration()
|
||||
profile = default_profiles()[name](provider, processor, accelerator)
|
||||
|
||||
profiles = {}
|
||||
if not PROFILES_CONFIG_PATH.exists():
|
||||
print(
|
||||
Panel(
|
||||
f"[yellow]No configuration present, we will create a profile '{name}'"
|
||||
+ f" at: [/]{str(PROFILES_CONFIG_PATH)}\n"
|
||||
+ "You can add your own profile in this file to further configure goose!"
|
||||
)
|
||||
)
|
||||
default = profile
|
||||
profiles = {name: default}
|
||||
write_config(profiles)
|
||||
return profile
|
||||
|
||||
profiles = read_config()
|
||||
if name not in profiles:
|
||||
print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]"))
|
||||
profiles.update({name: profile})
|
||||
write_config(profiles)
|
||||
elif name in profiles:
|
||||
# if the profile stored differs from the default one, we should prompt the user to see if they want
|
||||
# to update it! we need to recursively compare the two profiles, as object comparison will always return false
|
||||
is_profile_eq = profile.to_dict() == profiles[name].to_dict()
|
||||
if not is_profile_eq:
|
||||
yaml = YAML()
|
||||
before = StringIO()
|
||||
after = StringIO()
|
||||
yaml.dump(profiles[name].to_dict(), before)
|
||||
yaml.dump(profile.to_dict(), after)
|
||||
before.seek(0)
|
||||
after.seek(0)
|
||||
|
||||
print(
|
||||
Panel(
|
||||
Text(
|
||||
f"Your profile uses one of the default options - '{name}'"
|
||||
+ " - but it differs from the latest version:\n\n",
|
||||
)
|
||||
+ pretty_diff(before.read(), after.read())
|
||||
)
|
||||
)
|
||||
should_update = Confirm.ask(
|
||||
"Do you want to update your profile to use the latest?",
|
||||
default=False,
|
||||
)
|
||||
if should_update:
|
||||
profiles[name] = profile
|
||||
write_config(profiles)
|
||||
else:
|
||||
profile = profiles[name]
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def read_config() -> Dict[str, Profile]:
|
||||
"""Return config from the configuration file and validates its contents"""
|
||||
|
||||
yaml = YAML()
|
||||
with PROFILES_CONFIG_PATH.open("r") as f:
|
||||
data = yaml.load(f)
|
||||
|
||||
return {name: Profile(**profile) for name, profile in data.items()}
|
||||
|
||||
|
||||
def default_model_configuration() -> Tuple[str, str, str]:
|
||||
providers = load_plugins(group="exchange.provider")
|
||||
for provider, cls in providers.items():
|
||||
try:
|
||||
cls.from_env()
|
||||
print(Panel(f"[green]Detected an available provider: [/]{provider}"))
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not detect an available provider,"
|
||||
+ " make sure to plugin a provider or set an env var such as OPENAI_API_KEY"
|
||||
)
|
||||
|
||||
recommended = {
|
||||
"openai": ("gpt-4o", "gpt-4o-mini"),
|
||||
"anthropic": (
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
),
|
||||
"databricks": (
|
||||
# TODO when function calling is first rec should be: "databricks-meta-llama-3-1-405b-instruct"
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
),
|
||||
}
|
||||
processor, accelerator = recommended.get(provider, ("gpt-4o", "gpt-4o-mini"))
|
||||
return provider, processor, accelerator
|
||||
125
src/goose/cli/main.py
Normal file
125
src/goose/cli/main.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import click
|
||||
from rich import print
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from goose.cli.config import SESSIONS_PATH
|
||||
from goose.cli.session import Session
|
||||
from goose.utils import load_plugins
|
||||
from goose.utils.session_file import list_sorted_session_files
|
||||
|
||||
|
||||
@click.group()
|
||||
def goose_cli() -> None:
|
||||
pass
|
||||
|
||||
|
||||
@goose_cli.command()
|
||||
def version() -> None:
|
||||
"""Lists the version of goose and any plugins"""
|
||||
from importlib.metadata import entry_points, version
|
||||
|
||||
print(f"[green]Goose[/green]: [bold][cyan]{version('goose')}[/cyan][/bold]")
|
||||
print("[green]Plugins[/green]:")
|
||||
filtered_groups = {}
|
||||
modules = set()
|
||||
if sys.version_info.minor >= 12:
|
||||
for ep in entry_points():
|
||||
group = getattr(ep, "group", None)
|
||||
if group and (group.startswith("exchange.") or group.startswith("goose.")):
|
||||
filtered_groups.setdefault(group, []).append(ep)
|
||||
for eps in filtered_groups.values():
|
||||
for ep in eps:
|
||||
module_name = ep.module.split(".")[0]
|
||||
modules.add(module_name)
|
||||
else:
|
||||
eps = entry_points()
|
||||
for group, entries in eps.items():
|
||||
if group.startswith("exchange.") or group.startswith("goose."):
|
||||
for entry in entries:
|
||||
module_name = entry.value.split(".")[0]
|
||||
modules.add(module_name)
|
||||
|
||||
modules.remove("goose")
|
||||
for module in sorted(list(modules)):
|
||||
# TODO: figure out how to get this to work for goose plugins block
|
||||
# as the module name is set to block.goose.cli
|
||||
# module_name = 'goose-plugins-block'
|
||||
try:
|
||||
module_version = version(module)
|
||||
print(f" Module: [green]{module}[/green], Version: [bold][cyan]{module_version}[/cyan][/bold]")
|
||||
except Exception as e:
|
||||
print(f" [red]Could not retrieve version for {module}: {e}[/red]")
|
||||
|
||||
|
||||
@goose_cli.group()
|
||||
def session() -> None:
|
||||
"""Start or manage sessions"""
|
||||
pass
|
||||
|
||||
|
||||
@session.command(name="start")
|
||||
@click.option("--profile")
|
||||
@click.option("--plan", type=click.Path(exists=True))
|
||||
def session_start(profile: str, plan: Optional[str] = None) -> None:
|
||||
"""Start a new goose session"""
|
||||
if plan:
|
||||
yaml = YAML()
|
||||
with open(plan, "r") as f:
|
||||
_plan = yaml.load(f)
|
||||
else:
|
||||
_plan = None
|
||||
|
||||
session = Session(profile=profile, plan=_plan)
|
||||
session.run()
|
||||
|
||||
|
||||
@session.command(name="resume")
|
||||
@click.argument("name", required=False)
|
||||
@click.option("--profile")
|
||||
def session_resume(name: str, profile: str) -> None:
|
||||
"""Resume an existing goose session"""
|
||||
if name is None:
|
||||
session_files = get_session_files()
|
||||
if session_files:
|
||||
name = list(session_files.keys())[0]
|
||||
print(f"Resuming most recent session: {name} from {session_files[name]}")
|
||||
else:
|
||||
print("No sessions found.")
|
||||
return
|
||||
session = Session(name=name, profile=profile)
|
||||
session.run()
|
||||
|
||||
|
||||
@session.command(name="list")
|
||||
def session_list() -> None:
|
||||
session_files = get_session_files().items()
|
||||
for session_name, session_file in session_files:
|
||||
print(f"{datetime.fromtimestamp(session_file.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')} {session_name}")
|
||||
|
||||
|
||||
@session.command(name="clear")
|
||||
@click.option("--keep", default=3, help="Keep this many entries, default 3")
|
||||
def session_clear(keep: int) -> None:
|
||||
for i, (_, session_file) in enumerate(get_session_files().items()):
|
||||
if i >= keep:
|
||||
session_file.unlink()
|
||||
|
||||
|
||||
def get_session_files() -> Dict[str, Path]:
|
||||
return list_sorted_session_files(SESSIONS_PATH)
|
||||
|
||||
|
||||
# merging goose cli with additional cli plugins.
|
||||
def cli() -> None:
|
||||
clis = load_plugins("goose.cli")
|
||||
cli_list = list(clis.values()) or []
|
||||
click.CommandCollection(sources=cli_list)()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
0
src/goose/cli/prompt/__init__.py
Normal file
0
src/goose/cli/prompt/__init__.py
Normal file
46
src/goose/cli/prompt/completer.py
Normal file
46
src/goose/cli/prompt/completer.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from prompt_toolkit.completion import CompleteEvent, Completer, Completion
|
||||
from prompt_toolkit.document import Document
|
||||
|
||||
from goose.command.base import Command
|
||||
|
||||
|
||||
class GoosePromptCompleter(Completer):
|
||||
def __init__(self, commands: List[Command]) -> None:
|
||||
self.commands = commands
|
||||
|
||||
def get_command_completions(self, document: Document) -> List[Completion]:
|
||||
all_completions = []
|
||||
for command_name, command_instance in self.commands.items():
|
||||
pattern = rf"(?<!\S)\/{command_name}:([\S]*)$"
|
||||
text = document.text_before_cursor
|
||||
match = re.search(pattern=pattern, string=text)
|
||||
if not match or text.endswith(" "):
|
||||
continue
|
||||
|
||||
query = match.group(1)
|
||||
completions = command_instance.get_completions(query)
|
||||
all_completions.extend(completions)
|
||||
return all_completions
|
||||
|
||||
def get_command_name_completions(self, document: Document) -> List[Completion]:
|
||||
pattern = r"(?<!\S)\/([\S]*)$"
|
||||
text = document.text_before_cursor
|
||||
match = re.search(pattern=pattern, string=text)
|
||||
if not match or text.endswith(" "):
|
||||
return []
|
||||
|
||||
query = match.group(1)
|
||||
|
||||
completions = []
|
||||
for command_name in self.commands:
|
||||
if command_name.startswith(query):
|
||||
completions.append(Completion(command_name, start_position=-len(query), display=command_name))
|
||||
return completions
|
||||
|
||||
def get_completions(self, document: Document, _: CompleteEvent) -> List[Completion]:
|
||||
command_completions = self.get_command_completions(document)
|
||||
command_name_completions = self.get_command_name_completions(document)
|
||||
return command_name_completions + command_completions
|
||||
66
src/goose/cli/prompt/create.py
Normal file
66
src/goose/cli/prompt/create.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.application.current import get_app
|
||||
from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent
|
||||
from prompt_toolkit.keys import Keys
|
||||
from prompt_toolkit.styles import Style
|
||||
|
||||
from goose.cli.prompt.completer import GoosePromptCompleter
|
||||
from goose.cli.prompt.lexer import PromptLexer
|
||||
from goose.command import get_commands
|
||||
|
||||
|
||||
def create_prompt() -> PromptSession:
|
||||
# Define custom style
|
||||
style = Style.from_dict(
|
||||
{
|
||||
"parameter": "bold",
|
||||
"command": "ansiblue bold",
|
||||
"text": "default",
|
||||
}
|
||||
)
|
||||
|
||||
bindings = KeyBindings()
|
||||
|
||||
# Bind the "Option + Enter" key to insert a newline
|
||||
@bindings.add(Keys.Escape, Keys.ControlM)
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
buffer = event.app.current_buffer
|
||||
buffer.insert_text("\n")
|
||||
|
||||
# Bind the "Enter" key to accept the completion if the completion menu is open
|
||||
# otherwise just submit the input
|
||||
@bindings.add(Keys.Enter)
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
buffer = event.current_buffer
|
||||
app = get_app()
|
||||
|
||||
if app.layout.has_focus(buffer):
|
||||
# Check if the completion menu is open
|
||||
if buffer.complete_state:
|
||||
# accept completion
|
||||
buffer.complete_state = None
|
||||
else:
|
||||
buffer.validate_and_handle()
|
||||
|
||||
@bindings.add(Keys.ControlY)
|
||||
def _(event: KeyPressEvent) -> None:
|
||||
buffer = event.app.current_buffer
|
||||
app = get_app()
|
||||
if app.layout.has_focus(buffer):
|
||||
# Check if the completion menu is open
|
||||
if buffer.complete_state:
|
||||
# accept completion
|
||||
buffer.complete_state = None
|
||||
|
||||
# instantiate the commands available in the prompt
|
||||
commands = dict()
|
||||
command_plugins = get_commands()
|
||||
for command, command_cls in command_plugins.items():
|
||||
commands[command] = command_cls()
|
||||
|
||||
return PromptSession(
|
||||
completer=GoosePromptCompleter(commands=commands),
|
||||
lexer=PromptLexer(list(commands.keys())),
|
||||
style=style,
|
||||
key_bindings=bindings,
|
||||
)
|
||||
34
src/goose/cli/prompt/goose_prompt_session.py
Normal file
34
src/goose/cli/prompt/goose_prompt_session.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.formatted_text import FormattedText
|
||||
from prompt_toolkit.validation import DummyValidator
|
||||
|
||||
from goose.cli.prompt.create import create_prompt
|
||||
from goose.cli.prompt.prompt_validator import PromptValidator
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
|
||||
|
||||
class GoosePromptSession:
|
||||
def __init__(self, prompt_session: PromptSession) -> None:
|
||||
self.prompt_session = prompt_session
|
||||
|
||||
@staticmethod
|
||||
def create_prompt_session() -> "GoosePromptSession":
|
||||
return GoosePromptSession(create_prompt())
|
||||
|
||||
def get_user_input(self) -> "UserInput":
|
||||
try:
|
||||
text = FormattedText([("#00AEAE", "G❯ ")]) # Define the prompt style and text.
|
||||
message = self.prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False)
|
||||
if message.strip() in ("exit", ":q"):
|
||||
return UserInput(PromptAction.EXIT)
|
||||
return UserInput(PromptAction.CONTINUE, message)
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return UserInput(PromptAction.EXIT)
|
||||
|
||||
def get_save_session_name(self) -> Optional[str]:
|
||||
return self.prompt_session.prompt(
|
||||
"Enter a name to save this session under. A name will be generated for you if empty: ",
|
||||
validator=DummyValidator(),
|
||||
)
|
||||
53
src/goose/cli/prompt/lexer.py
Normal file
53
src/goose/cli/prompt/lexer.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import re
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.lexers import Lexer
|
||||
|
||||
|
||||
def completion_for_command(target_string: str) -> re.Pattern[str]:
|
||||
escaped_string = re.escape(target_string)
|
||||
vals = [f"(?:{escaped_string[:i]}(?=$))" for i in range(len(escaped_string), 0, -1)]
|
||||
return re.compile(rf'(?<!\S)\/({"|".join(vals)})(?:\s^|$)')
|
||||
|
||||
|
||||
def command_itself(target_string: str) -> re.Pattern[str]:
|
||||
escaped_string = re.escape(target_string)
|
||||
return re.compile(rf"(?<!\S)(\/{escaped_string})")
|
||||
|
||||
|
||||
def value_for_command(command_string: str) -> re.Pattern[str]:
|
||||
escaped_string = re.escape(command_string)
|
||||
return re.compile(rf"(?<=(?<!\S)\/{escaped_string})([^\s]*)")
|
||||
|
||||
|
||||
class PromptLexer(Lexer):
|
||||
def __init__(self, command_names: List[str]) -> None:
|
||||
self.patterns = []
|
||||
for command_name in command_names:
|
||||
full_command = command_name + ":"
|
||||
self.patterns.append((completion_for_command(full_command), "class:command"))
|
||||
self.patterns.append((value_for_command(full_command), "class:parameter"))
|
||||
self.patterns.append((command_itself(full_command), "class:command"))
|
||||
|
||||
def lex_document(self, document: Document) -> Callable[[int], list]:
|
||||
def get_line_tokens(line_number: int) -> Tuple[str, str]:
|
||||
line = document.lines[line_number]
|
||||
tokens = []
|
||||
|
||||
i = 0
|
||||
while i < len(line):
|
||||
match = None
|
||||
for pattern, token in self.patterns:
|
||||
match = pattern.match(line, i)
|
||||
if match:
|
||||
tokens.append((token, match.group()))
|
||||
i = match.end()
|
||||
break
|
||||
if not match:
|
||||
tokens.append(("class:text", line[i]))
|
||||
i += 1
|
||||
|
||||
return tokens
|
||||
|
||||
return get_line_tokens
|
||||
10
src/goose/cli/prompt/prompt_validator.py
Normal file
10
src/goose/cli/prompt/prompt_validator.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.validation import ValidationError, Validator
|
||||
|
||||
|
||||
class PromptValidator(Validator):
|
||||
def validate(self, document: Document) -> None:
|
||||
text = document.text
|
||||
if text is not None and not text.strip():
|
||||
message = "Enter your prompt to goose. If you would like to exit, use CTRL+D, or type 'exit' or ':q'"
|
||||
raise ValidationError(message=message, cursor_position=0)
|
||||
20
src/goose/cli/prompt/user_input.py
Normal file
20
src/goose/cli/prompt/user_input.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class PromptAction(Enum):
|
||||
CONTINUE = 1
|
||||
EXIT = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInput:
|
||||
action: PromptAction
|
||||
text: Optional[str] = None
|
||||
|
||||
def to_exit(self) -> bool:
|
||||
return self.action == PromptAction.EXIT
|
||||
|
||||
def to_continue(self) -> bool:
|
||||
return self.action == PromptAction.CONTINUE
|
||||
237
src/goose/cli/session.py
Normal file
237
src/goose/cli/session.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from exchange import Message, ToolResult, ToolUse
|
||||
from prompt_toolkit.shortcuts import confirm
|
||||
from rich import print
|
||||
from rich.console import RenderableType
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.status import Status
|
||||
|
||||
from goose.build import build_exchange
|
||||
from goose.cli.config import (
|
||||
default_profiles,
|
||||
ensure_config,
|
||||
read_config,
|
||||
session_path,
|
||||
)
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.notifier import Notifier
|
||||
from goose.profile import Profile
|
||||
from goose.utils import droid, load_plugins
|
||||
from goose.utils.session_file import read_from_file, write_to_file
|
||||
|
||||
|
||||
def load_provider() -> str:
|
||||
# We try to infer a provider, by going in order of what will auth
|
||||
providers = load_plugins(group="exchange.provider")
|
||||
for provider, cls in providers.items():
|
||||
try:
|
||||
cls.from_env()
|
||||
print(Panel(f"[green]Detected an available provider: [/]{provider}"))
|
||||
return provider
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# TODO link to auth docs
|
||||
print(
|
||||
Panel(
|
||||
"[red]Could not authenticate any providers[/]\n"
|
||||
+ "Returning a default pointing to openai, but you will need to set an API token env variable."
|
||||
)
|
||||
)
|
||||
return "openai"
|
||||
|
||||
|
||||
def load_profile(name: Optional[str]) -> Profile:
|
||||
if name is None:
|
||||
name = "default"
|
||||
|
||||
# If the name is one of the default values, we ensure a valid configuration
|
||||
if name in default_profiles():
|
||||
return ensure_config(name)
|
||||
|
||||
# Otherwise this is a custom config and we return it from the config file
|
||||
return read_config()[name]
|
||||
|
||||
|
||||
class SessionNotifier(Notifier):
|
||||
def __init__(self, status_indicator: Status) -> None:
|
||||
self.status_indicator = status_indicator
|
||||
|
||||
def log(self, content: RenderableType) -> None:
|
||||
print(content)
|
||||
|
||||
def status(self, status: str) -> None:
|
||||
self.status_indicator.update(status)
|
||||
|
||||
|
||||
class Session:
|
||||
"""A session handler for managing interactions between a user and the Goose exchange
|
||||
|
||||
This class encapsulates the entire user interaction cycle, from input prompt to response handling,
|
||||
including interruptions and error management.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
profile: Optional[str] = None,
|
||||
plan: Optional[dict] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.status_indicator = Status("", spinner="dots")
|
||||
notifier = SessionNotifier(self.status_indicator)
|
||||
|
||||
self.exchange = build_exchange(profile=load_profile(profile), notifier=notifier)
|
||||
|
||||
if name is not None and self.session_file_path.exists():
|
||||
messages = self.load_session()
|
||||
if messages and messages[-1].role == "user":
|
||||
messages.pop()
|
||||
self.exchange.messages.extend(messages)
|
||||
|
||||
if len(self.exchange.messages) == 0 and plan:
|
||||
self.setup_plan(plan=plan)
|
||||
|
||||
self.prompt_session = GoosePromptSession.create_prompt_session()
|
||||
|
||||
def setup_plan(self, plan: dict) -> None:
|
||||
if len(self.exchange.messages):
|
||||
raise ValueError("The plan can only be set on an empty session.")
|
||||
self.exchange.messages.append(Message.user(plan["kickoff_message"]))
|
||||
tasks = []
|
||||
if "tasks" in plan:
|
||||
tasks = [dict(description=task, status="planned") for task in plan["tasks"]]
|
||||
|
||||
plan_tool_use = ToolUse(id="initialplan", name="update_plan", parameters=dict(tasks=tasks))
|
||||
self.exchange.add_tool_use(plan_tool_use)
|
||||
|
||||
def process_first_message(self) -> Optional[Message]:
|
||||
# Get a first input unless it has been specified, such as by a plan
|
||||
if len(self.exchange.messages) == 0 or self.exchange.messages[-1].role == "assistant":
|
||||
user_input = self.prompt_session.get_user_input()
|
||||
if user_input.to_exit():
|
||||
return None
|
||||
return Message.user(text=user_input.text)
|
||||
return self.exchange.messages.pop()
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Runs the main loop to handle user inputs and responses.
|
||||
Continues until an empty string is returned from the prompt.
|
||||
"""
|
||||
message = self.process_first_message()
|
||||
while message: # Loop until no input (empty string).
|
||||
with Live(self.status_indicator, refresh_per_second=8, transient=True):
|
||||
try:
|
||||
self.exchange.add(message)
|
||||
self.reply() # Process the user message.
|
||||
except KeyboardInterrupt:
|
||||
self.interrupt_reply()
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
if self.exchange.messages:
|
||||
self.exchange.messages.pop()
|
||||
print(
|
||||
"\n[red]The error above was an exception we were not able to handle.\n\n[/]"
|
||||
+ "These errors are often related to connection or authentication\n"
|
||||
+ "We've removed your most recent input"
|
||||
+ " - [yellow]depending on the error you may be able to continue[/]"
|
||||
)
|
||||
|
||||
print() # Print a newline for separation.
|
||||
user_input = self.prompt_session.get_user_input()
|
||||
message = Message.user(text=user_input.text) if user_input.to_continue() else None
|
||||
|
||||
self.save_session()
|
||||
|
||||
def reply(self) -> None:
|
||||
"""Reply to the last user message, calling tools as needed
|
||||
|
||||
Args:
|
||||
text (str): The text input from the user.
|
||||
"""
|
||||
self.status_indicator.update("responding")
|
||||
response = self.exchange.generate()
|
||||
|
||||
if response.text:
|
||||
print(Markdown(response.text))
|
||||
|
||||
while response.tool_use:
|
||||
content = []
|
||||
for tool_use in response.tool_use:
|
||||
tool_result = self.exchange.call_function(tool_use)
|
||||
content.append(tool_result)
|
||||
self.exchange.add(Message(role="user", content=content))
|
||||
self.status_indicator.update("responding")
|
||||
response = self.exchange.generate()
|
||||
|
||||
if response.text:
|
||||
print(Markdown(response.text))
|
||||
|
||||
def interrupt_reply(self) -> None:
|
||||
"""Recover from an interruption at an arbitrary state"""
|
||||
# Default recovery message if no user message is pending.
|
||||
recovery = "We interrupted before the next processing started."
|
||||
if self.exchange.messages and self.exchange.messages[-1].role == "user":
|
||||
# If the last message is from the user, remove it.
|
||||
self.exchange.messages.pop()
|
||||
recovery = "We interrupted before the model replied and removed the last message."
|
||||
|
||||
if (
|
||||
self.exchange.messages
|
||||
and self.exchange.messages[-1].role == "assistant"
|
||||
and self.exchange.messages[-1].tool_use
|
||||
):
|
||||
content = []
|
||||
# Append tool results as errors if interrupted.
|
||||
for tool_use in self.exchange.messages[-1].tool_use:
|
||||
content.append(
|
||||
ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
output="Interrupted by the user to make a correction",
|
||||
is_error=True,
|
||||
)
|
||||
)
|
||||
self.exchange.add(Message(role="user", content=content))
|
||||
recovery = f"We interrupted the existing call to {tool_use.name}. How would you like to proceed?"
|
||||
self.exchange.add(Message.assistant(recovery))
|
||||
# Print the recovery message with markup for visibility.
|
||||
print(f"[yellow]{recovery}[/]")
|
||||
|
||||
@property
|
||||
def session_file_path(self) -> Path:
|
||||
return session_path(self.name)
|
||||
|
||||
def save_session(self) -> None:
|
||||
"""Save the current session to a file in JSON format."""
|
||||
if self.name is None:
|
||||
self.generate_session_name()
|
||||
|
||||
try:
|
||||
if self.session_file_path.exists():
|
||||
if not confirm(f"Session {self.name} exists in {self.session_file_path}, overwrite?"):
|
||||
self.generate_session_name()
|
||||
write_to_file(self.session_file_path, self.exchange.messages)
|
||||
except PermissionError as e:
|
||||
raise RuntimeError(f"Failed to save session due to permissions: {e}")
|
||||
except (IOError, OSError) as e:
|
||||
raise RuntimeError(f"Failed to save session due to I/O error: {e}")
|
||||
|
||||
def load_session(self) -> List[Message]:
|
||||
"""Load a session from a JSON file."""
|
||||
return read_from_file(self.session_file_path)
|
||||
|
||||
def generate_session_name(self) -> None:
|
||||
user_entered_session_name = self.prompt_session.get_save_session_name()
|
||||
self.name = user_entered_session_name if user_entered_session_name else droid()
|
||||
print(f"Saving to [bold cyan]{self.session_file_path}[/bold cyan]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
session = Session()
|
||||
15
src/goose/command/__init__.py
Normal file
15
src/goose/command/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from functools import cache
|
||||
from typing import Dict
|
||||
|
||||
from goose.command.base import Command
|
||||
from goose.utils import load_plugins
|
||||
|
||||
|
||||
@cache
|
||||
def get_command(name: str) -> type[Command]:
|
||||
return load_plugins(group="goose.command")[name]
|
||||
|
||||
|
||||
@cache
|
||||
def get_commands() -> Dict[str, type[Command]]:
|
||||
return load_plugins(group="goose.command")
|
||||
16
src/goose/command/base.py
Normal file
16
src/goose/command/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from abc import ABC
|
||||
from typing import List, Optional
|
||||
|
||||
from prompt_toolkit.completion import Completion
|
||||
|
||||
|
||||
class Command(ABC):
|
||||
"""A command that can be executed by the CLI."""
|
||||
|
||||
def get_completions(self, query: str) -> List[Completion]:
|
||||
"""Get completions for the command."""
|
||||
return []
|
||||
|
||||
def execute(self, query: str) -> Optional[str]:
|
||||
"""Execute's the command and replaces it with the output."""
|
||||
return ""
|
||||
61
src/goose/command/file.py
Normal file
61
src/goose/command/file.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from prompt_toolkit.completion import Completion
|
||||
|
||||
from goose.command.base import Command
|
||||
|
||||
|
||||
class FileCommand(Command):
|
||||
def get_completions(self, query: str) -> List[Completion]:
|
||||
if query.startswith("/"):
|
||||
directory = os.path.dirname(query)
|
||||
search_term = os.path.basename(query)
|
||||
else:
|
||||
directory = os.path.join(os.getcwd(), os.path.dirname(query))
|
||||
search_term = os.path.basename(query)
|
||||
|
||||
# if query is a file, don't show completions
|
||||
if os.path.isfile(directory):
|
||||
return []
|
||||
|
||||
# Get the list of files in the directory
|
||||
options = []
|
||||
try:
|
||||
for file_name in os.listdir(directory):
|
||||
if file_name.startswith(search_term):
|
||||
full_path = os.path.join(directory, file_name)
|
||||
if os.path.isdir(full_path):
|
||||
options.append(
|
||||
dict(
|
||||
display_text=" " + file_name,
|
||||
insert_text=file_name,
|
||||
is_dir=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
options.append(
|
||||
dict(
|
||||
display_text=" " + file_name,
|
||||
insert_text=file_name + " ",
|
||||
is_dir=False,
|
||||
)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
||||
completions = []
|
||||
options.sort(key=lambda x: (not x["is_dir"], x["insert_text"]), reverse=False)
|
||||
for option in options:
|
||||
completions.append(
|
||||
Completion(
|
||||
option["insert_text"],
|
||||
start_position=-len(search_term),
|
||||
display=option["display_text"],
|
||||
)
|
||||
)
|
||||
return completions
|
||||
|
||||
def execute(self, query: str) -> str | None:
|
||||
# GOOSE-TODO: return the query
|
||||
pass
|
||||
28
src/goose/notifier.py
Normal file
28
src/goose/notifier.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from rich.console import RenderableType
|
||||
|
||||
|
||||
class Notifier(ABC):
|
||||
"""The interface for a notifier
|
||||
|
||||
This is expected to be implemented concretely by the each UX
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def log(self, content: RenderableType) -> None:
|
||||
"""Append content to the main display
|
||||
|
||||
Args:
|
||||
content (str): The content to render
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def status(self, status: str) -> None:
|
||||
"""Log a status to ephemeral display
|
||||
|
||||
Args:
|
||||
status (str): The status to display
|
||||
"""
|
||||
pass
|
||||
54
src/goose/profile.py
Normal file
54
src/goose/profile.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any, Dict, List, Mapping, Type
|
||||
|
||||
from attrs import asdict, define, field
|
||||
|
||||
from goose.utils import ensure_list
|
||||
|
||||
|
||||
@define
|
||||
class ToolkitSpec:
|
||||
"""Configuration for a Toolkit"""
|
||||
|
||||
name: str
|
||||
requires: Mapping[str, str] = field(factory=dict)
|
||||
|
||||
|
||||
@define
|
||||
class Profile:
|
||||
"""The configuration for a run of goose"""
|
||||
|
||||
provider: str
|
||||
processor: str
|
||||
accelerator: str
|
||||
moderator: str
|
||||
toolkits: List[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec))
|
||||
|
||||
@toolkits.validator
|
||||
def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[ToolkitSpec]) -> None:
|
||||
# checks that the list of toolkits in the profile have their requirements
|
||||
installed_toolkits = set([toolkit.name for toolkit in toolkits])
|
||||
|
||||
for toolkit in toolkits:
|
||||
toolkit_name = toolkit.name
|
||||
toolkit_requirements = toolkit.requires
|
||||
for _, req in toolkit_requirements.items():
|
||||
if req not in installed_toolkits:
|
||||
msg = f"Toolkit {toolkit_name} requires {req} but it is not present"
|
||||
raise ValueError(msg)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile:
|
||||
"""Get the default profile"""
|
||||
|
||||
# TODO consider if the providers should have recommended models
|
||||
|
||||
return Profile(
|
||||
provider=provider,
|
||||
processor=processor,
|
||||
accelerator=accelerator,
|
||||
moderator="truncate",
|
||||
toolkits=[ToolkitSpec("developer")],
|
||||
)
|
||||
1
src/goose/system.jinja
Normal file
1
src/goose/system.jinja
Normal file
@@ -0,0 +1 @@
|
||||
You are an AI assistant named Goose. You solve problems using your tools.
|
||||
9
src/goose/toolkit/__init__.py
Normal file
9
src/goose/toolkit/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from functools import cache
|
||||
|
||||
from goose.toolkit.base import Toolkit
|
||||
from goose.utils import load_plugins
|
||||
|
||||
|
||||
@cache
|
||||
def get_toolkit(name: str) -> type[Toolkit]:
|
||||
return load_plugins(group="goose.toolkit")[name]
|
||||
65
src/goose/toolkit/base.py
Normal file
65
src/goose/toolkit/base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Callable, Mapping, Optional, Tuple, TypeVar
|
||||
|
||||
from attrs import define, field
|
||||
from exchange import Tool
|
||||
|
||||
from goose.notifier import Notifier
|
||||
|
||||
# Create a type variable that can represent any function signature
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
def tool(func: F) -> F:
|
||||
func._is_tool = True
|
||||
return func
|
||||
|
||||
|
||||
@define
|
||||
class Requirements:
|
||||
"""A collection of requirements for advanced toolkits
|
||||
|
||||
Requirements are an advanced use case, most toolkits will not need to
|
||||
use these. They allow one toolkit to interact with another's state.
|
||||
"""
|
||||
|
||||
_toolkit: str
|
||||
_requirements: Mapping[str, "Toolkit"] = field(factory=dict)
|
||||
|
||||
def get(self, requirement: str) -> "Toolkit":
|
||||
"""Get a requirement by name."""
|
||||
if requirement not in self._requirements:
|
||||
raise RuntimeError(
|
||||
f"The toolkit '{self._toolkit}' requested a requirement '{requirement}' but none was passed!\n"
|
||||
+ f" Make sure to include `requires: {{{requirement}: ...}}` in your profile config\n"
|
||||
+ f" See the documentation for {self._toolkit} for more details"
|
||||
)
|
||||
return self._requirements[requirement]
|
||||
|
||||
|
||||
class Toolkit(ABC):
|
||||
"""A collection of tools with corresponding prompting
|
||||
|
||||
This class defines the interface that all toolkit implementations must follow,
|
||||
providing a system prompt and a collection of tools. Both are allowed to be
|
||||
empty if they are not required for the toolkit.
|
||||
"""
|
||||
|
||||
def __init__(self, notifier: Notifier, requires: Optional[Requirements] = None) -> None:
|
||||
self.notifier = notifier
|
||||
# This needs to be updated after the fact via build_exchange
|
||||
self.exchange_view = None
|
||||
|
||||
def system(self) -> str:
|
||||
"""Get the addition to the system prompt for this toolkit."""
|
||||
return ""
|
||||
|
||||
def tools(self) -> Tuple[Tool, ...]:
|
||||
"""Get the tools for this toolkit
|
||||
|
||||
This default method looks for functions on the toolkit annotated
|
||||
with @tool.
|
||||
"""
|
||||
candidates = inspect.getmembers(self, predicate=inspect.ismethod)
|
||||
return (Tool.from_function(candidate) for _, candidate in candidates if getattr(candidate, "_is_tool", None))
|
||||
201
src/goose/toolkit/developer.py
Normal file
201
src/goose/toolkit/developer.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from pathlib import Path
|
||||
from subprocess import CompletedProcess, run
|
||||
from typing import List
|
||||
|
||||
from exchange import Message
|
||||
from rich import box
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm, PromptType
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from goose.toolkit.base import Toolkit, tool
|
||||
from goose.toolkit.utils import get_language
|
||||
|
||||
|
||||
def keep_unsafe_command_prompt(command: str) -> PromptType:
|
||||
command_text = Text(command, style="bold red")
|
||||
message = (
|
||||
Text("\nWe flagged the command: ")
|
||||
+ command_text
|
||||
+ Text(" as potentially unsafe, do you want to proceed? (yes/no)")
|
||||
)
|
||||
return Confirm.ask(message, default=True)
|
||||
|
||||
|
||||
class Developer(Toolkit):
|
||||
"""The developer toolkit provides a set of general purpose development capabilities
|
||||
|
||||
The tools include plan management, a general purpose shell execution tool, and file operations.
|
||||
We also include some default shell strategies in the prompt, such as using ripgrep
|
||||
"""
|
||||
|
||||
def system(self) -> str:
|
||||
"""Retrieve system configuration details for developer"""
|
||||
return Message.load("prompts/developer.jinja").text
|
||||
|
||||
@tool
|
||||
def update_plan(self, tasks: List[dict]) -> List[dict]:
|
||||
"""
|
||||
Update the plan by overwriting all current tasks
|
||||
|
||||
This can be used to update the status of a task. This update will be
|
||||
shown to the user directly, you do not need to reiterate it
|
||||
|
||||
Args:
|
||||
tasks (List(dict)): The list of tasks, where each task is a dictionary
|
||||
with a key for the task "description" and the task "status". The status
|
||||
MUST be one of "planned", "complete", "failed", "in-progress".
|
||||
|
||||
"""
|
||||
# Validate the status of each task to ensure it is one of the accepted values.
|
||||
for task in tasks:
|
||||
if task["status"] not in {"planned", "complete", "failed", "in-progress"}:
|
||||
raise ValueError(f"Invalid task status: {task['status']}")
|
||||
|
||||
# Create a table with columns for the index, description, and status of each task.
|
||||
table = Table(expand=True)
|
||||
table.add_column("#", justify="right", style="magenta")
|
||||
table.add_column("Task", justify="left")
|
||||
table.add_column("Status", justify="left")
|
||||
|
||||
# Mapping of statuses to emojis for better visual representation in the table.
|
||||
emoji = {"planned": "⏳", "complete": "✅", "failed": "❌", "in-progress": "🕓"}
|
||||
for i, entry in enumerate(tasks):
|
||||
table.add_row(str(i), entry["description"], emoji[entry["status"]])
|
||||
|
||||
# Log the table to display it directly to the user
|
||||
# `.log` method is used here to log the command execution in the application's UX
|
||||
self.notifier.log(table)
|
||||
|
||||
# Return the tasks unchanged as the function's primary purpose is to update and display the task status.
|
||||
return tasks
|
||||
|
||||
@tool
|
||||
def patch_file(self, path: str, before: str, after: str) -> str:
|
||||
"""Patch the file at the specified by replacing before with after
|
||||
|
||||
Before **must** be present exactly once in the file, so that it can safely
|
||||
be replaced with after.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file, in the format "path/to/file.txt"
|
||||
before (str): The content that will be replaced
|
||||
after (str): The content it will be replaced with
|
||||
"""
|
||||
self.notifier.status(f"editing {path}")
|
||||
_path = Path(path)
|
||||
language = get_language(path)
|
||||
|
||||
content = _path.read_text()
|
||||
|
||||
if content.count(before) > 1:
|
||||
raise ValueError("The before content is present multiple times in the file, be more specific.")
|
||||
if content.count(before) < 1:
|
||||
raise ValueError("The before content was not found in file, be careful that you recreate it exactly.")
|
||||
|
||||
content = content.replace(before, after)
|
||||
_path.write_text(content)
|
||||
|
||||
output = f"""
|
||||
```{language}
|
||||
{before}
|
||||
```
|
||||
->
|
||||
```{language}
|
||||
{after}
|
||||
```
|
||||
"""
|
||||
self.notifier.log(Panel.fit(Markdown(output), title=path))
|
||||
return "Succesfully replaced before with after."
|
||||
|
||||
@tool
|
||||
def read_file(self, path: str) -> str:
|
||||
"""Read the content of the file at path
|
||||
|
||||
Args:
|
||||
path (str): The destination file path, in the format "path/to/file.txt"
|
||||
"""
|
||||
language = get_language(path)
|
||||
content = Path(path).expanduser().read_text()
|
||||
self.notifier.log(Panel.fit(Markdown(f"```\ncat {path}\n```"), box=box.MINIMAL))
|
||||
return f"```{language}\n{content}\n```"
|
||||
|
||||
@tool
|
||||
def shell(self, command: str) -> str:
|
||||
"""
|
||||
Execute a command on the shell (in OSX)
|
||||
|
||||
This will return the output and error concatenated into a single string, as
|
||||
you would see from running on the command line. There will also be an indication
|
||||
of if the command succeeded or failed.
|
||||
|
||||
Args:
|
||||
command (str): The shell command to run. It can support multiline statements
|
||||
if you need to run more than one at a time
|
||||
"""
|
||||
self.notifier.status("running shell command")
|
||||
# Log the command being executed in a visually structured format (Markdown).
|
||||
# The `.log` method is used here to log the command execution in the application's UX
|
||||
# this method is dynamically attached to functions in the Goose framework to handle user-visible
|
||||
# logging and integrates with the overall UI logging system
|
||||
self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell"))
|
||||
|
||||
safety_rails_exchange = self.exchange_view.processor.replace(
|
||||
system=Message.load("prompts/safety_rails.jinja").text
|
||||
)
|
||||
# remove the previous message which was a tool_use Assistant message
|
||||
safety_rails_exchange.messages.pop()
|
||||
|
||||
safety_rails_exchange.add(Message.assistant(f"Here is the command I'd like to run: `{command}`"))
|
||||
safety_rails_exchange.add(Message.user("Please provide the danger rating of that command"))
|
||||
rating = safety_rails_exchange.reply().text
|
||||
|
||||
try:
|
||||
rating = int(rating)
|
||||
except ValueError:
|
||||
rating = 5 # if we can't interpret we default to unsafe
|
||||
if int(rating) > 3:
|
||||
if not keep_unsafe_command_prompt(command):
|
||||
raise RuntimeError(
|
||||
f"The command {command} was rejected as dangerous by the user."
|
||||
+ " Do not proceed further, instead ask for instructions."
|
||||
)
|
||||
|
||||
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
|
||||
if result.returncode == 0:
|
||||
output = "Command succeeded"
|
||||
else:
|
||||
output = f"Command failed with returncode {result.returncode}"
|
||||
return "\n".join([output, result.stdout, result.stderr])
|
||||
|
||||
@tool
|
||||
def write_file(self, path: str, content: str) -> str:
|
||||
"""
|
||||
Write a file at the specified path with the provided content. This will create any directories if they do not exist.
|
||||
The content will fully overwrite the existing file.
|
||||
|
||||
Args:
|
||||
path (str): The destination file path, in the format "path/to/file.txt"
|
||||
content (str): The raw file content.
|
||||
""" # noqa: E501
|
||||
self.notifier.status("writing file")
|
||||
# Get the programming language for syntax highlighting in logs
|
||||
language = get_language(path)
|
||||
md = f"```{language}\n{content}\n```"
|
||||
|
||||
# Log the content that will be written to the file
|
||||
# .log` method is used here to log the command execution in the application's UX
|
||||
# this method is dynamically attached to functions in the Goose framework
|
||||
self.notifier.log(Panel.fit(Markdown(md), title=path))
|
||||
|
||||
# Prepare the path and create any necessary parent directories
|
||||
_path = Path(path)
|
||||
_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the content to the file
|
||||
_path.write_text(content)
|
||||
|
||||
# Return a success message
|
||||
return f"Succesfully wrote to {path}"
|
||||
11
src/goose/toolkit/github.py
Normal file
11
src/goose/toolkit/github.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from exchange import Message
|
||||
|
||||
from goose.toolkit.base import Toolkit
|
||||
|
||||
|
||||
class Github(Toolkit):
|
||||
"""Provides an additional prompt on how to interact with Github"""
|
||||
|
||||
def system(self) -> str:
|
||||
"""Retrieve detailed configuration and procedural guidelines for GitHub operations"""
|
||||
return Message.load("prompts/github.jinja").text
|
||||
62
src/goose/toolkit/prompts/developer.jinja
Normal file
62
src/goose/toolkit/prompts/developer.jinja
Normal file
@@ -0,0 +1,62 @@
|
||||
To start solving problems, you will always create a plan, and then immediately
|
||||
execute the next step of the plan. Do not wait for confirmation on your plan,
|
||||
always proceed to the next step. After each step, update the plan based
|
||||
on the output you recieve from any actions you take, including marking
|
||||
finished tasks as complete!
|
||||
|
||||
To accomplish this, you will call your tools such as update_plan, shell, or write.
|
||||
|
||||
Always use the update_plan tool before taking any new actions, to show the client
|
||||
an up to date plan. The plan will be displayed automatically any time you update it.
|
||||
|
||||
The plan should consist of as few entries as possible, and translate from the user
|
||||
request into concrete tasks that you will use to get it done. These should reflect
|
||||
the actions you will need to take, such as writing files or executing shell commands.
|
||||
|
||||
For example, here's a plan to unstage all edited files in a git repo
|
||||
|
||||
{"description": "Use the git status command to find edited files", "status": "pending"}
|
||||
{"description": "For each file with changes, call git restore on the file", "status": "pending"}
|
||||
|
||||
After running git status, you would update to
|
||||
|
||||
{"description": "Use the git status command to find edited files", "status": "complete"}
|
||||
{"description": "For each file with changes, call git restore on the file", "status": "pending"}
|
||||
|
||||
|
||||
Here's another plan example to get the sum of the "payment" column in data.csv
|
||||
|
||||
{"description": "Install pandas", "status": "pending"}
|
||||
{"description": "Write a python script in the file sum.py that loads the csv in pandas and sums the column", "status": "pending"}
|
||||
{"description": "Run the python script with bash", "status": "pending"}
|
||||
|
||||
If you were to encounter an error along the way, you can update the plan to specify a new approach.
|
||||
Always call update_plan before any other tool calls! You should specify the whole plan upfront as pending,
|
||||
and then update status at each step. **Do not describe the plan in your text response, only communicate
|
||||
the plan through the tool**
|
||||
|
||||
If you need to manipulate files, always prefer the write file tool. To edit a file
|
||||
that already exists, first check the content with the shell and then overwrite it
|
||||
|
||||
|
||||
Some of the files that you work with will be long. When you want to edit a long file,
|
||||
prefer to use the patch tool. Make sure that you always read the file using the read tool
|
||||
before you call patch.
|
||||
|
||||
The patch and write tools can accomplish the same operations, but patch is a more complex tool.
|
||||
The patch tool is worth it when the file is large enough that it would be tedious to fully rewrite.
|
||||
|
||||
|
||||
You are an expert with ripgrep - `rg`. When you need to locate content in the code base, use
|
||||
`rg` exclusively. It will respect ignored files for efficiency.
|
||||
|
||||
To locate files by name, use
|
||||
|
||||
```bash
|
||||
rg --files | rg example.py
|
||||
```
|
||||
|
||||
To locate content inside files, use
|
||||
```bash
|
||||
rg 'class Example'
|
||||
```
|
||||
18
src/goose/toolkit/prompts/github.jinja
Normal file
18
src/goose/toolkit/prompts/github.jinja
Normal file
@@ -0,0 +1,18 @@
|
||||
You can interact through github via the `gh` command line generally.
|
||||
If it fails to auth, prompt the user to run `gh auth login`
|
||||
|
||||
Typically when someone requests you to look at a pull request review, they mean to view
|
||||
not just the top level comments and reviews, but also the comments nested within that review.
|
||||
|
||||
To do that, you need to first use the API to get reviews:
|
||||
```bash
|
||||
gh api -H "Accept: application/vnd.github+json" /repos/OWNER/REPO/pulls/PULL_NUMBER/reviews
|
||||
```
|
||||
|
||||
And then for each individual review, get all of the comments:
|
||||
```bash
|
||||
gh api -H 'Accept: application/vnd.github+json' /repos/OWNER/REPO/pulls/PULL_NUMBER/reviews/ID/comments
|
||||
```
|
||||
|
||||
When you work with a pull request review, use the above approach as well as `gh diff`
|
||||
to get the full details before answering questions.
|
||||
39
src/goose/toolkit/prompts/safety_rails.jinja
Normal file
39
src/goose/toolkit/prompts/safety_rails.jinja
Normal file
@@ -0,0 +1,39 @@
|
||||
You are a third party auditor, with the goal of ensuring safety for commands
|
||||
that will run on the user's system. You will be provided a conversation,
|
||||
along with the next command that is queued to run. Evaluate if that command
|
||||
is safe.
|
||||
|
||||
Here are examples of harmful commands:
|
||||
1. Commands that delete files that weren't created in the session, like `rm -rf /`
|
||||
2. Commands that edit dotfiles or generally anything at the top level of the home `~/.` directory
|
||||
3. Commands that globally modify packages or the environment, such as `brew install`
|
||||
|
||||
However, it's safe to allow edits scoped to a project:
|
||||
1. Commands that create new directories or new files, or scaffold new projects. `mkdir example`
|
||||
2. Commands that manage the dependencies of the current project, such as virtualenvs, cargo, npm etc. `poetry add requests`
|
||||
3. Commands that edit dotfiles in a project directory, like a .gitignore
|
||||
|
||||
You should rate the potential harm of a command on a scale of 1 to 5, with 1 being totally safe and 5 being very dangerous.
|
||||
Reply with a single character, the numerical rating.
|
||||
|
||||
Here are some examples.
|
||||
|
||||
Assistant: Here is the command I'd like to run: `rm -rf ~`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 5
|
||||
|
||||
Assistant: Here is the command I'd like to run: `brew install ripgrep`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 4
|
||||
|
||||
Assistant: Here is the command I'd like to run: `rm tests/test_exchange.py`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 3
|
||||
|
||||
Assistant: Here is the command I'd like to run: `cat demo.py`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 1
|
||||
|
||||
Assistant: Here is the command I'd like to run: `echo "export PATH=$HOME/.local/bin/:$PATH" >> ~/.zprofile`
|
||||
User: Please provide the danger rating of that command
|
||||
Assistant: 5
|
||||
0
src/goose/toolkit/repo_context/__init__.py
Normal file
0
src/goose/toolkit/repo_context/__init__.py
Normal file
39
src/goose/toolkit/repo_context/prompts/repo_context.jinja
Normal file
39
src/goose/toolkit/repo_context/prompts/repo_context.jinja
Normal file
@@ -0,0 +1,39 @@
|
||||
Given a dictionary of files and directories in a project repository, please identify which files should be retained
|
||||
based solely on their relevance to the core processing code of the project. Exclude configuration-related files and other
|
||||
non-code files, except for necessary Dockerfiles and Markdown files. You do not need to read or open the files or
|
||||
directories. Just make an educated guess.
|
||||
|
||||
**Important:** Return the file and directory names exactly as they appear in the input list. Do not modify, alter, or
|
||||
assume different names. Any suggested file or directory must match an entry in the input list.
|
||||
|
||||
Return ONLY a dictionary of relevant files and potentially relevant directories that need further inspection. NO
|
||||
MARKDOWN NOTATION.
|
||||
|
||||
Example:
|
||||
|
||||
Input:
|
||||
|
||||
{
|
||||
'files': [
|
||||
'LICENSE.md',
|
||||
'ARCHITECTURE.md',
|
||||
'mkdocs.yml',
|
||||
'justfile',
|
||||
'CHANGELOG.md',
|
||||
'pyproject.toml',
|
||||
'README.md',
|
||||
'CONTRIBUTING.md',
|
||||
'poetry.lock'
|
||||
],
|
||||
'directories': ['bin', 'tests', 'docs', 'mlruns', 'scripts', 'src']
|
||||
}
|
||||
|
||||
Output:
|
||||
|
||||
{
|
||||
'files': [
|
||||
'ARCHITECTURE.md',
|
||||
'README.md',
|
||||
],
|
||||
'directories': ['tests', 'scripts', 'src']
|
||||
}
|
||||
106
src/goose/toolkit/repo_context/repo_context.py
Normal file
106
src/goose/toolkit/repo_context/repo_context.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
from functools import cache
|
||||
from subprocess import CompletedProcess, run
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from exchange import Message
|
||||
|
||||
from goose.notifier import Notifier
|
||||
from goose.toolkit import Toolkit
|
||||
from goose.toolkit.base import Requirements, tool
|
||||
from goose.toolkit.repo_context.utils import get_repo_size, goose_picks_files
|
||||
from goose.toolkit.summarization.utils import load_summary_file_if_exists, summarize_files_concurrent
|
||||
from goose.utils.ask import clear_exchange, replace_prompt
|
||||
|
||||
|
||||
class RepoContext(Toolkit):
|
||||
def __init__(self, notifier: Notifier, requires: Requirements) -> None:
|
||||
super().__init__(notifier=notifier, requires=requires)
|
||||
|
||||
self.repo_project_root, self.is_git_repo, self.goose_session_root = self.determine_git_proj()
|
||||
|
||||
def determine_git_proj(self) -> Tuple[str, bool, str]:
|
||||
"""Determines the root as well as where Goose is currently running
|
||||
|
||||
If the project is not part of a Github repo, the root of the project will be defined as the current working
|
||||
directory
|
||||
|
||||
Returns:
|
||||
str: path to the root of the project (if part of a local repository) or the CWD if not
|
||||
boolean: if Goose is operating within local repository or not
|
||||
str: path to where the Goose session is running (the CWD)
|
||||
"""
|
||||
# FIXME: monorepos
|
||||
cwd = os.getcwd()
|
||||
command = "git rev-parse --show-toplevel"
|
||||
result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False)
|
||||
if result.returncode == 0:
|
||||
project_root = result.stdout.strip()
|
||||
return project_root, True, cwd
|
||||
else:
|
||||
self.notifier.log("Not part of a Git repository. Returning current working directory")
|
||||
return cwd, False, cwd
|
||||
|
||||
@property
|
||||
@cache
|
||||
def repo_size(self) -> float:
|
||||
"""Returns the size of the repo in MB (if Goose detects its running in a local repository
|
||||
|
||||
This measurement can be used to guess if the local repository is a monorepo
|
||||
|
||||
Returns:
|
||||
float: size of project in MB
|
||||
"""
|
||||
# in MB
|
||||
if self.is_git_repo:
|
||||
return get_repo_size(self.repo_project_root)
|
||||
else:
|
||||
self.notifier.log("Not a git repo. Returning 0.")
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def is_mono_repo(self) -> bool:
|
||||
"""An boolean indicator of whether the local repository is part of a monorepo
|
||||
|
||||
Returns:
|
||||
boolean: True if above 2000 MB; False otherwise
|
||||
"""
|
||||
# java: 6394.367112159729
|
||||
# go: 3729.93 MB
|
||||
return self.repo_size > 2000
|
||||
|
||||
@tool
|
||||
def summarize_current_project(self) -> Dict[str, str]:
|
||||
"""Summarizes the current project based on repo root (if git repo) or current project_directory (if not)
|
||||
|
||||
Returns:
|
||||
summary (Dict[str, str]): Keys are file paths and values are the summaries
|
||||
"""
|
||||
|
||||
self.notifier.log("Summarizing the most relevant files in the current project. This may take a while...")
|
||||
|
||||
if self.is_mono_repo:
|
||||
self.notifier.log("This might be a monorepo. Goose performs better on smaller projects. Using CWD.")
|
||||
# TODO: prompt user to specify a subdirectory
|
||||
project_directory = self.goose_session_root
|
||||
else:
|
||||
project_directory = self.repo_project_root
|
||||
|
||||
# before selecting files and summarizing look for summarization file
|
||||
project_name = project_directory.split("/")[-1]
|
||||
summary = load_summary_file_if_exists(project_name=project_name)
|
||||
if summary:
|
||||
self.notifier.log("Summary file for project exists already -- loading into the context")
|
||||
return summary
|
||||
|
||||
# clear exchange and replace the system prompt with instructions on why and how to select files to summarize
|
||||
file_select_exchange = clear_exchange(self.exchange_view.accelerator, clear_tools=True)
|
||||
system = Message.load("prompts/repo_context.jinja").text
|
||||
file_select_exchange = replace_prompt(exchange=file_select_exchange, prompt=system)
|
||||
files = goose_picks_files(root=project_directory, exchange=file_select_exchange)
|
||||
|
||||
summary = summarize_files_concurrent(
|
||||
exchange=self.exchange_view.accelerator, file_list=files, project_name=project_directory.split("/")[-1]
|
||||
)
|
||||
|
||||
return summary
|
||||
104
src/goose/toolkit/repo_context/utils.py
Normal file
104
src/goose/toolkit/repo_context/utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import ast
|
||||
import concurrent.futures
|
||||
import os
|
||||
from collections import deque
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from exchange import Exchange
|
||||
|
||||
from goose.utils.ask import ask_an_ai
|
||||
|
||||
|
||||
def get_directory_size(directory: str) -> int:
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk(directory):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
# Skip if it is a symbolic link
|
||||
if not os.path.islink(fp):
|
||||
total_size += os.path.getsize(fp)
|
||||
return total_size
|
||||
|
||||
|
||||
def get_repo_size(repo_path: str) -> int:
|
||||
"""Returns repo size in MB"""
|
||||
git_dir = os.path.join(repo_path, ".git")
|
||||
return get_directory_size(git_dir) / (1024**2)
|
||||
|
||||
|
||||
def get_files_and_directories(root_dir: str) -> Dict[str, list]:
|
||||
"""Gets file names and directory names. Checks that goose has correctly typed the file and directory names and that
|
||||
the files actually exist (to avoid downstream file read errors).
|
||||
|
||||
Args:
|
||||
root_dir (str): Path to the directory to examine for files and sub-directories
|
||||
|
||||
Returns:
|
||||
dict: A list of files and directories in the form {'files': [], 'directories: []}. Paths
|
||||
are all relative (i.e. ['src'] not ['goose/src'])
|
||||
"""
|
||||
files = []
|
||||
dirs = []
|
||||
|
||||
# check dir exists
|
||||
try:
|
||||
os.listdir(root_dir)
|
||||
except FileNotFoundError:
|
||||
# FIXME: fuzzy match might work here to recover directories 'lost' to goose mistyping
|
||||
# hallucination: Goose mistyped the path (e.g. `metrichandler` vs `metricshandler`)
|
||||
return {"files": files, "directories": dirs}
|
||||
|
||||
for entry in os.listdir(root_dir):
|
||||
if entry.startswith(".") or entry.startswith("~"):
|
||||
continue # Skip hidden files and directories
|
||||
|
||||
full_path = os.path.join(root_dir, entry)
|
||||
if os.path.isdir(full_path):
|
||||
dirs.append(entry)
|
||||
elif os.path.isfile(full_path):
|
||||
files.append(entry)
|
||||
|
||||
return {"files": files, "directories": dirs}
|
||||
|
||||
|
||||
def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> List[str]:
|
||||
"""Lets goose pick files in a BFS manner"""
|
||||
queue = deque([root])
|
||||
|
||||
all_files = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
while queue:
|
||||
current_batch = [queue.popleft() for _ in range(min(max_workers, len(queue)))]
|
||||
futures = {executor.submit(process_directory, dir, exchange): dir for dir in current_batch}
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
files, next_dirs = future.result()
|
||||
all_files.extend(files)
|
||||
queue.extend(next_dirs)
|
||||
|
||||
return all_files
|
||||
|
||||
|
||||
def process_directory(current_dir: str, exchange: Exchange) -> Tuple[List[str], List[str]]:
|
||||
"""Allows goose to pick files and subdirectories contained in a given directory (current_dir). Get the list of file
|
||||
and directory names in the current folder, then ask Goose to pick which ones to keep.
|
||||
|
||||
"""
|
||||
files_and_dirs = get_files_and_directories(current_dir)
|
||||
ai_response = ask_an_ai(str(files_and_dirs), exchange)
|
||||
|
||||
# FIXME: goose response validation
|
||||
try:
|
||||
as_dict = ast.literal_eval(ai_response.text)
|
||||
except Exception:
|
||||
# can happen if goose returns anything but {result: dict} (e.g. ```json\n {results: dict} \n```)
|
||||
return [], []
|
||||
if not isinstance(as_dict, dict):
|
||||
# can happen if goose returns something like `{'files': ['x.py'] 'directories': ['dir1']}` (missing comma)
|
||||
return [], []
|
||||
|
||||
files = [f"{current_dir}/{file}" for file in as_dict.get("files", [])]
|
||||
next_dirs = [f"{current_dir}/{next_dir}" for next_dir in as_dict.get("directories", [])]
|
||||
|
||||
return files, next_dirs
|
||||
28
src/goose/toolkit/screen.py
Normal file
28
src/goose/toolkit/screen.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
from goose.toolkit.base import Toolkit, tool
|
||||
|
||||
|
||||
class Screen(Toolkit):
|
||||
"""Provides an instructions on when and how to work with screenshots"""
|
||||
|
||||
@tool
|
||||
def take_screenshot(self) -> str:
|
||||
"""
|
||||
Take a screenshot to assist the user in debugging or designing an app. They may need help with moving a screen element, or interacting in some way where you could do with seeing the screen.
|
||||
|
||||
Return:
|
||||
(str) a path to the screenshot file, in the format of image: followed by the path to the file.
|
||||
""" # noqa: E501
|
||||
# Generate a random tmp filename for screenshot
|
||||
filename = f"/tmp/goose_screenshot_{uuid.uuid4().hex}.png"
|
||||
|
||||
subprocess.run(["screencapture", "-x", filename])
|
||||
|
||||
return f"image:{filename}"
|
||||
|
||||
# Provide any system instructions for the model
|
||||
# This can be generated dynamically, and is run at startup time
|
||||
def system(self) -> str:
|
||||
return """**When the user wants you to help debug, or work on a visual design by looking at their screen, IDE or browser, call the take_screenshot and send the output from the user.**""" # noqa: E501
|
||||
3
src/goose/toolkit/summarization/__init__.py
Normal file
3
src/goose/toolkit/summarization/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .summarize_repo import SummarizeRepo # noqa
|
||||
from .summarize_project import SummarizeProject # noqa
|
||||
from .summarize_file import SummarizeFile # noqa
|
||||
28
src/goose/toolkit/summarization/summarize_file.py
Normal file
28
src/goose/toolkit/summarization/summarize_file.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Optional
|
||||
|
||||
from goose.toolkit import Toolkit
|
||||
from goose.toolkit.base import tool
|
||||
from goose.toolkit.summarization.utils import summarize_file
|
||||
|
||||
|
||||
class SummarizeFile(Toolkit):
|
||||
@tool
|
||||
def summarize_file(self, filepath: str, prompt: Optional[str] = None) -> str:
|
||||
"""
|
||||
Tool to summarize a specific file
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the file to summarize
|
||||
prompt (str): Optional prompt giving the model instructions on how to summarize the file.
|
||||
Under the hood this defaults to "Please summarize this file"
|
||||
|
||||
Returns:
|
||||
summary (Optional[str]): Summary of the file contents
|
||||
|
||||
"""
|
||||
|
||||
exchange = self.exchange_view.accelerator
|
||||
|
||||
_, summary = summarize_file(filepath=filepath, exchange=exchange, prompt=prompt)
|
||||
|
||||
return summary
|
||||
37
src/goose/toolkit/summarization/summarize_project.py
Normal file
37
src/goose/toolkit/summarization/summarize_project.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from goose.toolkit import Toolkit
|
||||
from goose.toolkit.base import tool
|
||||
from goose.toolkit.summarization.utils import summarize_directory
|
||||
|
||||
|
||||
class SummarizeProject(Toolkit):
|
||||
@tool
|
||||
def get_project_summary(
|
||||
self,
|
||||
project_dir_path: Optional[str] = os.getcwd(),
|
||||
extensions: Optional[List[str]] = None,
|
||||
summary_instructions_prompt: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Generates or retrieves a project summary based on specified file extensions.
|
||||
|
||||
Args:
|
||||
project_dir_path (Optional[Path]): Path to the project directory. Defaults to the current working directory
|
||||
if None
|
||||
extensions (Optional[List[str]]): Specific file extensions to summarize.
|
||||
summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g.
|
||||
"Summarize the file in two sentences.". The default instruction is "Please summarize this file."
|
||||
|
||||
Returns:
|
||||
summary (dict): Project summary.
|
||||
""" # noqa: E501
|
||||
|
||||
summary = summarize_directory(
|
||||
project_dir_path,
|
||||
exchange=self.exchange_view.accelerator,
|
||||
extensions=extensions,
|
||||
summary_instructions_prompt=summary_instructions_prompt,
|
||||
)
|
||||
|
||||
return summary
|
||||
37
src/goose/toolkit/summarization/summarize_repo.py
Normal file
37
src/goose/toolkit/summarization/summarize_repo.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from goose.toolkit import Toolkit
|
||||
from goose.toolkit.base import tool
|
||||
from goose.toolkit.summarization.utils import summarize_repo
|
||||
|
||||
|
||||
class SummarizeRepo(Toolkit):
|
||||
@tool
|
||||
def summarize_repo(
|
||||
self,
|
||||
repo_url: str,
|
||||
specified_extensions: Optional[List[str]] = None,
|
||||
summary_instructions_prompt: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieves a summary of a repository. Clones the repository if not already cloned and summarizes based on the
|
||||
specified file extensions. If no extensions are specified, it summarizes the top `max_extensions` extensions.
|
||||
|
||||
Args:
|
||||
repo_url (str): The URL of the repository to summarize.
|
||||
specified_extensions (Optional[List[str]]): List of file extensions to summarize, e.g., ["tf", "md"]. If
|
||||
this list is empty, then all files in the repo are summarized
|
||||
summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g.
|
||||
"Summarize the file in two sentences.". The default instruction is "Please summarize this file."
|
||||
|
||||
Returns:
|
||||
summary (dict): A summary of the repository where keys are the file extensions and values are their
|
||||
summaries.
|
||||
""" # noqa: E501
|
||||
|
||||
return summarize_repo(
|
||||
repo_url=repo_url,
|
||||
exchange=self.exchange_view.accelerator,
|
||||
extensions=specified_extensions,
|
||||
summary_instructions_prompt=summary_instructions_prompt,
|
||||
)
|
||||
199
src/goose/toolkit/summarization/utils.py
Normal file
199
src/goose/toolkit/summarization/utils.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import json
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from exchange import Exchange
|
||||
from exchange.providers.utils import InitialMessageTooLargeError
|
||||
|
||||
from goose.utils.ask import ask_an_ai
|
||||
from goose.utils.file_utils import create_file_list
|
||||
|
||||
SUMMARIES_FOLDER = ".goose/summaries"
|
||||
CLONED_REPOS_FOLDER = ".goose/cloned_repos"
|
||||
|
||||
|
||||
# TODO: move git stuff
|
||||
def run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]:
|
||||
result = subprocess.run(["git"] + command, capture_output=True, text=True, check=False)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"Git command failed with message: {result.stderr.strip()}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def clone_repo(repo_url: str, target_directory: str) -> None:
|
||||
run_git_command(["clone", repo_url, target_directory])
|
||||
|
||||
|
||||
def load_summary_file_if_exists(project_name: str) -> Optional[Dict]:
|
||||
"""Checks if a summary file exists at '.goose/summaries/projectname-summary.json. Returns contents of the file if
|
||||
it exists, otherwise returns None
|
||||
|
||||
Args:
|
||||
project_name (str): name of the project or repo
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: File contents, else None
|
||||
"""
|
||||
summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json"
|
||||
if Path(summary_file_path).exists():
|
||||
with open(summary_file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> Tuple[str, str]:
|
||||
"""Summarizes a single file
|
||||
|
||||
Args:
|
||||
filepath (str): Path to the file to summarize.
|
||||
exchange (Exchange): Exchange object to use for summarization.
|
||||
prompt (Optional[str]): Defaults to "Please summarize this file."
|
||||
"""
|
||||
try:
|
||||
with open(filepath, "r") as f:
|
||||
file_text = f.read()
|
||||
except Exception as e:
|
||||
return filepath, f"Error reading file {filepath}: {str(e)}"
|
||||
|
||||
if not file_text:
|
||||
return filepath, "Empty file"
|
||||
|
||||
try:
|
||||
reply = ask_an_ai(
|
||||
input=file_text, exchange=exchange, prompt=prompt if prompt else "Please summarize this file."
|
||||
)
|
||||
except InitialMessageTooLargeError:
|
||||
return filepath, "File too large"
|
||||
|
||||
return filepath, reply.text
|
||||
|
||||
|
||||
def summarize_repo(
|
||||
repo_url: str,
|
||||
exchange: Exchange,
|
||||
extensions: List[str],
|
||||
summary_instructions_prompt: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Clones (if needed) and summarizes a repo
|
||||
|
||||
Args:
|
||||
repo_url (str): Repository url
|
||||
exchange (Exchange): Exchange for summarizing the repo.
|
||||
extensions (List[str]): List of file-types to summarize.
|
||||
summary_instructions_prompt (Optional[str]): Optional parameter to customize summarization results. Defaults to
|
||||
"Please summarize this file"
|
||||
"""
|
||||
# set up the paths for the repository and the summary file
|
||||
repo_name = repo_url.split("/")[-1]
|
||||
repo_dir = f"{CLONED_REPOS_FOLDER}/{repo_name}" # e.g. '.goose/cloned_repos/<project-name>'
|
||||
|
||||
if Path(repo_dir).exists():
|
||||
# TODO: re-add ability to log
|
||||
return summarize_directory(
|
||||
directory=repo_dir,
|
||||
exchange=exchange,
|
||||
extensions=extensions,
|
||||
summary_instructions_prompt=summary_instructions_prompt,
|
||||
)
|
||||
|
||||
clone_repo(repo_url, target_directory=repo_dir)
|
||||
|
||||
return summarize_directory(
|
||||
directory=repo_dir,
|
||||
exchange=exchange,
|
||||
extensions=extensions,
|
||||
summary_instructions_prompt=summary_instructions_prompt,
|
||||
)
|
||||
|
||||
|
||||
def summarize_directory(
|
||||
directory: str, exchange: Exchange, extensions: List[str], summary_instructions_prompt: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Summarize files in a given directory based on extensions. Will also recursively find files in subdirectories and
|
||||
summarize them.
|
||||
|
||||
Args:
|
||||
directory (str): path to the top-level directory to summarize
|
||||
exchange (Exchange): Exchange to use to summarize
|
||||
extensions (List[str]): List of file-type extensions to summarize (and ignore all other extensions).
|
||||
summary_instructions_prompt (Optional[str]): Optional instructions to give to the exchange regarding summarization.
|
||||
|
||||
Returns:
|
||||
file_summaries (dict): Keys are file names and values are summaries.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
# TODO: make sure that '.goose/summaries' is
|
||||
# in the root of the current not relative to current dir or in cloned repo root
|
||||
project_name = directory.split("/")[-1]
|
||||
summary_file = load_summary_file_if_exists(project_name)
|
||||
if summary_file:
|
||||
return summary_file
|
||||
|
||||
summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json"
|
||||
|
||||
# create the .goose/summaries folder if not already created
|
||||
Path(SUMMARIES_FOLDER).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# select a subset of files to summarize based on file extension
|
||||
files_to_summarize = create_file_list(directory, extensions=extensions)
|
||||
|
||||
file_summaries = summarize_files_concurrent(
|
||||
exchange=exchange,
|
||||
file_list=files_to_summarize,
|
||||
project_name=project_name,
|
||||
summary_instructions_prompt=summary_instructions_prompt,
|
||||
)
|
||||
|
||||
summary_file_contents = {"extensions": extensions, "summaries": file_summaries}
|
||||
|
||||
# Write the summaries into a json
|
||||
with open(summary_file_path, "w") as f:
|
||||
json.dump(summary_file_contents, f, indent=2)
|
||||
|
||||
return file_summaries
|
||||
|
||||
|
||||
def summarize_files_concurrent(
|
||||
exchange: Exchange, file_list: List[str], project_name: str, summary_instructions_prompt: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Takes in a list of files and summarizes them. Exchange does not keep history of the summarized files.
|
||||
|
||||
Args:
|
||||
exchange (Exchange): Underlying exchange
|
||||
file_list (List[str]): List of paths to files to summarize
|
||||
project_name (str): Used to save the summary of the files to .goose/summaries/<project_name>-summary.json
|
||||
summary_instructions_prompt (Optional[str]): Summary instructions for the LLM. Defaults to "Please summarize
|
||||
this file."
|
||||
|
||||
Returns:
|
||||
file_summaries (Dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange
|
||||
"""
|
||||
summary_file = load_summary_file_if_exists(project_name)
|
||||
if summary_file:
|
||||
return summary_file
|
||||
|
||||
file_summaries = {}
|
||||
# compile the individual file summaries into a single summary dict
|
||||
# TODO: add progress bar as this step can take quite some time and it's nice to see something is happening
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_to_file = {
|
||||
executor.submit(summarize_file, file, exchange, summary_instructions_prompt): file for file in file_list
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_file):
|
||||
file_name, file_summary = future.result()
|
||||
file_summaries[file_name] = file_summary
|
||||
|
||||
# create summaries folder if it doesn't exist
|
||||
Path(SUMMARIES_FOLDER).mkdir(exist_ok=True, parents=True)
|
||||
summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json"
|
||||
|
||||
# Write the summaries into a json
|
||||
with open(summary_file_path, "w") as f:
|
||||
json.dump(file_summaries, f, indent=2)
|
||||
|
||||
return file_summaries
|
||||
21
src/goose/toolkit/utils.py
Normal file
21
src/goose/toolkit/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
from pygments.lexers import get_lexer_for_filename
|
||||
from pygments.util import ClassNotFound
|
||||
|
||||
|
||||
def get_language(filename: Path) -> str:
|
||||
"""
|
||||
Determine the programming language of a file based on its filename extension.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file for which to determine the programming language.
|
||||
|
||||
Returns:
|
||||
str: The name of the programming language if recognized, otherwise an empty string.
|
||||
"""
|
||||
try:
|
||||
lexer = get_lexer_for_filename(filename)
|
||||
return lexer.name
|
||||
except ClassNotFound:
|
||||
return ""
|
||||
70
src/goose/utils/__init__.py
Normal file
70
src/goose/utils/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import random
|
||||
import string
|
||||
from importlib.metadata import entry_points
|
||||
from typing import Any, Callable, Dict, List, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def load_plugins(group: str) -> dict:
|
||||
"""
|
||||
Load plugins based on a specified entry point group.
|
||||
|
||||
This function iterates through all entry points registered under a specified group
|
||||
|
||||
Args:
|
||||
group (str): The entry point group to load plugins from. This should match the group specified
|
||||
in the package setup where plugins are defined.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where each key is the entry point name, and the value is the loaded plugin object.
|
||||
|
||||
Raises:
|
||||
Exception: Propagates exceptions raised by entry point loading, which might occur if a plugin
|
||||
is not found or if there are issues with the plugin's code.
|
||||
"""
|
||||
plugins = {}
|
||||
# Access all entry points for the specified group and load each.
|
||||
for entrypoint in entry_points(group=group):
|
||||
plugin = entrypoint.load() # Load the plugin.
|
||||
plugins[entrypoint.name] = plugin # Store the loaded plugin in the dictionary.
|
||||
return plugins
|
||||
|
||||
|
||||
def ensure(cls: Type[T]) -> Callable[[Any], T]:
|
||||
"""Convert dictionary to a class instance"""
|
||||
|
||||
def converter(val: Any) -> T: # noqa: ANN401
|
||||
if isinstance(val, cls):
|
||||
return val
|
||||
elif isinstance(val, dict):
|
||||
return cls(**val)
|
||||
elif isinstance(val, list):
|
||||
return cls(*val)
|
||||
else:
|
||||
return cls(val)
|
||||
|
||||
return converter
|
||||
|
||||
|
||||
def ensure_list(cls: Type[T]) -> Callable[[List[Dict[str, Any]]], Type[T]]:
|
||||
"""Convert a list of dictionaries to class instances"""
|
||||
|
||||
def converter(val: List[Dict[str, Any]]) -> List[T]:
|
||||
output = []
|
||||
for entry in val:
|
||||
output.append(ensure(cls)(entry))
|
||||
return output
|
||||
|
||||
return converter
|
||||
|
||||
|
||||
def droid() -> str:
|
||||
return "".join(
|
||||
[
|
||||
random.choice(string.ascii_lowercase),
|
||||
random.choice(string.digits),
|
||||
random.choice(string.ascii_lowercase),
|
||||
random.choice(string.digits),
|
||||
]
|
||||
)
|
||||
82
src/goose/utils/ask.py
Normal file
82
src/goose/utils/ask.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from exchange import Exchange, Message
|
||||
|
||||
|
||||
def ask_an_ai(input: str, exchange: Exchange, prompt: str = "", no_history: bool = True) -> Message:
|
||||
"""Sends a separate message to an LLM using a separate Exchange than the one underlying the Goose session.
|
||||
|
||||
Can be used to summarize a file, or submit any other request that you'd like to an AI. The Exchange can have a
|
||||
history/prior context, or be wiped clean (by setting no_history to True).
|
||||
|
||||
Parameters:
|
||||
input (str): The user's input string to be processed by the AI. Must be a non-empty string. Example: text from
|
||||
a file.
|
||||
exchange (Exchange): An object representing the AI exchange system which manages the state and flow of the
|
||||
conversation.
|
||||
prompt (str, optional): An optional new prompt to replace the current one in the exchange system. Defaults to
|
||||
None. Example: "Please summarize this file."
|
||||
no_history (bool, optional): A flag to determine if the conversation history should be cleared before
|
||||
processing the new input. True clears the context, False retains it. Defaults to True.
|
||||
|
||||
Returns:
|
||||
reply (str): The AI's reply as a string.
|
||||
|
||||
Raises:
|
||||
TypeError: If the `input` is not a non-empty string.
|
||||
Exception: If there is an issue within the exchange system, including errors from the provider or model.
|
||||
|
||||
Example:
|
||||
# Create an instance of an Exchange system
|
||||
exchange_system = Exchange(provider=OpenAIProvider.from_env(), model="gpt-4")
|
||||
|
||||
# Simulate asking the AI a question
|
||||
response = ask_an_ai("What is the weather today?", exchange_system)
|
||||
|
||||
print(response) # Outputs the AI's response to the question.
|
||||
"""
|
||||
if no_history:
|
||||
exchange = clear_exchange(exchange)
|
||||
|
||||
if prompt:
|
||||
exchange = replace_prompt(exchange, prompt)
|
||||
|
||||
if not input:
|
||||
raise TypeError("`input` must be a string of finite length")
|
||||
|
||||
msg = Message.user(input)
|
||||
exchange.add(msg)
|
||||
reply = exchange.reply()
|
||||
|
||||
return reply
|
||||
|
||||
|
||||
def clear_exchange(exchange: Exchange, clear_tools: bool = False) -> Exchange:
|
||||
"""Clears the exchange object
|
||||
|
||||
Args:
|
||||
exchange (Exchange): Exchange object to be overwritten. Messages and checkpoints are replaced with empty lists.
|
||||
clear_tools (bool): Boolean to indicate whether tools should be dropped from the exchange.
|
||||
|
||||
Returns:
|
||||
new_exchange (Exchange)
|
||||
|
||||
"""
|
||||
if clear_tools:
|
||||
new_exchange = exchange.replace(messages=[], checkpoints=[], tools=())
|
||||
else:
|
||||
new_exchange = exchange.replace(messages=[], checkpoints=[])
|
||||
return new_exchange
|
||||
|
||||
|
||||
def replace_prompt(exchange: Exchange, prompt: str) -> Exchange:
|
||||
"""Replaces the system prompt
|
||||
|
||||
Args:
|
||||
exchange (Exchange): Exchange object to be overwritten. Messages and checkpoints are replaced with empty lists.
|
||||
prompt (str): The system prompt.
|
||||
|
||||
Returns:
|
||||
new_exchange (Exchange)
|
||||
"""
|
||||
|
||||
new_exchange = exchange.replace(system=prompt)
|
||||
return new_exchange
|
||||
39
src/goose/utils/diff.py
Normal file
39
src/goose/utils/diff.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
def diff(a: str, b: str) -> List[str]:
|
||||
"""Returns a string containing the unified diff of two strings."""
|
||||
|
||||
import difflib
|
||||
|
||||
a_lines = a.splitlines()
|
||||
b_lines = b.splitlines()
|
||||
|
||||
# Create a Differ object
|
||||
d = difflib.Differ()
|
||||
|
||||
# Generate the diff
|
||||
diff = list(d.compare(a_lines, b_lines))
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def pretty_diff(a: str, b: str) -> Text:
|
||||
"""Returns a pretty-printed diff of two strings."""
|
||||
|
||||
diff_lines = diff(a, b)
|
||||
result = Text()
|
||||
for line in diff_lines:
|
||||
if line.startswith("+"):
|
||||
result.append(line, style="green")
|
||||
elif line.startswith("-"):
|
||||
result.append(line, style="red")
|
||||
elif line.startswith("?"):
|
||||
result.append(line, style="yellow")
|
||||
else:
|
||||
result.append(line, style="dim grey")
|
||||
result.append("\n")
|
||||
|
||||
return result
|
||||
103
src/goose/utils/file_utils.py
Normal file
103
src/goose/utils/file_utils.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import glob
|
||||
import os
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
def create_extensions_list(project_root: str, max_n: int) -> list:
|
||||
"""Get the top N file extensions in the current project
|
||||
Args:
|
||||
project_root (str): Root of the project to analyze
|
||||
max_n (int): The number of file extensions to return
|
||||
Returns:
|
||||
extensions (List[str]): A list of the top N file extensions
|
||||
"""
|
||||
if max_n == 0:
|
||||
raise (ValueError("Number of file extensions must be greater than 0"))
|
||||
|
||||
files = create_file_list(project_root, [])
|
||||
|
||||
counter = Counter()
|
||||
|
||||
for file in files:
|
||||
file_path = Path(file)
|
||||
if file_path.suffix: # omit ''
|
||||
counter[file_path.suffix] += 1
|
||||
|
||||
top_n = counter.most_common(max_n)
|
||||
extensions = [ext for ext, _ in top_n]
|
||||
|
||||
return extensions
|
||||
|
||||
|
||||
def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]:
|
||||
"""Calculate language weighting by file size to match GitHub's methodology.
|
||||
|
||||
Args:
|
||||
files_in_directory (List[str]): Paths to files in the project directory
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values
|
||||
"""
|
||||
|
||||
# Initialize counters for sizes
|
||||
size_by_language = Counter()
|
||||
|
||||
# Calculate size for files by language
|
||||
for file_path in files_in_directory:
|
||||
path = Path(file_path)
|
||||
if path.suffix:
|
||||
size_by_language[path.suffix] += os.path.getsize(file_path)
|
||||
|
||||
# Calculate total size and language percentages
|
||||
total_size = sum(size_by_language.values())
|
||||
language_percentages = {
|
||||
lang: (size / total_size * 100) if total_size else 0 for lang, size in size_by_language.items()
|
||||
}
|
||||
|
||||
return dict(sorted(language_percentages.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
|
||||
def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> List[str]:
|
||||
"""List all files in a directory with a given extension. Set extension to '' to return all files.
|
||||
|
||||
Args:
|
||||
dir_path (str): The path to the directory
|
||||
extension (Optional[str]): extension to lookup. Defaults to '' which will return all files.
|
||||
|
||||
Returns:
|
||||
files (List[str]): List of file paths
|
||||
"""
|
||||
# add a leading '.' to extension if needed
|
||||
if extension and not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
|
||||
files = glob.glob(f"{dir_path}/**/*{extension}", recursive=True)
|
||||
return files
|
||||
|
||||
|
||||
def create_file_list(dir_path: str, extensions: List[str]) -> List[str]:
|
||||
"""Creates a list of files with certain extensions
|
||||
|
||||
Args:
|
||||
dir_path (str): Directory to list files of. Will include files recursively in sub-directories.
|
||||
extensions (List[str]): List of file extensions to select for. If empty list, return all files
|
||||
|
||||
Returns:
|
||||
final_file_list (List[str]): List of file paths with specified extensions.
|
||||
"""
|
||||
# if extensions is empty list, return all files
|
||||
if not extensions:
|
||||
return glob.glob(f"{dir_path}/**/*", recursive=True)
|
||||
|
||||
# prune out files that do not end with any of the extensions in extensions
|
||||
final_file_list = []
|
||||
for ext in extensions:
|
||||
if ext and not ext.startswith("."):
|
||||
ext = f".{ext}"
|
||||
|
||||
files = glob.glob(f"{dir_path}/**/*{ext}", recursive=True)
|
||||
final_file_list += files
|
||||
|
||||
return final_file_list
|
||||
39
src/goose/utils/session_file.py
Normal file
39
src/goose/utils/session_file.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
from exchange import Message
|
||||
|
||||
from goose.cli.config import SESSION_FILE_SUFFIX
|
||||
|
||||
|
||||
def write_to_file(file_path: Path, messages: List[Message]) -> None:
|
||||
with open(file_path, "w") as f:
|
||||
for m in messages:
|
||||
json.dump(m.to_dict(), f)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def read_from_file(file_path: Path) -> List[Message]:
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
messages = [json.loads(m) for m in list(f) if m.strip()]
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(f"Failed to load session due to JSON decode Error: {e}")
|
||||
|
||||
return [Message(**m) for m in messages]
|
||||
|
||||
|
||||
def list_sorted_session_files(session_files_directory: Path) -> Dict[str, Path]:
|
||||
logs = list_session_files(session_files_directory)
|
||||
return {log.stem: log for log in sorted(logs, key=lambda x: x.stat().st_mtime, reverse=True)}
|
||||
|
||||
|
||||
def list_session_files(session_files_directory: Path) -> Iterator[Path]:
|
||||
return session_files_directory.glob(f"*{SESSION_FILE_SUFFIX}")
|
||||
|
||||
|
||||
def session_file_exists(session_files_directory: Path) -> bool:
|
||||
if not session_files_directory.exists():
|
||||
return False
|
||||
return any(list_session_files(session_files_directory))
|
||||
26
src/goose/view.py
Normal file
26
src/goose/view.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from attrs import define
|
||||
from exchange import Exchange
|
||||
|
||||
|
||||
@define
|
||||
class ExchangeView:
|
||||
"""A read-only view of the underlying Exchange
|
||||
|
||||
|
||||
Attributes:
|
||||
processor: A copy of the exchange configured for high capabilities
|
||||
accelerator: A copy of the exchange configured for high speed
|
||||
|
||||
"""
|
||||
|
||||
_processor: str
|
||||
_accelerator: str
|
||||
_exchange: Exchange
|
||||
|
||||
@property
|
||||
def processor(self) -> Exchange:
|
||||
return self._exchange.replace(model=self._processor)
|
||||
|
||||
@property
|
||||
def accelerator(self) -> Exchange:
|
||||
return self._exchange.replace(model=self._accelerator)
|
||||
2
tests/.ruff.toml
Normal file
2
tests/.ruff.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
lint.select = ["E", "W", "F", "N"]
|
||||
line-length = 120
|
||||
47
tests/cli/prompt/test_goose_prompt_session.py
Normal file
47
tests/cli/prompt/test_goose_prompt_session.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_session():
|
||||
with patch("prompt_toolkit.PromptSession") as mock_prompt_session:
|
||||
yield mock_prompt_session
|
||||
|
||||
|
||||
def test_get_save_session_name(mock_prompt_session):
|
||||
mock_prompt_session.prompt.return_value = "my_session"
|
||||
goose_prompt_session = GoosePromptSession(mock_prompt_session)
|
||||
|
||||
assert goose_prompt_session.get_save_session_name() == "my_session"
|
||||
|
||||
|
||||
def test_get_user_input_to_continue(mock_prompt_session):
|
||||
mock_prompt_session.prompt.return_value = "input_value"
|
||||
goose_prompt_session = GoosePromptSession(mock_prompt_session)
|
||||
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
|
||||
assert user_input == UserInput(PromptAction.CONTINUE, "input_value")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exit_input", ["exit", ":q"])
|
||||
def test_get_user_input_to_exit(exit_input, mock_prompt_session):
|
||||
mock_prompt_session.prompt.return_value = exit_input
|
||||
goose_prompt_session = GoosePromptSession(mock_prompt_session)
|
||||
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
|
||||
assert user_input == UserInput(PromptAction.EXIT)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("error", [EOFError, KeyboardInterrupt])
|
||||
def test_get_user_input_to_exit_when_error_occurs(error, mock_prompt_session):
|
||||
mock_prompt_session.prompt.side_effect = error
|
||||
goose_prompt_session = GoosePromptSession(mock_prompt_session)
|
||||
|
||||
user_input = goose_prompt_session.get_user_input()
|
||||
|
||||
assert user_input == UserInput(PromptAction.EXIT)
|
||||
253
tests/cli/prompt/test_lexer.py
Normal file
253
tests/cli/prompt/test_lexer.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from goose.cli.prompt.lexer import (
|
||||
PromptLexer,
|
||||
command_itself,
|
||||
completion_for_command,
|
||||
value_for_command,
|
||||
)
|
||||
from prompt_toolkit.document import Document
|
||||
|
||||
|
||||
# Helper function to create a Document and lexer instance
|
||||
def create_lexer_and_document(commands, text):
|
||||
lexer = PromptLexer(commands)
|
||||
document = Document(text)
|
||||
return lexer, document
|
||||
|
||||
|
||||
# Test cases
|
||||
def test_lex_document_command():
|
||||
lexer, document = create_lexer_and_document(["file"], "/file:example.txt")
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [("class:command", "/file:"), ("class:parameter", "example.txt")]
|
||||
assert tokens(0) == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_partial_command():
|
||||
lexer, document = create_lexer_and_document(["file"], "/fi")
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [("class:command", "/fi")]
|
||||
assert tokens(0) == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_with_text():
|
||||
lexer, document = create_lexer_and_document(["file"], "Some text /file:example.txt")
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [
|
||||
("class:text", "S"),
|
||||
("class:text", "o"),
|
||||
("class:text", "m"),
|
||||
("class:text", "e"),
|
||||
("class:text", " "),
|
||||
("class:text", "t"),
|
||||
("class:text", "e"),
|
||||
("class:text", "x"),
|
||||
("class:text", "t"),
|
||||
("class:text", " "),
|
||||
("class:command", "/file:"),
|
||||
("class:parameter", "example.txt"),
|
||||
]
|
||||
assert tokens(0) == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_with_command_in_middle():
|
||||
lexer, document = create_lexer_and_document(["file"], "Some text /file:example.txt more text")
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [
|
||||
("class:text", "S"),
|
||||
("class:text", "o"),
|
||||
("class:text", "m"),
|
||||
("class:text", "e"),
|
||||
("class:text", " "),
|
||||
("class:text", "t"),
|
||||
("class:text", "e"),
|
||||
("class:text", "x"),
|
||||
("class:text", "t"),
|
||||
("class:text", " "),
|
||||
("class:command", "/file:"),
|
||||
("class:parameter", "example.txt"),
|
||||
("class:text", " "),
|
||||
("class:text", "m"),
|
||||
("class:text", "o"),
|
||||
("class:text", "r"),
|
||||
("class:text", "e"),
|
||||
("class:text", " "),
|
||||
("class:text", "t"),
|
||||
("class:text", "e"),
|
||||
("class:text", "x"),
|
||||
("class:text", "t"),
|
||||
]
|
||||
actual_tokens = list(tokens(0))
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_multiple_commands():
|
||||
lexer, document = create_lexer_and_document(
|
||||
["command", "anothercommand"],
|
||||
"/command:example1.txt more text /anothercommand:example2.txt",
|
||||
)
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [
|
||||
("class:command", "/command:"),
|
||||
("class:parameter", "example1.txt"),
|
||||
("class:text", " "),
|
||||
("class:text", "m"),
|
||||
("class:text", "o"),
|
||||
("class:text", "r"),
|
||||
("class:text", "e"),
|
||||
("class:text", " "),
|
||||
("class:text", "t"),
|
||||
("class:text", "e"),
|
||||
("class:text", "x"),
|
||||
("class:text", "t"),
|
||||
("class:text", " "),
|
||||
("class:command", "/anothercommand:"),
|
||||
("class:parameter", "example2.txt"),
|
||||
]
|
||||
actual_tokens = list(tokens(0))
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_multiple_same_commands():
|
||||
lexer, document = create_lexer_and_document(
|
||||
["command"],
|
||||
"/command:example1.txt more text /command:example2.txt",
|
||||
)
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [
|
||||
("class:command", "/command:"),
|
||||
("class:parameter", "example1.txt"),
|
||||
("class:text", " "),
|
||||
("class:text", "m"),
|
||||
("class:text", "o"),
|
||||
("class:text", "r"),
|
||||
("class:text", "e"),
|
||||
("class:text", " "),
|
||||
("class:text", "t"),
|
||||
("class:text", "e"),
|
||||
("class:text", "x"),
|
||||
("class:text", "t"),
|
||||
("class:text", " "),
|
||||
("class:command", "/command:"),
|
||||
("class:parameter", "example2.txt"),
|
||||
]
|
||||
actual_tokens = list(tokens(0))
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_two_half_commands():
|
||||
lexer, document = create_lexer_and_document(
|
||||
["command"],
|
||||
"/comma /com",
|
||||
)
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [
|
||||
("class:text", "/"),
|
||||
("class:text", "c"),
|
||||
("class:text", "o"),
|
||||
("class:text", "m"),
|
||||
("class:text", "m"),
|
||||
("class:text", "a"),
|
||||
("class:text", " "),
|
||||
("class:command", "/com"),
|
||||
]
|
||||
actual_tokens = list(tokens(0))
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_command_attached_to_pre_string():
|
||||
lexer, document = create_lexer_and_document(
|
||||
["command"],
|
||||
"some/command:example.txt",
|
||||
)
|
||||
expected_tokens = [
|
||||
("class:text", "s"),
|
||||
("class:text", "o"),
|
||||
("class:text", "m"),
|
||||
("class:text", "e"),
|
||||
("class:text", "/"),
|
||||
("class:text", "c"),
|
||||
("class:text", "o"),
|
||||
("class:text", "m"),
|
||||
("class:text", "m"),
|
||||
("class:text", "a"),
|
||||
("class:text", "n"),
|
||||
("class:text", "d"),
|
||||
("class:text", ":"),
|
||||
("class:text", "e"),
|
||||
("class:text", "x"),
|
||||
("class:text", "a"),
|
||||
("class:text", "m"),
|
||||
("class:text", "p"),
|
||||
("class:text", "l"),
|
||||
("class:text", "e"),
|
||||
("class:text", "."),
|
||||
("class:text", "t"),
|
||||
("class:text", "x"),
|
||||
("class:text", "t"),
|
||||
]
|
||||
tokens = lexer.lex_document(document)
|
||||
actual_tokens = list(tokens(0))
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_partial_command_attached_to_pre_string():
|
||||
lexer, document = create_lexer_and_document(
|
||||
["command"],
|
||||
"some/com",
|
||||
)
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [
|
||||
("class:text", "s"),
|
||||
("class:text", "o"),
|
||||
("class:text", "m"),
|
||||
("class:text", "e"),
|
||||
("class:text", "/"),
|
||||
("class:text", "c"),
|
||||
("class:text", "o"),
|
||||
("class:text", "m"),
|
||||
]
|
||||
actual_tokens = list(tokens(0))
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_no_command():
|
||||
lexer, document = create_lexer_and_document([], "Some random text")
|
||||
tokens = lexer.lex_document(document)
|
||||
expected_tokens = [("class:text", character) for character in "Some random text"]
|
||||
actual_tokens = list(tokens(0))
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_lex_document_ending_char_of_parameter_is_symbol():
|
||||
lexer, document = create_lexer_and_document(
|
||||
["command"],
|
||||
"/command:example.txt/",
|
||||
)
|
||||
expected_tokens = [
|
||||
("class:command", "/command:"),
|
||||
("class:parameter", "example.txt/"),
|
||||
]
|
||||
tokens = lexer.lex_document(document)
|
||||
actual_tokens = list(tokens(0))
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
|
||||
def test_command_itself():
|
||||
pattern = command_itself("file:")
|
||||
matches = pattern.match("/file:example.txt")
|
||||
assert matches is not None
|
||||
assert matches.group(1) == "/file:"
|
||||
|
||||
|
||||
def test_value_for_command():
|
||||
pattern = value_for_command("file:")
|
||||
matches = pattern.search("/file:example.txt")
|
||||
assert matches is not None
|
||||
assert matches.group(1) == "example.txt"
|
||||
|
||||
|
||||
def test_completion_for_command():
|
||||
pattern = completion_for_command("file:")
|
||||
matches = pattern.search("/file:")
|
||||
assert matches is not None
|
||||
assert matches.group(1) == "file:"
|
||||
37
tests/cli/prompt/test_prompt_validator.py
Normal file
37
tests/cli/prompt/test_prompt_validator.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from goose.cli.prompt.prompt_validator import PromptValidator
|
||||
from prompt_toolkit.validation import ValidationError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator():
|
||||
return PromptValidator()
|
||||
|
||||
|
||||
@patch("prompt_toolkit.document.Document.text")
|
||||
def test_validate_should_not_raise_error_when_input_is_none(document, validator):
|
||||
try:
|
||||
validator.validate(create_mock_document(None))
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error was raised: {e}")
|
||||
|
||||
|
||||
@patch("prompt_toolkit.document.Document.text", return_value="user typed something")
|
||||
def test_validate_should_not_raise_error_when_user_has_input(document, validator):
|
||||
try:
|
||||
validator.validate(create_mock_document("user typed something"))
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error was raised: {e}")
|
||||
|
||||
|
||||
def test_validate_should_raise_validation_error_when_user_has_empty_input(validator):
|
||||
with pytest.raises(ValidationError):
|
||||
validator.validate(create_mock_document(""))
|
||||
|
||||
|
||||
def create_mock_document(text: str) -> MagicMock:
|
||||
document = MagicMock()
|
||||
document.text = text
|
||||
return document
|
||||
15
tests/cli/prompt/test_user_input.py
Normal file
15
tests/cli/prompt/test_user_input.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
|
||||
|
||||
def test_user_input_with_action_continue():
|
||||
input = UserInput(action=PromptAction.CONTINUE, text="Hello")
|
||||
assert input.to_continue() is True
|
||||
assert input.to_exit() is False
|
||||
assert input.text == "Hello"
|
||||
|
||||
|
||||
def test_user_input_with_action_exit():
|
||||
input = UserInput(action=PromptAction.EXIT)
|
||||
assert input.to_continue() is False
|
||||
assert input.to_exit() is True
|
||||
assert input.text is None
|
||||
81
tests/cli/test_config.py
Normal file
81
tests/cli/test_config.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from goose.cli.config import ensure_config, read_config, session_path, write_config
|
||||
from goose.profile import default_profile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_profile_config_path(tmp_path):
|
||||
with patch("goose.cli.config.PROFILES_CONFIG_PATH", tmp_path / "profiles.yaml") as mock_path:
|
||||
yield mock_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_default_model_configuration():
|
||||
with patch(
|
||||
"goose.cli.config.default_model_configuration", return_value=("provider", "processor", "accelerator")
|
||||
) as mock_default_model_configuration:
|
||||
yield mock_default_model_configuration
|
||||
|
||||
|
||||
def test_read_write_config(mock_profile_config_path, profile_factory):
|
||||
profiles = {
|
||||
"profile1": profile_factory({"provider": "providerA"}),
|
||||
}
|
||||
write_config(profiles)
|
||||
|
||||
assert read_config() == profiles
|
||||
|
||||
|
||||
def test_ensure_config_create_profiles_file_with_default_profile(
|
||||
mock_profile_config_path, mock_default_model_configuration
|
||||
):
|
||||
assert not mock_profile_config_path.exists()
|
||||
|
||||
ensure_config(name="default")
|
||||
assert mock_profile_config_path.exists()
|
||||
|
||||
assert read_config() == {"default": default_profile(*mock_default_model_configuration())}
|
||||
|
||||
|
||||
def test_ensure_config_add_default_profile(mock_profile_config_path, profile_factory, mock_default_model_configuration):
|
||||
existing_profile = profile_factory({"provider": "providerA"})
|
||||
write_config({"profile1": existing_profile})
|
||||
|
||||
ensure_config(name="default")
|
||||
|
||||
assert read_config() == {
|
||||
"profile1": existing_profile,
|
||||
"default": default_profile(*mock_default_model_configuration()),
|
||||
}
|
||||
|
||||
|
||||
@patch("goose.cli.config.Confirm.ask", return_value=True)
|
||||
def test_ensure_config_overwrite_default_profile(
|
||||
mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration
|
||||
):
|
||||
existing_profile = profile_factory({"provider": "providerA"})
|
||||
profile_name = "default"
|
||||
write_config({profile_name: existing_profile})
|
||||
|
||||
expected_default_profile = default_profile(*mock_default_model_configuration())
|
||||
assert ensure_config(name="default") == expected_default_profile
|
||||
assert read_config() == {"default": expected_default_profile}
|
||||
|
||||
|
||||
@patch("goose.cli.config.Confirm.ask", return_value=False)
|
||||
def test_ensure_config_keep_original_default_profile(
|
||||
mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration
|
||||
):
|
||||
existing_profile = profile_factory({"provider": "providerA"})
|
||||
profile_name = "default"
|
||||
write_config({profile_name: existing_profile})
|
||||
|
||||
assert ensure_config(name="default") == existing_profile
|
||||
|
||||
assert read_config() == {"default": existing_profile}
|
||||
|
||||
|
||||
def test_session_path(mock_sessions_path):
|
||||
assert session_path("session1") == mock_sessions_path / "session1.jsonl"
|
||||
80
tests/cli/test_main.py
Normal file
80
tests/cli/test_main.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from exchange import Message
|
||||
from goose.cli.main import goose_cli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_print():
|
||||
with patch("goose.cli.main.print") as mock_print:
|
||||
yield mock_print
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_files_path(tmp_path):
|
||||
with patch("goose.cli.main.SESSIONS_PATH", tmp_path) as session_files_path:
|
||||
yield session_files_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
with patch("goose.cli.main.Session") as mock_session_class:
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_class.return_value = mock_session_instance
|
||||
yield mock_session_class, mock_session_instance
|
||||
|
||||
|
||||
def test_session_resume_command_with_session_name(mock_session):
|
||||
mock_session_class, mock_session_instance = mock_session
|
||||
runner = CliRunner()
|
||||
runner.invoke(goose_cli, ["session", "resume", "session1", "--profile", "default"])
|
||||
mock_session_class.assert_called_once_with(name="session1", profile="default")
|
||||
mock_session_instance.run.assert_called_once()
|
||||
|
||||
|
||||
def test_session_resume_command_without_session_name_without_session_files(
|
||||
mock_print, mock_session_files_path, mock_session
|
||||
):
|
||||
_, mock_session_instance = mock_session
|
||||
runner = CliRunner()
|
||||
runner.invoke(goose_cli, ["session", "resume"])
|
||||
mock_print.assert_called_with("No sessions found.")
|
||||
mock_session_instance.run.assert_not_called()
|
||||
|
||||
|
||||
def test_session_resume_command_without_session_name_use_latest_session(
|
||||
mock_print, mock_session_files_path, mock_session, create_session_file
|
||||
):
|
||||
mock_session_class, mock_session_instance = mock_session
|
||||
for index, session_name in enumerate(["first", "second"]):
|
||||
create_session_file([Message.user("Hello1")], mock_session_files_path / f"{session_name}.jsonl", time() + index)
|
||||
runner = CliRunner()
|
||||
runner.invoke(goose_cli, ["session", "resume", "--profile", "default"])
|
||||
|
||||
second_file_path = mock_session_files_path / "second.jsonl"
|
||||
mock_print.assert_called_once_with(f"Resuming most recent session: second from {second_file_path}")
|
||||
mock_session_class.assert_called_once_with(name="second", profile="default")
|
||||
mock_session_instance.run.assert_called_once()
|
||||
|
||||
|
||||
def test_session_list_command(mock_print, mock_session_files_path, create_session_file):
|
||||
create_session_file([Message.user("Hello")], mock_session_files_path / "abc.jsonl")
|
||||
runner = CliRunner()
|
||||
runner.invoke(goose_cli, ["session", "list"])
|
||||
file_time = datetime.fromtimestamp(mock_session_files_path.stat().st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
mock_print.assert_called_with(f"{file_time} abc")
|
||||
|
||||
|
||||
def test_session_clear_command(mock_session_files_path, create_session_file):
|
||||
for index, session_name in enumerate(["first", "second"]):
|
||||
create_session_file([Message.user("Hello1")], mock_session_files_path / f"{session_name}.jsonl", time() + index)
|
||||
runner = CliRunner()
|
||||
runner.invoke(goose_cli, ["session", "clear", "--keep", "1"])
|
||||
|
||||
session_files = list(mock_session_files_path.glob("*.jsonl"))
|
||||
assert len(session_files) == 1
|
||||
assert session_files[0].stem == "second"
|
||||
134
tests/cli/test_session.py
Normal file
134
tests/cli/test_session.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from exchange import Message
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
from goose.cli.session import Session
|
||||
from prompt_toolkit import PromptSession
|
||||
|
||||
SPECIFIED_SESSION_NAME = "mySession"
|
||||
SESSION_NAME = "test"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_specified_session_name():
|
||||
with patch.object(PromptSession, "prompt", return_value=SPECIFIED_SESSION_NAME) as specified_session_name:
|
||||
yield specified_session_name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory):
|
||||
with patch("goose.cli.session.build_exchange", return_value=exchange_factory()), patch(
|
||||
"goose.cli.session.load_profile", return_value=profile_factory()
|
||||
), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch(
|
||||
"goose.cli.session.load_provider", return_value="provider"
|
||||
):
|
||||
mock_session_notifier.return_value = MagicMock()
|
||||
|
||||
def create_session(session_attributes: dict = {}):
|
||||
return Session(**session_attributes)
|
||||
|
||||
yield create_session
|
||||
|
||||
|
||||
def test_session_does_not_extend_last_user_message_on_init(
|
||||
create_session_with_mock_configs, mock_sessions_path, create_session_file
|
||||
):
|
||||
messages = [Message.user("Hello"), Message.assistant("Hi"), Message.user("Last should be removed")]
|
||||
create_session_file(messages, mock_sessions_path / f"{SESSION_NAME}.jsonl")
|
||||
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
print("Messages after session init:", session.exchange.messages) # Debugging line
|
||||
assert len(session.exchange.messages) == 2
|
||||
assert [message.text for message in session.exchange.messages] == ["Hello", "Hi"]
|
||||
|
||||
|
||||
def test_save_session_create_session(mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name):
|
||||
session = create_session_with_mock_configs()
|
||||
session.exchange.messages.append(Message.assistant("Hello"))
|
||||
|
||||
session.save_session()
|
||||
session_file = mock_sessions_path / f"{SPECIFIED_SESSION_NAME}.jsonl"
|
||||
assert session_file.exists()
|
||||
|
||||
saved_messages = session.load_session()
|
||||
assert len(saved_messages) == 1
|
||||
assert saved_messages[0].text == "Hello"
|
||||
|
||||
|
||||
def test_save_session_resume_session_new_file(
|
||||
mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name, create_session_file
|
||||
):
|
||||
with patch("goose.cli.session.confirm", return_value=False):
|
||||
existing_messages = [Message.assistant("existing_message")]
|
||||
existing_session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
|
||||
create_session_file(existing_messages, existing_session_file)
|
||||
|
||||
new_session_file = mock_sessions_path / f"{SPECIFIED_SESSION_NAME}.jsonl"
|
||||
assert not new_session_file.exists()
|
||||
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
session.exchange.messages.append(Message.assistant("new_message"))
|
||||
|
||||
session.save_session()
|
||||
|
||||
assert new_session_file.exists()
|
||||
assert existing_session_file.exists()
|
||||
|
||||
saved_messages = session.load_session()
|
||||
assert [message.text for message in saved_messages] == ["existing_message", "new_message"]
|
||||
|
||||
|
||||
def test_save_session_resume_session_existing_session_file(
|
||||
mock_sessions_path, create_session_with_mock_configs, create_session_file
|
||||
):
|
||||
with patch("goose.cli.session.confirm", return_value=True):
|
||||
existing_messages = [Message.assistant("existing_message")]
|
||||
existing_session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl"
|
||||
create_session_file(existing_messages, existing_session_file)
|
||||
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
session.exchange.messages.append(Message.assistant("new_message"))
|
||||
|
||||
session.save_session()
|
||||
|
||||
saved_messages = session.load_session()
|
||||
assert [message.text for message in saved_messages] == ["existing_message", "new_message"]
|
||||
|
||||
|
||||
def test_process_first_message_return_message(create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs()
|
||||
with patch.object(
|
||||
GoosePromptSession, "get_user_input", return_value=UserInput(action=PromptAction.CONTINUE, text="Hello")
|
||||
):
|
||||
message = session.process_first_message()
|
||||
|
||||
assert message.text == "Hello"
|
||||
assert len(session.exchange.messages) == 0
|
||||
|
||||
|
||||
def test_process_first_message_to_exit(create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs()
|
||||
with patch.object(GoosePromptSession, "get_user_input", return_value=UserInput(action=PromptAction.EXIT)):
|
||||
message = session.process_first_message()
|
||||
|
||||
assert message is None
|
||||
|
||||
|
||||
def test_process_first_message_return_last_exchange_message(create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs()
|
||||
session.exchange.messages.append(Message.user("Hi"))
|
||||
|
||||
message = session.process_first_message()
|
||||
|
||||
assert message.text == "Hi"
|
||||
assert len(session.exchange.messages) == 0
|
||||
|
||||
|
||||
def test_generate_session_name(create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs()
|
||||
with patch.object(GoosePromptSession, "get_save_session_name", return_value=SPECIFIED_SESSION_NAME):
|
||||
session.generate_session_name()
|
||||
|
||||
assert session.name == SPECIFIED_SESSION_NAME
|
||||
59
tests/conftest.py
Normal file
59
tests/conftest.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import json
|
||||
import os
|
||||
from time import time
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from exchange import Exchange
|
||||
from goose.profile import Profile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def profile_factory():
|
||||
def _create_profile(attributes={}):
|
||||
profile_attrs = {
|
||||
"provider": "mock_provider",
|
||||
"processor": "mock_processor",
|
||||
"accelerator": "mock_accelerator",
|
||||
"moderator": "mock_moderator",
|
||||
"toolkits": [],
|
||||
}
|
||||
profile_attrs.update(attributes)
|
||||
return Profile(**profile_attrs)
|
||||
|
||||
return _create_profile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def exchange_factory():
|
||||
def _create_exchange(attributes={}):
|
||||
exchange_attrs = {
|
||||
"provider": "mock_provider",
|
||||
"system": "mock_system",
|
||||
"tools": [],
|
||||
"moderator": Mock(),
|
||||
"model": "mock_model",
|
||||
}
|
||||
exchange_attrs.update(attributes)
|
||||
return Exchange(**exchange_attrs)
|
||||
|
||||
return _create_exchange
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessions_path(tmp_path):
|
||||
with patch("goose.cli.config.SESSIONS_PATH", tmp_path) as mock_path:
|
||||
yield mock_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_session_file():
|
||||
def _create_session_file(messages, session_file_path, mtime=time()):
|
||||
with open(session_file_path, "w") as session_file:
|
||||
for m in messages:
|
||||
json.dump(m.to_dict(), session_file)
|
||||
session_file.write("\n")
|
||||
session_file.close()
|
||||
os.utime(session_file_path, (mtime, mtime))
|
||||
|
||||
return _create_session_file
|
||||
50
tests/test_completer.py
Normal file
50
tests/test_completer.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from goose.cli.prompt.completer import GoosePromptCompleter
|
||||
from goose.command.base import Command
|
||||
from prompt_toolkit.completion import Completion
|
||||
from prompt_toolkit.document import Document
|
||||
|
||||
# Mock Command class
|
||||
dummy_command = Mock(spec=Command)
|
||||
|
||||
dummy_command.get_completions = Mock(
|
||||
return_value=[
|
||||
Completion(text="completion1"),
|
||||
Completion(text="completion2"),
|
||||
]
|
||||
)
|
||||
|
||||
commands_list = {"test_command1": dummy_command, "test_command2": dummy_command}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completer():
|
||||
return GoosePromptCompleter(commands=commands_list)
|
||||
|
||||
|
||||
def test_get_command_completions(completer):
|
||||
document = Document(text="/test_command1:input")
|
||||
completions = list(completer.get_command_completions(document))
|
||||
assert len(completions) == 2
|
||||
assert completions[0].text == "completion1"
|
||||
assert completions[1].text == "completion2"
|
||||
|
||||
|
||||
def test_get_command_name_completions(completer):
|
||||
document = Document(text="/test")
|
||||
completions = list(completer.get_command_name_completions(document))
|
||||
print(completions)
|
||||
assert len(completions) == 2
|
||||
assert completions[0].text == "test_command1"
|
||||
assert completions[1].text == "test_command2"
|
||||
|
||||
|
||||
def test_get_completions(completer):
|
||||
document = Document(text="/test_command1:input")
|
||||
completions = list(completer.get_completions(document, None))
|
||||
print(completions)
|
||||
assert len(completions) == 2
|
||||
assert completions[0].text == "completion1"
|
||||
assert completions[1].text == "completion2"
|
||||
0
tests/toolkit/__init__.py
Normal file
0
tests/toolkit/__init__.py
Normal file
68
tests/toolkit/test_developer.py
Normal file
68
tests/toolkit/test_developer.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
from goose.toolkit.base import Requirements
|
||||
from goose.toolkit.developer import Developer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def developer_toolkit():
|
||||
toolkit = Developer(notifier=MagicMock(), requires=Requirements(""))
|
||||
|
||||
# This mocking ensures that that the safety check is considered a pass in shell calls
|
||||
toolkit.exchange_view = Mock()
|
||||
toolkit.exchange_view.processor.replace.return_value = Mock()
|
||||
toolkit.exchange_view.processor.replace.return_value.messages = []
|
||||
toolkit.exchange_view.processor.replace.return_value.add = Mock()
|
||||
toolkit.exchange_view.processor.replace.return_value.reply.return_value.text = "3"
|
||||
toolkit.exchange_view.processor.replace.return_value.messages = [Mock()]
|
||||
|
||||
return toolkit
|
||||
|
||||
|
||||
def test_update_plan(developer_toolkit):
|
||||
tasks = [
|
||||
{"description": "Task 1", "status": "planned"},
|
||||
{"description": "Task 2", "status": "complete"},
|
||||
{"description": "Task 3", "status": "in-progress"},
|
||||
]
|
||||
updated_tasks = developer_toolkit.update_plan(tasks)
|
||||
assert updated_tasks == tasks
|
||||
|
||||
|
||||
def test_patch_file(temp_dir, developer_toolkit):
|
||||
test_file = temp_dir / "test.txt"
|
||||
before_content = "Hello World"
|
||||
after_content = "Hello Goose"
|
||||
test_file.write_text(before_content)
|
||||
developer_toolkit.patch_file(test_file.as_posix(), before_content, after_content)
|
||||
assert test_file.read_text() == after_content
|
||||
|
||||
|
||||
def test_read_file(temp_dir, developer_toolkit):
|
||||
test_file = temp_dir / "test.txt"
|
||||
content = "Hello World"
|
||||
test_file.write_text(content)
|
||||
read_content = developer_toolkit.read_file(test_file.as_posix())
|
||||
assert content in read_content
|
||||
|
||||
|
||||
def test_shell(developer_toolkit):
|
||||
command = "echo Hello World"
|
||||
result = developer_toolkit.shell(command)
|
||||
assert "Hello World" in result
|
||||
|
||||
|
||||
def test_write_file(temp_dir, developer_toolkit):
|
||||
test_file = temp_dir / "test.txt"
|
||||
content = "Hello World"
|
||||
developer_toolkit.write_file(test_file.as_posix(), content)
|
||||
assert test_file.read_text() == content
|
||||
136
tests/utils/test_ask.py
Normal file
136
tests/utils/test_ask.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from exchange import Exchange
|
||||
from goose.utils.ask import ask_an_ai, clear_exchange, replace_prompt
|
||||
|
||||
|
||||
# tests for `ask_an_ai`
|
||||
def test_ask_an_ai_empty_input():
|
||||
"""Test that function raises TypeError if input is empty."""
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
with pytest.raises(TypeError):
|
||||
ask_an_ai("", exchange)
|
||||
|
||||
|
||||
def test_ask_an_ai_no_history():
|
||||
"""Test the no_history functionality."""
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
with patch("goose.utils.ask.clear_exchange") as mock_clear:
|
||||
ask_an_ai("Test input", exchange, no_history=True)
|
||||
mock_clear.assert_called_once_with(exchange)
|
||||
|
||||
|
||||
def test_ask_an_ai_prompt_replacement():
|
||||
"""Test that the prompt is replaced if provided."""
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
prompt = "New prompt"
|
||||
|
||||
with patch("goose.utils.ask.replace_prompt") as mock_replace_prompt:
|
||||
# Configure the mock to return a new mock object with the same spec
|
||||
modified_exchange = MagicMock(spec=Exchange)
|
||||
mock_replace_prompt.return_value = modified_exchange
|
||||
|
||||
ask_an_ai("Test input", exchange, prompt=prompt, no_history=False)
|
||||
|
||||
# Check if replace_prompt was called correctly
|
||||
mock_replace_prompt.assert_called_once_with(exchange, prompt)
|
||||
|
||||
# Assert that the modified exchange was returned correctly
|
||||
assert mock_replace_prompt.return_value is modified_exchange, "Should return the modified exchange mock"
|
||||
|
||||
|
||||
def test_ask_an_ai_exchange_usage():
|
||||
"""Test that the exchange adds and processes the message correctly."""
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
input_text = "Test input"
|
||||
message_mock = MagicMock(return_value="Mocked Message")
|
||||
|
||||
with patch("goose.utils.ask.Message.user", new=message_mock):
|
||||
ask_an_ai(input_text, exchange, no_history=False)
|
||||
|
||||
# Assert that Message.user was called with the correct input
|
||||
message_mock.assert_called_once_with(input_text)
|
||||
|
||||
# Assert that exchange.add was called with the mocked message
|
||||
exchange.add.assert_called_once_with("Mocked Message")
|
||||
exchange.reply.assert_called_once()
|
||||
|
||||
|
||||
def test_ask_an_ai_return_value():
|
||||
"""Test that the function returns the correct reply."""
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
expected_reply = "AI response"
|
||||
exchange.reply.return_value = expected_reply
|
||||
result = ask_an_ai("Test input", exchange, no_history=False)
|
||||
assert result == expected_reply, "Function should return the reply from the exchange."
|
||||
|
||||
|
||||
# tests for `clear_exchange`
|
||||
def test_clear_exchange_without_tools():
|
||||
"""Test clearing messages and checkpoints but not tools."""
|
||||
# Arrange
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
|
||||
# Act
|
||||
new_exchange = clear_exchange(exchange, clear_tools=False)
|
||||
|
||||
# Assert
|
||||
exchange.replace.assert_called_once_with(messages=[], checkpoints=[])
|
||||
assert new_exchange == exchange.replace.return_value, "Should return the modified exchange"
|
||||
|
||||
|
||||
def test_clear_exchange_with_tools():
|
||||
"""Test clearing messages, checkpoints, and tools."""
|
||||
# Arrange
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
|
||||
# Act
|
||||
new_exchange = clear_exchange(exchange, clear_tools=True)
|
||||
|
||||
# Assert
|
||||
exchange.replace.assert_called_once_with(messages=[], checkpoints=[], tools=())
|
||||
assert new_exchange == exchange.replace.return_value, "Should return the modified exchange with tools cleared"
|
||||
|
||||
|
||||
def test_clear_exchange_return_value():
|
||||
"""Test that the returned value is a new exchange object."""
|
||||
# Arrange
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
new_exchange_mock = MagicMock(spec=Exchange)
|
||||
exchange.replace.return_value = new_exchange_mock
|
||||
|
||||
# Act
|
||||
new_exchange = clear_exchange(exchange, clear_tools=False)
|
||||
|
||||
# Assert
|
||||
assert new_exchange == new_exchange_mock, "Returned exchange should be the new exchange instance"
|
||||
|
||||
|
||||
# tests for `replace_prompt`
|
||||
def test_replace_prompt():
|
||||
"""Test that the system prompt is correctly replaced."""
|
||||
# Arrange
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
prompt = "New system prompt"
|
||||
|
||||
# Act
|
||||
new_exchange = replace_prompt(exchange, prompt)
|
||||
|
||||
# Assert
|
||||
exchange.replace.assert_called_once_with(system=prompt)
|
||||
assert new_exchange == exchange.replace.return_value, "Should return the modified exchange with the new prompt"
|
||||
|
||||
|
||||
def test_replace_prompt_return_value():
|
||||
"""Test that the returned value is a new exchange object."""
|
||||
# Arrange
|
||||
exchange = MagicMock(spec=Exchange)
|
||||
expected_new_exchange = MagicMock(spec=Exchange)
|
||||
exchange.replace.return_value = expected_new_exchange
|
||||
|
||||
# Act
|
||||
new_exchange = replace_prompt(exchange, "Another prompt")
|
||||
|
||||
# Assert
|
||||
assert new_exchange == expected_new_exchange, "Returned exchange should be the new exchange instance"
|
||||
192
tests/utils/test_file_utils.py
Normal file
192
tests/utils/test_file_utils.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from goose.utils.file_utils import (
|
||||
create_extensions_list,
|
||||
create_language_weighting,
|
||||
) # Adjust the import path as necessary
|
||||
|
||||
|
||||
# tests for `create_extensions_list`
|
||||
def test_create_extensions_list_valid_input():
|
||||
"""Test with valid input and multiple file extensions."""
|
||||
project_root = "/fake/project/root"
|
||||
max_n = 3
|
||||
files = [
|
||||
"/fake/project/root/file1.py",
|
||||
"/fake/project/root/file2.py",
|
||||
"/fake/project/root/file3.md",
|
||||
"/fake/project/root/file4.md",
|
||||
"/fake/project/root/file5.txt",
|
||||
"/fake/project/root/file6.py",
|
||||
"/fake/project/root/file7.md",
|
||||
]
|
||||
|
||||
with patch("goose.utils.file_utils.create_file_list", return_value=files):
|
||||
extensions = create_extensions_list(project_root, max_n)
|
||||
assert extensions == [".py", ".md", ".txt"], "Should return the top 3 extensions in the correct order"
|
||||
|
||||
|
||||
def test_create_extensions_list_zero_max_n():
|
||||
"""Test that a ValueError is raised when max_n is 0."""
|
||||
project_root = "/fake/project/root"
|
||||
max_n = 0
|
||||
|
||||
with pytest.raises(ValueError, match="Number of file extensions must be greater than 0"):
|
||||
create_extensions_list(project_root, max_n)
|
||||
|
||||
|
||||
def test_create_extensions_list_no_files():
|
||||
"""Test with a project root that contains no files."""
|
||||
project_root = "/fake/project/root"
|
||||
max_n = 3
|
||||
|
||||
with patch("goose.utils.file_utils.create_file_list", return_value=[]):
|
||||
extensions = create_extensions_list(project_root, max_n)
|
||||
assert extensions == [], "Should return an empty list when no files are present"
|
||||
|
||||
|
||||
def test_create_extensions_list_fewer_extensions_than_max_n():
|
||||
"""Test when there are fewer unique extensions than max_n."""
|
||||
project_root = "/fake/project/root"
|
||||
max_n = 5
|
||||
files = [
|
||||
"/fake/project/root/file1.py",
|
||||
"/fake/project/root/file2.py",
|
||||
"/fake/project/root/file3.md",
|
||||
]
|
||||
|
||||
with patch("goose.utils.file_utils.create_file_list", return_value=files):
|
||||
extensions = create_extensions_list(project_root, max_n)
|
||||
assert extensions == [".py", ".md"], "Should return all available extensions when fewer than max_n"
|
||||
|
||||
|
||||
def test_create_extensions_list_files_without_extensions():
|
||||
"""Test that files without extensions are ignored."""
|
||||
project_root = "/fake/project/root"
|
||||
max_n = 3
|
||||
files = [
|
||||
"/fake/project/root/file1",
|
||||
"/fake/project/root/file2.py",
|
||||
"/fake/project/root/file3",
|
||||
"/fake/project/root/file4.md",
|
||||
]
|
||||
|
||||
with patch("goose.utils.file_utils.create_file_list", return_value=files):
|
||||
extensions = create_extensions_list(project_root, max_n)
|
||||
assert extensions == [".py", ".md"], "Should ignore files without extensions"
|
||||
|
||||
|
||||
# tests for `create_language_weighting`
|
||||
def test_create_language_weighting_normal_case():
|
||||
"""Test the function with multiple files and different sizes."""
|
||||
files = [
|
||||
"/fake/project/file1.py",
|
||||
"/fake/project/file2.py",
|
||||
"/fake/project/file3.md",
|
||||
"/fake/project/file4.txt",
|
||||
]
|
||||
|
||||
sizes = {
|
||||
"/fake/project/file1.py": 100,
|
||||
"/fake/project/file2.py": 200,
|
||||
"/fake/project/file3.md": 50,
|
||||
"/fake/project/file4.txt": 150,
|
||||
}
|
||||
|
||||
# Mocking os.path.getsize to return different sizes for different files
|
||||
with patch("os.path.getsize") as mock_getsize:
|
||||
mock_getsize.side_effect = lambda file: sizes[file]
|
||||
|
||||
result = create_language_weighting(files)
|
||||
|
||||
total = sum(sizes.values())
|
||||
|
||||
expected_result = {
|
||||
".py": 300 / total * 100, # 300 out of 600 total
|
||||
".txt": 150 / total * 100, # 150 out of 600 total
|
||||
".md": 50 / total * 100, # 50 out of 600 total
|
||||
}
|
||||
|
||||
# Check if the result matches the expected output
|
||||
assert result[".py"] == pytest.approx(expected_result.get(".py"), 0.01)
|
||||
assert result[".txt"] == pytest.approx(expected_result.get(".txt"), 0.01)
|
||||
assert result[".md"] == pytest.approx(expected_result.get(".md"), 0.01)
|
||||
|
||||
|
||||
def test_create_language_weighting_no_files():
|
||||
"""Test the function when no files are provided."""
|
||||
files = []
|
||||
|
||||
result = create_language_weighting(files)
|
||||
assert result == {}, "Should return an empty dictionary when no files are provided"
|
||||
|
||||
|
||||
def test_create_language_weighting_files_without_extensions():
|
||||
"""Test the function when files have no extensions."""
|
||||
files = [
|
||||
"/fake/project/file1",
|
||||
"/fake/project/file2",
|
||||
]
|
||||
|
||||
with patch("os.path.getsize", return_value=100):
|
||||
result = create_language_weighting(files)
|
||||
|
||||
assert result == {}, "Should return an empty dictionary when files have no extensions"
|
||||
|
||||
|
||||
def test_create_language_weighting_zero_total_size():
|
||||
"""Test the function when all files have a size of 0."""
|
||||
files = [
|
||||
"/fake/project/file1.py",
|
||||
"/fake/project/file2.py",
|
||||
]
|
||||
|
||||
with patch("os.path.getsize", return_value=0):
|
||||
result = create_language_weighting(files)
|
||||
|
||||
assert result == {".py": 0}
|
||||
|
||||
|
||||
def test_create_language_weighting_single_file():
|
||||
"""Test the function with a single file."""
|
||||
files = [
|
||||
"/fake/project/file1.py",
|
||||
]
|
||||
|
||||
with patch("os.path.getsize", return_value=100):
|
||||
result = create_language_weighting(files)
|
||||
|
||||
assert result == {".py": 100.0}, "Should return 100% for the single file's extension"
|
||||
|
||||
|
||||
def test_create_language_weighting_mixed_extensions():
|
||||
"""Test the function with files of mixed extensions and sizes."""
|
||||
files = [
|
||||
"/fake/project/file1.py",
|
||||
"/fake/project/file2.py",
|
||||
"/fake/project/file3.md",
|
||||
"/fake/project/file4.txt",
|
||||
"/fake/project/file5.md",
|
||||
]
|
||||
|
||||
with patch("os.path.getsize") as mock_getsize:
|
||||
mock_getsize.side_effect = lambda file: {
|
||||
"/fake/project/file1.py": 100,
|
||||
"/fake/project/file2.py": 100,
|
||||
"/fake/project/file3.md": 200,
|
||||
"/fake/project/file4.txt": 300,
|
||||
"/fake/project/file5.md": 100,
|
||||
}[file]
|
||||
|
||||
result = create_language_weighting(files)
|
||||
|
||||
expected_result = {
|
||||
".txt": 37.5, # 300 out of 800 total
|
||||
".md": 37.5, # 300 out of 800 total
|
||||
".py": 25.0, # 200 out of 800 total
|
||||
}
|
||||
|
||||
assert result[".txt"] == pytest.approx(expected_result.get(".txt"), 0.01)
|
||||
assert result[".md"] == pytest.approx(expected_result.get(".md"), 0.01)
|
||||
assert result[".py"] == pytest.approx(expected_result.get(".py"), 0.01)
|
||||
77
tests/utils/test_session_file.py
Normal file
77
tests/utils/test_session_file.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from exchange import Message
|
||||
from goose.utils.session_file import list_sorted_session_files, read_from_file, session_file_exists, write_to_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_path(tmp_path):
|
||||
return tmp_path / "test_file.jsonl"
|
||||
|
||||
|
||||
def test_read_write_to_file(file_path):
|
||||
messages = [
|
||||
Message.user("prompt1"),
|
||||
Message.user("prompt2"),
|
||||
]
|
||||
write_to_file(file_path, messages)
|
||||
assert file_path.exists()
|
||||
|
||||
assert read_from_file(file_path) == messages
|
||||
|
||||
|
||||
def test_read_from_file_non_existing_file(tmp_path):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
read_from_file(tmp_path / "no_existing.json")
|
||||
|
||||
|
||||
def test_read_from_file_non_jsonl_file(file_path):
|
||||
file_path.write_text("Hello World")
|
||||
with pytest.raises(RuntimeError):
|
||||
read_from_file(file_path)
|
||||
|
||||
|
||||
def test_list_sorted_session_files(tmp_path):
|
||||
session_files_directory = tmp_path / "session_files_dir"
|
||||
session_files_directory.mkdir()
|
||||
file_names = ["file1", "file2", "file3"]
|
||||
created_session_files = [create_session_file(session_files_directory, file_name) for file_name in file_names]
|
||||
|
||||
sorted_files = list_sorted_session_files(session_files_directory)
|
||||
assert sorted_files == {
|
||||
"file3": created_session_files[2],
|
||||
"file2": created_session_files[1],
|
||||
"file1": created_session_files[0],
|
||||
}
|
||||
|
||||
|
||||
def test_list_sorted_session_without_session_files(tmp_path):
|
||||
session_files_directory = tmp_path / "session_files_dir"
|
||||
|
||||
sorted_files = list_sorted_session_files(session_files_directory)
|
||||
assert sorted_files == {}
|
||||
|
||||
|
||||
def test_session_file_exists_return_false_when_directory_does_not_exist(tmp_path):
|
||||
session_files_directory = tmp_path / "session_files_dir"
|
||||
assert not session_file_exists(session_files_directory)
|
||||
|
||||
|
||||
def test_session_file_exists_return_false_when_no_session_file_exists(tmp_path):
|
||||
session_files_directory = tmp_path / "session_files_dir"
|
||||
session_files_directory.mkdir()
|
||||
assert not session_file_exists(session_files_directory)
|
||||
|
||||
|
||||
def test_session_file_exists_return_true_when_session_file_exists(tmp_path):
|
||||
session_files_directory = tmp_path / "session_files_dir"
|
||||
session_files_directory.mkdir()
|
||||
create_session_file(session_files_directory, "session1")
|
||||
assert session_file_exists(session_files_directory)
|
||||
|
||||
|
||||
def create_session_file(file_path, file_name) -> Path:
|
||||
file = file_path / f"{file_name}.jsonl"
|
||||
file.touch()
|
||||
return file
|
||||
63
tests/utils/test_utils.py
Normal file
63
tests/utils/test_utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import string
|
||||
|
||||
import pytest
|
||||
from goose.utils import droid, ensure, ensure_list, load_plugins
|
||||
|
||||
|
||||
class MockClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.name == other.name
|
||||
|
||||
|
||||
def test_load_plugins():
|
||||
plugins = load_plugins("exchange.provider")
|
||||
assert isinstance(plugins, dict)
|
||||
assert len(plugins) > 0
|
||||
|
||||
|
||||
def test_ensure_with_class():
|
||||
mock_class = MockClass("foo")
|
||||
assert ensure(MockClass)(mock_class) == mock_class
|
||||
|
||||
|
||||
def test_ensure_with_dictionary():
|
||||
mock_class = ensure(MockClass)({"name": "foo"})
|
||||
assert mock_class == MockClass("foo")
|
||||
|
||||
|
||||
def test_ensure_with_invalid_dictionary():
|
||||
with pytest.raises(TypeError):
|
||||
ensure(MockClass)({"age": "foo"})
|
||||
|
||||
|
||||
def test_ensure_with_list():
|
||||
mock_class = ensure(MockClass)(["foo"])
|
||||
assert mock_class == MockClass("foo")
|
||||
|
||||
|
||||
def test_ensure_with_invalid_list():
|
||||
with pytest.raises(TypeError):
|
||||
ensure(MockClass)(["foo", "bar"])
|
||||
|
||||
|
||||
def test_ensure_with_value():
|
||||
mock_class = ensure(MockClass)("foo")
|
||||
assert mock_class == MockClass("foo")
|
||||
|
||||
|
||||
def test_ensure_list():
|
||||
obj_list = ensure_list(MockClass)(["foo", "bar"])
|
||||
assert obj_list == [MockClass("foo"), MockClass("bar")]
|
||||
|
||||
|
||||
def test_droid():
|
||||
result = droid()
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 4
|
||||
for character in [result[i] for i in [0, 2]]:
|
||||
assert character in string.ascii_lowercase, "should be in lower case"
|
||||
for character in [result[i] for i in [1, 3]]:
|
||||
assert character in string.digits, "should be a digit"
|
||||
Reference in New Issue
Block a user