diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index cb98b8ea2..d55a9188e 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -283,6 +283,8 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment { deserialize_from(type_from) } else if let Some(type_try_from) = cont.attrs.type_try_from() { deserialize_try_from(type_try_from) + } else if let Some(with) = cont.attrs.deserialize_with() { + deserialize_with(with) } else if let attr::Identifier::No = cont.attrs.identifier() { match &cont.data { Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs), @@ -311,6 +313,7 @@ fn deserialize_in_place_body(cont: &Container, params: &Parameters) -> Option Fragment { } } +fn deserialize_with(with: &syn::ExprPath) -> Fragment { + quote_block! { + #with(__deserializer) + } +} + fn deserialize_unit_struct(params: &Parameters, cattrs: &attr::Container) -> Fragment { let this_type = ¶ms.this_type; let this_value = ¶ms.this_value; diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index 42212a64d..89f43ed45 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -218,6 +218,8 @@ pub struct Container { is_packed: bool, /// Error message generated when type can't be deserialized expecting: Option, + serialize_with: Option, + deserialize_with: Option, } /// Styles of representing an enum. @@ -301,6 +303,8 @@ impl Container { let mut variant_identifier = BoolAttr::none(cx, VARIANT_IDENTIFIER); let mut serde_path = Attr::none(cx, CRATE); let mut expecting = Attr::none(cx, EXPECTING); + let mut serialize_with = Attr::none(cx, SERIALIZE_WITH); + let mut deserialize_with = Attr::none(cx, DESERIALIZE_WITH); for attr in &item.attrs { if attr.path() != SERDE { @@ -493,6 +497,32 @@ impl Container { if let Some(s) = get_lit_str(cx, EXPECTING, &meta)? { expecting.set(&meta.path, s.value()); } + } else if meta.path == WITH { + // #[serde(with = "...")] + if let Some(path) = parse_lit_into_expr_path(cx, WITH, &meta)? { + let mut ser_path = path.clone(); + ser_path + .path + .segments + .push(Ident::new("serialize", Span::call_site()).into()); + serialize_with.set(&meta.path, ser_path); + let mut de_path = path; + de_path + .path + .segments + .push(Ident::new("deserialize", Span::call_site()).into()); + deserialize_with.set(&meta.path, de_path); + } + } else if meta.path == SERIALIZE_WITH { + // #[serde(serialize_with = "...")] + if let Some(path) = parse_lit_into_expr_path(cx, SERIALIZE_WITH, &meta)? { + serialize_with.set(&meta.path, path); + } + } else if meta.path == DESERIALIZE_WITH { + // #[serde(deserialize_with = "...")] + if let Some(path) = parse_lit_into_expr_path(cx, DESERIALIZE_WITH, &meta)? { + deserialize_with.set(&meta.path, path); + } } else { let path = meta.path.to_token_stream().to_string().replace(' ', ""); return Err( @@ -540,6 +570,8 @@ impl Container { serde_path: serde_path.get(), is_packed, expecting: expecting.get(), + serialize_with: serialize_with.get(), + deserialize_with: deserialize_with.get(), } } @@ -621,6 +653,14 @@ impl Container { pub fn expecting(&self) -> Option<&str> { self.expecting.as_ref().map(String::as_ref) } + + pub fn serialize_with(&self) -> Option<&syn::ExprPath> { + self.serialize_with.as_ref() + } + + pub fn deserialize_with(&self) -> Option<&syn::ExprPath> { + self.deserialize_with.as_ref() + } } fn decide_tag( diff --git a/serde_derive/src/internals/check.rs b/serde_derive/src/internals/check.rs index 4a7f52c6c..f95b94538 100644 --- a/serde_derive/src/internals/check.rs +++ b/serde_derive/src/internals/check.rs @@ -15,6 +15,8 @@ pub fn check(cx: &Ctxt, cont: &mut Container, derive: Derive) { check_adjacent_tag_conflict(cx, cont); check_transparent(cx, cont, derive); check_from_and_try_from(cx, cont); + check_serialize_with(cx, cont); + check_deserialize_with(cx, cont); } // Remote derive definition type must have either all of the generics of the @@ -354,6 +356,20 @@ fn check_transparent(cx: &Ctxt, cont: &mut Container, derive: Derive) { ); } + if cont.attrs.serialize_with().is_some() { + cx.error_spanned_by( + cont.original, + "#[serde(transparent)] is not allowed with #[serde(serialize_with = \"...\")]", + ); + } + + if cont.attrs.deserialize_with().is_some() { + cx.error_spanned_by( + cont.original, + "#[serde(transparent)] is not allowed with #[serde(deserialize_with = \"...\")]", + ); + } + let fields = match &mut cont.data { Data::Enum(_) => { cx.error_spanned_by( @@ -436,3 +452,36 @@ fn check_from_and_try_from(cx: &Ctxt, cont: &mut Container) { ); } } + +fn check_serialize_with(cx: &Ctxt, cont: &mut Container) { + if cont.attrs.serialize_with().is_none() { + return; + } + + if cont.attrs.type_into().is_some() { + cx.error_spanned_by( + cont.original, + "#[serde(into = \"...\")] and #[serde(serialize_with = \"...\")] conflict with each other", + ); + } +} + +fn check_deserialize_with(cx: &Ctxt, cont: &mut Container) { + if cont.attrs.deserialize_with().is_none() { + return; + } + + if cont.attrs.type_from().is_some() { + cx.error_spanned_by( + cont.original, + "#[serde(from = \"...\")] and #[serde(deserialize_with = \"...\")] conflict with each other", + ); + } + + if cont.attrs.type_try_from().is_some() { + cx.error_spanned_by( + cont.original, + "#[serde(try_from = \"...\")] and #[serde(deserialize_with = \"...\")] conflict with each other", + ); + } +} diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index b9a9dce21..a6049ca71 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -171,6 +171,8 @@ fn serialize_body(cont: &Container, params: &Parameters) -> Fragment { serialize_transparent(cont, params) } else if let Some(type_into) = cont.attrs.type_into() { serialize_into(params, type_into) + } else if let Some(with) = cont.attrs.serialize_with() { + serialize_with(params, with) } else { match &cont.data { Data::Enum(variants) => serialize_enum(params, variants, &cont.attrs), @@ -218,6 +220,13 @@ fn serialize_into(params: &Parameters, type_into: &syn::Type) -> Fragment { } } +fn serialize_with(params: &Parameters, with: &syn::ExprPath) -> Fragment { + let self_var = ¶ms.self_var; + quote_block! { + #with(#self_var, __serializer) + } +} + fn serialize_unit_struct(cattrs: &attr::Container) -> Fragment { let type_name = cattrs.name().serialize_name(); diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 82a98c36a..6e9e9b7cd 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -3115,3 +3115,85 @@ fn test_expecting_message_identifier_enum() { r#"invalid type: map, expected something strange..."#, ); } + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[serde(with = "enum_with")] +enum EnumWith { + One, + Two, +} + +mod enum_with { + use super::EnumWith; + use serde::{Deserialize, Deserializer, Serializer}; + + pub(super) fn serialize(value: &EnumWith, serializer: S) -> Result + where + S: Serializer, + { + match value { + EnumWith::One => serializer.serialize_u32(1), + EnumWith::Two => serializer.serialize_u32(2), + } + } + + pub(super) fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + match u32::deserialize(deserializer)? { + 1 => Ok(EnumWith::One), + 2 => Ok(EnumWith::Two), + _ => Err(serde::de::Error::custom("out of range")), + } + } +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[serde(with = "num_str")] +struct NumStr(String); + +impl NumStr { + fn validate(&self) -> Result<(), &'static str> { + if self.0.chars().all(|c| c >= '0' && c <= '9') { + Ok(()) + } else { + Err("non-numeric string") + } + } +} + +mod num_str { + use crate::NumStr; + use serde::{Deserialize, Deserializer, Serializer}; + + pub(super) fn serialize(value: &NumStr, serializer: S) -> Result + where + S: Serializer, + { + value.validate().map_err(serde::ser::Error::custom)?; + serializer.serialize_str(&value.0) + } + + pub(super) fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + let n = NumStr(s); + n.validate().map_err(serde::de::Error::custom)?; + Ok(n) + } +} + +#[test] +fn test_container_with() { + assert_ser_tokens(&EnumWith::One, &[Token::U32(1)]); + assert_de_tokens(&EnumWith::Two, &[Token::U32(2)]); + assert_de_tokens_error::(&[Token::U32(5)], "out of range"); + + assert_ser_tokens(&NumStr("123".to_string()), &[Token::Str("123")]); + assert_ser_tokens_error(&NumStr("12ab3".to_string()), &[], "non-numeric string"); + assert_de_tokens(&NumStr("567".to_string()), &[Token::Str("567")]); + assert_de_tokens_error::(&[Token::Str("abc")], "non-numeric string"); +}