Fix issues with backup importing (#2092)

* Closes #2087 
* Use `python manage.py import_ctf` instead of a new Process to import backups from the Admin Panel. 
    * This avoids a number of issues with gevent and webserver forking/threading models. 
* Add `--delete_import_on_finish` to `python manage.py import_ctf`
* Fix issue where `field_entries` table could not be imported when moving between MySQL and MariaDB
This commit is contained in:
Kevin Chung
2022-04-17 18:28:30 -04:00
committed by GitHub
parent 90e81d7298
commit 9ac0bbba6c
5 changed files with 89 additions and 17 deletions

View File

@@ -21,10 +21,17 @@
<p> <p>
<b>Import Error:</b> {{ import_error }} <b>Import Error:</b> {{ import_error }}
</p> </p>
<div class="alert alert-danger" role="alert">
An error occurred during the import. Please try again.
</div>
{% else %} {% else %}
<p> <p>
<b>Current Status:</b> {{ import_status }} <b>Current Status:</b> {{ import_status }}
</p> </p>
<div class="alert alert-secondary" role="alert">
Page will redirect upon completion. Refresh page to get latest status.<br>
Page will automatically refresh every 5 seconds.
</div>
{% endif %} {% endif %}
</div> </div>
</div> </div>
@@ -33,18 +40,19 @@
{% block scripts %} {% block scripts %}
<script> <script>
// Reload every 5 seconds to poll import status
setTimeout(function(){
window.location.reload();
}, 5000);
let start_time = "{{ start_time | tojson }}"; let start_time = "{{ start_time | tojson }}";
let end_time = "{{ end_time | tojson }}"; let end_time = "{{ end_time | tojson }}";
let start = document.getElementById("start-time"); let start = document.getElementById("start-time");
start.innerText = new Date(parseInt(start_time) * 1000); start.innerText = new Date(parseInt(start_time) * 1000);
let end = document.getElementById("end-time");
end.innerText = new Date(parseInt(end_time) * 1000);
// Reload every 5 seconds to poll import status if (end_time !== "null") {
if (!end_time) { let end = document.getElementById("end-time");
setTimeout(function(){ end.innerText = new Date(parseInt(end_time) * 1000);
window.location.reload();
}, 5000);
} }
</script> </script>
{% endblock %} {% endblock %}

View File

