diff --git a/example/app.py b/example/app.py index 34d064f2..30ba7094 100644 --- a/example/app.py +++ b/example/app.py @@ -305,6 +305,22 @@ class DictDocView(ResourceView): resource = DictDocResource methods = [Fetch, List, Create, Update] +# Document, resource, and view for testing renamed fields +class ReqTitlePost(db.Document): + title_str = db.StringField(required=True) + +class ReqTitlePostResource(Resource): + document = ReqTitlePost + schema = schemas.ReqTitlePost + rename_fields = { + 'title_str': 'title', + } + +@api.register(url='/title_post/') +class ReqTitlePostView(ResourceView): + resource = ReqTitlePostResource + methods = [Fetch, List, Create, Update] + if __name__ == "__main__": port = int(os.environ.get('PORT', 8000)) diff --git a/example/schemas.py b/example/schemas.py index 76089834..1add0e80 100644 --- a/example/schemas.py +++ b/example/schemas.py @@ -37,3 +37,9 @@ class Person(Schema): class DateTime(Schema): datetime = DateTime() + +class NoneString(String): + blank_value = None + +class ReqTitlePost(Schema): + title_str = NoneString(required=False, max_length=10) diff --git a/flask_mongorest/resources.py b/flask_mongorest/resources.py index 85c6968d..1e305691 100644 --- a/flask_mongorest/resources.py +++ b/flask_mongorest/resources.py @@ -509,6 +509,28 @@ def value_for_field(self, obj, field): """ raise UnknownFieldError + def _rename_dict(self, data): + self._rename_dict_with_mapping(data, self._rename_fields) + + def _reverse_rename_dict(self, data): + self._rename_dict_with_mapping(data, self._reverse_rename_fields) + + @staticmethod + def _rename_dict_with_mapping(data, mapping): + # Do renaming in two passes to prevent potential multiple renames + # depending on dict traversal order. + # E.g. if a -> b, b -> c, then a should never be renamed to c. + fields_to_delete = [] + fields_to_update = {} + for to_name, from_name in mapping.items(): + if from_name in data: + fields_to_update[to_name] = data[from_name] + fields_to_delete.append(from_name) + for k in fields_to_delete: + del data[k] + for k, v in fields_to_update.items(): + data[k] = v + def validate_request(self, obj=None): """ Validate the request that's currently being processed and fill in @@ -531,19 +553,7 @@ def validate_request(self, obj=None): # updates. self.data = self.raw_data.copy() - # Do renaming in two passes to prevent potential multiple renames - # depending on dict traversal order. - # E.g. if a -> b, b -> c, then a should never be renamed to c. - fields_to_delete = [] - fields_to_update = {} - for k, v in self._rename_fields.items(): - if v in self.data: - fields_to_update[k] = self.data[v] - fields_to_delete.append(v) - for k in fields_to_delete: - del self.data[k] - for k, v in fields_to_update.items(): - self.data[k] = v + self._rename_dict(self.data) # If CleanCat schema exists on this resource, use it to perform the # validation @@ -557,6 +567,7 @@ def validate_request(self, obj=None): try: self.data = schema.full_clean() except SchemaValidationError: + self._reverse_rename_dict(schema.field_errors) raise ValidationError({'field-errors': schema.field_errors, 'errors': schema.errors }) def get_queryset(self): diff --git a/flask_mongorest/views.py b/flask_mongorest/views.py index 9e353744..65c73f02 100644 --- a/flask_mongorest/views.py +++ b/flask_mongorest/views.py @@ -79,7 +79,10 @@ def handle_validation_error(self, e): if isinstance(e, ValidationError): raise elif isinstance(e, mongoengine.ValidationError): - raise ValidationError(serialize_mongoengine_validation_error(e)) + msg = serialize_mongoengine_validation_error(e) + if 'field-errors' in msg: + self._resource._reverse_rename_dict(msg['field-errors']) + raise ValidationError(msg) else: raise diff --git a/tests/__init__.py b/tests/__init__.py index ee6d5528..a092da6b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -104,6 +104,7 @@ def setUp(self): example.C.drop_collection() example.MethodTestDoc.drop_collection() example.DictDoc.drop_collection() + example.ReqTitlePost.drop_collection() # create user 1 resp = self.app.post('/user/', data=json.dumps(self.user_1)) @@ -1287,6 +1288,88 @@ def test_send_bad_json(self): # test list self.assertRaises(ValueError, self.app.get, '/dict_doc/') + def test_rename_fields_create(self): + """ + Make sure we can create objects by posting a renamed field consistent + with Resource#rename_fields. + """ + resp = self.app.post('/title_post/', data=json.dumps({'title': 'title'})) + response_success(resp) + self.assertEqual(example.ReqTitlePost.objects.first().title_str, 'title') + + def test_rename_fields_get(self): + """ + Make sure fetched objects contain a renamed field consistent with + Resource#rename_fields. + """ + resp = self.app.post('/title_post/', data=json.dumps({'title': 'title'})) + response_success(resp) + post = resp_json(resp) + + # list objects + resp = self.app.get('/title_post/') + response_success(resp) + self.assertEqual(resp_json(resp)['data'][0]['title'], 'title') + + # fetch a single object + resp = self.app.get('/title_post/%s/' % post['id']) + response_success(resp) + self.assertEqual(resp_json(resp)['title'], 'title') + + def test_rename_fields_error(self): + """ + Make sure field errors for a renamed field are returned correctly. + """ + # post with a missing required field + resp = self.app.post('/title_post/', data=json.dumps({ + 'title': None + })) + response_error(resp, code=400) + self.assertEqual(resp_json(resp), { + 'field-errors': {'title': 'Field is required'} + }) + + # create a valid object + resp = self.app.post('/title_post/', data=json.dumps({'title': 'title'})) + response_success(resp) + + # update with a missing required field + resp = self.app.put('/title_post/%s/' % resp_json(resp)['id'], data=json.dumps({ + 'title': None + })) + response_error(resp, code=400) + self.assertEqual(resp_json(resp), { + 'field-errors': {'title': 'Field is required'} + }) + + def test_rename_fields_schema_error(self): + """ + Make sure field errors for a renamed field are returned correctly. + """ + # post with a missing required field + resp = self.app.post('/title_post/', data=json.dumps({ + 'title': 'X'*20 + })) + response_error(resp, code=400) + self.assertEqual(resp_json(resp), { + 'field-errors': {'title': 'The value must be no longer than 10 characters.'}, + 'errors': [] + }) + + # create a valid object + resp = self.app.post('/title_post/', data=json.dumps({'title': 'title'})) + response_success(resp) + + # update with a missing required field + resp = self.app.put('/title_post/%s/' % resp_json(resp)['id'], data=json.dumps({ + 'title': 'X'*20 + })) + response_error(resp, code=400) + self.assertEqual(resp_json(resp), { + 'field-errors': {'title': 'The value must be no longer than 10 characters.'}, + 'errors': [], + }) + class InternalTestCase(unittest.TestCase): """ Test internal methods.