diff --git a/CTFd/plugins/__init__.py b/CTFd/plugins/__init__.py index 42da0da5..7e97824e 100644 --- a/CTFd/plugins/__init__.py +++ b/CTFd/plugins/__init__.py @@ -176,6 +176,7 @@ def init_plugins(app): app.admin_plugin_menu_bar = [] app.plugin_menu_bar = [] + app.plugins_dir = os.path.dirname(__file__) if app.config.get("SAFE_MODE", False) is False: modules = sorted(glob.glob(os.path.dirname(__file__) + "/*")) diff --git a/CTFd/plugins/dynamic_challenges/__init__.py b/CTFd/plugins/dynamic_challenges/__init__.py index b02af073..04ebc17f 100644 --- a/CTFd/plugins/dynamic_challenges/__init__.py +++ b/CTFd/plugins/dynamic_challenges/__init__.py @@ -15,6 +15,7 @@ from CTFd.models import ( db, ) from CTFd.plugins import register_plugin_assets_directory +from CTFd.plugins.migrations import upgrade from CTFd.plugins.challenges import CHALLENGE_CLASSES, BaseChallenge from CTFd.plugins.flags import get_flag_class from CTFd.utils.modes import get_model @@ -240,7 +241,7 @@ class DynamicValueChallenge(BaseChallenge): class DynamicChallenge(Challenges): __mapper_args__ = {"polymorphic_identity": "dynamic"} id = db.Column( - None, db.ForeignKey("challenges.id", ondelete="CASCADE"), primary_key=True + db.Integer, db.ForeignKey("challenges.id", ondelete="CASCADE"), primary_key=True ) initial = db.Column(db.Integer, default=0) minimum = db.Column(db.Integer, default=0) @@ -252,8 +253,7 @@ class DynamicChallenge(Challenges): def load(app): - # upgrade() - app.db.create_all() + upgrade() CHALLENGE_CLASSES["dynamic"] = DynamicValueChallenge register_plugin_assets_directory( app, base_path="/plugins/dynamic_challenges/assets/" diff --git a/migrations/versions/b37fb68807ea_add_cascading_delete_to_dynamic_.py b/CTFd/plugins/dynamic_challenges/migrations/b37fb68807ea_add_cascading_delete_to_dynamic_.py similarity index 82% rename from migrations/versions/b37fb68807ea_add_cascading_delete_to_dynamic_.py rename to CTFd/plugins/dynamic_challenges/migrations/b37fb68807ea_add_cascading_delete_to_dynamic_.py index af9a3fcc..2db3f9a5 100644 --- a/migrations/versions/b37fb68807ea_add_cascading_delete_to_dynamic_.py +++ b/CTFd/plugins/dynamic_challenges/migrations/b37fb68807ea_add_cascading_delete_to_dynamic_.py @@ -10,21 +10,21 @@ from alembic import op # revision identifiers, used by Alembic. revision = "b37fb68807ea" -down_revision = "1093835a1051" +down_revision = None branch_labels = None depends_on = None -def upgrade(): +def upgrade(op=None): # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, "dynamic_challenge", type_="foreignkey") + op.drop_constraint("dynamic_challenge_ibfk_1", "dynamic_challenge", type_="foreignkey") op.create_foreign_key( None, "dynamic_challenge", "challenges", ["id"], ["id"], ondelete="CASCADE" ) # ### end Alembic commands ### -def downgrade(): +def downgrade(op=None): # ### commands auto generated by Alembic - please adjust! ### op.drop_constraint(None, "dynamic_challenge", type_="foreignkey") op.create_foreign_key(None, "dynamic_challenge", "challenges", ["id"], ["id"]) diff --git a/CTFd/plugins/migrations.py b/CTFd/plugins/migrations.py new file mode 100644 index 00000000..5393da35 --- /dev/null +++ b/CTFd/plugins/migrations.py @@ -0,0 +1,65 @@ +import os + +from alembic.migration import MigrationContext +from flask import current_app +from flask_migrate import Migrate, stamp +from sqlalchemy import create_engine +from alembic.operations import Operations + +from alembic.script import ScriptDirectory +from alembic.config import Config +from sqlalchemy import pool + +from CTFd.utils import get_config, set_config + +import inspect +import os + + +def upgrade(plugin_name=None): + database_url = current_app.config.get("SQLALCHEMY_DATABASE_URI") + if database_url.startswith("sqlite"): + current_app.db.create_all() + return + + if plugin_name is None: + # Get the directory name of the plugin if unspecified + caller_path = inspect.stack()[1].filename + plugin_name = os.path.basename(os.path.dirname(caller_path)) + + engine = create_engine( + database_url, poolclass=pool.NullPool + ) + conn = engine.connect() + context = MigrationContext.configure(conn) + op = Operations(context) + + # Find the list of migrations to run + config = Config() + config.set_main_option( + "script_location", + os.path.join(current_app.plugins_dir, plugin_name, "migrations"), + ) + config.set_main_option( + "version_locations", + os.path.join(current_app.plugins_dir, plugin_name, "migrations"), + ) + script = ScriptDirectory.from_config(config) + + # get current revision for plugin + lower = get_config(plugin_name + "_alembic_version") + upper = script.get_current_head() + + # Apply from lower to upper + revs = list(script.iterate_revisions(lower=lower, upper=upper)) + revs.reverse() + + try: + for r in revs: + with context.begin_transaction(): + r.module.upgrade(op=op) + finally: + conn.close() + + # Set the new latest revision + set_config(plugin_name + "_alembic_version", upper)