From e20af6076756767164403cd07d8d558da70ad850 Mon Sep 17 00:00:00 2001 From: Felipe Martins Diel Date: Sat, 27 Jan 2024 03:34:08 -0300 Subject: [PATCH] Improve ListSerializer --- rest_framework/serializers.py | 45 +++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 77c181b6cc..dd7bab845c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -28,7 +28,7 @@ from rest_framework.compat import postgres_fields from rest_framework.exceptions import ErrorDetail, ValidationError -from rest_framework.fields import get_error_detail +from rest_framework.fields import get_attribute, get_error_detail from rest_framework.settings import api_settings from rest_framework.utils import html, model_meta, representation from rest_framework.utils.field_mapping import ( @@ -610,11 +610,6 @@ def __init__(self, *args, **kwargs): assert self.child is not None, '`child` is a required argument.' assert not inspect.isclass(self.child), '`child` has not been instantiated.' - instance = kwargs.get('instance', []) - data = kwargs.get('data', []) - if instance and data: - assert len(data) == len(instance), 'Data and instance should have same length' - super().__init__(*args, **kwargs) self.child.bind(field_name='', parent=self) @@ -700,13 +695,37 @@ def to_internal_value(self, data): ret = [] errors = [] - for idx, item in enumerate(data): - if ( - hasattr(self, 'instance') - and self.instance - and len(self.instance) > idx - ): - self.child.instance = self.instance[idx] + if not self.instance and self.parent and self.parent.instance: + rel_mgr = get_attribute(self.parent.instance, self.source_attrs) + self.instance = rel_mgr.all() if rel_mgr else None + + pk_field_name = next( + field_name + for field_name, field in self.child.fields.items() + if field.source == self.child.Meta.model._meta.pk.attname + ) + + item_instance_dict = {} + if self.instance: + item_pks = [] + + for item in data: + item_pk = item.get(pk_field_name) + if item_pk: + item_pks.append(item_pk) + + for item_instance in self.instance.filter(pk__in=item_pks): + item_instance_dict[item_instance.pk] = item_instance + + for item in data: + self.child.initial_data = item + + if self.instance: + item_pk = item.get(pk_field_name) + self.child.instance = ( + item_instance_dict.get(item_pk) if item_pk else None + ) + try: validated = self.run_child_validation(item) except ValidationError as exc: