1406 plugin migrations improvements (#1420)

* Handle plugin migrations during CTF import
* Closes #1406
This commit is contained in:
Kevin Chung
2020-05-19 21:21:31 -04:00
committed by GitHub
parent 148bdccf26
commit 5618f0d04c
3 changed files with 167 additions and 113 deletions

View File

@@ -161,6 +161,17 @@ def bypass_csrf_protection(f):
return f
def get_plugin_names():
modules = sorted(glob.glob(app.plugins_dir + "/*"))
blacklist = {"__pycache__"}
plugins = []
for module in modules:
module_name = os.path.basename(module)
if os.path.isdir(module) and module_name not in blacklist:
plugins.append(module_name)
return plugins
def init_plugins(app):
"""
Searches for the load function in modules in the CTFd/plugins folder. This function is called with the current CTFd
@@ -179,12 +190,8 @@ def init_plugins(app):
app.plugins_dir = os.path.dirname(__file__)
if app.config.get("SAFE_MODE", False) is False:
modules = sorted(glob.glob(os.path.dirname(__file__) + "/*"))
blacklist = {"__pycache__"}
for module in modules:
module_name = os.path.basename(module)
if os.path.isdir(module) and module_name not in blacklist:
module = "." + module_name
for plugin in get_plugin_names():
module = "." + plugin
module = importlib.import_module(module, package="CTFd.plugins")
module.load(app)
print(" * Loaded module, %s" % module)

View File

@@ -11,7 +11,19 @@ from sqlalchemy import create_engine, pool
from CTFd.utils import get_config, set_config
def upgrade(plugin_name=None):
def current(plugin_name=None):
if plugin_name is None:
# Get the directory name of the plugin if unspecified
# Doing it this way doesn't waste the rest of the inspect.stack call
frame = inspect.currentframe()
caller_info = inspect.getframeinfo(frame.f_back)
caller_path = caller_info[0]
plugin_name = os.path.basename(os.path.dirname(caller_path))
return get_config(plugin_name + "_alembic_version")
def upgrade(plugin_name=None, revision=None):
database_url = current_app.config.get("SQLALCHEMY_DATABASE_URI")
if database_url.startswith("sqlite"):
current_app.db.create_all()
@@ -25,6 +37,13 @@ def upgrade(plugin_name=None):
caller_path = caller_info[0]
plugin_name = os.path.basename(os.path.dirname(caller_path))
# Check if the plugin has migraitons
migrations_path = os.path.join(current_app.plugins_dir, plugin_name, "migrations")
if os.path.isdir(migrations_path) is False:
# Create any tables that the plugin may have
current_app.db.create_all()
return
engine = create_engine(database_url, poolclass=pool.NullPool)
conn = engine.connect()
context = MigrationContext.configure(conn)
@@ -32,19 +51,18 @@ def upgrade(plugin_name=None):
# Find the list of migrations to run
config = Config()
config.set_main_option(
"script_location",
os.path.join(current_app.plugins_dir, plugin_name, "migrations"),
)
config.set_main_option(
"version_locations",
os.path.join(current_app.plugins_dir, plugin_name, "migrations"),
)
config.set_main_option("script_location", migrations_path)
config.set_main_option("version_locations", migrations_path)
script = ScriptDirectory.from_config(config)
# get current revision for plugin
lower = get_config(plugin_name + "_alembic_version")
# Do we upgrade to head or to a specific revision
if revision is None:
upper = script.get_current_head()
else:
upper = revision
# Apply from lower to upper
revs = list(script.iterate_revisions(lower=lower, upper=upper))

View File

@@ -9,13 +9,15 @@ import dataset
import six
from alembic.util import CommandError
from flask import current_app as app
from flask_migrate import upgrade
from flask_migrate import upgrade as migration_upgrade
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.sql import sqltypes
from CTFd import __version__ as CTFD_VERSION
from CTFd.cache import cache
from CTFd.models import db, get_class_by_tablename
from CTFd.plugins import get_plugin_names
from CTFd.plugins.migrations import upgrade as plugin_upgrade, current as plugin_current
from CTFd.utils import get_app_config, set_config
from CTFd.utils.exports.freeze import freeze_export
from CTFd.utils.migrations import (
@@ -52,7 +54,7 @@ def export_ctf():
"results": [{"version_num": get_current_revision()}],
"meta": {},
}
result_file = six.BytesIO()
result_file = six.StringIO()
json.dump(result, result_file)
result_file.seek(0)
backup_zip.writestr("db/alembic_version.json", result_file.read())
@@ -164,26 +166,22 @@ def import_ctf(backup, erase=True):
"db/config.json",
]
# We want to insert certain database tables first so we are specifying
# the order with a list. The leftover tables are tables that are from a
# plugin (more likely) or a table where we do not care about insertion order
for item in first:
if item in members:
members.remove(item)
members = first + members
upgrade(revision=alembic_version)
# Create tables created by plugins
try:
app.db.create_all()
except OperationalError as e:
if not postgres:
raise e
else:
print("Allowing error during app.db.create_all() due to Postgres")
# Upgrade the database to the point in time that the import was taken from
migration_upgrade(revision=alembic_version)
members.remove("db/alembic_version.json")
for member in members:
# Combine the database insertion code into a function so that we can pause
# insertion between official database tables and plugin tables
def insertion(table_filenames):
for member in table_filenames:
if member.startswith("db/"):
table_name = member[3:-5]
@@ -216,7 +214,8 @@ def import_ctf(backup, erase=True):
# If the table is expecting a datetime, we should check if the string is one and convert it
if is_dt_column:
match = re.match(
r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d", v
r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d",
v,
)
if match:
entry[k] = datetime.datetime.strptime(
@@ -239,7 +238,9 @@ def import_ctf(backup, erase=True):
"db/awards.json",
):
requirements = entry.get("requirements")
if requirements and isinstance(requirements, six.string_types):
if requirements and isinstance(
requirements, six.string_types
):
entry["requirements"] = json.loads(requirements)
try:
@@ -271,6 +272,34 @@ def import_ctf(backup, erase=True):
)
)
# Insert data from official tables
insertion(first)
# Create tables created by plugins
try:
# Run plugin migrations
plugins = get_plugin_names()
try:
for plugin in plugins:
revision = plugin_current(plugin_name=plugin)
plugin_upgrade(plugin_name=plugin, revision=revision)
finally:
# Create tables that don't have migrations
app.db.create_all()
except OperationalError as e:
if not postgres:
raise e
else:
print("Allowing error during app.db.create_all() due to Postgres")
# Insert data for plugin tables
insertion(members)
# Bring plugin tables up to head revision
plugins = get_plugin_names()
for plugin in plugins:
plugin_upgrade(plugin_name=plugin)
# Extracting files
files = [f for f in backup.namelist() if f.startswith("uploads/")]
uploader = get_uploader()
@@ -288,7 +317,7 @@ def import_ctf(backup, erase=True):
# Alembic sqlite support is lacking so we should just create_all anyway
try:
upgrade(revision="head")
migration_upgrade(revision="head")
except (OperationalError, CommandError, RuntimeError, SystemExit, Exception):
app.db.create_all()
stamp_latest_revision()