Validate URLs in web commands before execution (#2616)

Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
Eddie Cohen
2023-04-24 06:33:44 -04:00
committed by GitHub
parent 794a164098
commit 40a75c804c
7 changed files with 120 additions and 82 deletions

View File

@@ -10,6 +10,7 @@ from autogpt.memory import get_memory
from autogpt.processing.text import summarize_text from autogpt.processing.text import summarize_text
from autogpt.prompts.generator import PromptGenerator from autogpt.prompts.generator import PromptGenerator
from autogpt.speech import say_text from autogpt.speech import say_text
from autogpt.url_utils.validators import validate_url
CFG = Config() CFG = Config()
AGENT_MANAGER = AgentManager() AGENT_MANAGER = AgentManager()
@@ -141,6 +142,7 @@ def execute_command(
@command( @command(
"get_text_summary", "Get text summary", '"url": "<url>", "question": "<question>"' "get_text_summary", "Get text summary", '"url": "<url>", "question": "<question>"'
) )
@validate_url
def get_text_summary(url: str, question: str) -> str: def get_text_summary(url: str, question: str) -> str:
"""Return the results of a Google search """Return the results of a Google search
@@ -157,6 +159,7 @@ def get_text_summary(url: str, question: str) -> str:
@command("get_hyperlinks", "Get text summary", '"url": "<url>"') @command("get_hyperlinks", "Get text summary", '"url": "<url>"')
@validate_url
def get_hyperlinks(url: str) -> Union[str, List[str]]: def get_hyperlinks(url: str) -> Union[str, List[str]]:
"""Return the results of a Google search """Return the results of a Google search

View File

@@ -3,6 +3,7 @@ from git.repo import Repo
from autogpt.commands.command import command from autogpt.commands.command import command
from autogpt.config import Config from autogpt.config import Config
from autogpt.url_utils.validators import validate_url
CFG = Config() CFG = Config()
@@ -14,6 +15,7 @@ CFG = Config()
CFG.github_username and CFG.github_api_key, CFG.github_username and CFG.github_api_key,
"Configure github_username and github_api_key.", "Configure github_username and github_api_key.",
) )
@validate_url
def clone_repository(repository_url: str, clone_path: str) -> str: def clone_repository(repository_url: str, clone_path: str) -> str:
"""Clone a GitHub repository locally. """Clone a GitHub repository locally.

View File

