From cb7b5c12dddc6d3d4bd3f7b400fc8a030dbf7f88 Mon Sep 17 00:00:00 2001 From: David Gage Date: Wed, 7 Nov 2018 14:02:56 -0500 Subject: [PATCH] Set streaming serialization separately from many (#661) --- apps/metrics/views.py | 37 ++++++++++++++++++++++++----- hhs_oauth_server/request_logging.py | 2 +- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/apps/metrics/views.py b/apps/metrics/views.py index efddbafa8..bf43bf1fe 100644 --- a/apps/metrics/views.py +++ b/apps/metrics/views.py @@ -15,6 +15,7 @@ ListSerializer, IntegerField, DateTimeField, + LIST_SERIALIZER_KWARGS, ) from rest_framework.generics import ListAPIView from rest_framework.pagination import PageNumberPagination @@ -29,6 +30,8 @@ log = logging.getLogger('hhs_server.%s' % __name__) +STREAM_SERIALIZER_KWARGS = LIST_SERIALIZER_KWARGS + class StreamingSerializer(ListSerializer): @property @@ -45,13 +48,35 @@ def data(self): class StreamableSerializerMixin(object): + def __new__(cls, *args, **kwargs): + + # We override this method in order to automagically create + # `ListSerializer` classes instead when `many=True` is set. + if kwargs.pop('many', False): + if kwargs.pop('stream', False): + return cls.stream_init(*args, **kwargs) + return cls.many_init(*args, **kwargs) + + return super().__new__(cls, *args, **kwargs) + @classmethod - def many_init(cls, *args, **kwargs): - stream = kwargs.pop('stream', False) - if stream: - meta = getattr(cls, 'Meta', None) - setattr(meta, 'list_serializer_class', getattr(meta, 'stream_serializer_class', StreamingSerializer)) - return super().many_init(*args, **kwargs) + def stream_init(cls, *args, **kwargs): + allow_empty = kwargs.pop('allow_empty', None) + child_serializer = cls(*args, **kwargs) + stream_kwargs = { + 'child': child_serializer, + } + if allow_empty is not None: + stream_kwargs['allow_empty'] = allow_empty + + stream_kwargs.update({ + key: value for key, value in kwargs.items() + if key in STREAM_SERIALIZER_KWARGS + }) + + meta = getattr(cls, 'Meta', None) + stream_serializer_class = getattr(meta, 'stream_serializer_class', StreamingSerializer) + return stream_serializer_class(*args, **stream_kwargs) class UserSerializer(ModelSerializer): diff --git a/hhs_oauth_server/request_logging.py b/hhs_oauth_server/request_logging.py index 49eeee901..5f4650ee1 100644 --- a/hhs_oauth_server/request_logging.py +++ b/hhs_oauth_server/request_logging.py @@ -57,7 +57,7 @@ def __str__(self): if log_msg['response_code'] in (300, 301, 302, 307): log_msg['location'] = self.response.get('Location', '?') - elif self.response.content: + elif getattr(self.response, 'content', False): log_msg['size'] = len(self.response.content) log_msg['user'] = str(get_user_from_request(self.request))