Skip to content

Commit

Permalink
Add #[serde(with)] for containers
Browse files Browse the repository at this point in the history
  • Loading branch information
oblique committed Jul 18, 2023
1 parent 03da66c commit b71ce2c
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 0 deletions.
9 changes: 9 additions & 0 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment {
deserialize_from(type_from)
} else if let Some(type_try_from) = cont.attrs.type_try_from() {
deserialize_try_from(type_try_from)
} else if let Some(with) = cont.attrs.deserialize_with() {
deserialize_with(with)
} else if let attr::Identifier::No = cont.attrs.identifier() {
match &cont.data {
Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs),
Expand Down Expand Up @@ -311,6 +313,7 @@ fn deserialize_in_place_body(cont: &Container, params: &Parameters) -> Option<St
if cont.attrs.transparent()
|| cont.attrs.type_from().is_some()
|| cont.attrs.type_try_from().is_some()
|| cont.attrs.deserialize_with().is_some()
|| cont.attrs.identifier().is_some()
|| cont
.data
Expand Down Expand Up @@ -406,6 +409,12 @@ fn deserialize_try_from(type_try_from: &syn::Type) -> Fragment {
}
}

fn deserialize_with(with: &syn::ExprPath) -> Fragment {
quote_block! {
#with(__deserializer)
}
}

