Forge/workshop (#5654)

* Added basic memory

* Added action history

* Deleted placeholder files

* adding memstore

* Added web search ability

* Added web search and reading web pages

* remove agent.py changes

Signed-off-by: Merwane Hamadi <merwanehamadi@gmail.com>

---------

Signed-off-by: Merwane Hamadi <merwanehamadi@gmail.com>
Co-authored-by: SwiftyOS <craigswift13@gmail.com>
This commit is contained in:
merwanehamadi
2023-10-09 08:32:52 -07:00
committed by GitHub
parent f77d383a9f
commit 3bd8ae4843
17 changed files with 1351 additions and 374 deletions

View File

@@ -7,7 +7,12 @@ from forge.sdk import (
Task, Task,
TaskRequestBody, TaskRequestBody,
Workspace, Workspace,
PromptEngine,
chat_completion_request,
ChromaMemStore
) )
import json
import pprint
LOG = ForgeLogger(__name__) LOG = ForgeLogger(__name__)

View File

@@ -1,11 +1,12 @@
import os import os
from forge.agent import ForgeAgent from forge.agent import ForgeAgent
from forge.sdk import AgentDB, LocalWorkspace from forge.sdk import LocalWorkspace
from .db import ForgeDatabase
database_name = os.getenv("DATABASE_STRING") database_name = os.getenv("DATABASE_STRING")
workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE")) workspace = LocalWorkspace(os.getenv("AGENT_WORKSPACE"))
database = AgentDB(database_name, debug_enabled=False) database = ForgeDatabase(database_name, debug_enabled=False)
agent = ForgeAgent(database=database, workspace=workspace) agent = ForgeAgent(database=database, workspace=workspace)
app = agent.get_agent_app() app = agent.get_agent_app()

145
autogpts/forge/forge/db.py Normal file
View File

@@ -0,0 +1,145 @@
from .sdk import AgentDB, ForgeLogger, NotFoundError, Base
from sqlalchemy.exc import SQLAlchemyError
import datetime
from sqlalchemy import (
Column,
DateTime,
String,
)
import uuid
LOG = ForgeLogger(__name__)
class ChatModel(Base):
__tablename__ = "chat"
msg_id = Column(String, primary_key=True, index=True)
task_id = Column(String)
role = Column(String)
content = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
)
class ActionModel(Base):
__tablename__ = "action"
action_id = Column(String, primary_key=True, index=True)
task_id = Column(String)
name = Column(String)
args = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
modified_at = Column(
DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
)
class ForgeDatabase(AgentDB):
async def add_chat_history(self, task_id, messages):
for message in messages:
await self.add_chat_message(task_id, message['role'], message['content'])
async def add_chat_message(self, task_id, role, content):
if self.debug_enabled:
LOG.debug("Creating new task")
try:
with self.Session() as session:
mew_msg = ChatModel(
msg_id=str(uuid.uuid4()),
task_id=task_id,
role=role,
content=content,
)
session.add(mew_msg)
session.commit()
session.refresh(mew_msg)
if self.debug_enabled:
LOG.debug(f"Created new Chat message with task_id: {mew_msg.msg_id}")
return mew_msg
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while creating task: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while creating task: {e}")
raise
async def get_chat_history(self, task_id):
if self.debug_enabled:
LOG.debug(f"Getting chat history with task_id: {task_id}")
try:
with self.Session() as session:
if messages := (
session.query(ChatModel)
.filter(ChatModel.task_id == task_id)
.order_by(ChatModel.created_at)
.all()
):
return [{"role": m.role, "content": m.content} for m in messages]
else:
LOG.error(
f"Chat history not found with task_id: {task_id}"
)
raise NotFoundError("Chat history not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting chat history: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting chat history: {e}")
raise
async def create_action(self, task_id, name, args):
try:
with self.Session() as session:
new_action = ActionModel(
action_id=str(uuid.uuid4()),
task_id=task_id,
name=name,
args=str(args),
)
session.add(new_action)
session.commit()
session.refresh(new_action)
if self.debug_enabled:
LOG.debug(f"Created new Action with task_id: {new_action.action_id}")
return new_action
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while creating action: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while creating action: {e}")
raise
async def get_action_history(self, task_id):
if self.debug_enabled:
LOG.debug(f"Getting action history with task_id: {task_id}")
try:
with self.Session() as session:
if actions := (
session.query(ActionModel)
.filter(ActionModel.task_id == task_id)
.order_by(ActionModel.created_at)
.all()
):
return [{"name": a.name, "args": a.args} for a in actions]
else:
LOG.error(
f"Action history not found with task_id: {task_id}"
)
raise NotFoundError("Action history not found")
except SQLAlchemyError as e:
LOG.error(f"SQLAlchemy error while getting action history: {e}")
raise
except NotFoundError as e:
raise
except Exception as e:
LOG.error(f"Unexpected error while getting action history: {e}")
raise

