|
1 | 1 | import json |
| 2 | +import warnings |
2 | 3 | from collections import OrderedDict |
3 | | -from collections.abc import Iterable |
4 | 4 |
|
5 | 5 | import inflection |
6 | 6 | from django.core.exceptions import ImproperlyConfigured |
7 | 7 | from django.urls import NoReverseMatch |
8 | 8 | from django.utils.translation import gettext_lazy as _ |
9 | | -from rest_framework.fields import MISSING_ERROR_MESSAGE, SkipField |
| 9 | +from rest_framework.fields import MISSING_ERROR_MESSAGE, Field, SkipField |
10 | 10 | from rest_framework.relations import MANY_RELATION_KWARGS |
11 | 11 | from rest_framework.relations import ManyRelatedField as DRFManyRelatedField |
12 | 12 | from rest_framework.relations import PrimaryKeyRelatedField, RelatedField |
@@ -347,51 +347,63 @@ def to_internal_value(self, data): |
347 | 347 | return super(ResourceRelatedField, self).to_internal_value(data['id']) |
348 | 348 |
|
349 | 349 |
|
350 | | -class SerializerMethodResourceRelatedField(ResourceRelatedField): |
| 350 | +class SerializerMethodFieldBase(Field): |
| 351 | + def __init__(self, method_name=None, **kwargs): |
| 352 | + if not method_name and kwargs.get('source'): |
| 353 | + method_name = kwargs.pop('source') |
| 354 | + warnings.warn(DeprecationWarning( |
| 355 | + "'source' argument of {cls} is deprecated, use 'method_name' " |
| 356 | + "as in SerializerMethodField".format(cls=self.__class__.__name__)), stacklevel=3) |
| 357 | + self.method_name = method_name |
| 358 | + kwargs['source'] = '*' |
| 359 | + kwargs['read_only'] = True |
| 360 | + super().__init__(**kwargs) |
| 361 | + |
| 362 | + def bind(self, field_name, parent): |
| 363 | + default_method_name = 'get_{field_name}'.format(field_name=field_name) |
| 364 | + if self.method_name is None: |
| 365 | + self.method_name = default_method_name |
| 366 | + super().bind(field_name, parent) |
| 367 | + |
| 368 | + def get_attribute(self, instance): |
| 369 | + serializer_method = getattr(self.parent, self.method_name) |
| 370 | + return serializer_method(instance) |
| 371 | + |
| 372 | + |
| 373 | +class ManySerializerMethodResourceRelatedField(SerializerMethodFieldBase, ResourceRelatedField): |
| 374 | + def __init__(self, child_relation=None, *args, **kwargs): |
| 375 | + assert child_relation is not None, '`child_relation` is a required argument.' |
| 376 | + self.child_relation = child_relation |
| 377 | + super().__init__(**kwargs) |
| 378 | + self.child_relation.bind(field_name='', parent=self) |
| 379 | + |
| 380 | + def to_representation(self, value): |
| 381 | + return [self.child_relation.to_representation(item) for item in value] |
| 382 | + |
| 383 | + |
| 384 | +class SerializerMethodResourceRelatedField(SerializerMethodFieldBase, ResourceRelatedField): |
351 | 385 | """ |
352 | 386 | Allows us to use serializer method RelatedFields |
353 | 387 | with return querysets |
354 | 388 | """ |
355 | | - def __new__(cls, *args, **kwargs): |
356 | | - """ |
357 | | - We override this because getting serializer methods |
358 | | - fails at the base class when many=True |
359 | | - """ |
360 | | - if kwargs.pop('many', False): |
361 | | - return cls.many_init(*args, **kwargs) |
362 | | - return super(ResourceRelatedField, cls).__new__(cls, *args, **kwargs) |
363 | 389 |
|
364 | | - def __init__(self, child_relation=None, *args, **kwargs): |
365 | | - model = kwargs.pop('model', None) |
366 | | - if child_relation is not None: |
367 | | - self.child_relation = child_relation |
368 | | - if model: |
369 | | - self.model = model |
370 | | - super(SerializerMethodResourceRelatedField, self).__init__(*args, **kwargs) |
| 390 | + many_kwargs = [*MANY_RELATION_KWARGS, *LINKS_PARAMS, 'method_name', 'model'] |
| 391 | + many_cls = ManySerializerMethodResourceRelatedField |
371 | 392 |
|
372 | 393 | @classmethod |
373 | 394 | def many_init(cls, *args, **kwargs): |
374 | | - list_kwargs = {k: kwargs.pop(k) for k in LINKS_PARAMS if k in kwargs} |
375 | | - list_kwargs['child_relation'] = cls(*args, **kwargs) |
376 | | - for key in kwargs.keys(): |
377 | | - if key in ('model',) + MANY_RELATION_KWARGS: |
| 395 | + list_kwargs = {'child_relation': cls(**kwargs)} |
| 396 | + for key in kwargs: |
| 397 | + if key in cls.many_kwargs: |
378 | 398 | list_kwargs[key] = kwargs[key] |
379 | | - return cls(**list_kwargs) |
| 399 | + return cls.many_cls(**list_kwargs) |
380 | 400 |
|
381 | | - def get_attribute(self, instance): |
382 | | - # check for a source fn defined on the serializer instead of the model |
383 | | - if self.source and hasattr(self.parent, self.source): |
384 | | - serializer_method = getattr(self.parent, self.source) |
385 | | - if hasattr(serializer_method, '__call__'): |
386 | | - return serializer_method(instance) |
387 | | - return super(SerializerMethodResourceRelatedField, self).get_attribute(instance) |
388 | 401 |
|
389 | | - def to_representation(self, value): |
390 | | - if isinstance(value, Iterable): |
391 | | - base = super(SerializerMethodResourceRelatedField, self) |
392 | | - return [base.to_representation(x) for x in value] |
393 | | - return super(SerializerMethodResourceRelatedField, self).to_representation(value) |
| 402 | +class ManySerializerMethodHyperlinkedRelatedField(SkipDataMixin, |
| 403 | + ManySerializerMethodResourceRelatedField): |
| 404 | + pass |
394 | 405 |
|
395 | 406 |
|
396 | | -class SerializerMethodHyperlinkedRelatedField(SkipDataMixin, SerializerMethodResourceRelatedField): |
397 | | - pass |
| 407 | +class SerializerMethodHyperlinkedRelatedField(SkipDataMixin, |
| 408 | + SerializerMethodResourceRelatedField): |
| 409 | + many_cls = ManySerializerMethodHyperlinkedRelatedField |
0 commit comments