From 87bd37f50f0942284265f6944a9307733afa4d61 Mon Sep 17 00:00:00 2001 From: turbocool3r Date: Wed, 9 Oct 2024 00:16:10 +0300 Subject: [PATCH] der: Custom error types in derive macros. `Sequence`, `Enumerated` and `Choice` macros now support `#[asn1(error = Ty)]` attribute that provides a custom error type for `Decode`/`DecodeValue` implementations. This addresses #1559. --- der/tests/derive.rs | 52 ++++++++++++++++++++++++++++++++++++ der_derive/src/attributes.rs | 48 ++++++++++++++++++++------------- der_derive/src/choice.rs | 33 ++++++++++++++++------- der_derive/src/enumerated.rs | 51 +++++++++++++++++++++++------------ der_derive/src/sequence.rs | 20 ++++++++++---- der_derive/src/tag.rs | 16 +++++++++++ 6 files changed, 169 insertions(+), 51 deletions(-) diff --git a/der/tests/derive.rs b/der/tests/derive.rs index fa2e1c7e9..510e9a621 100644 --- a/der/tests/derive.rs +++ b/der/tests/derive.rs @@ -11,10 +11,29 @@ // TODO: fix needless_question_mark in the derive crate #![allow(clippy::bool_assert_comparison, clippy::needless_question_mark)] +#[derive(Debug)] +#[allow(dead_code)] +pub struct CustomError(der::Error); + +impl From for CustomError { + fn from(value: der::Error) -> Self { + Self(value) + } +} + +impl From for CustomError { + fn from(_value: std::convert::Infallible) -> Self { + unreachable!() + } +} + /// Custom derive test cases for the `Choice` macro. mod choice { + use super::CustomError; + /// `Choice` with `EXPLICIT` tagging. mod explicit { + use super::CustomError; use der::{ asn1::{GeneralizedTime, UtcTime}, Choice, Decode, Encode, SliceWriter, @@ -50,6 +69,13 @@ mod choice { } } + #[derive(Choice)] + #[asn1(error = CustomError)] + pub enum WithCustomError { + #[asn1(type = "GeneralizedTime")] + Foo(GeneralizedTime), + } + const UTC_TIMESTAMP_DER: &[u8] = &hex!("17 0d 39 31 30 35 30 36 32 33 34 35 34 30 5a"); const GENERAL_TIMESTAMP_DER: &[u8] = &hex!("18 0f 31 39 39 31 30 35 30 36 32 33 34 35 34 30 5a"); @@ -61,6 +87,10 @@ mod choice { let general_time = Time::from_der(GENERAL_TIMESTAMP_DER).unwrap(); assert_eq!(general_time.to_unix_duration().as_secs(), 673573540); + + let WithCustomError::Foo(with_custom_error) = + WithCustomError::from_der(GENERAL_TIMESTAMP_DER).unwrap(); + assert_eq!(with_custom_error.to_unix_duration().as_secs(), 673573540); } #[test] @@ -154,6 +184,7 @@ mod choice { /// Custom derive test cases for the `Enumerated` macro. mod enumerated { + use super::CustomError; use der::{Decode, Encode, Enumerated, SliceWriter}; use hex_literal::hex; @@ -176,6 +207,14 @@ mod enumerated { const UNSPECIFIED_DER: &[u8] = &hex!("0a 01 00"); const KEY_COMPROMISE_DER: &[u8] = &hex!("0a 01 01"); + #[derive(Enumerated, Copy, Clone, Eq, PartialEq, Debug)] + #[asn1(error = CustomError)] + #[repr(u32)] + pub enum EnumWithCustomError { + Unspecified = 0, + Specified = 1, + } + #[test] fn decode() { let unspecified = CrlReason::from_der(UNSPECIFIED_DER).unwrap(); @@ -183,6 +222,9 @@ mod enumerated { let key_compromise = CrlReason::from_der(KEY_COMPROMISE_DER).unwrap(); assert_eq!(CrlReason::KeyCompromise, key_compromise); + + let custom_error_enum = EnumWithCustomError::from_der(UNSPECIFIED_DER).unwrap(); + assert_eq!(custom_error_enum, EnumWithCustomError::Unspecified); } #[test] @@ -202,6 +244,7 @@ mod enumerated { /// Custom derive test cases for the `Sequence` macro. #[cfg(feature = "oid")] mod sequence { + use super::CustomError; use core::marker::PhantomData; use der::{ asn1::{AnyRef, ObjectIdentifier, SetOf}, @@ -383,6 +426,12 @@ mod sequence { pub typed_context_specific_optional: Option<&'a [u8]>, } + #[derive(Sequence)] + #[asn1(error = CustomError)] + pub struct TypeWithCustomError { + pub simple: bool, + } + #[test] fn idp_test() { let idp = IssuingDistributionPointExample::from_der(&hex!("30038101FF")).unwrap(); @@ -444,6 +493,9 @@ mod sequence { PRIME256V1_OID, ObjectIdentifier::try_from(algorithm_identifier.parameters.unwrap()).unwrap() ); + + let t = TypeWithCustomError::from_der(&hex!("30030101FF")).unwrap(); + assert!(t.simple); } #[test] diff --git a/der_derive/src/attributes.rs b/der_derive/src/attributes.rs index fa050cbcb..74099703b 100644 --- a/der_derive/src/attributes.rs +++ b/der_derive/src/attributes.rs @@ -18,37 +18,47 @@ pub(crate) struct TypeAttrs { /// /// The default value is `EXPLICIT`. pub tag_mode: TagMode, + pub error: Option, } impl TypeAttrs { /// Parse attributes from a struct field or enum variant. pub fn parse(attrs: &[Attribute]) -> syn::Result { let mut tag_mode = None; + let mut error = None; - let mut parsed_attrs = Vec::new(); - AttrNameValue::from_attributes(attrs, &mut parsed_attrs)?; - - for attr in parsed_attrs { - // `tag_mode = "..."` attribute - let mode = attr.parse_value("tag_mode")?.ok_or_else(|| { - syn::Error::new_spanned( - &attr.name, - "invalid `asn1` attribute (valid options are `tag_mode`)", - ) - })?; - - if tag_mode.is_some() { - return Err(syn::Error::new_spanned( - &attr.name, - "duplicate ASN.1 `tag_mode` attribute", - )); + attrs.iter().try_for_each(|attr| { + if !attr.path().is_ident(ATTR_NAME) { + return Ok(()); } - tag_mode = Some(mode); - } + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("tag_mode") { + if tag_mode.is_some() { + abort!(attr, "duplicate ASN.1 `tag_mode` attribute"); + } + + tag_mode = Some(meta.value()?.parse()?); + } else if meta.path.is_ident("error") { + if error.is_some() { + abort!(attr, "duplicate ASN.1 `error` attribute"); + } + + error = Some(meta.value()?.parse()?); + } else { + return Err(syn::Error::new_spanned( + attr, + "invalid `asn1` attribute (valid options are `tag_mode` and `error`)", + )); + } + + Ok(()) + }) + })?; Ok(Self { tag_mode: tag_mode.unwrap_or_default(), + error, }) } } diff --git a/der_derive/src/choice.rs b/der_derive/src/choice.rs index 8683c6441..8f10aa89c 100644 --- a/der_derive/src/choice.rs +++ b/der_derive/src/choice.rs @@ -7,8 +7,8 @@ mod variant; use self::variant::ChoiceVariant; use crate::{default_lifetime, TypeAttrs}; use proc_macro2::TokenStream; -use quote::quote; -use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; +use quote::{quote, ToTokens}; +use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam, Path}; /// Derive the `Choice` trait for an enum. pub(crate) struct DeriveChoice { @@ -20,6 +20,9 @@ pub(crate) struct DeriveChoice { /// Variants of this `Choice`. variants: Vec, + + /// Error type for `DecodeValue` implementation. + error: Option, } impl DeriveChoice { @@ -33,7 +36,7 @@ impl DeriveChoice { ), }; - let type_attrs = TypeAttrs::parse(&input.attrs)?; + let mut type_attrs = TypeAttrs::parse(&input.attrs)?; let variants = data .variants .iter() @@ -44,6 +47,7 @@ impl DeriveChoice { ident: input.ident, generics: input.generics.clone(), variants, + error: type_attrs.error.take(), }) } @@ -84,6 +88,12 @@ impl DeriveChoice { tagged_body.push(variant.to_tagged_tokens()); } + let error = self + .error + .as_ref() + .map(ToTokens::to_token_stream) + .unwrap_or_else(|| quote! { ::der::Error }); + quote! { impl #impl_generics ::der::Choice<#lifetime> for #ident #ty_generics #where_clause { fn can_decode(tag: ::der::Tag) -> bool { @@ -92,17 +102,20 @@ impl DeriveChoice { } impl #impl_generics ::der::Decode<#lifetime> for #ident #ty_generics #where_clause { - type Error = ::der::Error; + type Error = #error; - fn decode>(reader: &mut R) -> ::der::Result { + fn decode>(reader: &mut R) -> ::core::result::Result { use der::Reader as _; match ::der::Tag::peek(reader)? { #(#decode_body)* - actual => Err(der::ErrorKind::TagUnexpected { - expected: None, - actual - } - .into()), + actual => Err(::der::Error::new( + ::der::ErrorKind::TagUnexpected { + expected: None, + actual + }, + reader.position() + ).into() + ), } } } diff --git a/der_derive/src/enumerated.rs b/der_derive/src/enumerated.rs index 303014140..f70a57407 100644 --- a/der_derive/src/enumerated.rs +++ b/der_derive/src/enumerated.rs @@ -2,11 +2,10 @@ //! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to //! enum variants. -use crate::attributes::AttrNameValue; use crate::{default_lifetime, ATTR_NAME}; use proc_macro2::TokenStream; -use quote::quote; -use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant}; +use quote::{quote, ToTokens}; +use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, LitStr, Path, Variant}; /// Valid options for the `#[repr]` attribute on `Enumerated` types. const REPR_TYPES: &[&str] = &["u8", "u16", "u32"]; @@ -24,6 +23,9 @@ pub(crate) struct DeriveEnumerated { /// Variants of this enum. variants: Vec, + + /// Error type for `DecodeValue` implementation. + error: Option, } impl DeriveEnumerated { @@ -40,22 +42,30 @@ impl DeriveEnumerated { // Reject `asn1` attributes, parse the `repr` attribute let mut repr: Option = None; let mut integer = false; + let mut error: Option = None; for attr in &input.attrs { if attr.path().is_ident(ATTR_NAME) { - let kvs = match AttrNameValue::parse_attribute(attr) { - Ok(kvs) => kvs, - Err(e) => abort!(attr, e), - }; - for anv in kvs { - if anv.name.is_ident("type") { - match anv.value.value().as_str() { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("type") { + let value: LitStr = meta.value()?.parse()?; + match value.value().as_str() { "ENUMERATED" => integer = false, "INTEGER" => integer = true, - s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")), + s => abort!(value, format_args!("`type = \"{s}\"` is unsupported")), } + } else if meta.path.is_ident("error") { + let path: Path = meta.value()?.parse()?; + error = Some(path); + } else { + return Err(syn::Error::new_spanned( + &meta.path, + "invalid `asn1` attribute (valid options are `type` and `error`)", + )); } - } + + Ok(()) + })?; } else if attr.path().is_ident("repr") { if repr.is_some() { abort!( @@ -97,6 +107,7 @@ impl DeriveEnumerated { })?, variants, integer, + error, }) } @@ -115,14 +126,20 @@ impl DeriveEnumerated { try_from_body.push(variant.to_try_from_tokens()); } + let error = self + .error + .as_ref() + .map(ToTokens::to_token_stream) + .unwrap_or_else(|| quote! { ::der::Error }); + quote! { impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident { - type Error = ::der::Error; + type Error = #error; fn decode_value>( reader: &mut R, header: ::der::Header - ) -> ::der::Result { + ) -> ::core::result::Result { <#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into() } } @@ -142,12 +159,12 @@ impl DeriveEnumerated { } impl TryFrom<#repr> for #ident { - type Error = ::der::Error; + type Error = #error; - fn try_from(n: #repr) -> ::der::Result { + fn try_from(n: #repr) -> ::core::result::Result { match n { #(#try_from_body)* - _ => Err(#tag.value_error()) + _ => Err(#tag.value_error().into()) } } } diff --git a/der_derive/src/sequence.rs b/der_derive/src/sequence.rs index 81ca3d729..360525a03 100644 --- a/der_derive/src/sequence.rs +++ b/der_derive/src/sequence.rs @@ -6,8 +6,8 @@ mod field; use crate::{default_lifetime, TypeAttrs}; use field::SequenceField; use proc_macro2::TokenStream; -use quote::quote; -use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; +use quote::{quote, ToTokens}; +use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam, Path}; /// Derive the `Sequence` trait for a struct pub(crate) struct DeriveSequence { @@ -19,6 +19,9 @@ pub(crate) struct DeriveSequence { /// Fields of the struct. fields: Vec, + + /// Error type for `DecodeValue` implementation. + error: Option, } impl DeriveSequence { @@ -32,7 +35,7 @@ impl DeriveSequence { ), }; - let type_attrs = TypeAttrs::parse(&input.attrs)?; + let mut type_attrs = TypeAttrs::parse(&input.attrs)?; let fields = data .fields @@ -44,6 +47,7 @@ impl DeriveSequence { ident: input.ident, generics: input.generics.clone(), fields, + error: type_attrs.error.take(), }) } @@ -84,14 +88,20 @@ impl DeriveSequence { encode_fields.push(quote!(#field.encode(writer)?;)); } + let error = self + .error + .as_ref() + .map(ToTokens::to_token_stream) + .unwrap_or_else(|| quote! { ::der::Error }); + quote! { impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause { - type Error = ::der::Error; + type Error = #error; fn decode_value>( reader: &mut R, header: ::der::Header, - ) -> ::der::Result { + ) -> ::core::result::Result { use ::der::{Decode as _, DecodeValue as _, Reader as _}; reader.read_nested(header.length, |reader| { diff --git a/der_derive/src/tag.rs b/der_derive/src/tag.rs index aab2899b5..a1cf529cb 100644 --- a/der_derive/src/tag.rs +++ b/der_derive/src/tag.rs @@ -7,6 +7,7 @@ use std::{ fmt::{self, Display}, str::FromStr, }; +use syn::{parse::Parse, LitStr}; /// Tag "IR" type. #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] @@ -78,6 +79,21 @@ impl TagMode { } } +impl Parse for TagMode { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let s: LitStr = input.parse()?; + + match s.value().as_str() { + "EXPLICIT" | "explicit" => Ok(TagMode::Explicit), + "IMPLICIT" | "implicit" => Ok(TagMode::Implicit), + _ => Err(syn::Error::new( + s.span(), + "invalid tag mode (supported modes are `EXPLICIT` and `IMPLICIT`)", + )), + } + } +} + impl FromStr for TagMode { type Err = ParseError;