View File

@@ -9,7 +9,7 @@ Reply only in json with the following format:
\"speak\": \"thoughts summary to say to user\", \"speak\": \"thoughts summary to say to user\",
}, },
\"ability\": { \"ability\": {
\"name\": {\"type\": \"string\"}, \"name\": \"ability name\",
\"args\": { \"args\": {
\"arg1\": \"value1", etc... \"arg1\": \"value1", etc...
} }

View File

@@ -40,4 +40,11 @@ You have access to the following abilities you can call:
- {{ best_practice }} - {{ best_practice }}
{% endfor %} {% endfor %}
{% endif %} {% endif %}
{% if previous_actions %}
## History of Abilities Used
{% for action in previous_actions %}
- {{ action }}
{% endfor %}
{% endif %}
{% endblock %} {% endblock %}

View File

@@ -3,7 +3,7 @@ The Forge SDK. This is the core of the Forge. It contains the agent protocol, wh
core of the Forge. core of the Forge.
""" """
from .agent import Agent from .agent import Agent
from .db import AgentDB from .db import AgentDB, Base
from .forge_log import ForgeLogger from .forge_log import ForgeLogger
from .llm import chat_completion_request, create_embedding_request, transcribe_audio from .llm import chat_completion_request, create_embedding_request, transcribe_audio
from .prompting import PromptEngine from .prompting import PromptEngine
@@ -22,3 +22,6 @@ from .schema import (
TaskStepsListResponse, TaskStepsListResponse,
) )
from .workspace import LocalWorkspace, Workspace from .workspace import LocalWorkspace, Workspace
from .errors import *
from .memory.chroma_memstore import ChromaMemStore
from .memory.memstore import MemStore

View File

