diff --git a/lua/lapi.c b/lua/lapi.c index ae739e1921..7b9ed2a685 100644 --- a/lua/lapi.c +++ b/lua/lapi.c @@ -1105,22 +1105,11 @@ LUA_API struct stored_proc *lua_getsp(lua_State *L) { } -LUA_API void set_sqlrow_stmt(lua_State *L, struct dbstmt_t *dbstmt) { +LUA_API void luabb_set_sqlrow(lua_State *L) { lua_lock(L); int top = lua_gettop(L); TValue *obj = index2adr(L, top); Table *h = hvalue(obj); - h->dbstmt = dbstmt; + h->from_sql = 1; lua_unlock(L); } - - -LUA_API struct dbstmt_t *get_sqlrow_stmt(lua_State *L) { - lua_lock(L); - int top = lua_gettop(L); - TValue *obj = index2adr(L, top); - Table *h = hvalue(obj); - struct dbstmt_t *dbstmt = h->dbstmt; - lua_unlock(L); - return dbstmt; -} diff --git a/lua/lobject.h b/lua/lobject.h index b3abf2a46e..fc41f09b9c 100644 --- a/lua/lobject.h +++ b/lua/lobject.h @@ -345,7 +345,7 @@ typedef struct Table { Node *lastfree; /* any free position is before this position */ GCObject *gclist; int sizearray; /* size of `array' array */ - struct dbstmt_t *dbstmt; + int from_sql; } Table; diff --git a/lua/ltable.c b/lua/ltable.c index d9ce7adaa2..1955a61f62 100644 --- a/lua/ltable.c +++ b/lua/ltable.c @@ -196,7 +196,7 @@ static int findindex (lua_State *L, Table *t, StkId key) { int luaH_next (lua_State *L, Table *t, StkId key) { int i = findindex(L, t, key); /* find original element */ for (i++; i < t->sizearray; i++) { /* try first array part */ - if (t->dbstmt) continue; + if (t->from_sql) continue; if (!ttisnil(&t->array[i])) { /* a non-nil value? */ setnvalue(key, cast_num(i+1)); setobj2s(L, key+1, &t->array[i]); @@ -395,7 +395,7 @@ Table *luaH_new (lua_State *L, int narray, int nhash) { luaC_link(L, obj2gco(t), LUA_TTABLE); t->metatable = NULL; t->flags = cast_byte(~0); - t->dbstmt = NULL; + t->from_sql = 0; /* temporary values (kept only if some malloc fails) */ t->array = NULL; t->sizearray = 0; diff --git a/lua/lua.h b/lua/lua.h index b1598a692a..fe4717fad9 100644 --- a/lua/lua.h +++ b/lua/lua.h @@ -348,11 +348,7 @@ LUA_API int lua_gethookcount (lua_State *L); struct stored_proc; LUA_API void lua_setsp(lua_State *, struct stored_proc *); LUA_API struct stored_proc *lua_getsp(lua_State *); - -struct dbstmt_t; -LUA_API struct dbstmt_t *get_sqlrow_stmt(lua_State *L); -LUA_API void set_sqlrow_stmt(lua_State *L, struct dbstmt_t *); - +LUA_API void luabb_set_sqlrow(lua_State *L); struct lua_Debug { int event; diff --git a/lua/sp.c b/lua/sp.c index 6faaa936a6..3d7a0cb9ee 100644 --- a/lua/sp.c +++ b/lua/sp.c @@ -1443,11 +1443,39 @@ static int lua_sql_step(Lua lua, sqlite3_stmt *stmt) return rc; } +static void set_sqlrow_stmt(Lua L) +{ + /* + ** stack: + ** 1. row (lua table) + ** 2. stmt + ** tag stmt to row by: + ** newtbl = {} + ** newtbl.__metatable = stmt + ** setmetatable(row, newtbl) + */ + lua_newtable(L); + lua_pushvalue(L, -3); + lua_setfield(L, -2, "__metatable"); + lua_setmetatable(L, -2); + luabb_set_sqlrow(L); +} + +static dbstmt_t *get_sqlrow_stmt(Lua L) +{ + dbstmt_t *stmt = NULL; + if (lua_getmetatable(L, -1) == 0) return NULL; + lua_getfield(L, -1, "__metatable"); + if (luabb_type(L, -1) == DBTYPES_DBSTMT) stmt = lua_touserdata(L, -1); + lua_pop(L, 2); + return stmt; +} + static int stmt_sql_step(Lua L, dbstmt_t *dbstmt) { int rc; if ((rc = lua_sql_step(L, dbstmt->stmt)) == SQLITE_ROW) { - set_sqlrow_stmt(L, dbstmt); + set_sqlrow_stmt(L); } return rc; } @@ -2401,9 +2429,9 @@ static int luatable_emit(Lua L) { int cols; SP sp = getsp(L); + dbstmt_t *dbstmt; sqlite3_stmt *stmt = NULL; - dbstmt_t *dbstmt = get_sqlrow_stmt(L); - if (dbstmt) { + if ((dbstmt = get_sqlrow_stmt(L)) != NULL && dbstmt->stmt) { stmt = dbstmt->stmt; cols = column_count(NULL, stmt); } else if (sp->parent->ntypes) { diff --git a/tests/sp.test/sp.sh b/tests/sp.test/sp.sh index fca6997104..539f611366 100755 --- a/tests/sp.test/sp.sh +++ b/tests/sp.test/sp.sh @@ -1412,3 +1412,65 @@ local function main() end}$$ EOF cdb2sql $SP_OPTIONS "exec procedure dbtable_insert()" + +cdb2sql $SP_OPTIONS - <<'EOF' +CREATE PROCEDURE emit0 VERSION 'sptest' { +local function get_pairs_count(row) + --dbrow column can be retrieved by names as well as index. We don't want + --double counting, and so pairs(dbrow) should skip keys returned by + --ipairs(dbrow) + local cnt = 0 + for _, _ in pairs(row) do + cnt = cnt + 1 + end + return cnt +end +local function main() + local stmt = db:exec("select 1 as emit0") + local row = stmt:fetch() + while row do + local tmp = row + db:emit(tmp) + db:emit({emit0 = get_pairs_count(tmp)}) + row = stmt:fetch() + db:emit(tmp) + db:emit({emit0 = get_pairs_count(tmp)}) + end +end +}$$ +EXEC PROCEDURE emit0() +EOF + +cdb2sql $SP_OPTIONS - <<'EOF' +CREATE PROCEDURE emit1 VERSION 'sptest' { +local function main() + local stmt = db:exec("select 1 as emit1") + local row = stmt:fetch() + while row do + local tmp = row + row = stmt:fetch() + db:emit(tmp) + end +end +}$$ +EXEC PROCEDURE emit1() +EOF + +cdb2sql $SP_OPTIONS - <<'EOF' +CREATE PROCEDURE emit2 VERSION 'sptest' { +local function test() + local stmt = db:exec("select 1 as emit2") + local row = stmt:fetch() + db:emit(row) + stmt:fetch() + stmt:close() + stmt = nil + collectgarbage("collect") + return row +end +local function main() + db:emit(test()) +end +}$$ +EXEC PROCEDURE emit2() +EOF diff --git a/tests/sp.test/t01.req.out b/tests/sp.test/t01.req.out index 8960f9b199..b3687670de 100644 --- a/tests/sp.test/t01.req.out +++ b/tests/sp.test/t01.req.out @@ -1788,3 +1788,13 @@ SP: exec procedure bound(@i, @u, @ll, @ull, @s, @us, @f, @d, @dt, @cstr, @ba, @b ($0='{"control":"\u0007"}') ($0='DEADBEEF') ($0='incompatible values from SQL string of length 13 to bint4 field 'i' for table 't'') +(version='sptest') +(emit0=1) +(emit0=1) +(emit0=1) +(emit0=1) +(version='sptest') +[EXEC PROCEDURE emit1()] failed with rc -3 [db:emit(tmp)...]:8: attempt to emit row without defining columns +(version='sptest') +(emit2=1) +(emit2=1)