Skip to content

Commit

Permalink
Add LuaNativeFn/LuaNativeFnMut/LuaNativeAsyncFn traits for usin…
Browse files Browse the repository at this point in the history
…g in `Function::wrap`
  • Loading branch information
khvzak committed Sep 24, 2024
1 parent 8274b5f commit 91fe02d
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 53 deletions.
68 changes: 56 additions & 12 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{mem, ptr, slice};
use crate::error::{Error, Result};
use crate::state::Lua;
use crate::table::Table;
use crate::traits::{LuaNativeFn, LuaNativeFnMut};
use crate::types::{Callback, LuaType, MaybeSend, ValueRef};
use crate::util::{
assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, StackGuard,
Expand All @@ -13,6 +14,7 @@ use crate::value::{FromLuaMulti, IntoLua, IntoLuaMulti, Value};

#[cfg(feature = "async")]
use {
crate::traits::LuaNativeAsyncFn,
crate::types::AsyncCallback,
std::future::{self, Future},
};
Expand Down Expand Up @@ -522,55 +524,97 @@ impl Function {
/// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`]
/// trait.
#[inline]
pub fn wrap<A, R, F>(func: F) -> impl IntoLua
pub fn wrap<F, A, R>(func: F) -> impl IntoLua
where
F: LuaNativeFn<A, Output = Result<R>> + MaybeSend + 'static,
A: FromLuaMulti,
R: IntoLuaMulti,
F: Fn(&Lua, A) -> Result<R> + MaybeSend + 'static,
{
WrappedFunction(Box::new(move |lua, nargs| unsafe {
let args = A::from_stack_args(nargs, 1, None, lua)?;
func(lua.lua(), args)?.push_into_stack_multi(lua)
func.call(args)?.push_into_stack_multi(lua)
}))
}

/// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait.
#[inline]
pub fn wrap_mut<A, R, F>(func: F) -> impl IntoLua
pub fn wrap_mut<F, A, R>(func: F) -> impl IntoLua
where
F: LuaNativeFnMut<A, Output = Result<R>> + MaybeSend + 'static,
A: FromLuaMulti,
R: IntoLuaMulti,
F: FnMut(&Lua, A) -> Result<R> + MaybeSend + 'static,
{
let func = RefCell::new(func);
WrappedFunction(Box::new(move |lua, nargs| unsafe {
let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?;
let args = A::from_stack_args(nargs, 1, None, lua)?;
func(lua.lua(), args)?.push_into_stack_multi(lua)
func.call(args)?.push_into_stack_multi(lua)
}))
}

#[inline]
pub fn wrap_raw<F, A>(func: F) -> impl IntoLua
where
F: LuaNativeFn<A> + MaybeSend + 'static,
A: FromLuaMulti,
{
WrappedFunction(Box::new(move |lua, nargs| unsafe {
let args = A::from_stack_args(nargs, 1, None, lua)?;
func.call(args).push_into_stack_multi(lua)
}))
}

#[inline]
pub fn wrap_raw_mut<F, A>(func: F) -> impl IntoLua
where
F: LuaNativeFnMut<A> + MaybeSend + 'static,
A: FromLuaMulti,
{
let func = RefCell::new(func);
WrappedFunction(Box::new(move |lua, nargs| unsafe {
let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?;
let args = A::from_stack_args(nargs, 1, None, lua)?;
func.call(args).push_into_stack_multi(lua)
}))
}

/// Wraps a Rust async function or closure, returning an opaque type that implements [`IntoLua`]
/// trait.
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub fn wrap_async<A, R, F, FR>(func: F) -> impl IntoLua
pub fn wrap_async<F, A, R>(func: F) -> impl IntoLua
where
F: LuaNativeAsyncFn<A, Output = Result<R>> + MaybeSend + 'static,
A: FromLuaMulti,
R: IntoLuaMulti,
F: Fn(Lua, A) -> FR + MaybeSend + 'static,
FR: Future<Output = Result<R>> + MaybeSend + 'static,
{
WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe {
let args = match A::from_stack_args(nargs, 1, None, rawlua) {
Ok(args) => args,
Err(e) => return Box::pin(future::ready(Err(e))),
};
let lua = rawlua.lua().clone();
let fut = func(lua.clone(), args);
let lua = rawlua.lua();
let fut = func.call(args);
Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) })
}))
}

