Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement attribute to disallow deserialization of struct and struct variant with named fields from sequence #2639

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions serde_derive/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,10 @@ fn deserialize_struct(
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs);

// untagged struct variants do not get a visit_seq method. The same applies to
// structs that only have a map representation.
// structs that only have a map representation or are deserialized strict_or_some_other_name_ly.
let visit_seq = match form {
StructForm::Untagged(..) => None,
_ if cattrs.has_flatten() => None,
_ if cattrs.has_flatten() || cattrs.is_strict_or_some_other_name() => None,
_ => {
let mut_seq = if field_names_idents.is_empty() {
quote!(_)
Expand Down Expand Up @@ -1118,12 +1118,26 @@ fn deserialize_struct_in_place(

let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs);

let mut_seq = if field_names_idents.is_empty() {
quote!(_)
let visit_seq_fn = if !cattrs.is_strict_or_some_other_name() {
let mut_seq = if field_names_idents.is_empty() {
quote!(_)
} else {
quote!(mut __seq)
};
let visit_seq = Stmts(deserialize_seq_in_place(params, fields, cattrs, expecting));

Some(quote!(
#[inline]
fn visit_seq<__A>(self, #mut_seq: __A) -> _serde::__private::Result<Self::Value, __A::Error>
where
__A: _serde::de::SeqAccess<#delife>,
{
#visit_seq
}
))
} else {
quote!(mut __seq)
None
};
let visit_seq = Stmts(deserialize_seq_in_place(params, fields, cattrs, expecting));
let visit_map = Stmts(deserialize_map_in_place(params, fields, cattrs));
let field_names = field_names_idents
.iter()
Expand All @@ -1150,13 +1164,7 @@ fn deserialize_struct_in_place(
_serde::__private::Formatter::write_str(__formatter, #expecting)
}

#[inline]
fn visit_seq<__A>(self, #mut_seq: __A) -> _serde::__private::Result<Self::Value, __A::Error>
where
__A: _serde::de::SeqAccess<#delife>,
{
#visit_seq
}
#visit_seq_fn

#[inline]
fn visit_map<__A>(self, mut __map: __A) -> _serde::__private::Result<Self::Value, __A::Error>
Expand Down Expand Up @@ -2001,6 +2009,7 @@ fn deserialize_generated_identifier(
None,
!is_variant && cattrs.has_flatten(),
None,
cattrs.is_strict_or_some_other_name(),
));

let lifetime = if !is_variant && cattrs.has_flatten() {
Expand Down Expand Up @@ -2155,6 +2164,7 @@ fn deserialize_custom_identifier(
fallthrough_borrowed,
false,
cattrs.expecting(),
false,
));

quote_block! {
Expand Down Expand Up @@ -2188,6 +2198,7 @@ fn deserialize_identifier(
fallthrough_borrowed: Option<TokenStream>,
collect_other_fields: bool,
expecting: Option<&str>,
is_strict_or_some_other_name: bool,
) -> Fragment {
let str_mapping = fields.iter().map(|(_, ident, aliases)| {
// `aliases` also contains a main name
Expand Down Expand Up @@ -2255,7 +2266,7 @@ fn deserialize_identifier(
};

let visit_other = if collect_other_fields {
quote! {
Some(quote! {
fn visit_bool<__E>(self, __value: bool) -> _serde::__private::Result<Self::Value, __E>
where
__E: _serde::de::Error,
Expand Down Expand Up @@ -2346,8 +2357,8 @@ fn deserialize_identifier(
{
_serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Unit))
}
}
} else {
})
} else if !is_strict_or_some_other_name {
let u64_mapping = fields.iter().enumerate().map(|(i, (_, ident, _))| {
let i = i as u64;
quote!(#i => _serde::__private::Ok(#this_value::#ident))
Expand All @@ -2368,7 +2379,7 @@ fn deserialize_identifier(
&u64_fallthrough_arm_tokens
};

quote! {
Some(quote! {
fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result<Self::Value, __E>
where
__E: _serde::de::Error,
Expand All @@ -2378,7 +2389,9 @@ fn deserialize_identifier(
_ => #u64_fallthrough_arm,
}
}
}
})
} else {
None
};

let visit_borrowed = if fallthrough_borrowed.is_some() || collect_other_fields {
Expand Down
28 changes: 28 additions & 0 deletions serde_derive/src/internals/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ pub struct Container {
rename_all_fields_rules: RenameAllRules,
ser_bound: Option<Vec<syn::WherePredicate>>,
de_bound: Option<Vec<syn::WherePredicate>>,
is_strict_or_some_other_name: bool,
tag: TagType,
type_from: Option<syn::Type>,
type_try_from: Option<syn::Type>,
Expand Down Expand Up @@ -296,6 +297,7 @@ impl Container {
let mut rename_all_fields_de_rule = Attr::none(cx, RENAME_ALL_FIELDS);
let mut ser_bound = Attr::none(cx, BOUND);
let mut de_bound = Attr::none(cx, BOUND);
let mut strict_or_some_other_name = BoolAttr::none(cx, STRICT_OR_SOME_OTHER_NAME);
let mut untagged = BoolAttr::none(cx, UNTAGGED);
let mut internal_tag = Attr::none(cx, TAG);
let mut content = Attr::none(cx, CONTENT);
Expand Down Expand Up @@ -446,6 +448,27 @@ impl Container {
let (ser, de) = get_where_predicates(cx, &meta)?;
ser_bound.set_opt(&meta.path, ser);
de_bound.set_opt(&meta.path, de);
} else if meta.path == STRICT_OR_SOME_OTHER_NAME {
// #[serde(strict_or_some_other_name)]
let msg = "#[serde(strict_or_some_other_name)] can only be used on structs with named fields or enums";
match &item.data {
syn::Data::Struct(syn::DataStruct { fields, .. }) => {
match fields {
syn::Fields::Named(_) => {
strict_or_some_other_name.set_true(&meta.path);
}
_ => {
cx.syn_error(meta.error(msg));
}
};
}
syn::Data::Enum(_) => {
strict_or_some_other_name.set_true(&meta.path);
}
_ => {
cx.syn_error(meta.error(msg));
}
}
} else if meta.path == UNTAGGED {
// #[serde(untagged)]
match item.data {
Expand Down Expand Up @@ -581,6 +604,7 @@ impl Container {
},
ser_bound: ser_bound.get(),
de_bound: de_bound.get(),
is_strict_or_some_other_name: strict_or_some_other_name.get(),
tag: decide_tag(cx, item, untagged, internal_tag, content),
type_from: type_from.get(),
type_try_from: type_try_from.get(),
Expand Down Expand Up @@ -627,6 +651,10 @@ impl Container {
self.de_bound.as_ref().map(|vec| &vec[..])
}

pub fn is_strict_or_some_other_name(&self) -> bool {
self.is_strict_or_some_other_name
}

pub fn tag(&self) -> &TagType {
&self.tag
}
Expand Down
2 changes: 2 additions & 0 deletions serde_derive/src/internals/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub const SKIP: Symbol = Symbol("skip");
pub const SKIP_DESERIALIZING: Symbol = Symbol("skip_deserializing");
pub const SKIP_SERIALIZING: Symbol = Symbol("skip_serializing");
pub const SKIP_SERIALIZING_IF: Symbol = Symbol("skip_serializing_if");

pub const STRICT_OR_SOME_OTHER_NAME: Symbol = Symbol("strict_or_some_other_name");
pub const TAG: Symbol = Symbol("tag");
pub const TRANSPARENT: Symbol = Symbol("transparent");
pub const TRY_FROM: Symbol = Symbol("try_from");
Expand Down
35 changes: 35 additions & 0 deletions test_suite/tests/test_de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ struct StructDefault<T> {
b: T,
}

#[derive(PartialEq, Debug, Deserialize)]
struct StructNonStrictOrSomeOtherName {
a: i32,
}

#[derive(PartialEq, Debug, Deserialize)]
#[serde(strict_or_some_other_name)]
struct StructStrictOrSomeOtherName {
a: i32,
}

impl Default for StructDefault<String> {
fn default() -> Self {
StructDefault {
Expand Down Expand Up @@ -1592,6 +1603,30 @@ fn test_struct_default() {
);
}

#[test]
fn test_struct_non_strict_or_some_other_name() {
test(
StructNonStrictOrSomeOtherName { a: 50 },
&[Token::Seq { len: Some(1) }, Token::I32(50), Token::SeqEnd],
);
}

#[test]
fn test_struct_strict_or_some_other_name() {
test(
StructStrictOrSomeOtherName { a: 50 },
&[
Token::Struct {
name: "StructStrictOrSomeOtherName",
len: 1,
},
Token::Str("a"),
Token::I32(50),
Token::StructEnd,
],
);
}

#[test]
fn test_enum_unit() {
test(
Expand Down
103 changes: 103 additions & 0 deletions test_suite/tests/test_de_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,27 @@ struct StructSkipAllDenyUnknown {
a: i32,
}

#[derive(PartialEq, Debug, Deserialize)]
#[serde(strict_or_some_other_name)]
struct StructStrictOrSomeOtherName {
a: i32,
}

#[derive(Default, PartialEq, Debug)]
struct NotDeserializable;

#[derive(Debug, PartialEq, Deserialize)]
struct FlattenStrictStrictOrSomeOtherName {
#[serde(flatten)]
data: EnumStrictStrictOrSomeOtherName,
}

#[derive(PartialEq, Debug, Deserialize)]
#[serde(strict_or_some_other_name)]
enum EnumStrictStrictOrSomeOtherName {
Map { a: i32, b: i32, c: i32 },
}

#[derive(PartialEq, Debug, Deserialize)]
enum Enum {
#[allow(dead_code)]
Expand All @@ -68,6 +86,26 @@ enum EnumSkipAll {
Skipped,
}

#[derive(Debug, PartialEq, Deserialize)]
#[serde(tag = "type")]
#[serde(strict_or_some_other_name)]
enum InternallyTaggedStrictOrSomeOtherName {
A { a: u8 },
B(StructStrictOrSomeOtherName),
}

#[derive(Debug, PartialEq, Deserialize)]
#[serde(tag = "type")]
enum OuterStrictOrSomeOtherName {
Inner(InnerStrictOrSomeOtherName),
}

#[derive(Debug, PartialEq, Deserialize)]
#[serde(strict_or_some_other_name)]
enum InnerStrictOrSomeOtherName {
Struct { f: u8 },
}

#[test]
fn test_i8() {
let test = assert_de_tokens_error::<i8>;
Expand Down Expand Up @@ -1179,6 +1217,14 @@ fn test_skip_all_deny_unknown() {
);
}

#[test]
fn test_strict_or_some_other_name() {
assert_de_tokens_error::<StructStrictOrSomeOtherName>(
&[Token::Seq { len: Some(1) }],
"invalid type: sequence, expected struct StructStrictOrSomeOtherName",
);
}

#[test]
fn test_unknown_variant() {
assert_de_tokens_error::<Enum>(
Expand Down Expand Up @@ -1248,6 +1294,63 @@ fn test_enum_out_of_range() {
);
}

#[test]
fn test_struct_variant_strict_or_some_other_name_flatten() {
assert_de_tokens_error::<FlattenStrictStrictOrSomeOtherName>(
&[
Token::Map { len: None },
Token::Str("Map"), // variant
Token::Seq { len: Some(3) },
Token::U32(0), // a
Token::U32(42), // b
Token::U32(69), // c
Token::SeqEnd,
Token::MapEnd,
],
"invalid type: sequence, expected struct variant EnumStrictStrictOrSomeOtherName::Map",
);
}

#[test]
fn test_struct_variant_strict_or_some_other_name_internally_tagged() {
assert_de_tokens_error::<InternallyTaggedStrictOrSomeOtherName>(
&[
Token::Seq { len: Some(2) },
Token::Str("A"),
Token::U8(1),
Token::SeqEnd,
],
"invalid type: sequence, expected struct variant InternallyTaggedStrictOrSomeOtherName::A",
);
assert_de_tokens_error::<InternallyTaggedStrictOrSomeOtherName>(
&[
Token::Seq { len: Some(2) },
Token::Str("B"),
Token::I32(0),
Token::I32(42),
Token::SeqEnd,
],
"invalid type: sequence, expected struct StructStrictOrSomeOtherName",
);
}

#[test]
fn test_struct_variant_strict_or_some_other_name_enum_in_internally_tagged_enum() {
assert_de_tokens_error::<OuterStrictOrSomeOtherName>(
&[
Token::Map { len: Some(2) },
Token::Str("type"),
Token::Str("Inner"),
Token::Str("Struct"),
Token::Seq { len: Some(1) },
Token::U8(69),
Token::SeqEnd,
Token::MapEnd,
],
"invalid type: sequence, expected struct variant InnerStrictOrSomeOtherName::Struct",
);
}

#[test]
fn test_short_tuple() {
assert_de_tokens_error::<(u8, u8, u8)>(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use serde_derive::Deserialize;

#[derive(Deserialize)]
#[serde(strict_or_some_other_name)]
struct S(u8);

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: #[serde(strict_or_some_other_name)] can only be used on structs with named fields or enums
--> tests/ui/strict-or-some-other-name/newtype-struct.rs:4:9
|
4 | #[serde(strict_or_some_other_name)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^
7 changes: 7 additions & 0 deletions test_suite/tests/ui/strict-or-some-other-name/tuple-struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use serde_derive::Deserialize;

#[derive(Deserialize)]
#[serde(strict_or_some_other_name)]
struct S(u8, u8);

fn main() {}
Loading