@@ -1,15 +1,13 @@
"""Browse a webpage and summarize it using the LLM model""" """Browse a webpage and summarize it using the LLM model"""
from __future__ import annotations from __future__ import annotations
from urllib.parse import urljoin, urlparse
import requests import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from requests import Response from requests import Response
from requests.compat import urljoin
from autogpt.config import Config from autogpt.config import Config
from autogpt.processing.html import extract_hyperlinks, format_hyperlinks from autogpt.processing.html import extract_hyperlinks, format_hyperlinks
from autogpt.url_utils.validators import validate_url
CFG = Config() CFG = Config()
@@ -17,71 +15,7 @@ session = requests.Session()
session.headers.update({"User-Agent": CFG.user_agent}) session.headers.update({"User-Agent": CFG.user_agent})
def is_valid_url(url: str) -> bool: @validate_url
"""Check if the URL is valid
Args:
url (str): The URL to check
Returns:
bool: True if the URL is valid, False otherwise
"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False
def sanitize_url(url: str) -> str:
"""Sanitize the URL
Args:
url (str): The URL to sanitize
Returns:
str: The sanitized URL
"""
return urljoin(url, urlparse(url).path)
def check_local_file_access(url: str) -> bool:
"""Check if the URL is a local file
Args:
url (str): The URL to check
Returns:
bool: True if the URL is a local file, False otherwise
"""
local_prefixes = [
"file:///",
"file://localhost/",
"file://localhost",
"http://localhost",
"http://localhost/",
"https://localhost",
"https://localhost/",
"http://2130706433",
"http://2130706433/",
"https://2130706433",
"https://2130706433/",
"http://127.0.0.1/",
"http://127.0.0.1",
"https://127.0.0.1/",
"https://127.0.0.1",
"https://0.0.0.0/",
"https://0.0.0.0",
"http://0.0.0.0/",
"http://0.0.0.0",
"http://0000",
"http://0000/",
"https://0000",
"https://0000/",
]
return any(url.startswith(prefix) for prefix in local_prefixes)
def get_response( def get_response(
url: str, timeout: int = 10 url: str, timeout: int = 10
) -> tuple[None, str] | tuple[Response, None]: ) -> tuple[None, str] | tuple[Response, None]:
@@ -99,17 +33,7 @@ def get_response(
requests.exceptions.RequestException: If the HTTP request fails requests.exceptions.RequestException: If the HTTP request fails
""" """
try: try:
# Restrict access to local files response = session.get(url, timeout=timeout)
if check_local_file_access(url):
raise ValueError("Access to local files is restricted")
# Most basic check if the URL is valid:
if not url.startswith("http://") and not url.startswith("https://"):
raise ValueError("Invalid URL format")
sanitized_url = sanitize_url(url)
response = session.get(sanitized_url, timeout=timeout)
# Check if the response contains an HTTP error # Check if the response contains an HTTP error
if response.status_code >= 400: if response.status_code >= 400:

View File

@@ -21,6 +21,7 @@ import autogpt.processing.text as summary
from autogpt.commands.command import command from autogpt.commands.command import command
from autogpt.config import Config from autogpt.config import Config
from autogpt.processing.html import extract_hyperlinks, format_hyperlinks from autogpt.processing.html import extract_hyperlinks, format_hyperlinks
from autogpt.url_utils.validators import validate_url
FILE_DIR = Path(__file__).parent.parent FILE_DIR = Path(__file__).parent.parent
CFG = Config() CFG = Config()
@@ -31,6 +32,7 @@ CFG = Config()
"Browse Website", "Browse Website",
'"url": "<url>", "question": "<what_you_want_to_find_on_website>"', '"url": "<url>", "question": "<what_you_want_to_find_on_website>"',
) )
@validate_url
def browse_website(url: str, question: str) -> tuple[str, WebDriver]: def browse_website(url: str, question: str) -> tuple[str, WebDriver]:
"""Browse a website and return the answer and links to the user """Browse a website and return the answer and links to the user

View File

View File

@@ -0,0 +1,101 @@
import functools
from typing import Any, Callable
from urllib.parse import urljoin, urlparse
from requests.compat import urljoin
def validate_url(func: Callable[..., Any]) -> Any:
"""The method decorator validate_url is used to validate urls for any command that requires
a url as an arugment"""
@functools.wraps(func)
def wrapper(url: str, *args, **kwargs) -> Any:
"""Check if the URL is valid using a basic check, urllib check, and local file check
Args:
url (str): The URL to check
Returns:
the result of the wrapped function
Raises:
ValueError if the url fails any of the validation tests
"""
# Most basic check if the URL is valid:
if not url.startswith("http://") and not url.startswith("https://"):
raise ValueError("Invalid URL format")
if not is_valid_url(url):
raise ValueError("Missing Scheme or Network location")
# Restrict access to local files
if check_local_file_access(url):
raise ValueError("Access to local files is restricted")
return func(sanitize_url(url), *args, **kwargs)
return wrapper
def is_valid_url(url: str) -> bool:
"""Check if the URL is valid
Args:
url (str): The URL to check
Returns:
bool: True if the URL is valid, False otherwise
"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False
def sanitize_url(url: str) -> str:
"""Sanitize the URL
Args:
url (str): The URL to sanitize
Returns:
str: The sanitized URL
"""
return urljoin(url, urlparse(url).path)
def check_local_file_access(url: str) -> bool:
"""Check if the URL is a local file
Args:
url (str): The URL to check
Returns:
bool: True if the URL is a local file, False otherwise
"""
local_prefixes = [
"file:///",
"file://localhost/",
"file://localhost",
"http://localhost",
"http://localhost/",
"https://localhost",
"https://localhost/",
"http://2130706433",
"http://2130706433/",
"https://2130706433",
"https://2130706433/",
"http://127.0.0.1/",
"http://127.0.0.1",
"https://127.0.0.1/",
"https://127.0.0.1",
"https://0.0.0.0/",
"https://0.0.0.0",
"http://0.0.0.0/",
"http://0.0.0.0",
"http://0000",
"http://0000/",
"https://0000",
"https://0000/",
]
return any(url.startswith(prefix) for prefix in local_prefixes)

View File

@@ -1,5 +1,6 @@
# Generated by CodiumAI # Generated by CodiumAI
import pytest
import requests import requests
from autogpt.commands.web_requests import scrape_text from autogpt.commands.web_requests import scrape_text
@@ -58,9 +59,14 @@ class TestScrapeText:
url = "http://www.example.com" url = "http://www.example.com"
assert scrape_text(url) == expected_text assert scrape_text(url) == expected_text
# Tests that the function returns an error message when an invalid or unreachable # Tests that an error is raised when an invalid url is provided.
def test_invalid_url(self):
url = "invalidurl.com"
pytest.raises(ValueError, scrape_text, url)
# Tests that the function returns an error message when an unreachable
# url is provided. # url is provided.
def test_invalid_url(self, mocker): def test_unreachable_url(self, mocker):
# Mock the requests.get() method to raise an exception # Mock the requests.get() method to raise an exception
mocker.patch( mocker.patch(
"requests.Session.get", side_effect=requests.exceptions.RequestException "requests.Session.get", side_effect=requests.exceptions.RequestException
@@ -68,7 +74,7 @@ class TestScrapeText:
# Call the function with an invalid URL and assert that it returns an error # Call the function with an invalid URL and assert that it returns an error
# message # message
url = "http://www.invalidurl.com" url = "http://thiswebsitedoesnotexist.net/"
error_message = scrape_text(url) error_message = scrape_text(url)
assert "Error:" in error_message assert "Error:" in error_message