Skip to content

Commit

Permalink
Simplify pushFunction and pushClosure
Browse files Browse the repository at this point in the history
This makes the api work for all versions of Lua. If Luau users wish to
set the debugname for the function, they can use the pushFunctionNamed
and pushClosureNamed functions.
  • Loading branch information
natecraddock committed Feb 1, 2024
1 parent 4cbe5ef commit a213c57
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 65 deletions.
56 changes: 27 additions & 29 deletions src/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1443,31 +1443,37 @@ pub const Lua = struct {
c.lua_pushboolean(lua.state, @intFromBool(b));
}

fn pushClosureLua(lua: *Lua, c_fn: CFn, n: i32) void {
c.lua_pushcclosure(lua.state, c_fn, n);
}

fn pushClosureLuau(lua: *Lua, c_fn: CFn, name: [:0]const u8, n: i32) void {
c.lua_pushcclosurek(lua.state, c_fn, name, n, null);
}

/// Pushes a new Closure onto the stack
/// `n` tells how many upvalues this function will have
/// See https://www.lua.org/manual/5.1/manual.html#lua_pushcclosure
pub const pushClosure = if (lang == .luau) pushClosureLuau else pushClosureLua;
/// See https://www.lua.org/manual/5.4/manual.html#lua_pushcclosure
pub fn pushClosure(lua: *Lua, c_fn: CFn, n: i32) void {
switch (lang) {
.luau => c.lua_pushcclosurek(lua.state, c_fn, "ZigFn", n, null),
else => c.lua_pushcclosure(lua.state, c_fn, n),
}
}

