mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-21 07:54:21 +01:00
Validate URLs in web commands before execution (#2616)
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
This commit is contained in:
@@ -10,6 +10,7 @@ from autogpt.memory import get_memory
|
|||||||
from autogpt.processing.text import summarize_text
|
from autogpt.processing.text import summarize_text
|
||||||
from autogpt.prompts.generator import PromptGenerator
|
from autogpt.prompts.generator import PromptGenerator
|
||||||
from autogpt.speech import say_text
|
from autogpt.speech import say_text
|
||||||
|
from autogpt.url_utils.validators import validate_url
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
AGENT_MANAGER = AgentManager()
|
AGENT_MANAGER = AgentManager()
|
||||||
@@ -141,6 +142,7 @@ def execute_command(
|
|||||||
@command(
|
@command(
|
||||||
"get_text_summary", "Get text summary", '"url": "<url>", "question": "<question>"'
|
"get_text_summary", "Get text summary", '"url": "<url>", "question": "<question>"'
|
||||||
)
|
)
|
||||||
|
@validate_url
|
||||||
def get_text_summary(url: str, question: str) -> str:
|
def get_text_summary(url: str, question: str) -> str:
|
||||||
"""Return the results of a Google search
|
"""Return the results of a Google search
|
||||||
|
|
||||||
@@ -157,6 +159,7 @@ def get_text_summary(url: str, question: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@command("get_hyperlinks", "Get text summary", '"url": "<url>"')
|
@command("get_hyperlinks", "Get text summary", '"url": "<url>"')
|
||||||
|
@validate_url
|
||||||
def get_hyperlinks(url: str) -> Union[str, List[str]]:
|
def get_hyperlinks(url: str) -> Union[str, List[str]]:
|
||||||
"""Return the results of a Google search
|
"""Return the results of a Google search
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from git.repo import Repo
|
|||||||
|
|
||||||
from autogpt.commands.command import command
|
from autogpt.commands.command import command
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
|
from autogpt.url_utils.validators import validate_url
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ CFG = Config()
|
|||||||
CFG.github_username and CFG.github_api_key,
|
CFG.github_username and CFG.github_api_key,
|
||||||
"Configure github_username and github_api_key.",
|
"Configure github_username and github_api_key.",
|
||||||
)
|
)
|
||||||
|
@validate_url
|
||||||
def clone_repository(repository_url: str, clone_path: str) -> str:
|
def clone_repository(repository_url: str, clone_path: str) -> str:
|
||||||
"""Clone a GitHub repository locally.
|
"""Clone a GitHub repository locally.
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
"""Browse a webpage and summarize it using the LLM model"""
|
"""Browse a webpage and summarize it using the LLM model"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from urllib.parse import urljoin, urlparse
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from requests import Response
|
from requests import Response
|
||||||
from requests.compat import urljoin
|
|
||||||
|
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
from autogpt.processing.html import extract_hyperlinks, format_hyperlinks
|
from autogpt.processing.html import extract_hyperlinks, format_hyperlinks
|
||||||
|
from autogpt.url_utils.validators import validate_url
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -17,71 +15,7 @@ session = requests.Session()
|
|||||||
session.headers.update({"User-Agent": CFG.user_agent})
|
session.headers.update({"User-Agent": CFG.user_agent})
|
||||||
|
|
||||||
|
|
||||||
def is_valid_url(url: str) -> bool:
|
@validate_url
|
||||||
"""Check if the URL is valid
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the URL is valid, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = urlparse(url)
|
|
||||||
return all([result.scheme, result.netloc])
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_url(url: str) -> str:
|
|
||||||
"""Sanitize the URL
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The sanitized URL
|
|
||||||
"""
|
|
||||||
return urljoin(url, urlparse(url).path)
|
|
||||||
|
|
||||||
|
|
||||||
def check_local_file_access(url: str) -> bool:
|
|
||||||
"""Check if the URL is a local file
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url (str): The URL to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the URL is a local file, False otherwise
|
|
||||||
"""
|
|
||||||
local_prefixes = [
|
|
||||||
"file:///",
|
|
||||||
"file://localhost/",
|
|
||||||
"file://localhost",
|
|
||||||
"http://localhost",
|
|
||||||
"http://localhost/",
|
|
||||||
"https://localhost",
|
|
||||||
"https://localhost/",
|
|
||||||
"http://2130706433",
|
|
||||||
"http://2130706433/",
|
|
||||||
"https://2130706433",
|
|
||||||
"https://2130706433/",
|
|
||||||
"http://127.0.0.1/",
|
|
||||||
"http://127.0.0.1",
|
|
||||||
"https://127.0.0.1/",
|
|
||||||
"https://127.0.0.1",
|
|
||||||
"https://0.0.0.0/",
|
|
||||||
"https://0.0.0.0",
|
|
||||||
"http://0.0.0.0/",
|
|
||||||
"http://0.0.0.0",
|
|
||||||
"http://0000",
|
|
||||||
"http://0000/",
|
|
||||||
"https://0000",
|
|
||||||
"https://0000/",
|
|
||||||
]
|
|
||||||
return any(url.startswith(prefix) for prefix in local_prefixes)
|
|
||||||
|
|
||||||
|
|
||||||
def get_response(
|
def get_response(
|
||||||
url: str, timeout: int = 10
|
url: str, timeout: int = 10
|
||||||
) -> tuple[None, str] | tuple[Response, None]:
|
) -> tuple[None, str] | tuple[Response, None]:
|
||||||
@@ -99,17 +33,7 @@ def get_response(
|
|||||||
requests.exceptions.RequestException: If the HTTP request fails
|
requests.exceptions.RequestException: If the HTTP request fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Restrict access to local files
|
response = session.get(url, timeout=timeout)
|
||||||
if check_local_file_access(url):
|
|
||||||
raise ValueError("Access to local files is restricted")
|
|
||||||
|
|
||||||
# Most basic check if the URL is valid:
|
|
||||||
if not url.startswith("http://") and not url.startswith("https://"):
|
|
||||||
raise ValueError("Invalid URL format")
|
|
||||||
|
|
||||||
sanitized_url = sanitize_url(url)
|
|
||||||
|
|
||||||
response = session.get(sanitized_url, timeout=timeout)
|
|
||||||
|
|
||||||
# Check if the response contains an HTTP error
|
# Check if the response contains an HTTP error
|
||||||
if response.status_code >= 400:
|
if response.status_code >= 400:
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import autogpt.processing.text as summary
|
|||||||
from autogpt.commands.command import command
|
from autogpt.commands.command import command
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
from autogpt.processing.html import extract_hyperlinks, format_hyperlinks
|
from autogpt.processing.html import extract_hyperlinks, format_hyperlinks
|
||||||
|
from autogpt.url_utils.validators import validate_url
|
||||||
|
|
||||||
FILE_DIR = Path(__file__).parent.parent
|
FILE_DIR = Path(__file__).parent.parent
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@@ -31,6 +32,7 @@ CFG = Config()
|
|||||||
"Browse Website",
|
"Browse Website",
|
||||||
'"url": "<url>", "question": "<what_you_want_to_find_on_website>"',
|
'"url": "<url>", "question": "<what_you_want_to_find_on_website>"',
|
||||||
)
|
)
|
||||||
|
@validate_url
|
||||||
def browse_website(url: str, question: str) -> tuple[str, WebDriver]:
|
def browse_website(url: str, question: str) -> tuple[str, WebDriver]:
|
||||||
"""Browse a website and return the answer and links to the user
|
"""Browse a website and return the answer and links to the user
|
||||||
|
|
||||||
|
|||||||
0
autogpt/url_utils/__init__.py
Normal file
0
autogpt/url_utils/__init__.py
Normal file
101
autogpt/url_utils/validators.py
Normal file
101
autogpt/url_utils/validators.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
import functools
|
||||||
|
from typing import Any, Callable
|
||||||
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
|
from requests.compat import urljoin
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(func: Callable[..., Any]) -> Any:
|
||||||
|
"""The method decorator validate_url is used to validate urls for any command that requires
|
||||||
|
a url as an arugment"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(url: str, *args, **kwargs) -> Any:
|
||||||
|
"""Check if the URL is valid using a basic check, urllib check, and local file check
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the result of the wrapped function
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if the url fails any of the validation tests
|
||||||
|
"""
|
||||||
|
# Most basic check if the URL is valid:
|
||||||
|
if not url.startswith("http://") and not url.startswith("https://"):
|
||||||
|
raise ValueError("Invalid URL format")
|
||||||
|
if not is_valid_url(url):
|
||||||
|
raise ValueError("Missing Scheme or Network location")
|
||||||
|
# Restrict access to local files
|
||||||
|
if check_local_file_access(url):
|
||||||
|
raise ValueError("Access to local files is restricted")
|
||||||
|
|
||||||
|
return func(sanitize_url(url), *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_url(url: str) -> bool:
|
||||||
|
"""Check if the URL is valid
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the URL is valid, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all([result.scheme, result.netloc])
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_url(url: str) -> str:
|
||||||
|
"""Sanitize the URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The sanitized URL
|
||||||
|
"""
|
||||||
|
return urljoin(url, urlparse(url).path)
|
||||||
|
|
||||||
|
|
||||||
|
def check_local_file_access(url: str) -> bool:
|
||||||
|
"""Check if the URL is a local file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the URL is a local file, False otherwise
|
||||||
|
"""
|
||||||
|
local_prefixes = [
|
||||||
|
"file:///",
|
||||||
|
"file://localhost/",
|
||||||
|
"file://localhost",
|
||||||
|
"http://localhost",
|
||||||
|
"http://localhost/",
|
||||||
|
"https://localhost",
|
||||||
|
"https://localhost/",
|
||||||
|
"http://2130706433",
|
||||||
|
"http://2130706433/",
|
||||||
|
"https://2130706433",
|
||||||
|
"https://2130706433/",
|
||||||
|
"http://127.0.0.1/",
|
||||||
|
"http://127.0.0.1",
|
||||||
|
"https://127.0.0.1/",
|
||||||
|
"https://127.0.0.1",
|
||||||
|
"https://0.0.0.0/",
|
||||||
|
"https://0.0.0.0",
|
||||||
|
"http://0.0.0.0/",
|
||||||
|
"http://0.0.0.0",
|
||||||
|
"http://0000",
|
||||||
|
"http://0000/",
|
||||||
|
"https://0000",
|
||||||
|
"https://0000/",
|
||||||
|
]
|
||||||
|
return any(url.startswith(prefix) for prefix in local_prefixes)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
# Generated by CodiumAI
|
# Generated by CodiumAI
|
||||||
|
|
||||||
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from autogpt.commands.web_requests import scrape_text
|
from autogpt.commands.web_requests import scrape_text
|
||||||
@@ -58,9 +59,14 @@ class TestScrapeText:
|
|||||||
url = "http://www.example.com"
|
url = "http://www.example.com"
|
||||||
assert scrape_text(url) == expected_text
|
assert scrape_text(url) == expected_text
|
||||||
|
|
||||||
# Tests that the function returns an error message when an invalid or unreachable
|
# Tests that an error is raised when an invalid url is provided.
|
||||||
|
def test_invalid_url(self):
|
||||||
|
url = "invalidurl.com"
|
||||||
|
pytest.raises(ValueError, scrape_text, url)
|
||||||
|
|
||||||
|
# Tests that the function returns an error message when an unreachable
|
||||||
# url is provided.
|
# url is provided.
|
||||||
def test_invalid_url(self, mocker):
|
def test_unreachable_url(self, mocker):
|
||||||
# Mock the requests.get() method to raise an exception
|
# Mock the requests.get() method to raise an exception
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"requests.Session.get", side_effect=requests.exceptions.RequestException
|
"requests.Session.get", side_effect=requests.exceptions.RequestException
|
||||||
@@ -68,7 +74,7 @@ class TestScrapeText:
|
|||||||
|
|
||||||
# Call the function with an invalid URL and assert that it returns an error
|
# Call the function with an invalid URL and assert that it returns an error
|
||||||
# message
|
# message
|
||||||
url = "http://www.invalidurl.com"
|
url = "http://thiswebsitedoesnotexist.net/"
|
||||||
error_message = scrape_text(url)
|
error_message = scrape_text(url)
|
||||||
assert "Error:" in error_message
|
assert "Error:" in error_message
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user