Skip to content

Commit

Permalink
Merge pull request #5 from paulcwatts/sparse-fieldsets
Browse files Browse the repository at this point in the history
Sparse fieldsets
  • Loading branch information
paulcwatts authored May 5, 2017
2 parents 28775fa + bfcea85 commit 219aabf
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 55 deletions.
3 changes: 2 additions & 1 deletion rest_framework_json_schema/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .exceptions import TypeConflict
from .renderers import JSONAPIRenderer
from .schema import Context


class Conflict(exceptions.APIException):
Expand Down Expand Up @@ -31,7 +32,7 @@ def parse(self, stream, media_type=None, parser_context=None):
raise exceptions.ValidationError('No primary data.')

try:
parsed = schema().parse(data, parser_context.get('request', None))
parsed = schema().parse(data, Context(parser_context.get('request', None)))
except TypeConflict as e:
raise Conflict(str(e))

Expand Down
30 changes: 24 additions & 6 deletions rest_framework_json_schema/renderers.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from collections import OrderedDict
import re
import six

from rest_framework.renderers import JSONRenderer

from .schema import Context
from .exceptions import NoSchema
from .utils import parse_include


RX_FIELDS = re.compile(r'^fields\[([a-zA-Z0-9\-_]+)\]$')


class JSONAPIRenderer(JSONRenderer):
media_type = 'application/vnd.api+json'
format = 'vnd.api+json'
# You can specify top-level items here.
meta = None
jsonapi = None

def render_obj(self, obj, schema, renderer_context, include):
return schema.render(obj, renderer_context.get('request', None), include)
def render_obj(self, obj, schema, renderer_context, context):
return schema.render(obj, context)

def render_list(self, obj_list, schema, renderer_context, include):
def render_list(self, obj_list, schema, renderer_context, context):
primary = []
included = []
for obj in obj_list:
obj, inc = self.render_obj(obj, schema, renderer_context, include)
obj, inc = self.render_obj(obj, schema, renderer_context, context)
primary.append(obj)
included.extend(inc)

Expand All @@ -29,12 +35,14 @@ def render_list(self, obj_list, schema, renderer_context, include):
def render_data(self, data, renderer_context, include):
schema = self.get_schema(data, renderer_context)
assert schema, 'Unable to get schema class'
fields = self.get_fields(renderer_context)
context = Context(renderer_context.get('request', None), include, fields)

if isinstance(data, dict):
return self.render_obj(data, schema(), renderer_context, include)
return self.render_obj(data, schema(), renderer_context, context)

elif isinstance(data, list):
return self.render_list(data, schema(), renderer_context, include)
return self.render_list(data, schema(), renderer_context, context)

def render_exception(self, data, renderer_context):
return [data]
Expand Down Expand Up @@ -66,6 +74,16 @@ def get_include(self, renderer_context):
else:
return {}

def get_fields(self, renderer_context):
request = renderer_context.get('request', None)
fields = {}
if request:
for key, value in six.iteritems(request.query_params):
m = RX_FIELDS.match(key)
if m:
fields[m.group(1)] = value.split(',')
return fields

def render(self, data, media_type=None, renderer_context=None):
if data is None:
return bytes()
Expand Down
88 changes: 58 additions & 30 deletions rest_framework_json_schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@
from .transforms import NullTransform


class Context(object):
"""
Collection of arguments needed for rendering/parsing.
"""
def __init__(self, request, include=None, fields=None):
self.request = request
self.include = include or {}
self.fields = fields or {}


class BaseLinkedObject(object):
def render_links(self, data, request):
def render_links(self, data, context):
return OrderedDict(
(link_name, link_obj.render(data, request)) for (link_name, link_obj) in self.links
(link_name, link_obj.render(data, context.request))
for (link_name, link_obj) in self.links
)

def render_meta(self, data, request):
def render_meta(self, data, context):
"""
Implement this in your subclass if you have more complex meta information
that depends on data or the request.
Expand Down Expand Up @@ -53,7 +64,7 @@ def _normalize_rel(rel):
for (name, rel) in self.relationships:
self.transformed_names[name] = transformer.transform(name)

def parse(self, data, request):
def parse(self, data, context):
"""
Parses a Resource Object representation into an internal representation.
Verifies that the object is of the correct type.
Expand All @@ -73,65 +84,82 @@ def parse(self, data, request):
})
return result

def render(self, data, request, include):
def render(self, data, context):
"""
Renders data to a Resource Object representation.
"""
result = OrderedDict((
('id', str(data[self.id])),
('type', self.type)
))
attributes = self.render_attributes(data, request)
attributes = self.render_attributes(data, context)
if attributes:
result['attributes'] = attributes

relationships, included = self.render_relationships(data, request, include)
relationships, included = self.render_relationships(data, context)
if relationships:
result['relationships'] = relationships

links = self.render_links(data, request)
links = self.render_links(data, context)
if links:
result['links'] = links

meta = self.render_meta(data, request)
meta = self.render_meta(data, context)
if meta:
result['meta'] = meta
return result, included

