Import backup improvements (#2078)

* Add progress tracking to import_ctf
* Make imports happen in the background so that we can see status
* Add GET /admin/import to see status of import
* Disable the public interface during imports
* Closes #1980
This commit is contained in:
Kevin Chung
2022-04-08 16:52:04 -04:00
committed by GitHub
parent 0c6e28315c
commit f24f2a18bb
8 changed files with 151 additions and 25 deletions

View File

@@ -45,9 +45,8 @@ from CTFd.utils import config as ctf_config
from CTFd.utils import get_config, set_config from CTFd.utils import get_config, set_config
from CTFd.utils.csv import dump_csv, load_challenges_csv, load_teams_csv, load_users_csv from CTFd.utils.csv import dump_csv, load_challenges_csv, load_teams_csv, load_users_csv
from CTFd.utils.decorators import admins_only from CTFd.utils.decorators import admins_only
from CTFd.utils.exports import background_import_ctf
from CTFd.utils.exports import export_ctf as export_ctf_util from CTFd.utils.exports import export_ctf as export_ctf_util
from CTFd.utils.exports import import_ctf as import_ctf_util
from CTFd.utils.helpers import get_errors
from CTFd.utils.security.auth import logout_user from CTFd.utils.security.auth import logout_user
from CTFd.utils.uploads import delete_file from CTFd.utils.uploads import delete_file
from CTFd.utils.user import is_admin from CTFd.utils.user import is_admin
@@ -88,21 +87,25 @@ def plugin(plugin):
return "1" return "1"
@admin.route("/admin/import", methods=["POST"]) @admin.route("/admin/import", methods=["GET", "POST"])
@admins_only @admins_only
def import_ctf(): def import_ctf():
backup = request.files["backup"] if request.method == "GET":
errors = get_errors() start_time = cache.get("import_start_time")
try: end_time = cache.get("import_end_time")
import_ctf_util(backup) import_status = cache.get("import_status")
except Exception as e: import_error = cache.get("import_error")
print(e) return render_template(
errors.append(repr(e)) "admin/import.html",
start_time=start_time,
if errors: end_time=end_time,
return errors[0], 500 import_status=import_status,
else: import_error=import_error,
return redirect(url_for("admin.config")) )
elif request.method == "POST":
backup = request.files["backup"]
background_import_ctf(backup)
return redirect(url_for("admin.import_ctf"))
@admin.route("/admin/export", methods=["GET", "POST"]) @admin.route("/admin/export", methods=["GET", "POST"])

View File

@@ -359,12 +359,7 @@ function importConfig(event) {
target: pg, target: pg,
width: 100 width: 100
}); });
setTimeout(function() { location.href = CTFd.config.urlRoot + "/admin/import";
pg.modal("hide");
}, 500);
setTimeout(function() {
window.location.reload();
}, 700);
} }
}); });
} }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,50 @@
{% extends "admin/base.html" %}
{% block content %}
<div class="jumbotron">
<div class="container">
<h1>Import Status</h1>
</div>
</div>
<div class="container">
<div class="row">
<div class="col-md-6 offset-md-3">
<p>
<b>Start Time:</b> <span id="start-time">{{ start_time }}</span>
</p>
{% if end_time %}
<p>
<b>End Time:</b> <span id="end-time">{{ end_time }}</span>
</p>
{% endif %}
{% if import_error %}
<p>
<b>Import Error:</b> {{ import_error }}
</p>
{% else %}
<p>
<b>Current Status:</b> {{ import_status }}
</p>
{% endif %}
</div>
</div>
</div>
{% endblock %}
{% block scripts %}
<script>
let start_time = "{{ start_time | tojson }}";
let end_time = "{{ end_time | tojson }}";
let start = document.getElementById("start-time");
start.innerText = new Date(parseInt(start_time) * 1000);
let end = document.getElementById("end-time");
end.innerText = new Date(parseInt(end_time) * 1000);
// Reload every 5 seconds to poll import status
if (!end_time) {
setTimeout(function(){
window.location.reload();
}, 5000);
}
</script>
{% endblock %}

View File

@@ -81,3 +81,14 @@ def set_config(key, value):
cache.delete_memoized(_get_config, key) cache.delete_memoized(_get_config, key)
return config return config
def import_in_progress():
import_status = cache.get(key="import_status")
import_error = cache.get(key="import_error")
if import_error:
return False
elif import_status:
return True
else:
return False

View File

