|
1 | 1 | import json
|
| 2 | +import warnings |
2 | 3 | from collections import OrderedDict
|
3 |
| -from collections.abc import Iterable |
4 | 4 |
|
5 | 5 | import inflection
|
6 | 6 | from django.core.exceptions import ImproperlyConfigured
|
7 | 7 | from django.urls import NoReverseMatch
|
8 | 8 | from django.utils.translation import gettext_lazy as _
|
9 |
| -from rest_framework.fields import MISSING_ERROR_MESSAGE, SkipField |
| 9 | +from rest_framework.fields import MISSING_ERROR_MESSAGE, Field, SkipField |
10 | 10 | from rest_framework.relations import MANY_RELATION_KWARGS
|
11 | 11 | from rest_framework.relations import ManyRelatedField as DRFManyRelatedField
|
12 | 12 | from rest_framework.relations import PrimaryKeyRelatedField, RelatedField
|
@@ -347,51 +347,63 @@ def to_internal_value(self, data):
|
347 | 347 | return super(ResourceRelatedField, self).to_internal_value(data['id'])
|
348 | 348 |
|
349 | 349 |
|
350 |
| -class SerializerMethodResourceRelatedField(ResourceRelatedField): |
| 350 | +class SerializerMethodFieldBase(Field): |
| 351 | + def __init__(self, method_name=None, **kwargs): |
| 352 | + if not method_name and kwargs.get('source'): |
| 353 | + method_name = kwargs.pop('source') |
| 354 | + warnings.warn(DeprecationWarning( |
| 355 | + "'source' argument of {cls} is deprecated, use 'method_name' " |
| 356 | + "as in SerializerMethodField".format(cls=self.__class__.__name__)), stacklevel=3) |
| 357 | + self.method_name = method_name |
| 358 | + kwargs['source'] = '*' |
| 359 | + kwargs['read_only'] = True |
| 360 | + super().__init__(**kwargs) |
| 361 | + |
| 362 | + def bind(self, field_name, parent): |
| 363 | + default_method_name = 'get_{field_name}'.format(field_name=field_name) |
| 364 | + if self.method_name is None: |
| 365 | + self.method_name = default_method_name |
| 366 | + super().bind(field_name, parent) |
| 367 | + |
| 368 | + def get_attribute(self, instance): |
| 369 | + serializer_method = getattr(self.parent, self.method_name) |
| 370 | + return serializer_method(instance) |
| 371 | + |
| 372 | + |
| 373 | +class ManySerializerMethodResourceRelatedField(SerializerMethodFieldBase, ResourceRelatedField): |
| 374 | + def __init__(self, child_relation=None, *args, **kwargs): |
| 375 | + assert child_relation is not None, '`child_relation` is a required argument.' |
| 376 | + self.child_relation = child_relation |
| 377 | + super().__init__(**kwargs) |
| 378 | + self.child_relation.bind(field_name='', parent=self) |
| 379 | + |
| 380 | + def to_representation(self, value): |
| 381 | + return [self.child_relation.to_representation(item) for item in value] |
| 382 | + |
| 383 | + |
| 384 | +class SerializerMethodResourceRelatedField(SerializerMethodFieldBase, ResourceRelatedField): |
351 | 385 | """
|
352 | 386 | Allows us to use serializer method RelatedFields
|
353 | 387 | with return querysets
|
354 | 388 | """
|
355 |
| - def __new__(cls, *args, **kwargs): |
356 |
| - """ |
357 |
| - We override this because getting serializer methods |
358 |
| - fails at the base class when many=True |
359 |
| - """ |
360 |
| - if kwargs.pop('many', False): |
361 |
| - return cls.many_init(*args, **kwargs) |
362 |
| - return super(ResourceRelatedField, cls).__new__(cls, *args, **kwargs) |
363 | 389 |
|
364 |
| - def __init__(self, child_relation=None, *args, **kwargs): |
365 |
| - model = kwargs.pop('model', None) |
366 |
| - if child_relation is not None: |
367 |
| - self.child_relation = child_relation |
368 |
| - if model: |
369 |
| - self.model = model |
370 |
| - super(SerializerMethodResourceRelatedField, self).__init__(*args, **kwargs) |
| 390 | + many_kwargs = [*MANY_RELATION_KWARGS, *LINKS_PARAMS, 'method_name', 'model'] |
| 391 | + many_cls = ManySerializerMethodResourceRelatedField |
371 | 392 |
|
372 | 393 | @classmethod
|
373 | 394 | def many_init(cls, *args, **kwargs):
|
374 |
| - list_kwargs = {k: kwargs.pop(k) for k in LINKS_PARAMS if k in kwargs} |
375 |
| - list_kwargs['child_relation'] = cls(*args, **kwargs) |
376 |
| - for key in kwargs.keys(): |
377 |
| - if key in ('model',) + MANY_RELATION_KWARGS: |
| 395 | + list_kwargs = {'child_relation': cls(**kwargs)} |
| 396 | + for key in kwargs: |
| 397 | + if key in cls.many_kwargs: |
378 | 398 | list_kwargs[key] = kwargs[key]
|
379 |
| - return cls(**list_kwargs) |
| 399 | + return cls.many_cls(**list_kwargs) |
380 | 400 |
|
381 |
| - def get_attribute(self, instance): |
382 |
| - # check for a source fn defined on the serializer instead of the model |
383 |
| - if self.source and hasattr(self.parent, self.source): |
384 |
| - serializer_method = getattr(self.parent, self.source) |
385 |
| - if hasattr(serializer_method, '__call__'): |
386 |
| - return serializer_method(instance) |
387 |
| - return super(SerializerMethodResourceRelatedField, self).get_attribute(instance) |
388 | 401 |
|
389 |
| - def to_representation(self, value): |
390 |
| - if isinstance(value, Iterable): |
391 |
| - base = super(SerializerMethodResourceRelatedField, self) |
392 |
| - return [base.to_representation(x) for x in value] |
393 |
| - return super(SerializerMethodResourceRelatedField, self).to_representation(value) |
| 402 | +class ManySerializerMethodHyperlinkedRelatedField(SkipDataMixin, |
| 403 | + ManySerializerMethodResourceRelatedField): |
| 404 | + pass |
394 | 405 |
|
395 | 406 |
|
396 |
| -class SerializerMethodHyperlinkedRelatedField(SkipDataMixin, SerializerMethodResourceRelatedField): |
397 |
| - pass |
| 407 | +class SerializerMethodHyperlinkedRelatedField(SkipDataMixin, |
| 408 | + SerializerMethodResourceRelatedField): |
| 409 | + many_cls = ManySerializerMethodHyperlinkedRelatedField |
0 commit comments