mirror of
https://github.com/aljazceru/CTFd.git
synced 2025-12-18 06:24:23 +01:00
1406 plugin migrations improvements (#1420)
* Handle plugin migrations during CTF import * Closes #1406
This commit is contained in:
@@ -161,6 +161,17 @@ def bypass_csrf_protection(f):
|
|||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def get_plugin_names():
|
||||||
|
modules = sorted(glob.glob(app.plugins_dir + "/*"))
|
||||||
|
blacklist = {"__pycache__"}
|
||||||
|
plugins = []
|
||||||
|
for module in modules:
|
||||||
|
module_name = os.path.basename(module)
|
||||||
|
if os.path.isdir(module) and module_name not in blacklist:
|
||||||
|
plugins.append(module_name)
|
||||||
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
def init_plugins(app):
|
def init_plugins(app):
|
||||||
"""
|
"""
|
||||||
Searches for the load function in modules in the CTFd/plugins folder. This function is called with the current CTFd
|
Searches for the load function in modules in the CTFd/plugins folder. This function is called with the current CTFd
|
||||||
@@ -179,15 +190,11 @@ def init_plugins(app):
|
|||||||
app.plugins_dir = os.path.dirname(__file__)
|
app.plugins_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
if app.config.get("SAFE_MODE", False) is False:
|
if app.config.get("SAFE_MODE", False) is False:
|
||||||
modules = sorted(glob.glob(os.path.dirname(__file__) + "/*"))
|
for plugin in get_plugin_names():
|
||||||
blacklist = {"__pycache__"}
|
module = "." + plugin
|
||||||
for module in modules:
|
module = importlib.import_module(module, package="CTFd.plugins")
|
||||||
module_name = os.path.basename(module)
|
module.load(app)
|
||||||
if os.path.isdir(module) and module_name not in blacklist:
|
print(" * Loaded module, %s" % module)
|
||||||
module = "." + module_name
|
|
||||||
module = importlib.import_module(module, package="CTFd.plugins")
|
|
||||||
module.load(app)
|
|
||||||
print(" * Loaded module, %s" % module)
|
|
||||||
|
|
||||||
app.jinja_env.globals.update(get_admin_plugin_menu_bar=get_admin_plugin_menu_bar)
|
app.jinja_env.globals.update(get_admin_plugin_menu_bar=get_admin_plugin_menu_bar)
|
||||||
app.jinja_env.globals.update(get_user_page_menu_bar=get_user_page_menu_bar)
|
app.jinja_env.globals.update(get_user_page_menu_bar=get_user_page_menu_bar)
|
||||||
|
|||||||
@@ -11,7 +11,19 @@ from sqlalchemy import create_engine, pool
|
|||||||
from CTFd.utils import get_config, set_config
|
from CTFd.utils import get_config, set_config
|
||||||
|
|
||||||
|
|
||||||
def upgrade(plugin_name=None):
|
def current(plugin_name=None):
|
||||||
|
if plugin_name is None:
|
||||||
|
# Get the directory name of the plugin if unspecified
|
||||||
|
# Doing it this way doesn't waste the rest of the inspect.stack call
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
caller_info = inspect.getframeinfo(frame.f_back)
|
||||||
|
caller_path = caller_info[0]
|
||||||
|
plugin_name = os.path.basename(os.path.dirname(caller_path))
|
||||||
|
|
||||||
|
return get_config(plugin_name + "_alembic_version")
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(plugin_name=None, revision=None):
|
||||||
database_url = current_app.config.get("SQLALCHEMY_DATABASE_URI")
|
database_url = current_app.config.get("SQLALCHEMY_DATABASE_URI")
|
||||||
if database_url.startswith("sqlite"):
|
if database_url.startswith("sqlite"):
|
||||||
current_app.db.create_all()
|
current_app.db.create_all()
|
||||||
@@ -25,6 +37,13 @@ def upgrade(plugin_name=None):
|
|||||||
caller_path = caller_info[0]
|
caller_path = caller_info[0]
|
||||||
plugin_name = os.path.basename(os.path.dirname(caller_path))
|
plugin_name = os.path.basename(os.path.dirname(caller_path))
|
||||||
|
|
||||||
|
# Check if the plugin has migraitons
|
||||||
|
migrations_path = os.path.join(current_app.plugins_dir, plugin_name, "migrations")
|
||||||
|
if os.path.isdir(migrations_path) is False:
|
||||||
|
# Create any tables that the plugin may have
|
||||||
|
current_app.db.create_all()
|
||||||
|
return
|
||||||
|
|
||||||
engine = create_engine(database_url, poolclass=pool.NullPool)
|
engine = create_engine(database_url, poolclass=pool.NullPool)
|
||||||
conn = engine.connect()
|
conn = engine.connect()
|
||||||
context = MigrationContext.configure(conn)
|
context = MigrationContext.configure(conn)
|
||||||
@@ -32,19 +51,18 @@ def upgrade(plugin_name=None):
|
|||||||
|
|
||||||
# Find the list of migrations to run
|
# Find the list of migrations to run
|
||||||
config = Config()
|
config = Config()
|
||||||
config.set_main_option(
|
config.set_main_option("script_location", migrations_path)
|
||||||
"script_location",
|
config.set_main_option("version_locations", migrations_path)
|
||||||
os.path.join(current_app.plugins_dir, plugin_name, "migrations"),
|
|
||||||
)
|
|
||||||
config.set_main_option(
|
|
||||||
"version_locations",
|
|
||||||
os.path.join(current_app.plugins_dir, plugin_name, "migrations"),
|
|
||||||
)
|
|
||||||
script = ScriptDirectory.from_config(config)
|
script = ScriptDirectory.from_config(config)
|
||||||
|
|
||||||
# get current revision for plugin
|
# get current revision for plugin
|
||||||
lower = get_config(plugin_name + "_alembic_version")
|
lower = get_config(plugin_name + "_alembic_version")
|
||||||
upper = script.get_current_head()
|
|
||||||
|
# Do we upgrade to head or to a specific revision
|
||||||
|
if revision is None:
|
||||||
|
upper = script.get_current_head()
|
||||||
|
else:
|
||||||
|
upper = revision
|
||||||
|
|
||||||
# Apply from lower to upper
|
# Apply from lower to upper
|
||||||
revs = list(script.iterate_revisions(lower=lower, upper=upper))
|
revs = list(script.iterate_revisions(lower=lower, upper=upper))
|
||||||
|
|||||||
@@ -9,13 +9,15 @@ import dataset
|
|||||||
import six
|
import six
|
||||||
from alembic.util import CommandError
|
from alembic.util import CommandError
|
||||||
from flask import current_app as app
|
from flask import current_app as app
|
||||||
from flask_migrate import upgrade
|
from flask_migrate import upgrade as migration_upgrade
|
||||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||||
from sqlalchemy.sql import sqltypes
|
from sqlalchemy.sql import sqltypes
|
||||||
|
|
||||||
from CTFd import __version__ as CTFD_VERSION
|
from CTFd import __version__ as CTFD_VERSION
|
||||||
from CTFd.cache import cache
|
from CTFd.cache import cache
|
||||||
from CTFd.models import db, get_class_by_tablename
|
from CTFd.models import db, get_class_by_tablename
|
||||||
|
from CTFd.plugins import get_plugin_names
|
||||||
|
from CTFd.plugins.migrations import upgrade as plugin_upgrade, current as plugin_current
|
||||||
from CTFd.utils import get_app_config, set_config
|
from CTFd.utils import get_app_config, set_config
|
||||||
from CTFd.utils.exports.freeze import freeze_export
|
from CTFd.utils.exports.freeze import freeze_export
|
||||||
from CTFd.utils.migrations import (
|
from CTFd.utils.migrations import (
|
||||||
@@ -52,7 +54,7 @@ def export_ctf():
|
|||||||
"results": [{"version_num": get_current_revision()}],
|
"results": [{"version_num": get_current_revision()}],
|
||||||
"meta": {},
|
"meta": {},
|
||||||
}
|
}
|
||||||
result_file = six.BytesIO()
|
result_file = six.StringIO()
|
||||||
json.dump(result, result_file)
|
json.dump(result, result_file)
|
||||||
result_file.seek(0)
|
result_file.seek(0)
|
||||||
backup_zip.writestr("db/alembic_version.json", result_file.read())
|
backup_zip.writestr("db/alembic_version.json", result_file.read())
|
||||||
@@ -164,112 +166,139 @@ def import_ctf(backup, erase=True):
|
|||||||
"db/config.json",
|
"db/config.json",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# We want to insert certain database tables first so we are specifying
|
||||||
|
# the order with a list. The leftover tables are tables that are from a
|
||||||
|
# plugin (more likely) or a table where we do not care about insertion order
|
||||||
for item in first:
|
for item in first:
|
||||||
if item in members:
|
if item in members:
|
||||||
members.remove(item)
|
members.remove(item)
|
||||||
|
|
||||||
members = first + members
|
# Upgrade the database to the point in time that the import was taken from
|
||||||
|
migration_upgrade(revision=alembic_version)
|
||||||
|
|
||||||
upgrade(revision=alembic_version)
|
members.remove("db/alembic_version.json")
|
||||||
|
|
||||||
|
# Combine the database insertion code into a function so that we can pause
|
||||||
|
# insertion between official database tables and plugin tables
|
||||||
|
def insertion(table_filenames):
|
||||||
|
for member in table_filenames:
|
||||||
|
if member.startswith("db/"):
|
||||||
|
table_name = member[3:-5]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to open a file but skip if it doesn't exist.
|
||||||
|
data = backup.open(member).read()
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if data:
|
||||||
|
table = side_db[table_name]
|
||||||
|
|
||||||
|
saved = json.loads(data)
|
||||||
|
for entry in saved["results"]:
|
||||||
|
# This is a hack to get SQLite to properly accept datetime values from dataset
|
||||||
|
# See Issue #246
|
||||||
|
if sqlite:
|
||||||
|
direct_table = get_class_by_tablename(table.name)
|
||||||
|
for k, v in entry.items():
|
||||||
|
if isinstance(v, six.string_types):
|
||||||
|
# We only want to apply this hack to columns that are expecting a datetime object
|
||||||
|
try:
|
||||||
|
is_dt_column = (
|
||||||
|
type(getattr(direct_table, k).type)
|
||||||
|
== sqltypes.DateTime
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
is_dt_column = False
|
||||||
|
|
||||||
|
# If the table is expecting a datetime, we should check if the string is one and convert it
|
||||||
|
if is_dt_column:
|
||||||
|
match = re.match(
|
||||||
|
r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d",
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
entry[k] = datetime.datetime.strptime(
|
||||||
|
v, "%Y-%m-%dT%H:%M:%S.%f"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
match = re.match(
|
||||||
|
r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", v
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
entry[k] = datetime.datetime.strptime(
|
||||||
|
v, "%Y-%m-%dT%H:%M:%S"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
# From v2.0.0 to v2.1.0 requirements could have been a string or JSON because of a SQLAlchemy issue
|
||||||
|
# This is a hack to ensure we can still accept older exports. See #867
|
||||||
|
if member in (
|
||||||
|
"db/challenges.json",
|
||||||
|
"db/hints.json",
|
||||||
|
"db/awards.json",
|
||||||
|
):
|
||||||
|
requirements = entry.get("requirements")
|
||||||
|
if requirements and isinstance(
|
||||||
|
requirements, six.string_types
|
||||||
|
):
|
||||||
|
entry["requirements"] = json.loads(requirements)
|
||||||
|
|
||||||
|
try:
|
||||||
|
table.insert(entry)
|
||||||
|
except ProgrammingError:
|
||||||
|
# MariaDB does not like JSON objects and prefers strings because it internally
|
||||||
|
# represents JSON with LONGTEXT.
|
||||||
|
# See Issue #973
|
||||||
|
requirements = entry.get("requirements")
|
||||||
|
if requirements and isinstance(requirements, dict):
|
||||||
|
entry["requirements"] = json.dumps(requirements)
|
||||||
|
table.insert(entry)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
if postgres:
|
||||||
|
# This command is to set the next primary key ID for the re-inserted tables in Postgres. However,
|
||||||
|
# this command is very difficult to translate into SQLAlchemy code. Because Postgres is not
|
||||||
|
# officially supported, no major work will go into this functionality.
|
||||||
|
# https://stackoverflow.com/a/37972960
|
||||||
|
if '"' not in table_name and "'" not in table_name:
|
||||||
|
query = "SELECT setval(pg_get_serial_sequence('{table_name}', 'id'), coalesce(max(id)+1,1), false) FROM \"{table_name}\"".format( # nosec
|
||||||
|
table_name=table_name
|
||||||
|
)
|
||||||
|
side_db.engine.execute(query)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Table name {table_name} contains quotes".format(
|
||||||
|
table_name=table_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert data from official tables
|
||||||
|
insertion(first)
|
||||||
|
|
||||||
# Create tables created by plugins
|
# Create tables created by plugins
|
||||||
try:
|
try:
|
||||||
app.db.create_all()
|
# Run plugin migrations
|
||||||
|
plugins = get_plugin_names()
|
||||||
|
try:
|
||||||
|
for plugin in plugins:
|
||||||
|
revision = plugin_current(plugin_name=plugin)
|
||||||
|
plugin_upgrade(plugin_name=plugin, revision=revision)
|
||||||
|
finally:
|
||||||
|
# Create tables that don't have migrations
|
||||||
|
app.db.create_all()
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
if not postgres:
|
if not postgres:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
print("Allowing error during app.db.create_all() due to Postgres")
|
print("Allowing error during app.db.create_all() due to Postgres")
|
||||||
|
|
||||||
members.remove("db/alembic_version.json")
|
# Insert data for plugin tables
|
||||||
|
insertion(members)
|
||||||
|
|
||||||
for member in members:
|
# Bring plugin tables up to head revision
|
||||||
if member.startswith("db/"):
|
plugins = get_plugin_names()
|
||||||
table_name = member[3:-5]
|
for plugin in plugins:
|
||||||
|
plugin_upgrade(plugin_name=plugin)
|
||||||
try:
|
|
||||||
# Try to open a file but skip if it doesn't exist.
|
|
||||||
data = backup.open(member).read()
|
|
||||||
except KeyError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if data:
|
|
||||||
table = side_db[table_name]
|
|
||||||
|
|
||||||
saved = json.loads(data)
|
|
||||||
for entry in saved["results"]:
|
|
||||||
# This is a hack to get SQLite to properly accept datetime values from dataset
|
|
||||||
# See Issue #246
|
|
||||||
if sqlite:
|
|
||||||
direct_table = get_class_by_tablename(table.name)
|
|
||||||
for k, v in entry.items():
|
|
||||||
if isinstance(v, six.string_types):
|
|
||||||
# We only want to apply this hack to columns that are expecting a datetime object
|
|
||||||
try:
|
|
||||||
is_dt_column = (
|
|
||||||
type(getattr(direct_table, k).type)
|
|
||||||
== sqltypes.DateTime
|
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
is_dt_column = False
|
|
||||||
|
|
||||||
# If the table is expecting a datetime, we should check if the string is one and convert it
|
|
||||||
if is_dt_column:
|
|
||||||
match = re.match(
|
|
||||||
r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d", v
|
|
||||||
)
|
|
||||||
if match:
|
|
||||||
entry[k] = datetime.datetime.strptime(
|
|
||||||
v, "%Y-%m-%dT%H:%M:%S.%f"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
match = re.match(
|
|
||||||
r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", v
|
|
||||||
)
|
|
||||||
if match:
|
|
||||||
entry[k] = datetime.datetime.strptime(
|
|
||||||
v, "%Y-%m-%dT%H:%M:%S"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
# From v2.0.0 to v2.1.0 requirements could have been a string or JSON because of a SQLAlchemy issue
|
|
||||||
# This is a hack to ensure we can still accept older exports. See #867
|
|
||||||
if member in (
|
|
||||||
"db/challenges.json",
|
|
||||||
"db/hints.json",
|
|
||||||
"db/awards.json",
|
|
||||||
):
|
|
||||||
requirements = entry.get("requirements")
|
|
||||||
if requirements and isinstance(requirements, six.string_types):
|
|
||||||
entry["requirements"] = json.loads(requirements)
|
|
||||||
|
|
||||||
try:
|
|
||||||
table.insert(entry)
|
|
||||||
except ProgrammingError:
|
|
||||||
# MariaDB does not like JSON objects and prefers strings because it internally
|
|
||||||
# represents JSON with LONGTEXT.
|
|
||||||
# See Issue #973
|
|
||||||
requirements = entry.get("requirements")
|
|
||||||
if requirements and isinstance(requirements, dict):
|
|
||||||
entry["requirements"] = json.dumps(requirements)
|
|
||||||
table.insert(entry)
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
if postgres:
|
|
||||||
# This command is to set the next primary key ID for the re-inserted tables in Postgres. However,
|
|
||||||
# this command is very difficult to translate into SQLAlchemy code. Because Postgres is not
|
|
||||||
# officially supported, no major work will go into this functionality.
|
|
||||||
# https://stackoverflow.com/a/37972960
|
|
||||||
if '"' not in table_name and "'" not in table_name:
|
|
||||||
query = "SELECT setval(pg_get_serial_sequence('{table_name}', 'id'), coalesce(max(id)+1,1), false) FROM \"{table_name}\"".format( # nosec
|
|
||||||
table_name=table_name
|
|
||||||
)
|
|
||||||
side_db.engine.execute(query)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"Table name {table_name} contains quotes".format(
|
|
||||||
table_name=table_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extracting files
|
# Extracting files
|
||||||
files = [f for f in backup.namelist() if f.startswith("uploads/")]
|
files = [f for f in backup.namelist() if f.startswith("uploads/")]
|
||||||
@@ -288,7 +317,7 @@ def import_ctf(backup, erase=True):
|
|||||||
|
|
||||||
# Alembic sqlite support is lacking so we should just create_all anyway
|
# Alembic sqlite support is lacking so we should just create_all anyway
|
||||||
try:
|
try:
|
||||||
upgrade(revision="head")
|
migration_upgrade(revision="head")
|
||||||
except (OperationalError, CommandError, RuntimeError, SystemExit, Exception):
|
except (OperationalError, CommandError, RuntimeError, SystemExit, Exception):
|
||||||
app.db.create_all()
|
app.db.create_all()
|
||||||
stamp_latest_revision()
|
stamp_latest_revision()
|
||||||
|
|||||||
Reference in New Issue
Block a user