browse: (1) apply validation also to scrape_links(), (2) add tests for scrape_links()

This commit is contained in:
Itamar Friedman
2023-04-11 11:17:07 +03:00
parent 5a6e565c52
commit 2d5d0131bb
2 changed files with 146 additions and 14 deletions

View File

@@ -5,25 +5,38 @@ from llm_utils import create_chat_completion
cfg = Config()
# Define and check for local file address prefixes
def check_local_file_access(url):
# Define and check for local file address prefixes
local_prefixes = ['file:///', 'file://localhost', 'http://localhost', 'https://localhost']
return any(url.startswith(prefix) for prefix in local_prefixes)
def get_validated_response(url, headers=cfg.user_agent_header):
try:
# Restrict access to local files
if check_local_file_access(url):
raise ValueError('Access to local files is restricted')
# Most basic check if the URL is valid:
if not url.startswith('http://') and not url.startswith('https://'):
raise ValueError('Invalid URL format')
# Make the HTTP request and return the response
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an exception if the response contains an HTTP error status code
return response, None
except ValueError as ve:
# Handle invalid URL format
return None, "Error: " + str(ve)
except requests.exceptions.RequestException as re:
# Handle exceptions related to the HTTP request (e.g., connection errors, timeouts, etc.)
return None, "Error: " + str(re)
def scrape_text(url):
"""Scrape text from a webpage"""
# Most basic check if the URL is valid:
if not url.startswith('http'):
return "Error: Invalid URL"
# Restrict access to local files
if check_local_file_access(url):
return "Error: Access to local files is restricted"
try:
response = requests.get(url, headers=cfg.user_agent_header)
except requests.exceptions.RequestException as e:
return "Error: " + str(e)
response, error_message = get_validated_response(url)
if error_message:
return error_message
# Check if the response contains an HTTP error
if response.status_code >= 400:
@@ -60,7 +73,9 @@ def format_hyperlinks(hyperlinks):
def scrape_links(url):
"""Scrape links from a webpage"""
response = requests.get(url, headers=cfg.user_agent_header)
response, error_message = get_validated_response(url)
if error_message:
return error_message
# Check if the response contains an HTTP error
if response.status_code >= 400: