Skip to content

Commit

Permalink
Local memoization
Browse files Browse the repository at this point in the history
  • Loading branch information
Dherse committed Dec 29, 2023
1 parent ddb3773 commit 15c8d5b
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 54 deletions.
31 changes: 28 additions & 3 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ use syn::{parse_quote, Error, Result};
///
/// Furthermore, memoized functions cannot use destructuring patterns in their
/// arguments.
///
/// # Local memoization
///
/// In the case where you would want to explicitely restrict usage to a single
/// thread, you can use the [`#[memoize(local)]`](macro@memoize) attribute.
/// This will use thread-local storage for the cache and requires a call to the
/// [`local_evict`](comemo::local_evict) function to clear the cache.
///
/// Additionally, if you wish to pass borrowed arguments to a memoized function
/// you will need to manually annotate it with the `'local` lifetime. This allows
/// the macro to specify the lifetime of the tracked value correctly.
///
/// # Example
/// ```
Expand All @@ -71,12 +82,26 @@ use syn::{parse_quote, Error, Result};
/// })
/// .sum()
/// }
///
/// // Evaluate a `.calc` script in a thread-local cache.
/// // /!\ Notice the `'local` lifetime annotation.
/// #[comemo::memoize(local)]
/// fn evaluate(script: &'local str, files: &comemo::Tracked<'local, Files>) -> i32 {
/// script
/// .split('+')
/// .map(str::trim)
/// .map(|part| match part.strip_prefix("eval ") {
/// Some(path) => evaluate(&files.read(path), files),
/// None => part.parse::<i32>().unwrap(),
/// })
/// .sum()
/// }
/// ```
///
#[proc_macro_attribute]
pub fn memoize(_: BoundaryStream, stream: BoundaryStream) -> BoundaryStream {
let func = syn::parse_macro_input!(stream as syn::Item);
memoize::expand(&func)
pub fn memoize(stream: BoundaryStream, item: BoundaryStream) -> BoundaryStream {
let func = syn::parse_macro_input!(item as syn::Item);
memoize::expand(stream.into(), &func)
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
Expand Down
105 changes: 84 additions & 21 deletions macros/src/memoize.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
use syn::{parse::{Parse, ParseStream}, token::Token};

use super::*;

/// Memoize a function.
pub fn expand(item: &syn::Item) -> Result<proc_macro2::TokenStream> {
pub fn expand(stream: TokenStream, item: &syn::Item) -> Result<proc_macro2::TokenStream> {
let meta: Meta = syn::parse2(stream)?;
let syn::Item::Fn(item) = item else {
bail!(item, "`memoize` can only be applied to functions and methods");
};

// Preprocess and validate the function.
let function = prepare(item)?;
let function = prepare(&meta, item)?;

// Rewrite the function's body to memoize it.
process(&function)
}

/// The `..` in `#[memoize(..)]`.
pub struct Meta {
pub local: bool,
}

impl Parse for Meta {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
local: parse_flag::<kw::local>(input)?,
})
}
}

/// Details about a function that should be memoized.
struct Function {
item: syn::ItemFn,
args: Vec<Argument>,
output: syn::Type,
local: bool,
}

/// An argument to a memoized function.
Expand All @@ -27,7 +44,7 @@ enum Argument {
}

/// Preprocess and validate a function.
fn prepare(function: &syn::ItemFn) -> Result<Function> {
fn prepare(meta: &Meta, function: &syn::ItemFn) -> Result<Function> {
let mut args = vec![];

for input in &function.sig.inputs {
Expand All @@ -39,7 +56,7 @@ fn prepare(function: &syn::ItemFn) -> Result<Function> {
syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
};

Ok(Function { item: function.clone(), args, output })
Ok(Function { item: function.clone(), args, output, local: meta.local })
}

/// Preprocess a function argument.
Expand Down Expand Up @@ -124,23 +141,69 @@ fn process(function: &Function) -> Result<TokenStream> {
ident.mutability = None;
}

wrapped.block = parse_quote! { {
static __CACHE: ::comemo::internal::Cache<
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age));
::core::default::Default::default()
});

