Import backup improvements (#2078)

* Add progress tracking to import_ctf
* Make imports happen in the background so that we can see status
* Add GET /admin/import to see status of import
* Disable the public interface during imports
* Closes #1980
This commit is contained in:
Kevin Chung
2022-04-08 16:52:04 -04:00
committed by GitHub
parent 0c6e28315c
commit f24f2a18bb
8 changed files with 151 additions and 25 deletions

View File

@@ -45,9 +45,8 @@ from CTFd.utils import config as ctf_config
from CTFd.utils import get_config, set_config
from CTFd.utils.csv import dump_csv, load_challenges_csv, load_teams_csv, load_users_csv
from CTFd.utils.decorators import admins_only
from CTFd.utils.exports import background_import_ctf
from CTFd.utils.exports import export_ctf as export_ctf_util
from CTFd.utils.exports import import_ctf as import_ctf_util
from CTFd.utils.helpers import get_errors
from CTFd.utils.security.auth import logout_user
from CTFd.utils.uploads import delete_file
from CTFd.utils.user import is_admin
@@ -88,21 +87,25 @@ def plugin(plugin):
return "1"
@admin.route("/admin/import", methods=["POST"])
@admin.route("/admin/import", methods=["GET", "POST"])
@admins_only
def import_ctf():
if request.method == "GET":
start_time = cache.get("import_start_time")
end_time = cache.get("import_end_time")
import_status = cache.get("import_status")
import_error = cache.get("import_error")
return render_template(
"admin/import.html",
start_time=start_time,
end_time=end_time,
import_status=import_status,
import_error=import_error,
)
elif request.method == "POST":
backup = request.files["backup"]
errors = get_errors()
try:
import_ctf_util(backup)
except Exception as e:
print(e)
errors.append(repr(e))
if errors:
return errors[0], 500
else:
return redirect(url_for("admin.config"))
background_import_ctf(backup)
return redirect(url_for("admin.import_ctf"))
@admin.route("/admin/export", methods=["GET", "POST"])

View File

@@ -359,12 +359,7 @@ function importConfig(event) {
target: pg,
width: 100
});
setTimeout(function() {
pg.modal("hide");
}, 500);
setTimeout(function() {
window.location.reload();
}, 700);
location.href = CTFd.config.urlRoot + "/admin/import";
}
});
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,50 @@
{% extends "admin/base.html" %}
{% block content %}
<div class="jumbotron">
<div class="container">
<h1>Import Status</h1>
</div>
</div>
<div class="container">
<div class="row">
<div class="col-md-6 offset-md-3">
<p>
<b>Start Time:</b> <span id="start-time">{{ start_time }}</span>
</p>
{% if end_time %}
<p>
<b>End Time:</b> <span id="end-time">{{ end_time }}</span>
</p>
{% endif %}
{% if import_error %}
<p>
<b>Import Error:</b> {{ import_error }}
</p>
{% else %}
<p>
<b>Current Status:</b> {{ import_status }}
</p>
{% endif %}
</div>
</div>
</div>
{% endblock %}
{% block scripts %}
<script>
let start_time = "{{ start_time | tojson }}";
let end_time = "{{ end_time | tojson }}";
let start = document.getElementById("start-time");
start.innerText = new Date(parseInt(start_time) * 1000);
let end = document.getElementById("end-time");
end.innerText = new Date(parseInt(end_time) * 1000);
// Reload every 5 seconds to poll import status
if (!end_time) {
setTimeout(function(){
window.location.reload();
}, 5000);
}
</script>
{% endblock %}

View File

@@ -81,3 +81,14 @@ def set_config(key, value):
cache.delete_memoized(_get_config, key)
return config
def import_in_progress():
import_status = cache.get(key="import_status")
import_error = cache.get(key="import_error")
if import_error:
return False
elif import_status:
return True
else:
return False

View File