def render_attributes(self, data, request):
def render_attributes(self, data, context):
attributes = self.filter_by_fields(self.attributes, context.fields)
return OrderedDict(
(self.transformed_names[attr], self.from_data(data, attr)) for attr in self.attributes
(self.transformed_names[attr], self.from_data(data, attr)) for attr in attributes
)

def render_relationships(self, data, request, include):
def render_relationships(self, data, context):
relationships = OrderedDict()
included = []
# Validate that all top-level include keys are actually relationships
rel_keys = {rel[0] for rel in self.relationships}
for key in include:
for key in context.include:
if key not in rel_keys:
raise IncludeInvalid('Invalid relationship to include: %s' % key)

for (name, rel) in self.relationships:
relationship, rel_included = self.render_relationship(data, name, rel, request, include)
filtered = self.filter_by_fields(self.relationships, context.fields, lambda x: x[0])
for (name, rel) in filtered:
relationship, rel_included = self.render_relationship(data, name, rel, context)
relationships[self.transformed_names[name]] = relationship
included.extend(rel_included)

return relationships, included

def render_relationship(self, data, rel_name, rel, request, include):
def render_relationship(self, data, rel_name, rel, context):
# This relationship is included if rel_name is in the include paths.
include_this = rel_name in include
include_paths = include.get(rel_name, {})

include_this = rel_name in context.include
# Create a new context by going one level deeper into the include paths.
rel_context = Context(
context.request,
context.include.get(rel_name, {}),
context.fields
)
rel_data = self.from_data(data, rel_name)
return rel.render(data, rel_data, request, include_this, include_paths)
return rel.render(data, rel_data, rel_context, include_this)

def from_data(self, data, attr):
# This is easy for now, but eventually we want to be able to specify
# functions and the like
return data[attr]

def filter_by_fields(self, names, fields, name_fn=lambda name: name):
"""
Filters the list of names by the list of fields.
"""
if self.type not in fields:
return names
type_fields = fields[self.type]
# This is essentially an intersection, but we preserve the order
# of the attributes/relationships specified by the schema.
return [name for name in names if self.transformed_names[name_fn(name)] in type_fields]


class ResourceIdObject(BaseLinkedObject):
"""
Expand Down Expand Up @@ -178,38 +206,38 @@ def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

def render_included(self, rel_data, request, include_paths):
def render_included(self, rel_data, context):
# This recursively calls the resource's schema to render the full object.
schema = rel_data.get_schema()
obj, included = schema.render(rel_data.get_data(), request, include_paths)
obj, included = schema.render(rel_data.get_data(), context)
return [obj] + included

def render(self, obj_data, rel_data, request, include_this, include_paths):
def render(self, obj_data, rel_data, context, include_this):
result = OrderedDict()
included = []

if not rel_data:
# None or []
result['data'] = rel_data
elif isinstance(rel_data, ResourceIdObject):
result['data'] = rel_data.render(request)
result['data'] = rel_data.render(context.request)
if include_this:
included.extend(self.render_included(rel_data, request, include_paths))
included.extend(self.render_included(rel_data, context))
else:
# Probably a list of resource objects
if include_this:
result['data'] = []
for obj in rel_data:
result['data'].append(obj.render(request))
included.extend(self.render_included(obj, request, include_paths))
result['data'].append(obj.render(context.request))
included.extend(self.render_included(obj, context))
else:
result['data'] = [obj.render(request) for obj in rel_data]
result['data'] = [obj.render(context.request) for obj in rel_data]

links = self.render_links(obj_data, request)
links = self.render_links(obj_data, context)
if links:
result['links'] = links

meta = self.render_meta(obj_data, request)
meta = self.render_meta(obj_data, context)
if meta:
result['meta'] = meta
return result, included
Expand Down
45 changes: 45 additions & 0 deletions rest_framework_json_schema/tests/test_renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ def test_options(self):
}
})

def test_fields(self):
url = reverse('artist-detail', kwargs={'pk': 1})
request = self.factory.get(url, {'fields[artist]': 'firstName'})
response = self.view_detail(request, pk=1)
response.render()
self.assertEqual(response['Content-Type'], 'application/vnd.api+json')
self.assertJSONEqual(response.content.decode(), {
'data': {
'id': '1',
'type': 'artist',
'attributes': {
'firstName': 'John',
}
}
})


@override_settings(ROOT_URLCONF='rest_framework_json_schema.test_support.urls')
class JSONAPIRelationshipsRendererTestCase(APISimpleTestCase):
Expand Down Expand Up @@ -344,3 +360,32 @@ def test_include_to_many_and_paths(self):
}
]
})

def test_fields(self):
request = self.factory.get(reverse('album-detail', kwargs={'pk': 0}),
{'fields[album]': 'artist',
'fields[artist]': 'firstName',
'include': 'artist'})
response = self.view_detail(request, pk=0)
response.render()
self.assertEqual(response['Content-Type'], 'application/vnd.api+json')
self.assertJSONEqual(response.content.decode(), {
'data': {
'id': '0',
'type': 'album',
'relationships': {
'artist': {
'data': {'id': '1', 'type': 'artist'}
}
}
},
'included': [
{
'id': '1',
'type': 'artist',
'attributes': {
'firstName': 'John'
}
}
]
})
Loading

0 comments on commit 219aabf

Please sign in to comment.