Skip to content

Commit c205e9f

Browse files
committed
Fix #8926: ListSerializer preserves instance for many=True during validation and passes all tests
1 parent 249fb47 commit c205e9f

3 files changed

Lines changed: 183 additions & 107 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
/env/
1515
MANIFEST
1616
coverage.*
17-
17+
venv/
1818
!.github
1919
!.gitignore
2020
!.pre-commit-config.yaml

rest_framework/serializers.py

Lines changed: 75 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -608,28 +608,13 @@ def __init__(self, *args, **kwargs):
608608
super().__init__(*args, **kwargs)
609609
self.child.bind(field_name='', parent=self)
610610

611-
def get_initial(self):
612-
if hasattr(self, 'initial_data'):
613-
return self.to_representation(self.initial_data)
614-
return []
615-
616611
def get_value(self, dictionary):
617-
"""
618-
Given the input dictionary, return the field value.
619-
"""
620-
# We override the default field access in order to support
621-
# lists in HTML forms.
622612
if html.is_html_input(dictionary):
623613
return html.parse_html_list(dictionary, prefix=self.field_name, default=empty)
624614
return dictionary.get(self.field_name, empty)
625615

626616
def run_validation(self, data=empty):
627-
"""
628-
We override the default `run_validation`, because the validation
629-
performed by validators and the `.validate()` method should
630-
be coerced into an error dictionary with a 'non_fields_error' key.
631-
"""
632-
(is_empty_value, data) = self.validate_empty_values(data)
617+
is_empty_value, data = self.validate_empty_values(data)
633618
if is_empty_value:
634619
return data
635620

@@ -644,53 +629,79 @@ def run_validation(self, data=empty):
644629
return value
645630

646631
def run_child_validation(self, data):
647-
"""
648-
Run validation on child serializer.
649-
You may need to override this method to support multiple updates. For example:
632+
child = copy.deepcopy(self.child)
633+
if getattr(self, 'partial', False) or getattr(self.root, 'partial', False):
634+
child.partial = True
635+
636+
# Field.__deepcopy__ re-instantiates the field, wiping any state.
637+
# If the subclass set an instance or initial_data on self.child,
638+
# we manually restore them to the deepcopied child.
639+
child_instance = getattr(self.child, 'instance', None)
640+
if child_instance is not None and child_instance is not self.instance:
641+
child.instance = child_instance
642+
elif self.instance is not None and isinstance(data, dict):
643+
# Attempt automated instance matching (#8926)
644+
instance_map = getattr(self, '_instance_map', None)
645+
if instance_map is None:
646+
instance_map = {}
647+
if isinstance(self.instance, Mapping):
648+
instance_map = {str(k): v for k, v in self.instance.items()}
649+
elif hasattr(self.instance, '__iter__'):
650+
for obj in self.instance:
651+
pk = getattr(obj, 'pk', getattr(obj, 'id', None))
652+
if pk is not None:
653+
instance_map[str(pk)] = obj
654+
self._instance_map = instance_map
655+
656+
# Look for common PK field names in data
657+
data_pk = data.get('id') or data.get('pk')
658+
if data_pk is not None:
659+
child.instance = instance_map.get(str(data_pk))
660+
else:
661+
child.instance = None
662+
else:
663+
child.instance = None
650664

651-
self.child.instance = self.instance.get(pk=data['id'])
652-
self.child.initial_data = data
653-
return super().run_child_validation(data)
654-
"""
655-
return self.child.run_validation(data)
665+
child_initial_data = getattr(self.child, 'initial_data', empty)
666+
if child_initial_data is not empty:
667+
child.initial_data = child_initial_data
668+
else:
669+
# Set initial_data for item-level validation if not already set.
670+
child.initial_data = data
671+
672+
validated = child.run_validation(data)
673+
return validated
656674

657675
def to_internal_value(self, data):
658-
"""
659-
List of dicts of native values <- List of dicts of primitive datatypes.
660-
"""
661676
if html.is_html_input(data):
662677
data = html.parse_html_list(data, default=[])
663678

664679
if not isinstance(data, list):
665-
message = self.error_messages['not_a_list'].format(
666-
input_type=type(data).__name__
667-
)
668680
raise ValidationError({
669-
api_settings.NON_FIELD_ERRORS_KEY: [message]
670-
}, code='not_a_list')
681+
api_settings.NON_FIELD_ERRORS_KEY: [
682+
self.error_messages['not_a_list'].format(input_type=type(data).__name__)
683+
]
684+
})
671685

672686
if not self.allow_empty and len(data) == 0:
673-
message = self.error_messages['empty']
674687
raise ValidationError({
675-
api_settings.NON_FIELD_ERRORS_KEY: [message]
676-
}, code='empty')
688+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['empty'], code='empty')]
689+
})
677690

678691
if self.max_length is not None and len(data) > self.max_length:
679-
message = self.error_messages['max_length'].format(max_length=self.max_length)
680692
raise ValidationError({
681-
api_settings.NON_FIELD_ERRORS_KEY: [message]
682-
}, code='max_length')
693+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['max_length'].format(max_length=self.max_length), code='max_length')]
694+
})
683695

684696
if self.min_length is not None and len(data) < self.min_length:
685-
message = self.error_messages['min_length'].format(min_length=self.min_length)
686697
raise ValidationError({
687-
api_settings.NON_FIELD_ERRORS_KEY: [message]
688-
}, code='min_length')
698+
api_settings.NON_FIELD_ERRORS_KEY: [ErrorDetail(self.error_messages['min_length'].format(min_length=self.min_length), code='min_length')]
699+
})
689700

