Fix issue with partially updating field data

This commit is contained in:
Kevin Chung
2020-08-17 03:51:45 -04:00
parent baa5918134
commit 96c50f26b1

View File

@@ -1,6 +1,7 @@
from marshmallow import ValidationError, post_dump, pre_load, validate
from marshmallow.fields import Nested
from marshmallow_sqlalchemy import field_for
from sqlalchemy.orm import load_only
from CTFd.models import FieldEntries, UserFields, Users, ma
from CTFd.schemas.fields import FieldEntriesSchema
@@ -51,7 +52,9 @@ class UserSchema(ma.ModelSchema):
)
country = field_for(Users, "country", validate=[validate_country_code])
password = field_for(Users, "password")
fields = Nested(FieldEntriesSchema, partial=True, many=True, attribute="field_entries")
fields = Nested(
FieldEntriesSchema, partial=True, many=True, attribute="field_entries"
)
@pre_load
def validate_name(self, data):
@@ -196,6 +199,7 @@ class UserSchema(ma.ModelSchema):
user_id = data.get("id")
if user_id:
target_user = Users.query.filter_by(id=data["id"]).first()
provided_ids = []
for f in fields:
f.pop("id", None)
field_id = f.get("field_id")
@@ -209,10 +213,23 @@ class UserSchema(ma.ModelSchema):
).first()
if entry:
f["id"] = entry.id
# Extremely dirty hack to prevent deleting previously provided data.
# This needs a better soln.
entries = (
FieldEntries.query.options(load_only("id"))
.filter_by(user_id=current_user.id)
.all()
)
print(entries)
for entry in entries:
if entry.id not in provided_ids:
fields.append({"id": entry.id})
else:
# Marshmallow automatically links the fields to newly created users
pass
else:
provided_ids = []
for f in fields:
# Remove any existing set
f.pop("id", None)
@@ -233,6 +250,19 @@ class UserSchema(ma.ModelSchema):
if entry:
f["id"] = entry.id
provided_ids.append(entry.id)
# Extremely dirty hack to prevent deleting previously provided data.
# This needs a better soln.
entries = (
FieldEntries.query.options(load_only("id"))
.filter_by(user_id=current_user.id)
.all()
)
print(entries)
for entry in entries:
if entry.id not in provided_ids:
fields.append({"id": entry.id})
@post_dump()
def process_fields(self, data):