#(#bounds;)*
::comemo::internal::memoized(
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&__CACHE,
#closure,
)
} };
if function.local {
wrapped.block = parse_quote! { {
type __ARGS<'local> = <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint;
::std::thread_local! {
static __CACHE: ::comemo::internal::Cache<
__ARGS<'static>,
#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_local_evictor(|max_age| __CACHE.with(|cache| cache.evict(max_age)));
::core::default::Default::default()
});
}

#(#bounds;)*
__CACHE.with(|cache| {
::comemo::internal::memoized(
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&cache,
#closure,
)
})
} };
} else {
wrapped.block = parse_quote! { {
static __CACHE: ::comemo::internal::Cache<
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age));
::core::default::Default::default()
});

#(#bounds;)*
::comemo::internal::memoized(
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&__CACHE,
#closure,
)
} };
}

Ok(quote! { #wrapped })
}

/// Parse a metadata flag that can be present or not.
pub fn parse_flag<K: Token + Default + Parse>(input: ParseStream) -> Result<bool> {
if input.peek(|_| K::default()) {
let _: K = input.parse()?;
eat_comma(input);
return Ok(true);
}
Ok(false)
}

/// Parse a comma if there is one.
fn eat_comma(input: ParseStream) {
if input.peek(syn::Token![,]) {
let _: syn::Token![,] = input.parse().unwrap();
}
}

pub mod kw {
syn::custom_keyword!(local);
}
32 changes: 32 additions & 0 deletions src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};

Expand All @@ -9,6 +10,11 @@ use crate::accelerate;
use crate::constraint::Join;
use crate::input::Input;

thread_local! {
/// The thread-local list of eviction functions.
static LOCAL_EVICTORS: RefCell<Vec<fn(usize)>> = const { RefCell::new(Vec::new()) };
}

/// The global list of eviction functions.
static EVICTORS: RwLock<Vec<fn(usize)>> = RwLock::new(Vec::new());

Expand Down Expand Up @@ -90,11 +96,37 @@ pub fn evict(max_age: usize) {
accelerate::evict();
}

/// Evict the thread local cache.
///
/// This removes all memoized results from the cache whose age is larger than or
/// equal to `max_age`. The age of a result grows by one during each eviction
/// and is reset to zero when the result produces a cache hit. Set `max_age` to
/// zero to completely clear the cache.
///
/// Comemo's cache is thread-local, meaning that this only evicts this thread's
/// cache.
pub fn local_evict(max_age: usize) {
LOCAL_EVICTORS.with_borrow(|cell| {
for subevict in cell.iter() {
subevict(max_age);
}
});

accelerate::evict();
}

/// Register an eviction function in the global list.
pub fn register_evictor(evict: fn(usize)) {
EVICTORS.write().push(evict);
}

/// Register an eviction function in the global list.
pub fn register_local_evictor(evict: fn(usize)) {
LOCAL_EVICTORS.with_borrow_mut(|cell| {
cell.push(evict);
})
}

/// Whether the last call was a hit.
#[cfg(feature = "testing")]
pub fn last_was_hit() -> bool {
Expand Down
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ mod input;
mod prehashed;
mod track;

pub use crate::cache::evict;
pub use crate::cache::{evict, local_evict};
pub use crate::prehashed::Prehashed;
pub use crate::track::{Track, Tracked, TrackedMut, Validate};
pub use comemo_macros::{memoize, track};
Expand All @@ -99,7 +99,9 @@ pub use comemo_macros::{memoize, track};
pub mod internal {
pub use parking_lot::RwLock;

pub use crate::cache::{memoized, register_evictor, Cache, CacheData};
pub use crate::cache::{
memoized, register_evictor, register_local_evictor, Cache, CacheData,
};
pub use crate::constraint::{hash, Call, ImmutableConstraint, MutableConstraint};
pub use crate::input::{assert_hashable_or_trackable, Args, Input};
pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces};
Expand Down
Loading

0 comments on commit 15c8d5b

Please sign in to comment.