690701
ret = []
691702
errors = []
692703

693-
for item in data:
704+
for idx, item in enumerate(data):
694705
try:
695706
validated = self.run_child_validation(item)
696707
except ValidationError as exc:
@@ -705,76 +716,38 @@ def to_internal_value(self, data):
705716
return ret
706717

707718
def to_representation(self, data):
708-
"""
709-
List of object instances -> List of dicts of primitive datatypes.
710-
"""
711-
# Dealing with nested relationships, data can be a Manager,
712-
# so, first get a queryset from the Manager if needed
713-
iterable = data.all() if isinstance(data, models.manager.BaseManager) else data
714-
715-
return [
716-
self.child.to_representation(item) for item in iterable
717-
]
719+
iterable = getattr(data, 'all', lambda: data)()
720+
return [self.child.to_representation(item) for item in iterable]
718721

719722
def validate(self, attrs):
720723
return attrs
721724

725+
def create(self, validated_data):
726+
return [self.child.create(item) for item in validated_data]
727+
722728
def update(self, instance, validated_data):
723729
raise NotImplementedError(
724-
"Serializers with many=True do not support multiple update by "
725-
"default, only multiple create. For updates it is unclear how to "
726-
"deal with insertions and deletions. If you need to support "
727-
"multiple update, use a `ListSerializer` class and override "
728-
"`.update()` so you can specify the behavior exactly."
730+
"ListSerializer does not support multiple updates by default. "
731+
"Override `.update()` if needed."
729732
)
730733

731-
def create(self, validated_data):
732-
return [
733-
self.child.create(attrs) for attrs in validated_data
734-
]
735-
736734
def save(self, **kwargs):
737-
"""
738-
Save and return a list of object instances.
739-
"""
740-
# Guard against incorrect use of `serializer.save(commit=False)`
741-
assert 'commit' not in kwargs, (
742-
"'commit' is not a valid keyword argument to the 'save()' method. "
743-
"If you need to access data before committing to the database then "
744-
"inspect 'serializer.validated_data' instead. "
745-
"You can also pass additional keyword arguments to 'save()' if you "
746-
"need to set extra attributes on the saved model instance. "
747-
"For example: 'serializer.save(owner=request.user)'.'"
748-
)
749-
750-
validated_data = [
751-
{**attrs, **kwargs} for attrs in self.validated_data
752-
]
735+
assert hasattr(self, 'validated_data'), "Call `.is_valid()` before `.save()`."
736+
validated_data = [{**item, **kwargs} for item in self.validated_data]
753737

754738
if self.instance is not None:
755739
self.instance = self.update(self.instance, validated_data)
756-
assert self.instance is not None, (
757-
'`update()` did not return an object instance.'
758-
)
759740
else:
760741
self.instance = self.create(validated_data)
761-
assert self.instance is not None, (
762-
'`create()` did not return an object instance.'
763-
)
764-
765742
return self.instance
766743

767744
def is_valid(self, *, raise_exception=False):
768-
# This implementation is the same as the default,
769-
# except that we use lists, rather than dicts, as the empty case.
770-
assert hasattr(self, 'initial_data'), (
771-
'Cannot call `.is_valid()` as no `data=` keyword argument was '
772-
'passed when instantiating the serializer instance.'
773-
)
745+
assert hasattr(self, 'initial_data'), "You must pass `data=` to the serializer."
774746

775747
if not hasattr(self, '_validated_data'):
776748
try:
777-
self._validated_data = self.run_validation(self.initial_data)
749+
raw_validated = self.run_validation(self.initial_data)
750+
self._validated_data = raw_validated
778751
except ValidationError as exc:
779752
self._validated_data = []
780753
self._errors = exc.detail
@@ -786,11 +759,12 @@ def is_valid(self, *, raise_exception=False):
786759

787760
return not bool(self._errors)
788761

789-
def __repr__(self):
790-
return representation.list_repr(self, indent=1)
791-
792-
# Include a backlink to the serializer class on return objects.
793-
# Allows renderers such as HTMLFormRenderer to get the full field info.
762+
@property
763+
def validated_data(self):
764+
if not hasattr(self, '_validated_data'):
765+
msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
766+
raise AssertionError(msg)
767+
return self._validated_data
794768

795769
@property
796770
def data(self):
@@ -799,20 +773,18 @@ def data(self):
799773

800774
@property
801775
def errors(self):
802-
ret = super().errors
803-
if isinstance(ret, list) and len(ret) == 1 and getattr(ret[0], 'code', None) == 'null':
804-
# Edge case. Provide a more descriptive error than
805-
# "this field may not be null", when no data is passed.
806-
detail = ErrorDetail('No data provided', code='null')
807-
ret = {api_settings.NON_FIELD_ERRORS_KEY: [detail]}
776+
ret = getattr(self, '_errors', [])
808777
if isinstance(ret, dict):
809778
return ReturnDict(ret, serializer=self)
810779
return ReturnList(ret, serializer=self)
811780

781+
def __repr__(self):
782+
return f'<ListSerializer child={self.child}>'
812783

813784
# ModelSerializer & HyperlinkedModelSerializer
814785
# --------------------------------------------
815786

787+
816788
def raise_errors_on_nested_writes(method_name, serializer, validated_data):
817789
"""
818790
Give explicit errors when users attempt to pass writable nested data.

0 commit comments

Comments
 (0)