fn pushFunctionLua(lua: *Lua, c_fn: CFn) void {
lua.pushClosure(c_fn, 0);
/// Pushes a new Closure onto the stack with a debugname
/// `n` tells how many upvalues this function will have
/// See https://www.lua.org/manual/5.4/manual.html#lua_pushcclosure
pub fn pushClosureNamed(lua: *Lua, c_fn: CFn, debugname: [:0]const u8, n: i32) void {
c.lua_pushcclosurek(lua.state, c_fn, debugname, n, null);
}

fn pushFunctionLuau(lua: *Lua, c_fn: CFn, name: [:0]const u8) void {
lua.pushClosure(c_fn, name, 0);
/// Pushes a new function onto the stack
/// See https://www.lua.org/manual/5.4/manual.html#lua_pushcfunction
pub fn pushFunction(lua: *Lua, c_fn: CFn) void {
switch (lang) {
.luau => c.lua_pushcclosurek(lua.state, c_fn, "ZigFn", 0, null),
else => c.lua_pushcfunction(lua.state, c_fn),
}
}

/// Pushes a new Closure onto the stack
/// `n` tells how many upvalues this function will have
/// See https://www.lua.org/manual/5.4/manual.html#lua_pushcclosure
pub const pushFunction = if (lang == .luau) pushFunctionLuau else pushFunctionLua;
/// Pushes a new function onto the stack with a debugname
/// See https://www.lua.org/manual/5.4/manual.html#lua_pushcfunction
pub fn pushFunctionNamed(lua: *Lua, c_fn: CFn, debugname: [:0]const u8) void {
c.lua_pushcclosurek(lua.state, c_fn, debugname, 0, null);
}

/// Push a formatted string onto the stack and return a pointer to the string
/// See https://www.lua.org/manual/5.4/manual.html#lua_pushfstring
Expand Down Expand Up @@ -1672,7 +1678,7 @@ pub const Lua = struct {
pub fn register(lua: *Lua, name: [:0]const u8, c_fn: CFn) void {
switch (lang) {
.luau => {
lua.pushFunction(c_fn, name);
lua.pushFunction(c_fn);
lua.setGlobal(name);
},
else => c.lua_register(lua.state, name.ptr, c_fn),
Expand Down Expand Up @@ -2815,10 +2821,7 @@ pub const Lua = struct {
}
for (funcs) |f| {
// TODO: handle null functions
switch (lang) {
.luau => lua.pushFunction(f.func.?, f.name),
else => lua.pushFunction(f.func.?),
}
lua.pushFunction(f.func.?);
lua.setField(-2, f.name);
}
}
Expand All @@ -2829,12 +2832,7 @@ pub const Lua = struct {
pub fn requireF(lua: *Lua, mod_name: [:0]const u8, open_fn: CFn, global: bool) void {
switch (lang) {
.lua51, .luajit, .luau => {
if (lang == .luau) {
lua.pushFunction(open_fn, mod_name);
} else {
lua.pushFunction(open_fn);
}

lua.pushFunction(open_fn);
_ = lua.pushString(mod_name);
lua.call(1, 0);
},
Expand Down
62 changes: 26 additions & 36 deletions src/tests.zig
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ inline fn toNumber(lua: *Lua, index: i32) !ziglua.Number {
} else return try lua.toNumber(index);
}

/// pushFunction that sets the name for Luau
inline fn pushFunction(lua: *Lua, c_fn: ziglua.CFn) void {
if (ziglua.lang == .luau) return lua.pushFunction(c_fn, "");
lua.pushFunction(c_fn);
}

fn alloc(data: ?*anyopaque, ptr: ?*anyopaque, osize: usize, nsize: usize) callconv(.C) ?*anyopaque {
_ = data;

Expand Down Expand Up @@ -145,7 +139,7 @@ test "Zig allocator access" {
}
}.inner;

pushFunction(&lua, ziglua.wrap(inner));
lua.pushFunction(ziglua.wrap(inner));
lua.pushInteger(10);
try lua.protectedCall(1, 1, 0);

Expand Down Expand Up @@ -347,7 +341,7 @@ test "type of and getting values" {
try expectEqual(.string, lua.typeOf(-1));
try expect(lua.isString(-1));

if (ziglua.lang == .luau) lua.pushFunction(ziglua.wrap(add), "add") else lua.pushFunction(ziglua.wrap(add));
lua.pushFunction(ziglua.wrap(add));
try expectEqual(.function, lua.typeOf(-1));
try expect(lua.isCFunction(-1));
try expect(lua.isFunction(-1));
Expand Down Expand Up @@ -875,10 +869,7 @@ test "table access" {
// create a metatable (it isn't a useful one)
lua.newTable();

if (ziglua.lang == .luau)
lua.pushFunction(ziglua.wrap(add), "add")
else
lua.pushFunction(ziglua.wrap(add));
lua.pushFunction(ziglua.wrap(add));
lua.setField(-2, "__len");
lua.setMetatable(1);

Expand Down Expand Up @@ -1087,7 +1078,7 @@ test "upvalues" {

// Initialize the counter at 0
lua.pushInteger(0);
if (ziglua.lang == .luau) lua.pushClosure(ziglua.wrap(counter), "counter", 1) else lua.pushClosure(ziglua.wrap(counter), 1);
lua.pushClosure(ziglua.wrap(counter), 1);
lua.setGlobal("counter");

// call the function repeatedly, each time ensuring the result increases by one
Expand Down Expand Up @@ -1191,7 +1182,7 @@ test "raise error" {
}
}.inner;

if (ziglua.lang == .luau) lua.pushFunction(ziglua.wrap(makeError), "makeError") else lua.pushFunction(ziglua.wrap(makeError));
lua.pushFunction(ziglua.wrap(makeError));
try expectError(error.Runtime, lua.protectedCall(0, 0, 0));
try expectEqualStrings("makeError made an error", try lua.toBytes(-1));
}
Expand Down Expand Up @@ -1307,11 +1298,10 @@ test "yielding no continuation" {
return l.yield(1);
}
}.inner);
thread.pushFunction(func);
if (ziglua.lang == .luau) {
thread.pushFunction(func, "yieldfn");
_ = try thread.resumeThread(null, 0);
} else {
thread.pushFunction(func);
_ = try thread.resumeThread(0);
}

Expand Down Expand Up @@ -1367,28 +1357,28 @@ test "aux check functions" {
}
}.inner);

pushFunction(&lua, function);
lua.pushFunction(function);
lua.protectedCall(0, 0, 0) catch {
try expectStringContains("argument #1", try lua.toBytes(-1));
lua.pop(-1);
};

pushFunction(&lua, function);
lua.pushFunction(function);
lua.pushNil();
lua.protectedCall(1, 0, 0) catch {
try expectStringContains("number expected", try lua.toBytes(-1));
lua.pop(-1);
};

pushFunction(&lua, function);
lua.pushFunction(function);
lua.pushNil();
lua.pushInteger(3);
lua.protectedCall(2, 0, 0) catch {
try expectStringContains("string expected", try lua.toBytes(-1));
lua.pop(-1);
};

pushFunction(&lua, function);
lua.pushFunction(function);
lua.pushNil();
lua.pushInteger(3);
_ = lua.pushBytes("hello world");
Expand All @@ -1397,7 +1387,7 @@ test "aux check functions" {
lua.pop(-1);
};

pushFunction(&lua, function);
lua.pushFunction(function);
lua.pushNil();
lua.pushInteger(3);
_ = lua.pushBytes("hello world");
Expand All @@ -1407,7 +1397,7 @@ test "aux check functions" {
lua.pop(-1);
};

pushFunction(&lua, function);
lua.pushFunction(function);
lua.pushNil();
lua.pushInteger(3);
_ = lua.pushBytes("hello world");
Expand All @@ -1432,7 +1422,7 @@ test "aux check functions" {
};
}

pushFunction(&lua, function);
lua.pushFunction(function);
// test pushFail here (currently acts the same as pushNil)
if (ziglua.lang == .lua54) lua.pushFail() else lua.pushNil();
lua.pushInteger(3);
Expand Down Expand Up @@ -1460,10 +1450,10 @@ test "aux opt functions" {
}
}.inner);