@@ -2,7 +2,6 @@ from typing import List
from ..registry import ability from ..registry import ability
@ability( @ability(
name="list_files", name="list_files",
description="List files in a directory", description="List files in a directory",
@@ -20,7 +19,7 @@ async def list_files(agent, task_id: str, path: str) -> List[str]:
""" """
List files in a workspace directory List files in a workspace directory
""" """
return agent.workspace.list(task_id=task_id, path=path) return agent.workspace.list(task_id=task_id, path=str(path))
@ability( @ability(
@@ -42,7 +41,7 @@ async def list_files(agent, task_id: str, path: str) -> List[str]:
], ],
output_type="None", output_type="None",
) )
async def write_file(agent, task_id: str, file_path: str, data: bytes) -> None: async def write_file(agent, task_id: str, file_path: str, data: bytes):
""" """
Write data to a file Write data to a file
""" """
@@ -50,7 +49,7 @@ async def write_file(agent, task_id: str, file_path: str, data: bytes) -> None:
data = data.encode() data = data.encode()
agent.workspace.write(task_id=task_id, path=file_path, data=data) agent.workspace.write(task_id=task_id, path=file_path, data=data)
await agent.db.create_artifact( return await agent.db.create_artifact(
task_id=task_id, task_id=task_id,
file_name=file_path.split("/")[-1], file_name=file_path.split("/")[-1],
relative_path=file_path, relative_path=file_path,

View File

@@ -0,0 +1,75 @@
from __future__ import annotations
import json
import time
from itertools import islice
from duckduckgo_search import DDGS
from ..registry import ability
DUCKDUCKGO_MAX_ATTEMPTS = 3
@ability(
name="web_search",
description="Searches the web",
parameters=[
{
"name": "query",
"description": "The search query",
"type": "string",
"required": True,
}
],
output_type="list[str]",
)
async def web_search(agent, task_id: str, query: str) -> str:
"""Return the results of a Google search
Args:
query (str): The search query.
num_results (int): The number of results to return.
Returns:
str: The results of the search.
"""
search_results = []
attempts = 0
num_results = 8
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
if not query:
return json.dumps(search_results)
results = DDGS().text(query)
search_results = list(islice(results, num_results))
if search_results:
break
time.sleep(1)
attempts += 1
results = json.dumps(search_results, ensure_ascii=False, indent=4)
return safe_google_results(results)
def safe_google_results(results: str | list) -> str:
"""
Return the results of a Google search in a safe format.
Args:
results (str | list): The search results.
Returns:
str: The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps(
[result.encode("utf-8", "ignore").decode("utf-8") for result in results]
)
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message

View File

@@ -0,0 +1,375 @@
"""Commands for browsing a website"""
from __future__ import annotations
COMMAND_CATEGORY = "web_browse"
COMMAND_CATEGORY_TITLE = "Web Browsing"
import logging
import re
from pathlib import Path
from sys import platform
from typing import TYPE_CHECKING, Optional, Type, List, Tuple
from bs4 import BeautifulSoup
from selenium.common.exceptions import WebDriverException
from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.chrome.service import Service as ChromeDriverService
from selenium.webdriver.chrome.webdriver import WebDriver as ChromeDriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.options import ArgOptions as BrowserOptions
from selenium.webdriver.edge.options import Options as EdgeOptions
from selenium.webdriver.edge.service import Service as EdgeDriverService
from selenium.webdriver.edge.webdriver import WebDriver as EdgeDriver
from selenium.webdriver.firefox.options import Options as FirefoxOptions
from selenium.webdriver.firefox.service import Service as GeckoDriverService
from selenium.webdriver.firefox.webdriver import WebDriver as FirefoxDriver
from selenium.webdriver.remote.webdriver import WebDriver
from selenium.webdriver.safari.options import Options as SafariOptions
from selenium.webdriver.safari.webdriver import WebDriver as SafariDriver
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from webdriver_manager.chrome import ChromeDriverManager
from webdriver_manager.firefox import GeckoDriverManager
from webdriver_manager.microsoft import EdgeChromiumDriverManager as EdgeDriverManager
from ..registry import ability
from forge.sdk.errors import *
import functools
import re
from typing import Any, Callable
from urllib.parse import urljoin, urlparse
from requests.compat import urljoin
from bs4 import BeautifulSoup
from requests.compat import urljoin
def extract_hyperlinks(soup: BeautifulSoup, base_url: str) -> list[tuple[str, str]]:
"""Extract hyperlinks from a BeautifulSoup object
Args:
soup (BeautifulSoup): The BeautifulSoup object
base_url (str): The base URL
Returns:
List[Tuple[str, str]]: The extracted hyperlinks
"""
return [
(link.text, urljoin(base_url, link["href"]))
for link in soup.find_all("a", href=True)
]
def format_hyperlinks(hyperlinks: list[tuple[str, str]]) -> list[str]:
"""Format hyperlinks to be displayed to the user
Args:
hyperlinks (List[Tuple[str, str]]): The hyperlinks to format
Returns:
List[str]: The formatted hyperlinks
"""
return [f"{link_text} ({link_url})" for link_text, link_url in hyperlinks]
def validate_url(func: Callable[..., Any]) -> Any:
"""The method decorator validate_url is used to validate urls for any command that requires
a url as an argument"""
@functools.wraps(func)
def wrapper(url: str, *args, **kwargs) -> Any:
"""Check if the URL is valid using a basic check, urllib check, and local file check
Args:
url (str): The URL to check
Returns:
the result of the wrapped function
Raises:
ValueError if the url fails any of the validation tests
"""
# Most basic check if the URL is valid:
if not re.match(r"^https?://", url):
raise ValueError("Invalid URL format")
if not is_valid_url(url):
raise ValueError("Missing Scheme or Network location")
# Restrict access to local files
if check_local_file_access(url):
raise ValueError("Access to local files is restricted")
# Check URL length
if len(url) > 2000:
raise ValueError("URL is too long")
return func(sanitize_url(url), *args, **kwargs)
return wrapper
def is_valid_url(url: str) -> bool:
"""Check if the URL is valid
Args:
url (str): The URL to check
Returns:
bool: True if the URL is valid, False otherwise
"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False
def sanitize_url(url: str) -> str:
"""Sanitize the URL
Args:
url (str): The URL to sanitize
Returns:
str: The sanitized URL
"""
parsed_url = urlparse(url)
reconstructed_url = f"{parsed_url.path}{parsed_url.params}?{parsed_url.query}"
return urljoin(url, reconstructed_url)
def check_local_file_access(url: str) -> bool:
"""Check if the URL is a local file
Args:
url (str): The URL to check
Returns:
bool: True if the URL is a local file, False otherwise
"""
local_prefixes = [
"file:///",
"file://localhost/",
"file://localhost",
"http://localhost",
"http://localhost/",
"https://localhost",
"https://localhost/",
"http://2130706433",
"http://2130706433/",
"https://2130706433",
"https://2130706433/",
"http://127.0.0.1/",
"http://127.0.0.1",
"https://127.0.0.1/",
"https://127.0.0.1",
"https://0.0.0.0/",
"https://0.0.0.0",
"http://0.0.0.0/",
"http://0.0.0.0",
"http://0000",
"http://0000/",
"https://0000",
"https://0000/",
]
return any(url.startswith(prefix) for prefix in local_prefixes)
logger = logging.getLogger(__name__)
FILE_DIR = Path(__file__).parent.parent
TOKENS_TO_TRIGGER_SUMMARY = 50
LINKS_TO_RETURN = 20
class BrowsingError(CommandExecutionError):
"""An error occurred while trying to browse the page"""
@ability(
name="read_webpage",
description="Read a webpage, and extract specific information from it if a question is specified. If you are looking to extract specific information from the webpage, you should specify a question.",
parameters=[
{
"name": "url",
"description": "The URL to visit",
"type": "string",
"required": True,
},
{
"name": "question",
"description": "A question that you want to answer using the content of the webpage.",
"type": "string",
"required": False,
}
],
output_type="string",
)
@validate_url
async def read_webpage(agent, task_id: str, url: str, question: str = "") -> Tuple(str, List[str]):
"""Browse a website and return the answer and links to the user
Args:
url (str): The url of the website to browse
question (str): The question to answer using the content of the webpage
Returns:
str: The answer and links to the user and the webdriver
"""
driver = None
try:
driver = open_page_in_browser(url)
text = scrape_text_with_selenium(driver)
links = scrape_links_with_selenium(driver, url)
if not text:
return f"Website did not contain any text.\n\nLinks: {links}"
# Limit links to LINKS_TO_RETURN
if len(links) > LINKS_TO_RETURN:
links = links[:LINKS_TO_RETURN]
return (text, links)
except WebDriverException as e:
# These errors are often quite long and include lots of context.
# Just grab the first line.
msg = e.msg.split("\n")[0]
if "net::" in msg:
raise BrowsingError(
f"A networking error occurred while trying to load the page: "
+ re.sub(r"^unknown error: ", "", msg)
)
raise CommandExecutionError(msg)
finally:
if driver:
close_browser(driver)
def scrape_text_with_selenium(driver: WebDriver) -> str:
"""Scrape text from a browser window using selenium
Args:
driver (WebDriver): A driver object representing the browser window to scrape
Returns:
str: the text scraped from the website
"""
# Get the HTML content directly from the browser's DOM
page_source = driver.execute_script("return document.body.outerHTML;")
soup = BeautifulSoup(page_source, "html.parser")
for script in soup(["script", "style"]):
script.extract()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = "\n".join(chunk for chunk in chunks if chunk)
return text
def scrape_links_with_selenium(driver: WebDriver, base_url: str) -> list[str]:
"""Scrape links from a website using selenium
Args:
driver (WebDriver): A driver object representing the browser window to scrape
base_url (str): The base URL to use for resolving relative links
Returns:
List[str]: The links scraped from the website
"""
page_source = driver.page_source
soup = BeautifulSoup(page_source, "html.parser")
for script in soup(["script", "style"]):
script.extract()
hyperlinks = extract_hyperlinks(soup, base_url)
return format_hyperlinks(hyperlinks)
def open_page_in_browser(url: str) -> WebDriver:
"""Open a browser window and load a web page using Selenium
Params:
url (str): The URL of the page to load
Returns:
driver (WebDriver): A driver object representing the browser window to scrape
"""
logging.getLogger("selenium").setLevel(logging.CRITICAL)
selenium_web_browser = "chrome"
selenium_headless = True
options_available: dict[str, Type[BrowserOptions]] = {
"chrome": ChromeOptions,
"edge": EdgeOptions,
"firefox": FirefoxOptions,
"safari": SafariOptions,
}
options: BrowserOptions = options_available[selenium_web_browser]()
options.add_argument(
"user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.5615.49 Safari/537.36"
)
if selenium_web_browser == "firefox":
if selenium_headless:
options.headless = True
options.add_argument("--disable-gpu")
driver = FirefoxDriver(
service=GeckoDriverService(GeckoDriverManager().install()), options=options
)
elif selenium_web_browser == "edge":
driver = EdgeDriver(
service=EdgeDriverService(EdgeDriverManager().install()), options=options
)
elif selenium_web_browser == "safari":
# Requires a bit more setup on the users end
# See https://developer.apple.com/documentation/webkit/testing_with_webdriver_in_safari
driver = SafariDriver(options=options)
else:
if platform == "linux" or platform == "linux2":
options.add_argument("--disable-dev-shm-usage")
options.add_argument("--remote-debugging-port=9222")
options.add_argument("--no-sandbox")
if selenium_headless:
options.add_argument("--headless=new")
options.add_argument("--disable-gpu")
chromium_driver_path = Path("/usr/bin/chromedriver")
driver = ChromeDriver(
service=ChromeDriverService(str(chromium_driver_path))
if chromium_driver_path.exists()
else ChromeDriverService(ChromeDriverManager().install()),
options=options,
)
driver.get(url)
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
return driver
def close_browser(driver: WebDriver) -> None:
"""Close the browser
Args:
driver (WebDriver): The webdriver to close
Returns:
None
"""
driver.quit()

View File

@@ -1,25 +0,0 @@
"""
PROFILE CONCEPT:
The profile generator is used to intiliase and configure an ai agent.
It came from the obsivation that if an llm is provided with a profile such as:
```
Expert:
```
Then it's performance at a task can impove. Here we use the profile to generate
a system prompt for the agent to use. However, it can be used to configure other
aspects of the agent such as memory, planning, and actions available.
The possibilities are limited just by your imagination.
"""
from forge.sdk import PromptEngine
class ProfileGenerator:
def __init__(self, task: str, PromptEngine: PromptEngine):
"""
Initialize the profile generator with the task to be performed.
"""
self.task = task

View File

@@ -37,6 +37,7 @@ class LocalWorkspace(Workspace):
self.base_path = Path(base_path).resolve() self.base_path = Path(base_path).resolve()
def _resolve_path(self, task_id: str, path: str) -> Path: def _resolve_path(self, task_id: str, path: str) -> Path:
path = str(path)
path = path if not path.startswith("/") else path[1:] path = path if not path.startswith("/") else path[1:]
abs_path = (self.base_path / task_id / path).resolve() abs_path = (self.base_path / task_id / path).resolve()
if not str(abs_path).startswith(str(self.base_path)): if not str(abs_path).startswith(str(self.base_path)):
@@ -77,4 +78,6 @@ class LocalWorkspace(Workspace):
def list(self, task_id: str, path: str) -> typing.List[str]: def list(self, task_id: str, path: str) -> typing.List[str]:
path = self.base_path / task_id / path path = self.base_path / task_id / path
base = self._resolve_path(task_id, path) base = self._resolve_path(task_id, path)
if not base.exists() or not base.is_dir():
return []
return [str(p.relative_to(self.base_path / task_id)) for p in base.iterdir()] return [str(p.relative_to(self.base_path / task_id)) for p in base.iterdir()]

File diff suppressed because it is too large Load Diff

View File

@@ -21,6 +21,9 @@ toml = "^0.10.2"
jinja2 = "^3.1.2" jinja2 = "^3.1.2"
uvicorn = "^0.23.2" uvicorn = "^0.23.2"
litellm = "^0.1.821" litellm = "^0.1.821"
duckduckgo-search = "^3.9.3"
selenium = "^4.13.0"
bs4 = "^0.0.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
isort = "^5.12.0" isort = "^5.12.0"

16
cli.py
View File

@@ -301,14 +301,22 @@ def stop():
import subprocess import subprocess
try: try:
pid = int(subprocess.check_output(["lsof", "-t", "-i", ":8000"])) pids = subprocess.check_output(["lsof", "-t", "-i", ":8000"]).split()
os.kill(pid, signal.SIGTERM) if isinstance(pids, int):
os.kill(int(pids), signal.SIGTERM)
else:
for pid in pids:
os.kill(int(pid), signal.SIGTERM)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
click.echo("No process is running on port 8000") click.echo("No process is running on port 8000")
try: try:
pid = int(subprocess.check_output(["lsof", "-t", "-i", ":8080"])) pids = int(subprocess.check_output(["lsof", "-t", "-i", ":8080"]))
os.kill(pid, signal.SIGTERM) if isinstance(pids, int):
os.kill(int(pids), signal.SIGTERM)
else:
for pid in pids:
os.kill(int(pid), signal.SIGTERM)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
click.echo("No process is running on port 8080") click.echo("No process is running on port 8080")