diff --git a/rustler_codegen/src/context.rs b/rustler_codegen/src/context.rs index e1066689..dc4e1187 100644 --- a/rustler_codegen/src/context.rs +++ b/rustler_codegen/src/context.rs @@ -1,5 +1,5 @@ use proc_macro2::{Span, TokenStream}; -use syn::{Data, Field, Ident, Lit, Meta, NestedMeta, Variant}; +use syn::{Data, Field, Fields, Ident, Lit, Meta, NestedMeta, Variant}; use super::RustlerAttr; @@ -14,6 +14,7 @@ pub(crate) struct Context<'a> { pub ident_with_lifetime: proc_macro2::TokenStream, pub variants: Option>, pub struct_fields: Option>, + pub is_tuple_struct: bool, } impl<'a> Context<'a> { @@ -56,12 +57,21 @@ impl<'a> Context<'a> { _ => None, }; + let is_tuple_struct = match ast.data { + Data::Struct(ref data_struct) => match data_struct.fields { + Fields::Unnamed(_) => true, + _ => false, + }, + _ => false, + }; + Self { attrs, ident, ident_with_lifetime, variants, struct_fields, + is_tuple_struct, } } diff --git a/rustler_codegen/src/record.rs b/rustler_codegen/src/record.rs index feee4de7..a67f2ffb 100644 --- a/rustler_codegen/src/record.rs +++ b/rustler_codegen/src/record.rs @@ -1,6 +1,6 @@ use proc_macro2::TokenStream; -use syn::{self, Field}; +use syn::{self, Field, Index}; use super::context::Context; use super::RustlerAttr; @@ -50,20 +50,29 @@ fn gen_decoder(ctx: &Context, atom_defs: &TokenStream, fields: &[&Field]) -> Tok .iter() .enumerate() .map(|(index, field)| { - let ident = field.ident.as_ref().unwrap(); + let ident = field.ident.as_ref(); + let pos_in_struct = if let Some(ident) = ident { + ident.to_string() + } else { + index.to_string() + }; let error_message = format!( - "Could not decode field :{} on Record {}", - ident.to_string(), + "Could not decode field {} on Record {}", + pos_in_struct, struct_name.to_string() ); + let actual_index = index + 1; let decoder = quote! { - match ::rustler::Decoder::decode(terms[#index + 1]) { + match ::rustler::Decoder::decode(terms[#actual_index]) { Err(_) => return Err(::rustler::Error::RaiseTerm(Box::new(#error_message))), Ok(value) => value } }; - quote! { #ident: #decoder } + match ident { + None => quote! { #decoder }, + Some(ident) => quote! { #ident: #decoder }, + } }) .collect(); @@ -71,13 +80,23 @@ fn gen_decoder(ctx: &Context, atom_defs: &TokenStream, fields: &[&Field]) -> Tok let struct_name_str = struct_name.to_string(); // The implementation itself + let construct = if ctx.is_tuple_struct { + quote! { + #struct_name ( #(#field_defs),* ) + } + } else { + quote! { + #struct_name { #(#field_defs),* } + } + }; let gen = quote! { impl<'a> ::rustler::Decoder<'a> for #struct_type { fn decode(term: ::rustler::Term<'a>) -> Result { #atom_defs let terms = match ::rustler::types::tuple::get_tuple(term) { - Err(_) => return Err(::rustler::Error::RaiseTerm(Box::new(format!("Invalid Record structure for {}", #struct_name_str)))), + Err(_) => return Err(::rustler::Error::RaiseTerm( + Box::new(format!("Invalid Record structure for {}", #struct_name_str)))), Ok(value) => value, }; @@ -92,9 +111,7 @@ fn gen_decoder(ctx: &Context, atom_defs: &TokenStream, fields: &[&Field]) -> Tok } Ok( - #struct_name { - #(#field_defs),* - } + #construct ) } } @@ -109,9 +126,14 @@ fn gen_encoder(ctx: &Context, atom_defs: &TokenStream, fields: &[&Field]) -> Tok // Make a field encoder expression for each of the items in the struct. let field_encoders: Vec = fields .iter() - .map(|field| { - let field_ident = field.ident.as_ref().unwrap(); - let field_source = quote! { self.#field_ident }; + .enumerate() + .map(|(index, field)| { + let literal_index = Index::from(index); + let field_source = match field.ident.as_ref() { + None => quote! { self.#literal_index }, + Some(ident) => quote! { self.#ident }, + }; + quote! { #field_source.encode(env) } }) .collect(); diff --git a/rustler_codegen/src/tuple.rs b/rustler_codegen/src/tuple.rs index ab3f32a5..487be523 100644 --- a/rustler_codegen/src/tuple.rs +++ b/rustler_codegen/src/tuple.rs @@ -1,6 +1,6 @@ use proc_macro2::TokenStream; -use syn::{self, Field}; +use syn::{self, Field, Index}; use super::context::Context; @@ -13,13 +13,13 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { .expect("NifTuple can only be used with structs"); let decoder = if ctx.decode() { - gen_decoder(&ctx, &struct_fields, false) + gen_decoder(&ctx, &struct_fields) } else { quote! {} }; let encoder = if ctx.encode() { - gen_encoder(&ctx, &struct_fields, false) + gen_encoder(&ctx, &struct_fields) } else { quote! {} }; @@ -32,7 +32,7 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { gen } -fn gen_decoder(ctx: &Context, fields: &[&Field], is_tuple: bool) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { let struct_type = &ctx.ident_with_lifetime; let struct_name = ctx.ident; @@ -41,7 +41,18 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], is_tuple: bool) -> TokenStream .iter() .enumerate() .map(|(index, field)| { - let error_message = format!("Could not decode index {} on tuple", index); + let ident = field.ident.as_ref(); + let pos_in_struct = if let Some(ident) = ident { + ident.to_string() + } else { + index.to_string() + }; + let error_message = format!( + "Could not decode field {} on {}", + pos_in_struct, + struct_name.to_string() + ); + let decoder = quote! { match ::rustler::Decoder::decode(terms[#index]) { Err(_) => return Err(::rustler::Error::RaiseTerm(Box::new(#error_message))), @@ -49,11 +60,9 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], is_tuple: bool) -> TokenStream } }; - if is_tuple { - unimplemented!(); - } else { - let ident = field.ident.as_ref().unwrap(); - quote! { #ident: #decoder } + match ident { + None => quote! { #decoder }, + Some(ident) => quote! { #ident: #decoder }, } }) .collect(); @@ -61,6 +70,15 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], is_tuple: bool) -> TokenStream let field_num = field_defs.len(); // The implementation itself + let construct = if ctx.is_tuple_struct { + quote! { + #struct_name ( #(#field_defs),* ) + } + } else { + quote! { + #struct_name { #(#field_defs),* } + } + }; let gen = quote! { impl<'a> ::rustler::Decoder<'a> for #struct_type { fn decode(term: ::rustler::Term<'a>) -> Result { @@ -69,9 +87,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], is_tuple: bool) -> TokenStream return Err(::rustler::Error::BadArg); } Ok( - #struct_name { - #(#field_defs),* - } + #construct ) } } @@ -80,19 +96,20 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], is_tuple: bool) -> TokenStream gen } -fn gen_encoder(ctx: &Context, fields: &[&Field], is_tuple: bool) -> TokenStream { +fn gen_encoder(ctx: &Context, fields: &[&Field]) -> TokenStream { let struct_type = &ctx.ident_with_lifetime; // Make a field encoder expression for each of the items in the struct. let field_encoders: Vec = fields .iter() - .map(|field| { - let field_source = if is_tuple { - unimplemented!(); - } else { - let field_ident = field.ident.as_ref().unwrap(); - quote! { self.#field_ident } + .enumerate() + .map(|(index, field)| { + let literal_index = Index::from(index); + let field_source = match field.ident.as_ref() { + None => quote! { self.#literal_index }, + Some(ident) => quote! { self.#ident }, }; + quote! { #field_source.encode(env) } }) .collect(); diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index f4fc24c6..59ad387b 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -61,6 +61,10 @@ defmodule RustlerTest do def unit_enum_echo(_), do: err() def untagged_enum_echo(_), do: err() def untagged_enum_with_truthy(_), do: err() + def newtype_echo(_), do: err() + def tuplestruct_echo(_), do: err() + def newtype_record_echo(_), do: err() + def tuplestruct_record_echo(_), do: err() def dirty_io(), do: err() def dirty_cpu(), do: err() diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index 71a6cb95..c00975f2 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -86,6 +86,14 @@ rustler::init!( 1, test_codegen::untagged_enum_with_truthy ), + ("newtype_echo", 1, test_codegen::newtype_echo), + ("tuplestruct_echo", 1, test_codegen::tuplestruct_echo), + ("newtype_record_echo", 1, test_codegen::newtype_record_echo), + ( + "tuplestruct_record_echo", + 1, + test_codegen::tuplestruct_record_echo + ), ("dirty_cpu", 0, test_dirty::dirty_cpu, DirtyCpu), ("dirty_io", 0, test_dirty::dirty_io, DirtyIo), ("sum_range", 1, test_range::sum_range), diff --git a/rustler_tests/native/rustler_test/src/test_codegen.rs b/rustler_tests/native/rustler_test/src/test_codegen.rs index 21d852a1..12d6c2aa 100644 --- a/rustler_tests/native/rustler_test/src/test_codegen.rs +++ b/rustler_tests/native/rustler_test/src/test_codegen.rs @@ -88,3 +88,40 @@ pub fn untagged_enum_with_truthy<'a>( let untagged_enum: UntaggedEnumWithTruthy = args[0].decode()?; Ok(untagged_enum) } + +#[derive(NifTuple)] +pub struct Newtype(i64); + +pub fn newtype_echo<'a>(_env: Env<'a>, args: &[Term<'a>]) -> NifResult { + let newtype: Newtype = args[0].decode()?; + Ok(newtype) +} + +#[derive(NifTuple)] +pub struct TupleStruct(i64, i64, i64); + +pub fn tuplestruct_echo<'a>(_env: Env<'a>, args: &[Term<'a>]) -> NifResult { + let tuplestruct: TupleStruct = args[0].decode()?; + Ok(tuplestruct) +} + +#[derive(NifRecord)] +#[tag = "newtype"] +pub struct NewtypeRecord(i64); + +pub fn newtype_record_echo<'a>(_env: Env<'a>, args: &[Term<'a>]) -> NifResult { + let newtype: NewtypeRecord = args[0].decode()?; + Ok(newtype) +} + +#[derive(NifRecord)] +#[tag = "tuplestruct"] +pub struct TupleStructRecord(i64, i64, i64); + +pub fn tuplestruct_record_echo<'a>( + _env: Env<'a>, + args: &[Term<'a>], +) -> NifResult { + let tuplestruct: TupleStructRecord = args[0].decode()?; + Ok(tuplestruct) +} diff --git a/rustler_tests/test/codegen_test.exs b/rustler_tests/test/codegen_test.exs index b448c9f6..91c38ae1 100644 --- a/rustler_tests/test/codegen_test.exs +++ b/rustler_tests/test/codegen_test.exs @@ -7,6 +7,16 @@ defmodule AddRecord do defrecord :record, [lhs: 1, rhs: 2] end +defmodule NewtypeRecord do + import Record + defrecord :newtype, [a: 1] +end + +defmodule TupleStructRecord do + import Record + defrecord :tuplestruct, [a: 1, b: 2, c: 3] +end + defmodule RustlerTest.CodegenTest do use ExUnit.Case, async: true @@ -19,7 +29,7 @@ defmodule RustlerTest.CodegenTest do test "with invalid tuple" do value = {"invalid", 2} - assert_raise ErlangError, "Erlang error: \"Could not decode index 0 on tuple\"", fn -> + assert_raise ErlangError, "Erlang error: \"Could not decode field lhs on AddTuple\"", fn -> RustlerTest.tuple_echo(value) end end @@ -75,7 +85,7 @@ defmodule RustlerTest.CodegenTest do require AddRecord value = AddRecord.record(lhs: 5, rhs: "invalid") - assert_raise ErlangError, "Erlang error: \"Could not decode field :rhs on Record AddRecord\"", fn -> + assert_raise ErlangError, "Erlang error: \"Could not decode field rhs on Record AddRecord\"", fn -> RustlerTest.record_echo(value) end end @@ -103,4 +113,56 @@ defmodule RustlerTest.CodegenTest do assert false == RustlerTest.untagged_enum_with_truthy(false) assert false == RustlerTest.untagged_enum_with_truthy(nil) end + + test "newtype tuple" do + assert {1} == RustlerTest.newtype_echo({1}) + assert_raise ErlangError, "Erlang error: \"Could not decode field 0 on Newtype\"", fn -> + RustlerTest.newtype_echo({"with error message"}) + end + assert_raise ArgumentError, fn -> + RustlerTest.newtype_echo("will result in argument error") + end + end + + test "tuplestruct tuple" do + assert {1, 2, 3} == RustlerTest.tuplestruct_echo({1, 2, 3}) + + assert_raise ArgumentError, fn -> + RustlerTest.tuplestruct_echo({1, 2}) + end + + assert_raise ErlangError, "Erlang error: \"Could not decode field 1 on TupleStruct\"", fn -> + RustlerTest.tuplestruct_echo({1, "with error message", 3}) + end + + assert_raise ArgumentError, fn -> + RustlerTest.tuplestruct_echo("will result in argument error") + end + end + + test "newtype record" do + require NewtypeRecord + value = NewtypeRecord.newtype() + assert value == RustlerTest.newtype_record_echo(value) + assert :invalid_record == RustlerTest.newtype_record_echo({"with error message"}) + + assert_raise ErlangError, "Erlang error: \"Invalid Record structure for NewtypeRecord\"", fn -> + RustlerTest.newtype_record_echo("error") + end + + assert_raise ErlangError, "Erlang error: \"Could not decode field 0 on Record NewtypeRecord\"", fn -> + RustlerTest.newtype_record_echo(NewtypeRecord.newtype(a: "error")) + end + end + + test "tuplestruct record" do + require TupleStructRecord + value = TupleStructRecord.tuplestruct() + assert value == RustlerTest.tuplestruct_record_echo(value) + assert :invalid_record == RustlerTest.tuplestruct_record_echo({"invalid"}) + + assert_raise ErlangError, "Erlang error: \"Invalid Record structure for TupleStructRecord\"", fn -> + RustlerTest.tuplestruct_record_echo("error") + end + end end