#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub fn wrap_raw_async<F, A>(func: F) -> impl IntoLua
where
F: LuaNativeAsyncFn<A> + MaybeSend + 'static,
A: FromLuaMulti,
{
WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe {
let args = match A::from_stack_args(nargs, 1, None, rawlua) {
Ok(args) => args,
Err(e) => return Box::pin(future::ready(Err(e))),
};
let lua = rawlua.lua();
let fut = func.call(args);
Box::pin(async move { fut.await.push_into_stack_multi(lua.raw_lua()) })
}))
}
}

impl IntoLua for WrappedFunction {
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub use crate::stdlib::StdLib;
pub use crate::string::{BorrowedBytes, BorrowedStr, String};
pub use crate::table::{Table, TablePairs, TableSequence};
pub use crate::thread::{Thread, ThreadStatus};
pub use crate::traits::ObjectLike;
pub use crate::traits::{LuaNativeFn, LuaNativeFnMut, ObjectLike};
pub use crate::types::{
AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState,
};
Expand All @@ -133,7 +133,7 @@ pub use crate::hook::HookTriggers;
pub use crate::{chunk::Compiler, function::CoverageInfo, types::Vector};

#[cfg(feature = "async")]
pub use crate::thread::AsyncThread;
pub use crate::{thread::AsyncThread, traits::LuaNativeAsyncFn};

#[cfg(feature = "serialize")]
#[doc(inline)]
Expand Down
19 changes: 10 additions & 9 deletions src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ pub use crate::{
AnyUserData as LuaAnyUserData, Chunk as LuaChunk, Error as LuaError, ErrorContext as LuaErrorContext,
ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti,
Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, Integer as LuaInteger,
IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaOptions, MetaMethod as LuaMetaMethod,
MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, ObjectLike as LuaObjectLike,
RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, String as LuaString,
Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, Thread as LuaThread,
ThreadStatus as LuaThreadStatus, UserData as LuaUserData, UserDataFields as LuaUserDataFields,
UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods,
UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut,
UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, VmState as LuaVmState,
IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, LuaNativeFnMut, LuaOptions,
MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber,
ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib,
String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence,
Thread as LuaThread, ThreadStatus as LuaThreadStatus, UserData as LuaUserData,
UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable,
UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef,
UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue,
VmState as LuaVmState,
};

#[cfg(not(feature = "luau"))]
Expand All @@ -25,7 +26,7 @@ pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector};

#[cfg(feature = "async")]
#[doc(no_inline)]
pub use crate::AsyncThread as LuaAsyncThread;
pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn};

#[cfg(feature = "serialize")]
#[doc(no_inline)]
Expand Down
6 changes: 6 additions & 0 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,12 @@ impl Lua {
T::from_lua(value, self)
}

/// Converts a value that implements `IntoLua` into a `FromLua` variant.
#[inline]
pub fn convert<U: FromLua>(&self, value: impl IntoLua) -> Result<U> {
U::from_lua(value.into_lua(self)?, self)
}

/// Converts a value that implements `IntoLuaMulti` into a `MultiValue` instance.
#[inline]
pub fn pack_multi(&self, t: impl IntoLuaMulti) -> Result<MultiValue> {
Expand Down
92 changes: 92 additions & 0 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::string::String as StdString;

use crate::error::Result;
use crate::private::Sealed;
use crate::types::MaybeSend;
use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti};

#[cfg(feature = "async")]
Expand Down Expand Up @@ -76,3 +77,94 @@ pub trait ObjectLike: Sealed {
/// This might invoke the `__tostring` metamethod.
fn to_string(&self) -> Result<StdString>;
}

/// A trait for types that can be used as Lua functions.
pub trait LuaNativeFn<A: FromLuaMulti> {
type Output: IntoLuaMulti;

fn call(&self, args: A) -> Self::Output;
}

/// A trait for types with mutable state that can be used as Lua functions.
pub trait LuaNativeFnMut<A: FromLuaMulti> {
type Output: IntoLuaMulti;

fn call(&mut self, args: A) -> Self::Output;
}