@@ -5,8 +5,10 @@ import re
import tempfile
import zipfile
from io import BytesIO
from multiprocessing import Process
import dataset
from flask import copy_current_request_context
from flask import current_app as app
from flask_migrate import upgrade as migration_upgrade
from sqlalchemy.engine.url import make_url
@@ -21,6 +23,7 @@ from CTFd.plugins import get_plugin_names
from CTFd.plugins.migrations import current as plugin_current
from CTFd.plugins.migrations import upgrade as plugin_upgrade
from CTFd.utils import get_app_config, set_config, string_types
from CTFd.utils.dates import unix_time
from CTFd.utils.exports.freeze import freeze_export
from CTFd.utils.migrations import (
create_database,
@@ -82,7 +85,18 @@ def export_ctf():
def import_ctf(backup, erase=True):
cache_timeout = 604800 # 604800 is 1 week in seconds
def set_error(val):
cache.set(key="import_error", value=val, timeout=cache_timeout)
print(val)
def set_status(val):
cache.set(key="import_status", value=val, timeout=cache_timeout)
print(val)
if not zipfile.is_zipfile(backup):
set_error("zipfile.BadZipfile: zipfile is invalid")
raise zipfile.BadZipfile
backup = zipfile.ZipFile(backup)
@@ -92,15 +106,18 @@ def import_ctf(backup, erase=True):
for f in members:
if f.startswith("/") or ".." in f:
# Abort on malicious zip files
set_error("zipfile.BadZipfile: zipfile is malicious")
raise zipfile.BadZipfile
info = backup.getinfo(f)
if max_content_length:
if info.file_size > max_content_length:
set_error("zipfile.LargeZipFile: zipfile is too large")
raise zipfile.LargeZipFile
# Get list of directories in zipfile
member_dirs = [os.path.split(m)[0] for m in members if "/" in m]
if "db" not in member_dirs:
set_error("Exception: db folder is missing")
raise Exception(
'CTFd couldn\'t find the "db" folder in this backup. '
"The backup may be malformed or corrupted and the import process cannot continue."
@@ -110,6 +127,7 @@ def import_ctf(backup, erase=True):
alembic_version = json.loads(backup.open("db/alembic_version.json").read())
alembic_version = alembic_version["results"][0]["version_num"]
except Exception:
set_error("Exception: Could not determine appropriate database version")
raise Exception(
"Could not determine appropriate database version. This backup cannot be automatically imported."
)
@@ -130,15 +148,26 @@ def import_ctf(backup, erase=True):
"dab615389702",
"e62fd69bd417",
):
set_error(
"Exception: The version of CTFd that this backup is from is too old to be automatically imported."
)
raise Exception(
"The version of CTFd that this backup is from is too old to be automatically imported."
)
start_time = unix_time(datetime.datetime.utcnow())
cache.set(key="import_start_time", value=start_time, timeout=cache_timeout)
cache.set(key="import_end_time", value=None, timeout=cache_timeout)
set_status("started")
sqlite = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("sqlite")
postgres = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("postgres")
mysql = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("mysql")
if erase:
set_status("erasing")
# Clear out existing connections to release any locks
db.session.close()
db.engine.dispose()
@@ -165,10 +194,12 @@ def import_ctf(backup, erase=True):
create_database()
# We explicitly do not want to upgrade or stamp here.
# The import will have this information.
set_status("erased")
side_db = dataset.connect(get_app_config("SQLALCHEMY_DATABASE_URI"))
try:
set_status("disabling foreign key checks")
if postgres:
side_db.query("SET session_replication_role=replica;")
else:
@@ -213,6 +244,7 @@ def import_ctf(backup, erase=True):
# insertion between official database tables and plugin tables
def insertion(table_filenames):
for member in table_filenames:
set_status(f"inserting {member}")
if member.startswith("db/"):
table_name = member[3:-5]
@@ -226,7 +258,9 @@ def import_ctf(backup, erase=True):
table = side_db[table_name]
saved = json.loads(data)
for entry in saved["results"]:
count = saved["count"]
for i, entry in enumerate(saved["results"]):
set_status(f"inserting {member} {i}/{count}")
# This is a hack to get SQLite to properly accept datetime values from dataset
# See Issue #246
if sqlite:
@@ -295,6 +329,9 @@ def import_ctf(backup, erase=True):
)
side_db.engine.execute(query)
else:
set_error(
f"Exception: Table name {table_name} contains quotes"
)
raise Exception(
"Table name {table_name} contains quotes".format(
table_name=table_name
@@ -302,12 +339,15 @@ def import_ctf(backup, erase=True):
)
# Insert data from official tables
set_status("inserting tables")
insertion(first)
# Create tables created by plugins
# Run plugin migrations
set_status("inserting plugins")
plugins = get_plugin_names()
for plugin in plugins:
set_status(f"inserting plugin {plugin}")
revision = plugin_current(plugin_name=plugin)
plugin_upgrade(plugin_name=plugin, revision=revision, lower=None)
@@ -320,6 +360,7 @@ def import_ctf(backup, erase=True):
plugin_upgrade(plugin_name=plugin)
# Extracting files
set_status("uploading files")
files = [f for f in backup.namelist() if f.startswith("uploads/")]
uploader = get_uploader()
for f in files:
@@ -335,6 +376,7 @@ def import_ctf(backup, erase=True):
uploader.store(fileobj=source, filename=filename)
# Alembic sqlite support is lacking so we should just create_all anyway
set_status("running head migrations")
if sqlite:
app.db.create_all()
stamp_latest_revision()
@@ -345,6 +387,7 @@ def import_ctf(backup, erase=True):
app.db.create_all()
try:
set_status("reenabling foreign key checks")
if postgres:
side_db.query("SET session_replication_role=DEFAULT;")
else:
@@ -353,8 +396,26 @@ def import_ctf(backup, erase=True):
print("Failed to enable foreign key checks. Continuing.")
# Invalidate all cached data
set_status("clearing caches")
cache.clear()
# Set default theme in case the current instance or the import does not provide it
set_config("ctf_theme", DEFAULT_THEME)
set_config("ctf_version", CTFD_VERSION)
# Set config variables to mark import completed
cache.set(key="import_start_time", value=start_time, timeout=cache_timeout)
cache.set(
key="import_end_time",
value=unix_time(datetime.datetime.utcnow()),
timeout=cache_timeout,
)
def background_import_ctf(backup):
@copy_current_request_context
def ctx_bridge():
import_ctf(backup)
p = Process(target=ctx_bridge)
p.start()

View File

@@ -10,7 +10,7 @@ from werkzeug.middleware.dispatcher import DispatcherMiddleware
from CTFd.cache import clear_user_recent_ips
from CTFd.exceptions import UserNotFoundException, UserTokenExpiredException
from CTFd.models import Tracking, db
from CTFd.utils import config, get_config, markdown
from CTFd.utils import config, get_config, import_in_progress, markdown
from CTFd.utils.config import (
can_send_mail,
ctf_logo,
@@ -208,6 +208,12 @@ def init_request_processors(app):
if request.endpoint == "views.themes":
return
if import_in_progress():
if request.endpoint == "admin.import_ctf":
return
else:
abort(403, description="Import currently in progress")
if authed():
user_ips = get_current_user_recent_ips()
ip = get_ip()