pushFunction(&lua, function);
lua.pushFunction(function);
try lua.protectedCall(0, 0, 0);

pushFunction(&lua, function);
lua.pushFunction(function);
lua.pushInteger(10);
_ = lua.pushBytes("zig");
lua.pushNumber(1.23);
Expand Down Expand Up @@ -1493,32 +1483,32 @@ test "checkOption" {
}
}.inner);

pushFunction(&lua, function);
lua.pushFunction(function);
_ = lua.pushString("one");
try lua.protectedCall(1, 1, 0);
try expectEqual(1, try toInteger(&lua, -1));
lua.pop(1);

pushFunction(&lua, function);
lua.pushFunction(function);
_ = lua.pushString("two");
try lua.protectedCall(1, 1, 0);
try expectEqual(2, try toInteger(&lua, -1));
lua.pop(1);

pushFunction(&lua, function);
lua.pushFunction(function);
_ = lua.pushString("three");
try lua.protectedCall(1, 1, 0);
try expectEqual(3, try toInteger(&lua, -1));
lua.pop(1);

// try the default now
pushFunction(&lua, function);
lua.pushFunction(function);
try lua.protectedCall(0, 1, 0);
try expectEqual(1, try toInteger(&lua, -1));
lua.pop(1);

// check the raised error
pushFunction(&lua, function);
lua.pushFunction(function);
_ = lua.pushString("unknown");
try expectError(error.Runtime, lua.protectedCall(1, 1, 0));
try expectStringContains("(invalid option 'unknown')", try lua.toBytes(-1));
Expand Down Expand Up @@ -1569,7 +1559,7 @@ test "where" {
}
}.inner);

pushFunction(&lua, whereFn);
lua.pushFunction(whereFn);
lua.setGlobal("whereFn");

try lua.doString(
Expand Down Expand Up @@ -1662,7 +1652,7 @@ test "args and errors" {
}
}.inner);

pushFunction(&lua, argCheck);
lua.pushFunction(argCheck);
try expectError(error.Runtime, lua.protectedCall(0, 0, 0));

const raisesError = ziglua.wrap(struct {
Expand All @@ -1672,7 +1662,7 @@ test "args and errors" {
}
}.inner);

pushFunction(&lua, raisesError);
lua.pushFunction(raisesError);
try expectError(error.Runtime, lua.protectedCall(0, 0, 0));
try expectEqualStrings("some error zig!", try lua.toBytes(-1));

Expand Down Expand Up @@ -1755,7 +1745,7 @@ test "userdata" {
}
}.inner);

pushFunction(&lua, checkUdata);
lua.pushFunction(checkUdata);

{
var t = if (ziglua.lang == .lua54) lua.newUserdata(Type, 0) else lua.newUserdata(Type);
Expand Down Expand Up @@ -1840,7 +1830,7 @@ test "userdata slices" {
}
}.inner;

pushFunction(&lua, ziglua.wrap(udataFn));
lua.pushFunction(ziglua.wrap(udataFn));
lua.pushValue(2);

try lua.protectedCall(1, 0, 0);
Expand Down Expand Up @@ -2387,7 +2377,7 @@ test "namecall" {

try lua.newMetatable("vector");
lua.pushString("__namecall");
lua.pushFunction(ziglua.wrap(funcs.vectorNamecall), "vector_namecall");
lua.pushFunctionNamed(ziglua.wrap(funcs.vectorNamecall), "vector_namecall");
lua.setTable(-3);

lua.setReadonly(-1, true);
Expand Down

0 comments on commit a213c57

Please sign in to comment.