@@ -2,13 +2,14 @@ import datetime
import json import json
import os import os
import re import re
import subprocess # nosec B404
import sys
import tempfile import tempfile
import zipfile import zipfile
from io import BytesIO from io import BytesIO
from multiprocessing import Process from pathlib import Path
import dataset import dataset
from flask import copy_current_request_context
from flask import current_app as app from flask import current_app as app
from flask_migrate import upgrade as migration_upgrade from flask_migrate import upgrade as migration_upgrade
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
@@ -24,6 +25,7 @@ from CTFd.plugins.migrations import current as plugin_current
from CTFd.plugins.migrations import upgrade as plugin_upgrade from CTFd.plugins.migrations import upgrade as plugin_upgrade
from CTFd.utils import get_app_config, set_config, string_types from CTFd.utils import get_app_config, set_config, string_types
from CTFd.utils.dates import unix_time from CTFd.utils.dates import unix_time
from CTFd.utils.exports.databases import is_database_mariadb
from CTFd.utils.exports.freeze import freeze_export from CTFd.utils.exports.freeze import freeze_export
from CTFd.utils.migrations import ( from CTFd.utils.migrations import (
create_database, create_database,
@@ -95,6 +97,10 @@ def import_ctf(backup, erase=True):
cache.set(key="import_status", value=val, timeout=cache_timeout) cache.set(key="import_status", value=val, timeout=cache_timeout)
print(val) print(val)
# Reset import cache keys and don't print these values
cache.set(key="import_error", value=None, timeout=cache_timeout)
cache.set(key="import_status", value=None, timeout=cache_timeout)
if not zipfile.is_zipfile(backup): if not zipfile.is_zipfile(backup):
set_error("zipfile.BadZipfile: zipfile is invalid") set_error("zipfile.BadZipfile: zipfile is invalid")
raise zipfile.BadZipfile raise zipfile.BadZipfile
@@ -165,6 +171,7 @@ def import_ctf(backup, erase=True):
sqlite = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("sqlite") sqlite = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("sqlite")
postgres = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("postgres") postgres = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("postgres")
mysql = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("mysql") mysql = get_app_config("SQLALCHEMY_DATABASE_URI").startswith("mysql")
mariadb = is_database_mariadb()
if erase: if erase:
set_status("erasing") set_status("erasing")
@@ -258,7 +265,7 @@ def import_ctf(backup, erase=True):
table = side_db[table_name] table = side_db[table_name]
saved = json.loads(data) saved = json.loads(data)
count = saved["count"] count = len(saved["results"])
for i, entry in enumerate(saved["results"]): for i, entry in enumerate(saved["results"]):
set_status(f"inserting {member} {i}/{count}") set_status(f"inserting {member} {i}/{count}")
# This is a hack to get SQLite to properly accept datetime values from dataset # This is a hack to get SQLite to properly accept datetime values from dataset
@@ -306,6 +313,23 @@ def import_ctf(backup, erase=True):
if requirements and isinstance(requirements, string_types): if requirements and isinstance(requirements, string_types):
entry["requirements"] = json.loads(requirements) entry["requirements"] = json.loads(requirements)
# From v3.1.0 to v3.5.0 FieldEntries could have been varying levels of JSON'ified strings.
# For example "\"test\"" vs "test". This results in issues with importing backups between
# databases. Specifically between MySQL and MariaDB. Because CTFd standardizes against MySQL
# we need to have an edge case here.
if member == "db/field_entries.json":
value = entry.get("value")
if value:
try:
# Attempt to convert anything to its original Python value
entry["value"] = str(json.loads(value))
except (json.JSONDecodeError, TypeError):
pass
finally:
# Dump the value into JSON if its mariadb or skip the conversion if not mariadb
if mariadb:
entry["value"] = json.dumps(entry["value"])
try: try:
table.insert(entry) table.insert(entry)
except ProgrammingError: except ProgrammingError:
@@ -413,9 +437,11 @@ def import_ctf(backup, erase=True):
def background_import_ctf(backup): def background_import_ctf(backup):
@copy_current_request_context # The manage.py script will delete the backup for us
def ctx_bridge(): f = tempfile.NamedTemporaryFile(delete=False)
import_ctf(backup) backup.save(f.name)
python = sys.executable # Get path of Python interpreter
p = Process(target=ctx_bridge) manage_py = Path(app.root_path).parent / "manage.py" # Path to manage.py
p.start() subprocess.Popen( # nosec B603
[python, manage_py, "import_ctf", "--delete_import_on_finish", f.name]
)

View File

@@ -0,0 +1,12 @@
from sqlalchemy.exc import OperationalError
from CTFd.models import db
def is_database_mariadb():
try:
result = db.session.execute("SELECT version()").fetchone()[0]
mariadb = "mariadb" in result.lower()
except OperationalError:
mariadb = False
return mariadb

View File

@@ -4,6 +4,7 @@ from datetime import date, datetime
from decimal import Decimal from decimal import Decimal
from CTFd.utils import string_types from CTFd.utils import string_types
from CTFd.utils.exports.databases import is_database_mariadb
class JSONEncoder(json.JSONEncoder): class JSONEncoder(json.JSONEncoder):
@@ -35,6 +36,7 @@ class JSONSerializer(object):
return result return result
def close(self): def close(self):
mariadb = is_database_mariadb()
for _path, result in self.buckets.items(): for _path, result in self.buckets.items():
result = self.wrap(result) result = self.wrap(result)
@@ -42,6 +44,7 @@ class JSONSerializer(object):
# Before emitting a file we should standardize to valid JSON (i.e. a dict) # Before emitting a file we should standardize to valid JSON (i.e. a dict)
# See Issue #973 # See Issue #973
for i, r in enumerate(result["results"]): for i, r in enumerate(result["results"]):
# Handle JSON used in tables that use requirements
data = r.get("requirements") data = r.get("requirements")
if data: if data:
try: try:
@@ -50,5 +53,22 @@ class JSONSerializer(object):
except ValueError: except ValueError:
pass pass
# Handle JSON used in FieldEntries table
if mariadb:
if sorted(r.keys()) == [
"field_id",
"id",
"team_id",
"type",
"user_id",
"value",
]:
value = r.get("value")
if value:
try:
result["results"][i]["value"] = json.loads(value)
except ValueError:
pass
data = json.dumps(result, cls=JSONEncoder, separators=(",", ":")) data = json.dumps(result, cls=JSONEncoder, separators=(",", ":"))
self.fileobj.write(data.encode("utf-8")) self.fileobj.write(data.encode("utf-8"))

View File

@@ -1,6 +1,8 @@
import datetime import datetime
import shutil import shutil
from pathlib import Path
from flask_migrate import MigrateCommand from flask_migrate import MigrateCommand
from flask_script import Manager from flask_script import Manager
@@ -71,10 +73,14 @@ def export_ctf(path=None):
@manager.command @manager.command
def import_ctf(path): def import_ctf(path, delete_import_on_finish=False):
with app.app_context(): with app.app_context():
import_ctf_util(path) import_ctf_util(path)
if delete_import_on_finish:
print(f"Deleting {path}")
Path(path).unlink()
if __name__ == "__main__": if __name__ == "__main__":
manager.run() manager.run()