From e6190185c3640e09118da35abefbfeb7e0887227 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 8 Sep 2022 23:49:17 +0200 Subject: [PATCH] fields.Enum: merge by_value and field arguments --- src/marshmallow/fields.py | 57 ++++++++++++++++--------------- tests/base.py | 4 +-- tests/test_deserialization.py | 64 +++++++++++++++++++++++------------ tests/test_serialization.py | 21 ++++++++---- 4 files changed, 88 insertions(+), 58 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index ffd1fea5c..2d2920f52 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -1856,13 +1856,15 @@ class IPv6Interface(IPInterface): class Enum(Field): - """An Enum field (de)serializing enum members by symbol (name) as string or by value. + """An Enum field (de)serializing enum members by symbol (name) or by value. :param enum Enum: Enum class - :param boolean by_value: Whether to (de)serialize by value or by name. Defaults to False. - :param field: Field class or instance to use if (de)serializing by value. Defaults to Field. + :param boolean|Schema|Field by_value: Whether to (de)serialize by value or by name, + or Field class or instance to use to (de)serialize by value. Defaults to False. - ``field`` argument may only be passed if (de)serializing by value. + If `by_value` is `False` (default), enum members are (de)serialized by symbol (name). + If it is `True`, they are (de)serialized by value using :class:`Field`. + If it is a field instance of class, they are (de)serialized by value using this field. .. versionadded:: 3.18.0 """ @@ -1874,8 +1876,8 @@ class Enum(Field): def __init__( self, enum: type[EnumType], - by_value: bool = False, - field: Field | type | None = None, + *, + by_value: bool | Field | type = False, **kwargs, ): super().__init__(**kwargs) @@ -1883,48 +1885,47 @@ def __init__( self.by_value = by_value # Serialization by name - if self.by_value is False: - if field is not None: - raise ValueError('"field" can not be passed when serializing by name.') + if by_value is False: self.field: Field = String() - self.choices = ", ".join( - [str(self.field._serialize(m, None, None)) for m in enum.__members__] + self.choices_text = ", ".join( + str(self.field._serialize(m, None, None)) for m in enum.__members__ ) # Serialization by value else: - if field is not None: + if by_value is True: + self.field = Field() + else: try: - self.field = resolve_field_instance(field) + self.field = resolve_field_instance(by_value) except FieldInstanceResolutionError as error: raise ValueError( - '"field" must be a subclass or instance of ' + '"by_value" must be either a bool or a subclass or instance of ' "marshmallow.base.FieldABC." ) from error - else: - self.field = Field() - self.choices = ", ".join( - [str(self.field._serialize(m.value, None, None)) for m in enum] + self.choices_text = ", ".join( + str(self.field._serialize(m.value, None, None)) for m in enum ) def _serialize(self, value, attr, obj, **kwargs): if value is None: return None if self.by_value: - return self.field._serialize(value.value, attr, obj, **kwargs) - return value.name + val = value.value + else: + val = value.name + return self.field._serialize(val, attr, obj, **kwargs) def _deserialize(self, value, attr, data, **kwargs): + val = self.field._deserialize(value, attr, data, **kwargs) if self.by_value: - value = self.field._deserialize(value, attr, data, **kwargs) try: - return self.enum(value) - except ValueError as exc: - raise self.make_error("unknown", choices=self.choices) from exc - value = self.field._deserialize(value, attr, data, **kwargs) + return self.enum(val) + except ValueError as error: + raise self.make_error("unknown", choices=self.choices_text) from error try: - return getattr(self.enum, value) - except AttributeError as exc: - raise self.make_error("unknown", choices=self.choices) from exc + return getattr(self.enum, val) + except AttributeError as error: + raise self.make_error("unknown", choices=self.choices_text) from error class Method(Field): diff --git a/tests/base.py b/tests/base.py index 36014c35e..fea1e1b47 100644 --- a/tests/base.py +++ b/tests/base.py @@ -55,8 +55,8 @@ class DateEnum(Enum): fields.IPv4Interface, fields.IPv6Interface, functools.partial(fields.Enum, GenderEnum), - functools.partial(fields.Enum, HairColorEnum, fields.String), - functools.partial(fields.Enum, GenderEnum, fields.Integer), + functools.partial(fields.Enum, HairColorEnum, by_value=fields.String), + functools.partial(fields.Enum, GenderEnum, by_value=fields.Integer), ] diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index ecebc3a71..c80a2581f 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -1097,60 +1097,82 @@ def test_invalid_ipv6interface_deserialization(self, in_value): assert excinfo.value.args[0] == "Not a valid IPv6 interface." - def test_enum_by_symbol_field_deserialization(self): + def test_enum_field_by_symbol_deserialization(self): field = fields.Enum(GenderEnum) assert field.deserialize("male") == GenderEnum.male - def test_enum_by_symbol_field_invalid_value(self): + def test_enum_field_by_symbol_invalid_value(self): field = fields.Enum(GenderEnum) with pytest.raises( ValidationError, match="Must be one of: male, female, non_binary." ): field.deserialize("dummy") - def test_enum_by_symbol_field_not_string(self): + def test_enum_field_by_symbol_not_string(self): field = fields.Enum(GenderEnum) with pytest.raises(ValidationError, match="Not a valid string."): field.deserialize(12) - def test_enum_by_value_field_deserialization(self): - field = fields.Enum(HairColorEnum, by_value=True, field=fields.String) + def test_enum_field_by_value_true_deserialization(self): + field = fields.Enum(HairColorEnum, by_value=True) assert field.deserialize("black hair") == HairColorEnum.black - field = fields.Enum(GenderEnum, by_value=True, field=fields.Integer) + field = fields.Enum(GenderEnum, by_value=True) assert field.deserialize(1) == GenderEnum.male - field = fields.Enum( - DateEnum, by_value=True, field=fields.Date(format="%d/%m/%Y") - ) + + def test_enum_field_by_value_field_deserialization(self): + field = fields.Enum(HairColorEnum, by_value=fields.String) + assert field.deserialize("black hair") == HairColorEnum.black + field = fields.Enum(GenderEnum, by_value=fields.Integer) + assert field.deserialize(1) == GenderEnum.male + field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y")) assert field.deserialize("29/02/2004") == DateEnum.date_1 - def test_enum_by_value_field_invalid_value(self): - field = fields.Enum(HairColorEnum, by_value=True, field=fields.String) + def test_enum_field_by_value_true_invalid_value(self): + field = fields.Enum(HairColorEnum, by_value=True) with pytest.raises( ValidationError, match="Must be one of: black hair, brown hair, blond hair, red hair.", ): field.deserialize("dummy") - field = fields.Enum(GenderEnum, by_value=True, field=fields.Integer) + field = fields.Enum(GenderEnum, by_value=True) with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."): field.deserialize(12) - field = fields.Enum( - DateEnum, by_value=True, field=fields.Date(format="%d/%m/%Y") - ) + + def test_enum_field_by_value_field_invalid_value(self): + field = fields.Enum(HairColorEnum, by_value=fields.String) + with pytest.raises( + ValidationError, + match="Must be one of: black hair, brown hair, blond hair, red hair.", + ): + field.deserialize("dummy") + field = fields.Enum(GenderEnum, by_value=fields.Integer) + with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."): + field.deserialize(12) + field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y")) with pytest.raises( ValidationError, match="Must be one of: 29/02/2004, 29/02/2008, 29/02/2012." ): field.deserialize("28/02/2004") - def test_enum_by_value_field_wrong_type(self): - field = fields.Enum(HairColorEnum, by_value=True, field=fields.String) + def test_enum_field_by_value_true_wrong_type(self): + field = fields.Enum(HairColorEnum, by_value=True) + with pytest.raises( + ValidationError, + match="Must be one of: black hair, brown hair, blond hair, red hair.", + ): + field.deserialize("dummy") + field = fields.Enum(GenderEnum, by_value=True) + with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."): + field.deserialize(12) + + def test_enum_field_by_value_field_wrong_type(self): + field = fields.Enum(HairColorEnum, by_value=fields.String) with pytest.raises(ValidationError, match="Not a valid string."): field.deserialize(12) - field = fields.Enum(GenderEnum, by_value=True, field=fields.Integer) + field = fields.Enum(GenderEnum, by_value=fields.Integer) with pytest.raises(ValidationError, match="Not a valid integer."): field.deserialize("dummy") - field = fields.Enum( - DateEnum, by_value=True, field=fields.Date(format="%d/%m/%Y") - ) + field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y")) with pytest.raises(ValidationError, match="Not a valid date."): field.deserialize("30/02/2004") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 793ca60c1..51671ddf6 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -255,22 +255,29 @@ def test_ipv6_interface_field(self, user): == ipv6interface_exploded_string ) - def test_enum_by_symbol_field_serialization(self, user): + def test_enum_field_by_symbol_serialization(self, user): user.sex = GenderEnum.male field = fields.Enum(GenderEnum) assert field.serialize("sex", user) == "male" - def test_enum_by_value_field_serialization(self, user): + def test_enum_field_by_value_true_serialization(self, user): user.hair_color = HairColorEnum.black - field = fields.Enum(HairColorEnum, by_value=True, field=fields.String) + field = fields.Enum(HairColorEnum, by_value=True) assert field.serialize("hair_color", user) == "black hair" user.sex = GenderEnum.male - field = fields.Enum(GenderEnum, by_value=True, field=fields.Integer) + field = fields.Enum(GenderEnum, by_value=True) assert field.serialize("sex", user) == 1 user.some_date = DateEnum.date_1 - field = fields.Enum( - DateEnum, by_value=True, field=fields.Date(format="%d/%m/%Y") - ) + + def test_enum_field_by_value_field_serialization(self, user): + user.hair_color = HairColorEnum.black + field = fields.Enum(HairColorEnum, by_value=fields.String) + assert field.serialize("hair_color", user) == "black hair" + user.sex = GenderEnum.male + field = fields.Enum(GenderEnum, by_value=fields.Integer) + assert field.serialize("sex", user) == 1 + user.some_date = DateEnum.date_1 + field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y")) assert field.serialize("some_date", user) == "29/02/2004" def test_decimal_field(self, user):