Skip to content

Commit

Permalink
Add fast ser/de paths for structs of prims with repr(C) (#2251)
Browse files Browse the repository at this point in the history
  • Loading branch information
Centril authored Feb 11, 2025
1 parent fe68fe6 commit 0b2364b
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 4 deletions.
15 changes: 15 additions & 0 deletions crates/bindings-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,25 @@ mod sym {
symbol!(primary_key);
symbol!(private);
symbol!(public);
symbol!(repr);
symbol!(sats);
symbol!(scheduled);
symbol!(unique);
symbol!(update);

symbol!(u8);
symbol!(i8);
symbol!(u16);
symbol!(i16);
symbol!(u32);
symbol!(i32);
symbol!(u64);
symbol!(i64);
symbol!(u128);
symbol!(i128);
symbol!(f32);
symbol!(f64);

impl PartialEq<Symbol> for syn::Ident {
fn eq(&self, sym: &Symbol) -> bool {
self == sym.0
Expand Down Expand Up @@ -350,6 +364,7 @@ pub fn schema_type(input: StdTokenStream) -> StdTokenStream {
sats_derive(input, true, |ty| {
let ident = ty.ident;
let name = &ty.name;

let krate = &ty.krate;
TokenStream::from_iter([
sats::derive_satstype(ty),
Expand Down
121 changes: 121 additions & 0 deletions crates/bindings-macro/src/sats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub(crate) struct SatsType<'a> {
#[allow(unused)]
pub original_attrs: &'a [syn::Attribute],
pub data: SatsTypeData<'a>,
/// Was the type marked as `#[repr(C)]`?
pub is_repr_c: bool,
}

pub(crate) enum SatsTypeData<'a> {
Expand Down Expand Up @@ -78,6 +80,17 @@ pub(crate) fn sats_type_from_derive(
extract_sats_type(&input.ident, &input.generics, &input.attrs, data, crate_fallback)
}

fn is_repr_c(attrs: &[syn::Attribute]) -> bool {
let mut is_repr_c = false;
for attr in attrs.iter().filter(|a| a.path() == sym::repr) {
let _ = attr.parse_nested_meta(|meta| {
is_repr_c |= meta.path.is_ident("C");
Ok(())
});
}
is_repr_c
}

pub(crate) fn extract_sats_type<'a>(
ident: &'a syn::Ident,
generics: &'a syn::Generics,
Expand Down Expand Up @@ -112,13 +125,16 @@ pub(crate) fn extract_sats_type<'a>(
let krate = krate.unwrap_or(crate_fallback);
let name = name.unwrap_or_else(|| crate::util::ident_to_litstr(ident));

let is_repr_c = is_repr_c(attrs);

Ok(SatsType {
ident,
generics,
name,
krate,
original_attrs: attrs,
data,
is_repr_c,
})
}

Expand Down Expand Up @@ -220,6 +236,48 @@ fn add_type_bounds(generics: &mut syn::Generics, trait_bound: &TokenStream) {
}
}

/// Returns the list of types if syntactically we see that the `ty`
/// is `#[repr(C)]` of only primitives.
///
/// We later assert semantically in generated code that the list of types
/// actually are primitives.
/// We'll also check that `ty` is paddingless.
fn extract_repr_c_primitive<'a>(ty: &'a SatsType) -> Option<Vec<&'a syn::Ident>> {
// Ensure we have a `#[repr(C)]` struct.
if !ty.is_repr_c {
return None;
}
let SatsTypeData::Product(fields) = &ty.data else {
return None;
};

// Ensure every field is a primitive and collect the idents.
const PRIM_TY: &[sym::Symbol] = &[
sym::u8,
sym::i8,
sym::u16,
sym::i16,
sym::u32,
sym::i32,
sym::u64,
sym::i64,
sym::u128,
sym::i128,
sym::f32,
sym::f64,
];
let mut field_tys = Vec::with_capacity(fields.len());
for field in fields {
if let syn::Type::Path(ty) = &field.ty {
let ident = ty.path.get_ident().filter(|ident| PRIM_TY.iter().any(|p| ident == p))?;
field_tys.push(ident);
} else {
return None;
}
}
Some(field_tys)
}

pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
let (name, tuple_name) = (&ty.ident, &ty.name);
let spacetimedb_lib = &ty.krate;
Expand Down Expand Up @@ -249,6 +307,33 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {

match &ty.data {
SatsTypeData::Product(fields) => {
let mut fast_body = None;
if let Some(fields) = extract_repr_c_primitive(ty) {
fast_body = Some(quote! {
#[inline(always)]
fn deserialize_from_bsatn<R: #spacetimedb_lib::buffer::BufReader<'de>>(
mut deserializer: #spacetimedb_lib::bsatn::Deserializer<'de, R>
) -> Result<Self, #spacetimedb_lib::bsatn::DecodeError> {
const _: () = {
#(#spacetimedb_lib::bsatn::assert_is_primitive_type::<#fields>();)*
};
// This guarantees that `Self` has no padding.
if const { core::mem::size_of::<Self>() == #(core::mem::size_of::<#fields>())+* } {
let bytes = deserializer.get_slice(core::mem::size_of::<Self>())?;
let ptr = bytes as *const [u8] as *const u8 as *const Self;
// SAFETY:
// - `ptr` is valid for reads, `size_of::<T>()`.
// - `ptr` is trivially properly aligned (alignment = 1).
// - `ptr` points to a properly initialized `Foo`
// as we've guaranteed that there is no padding.
Ok(unsafe { core::ptr::read(ptr) })
} else {
Self::deserialize(deserializer)
}
}
});
}

let n_fields = fields.len();

let field_names = fields.iter().map(|f| f.ident.unwrap()).collect::<Vec<_>>();
Expand All @@ -260,6 +345,8 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream {
#[allow(clippy::all)]
const _: () = {
impl #de_impl_generics #spacetimedb_lib::de::Deserialize<'de> for #name #ty_generics #de_where_clause {
#fast_body

fn deserialize<D: #spacetimedb_lib::de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
deserializer.deserialize_product(__ProductVisitor {
_marker: std::marker::PhantomData::<fn() -> #name #ty_generics>,
Expand Down Expand Up @@ -422,8 +509,41 @@ pub(crate) fn derive_serialize(ty: &SatsType) -> TokenStream {

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let mut fast_body = None;
let body = match &ty.data {
SatsTypeData::Product(fields) => {
if let Some(fields) = extract_repr_c_primitive(ty) {
fast_body = Some(quote! {
#[inline(always)]
fn serialize_into_bsatn<W: #spacetimedb_lib::buffer::BufWriter>(
&self,
serializer: #spacetimedb_lib::bsatn::Serializer<'_, W>
) -> Result<(), #spacetimedb_lib::bsatn::EncodeError> {
const _: () = {
#(#spacetimedb_lib::bsatn::assert_is_primitive_type::<#fields>();)*
};
// This guarantees that `Self` has no padding.
if const { core::mem::size_of::<Self>() == #(core::mem::size_of::<#fields>())+* } {
// SAFETY:
// - We know `self` is non-null as it's a shared reference
// and we know it's valid for reads for `core::mem::size_of::<Self>()` bytes.
// Alignment of `u8` is 1, so it's trivially satisfied.
// - The slice is all within `self`, so in the same allocated object.
// - `self` does point to `core::mem::size_of::<Self>()` consecutive `u8`s,
// as per `assert_is_primitive_type` above,
// we know none of the fields of `Self` have any padding.
// - We're not going to mutate the memory within `bytes`.
// - We know `core::mem::size_of::<Self>() < isize::MAX`.
let bytes = unsafe { core::slice::from_raw_parts(self as *const _ as *const u8, core::mem::size_of::<Self>()) };
serializer.raw_write_bytes(bytes);
Ok(())
} else {
self.serialize(serializer)
}
}
});
}

let fieldnames = fields.iter().map(|field| field.ident.unwrap());
let tys = fields.iter().map(|f| &f.ty);
let fieldnamestrings = fields.iter().map(|field| field.name.as_ref().unwrap());
Expand Down Expand Up @@ -456,6 +576,7 @@ pub(crate) fn derive_serialize(ty: &SatsType) -> TokenStream {
};
quote! {
impl #impl_generics #spacetimedb_lib::ser::Serialize for #name #ty_generics #where_clause {
#fast_body
fn serialize<S: #spacetimedb_lib::ser::Serializer>(&self, __serializer: S) -> Result<S::Ok, S::Error> {
#body
}
Expand Down
29 changes: 28 additions & 1 deletion crates/sats/src/bsatn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ pub use crate::buffer::DecodeError;
pub use ser::BsatnError as EncodeError;

/// Serialize `value` into the buffered writer `w` in the BSATN format.
#[inline]
pub fn to_writer<W: BufWriter, T: Serialize + ?Sized>(w: &mut W, value: &T) -> Result<(), EncodeError> {
value.serialize(Serializer::new(w))
value.serialize_into_bsatn(Serializer::new(w))
}

/// Serialize `value` into a `Vec<u8>` in the BSATN format.
Expand Down Expand Up @@ -150,6 +151,32 @@ impl ToBsatn for ProductValue {
}
}

mod private_is_primitive_type {
pub trait Sealed {}
}
/// A primitive type.
/// This is purely intended for use in `crates/bindings-macro`.
///
/// # Safety
///
/// Implementing this guarantees that the type has no padding, recursively.
#[doc(hidden)]
pub unsafe trait IsPrimitiveType: private_is_primitive_type::Sealed {}
macro_rules! is_primitive_type {
($($prim:ty),*) => {
$(
impl private_is_primitive_type::Sealed for $prim {}
// SAFETY: the type is primitive and has no padding.
unsafe impl IsPrimitiveType for $prim {}
)*
};
}
is_primitive_type!(u8, i8, u16, i16, u32, i32, u64, i64, u128, i128, f32, f64);

/// Enforces that a type is a primitive.
/// This is purely intended for use in `crates/bindings-macro`.
pub const fn assert_is_primitive_type<T: IsPrimitiveType>() {}

#[cfg(test)]
mod tests {
use super::{to_vec, DecodeError};
Expand Down
4 changes: 3 additions & 1 deletion crates/sats/src/bsatn/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ impl<'a, 'de, R: BufReader<'de>> Deserializer<'a, R> {
}

/// Reads a slice of `len` elements.
pub(crate) fn get_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError> {
#[inline]
#[doc(hidden)]
pub fn get_slice(&mut self, len: usize) -> Result<&'de [u8], DecodeError> {
self.reader.get_slice(len)
}

Expand Down
10 changes: 10 additions & 0 deletions crates/sats/src/bsatn/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ impl<'a, W> Serializer<'a, W> {
}
}

impl<W: BufWriter> Serializer<'_, W> {
/// Directly write `bytes` to the writer.
/// This is a raw API. Only use this if you know what you are doing.
#[inline(always)]
#[doc(hidden)]
pub fn raw_write_bytes(self, bytes: &[u8]) {
self.writer.put_slice(bytes);
}
}

/// An error during BSATN serialization.
#[derive(Debug, Clone)]
// TODO: rename to EncodeError
Expand Down
11 changes: 10 additions & 1 deletion crates/sats/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ pub mod serde;
#[doc(hidden)]
pub use impls::{visit_named_product, visit_seq_product, WithBound};

use crate::{i256, u256};
use crate::buffer::BufReader;
use crate::{bsatn, i256, u256};
use core::fmt;
use core::marker::PhantomData;
use smallvec::SmallVec;
Expand Down Expand Up @@ -557,6 +558,14 @@ pub trait Deserialize<'de>: Sized {
/// Deserialize this value from the given `deserializer`.
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error>;

#[doc(hidden)]
/// Deserialize this value from the given the BSATN `deserializer`.
fn deserialize_from_bsatn<R: BufReader<'de>>(
deserializer: bsatn::Deserializer<'de, R>,
) -> Result<Self, bsatn::DecodeError> {
Self::deserialize(deserializer)
}

/// used in the Deserialize for Vec<T> impl to allow specializing deserializing Vec<T> as bytes
#[doc(hidden)]
#[inline(always)]
Expand Down
11 changes: 10 additions & 1 deletion crates/sats/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ pub trait Serializer: Sized {
use ethnum::{i256, u256};
pub use spacetimedb_bindings_macro::Serialize;

use crate::AlgebraicType;
use crate::{bsatn, buffer::BufWriter, AlgebraicType};

/// A **data structure** that can be serialized into any data format supported by SATS.
///
Expand All @@ -216,6 +216,15 @@ pub trait Serialize {
/// Serialize `self` in the data format of `S` using the provided `serializer`.
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error>;

#[doc(hidden)]
/// Serialize `self` in the data format BSATN using the provided BSATN `serializer`.
fn serialize_into_bsatn<W: BufWriter>(
&self,
serializer: bsatn::Serializer<'_, W>,
) -> Result<(), bsatn::EncodeError> {
self.serialize(serializer)
}

/// Used in the `Serialize for Vec<T>` implementation
/// to allow a specialized serialization of `Vec<T>` as bytes.
#[doc(hidden)]
Expand Down

2 comments on commit 0b2364b

@github-actions
Copy link

@github-actions github-actions bot commented on 0b2364b Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callgrind benchmark results Error when comparing benchmarks: Couldn't find AWS credentials in environment, credentials file, or IAM role.

Caused by:
Couldn't find AWS credentials in environment, credentials file, or IAM role.

@github-actions
Copy link

@github-actions github-actions bot commented on 0b2364b Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Criterion benchmark results

Error when comparing benchmarks: Couldn't find AWS credentials in environment, credentials file, or IAM role.

Caused by:
Couldn't find AWS credentials in environment, credentials file, or IAM role.

Please sign in to comment.