@@ -5,8 +5,10 @@ import re
import tempfile import tempfile
import zipfile import zipfile
from io import BytesIO from io import BytesIO
from multiprocessing import Process
import dataset import dataset
from flask import copy_current_request_context
from flask import current_app as app from flask import current_app as app
from flask_migrate import upgrade as migration_upgrade from flask_migrate import upgrade as migration_upgrade
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
@@ -21,6 +23,7 @@ from CTFd.plugins import get_plugin_names
from CTFd.plugins.migrations import current as plugin_current from CTFd.plugins.migrations import current as plugin_current
from CTFd.plugins.migrations import upgrade as plugin_upgrade from CTFd.plugins.migrations import upgrade as plugin_upgrade
from CTFd.utils import get_app_config, set_config, string_types from CTFd.utils import get_app_config, set_config, string_types
from CTFd.utils.dates import unix_time
from CTFd.utils.exports.freeze import freeze_export from CTFd.utils.exports.freeze import freeze_export
from CTFd.utils.migrations import ( from CTFd.utils.migrations import (
create_database, create_database,
@@ -82,7 +85,18 @@ def export_ctf():
def import_ctf(backup, erase=True): def import_ctf(backup, erase=True):
cache_timeout = 604800 # 604800 is 1 week in seconds
def set_error(val):
cache.set(key="import_error", value=val, timeout=cache_timeout)
print(val)
def set_status(val):
cache.set(key="import_status", value=val, timeout=cache_timeout)
print(val)
if not zipfile.is_zipfile(backup): if not zipfile.is_zipfile(backup):
set_error("zipfile.BadZipfile: zipfile is invalid")
raise zipfile.BadZipfile raise zipfile.BadZipfile
backup = zipfile.ZipFile(backup) backup = zipfile.ZipFile(backup)
@@ -92,15 +106,18 @@ def import_ctf(backup, erase=True):
for f in members: for f in members:
if f.startswith("/") or ".." in f: if f.startswith("/") or ".." in f:
# Abort on malicious zip files # Abort on malicious zip files
set_error("zipfile.BadZipfile: zipfile is malicious")
raise zipfile.BadZipfile raise zipfile.BadZipfile
info = backup.getinfo(f) info = backup.getinfo(f)
if max_content_length: if max_content_length:
if info.file_size > max_content_length: if info.file_size > max_content_length:
set_error("zipfile.LargeZipFile: zipfile is too large")
raise zipfile.LargeZipFile raise zipfile.LargeZipFile
# Get list of directories in zipfile # Get list of directories in zipfile
member_dirs = [os.path.split(m)[0] for m in members if "/" in m] member_dirs = [os.path.split(m)[0] for m in members if "/" in m]
if "db" not in member_dirs: if "db" not in member_dirs:
set_error("Exception: db folder is missing")
raise Exception( raise Exception(
'CTFd couldn\'t find the "db" folder in this backup. ' 'CTFd couldn\'t find the "db" folder in this backup. '
"The backup may be malformed or corrupted and the import process cannot continue." "The backup may be malformed or corrupted and the import process cannot continue."
@@ -110,6 +127,7 @@ def import_ctf(backup, erase=True):
alembic_version = json.loads(backup.open("db/alembic_version.json").read()) alembic_version = json.loads(backup.open("db/alembic_version.json").read())
alembic_version = alembic_version["results"][0]["version_num"] alembic_version = alembic_version["results"][0]["version_num"]
except Exception: except Exception:
set_error("Exception: Could not determine appropriate database version")
raise Exception( raise Exception(
"Could not determine appropriate database version. This backup cannot be automatically imported." "Could not determine appropriate database version. This backup cannot be automatically imported."
) )
@@ -130,15 +148,26 @@ def import_ctf(backup, erase=True):
"dab615389702", "dab615389702",
"e62fd69bd417", "e62fd69bd417",
): ):
set_error(
"Exception: The version of CTFd that this backup is from is too old to be automatically imported."
)
raise Exception( raise Exception(
"The version of CTFd that this backup is from is too old to be automatically imported." "The version of CTFd that this backup is from is too old to be automatically imported."
) )
start_time = unix_time(datetime.datetime.utcnow())
cache.set(key="import_start_time", value=start_time, timeout=cache_timeout)
cache.set(key="import_end_time", value=None, timeout=cache_timeout)
set_status("started")
sqlite = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("sqlite") sqlite = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("sqlite")
postgres = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("postgres") postgres = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("postgres")
mysql = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("mysql") mysql = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("mysql")
if erase: if erase:
set_status("erasing")
# Clear out existing connections to release any locks # Clear out existing connections to release any locks
db.session.close() db.session.close()
db.engine.dispose() db.engine.dispose()
@@ -165,10 +194,12 @@ def import_ctf(backup, erase=True):
create_database() create_database()
# We explicitly do not want to upgrade or stamp here. # We explicitly do not want to upgrade or stamp here.
# The import will have this information. # The import will have this information.
set_status("erased")
side_db = dataset.connect(get_app_config("SQLALCHEMY_DATABASE_URI")) side_db = dataset.connect(get_app_config("SQLALCHEMY_DATABASE_URI"))
try: try:
set_status("disabling foreign key checks")
if postgres: if postgres:
side_db.query("SET session_replication_role=replica;") side_db.query("SET session_replication_role=replica;")
else: else:
@@ -213,6 +244,7 @@ def import_ctf(backup, erase=True):
# insertion between official database tables and plugin tables # insertion between official database tables and plugin tables
def insertion(table_filenames): def insertion(table_filenames):
for member in table_filenames: for member in table_filenames:
set_status(f"inserting {member}")
if member.startswith("db/"): if member.startswith("db/"):
table_name = member[3:-5] table_name = member[3:-5]
@@ -226,7 +258,9 @@ def import_ctf(backup, erase=True):
table = side_db[table_name] table = side_db[table_name]
saved = json.loads(data) saved = json.loads(data)
for entry in saved["results"]: count = saved["count"]
for i, entry in enumerate(saved["results"]):
set_status(f"inserting {member} {i}/{count}")
# This is a hack to get SQLite to properly accept datetime values from dataset # This is a hack to get SQLite to properly accept datetime values from dataset
# See Issue #246 # See Issue #246
if sqlite: if sqlite:
@@ -295,6 +329,9 @@ def import_ctf(backup, erase=True):
) )
side_db.engine.execute(query) side_db.engine.execute(query)
else: else:
set_error(
f"Exception: Table name {table_name} contains quotes"
)
raise Exception( raise Exception(
"Table name {table_name} contains quotes".format( "Table name {table_name} contains quotes".format(
table_name=table_name table_name=table_name
@@ -302,12 +339,15 @@ def import_ctf(backup, erase=True):
) )
# Insert data from official tables # Insert data from official tables
set_status("inserting tables")
insertion(first) insertion(first)
# Create tables created by plugins # Create tables created by plugins
# Run plugin migrations # Run plugin migrations
set_status("inserting plugins")
plugins = get_plugin_names() plugins = get_plugin_names()
for plugin in plugins: for plugin in plugins:
set_status(f"inserting plugin {plugin}")
revision = plugin_current(plugin_name=plugin) revision = plugin_current(plugin_name=plugin)
plugin_upgrade(plugin_name=plugin, revision=revision, lower=None) plugin_upgrade(plugin_name=plugin, revision=revision, lower=None)
@@ -320,6 +360,7 @@ def import_ctf(backup, erase=True):
plugin_upgrade(plugin_name=plugin) plugin_upgrade(plugin_name=plugin)
# Extracting files # Extracting files
set_status("uploading files")
files = [f for f in backup.namelist() if f.startswith("uploads/")] files = [f for f in backup.namelist() if f.startswith("uploads/")]
uploader = get_uploader() uploader = get_uploader()
for f in files: for f in files:
@@ -335,6 +376,7 @@ def import_ctf(backup, erase=True):
uploader.store(fileobj=source, filename=filename) uploader.store(fileobj=source, filename=filename)
# Alembic sqlite support is lacking so we should just create_all anyway # Alembic sqlite support is lacking so we should just create_all anyway
set_status("running head migrations")
if sqlite: if sqlite:
app.db.create_all() app.db.create_all()
stamp_latest_revision() stamp_latest_revision()
@@ -345,6 +387,7 @@ def import_ctf(backup, erase=True):
app.db.create_all() app.db.create_all()
try: try:
set_status("reenabling foreign key checks")
if postgres: if postgres:
side_db.query("SET session_replication_role=DEFAULT;") side_db.query("SET session_replication_role=DEFAULT;")
else: else:
@@ -353,8 +396,26 @@ def import_ctf(backup, erase=True):
print("Failed to enable foreign key checks. Continuing.") print("Failed to enable foreign key checks. Continuing.")
# Invalidate all cached data # Invalidate all cached data
set_status("clearing caches")
cache.clear() cache.clear()
# Set default theme in case the current instance or the import does not provide it # Set default theme in case the current instance or the import does not provide it
set_config("ctf_theme", DEFAULT_THEME) set_config("ctf_theme", DEFAULT_THEME)
set_config("ctf_version", CTFD_VERSION) set_config("ctf_version", CTFD_VERSION)
# Set config variables to mark import completed
cache.set(key="import_start_time", value=start_time, timeout=cache_timeout)
cache.set(
key="import_end_time",
value=unix_time(datetime.datetime.utcnow()),
timeout=cache_timeout,
)
def background_import_ctf(backup):
@copy_current_request_context
def ctx_bridge():
import_ctf(backup)
p = Process(target=ctx_bridge)
p.start()

View File

@@ -10,7 +10,7 @@ from werkzeug.middleware.dispatcher import DispatcherMiddleware
from CTFd.cache import clear_user_recent_ips from CTFd.cache import clear_user_recent_ips
from CTFd.exceptions import UserNotFoundException, UserTokenExpiredException from CTFd.exceptions import UserNotFoundException, UserTokenExpiredException
from CTFd.models import Tracking, db from CTFd.models import Tracking, db
from CTFd.utils import config, get_config, markdown from CTFd.utils import config, get_config, import_in_progress, markdown
from CTFd.utils.config import ( from CTFd.utils.config import (
can_send_mail, can_send_mail,
ctf_logo, ctf_logo,
@@ -208,6 +208,12 @@ def init_request_processors(app):
if request.endpoint == "views.themes": if request.endpoint == "views.themes":
return return
if import_in_progress():
if request.endpoint == "admin.import_ctf":
return
else:
abort(403, description="Import currently in progress")
if authed(): if authed():
user_ips = get_current_user_recent_ips() user_ips = get_current_user_recent_ips()
ip = get_ip() ip = get_ip()