Skip to content

Commit

Permalink
Lifetimes and performance
Browse files Browse the repository at this point in the history
- Relaxes the 'static bound on tracked types, they can now have lifetimes
- Improves performance through hash-based constraint generation and validation acceleration
- Fixes possible name clashes in tracked methods
  • Loading branch information
laurmaedje committed May 10, 2023
1 parent 2b1d9dd commit 0c141bb
Show file tree
Hide file tree
Showing 8 changed files with 583 additions and 345 deletions.
6 changes: 2 additions & 4 deletions macros/src/memoize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ fn prepare_arg(input: &syn::FnArg) -> Result<Argument> {
mutability: Some(_), ..
}) = typed.ty.as_ref()
{
bail!(
typed.ty,
"memoized functions cannot have mutable parameters"
)
bail!(typed.ty, "memoized functions cannot have mutable parameters")
}

Argument::Ident(mutability.clone(), ident.clone())
Expand Down Expand Up @@ -127,6 +124,7 @@ fn process(function: &Function) -> Result<TokenStream> {
::comemo::internal::memoized(
::core::any::TypeId::of::<#unique>(),
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
#closure,
)
} };
Expand Down
181 changes: 117 additions & 64 deletions macros/src/track.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@ pub fn expand(item: &syn::Item) -> Result<TokenStream> {
// Preprocess and validate the methods.
let mut methods = vec![];

let (ty, trait_) = match item {
let (ty, generics, trait_) = match item {
syn::Item::Impl(item) => {
for param in item.generics.params.iter() {
bail!(param, "tracked impl blocks cannot be generic")
match param {
syn::GenericParam::Lifetime(_) => {}
syn::GenericParam::Type(_) => {
bail!(param, "tracked impl blocks cannot use type generics")
}
syn::GenericParam::Const(_) => {
bail!(param, "tracked impl blocks cannot use const generics")
}
}
}

for item in &item.items {
methods.push(prepare_impl_method(&item)?);
}

let ty = item.self_ty.as_ref().clone();
(ty, None)
(ty, &item.generics, None)
}
syn::Item::Trait(item) => {
for param in item.generics.params.iter() {
Expand All @@ -28,17 +36,14 @@ pub fn expand(item: &syn::Item) -> Result<TokenStream> {
}

let name = &item.ident;
let ty = parse_quote! { dyn #name };
(ty, Some(name.clone()))
let ty = parse_quote! { dyn #name + '__comemo_dynamic };
(ty, &item.generics, Some(name.clone()))
}
_ => bail!(
item,
"`track` can only be applied to impl blocks and traits"
),
_ => bail!(item, "`track` can only be applied to impl blocks and traits"),
};

// Produce the necessary items for the type to become trackable.
let scope = create(&ty, trait_, &methods)?;
let scope = create(&ty, generics, trait_, &methods)?;

Ok(quote! {
#item
Expand Down Expand Up @@ -172,104 +177,152 @@ fn prepare_method(vis: syn::Visibility, sig: &syn::Signature) -> Result<Method>
/// Produce the necessary items for a type to become trackable.
fn create(
ty: &syn::Type,
generics: &syn::Generics,
trait_: Option<syn::Ident>,
methods: &[Method],
) -> Result<TokenStream> {
let prefix = trait_.map(|name| quote! { #name for });
let variants = methods.iter().map(create_variant);
let validations = methods.iter().map(create_validation);
let t: syn::GenericParam = parse_quote! { '__comemo_tracked };
let r: syn::GenericParam = parse_quote! { '__comemo_retrack };
let d: syn::GenericParam = parse_quote! { '__comemo_dynamic };
let maybe_cloned = if methods.iter().any(|it| it.mutable) {
quote! { ::core::clone::Clone::clone(self) }
} else {
quote! { self }
};

// Prepare generics.
let (impl_gen, type_gen, where_clause) = generics.split_for_impl();
let mut impl_params: syn::Generics = parse_quote! { #impl_gen };
let mut type_params: syn::Generics = parse_quote! { #type_gen };
if trait_.is_some() {
impl_params.params.push(d.clone());
type_params.params.push(d.clone());
}

let mut impl_params_t: syn::Generics = impl_params.clone();
let mut type_params_t: syn::Generics = type_params.clone();
impl_params_t.params.push(t.clone());
type_params_t.params.push(t.clone());

// Prepare validations.
let prefix = trait_.as_ref().map(|name| quote! { #name for });
let validations: Vec<_> = methods.iter().map(create_validation).collect();
let validate = if !methods.is_empty() {
quote! {
let mut this = #maybe_cloned;
constraint.validate(|call| match &call.0 { #(#validations,)* })
}
} else {
quote! { true }
};
let validate_with_id = if !methods.is_empty() {
quote! {
let mut this = #maybe_cloned;
constraint.validate_with_id(
|call| match &call.0 { #(#validations,)* },
id,
)
}
} else {
quote! { true }
};

// Prepare replying.
let replays = methods.iter().map(create_replay);
let replay = methods.iter().any(|m| m.mutable).then(|| {
quote! {
constraint.replay(|call| match &call.0 { #(#replays,)* });
}
});

// Prepare variants and wrapper methods.
let variants = methods.iter().map(create_variant);
let wrapper_methods = methods
.iter()
.filter(|m| !m.mutable)
.map(|m| create_wrapper(m, false));
let wrapper_methods_mut = methods.iter().map(|m| create_wrapper(m, true));
let maybe_cloned = if methods.iter().any(|it| it.mutable) {
quote! { ::std::clone::Clone::clone(self) }
} else {
quote! { self }
};

Ok(quote! {
#[derive(Clone, PartialEq)]
pub struct __ComemoCall(__ComemoVariant);
impl #impl_params ::comemo::Track for #ty #where_clause {}

#[derive(Clone, PartialEq)]
#[allow(non_camel_case_types)]
enum __ComemoVariant {
#(#variants,)*
}
impl #impl_params ::comemo::Validate for #ty #where_clause {
type Constraint = ::comemo::internal::Constraint<__ComemoCall>;

impl ::comemo::Track for #ty {
#[inline]
fn valid(&self, constraint: &::comemo::Constraint<Self>) -> bool {
let mut this = #maybe_cloned;
constraint.valid(|call| match &call.0 { #(#validations,)* })
fn validate(&self, constraint: &Self::Constraint) -> bool {
#validate
}
}

#[doc(hidden)]
impl ::comemo::internal::Trackable for #ty {
type Call = __ComemoCall;
type Surface = __ComemoSurfaceFamily;
type SurfaceMut = __ComemoSurfaceMutFamily;
#[inline]
fn validate_with_id(&self, constraint: &Self::Constraint, id: usize) -> bool {
#validate_with_id
}

#[inline]
#[allow(unused_variables)]
fn replay(&mut self, constraint: &::comemo::Constraint<Self>) {
constraint.replay(|call| match &call.0 { #(#replays,)* });
fn replay(&mut self, constraint: &Self::Constraint) {
#replay
}
}

#[derive(Clone, PartialEq, Hash)]
pub struct __ComemoCall(__ComemoVariant);

#[derive(Clone, PartialEq, Hash)]
#[allow(non_camel_case_types)]
enum __ComemoVariant {
#(#variants,)*
}

#[doc(hidden)]
impl #impl_params ::comemo::internal::Surfaces for #ty #where_clause {
type Surface<#t> = __ComemoSurface #type_params_t where Self: #t;
type SurfaceMut<#t> = __ComemoSurfaceMut #type_params_t where Self: #t;

#[inline]
fn surface_ref<'a, 'r>(
tracked: &'r ::comemo::Tracked<'a, Self>,
) -> &'r __ComemoSurface<'a> {
fn surface_ref<#t, #r>(
tracked: &#r ::comemo::Tracked<#t, Self>,
) -> &#r Self::Surface<#t> {
// Safety: __ComemoSurface is repr(transparent).
unsafe { &*(tracked as *const _ as *const _) }
}

#[inline]
fn surface_mut_ref<'a, 'r>(
tracked: &'r ::comemo::TrackedMut<'a, Self>,
) -> &'r __ComemoSurfaceMut<'a> {
fn surface_mut_ref<#t, #r>(
tracked: &#r ::comemo::TrackedMut<#t, Self>,
) -> &#r Self::SurfaceMut<#t> {
// Safety: __ComemoSurfaceMut is repr(transparent).
unsafe { &*(tracked as *const _ as *const _) }
}

#[inline]
fn surface_mut_mut<'a, 'r>(
tracked: &'r mut ::comemo::TrackedMut<'a, Self>,
) -> &'r mut __ComemoSurfaceMut<'a> {
fn surface_mut_mut<#t, #r>(
tracked: &#r mut ::comemo::TrackedMut<#t, Self>,
) -> &#r mut Self::SurfaceMut<#t> {
// Safety: __ComemoSurfaceMut is repr(transparent).
unsafe { &mut *(tracked as *mut _ as *mut _) }
}
}

#[repr(transparent)]
pub struct __ComemoSurface<'a>(::comemo::Tracked<'a, #ty>);
pub struct __ComemoSurface #impl_params_t(::comemo::Tracked<#t, #ty>)
#where_clause;

#[allow(dead_code)]
impl #prefix __ComemoSurface<'_> {
impl #impl_params_t #prefix __ComemoSurface #type_params_t {
#(#wrapper_methods)*
}

pub enum __ComemoSurfaceFamily {}
impl<'a> ::comemo::internal::Family<'a> for __ComemoSurfaceFamily {
type Out = __ComemoSurface<'a>;
}

#[repr(transparent)]
pub struct __ComemoSurfaceMut<'a>(::comemo::TrackedMut<'a, #ty>);
pub struct __ComemoSurfaceMut #impl_params_t(::comemo::TrackedMut<#t, #ty>)
#where_clause;

#[allow(dead_code)]
impl #prefix __ComemoSurfaceMut<'_> {
impl #impl_params_t #prefix __ComemoSurfaceMut #type_params_t {
#(#wrapper_methods_mut)*
}

pub enum __ComemoSurfaceMutFamily {}
impl<'a> ::comemo::internal::Family<'a> for __ComemoSurfaceMutFamily {
type Out = __ComemoSurfaceMut<'a>;
}
})
}

Expand Down Expand Up @@ -328,12 +381,12 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
#[track_caller]
#[inline]
#vis #sig {
let call = __ComemoVariant::#name(#(#args.to_owned()),*);
let (value, constraint) = ::comemo::internal::#to_parts;
let output = value.#name(#(#args,)*);
if let Some(constraint) = constraint {
let __comemo_variant = __ComemoVariant::#name(#(#args.to_owned()),*);
let (__comemo_value, __comemo_constraint) = ::comemo::internal::#to_parts;
let output = __comemo_value.#name(#(#args,)*);
if let Some(constraint) = __comemo_constraint {
constraint.push(
__ComemoCall(call),
__ComemoCall(__comemo_variant),
::comemo::internal::hash(&output),
#mutable,
);
Expand Down
Loading

0 comments on commit 0c141bb

Please sign in to comment.