diff --git a/CTFd/auth.py b/CTFd/auth.py index f24bf342..cac5ef11 100644 --- a/CTFd/auth.py +++ b/CTFd/auth.py @@ -7,7 +7,7 @@ from flask import redirect, render_template, request, session, url_for from itsdangerous.exc import BadSignature, BadTimeSignature, SignatureExpired from CTFd.cache import clear_team_session, clear_user_session -from CTFd.models import FieldEntries, Teams, UserFields, Users, db +from CTFd.models import Teams, UserFieldEntries, UserFields, Users, db from CTFd.utils import config, email, get_app_config, get_config from CTFd.utils import user as current_user from CTFd.utils import validators @@ -288,7 +288,7 @@ def register(): db.session.flush() for field_id, value in entries.items(): - entry = FieldEntries( + entry = UserFieldEntries( field_id=field_id, value=value, user_id=user.id ) db.session.add(entry) diff --git a/CTFd/forms/auth.py b/CTFd/forms/auth.py index ad3a3403..77c4e1ec 100644 --- a/CTFd/forms/auth.py +++ b/CTFd/forms/auth.py @@ -5,7 +5,6 @@ from wtforms.validators import InputRequired from CTFd.forms import BaseForm from CTFd.forms.fields import SubmitField from CTFd.forms.users import attach_custom_user_fields, build_custom_user_fields -from CTFd.models import UserFields def RegistrationForm(*args, **kwargs): diff --git a/CTFd/forms/self.py b/CTFd/forms/self.py index aaf9be6f..13a32035 100644 --- a/CTFd/forms/self.py +++ b/CTFd/forms/self.py @@ -5,7 +5,6 @@ from wtforms.fields.html5 import DateField, URLField from CTFd.forms import BaseForm from CTFd.forms.fields import SubmitField from CTFd.forms.users import attach_custom_user_fields, build_custom_user_fields -from CTFd.models import FieldEntries, UserFields from CTFd.utils.countries import SELECT_COUNTRIES_LIST diff --git a/CTFd/forms/users.py b/CTFd/forms/users.py index 13c8e0a1..9b1882a5 100644 --- a/CTFd/forms/users.py +++ b/CTFd/forms/users.py @@ -4,7 +4,7 @@ from wtforms.validators import InputRequired from CTFd.forms import BaseForm from CTFd.forms.fields import SubmitField -from CTFd.models import FieldEntries, UserFields +from CTFd.models import UserFieldEntries, UserFields from CTFd.utils.countries import SELECT_COUNTRIES_LIST @@ -25,7 +25,7 @@ def build_custom_user_fields( # Only include preexisting values if asked if include_entries is True: - for f in FieldEntries.query.filter_by(**field_entries_kwargs).all(): + for f in UserFieldEntries.query.filter_by(**field_entries_kwargs).all(): user_fields[f.field_id] = f.value for field in new_fields: diff --git a/CTFd/models/__init__.py b/CTFd/models/__init__.py index 01e994e2..a849bb24 100644 --- a/CTFd/models/__init__.py +++ b/CTFd/models/__init__.py @@ -277,7 +277,7 @@ class Users(db.Model): team_id = db.Column(db.Integer, db.ForeignKey("teams.id")) field_entries = db.relationship( - "FieldEntries", foreign_keys="FieldEntries.user_id", lazy="joined" + "UserFieldEntries", foreign_keys="UserFieldEntries.user_id", lazy="joined" ) created = db.Column(db.DateTime, default=datetime.datetime.utcnow) @@ -818,12 +818,15 @@ class UserFields(Fields): class FieldEntries(db.Model): __tablename__ = "field_entries" id = db.Column(db.Integer, primary_key=True) + type = db.Column(db.String(80), default="standard") value = db.Column(db.Text) field_id = db.Column(db.Integer, db.ForeignKey("fields.id", ondelete="CASCADE")) - user_id = db.Column(db.Integer, db.ForeignKey("users.id", ondelete="CASCADE")) - user = db.relationship("Users", foreign_keys="FieldEntries.user_id") - field = db.relationship("Fields", foreign_keys="FieldEntries.field_id", lazy="joined") + field = db.relationship( + "Fields", foreign_keys="FieldEntries.field_id", lazy="joined" + ) + + __mapper_args__ = {"polymorphic_identity": "standard", "polymorphic_on": type} @hybrid_property def name(self): @@ -832,3 +835,9 @@ class FieldEntries(db.Model): @hybrid_property def description(self): return self.field.description + + +class UserFieldEntries(FieldEntries): + __mapper_args__ = {"polymorphic_identity": "user"} + user_id = db.Column(db.Integer, db.ForeignKey("users.id", ondelete="CASCADE")) + user = db.relationship("Users", foreign_keys="UserFieldEntries.user_id") diff --git a/CTFd/schemas/fields.py b/CTFd/schemas/fields.py index d1c13207..f2271e3d 100644 --- a/CTFd/schemas/fields.py +++ b/CTFd/schemas/fields.py @@ -1,7 +1,6 @@ -from marshmallow import fields, pre_load +from marshmallow import fields -from CTFd.models import FieldEntries, Fields, db, ma -from CTFd.utils.user import get_current_user, is_admin +from CTFd.models import Fields, UserFieldEntries, ma class FieldSchema(ma.ModelSchema): @@ -11,9 +10,9 @@ class FieldSchema(ma.ModelSchema): dump_only = ("id",) -class FieldEntriesSchema(ma.ModelSchema): +class UserFieldEntriesSchema(ma.ModelSchema): class Meta: - model = FieldEntries + model = UserFieldEntries include_fk = True load_only = ("id",) exclude = ("field", "user", "user_id") diff --git a/CTFd/schemas/users.py b/CTFd/schemas/users.py index 47c960b9..2dc95558 100644 --- a/CTFd/schemas/users.py +++ b/CTFd/schemas/users.py @@ -3,8 +3,8 @@ from marshmallow.fields import Nested from marshmallow_sqlalchemy import field_for from sqlalchemy.orm import load_only -from CTFd.models import FieldEntries, UserFields, Users, ma -from CTFd.schemas.fields import FieldEntriesSchema +from CTFd.models import UserFieldEntries, UserFields, Users, ma +from CTFd.schemas.fields import UserFieldEntriesSchema from CTFd.utils import get_config, string_types from CTFd.utils.crypto import verify_password from CTFd.utils.email import check_email_is_whitelisted @@ -53,7 +53,7 @@ class UserSchema(ma.ModelSchema): country = field_for(Users, "country", validate=[validate_country_code]) password = field_for(Users, "password") fields = Nested( - FieldEntriesSchema, partial=True, many=True, attribute="field_entries" + UserFieldEntriesSchema, partial=True, many=True, attribute="field_entries" ) @pre_load @@ -211,7 +211,7 @@ class UserSchema(ma.ModelSchema): field = UserFields.query.filter_by(id=field_id).first_or_404() # Get the existing field entry if one exists - entry = FieldEntries.query.filter_by( + entry = UserFieldEntries.query.filter_by( field_id=field.id, user_id=target_user.id ).first() if entry: @@ -221,7 +221,7 @@ class UserSchema(ma.ModelSchema): # Extremely dirty hack to prevent deleting previously provided data. # This needs a better soln. entries = ( - FieldEntries.query.options(load_only("id")) + UserFieldEntries.query.options(load_only("id")) .filter_by(user_id=target_user.id) .all() ) @@ -244,7 +244,7 @@ class UserSchema(ma.ModelSchema): ) # Get the existing field entry if one exists - entry = FieldEntries.query.filter_by( + entry = UserFieldEntries.query.filter_by( field_id=field.id, user_id=current_user.id ).first() @@ -255,7 +255,7 @@ class UserSchema(ma.ModelSchema): # Extremely dirty hack to prevent deleting previously provided data. # This needs a better soln. entries = ( - FieldEntries.query.options(load_only("id")) + UserFieldEntries.query.options(load_only("id")) .filter_by(user_id=current_user.id) .all() )