fn deserialize_unit_struct(params: &Parameters, cattrs: &attr::Container) -> Fragment {
let this_type = &params.this_type;
let this_value = &params.this_value;
Expand Down
40 changes: 40 additions & 0 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ pub struct Container {
is_packed: bool,
/// Error message generated when type can't be deserialized
expecting: Option<String>,
serialize_with: Option<syn::ExprPath>,
deserialize_with: Option<syn::ExprPath>,
}

/// Styles of representing an enum.
Expand Down Expand Up @@ -301,6 +303,8 @@ impl Container {
let mut variant_identifier = BoolAttr::none(cx, VARIANT_IDENTIFIER);
let mut serde_path = Attr::none(cx, CRATE);
let mut expecting = Attr::none(cx, EXPECTING);
let mut serialize_with = Attr::none(cx, SERIALIZE_WITH);
let mut deserialize_with = Attr::none(cx, DESERIALIZE_WITH);

for attr in &item.attrs {
if attr.path() != SERDE {
Expand Down Expand Up @@ -493,6 +497,32 @@ impl Container {
if let Some(s) = get_lit_str(cx, EXPECTING, &meta)? {
expecting.set(&meta.path, s.value());
}
} else if meta.path == WITH {
// #[serde(with = "...")]
if let Some(path) = parse_lit_into_expr_path(cx, WITH, &meta)? {
let mut ser_path = path.clone();
ser_path
.path
.segments
.push(Ident::new("serialize", Span::call_site()).into());
serialize_with.set(&meta.path, ser_path);
let mut de_path = path;
de_path
.path
.segments
.push(Ident::new("deserialize", Span::call_site()).into());
deserialize_with.set(&meta.path, de_path);
}
} else if meta.path == SERIALIZE_WITH {
// #[serde(serialize_with = "...")]
if let Some(path) = parse_lit_into_expr_path(cx, SERIALIZE_WITH, &meta)? {
serialize_with.set(&meta.path, path);
}
} else if meta.path == DESERIALIZE_WITH {
// #[serde(deserialize_with = "...")]
if let Some(path) = parse_lit_into_expr_path(cx, DESERIALIZE_WITH, &meta)? {
deserialize_with.set(&meta.path, path);
}
} else {
let path = meta.path.to_token_stream().to_string().replace(' ', "");
return Err(
Expand Down Expand Up @@ -540,6 +570,8 @@ impl Container {
serde_path: serde_path.get(),
is_packed,
expecting: expecting.get(),
serialize_with: serialize_with.get(),
deserialize_with: deserialize_with.get(),
}
}

Expand Down Expand Up @@ -621,6 +653,14 @@ impl Container {
pub fn expecting(&self) -> Option<&str> {
self.expecting.as_ref().map(String::as_ref)
}

pub fn serialize_with(&self) -> Option<&syn::ExprPath> {
self.serialize_with.as_ref()
}

pub fn deserialize_with(&self) -> Option<&syn::ExprPath> {
self.deserialize_with.as_ref()
}
}

fn decide_tag(
Expand Down
49 changes: 49 additions & 0 deletions serde_derive/src/internals/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub fn check(cx: &Ctxt, cont: &mut Container, derive: Derive) {
check_adjacent_tag_conflict(cx, cont);
check_transparent(cx, cont, derive);
check_from_and_try_from(cx, cont);
check_serialize_with(cx, cont);
check_deserialize_with(cx, cont);
}

// Remote derive definition type must have either all of the generics of the
Expand Down Expand Up @@ -354,6 +356,20 @@ fn check_transparent(cx: &Ctxt, cont: &mut Container, derive: Derive) {
);
}

if cont.attrs.serialize_with().is_some() {
cx.error_spanned_by(
cont.original,
"#[serde(transparent)] is not allowed with #[serde(serialize_with = \"...\")]",
);
}

if cont.attrs.deserialize_with().is_some() {
cx.error_spanned_by(
cont.original,
"#[serde(transparent)] is not allowed with #[serde(deserialize_with = \"...\")]",
);
}

let fields = match &mut cont.data {
Data::Enum(_) => {
cx.error_spanned_by(
Expand Down Expand Up @@ -436,3 +452,36 @@ fn check_from_and_try_from(cx: &Ctxt, cont: &mut Container) {
);
}
}

fn check_serialize_with(cx: &Ctxt, cont: &mut Container) {
if cont.attrs.serialize_with().is_none() {
return;
}

if cont.attrs.type_into().is_some() {
cx.error_spanned_by(
cont.original,
"#[serde(into = \"...\")] and #[serde(serialize_with = \"...\")] conflict with each other",
);
}
}

fn check_deserialize_with(cx: &Ctxt, cont: &mut Container) {
if cont.attrs.deserialize_with().is_none() {
return;
}

if cont.attrs.type_from().is_some() {
cx.error_spanned_by(
cont.original,
"#[serde(from = \"...\")] and #[serde(deserialize_with = \"...\")] conflict with each other",
);
}

if cont.attrs.type_try_from().is_some() {
cx.error_spanned_by(
cont.original,
"#[serde(try_from = \"...\")] and #[serde(deserialize_with = \"...\")] conflict with each other",
);
}
}
9 changes: 9 additions & 0 deletions serde_derive/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ fn serialize_body(cont: &Container, params: &Parameters) -> Fragment {
serialize_transparent(cont, params)
} else if let Some(type_into) = cont.attrs.type_into() {
serialize_into(params, type_into)
} else if let Some(with) = cont.attrs.serialize_with() {
serialize_with(params, with)
} else {
match &cont.data {
Data::Enum(variants) => serialize_enum(params, variants, &cont.attrs),
Expand Down Expand Up @@ -218,6 +220,13 @@ fn serialize_into(params: &Parameters, type_into: &syn::Type) -> Fragment {
}
}

fn serialize_with(params: &Parameters, with: &syn::ExprPath) -> Fragment {
let self_var = &params.self_var;
quote_block! {
#with(#self_var, __serializer)
}
}

fn serialize_unit_struct(cattrs: &attr::Container) -> Fragment {
let type_name = cattrs.name().serialize_name();

Expand Down
82 changes: 82 additions & 0 deletions test_suite/tests/test_annotations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3115,3 +3115,85 @@ fn test_expecting_message_identifier_enum() {
r#"invalid type: map, expected something strange..."#,
);
}

#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[serde(with = "enum_with")]
enum EnumWith {
One,
Two,
}

mod enum_with {
use super::EnumWith;
use serde::{Deserialize, Deserializer, Serializer};

pub(super) fn serialize<S>(value: &EnumWith, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
EnumWith::One => serializer.serialize_u32(1),
EnumWith::Two => serializer.serialize_u32(2),
}
}

pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<EnumWith, D::Error>
where
D: Deserializer<'de>,
{
match u32::deserialize(deserializer)? {
1 => Ok(EnumWith::One),
2 => Ok(EnumWith::Two),
_ => Err(serde::de::Error::custom("out of range")),
}
}
}

#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[serde(with = "num_str")]
struct NumStr(String);

impl NumStr {
fn validate(&self) -> Result<(), &'static str> {
if self.0.chars().all(|c| c >= '0' && c <= '9') {
Ok(())
} else {
Err("non-numeric string")
}
}
}

mod num_str {
use crate::NumStr;
use serde::{Deserialize, Deserializer, Serializer};

pub(super) fn serialize<S>(value: &NumStr, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
value.validate().map_err(serde::ser::Error::custom)?;
serializer.serialize_str(&value.0)
}

pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<NumStr, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let n = NumStr(s);
n.validate().map_err(serde::de::Error::custom)?;
Ok(n)
}
}

#[test]
fn test_container_with() {
assert_ser_tokens(&EnumWith::One, &[Token::U32(1)]);
assert_de_tokens(&EnumWith::Two, &[Token::U32(2)]);
assert_de_tokens_error::<EnumWith>(&[Token::U32(5)], "out of range");

assert_ser_tokens(&NumStr("123".to_string()), &[Token::Str("123")]);
assert_ser_tokens_error(&NumStr("12ab3".to_string()), &[], "non-numeric string");
assert_de_tokens(&NumStr("567".to_string()), &[Token::Str("567")]);
assert_de_tokens_error::<NumStr>(&[Token::Str("abc")], "non-numeric string");
}

0 comments on commit b71ce2c

Please sign in to comment.