Skip to content

Commit

Permalink
der: Custom error types in derive macros.
Browse files Browse the repository at this point in the history
`Sequence`, `Enumerated` and `Choice` macros now support `#[asn1(error = Ty)]` attribute that provides a custom error type for `Decode`/`DecodeValue` implementations.

This addresses RustCrypto#1559.
  • Loading branch information
turbocool3r committed Oct 8, 2024
1 parent 61db930 commit 87bd37f
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 51 deletions.
52 changes: 52 additions & 0 deletions der/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,29 @@
// TODO: fix needless_question_mark in the derive crate
#![allow(clippy::bool_assert_comparison, clippy::needless_question_mark)]

#[derive(Debug)]
#[allow(dead_code)]
pub struct CustomError(der::Error);

impl From<der::Error> for CustomError {
fn from(value: der::Error) -> Self {
Self(value)
}
}

impl From<std::convert::Infallible> for CustomError {
fn from(_value: std::convert::Infallible) -> Self {
unreachable!()
}
}

/// Custom derive test cases for the `Choice` macro.
mod choice {
use super::CustomError;

/// `Choice` with `EXPLICIT` tagging.
mod explicit {
use super::CustomError;
use der::{
asn1::{GeneralizedTime, UtcTime},
Choice, Decode, Encode, SliceWriter,
Expand Down Expand Up @@ -50,6 +69,13 @@ mod choice {
}
}

#[derive(Choice)]
#[asn1(error = CustomError)]
pub enum WithCustomError {
#[asn1(type = "GeneralizedTime")]
Foo(GeneralizedTime),
}

const UTC_TIMESTAMP_DER: &[u8] = &hex!("17 0d 39 31 30 35 30 36 32 33 34 35 34 30 5a");
const GENERAL_TIMESTAMP_DER: &[u8] =
&hex!("18 0f 31 39 39 31 30 35 30 36 32 33 34 35 34 30 5a");
Expand All @@ -61,6 +87,10 @@ mod choice {

let general_time = Time::from_der(GENERAL_TIMESTAMP_DER).unwrap();
assert_eq!(general_time.to_unix_duration().as_secs(), 673573540);

let WithCustomError::Foo(with_custom_error) =
WithCustomError::from_der(GENERAL_TIMESTAMP_DER).unwrap();
assert_eq!(with_custom_error.to_unix_duration().as_secs(), 673573540);
}

#[test]
Expand Down Expand Up @@ -154,6 +184,7 @@ mod choice {

/// Custom derive test cases for the `Enumerated` macro.
mod enumerated {
use super::CustomError;
use der::{Decode, Encode, Enumerated, SliceWriter};
use hex_literal::hex;

Expand All @@ -176,13 +207,24 @@ mod enumerated {
const UNSPECIFIED_DER: &[u8] = &hex!("0a 01 00");
const KEY_COMPROMISE_DER: &[u8] = &hex!("0a 01 01");

#[derive(Enumerated, Copy, Clone, Eq, PartialEq, Debug)]
#[asn1(error = CustomError)]
#[repr(u32)]
pub enum EnumWithCustomError {
Unspecified = 0,
Specified = 1,
}

#[test]
fn decode() {
let unspecified = CrlReason::from_der(UNSPECIFIED_DER).unwrap();
assert_eq!(CrlReason::Unspecified, unspecified);

let key_compromise = CrlReason::from_der(KEY_COMPROMISE_DER).unwrap();
assert_eq!(CrlReason::KeyCompromise, key_compromise);

let custom_error_enum = EnumWithCustomError::from_der(UNSPECIFIED_DER).unwrap();
assert_eq!(custom_error_enum, EnumWithCustomError::Unspecified);
}

#[test]
Expand All @@ -202,6 +244,7 @@ mod enumerated {
/// Custom derive test cases for the `Sequence` macro.
#[cfg(feature = "oid")]
mod sequence {
use super::CustomError;
use core::marker::PhantomData;
use der::{
asn1::{AnyRef, ObjectIdentifier, SetOf},
Expand Down Expand Up @@ -383,6 +426,12 @@ mod sequence {
pub typed_context_specific_optional: Option<&'a [u8]>,
}

#[derive(Sequence)]
#[asn1(error = CustomError)]
pub struct TypeWithCustomError {
pub simple: bool,
}

#[test]
fn idp_test() {
let idp = IssuingDistributionPointExample::from_der(&hex!("30038101FF")).unwrap();
Expand Down Expand Up @@ -444,6 +493,9 @@ mod sequence {
PRIME256V1_OID,
ObjectIdentifier::try_from(algorithm_identifier.parameters.unwrap()).unwrap()
);

let t = TypeWithCustomError::from_der(&hex!("30030101FF")).unwrap();
assert!(t.simple);
}

#[test]
Expand Down
48 changes: 29 additions & 19 deletions der_derive/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,47 @@ pub(crate) struct TypeAttrs {
///
/// The default value is `EXPLICIT`.
pub tag_mode: TagMode,
pub error: Option<Path>,
}

impl TypeAttrs {
/// Parse attributes from a struct field or enum variant.
pub fn parse(attrs: &[Attribute]) -> syn::Result<Self> {
let mut tag_mode = None;
let mut error = None;

let mut parsed_attrs = Vec::new();
AttrNameValue::from_attributes(attrs, &mut parsed_attrs)?;

for attr in parsed_attrs {
// `tag_mode = "..."` attribute
let mode = attr.parse_value("tag_mode")?.ok_or_else(|| {
syn::Error::new_spanned(
&attr.name,
"invalid `asn1` attribute (valid options are `tag_mode`)",
)
})?;

if tag_mode.is_some() {
return Err(syn::Error::new_spanned(
&attr.name,
"duplicate ASN.1 `tag_mode` attribute",
));
attrs.iter().try_for_each(|attr| {
if !attr.path().is_ident(ATTR_NAME) {
return Ok(());
}

tag_mode = Some(mode);
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("tag_mode") {
if tag_mode.is_some() {
abort!(attr, "duplicate ASN.1 `tag_mode` attribute");
}

tag_mode = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("error") {
if error.is_some() {
abort!(attr, "duplicate ASN.1 `error` attribute");
}

error = Some(meta.value()?.parse()?);
} else {
return Err(syn::Error::new_spanned(
attr,
"invalid `asn1` attribute (valid options are `tag_mode` and `error`)",
));
}

Ok(())
})
})?;

Ok(Self {
tag_mode: tag_mode.unwrap_or_default(),
error,
})
}
}
Expand Down
33 changes: 23 additions & 10 deletions der_derive/src/choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ mod variant;
use self::variant::ChoiceVariant;
use crate::{default_lifetime, TypeAttrs};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};
use quote::{quote, ToTokens};
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam, Path};

/// Derive the `Choice` trait for an enum.
pub(crate) struct DeriveChoice {
Expand All @@ -20,6 +20,9 @@ pub(crate) struct DeriveChoice {

/// Variants of this `Choice`.
variants: Vec<ChoiceVariant>,

/// Error type for `DecodeValue` implementation.
error: Option<Path>,
}

impl DeriveChoice {
Expand All @@ -33,7 +36,7 @@ impl DeriveChoice {
),
};

let type_attrs = TypeAttrs::parse(&input.attrs)?;
let mut type_attrs = TypeAttrs::parse(&input.attrs)?;
let variants = data
.variants
.iter()
Expand All @@ -44,6 +47,7 @@ impl DeriveChoice {
ident: input.ident,
generics: input.generics.clone(),
variants,
error: type_attrs.error.take(),
})
}

Expand Down Expand Up @@ -84,6 +88,12 @@ impl DeriveChoice {
tagged_body.push(variant.to_tagged_tokens());
}

let error = self
.error
.as_ref()
.map(ToTokens::to_token_stream)
.unwrap_or_else(|| quote! { ::der::Error });

quote! {
impl #impl_generics ::der::Choice<#lifetime> for #ident #ty_generics #where_clause {
fn can_decode(tag: ::der::Tag) -> bool {
Expand All @@ -92,17 +102,20 @@ impl DeriveChoice {
}

impl #impl_generics ::der::Decode<#lifetime> for #ident #ty_generics #where_clause {
type Error = ::der::Error;
type Error = #error;

fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::core::result::Result<Self, #error> {
use der::Reader as _;
match ::der::Tag::peek(reader)? {
#(#decode_body)*
actual => Err(der::ErrorKind::TagUnexpected {
expected: None,
actual
}
.into()),
actual => Err(::der::Error::new(
::der::ErrorKind::TagUnexpected {
expected: None,
actual
},
reader.position()
).into()
),
}
}
}
Expand Down
51 changes: 34 additions & 17 deletions der_derive/src/enumerated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
//! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to
//! enum variants.
use crate::attributes::AttrNameValue;
use crate::{default_lifetime, ATTR_NAME};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant};
use quote::{quote, ToTokens};
use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, LitStr, Path, Variant};

/// Valid options for the `#[repr]` attribute on `Enumerated` types.
const REPR_TYPES: &[&str] = &["u8", "u16", "u32"];
Expand All @@ -24,6 +23,9 @@ pub(crate) struct DeriveEnumerated {

/// Variants of this enum.
variants: Vec<EnumeratedVariant>,

/// Error type for `DecodeValue` implementation.
error: Option<Path>,
}

impl DeriveEnumerated {
Expand All @@ -40,22 +42,30 @@ impl DeriveEnumerated {
// Reject `asn1` attributes, parse the `repr` attribute
let mut repr: Option<Ident> = None;
let mut integer = false;
let mut error: Option<Path> = None;

for attr in &input.attrs {
if attr.path().is_ident(ATTR_NAME) {
let kvs = match AttrNameValue::parse_attribute(attr) {
Ok(kvs) => kvs,
Err(e) => abort!(attr, e),
};
for anv in kvs {
if anv.name.is_ident("type") {
match anv.value.value().as_str() {
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("type") {
let value: LitStr = meta.value()?.parse()?;
match value.value().as_str() {
"ENUMERATED" => integer = false,
"INTEGER" => integer = true,
s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")),
s => abort!(value, format_args!("`type = \"{s}\"` is unsupported")),
}
} else if meta.path.is_ident("error") {
let path: Path = meta.value()?.parse()?;
error = Some(path);
} else {
return Err(syn::Error::new_spanned(
&meta.path,
"invalid `asn1` attribute (valid options are `type` and `error`)",
));
}
}

Ok(())
})?;
} else if attr.path().is_ident("repr") {
if repr.is_some() {
abort!(
Expand Down Expand Up @@ -97,6 +107,7 @@ impl DeriveEnumerated {
})?,
variants,
integer,
error,
})
}

Expand All @@ -115,14 +126,20 @@ impl DeriveEnumerated {
try_from_body.push(variant.to_try_from_tokens());
}

let error = self
.error
.as_ref()
.map(ToTokens::to_token_stream)
.unwrap_or_else(|| quote! { ::der::Error });

quote! {
impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
type Error = ::der::Error;
type Error = #error;

fn decode_value<R: ::der::Reader<#default_lifetime>>(
reader: &mut R,
header: ::der::Header
) -> ::der::Result<Self> {
) -> ::core::result::Result<Self, #error> {
<#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
}
}
Expand All @@ -142,12 +159,12 @@ impl DeriveEnumerated {
}

impl TryFrom<#repr> for #ident {
type Error = ::der::Error;
type Error = #error;

fn try_from(n: #repr) -> ::der::Result<Self> {
fn try_from(n: #repr) -> ::core::result::Result<Self, #error> {
match n {
#(#try_from_body)*
_ => Err(#tag.value_error())
_ => Err(#tag.value_error().into())
}
}
}
Expand Down
Loading

0 comments on commit 87bd37f

Please sign in to comment.