diff --git a/CTFd/auth.py b/CTFd/auth.py index a43ce089..c4def6ff 100644 --- a/CTFd/auth.py +++ b/CTFd/auth.py @@ -252,17 +252,18 @@ def register(): if valid_affiliation is False: errors.append("Please provide a shorter affiliation") - from CTFd.models import Fields + from CTFd.models import UserFields fields = {} - for field in Fields.query.all(): + for field in UserFields.query.all(): fields[field.id] = field entries = {} for field_id, field in fields.items(): value = request.form.get(f"fields[{field_id}]", "").strip() if field.required is True and (value is None or value == ""): - errors.append("Please enter in all required fields") + errors.append("Please provide all required fields") + break entries[field_id] = value if len(errors) > 0: diff --git a/CTFd/forms/auth.py b/CTFd/forms/auth.py index 8066759c..74fa3aa2 100644 --- a/CTFd/forms/auth.py +++ b/CTFd/forms/auth.py @@ -4,7 +4,7 @@ from wtforms.validators import InputRequired from CTFd.forms import BaseForm from CTFd.forms.fields import SubmitField -from CTFd.models import Fields +from CTFd.models import UserFields def RegistrationForm(*args, **kwargs): @@ -17,13 +17,13 @@ def RegistrationForm(*args, **kwargs): @property def extra(self): fields = [] - new_fields = Fields.query.all() + new_fields = UserFields.query.all() for field in new_fields: entry = (field.name, getattr(self, f"fields[{field.id}]")) fields.append(entry) return fields - new_fields = Fields.query.all() + new_fields = UserFields.query.all() for field in new_fields: setattr(_RegistrationForm, f"fields[{field.id}]", StringField(field.name)) diff --git a/CTFd/forms/self.py b/CTFd/forms/self.py index 50b2c2d9..f0484631 100644 --- a/CTFd/forms/self.py +++ b/CTFd/forms/self.py @@ -4,7 +4,7 @@ from wtforms.fields.html5 import DateField, URLField from CTFd.forms import BaseForm from CTFd.forms.fields import SubmitField -from CTFd.models import FieldEntries, Fields +from CTFd.models import FieldEntries, UserFields from CTFd.utils.countries import SELECT_COUNTRIES_LIST @@ -22,7 +22,7 @@ def SettingsForm(*args, **kwargs): @property def extra(self): fields = [] - new_fields = Fields.query.all() + new_fields = UserFields.query.filter_by(editable=True).all() user_fields = {} for f in FieldEntries.query.filter_by(user_id=session["id"]).all(): @@ -35,7 +35,7 @@ def SettingsForm(*args, **kwargs): fields.append(entry) return fields - new_fields = Fields.query.all() + new_fields = UserFields.query.filter_by(editable=True).all() for field in new_fields: setattr(_SettingsForm, f"fields[{field.id}]", StringField(field.name)) diff --git a/CTFd/forms/users.py b/CTFd/forms/users.py index 8655f386..185408f4 100644 --- a/CTFd/forms/users.py +++ b/CTFd/forms/users.py @@ -4,7 +4,7 @@ from wtforms.validators import InputRequired from CTFd.forms import BaseForm from CTFd.forms.fields import SubmitField -from CTFd.models import FieldEntries, Fields +from CTFd.models import FieldEntries, UserFields from CTFd.utils.countries import SELECT_COUNTRIES_LIST @@ -62,7 +62,7 @@ def UserEditForm(*args, **kwargs): @property def extra(self): fields = [] - new_fields = Fields.query.all() + new_fields = UserFields.query.all() user_fields = {} for f in FieldEntries.query.filter_by(user_id=self.obj.id).all(): @@ -84,7 +84,7 @@ def UserEditForm(*args, **kwargs): if obj: self.obj = obj - new_fields = Fields.query.all() + new_fields = UserFields.query.all() for field in new_fields: setattr(_UserEditForm, f"fields[{field.id}]", StringField(field.name)) @@ -98,7 +98,7 @@ def UserCreateForm(*args, **kwargs): @property def extra(self): fields = [] - new_fields = Fields.query.all() + new_fields = UserFields.query.all() for field in new_fields: form_field = getattr(self, f"fields[{field.id}]") @@ -106,7 +106,7 @@ def UserCreateForm(*args, **kwargs): fields.append(entry) return fields - new_fields = Fields.query.all() + new_fields = UserFields.query.all() for field in new_fields: setattr(_UserCreateForm, f"fields[{field.id}]", StringField(field.name)) diff --git a/CTFd/models/__init__.py b/CTFd/models/__init__.py index 29e732cc..5a10f9a8 100644 --- a/CTFd/models/__init__.py +++ b/CTFd/models/__init__.py @@ -825,7 +825,7 @@ class Fields(db.Model): __mapper_args__ = {"polymorphic_identity": "standard", "polymorphic_on": type} -class UserFields(Comments): +class UserFields(Fields): __mapper_args__ = {"polymorphic_identity": "user"} diff --git a/CTFd/schemas/users.py b/CTFd/schemas/users.py index 2f259d98..359a97f3 100644 --- a/CTFd/schemas/users.py +++ b/CTFd/schemas/users.py @@ -2,7 +2,7 @@ from marshmallow import ValidationError, post_dump, pre_load, validate from marshmallow.fields import Nested from marshmallow_sqlalchemy import field_for -from CTFd.models import FieldEntries, Fields, Users, ma +from CTFd.models import FieldEntries, UserFields, Users, ma from CTFd.schemas.fields import FieldEntriesSchema from CTFd.utils import get_config, string_types from CTFd.utils.crypto import verify_password @@ -201,7 +201,7 @@ class UserSchema(ma.ModelSchema): field_id = f.get("field_id") # # Check that we have an existing field for this. May be unnecessary b/c the foriegn key should enforce - field = Fields.query.filter_by(id=field_id).first_or_404() + field = UserFields.query.filter_by(id=field_id).first_or_404() # Get the existing field entry if one exists entry = FieldEntries.query.filter_by( @@ -219,7 +219,7 @@ class UserSchema(ma.ModelSchema): field_id = f.get("field_id") # # Check that we have an existing field for this. May be unnecessary b/c the foriegn key should enforce - field = Fields.query.filter_by(id=field_id).first_or_404() + field = UserFields.query.filter_by(id=field_id).first_or_404() if field.editable is False: raise ValidationError( @@ -246,7 +246,7 @@ class UserSchema(ma.ModelSchema): """ # Gather all possible fields removed_field_ids = [] - fields = Fields.query.all() + fields = UserFields.query.all() # Select fields for removal based on current view and properties of the field for field in fields: