Skip to content

Commit

Permalink
Add limit for codec indexes
Browse files Browse the repository at this point in the history
closes #507
  • Loading branch information
pkhry committed Sep 11, 2024
1 parent 224b834 commit e6b0b0f
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ The derive implementation supports the following attributes:
- `codec(encoded_as = "OtherType")`: Needs to be placed above a field and makes the field being
encoded by using `OtherType`.
- `codec(index = 0)`: Needs to be placed above an enum variant to make the variant use the given
index when encoded. By default the index is determined by counting from `0` beginning wth the
index when encoded. By default the index is determined by counting from `0` beginning with the
first variant.
- `codec(encode_bound)`, `codec(decode_bound)` and `codec(mel_bound)`: All 3 attributes take
in a `where` clause for the `Encode`, `Decode` and `MaxEncodedLen` trait implementation for
Expand Down
23 changes: 16 additions & 7 deletions derive/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use proc_macro2::{Ident, Span, TokenStream};
use syn::{spanned::Spanned, Data, Error, Field, Fields};

use crate::utils;
use crate::utils::{self, UsedIndexes};

/// Generate function block for function `Decode::decode`.
///
Expand Down Expand Up @@ -57,9 +57,17 @@ pub fn quote(
.to_compile_error();
}

let recurse = data_variants().enumerate().map(|(i, v)| {
let mut used_indexes = match UsedIndexes::from_iter(data_variants()) {
Ok(index) => index,
Err(e) => return e.into_compile_error(),
};
let mut items = vec![];
for v in data_variants() {
let name = &v.ident;
let index = utils::variant_index(v, i);
let index = match used_indexes.variant_index(v) {
Ok(index) => index,
Err(e) => return e.into_compile_error(),
};

let create = create_instance(
quote! { #type_name #type_generics :: #name },
Expand All @@ -69,7 +77,7 @@ pub fn quote(
crate_path,
);

quote_spanned! { v.span() =>
let item = quote_spanned! { v.span() =>
#[allow(clippy::unnecessary_cast)]
__codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => {
// NOTE: This lambda is necessary to work around an upstream bug
Expand All @@ -80,8 +88,9 @@ pub fn quote(
#create
})();
},
}
});
};
items.push(item);
}

let read_byte_err_msg =
format!("Could not decode `{type_name}`, failed to read variant byte");
Expand All @@ -91,7 +100,7 @@ pub fn quote(
match #input.read_byte()
.map_err(|e| e.chain(#read_byte_err_msg))?
{
#( #recurse )*
#( #items )*
_ => {
#[allow(clippy::redundant_closure_call)]
return (move || {
Expand Down
27 changes: 17 additions & 10 deletions derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::str::from_utf8;
use proc_macro2::{Ident, Span, TokenStream};
use syn::{punctuated::Punctuated, spanned::Spanned, token::Comma, Data, Error, Field, Fields};

use crate::utils;
use crate::{utils, utils::UsedIndexes};

type FieldsList = Punctuated<Field, Comma>;

Expand Down Expand Up @@ -313,12 +313,18 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
if data_variants().count() == 0 {
return quote!();
}

let recurse = data_variants().enumerate().map(|(i, f)| {
let mut used_indexes = match UsedIndexes::from_iter(data_variants()) {
Ok(index) => index,
Err(e) => return e.into_compile_error(),
};
let mut items = vec![];
for f in data_variants() {
let name = &f.ident;
let index = utils::variant_index(f, i);

match f.fields {
let index = match used_indexes.variant_index(f) {
Ok(index) => index,
Err(e) => return e.into_compile_error(),
};
let item = match f.fields {
Fields::Named(ref fields) => {
let fields = &fields.named;
let field_name = |_, ident: &Option<Ident>| quote!(#ident);
Expand Down Expand Up @@ -396,11 +402,12 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS

[hinting, encoding]
},
}
});
};
items.push(item)
}

let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting);
let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding);
let recurse_hinting = items.iter().map(|[hinting, _]| hinting);
let recurse_encoding = items.iter().map(|[_, encoding]| encoding);

let hinting = quote! {
// The variant index uses 1 byte.
Expand Down
113 changes: 89 additions & 24 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
//! NOTE: attributes finder must be checked using check_attribute first,
//! otherwise the macro can panic.

use std::str::FromStr;
use std::{collections::HashSet, str::FromStr};

use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput,
Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant,
ExprLit, Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path,
Variant,
};

fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option<R>
Expand All @@ -37,32 +38,96 @@ where
})
}

/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
// first look for an attribute
let index = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some(byte);
pub struct UsedIndexes {
used_set: HashSet<u8>,
current: u8,
}

impl UsedIndexes {
/// Build a Set of used indexes for use with #[scale(index = $int)] attribute on variant
pub fn from_iter<'a, I: Iterator<Item = &'a Variant>>(values: I) -> syn::Result<Self> {
let mut set = HashSet::new();
for (i, v) in values.enumerate() {
if let Some((index, nv)) = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some((byte, nv.span()));
}
}
}
None
}) {
if !set.insert(index) {
return Err(syn::Error::new(nv.span(), "Duplicate variant index. qed"))
}
set.insert(i.try_into().expect("Will never happen. qed"));
} else {
match v.discriminant.as_ref() {
Some((
_,
expr @ syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(lit_int), .. }),
)) => {
let index = lit_int
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
if !set.insert(index) {
return Err(syn::Error::new(expr.span(), "Duplicate variant index. qed"))
}
set.insert(i.try_into().expect("Will never happen. qed"));
},
_ => (),
}
}
}
Ok(Self { current: 0, used_set: set })
}

None
});

// then fallback to discriminant or just index
index.map(|i| quote! { #i }).unwrap_or_else(|| {
v.discriminant
.as_ref()
.map(|(_, expr)| quote! { #expr })
.unwrap_or_else(|| quote! { #i })
})
/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(&mut self, v: &Variant) -> syn::Result<TokenStream> {
// first look for an attribute
let index = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some(byte);
}
}
}

None
});

index.map_or_else(
|| match v.discriminant.as_ref() {
Some((_, expr)) => return Ok(quote! { #expr }),
None => {
let idx = self.next_index();
return Ok(quote! { #idx })
},
},
|i| Ok(quote! { #i }),
)
}

fn next_index(&mut self) -> u8 {
loop {
if self.used_set.contains(&self.current) {
self.current += 1;
} else {
let index = self.current;
self.current += 1;
return index
}
}
}
}

/// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given
Expand Down
4 changes: 2 additions & 2 deletions tests/variant_number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn discriminant_variant_counted_in_default_index() {
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![2]);
}

#[test]
Expand All @@ -36,5 +36,5 @@ fn index_attr_variant_counted_and_reused_in_default_index() {
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![2]);
}

0 comments on commit e6b0b0f

Please sign in to comment.