/// A trait for types that returns a future and can be used as Lua functions.
#[cfg(feature = "async")]
pub trait LuaNativeAsyncFn<A: FromLuaMulti> {
type Output: IntoLuaMulti;

fn call(&self, args: A) -> impl Future<Output = Self::Output> + MaybeSend + 'static;
}

macro_rules! impl_lua_native_fn {
($($A:ident),*) => {
impl<FN, $($A,)* R> LuaNativeFn<($($A,)*)> for FN
where
FN: Fn($($A,)*) -> R + MaybeSend + 'static,
($($A,)*): FromLuaMulti,
R: IntoLuaMulti,
{
type Output = R;

#[allow(non_snake_case)]
fn call(&self, args: ($($A,)*)) -> Self::Output {
let ($($A,)*) = args;
self($($A,)*)
}
}

impl<FN, $($A,)* R> LuaNativeFnMut<($($A,)*)> for FN
where
FN: FnMut($($A,)*) -> R + MaybeSend + 'static,
($($A,)*): FromLuaMulti,
R: IntoLuaMulti,
{
type Output = R;

#[allow(non_snake_case)]
fn call(&mut self, args: ($($A,)*)) -> Self::Output {
let ($($A,)*) = args;
self($($A,)*)
}
}

#[cfg(feature = "async")]
impl<FN, $($A,)* Fut, R> LuaNativeAsyncFn<($($A,)*)> for FN
where
FN: Fn($($A,)*) -> Fut + MaybeSend + 'static,
($($A,)*): FromLuaMulti,
Fut: Future<Output = R> + MaybeSend + 'static,
R: IntoLuaMulti,
{
type Output = R;

#[allow(non_snake_case)]
fn call(&self, args: ($($A,)*)) -> impl Future<Output = Self::Output> + MaybeSend + 'static {
let ($($A,)*) = args;
self($($A,)*)
}
}
};
}

impl_lua_native_fn!();
impl_lua_native_fn!(A);
impl_lua_native_fn!(A, B);
impl_lua_native_fn!(A, B, C);
impl_lua_native_fn!(A, B, C, D);
impl_lua_native_fn!(A, B, C, D, E);
impl_lua_native_fn!(A, B, C, D, E, F);
impl_lua_native_fn!(A, B, C, D, E, F, G);
impl_lua_native_fn!(A, B, C, D, E, F, G, H);
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I);
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J);
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K);
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L);
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M);
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N);
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O);
impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P);
29 changes: 28 additions & 1 deletion tests/async.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![cfg(feature = "async")]

use std::string::String as StdString;
use std::sync::Arc;
use std::time::Duration;

Expand Down Expand Up @@ -39,12 +40,38 @@ async fn test_async_function() -> Result<()> {
async fn test_async_function_wrap() -> Result<()> {
let lua = Lua::new();

let f = Function::wrap_async(|_, s: String| async move { Ok(s) });
let f = Function::wrap_async(|s: StdString| async move {
tokio::task::yield_now().await;
Ok(s)
});
lua.globals().set("f", f)?;
let res: String = lua.load(r#"f("hello")"#).eval_async().await?;
assert_eq!(res, "hello");

Ok(())
}

#[tokio::test]
async fn test_async_function_wrap_raw() -> Result<()> {
let lua = Lua::new();

let f = Function::wrap_raw_async(|s: StdString| async move {
tokio::task::yield_now().await;
s
});
lua.globals().set("f", f)?;
let res: String = lua.load(r#"f("hello")"#).eval_async().await?;
assert_eq!(res, "hello");

// Return error
let ferr = Function::wrap_raw_async(|| async move {
tokio::task::yield_now().await;
Err::<(), _>("some error")
});
lua.globals().set("ferr", ferr)?;
let (_, err): (Value, String) = lua.load(r#"ferr()"#).eval_async().await?;
assert_eq!(err, "some error");

Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion tests/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn test_chunk_macro() -> Result<()> {
data.raw_set("num", 1)?;

let ud = mlua::AnyUserData::wrap("hello");
let f = mlua::Function::wrap(|_lua, ()| Ok(()));
let f = mlua::Function::wrap(|| Ok(()));

lua.globals().set("g", 123)?;

Expand Down
Loading

0 comments on commit 91fe02d

Please sign in to comment.