diff --git a/CTFd/schemas/users.py b/CTFd/schemas/users.py index 359a97f3..07c49f65 100644 --- a/CTFd/schemas/users.py +++ b/CTFd/schemas/users.py @@ -1,6 +1,7 @@ from marshmallow import ValidationError, post_dump, pre_load, validate from marshmallow.fields import Nested from marshmallow_sqlalchemy import field_for +from sqlalchemy.orm import load_only from CTFd.models import FieldEntries, UserFields, Users, ma from CTFd.schemas.fields import FieldEntriesSchema @@ -51,7 +52,9 @@ class UserSchema(ma.ModelSchema): ) country = field_for(Users, "country", validate=[validate_country_code]) password = field_for(Users, "password") - fields = Nested(FieldEntriesSchema, partial=True, many=True, attribute="field_entries") + fields = Nested( + FieldEntriesSchema, partial=True, many=True, attribute="field_entries" + ) @pre_load def validate_name(self, data): @@ -196,6 +199,7 @@ class UserSchema(ma.ModelSchema): user_id = data.get("id") if user_id: target_user = Users.query.filter_by(id=data["id"]).first() + provided_ids = [] for f in fields: f.pop("id", None) field_id = f.get("field_id") @@ -209,10 +213,23 @@ class UserSchema(ma.ModelSchema): ).first() if entry: f["id"] = entry.id + + # Extremely dirty hack to prevent deleting previously provided data. + # This needs a better soln. + entries = ( + FieldEntries.query.options(load_only("id")) + .filter_by(user_id=current_user.id) + .all() + ) + print(entries) + for entry in entries: + if entry.id not in provided_ids: + fields.append({"id": entry.id}) else: # Marshmallow automatically links the fields to newly created users pass else: + provided_ids = [] for f in fields: # Remove any existing set f.pop("id", None) @@ -233,6 +250,19 @@ class UserSchema(ma.ModelSchema): if entry: f["id"] = entry.id + provided_ids.append(entry.id) + + # Extremely dirty hack to prevent deleting previously provided data. + # This needs a better soln. + entries = ( + FieldEntries.query.options(load_only("id")) + .filter_by(user_id=current_user.id) + .all() + ) + print(entries) + for entry in entries: + if entry.id not in provided_ids: + fields.append({"id": entry.id}) @post_dump() def process_fields(self, data):