diff --git a/src/aof.c b/src/aof.c index e0ca6fbb61d..9b035b4d647 100644 --- a/src/aof.c +++ b/src/aof.c @@ -2175,7 +2175,7 @@ static int rewriteFunctions(rio *aof) { dictIterator *iter = dictGetIterator(functions); dictEntry *entry = NULL; while ((entry = dictNext(iter))) { - functionLibInfo *li = dictGetVal(entry); + ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry); if (rioWrite(aof, "*3\r\n", 4) == 0) goto werr; char function_load[] = "$8\r\nFUNCTION\r\n$4\r\nLOAD\r\n"; if (rioWrite(aof, function_load, sizeof(function_load) - 1) == 0) goto werr; diff --git a/src/function_lua.c b/src/function_lua.c index fa9983bf7ed..55e1327ac42 100644 --- a/src/function_lua.c +++ b/src/function_lua.c @@ -64,7 +64,7 @@ typedef struct luaFunctionCtx { } luaFunctionCtx; typedef struct loadCtx { - functionLibInfo *li; + ValkeyModuleScriptingEngineFunctionLibrary *li; monotime start_time; size_t timeout; } loadCtx; @@ -100,7 +100,7 @@ static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) { * * Return NULL on compilation error and set the error to the err variable */ -static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size_t timeout, sds *err) { +static int luaEngineCreate(void *engine_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li, const char *blob, size_t timeout, char **err) { int ret = C_ERR; luaEngineCtx *lua_engine_ctx = engine_ctx; lua_State *lua = lua_engine_ctx->lua; @@ -114,7 +114,7 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size lua_pop(lua, 1); /* pop the metatable */ /* compile the code */ - if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) { + if (luaL_loadbuffer(lua, blob, strlen(blob), "@user_function")) { *err = sdscatprintf(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1)); lua_pop(lua, 1); /* pops the error */ goto done; @@ -158,7 +158,7 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size /* * Invole the give function with the given keys and args */ -static void luaEngineCall(scriptRunCtx *run_ctx, +static void luaEngineCall(ValkeyModuleEngineFunctionCallCtx *func_ctx, void *engine_ctx, void *compiled_function, robj **keys, @@ -177,6 +177,7 @@ static void luaEngineCall(scriptRunCtx *run_ctx, serverAssert(lua_isfunction(lua, -1)); + scriptRunCtx *run_ctx = moduleGetScriptRunCtxFromFunctionCtx(func_ctx); luaCallFunction(run_ctx, lua, keys, nkeys, args, nargs, 0); lua_pop(lua, 1); /* Pop error handler */ } @@ -495,8 +496,8 @@ int luaEngineInitEngine(void) { lua_replace(lua_engine_ctx->lua, LUA_GLOBALSINDEX); /* set new global table as the new globals */ - engine *lua_engine = zmalloc(sizeof(*lua_engine)); - *lua_engine = (engine){ + ValkeyModuleScriptingEngine *lua_engine = zmalloc(sizeof(*lua_engine)); + *lua_engine = (ValkeyModuleScriptingEngine){ .engine_ctx = lua_engine_ctx, .create = luaEngineCreate, .call = luaEngineCall, @@ -505,5 +506,5 @@ int luaEngineInitEngine(void) { .get_engine_memory_overhead = luaEngineMemoryOverhead, .free_function = luaEngineFreeFunction, }; - return functionsRegisterEngine(LUA_ENGINE_NAME, lua_engine); + return functionsRegisterEngine(LUA_ENGINE_NAME, NULL, lua_engine); } diff --git a/src/functions.c b/src/functions.c index e950024bad0..ee94c8f3a8a 100644 --- a/src/functions.c +++ b/src/functions.c @@ -122,7 +122,7 @@ static size_t functionMallocSize(functionInfo *fi) { fi->li->ei->engine->get_function_memory_overhead(fi->function); } -static size_t libraryMallocSize(functionLibInfo *li) { +static size_t libraryMallocSize(ValkeyModuleScriptingEngineFunctionLibrary *li) { return zmalloc_size(li) + sdsAllocSize(li->name) + sdsAllocSize(li->code); } @@ -143,12 +143,12 @@ static void engineFunctionDispose(dict *d, void *obj) { if (fi->desc) { sdsfree(fi->desc); } - engine *engine = fi->li->ei->engine; + ValkeyModuleScriptingEngine *engine = fi->li->ei->engine; engine->free_function(engine->engine_ctx, fi->function); zfree(fi); } -static void engineLibraryFree(functionLibInfo *li) { +static void engineLibraryFree(ValkeyModuleScriptingEngineFunctionLibrary *li) { if (!li) { return; } @@ -227,6 +227,15 @@ functionsLibCtx *functionsLibCtxCreate(void) { return ret; } +void functionsAddEngineStats(engineInfo *ei) { + serverAssert(curr_functions_lib_ctx != NULL); + dictEntry *entry = dictFind(curr_functions_lib_ctx->engines_stats, ei->name); + if (entry == NULL) { + functionsLibEngineStats *stats = zcalloc(sizeof(*stats)); + dictAdd(curr_functions_lib_ctx->engines_stats, ei->name, stats); + } +} + /* * Creating a function inside the given library. * On success, return C_OK. @@ -236,7 +245,7 @@ functionsLibCtx *functionsLibCtxCreate(void) { * the function will verify that the given name is following the naming format * and return an error if its not. */ -int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds desc, uint64_t f_flags, sds *err) { +int functionLibCreateFunction(sds name, void *function, ValkeyModuleScriptingEngineFunctionLibrary *li, sds desc, uint64_t f_flags, sds *err) { if (functionsVerifyName(name) != C_OK) { *err = sdsnew("Library names can only contain letters, numbers, or underscores(_) and must be at least one " "character long"); @@ -263,9 +272,9 @@ int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds return C_OK; } -static functionLibInfo *engineLibraryCreate(sds name, engineInfo *ei, sds code) { - functionLibInfo *li = zmalloc(sizeof(*li)); - *li = (functionLibInfo){ +static ValkeyModuleScriptingEngineFunctionLibrary *engineLibraryCreate(sds name, engineInfo *ei, sds code) { + ValkeyModuleScriptingEngineFunctionLibrary *li = zmalloc(sizeof(*li)); + *li = (ValkeyModuleScriptingEngineFunctionLibrary){ .name = sdsdup(name), .functions = dictCreate(&libraryFunctionDictType), .ei = ei, @@ -274,7 +283,7 @@ static functionLibInfo *engineLibraryCreate(sds name, engineInfo *ei, sds code) return li; } -static void libraryUnlink(functionsLibCtx *lib_ctx, functionLibInfo *li) { +static void libraryUnlink(functionsLibCtx *lib_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li) { dictIterator *iter = dictGetIterator(li->functions); dictEntry *entry = NULL; while ((entry = dictNext(iter))) { @@ -296,7 +305,7 @@ static void libraryUnlink(functionsLibCtx *lib_ctx, functionLibInfo *li) { stats->n_functions -= dictSize(li->functions); } -static void libraryLink(functionsLibCtx *lib_ctx, functionLibInfo *li) { +static void libraryLink(functionsLibCtx *lib_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li) { dictIterator *iter = dictGetIterator(li->functions); dictEntry *entry = NULL; while ((entry = dictNext(iter))) { @@ -332,8 +341,8 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l dictEntry *entry = NULL; iter = dictGetIterator(functions_lib_ctx_src->libraries); while ((entry = dictNext(iter))) { - functionLibInfo *li = dictGetVal(entry); - functionLibInfo *old_li = dictFetchValue(functions_lib_ctx_dst->libraries, li->name); + ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry); + ValkeyModuleScriptingEngineFunctionLibrary *old_li = dictFetchValue(functions_lib_ctx_dst->libraries, li->name); if (old_li) { if (!replace) { /* library already exists, failed the restore. */ @@ -367,7 +376,7 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l /* No collision, it is safe to link all the new libraries. */ iter = dictGetIterator(functions_lib_ctx_src->libraries); while ((entry = dictNext(iter))) { - functionLibInfo *li = dictGetVal(entry); + ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry); libraryLink(functions_lib_ctx_dst, li); dictSetVal(functions_lib_ctx_src->libraries, entry, NULL); } @@ -387,7 +396,7 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l /* Link back all libraries on tmp_l_ctx */ while (listLength(old_libraries_list) > 0) { listNode *head = listFirst(old_libraries_list); - functionLibInfo *li = listNodeValue(head); + ValkeyModuleScriptingEngineFunctionLibrary *li = listNodeValue(head); listNodeValue(head) = NULL; libraryLink(functions_lib_ctx_dst, li); listDelNode(old_libraries_list, head); @@ -401,7 +410,9 @@ libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_l * * - engine_name - name of the engine to register * - engine_ctx - the engine ctx that should be used by the server to interact with the engine */ -int functionsRegisterEngine(const char *engine_name, engine *engine) { +int functionsRegisterEngine(const char *engine_name, + ValkeyModule *engine_module, + ValkeyModuleScriptingEngine *engine) { sds engine_name_sds = sdsnew(engine_name); if (dictFetchValue(engines, engine_name_sds)) { serverLog(LL_WARNING, "Same engine was registered twice"); @@ -416,12 +427,15 @@ int functionsRegisterEngine(const char *engine_name, engine *engine) { engineInfo *ei = zmalloc(sizeof(*ei)); *ei = (engineInfo){ .name = engine_name_sds, + .engineModule = engine_module, .engine = engine, .c = c, }; dictAdd(engines, engine_name_sds, ei); + functionsAddEngineStats(ei); + engine_cache_memory += zmalloc_size(ei) + sdsAllocSize(ei->name) + zmalloc_size(engine) + engine->get_engine_memory_overhead(engine->engine_ctx); @@ -535,7 +549,7 @@ void functionListCommand(client *c) { dictIterator *iter = dictGetIterator(curr_functions_lib_ctx->libraries); dictEntry *entry = NULL; while ((entry = dictNext(iter))) { - functionLibInfo *li = dictGetVal(entry); + ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry); if (library_name) { if (!stringmatchlen(library_name, sdslen(library_name), li->name, sdslen(li->name), 1)) { continue; @@ -584,7 +598,7 @@ void functionListCommand(client *c) { */ void functionDeleteCommand(client *c) { robj *function_name = c->argv[2]; - functionLibInfo *li = dictFetchValue(curr_functions_lib_ctx->libraries, function_name->ptr); + ValkeyModuleScriptingEngineFunctionLibrary *li = dictFetchValue(curr_functions_lib_ctx->libraries, function_name->ptr); if (!li) { addReplyError(c, "Library not found"); return; @@ -614,55 +628,18 @@ uint64_t fcallGetCommandFlags(client *c, uint64_t cmd_flags) { return scriptFlagsToCmdFlags(cmd_flags, script_flags); } -static void fcallCommandGeneric(client *c, int ro) { - /* Functions need to be fed to monitors before the commands they execute. */ - replicationFeedMonitors(c, server.monitors, c->db->id, c->argv, c->argc); - - robj *function_name = c->argv[1]; - dictEntry *de = c->cur_script; - if (!de) de = dictFind(curr_functions_lib_ctx->functions, function_name->ptr); - if (!de) { - addReplyError(c, "Function not found"); - return; - } - functionInfo *fi = dictGetVal(de); - engine *engine = fi->li->ei->engine; - - long long numkeys; - /* Get the number of arguments that are keys */ - if (getLongLongFromObject(c->argv[2], &numkeys) != C_OK) { - addReplyError(c, "Bad number of keys provided"); - return; - } - if (numkeys > (c->argc - 3)) { - addReplyError(c, "Number of keys can't be greater than number of args"); - return; - } else if (numkeys < 0) { - addReplyError(c, "Number of keys can't be negative"); - return; - } - - scriptRunCtx run_ctx; - - if (scriptPrepareForRun(&run_ctx, fi->li->ei->c, c, fi->name, fi->f_flags, ro) != C_OK) return; - - engine->call(&run_ctx, engine->engine_ctx, fi->function, c->argv + 3, numkeys, c->argv + 3 + numkeys, - c->argc - 3 - numkeys); - scriptResetRun(&run_ctx); -} - /* * FCALL nkeys */ void fcallCommand(client *c) { - fcallCommandGeneric(c, 0); + fcallCommandGeneric(curr_functions_lib_ctx->functions, c, 0); } /* * FCALL_RO nkeys */ void fcallroCommand(client *c) { - fcallCommandGeneric(c, 1); + fcallCommandGeneric(curr_functions_lib_ctx->functions, c, 1); } /* @@ -952,9 +929,10 @@ void functionFreeLibMetaData(functionsLibMetaData *md) { sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibCtx *lib_ctx, size_t timeout) { dictIterator *iter = NULL; dictEntry *entry = NULL; - functionLibInfo *new_li = NULL; - functionLibInfo *old_li = NULL; + ValkeyModuleScriptingEngineFunctionLibrary *old_li = NULL; functionsLibMetaData md = {0}; + ValkeyModuleScriptingEngineFunctionLibrary *new_li = NULL; + if (functionExtractLibMetaData(code, &md, err) != C_OK) { return NULL; } @@ -970,7 +948,7 @@ sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibC *err = sdscatfmt(sdsempty(), "Engine '%S' not found", md.engine); goto error; } - engine *engine = ei->engine; + ValkeyModuleScriptingEngine *engine = ei->engine; old_li = dictFetchValue(lib_ctx->libraries, md.name); if (old_li && !replace) { @@ -1073,7 +1051,7 @@ unsigned long functionsMemory(void) { size_t engines_memory = 0; while ((entry = dictNext(iter))) { engineInfo *ei = dictGetVal(entry); - engine *engine = ei->engine; + ValkeyModuleScriptingEngine *engine = ei->engine; engines_memory += engine->get_used_memory(engine->engine_ctx); } dictReleaseIterator(iter); @@ -1114,12 +1092,11 @@ size_t functionsLibCtxFunctionsLen(functionsLibCtx *functions_ctx) { int functionsInit(void) { engines = dictCreate(&engineDictType); + curr_functions_lib_ctx = functionsLibCtxCreate(); + if (luaEngineInitEngine() != C_OK) { return C_ERR; } - /* Must be initialized after engines initialization */ - curr_functions_lib_ctx = functionsLibCtxCreate(); - return C_OK; } diff --git a/src/functions.h b/src/functions.h index da196cf1979..2ef6d76f7a9 100644 --- a/src/functions.h +++ b/src/functions.h @@ -52,78 +52,40 @@ #include "script.h" #include "valkeymodule.h" -typedef struct functionLibInfo functionLibInfo; - -typedef struct engine { - /* engine specific context */ - void *engine_ctx; - - /* Create function callback, get the engine_ctx, and function code - * engine_ctx - opaque struct that was created on engine initialization - * li - library information that need to be provided and when add functions - * code - the library code - * timeout - timeout for the library creation (0 for no timeout) - * err - description of error (if occurred) - * returns C_ERR on error and set err to be the error message */ - int (*create)(void *engine_ctx, functionLibInfo *li, sds code, size_t timeout, sds *err); - - /* Invoking a function, r_ctx is an opaque object (from engine POV). - * The r_ctx should be used by the engine to interaction with the server, - * such interaction could be running commands, set resp, or set - * replication mode - */ - void (*call)(scriptRunCtx *r_ctx, - void *engine_ctx, - void *compiled_function, - robj **keys, - size_t nkeys, - robj **args, - size_t nargs); - - /* get current used memory by the engine */ - size_t (*get_used_memory)(void *engine_ctx); - - /* Return memory overhead for a given function, - * such memory is not counted as engine memory but as general - * structs memory that hold different information */ - size_t (*get_function_memory_overhead)(void *compiled_function); - - /* Return memory overhead for engine (struct size holding the engine)*/ - size_t (*get_engine_memory_overhead)(void *engine_ctx); - - /* free the given function */ - void (*free_function)(void *engine_ctx, void *compiled_function); -} engine; +typedef struct ValkeyModuleScriptingEngineFunctionLibrary ValkeyModuleScriptingEngineFunctionLibrary; /* Hold information about an engine. * Used on rdb.c so it must be declared here. */ typedef struct engineInfo { - sds name; /* Name of the engine */ - engine *engine; /* engine callbacks that allows to interact with the engine */ - client *c; /* Client that is used to run commands */ + sds name; /* Name of the engine */ + ValkeyModule *engineModule; /* the module that implements the scripting engine */ + ValkeyModuleScriptingEngine *engine; /* engine callbacks that allows to interact with the engine */ + client *c; /* Client that is used to run commands */ } engineInfo; /* Hold information about the specific function. * Used on rdb.c so it must be declared here. */ typedef struct functionInfo { - sds name; /* Function name */ - void *function; /* Opaque object that set by the function's engine and allow it - to run the function, usually it's the function compiled code. */ - functionLibInfo *li; /* Pointer to the library created the function */ - sds desc; /* Function description */ - uint64_t f_flags; /* Function flags */ + sds name; /* Function name */ + void *function; /* Opaque object that set by the function's engine and allow it + to run the function, usually it's the function compiled code. */ + ValkeyModuleScriptingEngineFunctionLibrary *li; /* Pointer to the library created the function */ + sds desc; /* Function description */ + uint64_t f_flags; /* Function flags */ } functionInfo; /* Hold information about the specific library. * Used on rdb.c so it must be declared here. */ -struct functionLibInfo { +struct ValkeyModuleScriptingEngineFunctionLibrary { sds name; /* Library name */ dict *functions; /* Functions dictionary */ engineInfo *ei; /* Pointer to the function engine */ sds code; /* Library code */ }; -int functionsRegisterEngine(const char *engine_name, engine *engine_ctx); +int functionsRegisterEngine(const char *engine_name, + ValkeyModule *engine_module, + ValkeyModuleScriptingEngine *engine); sds functionsCreateWithLibraryCtx(sds code, int replace, sds *err, functionsLibCtx *lib_ctx, size_t timeout); unsigned long functionsMemory(void); unsigned long functionsMemoryOverhead(void); @@ -138,7 +100,9 @@ void functionsLibCtxFree(functionsLibCtx *lib_ctx); void functionsLibCtxClear(functionsLibCtx *lib_ctx); void functionsLibCtxSwapWithCurrent(functionsLibCtx *lib_ctx); -int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds desc, uint64_t f_flags, sds *err); +void fcallCommandGeneric(dict *functions, client *c, int ro); + +int functionLibCreateFunction(sds name, void *function, ValkeyModuleScriptingEngineFunctionLibrary *li, sds desc, uint64_t f_flags, sds *err); int luaEngineInitEngine(void); int functionsInit(void); diff --git a/src/module.c b/src/module.c index 28842392002..4803dd967c3 100644 --- a/src/module.c +++ b/src/module.c @@ -67,6 +67,7 @@ #include #include #include +#include "functions.h" /* -------------------------------------------------------------------------- * Private data structures used by the modules system. Those are data @@ -173,6 +174,11 @@ struct ValkeyModuleCtx { }; typedef struct ValkeyModuleCtx ValkeyModuleCtx; +struct ValkeyModuleEngineFunctionCallCtx { + ValkeyModuleCtx module_ctx; + scriptRunCtx run_ctx; +}; + #define VALKEYMODULE_CTX_NONE (0) #define VALKEYMODULE_CTX_AUTO_MEMORY (1 << 0) #define VALKEYMODULE_CTX_KEYS_POS_REQUEST (1 << 1) @@ -13032,6 +13038,106 @@ int VM_RdbSave(ValkeyModuleCtx *ctx, ValkeyModuleRdbStream *stream, int flags) { return VALKEYMODULE_OK; } +/* Registers a new scripting engine in the server. + * + * `engine_name` is the name of engine that matches the scripts header shebang. + * + */ +int VM_RegisterScriptingEngine(ValkeyModuleCtx *ctx, + const char *engine_name, + ValkeyModuleScriptingEngine *scripting_engine) { + UNUSED(ctx); + + serverLog(LL_DEBUG, "Registering a new scripting engine: %s", engine_name); + + if (functionsRegisterEngine(engine_name, + ctx->module, + scripting_engine) != C_OK) { + return VALKEYMODULE_ERR; + } + + return VALKEYMODULE_OK; +} + +/* Registers a new scripting function in the engine function library. + * + * This function should only be called in the context of the scripting engine + * creation callback function. + * + */ +int VM_RegisterScriptingEngineFunction(const char *name, + void *function, + ValkeyModuleScriptingEngineFunctionLibrary *li, + const char *desc, + uint64_t f_flags, + char **err) { + if (functionLibCreateFunction(sdsnew(name), function, li, sdsnew(desc), f_flags, err) != C_OK) { + return VALKEYMODULE_ERR; + } + + return VALKEYMODULE_OK; +} + +/* Implements the scripting engine function call logic. + * + */ +void fcallCommandGeneric(dict *functions, client *c, int ro) { + /* Functions need to be fed to monitors before the commands they execute. */ + replicationFeedMonitors(c, server.monitors, c->db->id, c->argv, c->argc); + + robj *function_name = c->argv[1]; + dictEntry *de = c->cur_script; + if (!de) de = dictFind(functions, function_name->ptr); + if (!de) { + addReplyError(c, "Function not found"); + return; + } + functionInfo *fi = dictGetVal(de); + ValkeyModuleScriptingEngine *engine = fi->li->ei->engine; + + long long numkeys; + /* Get the number of arguments that are keys */ + if (getLongLongFromObject(c->argv[2], &numkeys) != C_OK) { + addReplyError(c, "Bad number of keys provided"); + return; + } + if (numkeys > (c->argc - 3)) { + addReplyError(c, "Number of keys can't be greater than number of args"); + return; + } else if (numkeys < 0) { + addReplyError(c, "Number of keys can't be negative"); + return; + } + + struct ValkeyModuleEngineFunctionCallCtx func_ctx; + + if (scriptPrepareForRun(&func_ctx.run_ctx, fi->li->ei->c, c, fi->name, fi->f_flags, ro) != C_OK) return; + + if (fi->li->ei->engineModule != NULL) { + moduleCreateContext(&func_ctx.module_ctx, fi->li->ei->engineModule, VALKEYMODULE_CTX_NONE); + func_ctx.module_ctx.client = func_ctx.run_ctx.original_client; + } + + engine->call(&func_ctx, engine->engine_ctx, fi->function, c->argv + 3, numkeys, c->argv + 3 + numkeys, + c->argc - 3 - numkeys); + scriptResetRun(&func_ctx.run_ctx); + + if (fi->li->ei->engineModule != NULL) { + moduleFreeContext(&func_ctx.module_ctx); + } +} + +/* Allows to get the module context pointer from the function call context pointer. + * + */ +ValkeyModuleCtx *VM_GetModuleCtxFromFunctionCallCtx(ValkeyModuleEngineFunctionCallCtx *func_ctx) { + return &func_ctx->module_ctx; +} + +scriptRunCtx *moduleGetScriptRunCtxFromFunctionCtx(ValkeyModuleEngineFunctionCallCtx *func_ctx) { + return &func_ctx->run_ctx; +} + /* MODULE command. * * MODULE LIST @@ -13901,4 +14007,7 @@ void moduleRegisterCoreAPI(void) { REGISTER_API(RdbStreamFree); REGISTER_API(RdbLoad); REGISTER_API(RdbSave); + REGISTER_API(RegisterScriptingEngine); + REGISTER_API(RegisterScriptingEngineFunction); + REGISTER_API(GetModuleCtxFromFunctionCallCtx); } diff --git a/src/modules/CMakeLists.txt b/src/modules/CMakeLists.txt index 958796232f1..8181cf93a09 100644 --- a/src/modules/CMakeLists.txt +++ b/src/modules/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND MODULES_LIST "hellohook") list(APPEND MODULES_LIST "hellotimer") list(APPEND MODULES_LIST "hellotype") list(APPEND MODULES_LIST "helloworld") +list(APPEND MODULES_LIST "helloscripting") foreach (MODULE_NAME ${MODULES_LIST}) message(STATUS "Building module: ${MODULE_NAME}") diff --git a/src/modules/Makefile b/src/modules/Makefile index ba8c3dc169c..5053da0a9c0 100644 --- a/src/modules/Makefile +++ b/src/modules/Makefile @@ -20,7 +20,7 @@ endif .SUFFIXES: .c .so .xo .o -all: helloworld.so hellotype.so helloblock.so hellocluster.so hellotimer.so hellodict.so hellohook.so helloacl.so +all: helloworld.so hellotype.so helloblock.so hellocluster.so hellotimer.so hellodict.so hellohook.so helloacl.so helloscripting.so .c.xo: $(CC) -I. $(CFLAGS) $(SHOBJ_CFLAGS) -fPIC -c $< -o $@ @@ -65,5 +65,10 @@ helloacl.xo: ../valkeymodule.h helloacl.so: helloacl.xo $(LD) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc +helloscripting.xo: ../valkeymodule.h + +helloscripting.so: helloscripting.xo + $(LD) -o $@ $^ $(SHOBJ_LDFLAGS) $(LIBS) -lc + clean: rm -rf *.xo *.so diff --git a/src/modules/helloscripting.c b/src/modules/helloscripting.c new file mode 100644 index 00000000000..77c9843e3cd --- /dev/null +++ b/src/modules/helloscripting.c @@ -0,0 +1,254 @@ +#include "../valkeymodule.h" + +#include +#include +#include + +static ValkeyModuleScriptingEngine *scripting_engine = NULL; + + +typedef enum HelloInstKind { + FUNCTION = 0, + CONSTI, + ARGS, + RETURN, + _END, +} HelloInstKind; + +const char *HelloInstKindStr[] = { + "FUNCTION", + "CONSTI", + "ARGS", + "RETURN", +}; + +typedef struct HelloInst { + HelloInstKind kind; + union { + uint32_t integer; + const char *string; + } param; +} HelloInst; + +typedef struct HelloFunc { + char *name; + HelloInst instructions[256]; + uint32_t num_instructions; +} HelloFunc; + +typedef struct HelloProgram { + HelloFunc functions[16]; + uint32_t num_functions; +} HelloProgram; + +typedef struct HelloLangCtx { + HelloProgram *program; +} HelloLangCtx; + +static HelloInstKind helloLangParseInstruction(const char *token) { + ValkeyModule_Log(NULL, "debug", "[hellolang] parsing token '%s'", token); + for (HelloInstKind i = 0; i < _END; i++) { + if (strcmp(HelloInstKindStr[i], token) == 0) { + return i; + } + } + return _END; +} + +static void helloLangParseFunction(HelloFunc *func) { + char *token = strtok(NULL, " \n"); + ValkeyModule_Assert(token != NULL); + func->name = ValkeyModule_Alloc(sizeof(char) * strlen(token)); + strcpy(func->name, token); +} + +static uint32_t str2int(const char *str) { + char *end; + errno = 0; + uint32_t val = (uint32_t)strtoul(str, &end, 10); + ValkeyModule_Assert(errno == 0); + return val; +} + +static void helloLangParseIntegerParam(HelloFunc *func) { + char *token = strtok(NULL, " \n"); + func->instructions[func->num_instructions].param.integer = str2int(token); +} + +static void helloLangParseConstI(HelloFunc *func) { + helloLangParseIntegerParam(func); + func->num_instructions++; +} + +static void helloLangParseArgs(HelloFunc *func) { + helloLangParseIntegerParam(func); + func->num_instructions++; +} + +static HelloProgram *helloLangParseCode(const char *code) { + char *_code = ValkeyModule_Alloc(sizeof(char) * strlen(code)); + strcpy(_code, code); + + HelloProgram *program = ValkeyModule_Alloc(sizeof(HelloProgram)); + program->num_functions = 0; + + HelloFunc *currentFunc = NULL; + + char *token = strtok(_code, " \n"); + while (token != NULL) { + HelloInstKind kind = helloLangParseInstruction(token); + + if (currentFunc != NULL) { + currentFunc->instructions[currentFunc->num_instructions].kind = kind; + } + + switch (kind) { + case FUNCTION: + ValkeyModule_Assert(currentFunc == NULL); + currentFunc = &program->functions[program->num_functions++]; + helloLangParseFunction(currentFunc); + break; + case CONSTI: + ValkeyModule_Assert(currentFunc != NULL); + helloLangParseConstI(currentFunc); + break; + case ARGS: + ValkeyModule_Assert(currentFunc != NULL); + helloLangParseArgs(currentFunc); + break; + case RETURN: + ValkeyModule_Assert(currentFunc != NULL); + currentFunc->num_instructions++; + currentFunc = NULL; + break; + case _END: + ValkeyModule_Assert(0); + } + + token = strtok(NULL, " \n"); + } + + ValkeyModule_Free(_code); + + return program; +} + +static uint32_t executeHelloLangFunction(HelloFunc *func, ValkeyModuleString **args, int nargs) { + uint32_t stack[64]; + int sp = 0; + + for (uint32_t pc = 0; pc < func->num_instructions; pc++) { + HelloInst instr = func->instructions[pc]; + switch (instr.kind) { + case CONSTI: + stack[sp++] = instr.param.integer; + break; + case ARGS: + uint32_t idx = instr.param.integer; + ValkeyModule_Assert(idx < (uint32_t)nargs); + size_t len; + const char *argStr = ValkeyModule_StringPtrLen(args[idx], &len); + uint32_t arg = str2int(argStr); + stack[sp++] = arg; + break; + case RETURN: + uint32_t val = stack[--sp]; + ValkeyModule_Assert(sp == 0); + return val; + case FUNCTION: + case _END: + ValkeyModule_Assert(0); + } + } + + ValkeyModule_Assert(0); + return 0; +} + +static size_t engineGetUsedMemoy(void *engine_ctx) { + VALKEYMODULE_NOT_USED(engine_ctx); + return 0; +} + +static size_t engineMemoryOverhead(void *engine_ctx) { + HelloLangCtx *ctx = (HelloLangCtx *)engine_ctx; + size_t overhead = ValkeyModule_MallocSize(engine_ctx); + if (ctx->program != NULL) { + overhead += ValkeyModule_MallocSize(ctx->program); + } + return overhead; +} + +static size_t engineFunctionMemoryOverhead(void *compiled_function) { + VALKEYMODULE_NOT_USED(compiled_function); + return 0; +} + +static void engineFreeFunction(void *engine_ctx, void *compiled_function) { + VALKEYMODULE_NOT_USED(engine_ctx); + VALKEYMODULE_NOT_USED(compiled_function); +} + +static int createHelloLangEngine(void *engine_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li, const char *code, size_t timeout, char **err) { + VALKEYMODULE_NOT_USED(timeout); + VALKEYMODULE_NOT_USED(err); + + HelloLangCtx *ctx = (HelloLangCtx *)engine_ctx; + ctx->program = helloLangParseCode(code); + + for (uint32_t i = 0; i < ctx->program->num_functions; i++) { + HelloFunc *func = &ctx->program->functions[i]; + ValkeyModule_RegisterScriptingEngineFunction(func->name, func, li, NULL, 0, err); + } + + return 0; +} + +static void callHelloLangFunction(ValkeyModuleEngineFunctionCallCtx *func_ctx, + void *engine_ctx, + void *compiled_function, + ValkeyModuleString **keys, + size_t nkeys, + ValkeyModuleString **args, + size_t nargs) { + VALKEYMODULE_NOT_USED(engine_ctx); + VALKEYMODULE_NOT_USED(keys); + VALKEYMODULE_NOT_USED(nkeys); + + ValkeyModuleCtx *ctx = ValkeyModule_GetModuleCtxFromFunctionCallCtx(func_ctx); + + HelloFunc *func = (HelloFunc *)compiled_function; + ValkeyModule_Log(ctx, "debug", "calling function '%s'", func->name); + + uint32_t result = executeHelloLangFunction(func, args, nargs); + + ValkeyModule_Log(ctx, "debug", "function returned '%u'", result); + + ValkeyModule_ReplyWithLongLong(ctx, result); +} + +int ValkeyModule_OnLoad(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { + VALKEYMODULE_NOT_USED(argv); + VALKEYMODULE_NOT_USED(argc); + + if (ValkeyModule_Init(ctx, "helloengine", 1, VALKEYMODULE_APIVER_1) == VALKEYMODULE_ERR) return VALKEYMODULE_ERR; + + HelloLangCtx *hello_ctx = ValkeyModule_Alloc(sizeof(HelloLangCtx)); + hello_ctx->program = NULL; + + scripting_engine = ValkeyModule_Alloc(sizeof(ValkeyModuleScriptingEngine)); + + *scripting_engine = (ValkeyModuleScriptingEngine){ + .engine_ctx = hello_ctx, + .create = createHelloLangEngine, + .call = callHelloLangFunction, + .get_used_memory = engineGetUsedMemoy, + .get_function_memory_overhead = engineFunctionMemoryOverhead, + .get_engine_memory_overhead = engineMemoryOverhead, + .free_function = engineFreeFunction, + }; + + ValkeyModule_RegisterScriptingEngine(ctx, "HELLO", scripting_engine); + + return VALKEYMODULE_OK; +} diff --git a/src/rdb.c b/src/rdb.c index 1c200e54f53..092bcf95ec8 100644 --- a/src/rdb.c +++ b/src/rdb.c @@ -1303,7 +1303,7 @@ ssize_t rdbSaveFunctions(rio *rdb) { while ((entry = dictNext(iter))) { if ((ret = rdbSaveType(rdb, RDB_OPCODE_FUNCTION2)) < 0) goto werr; written += ret; - functionLibInfo *li = dictGetVal(entry); + ValkeyModuleScriptingEngineFunctionLibrary *li = dictGetVal(entry); if ((ret = rdbSaveRawString(rdb, (unsigned char *)li->code, sdslen(li->code))) < 0) goto werr; written += ret; } diff --git a/src/server.h b/src/server.h index 5cf56e9c868..e45ed58c71b 100644 --- a/src/server.h +++ b/src/server.h @@ -2716,6 +2716,7 @@ int moduleLateDefrag(robj *key, robj *value, unsigned long *cursor, long long en void moduleDefragGlobals(void); void *moduleGetHandleByName(char *modulename); int moduleIsModuleCommand(void *module_handle, struct serverCommand *cmd); +struct scriptRunCtx *moduleGetScriptRunCtxFromFunctionCtx(ValkeyModuleEngineFunctionCallCtx *func_ctx); /* Utils */ long long ustime(void); diff --git a/src/valkeymodule.h b/src/valkeymodule.h index c2cdb2f0e7b..411b91b2e42 100644 --- a/src/valkeymodule.h +++ b/src/valkeymodule.h @@ -787,6 +787,8 @@ typedef struct ValkeyModuleIO ValkeyModuleIO; typedef struct ValkeyModuleDigest ValkeyModuleDigest; typedef struct ValkeyModuleInfoCtx ValkeyModuleInfoCtx; typedef struct ValkeyModuleDefragCtx ValkeyModuleDefragCtx; +typedef struct ValkeyModuleScriptingEngineFunctionLibrary ValkeyModuleScriptingEngineFunctionLibrary; +typedef struct ValkeyModuleEngineFunctionCallCtx ValkeyModuleEngineFunctionCallCtx; /* Function pointers needed by both the core and modules, these needs to be * exposed since you can't cast a function pointer to (void *). */ @@ -794,6 +796,41 @@ typedef void (*ValkeyModuleInfoFunc)(ValkeyModuleInfoCtx *ctx, int for_crash_rep typedef void (*ValkeyModuleDefragFunc)(ValkeyModuleDefragCtx *ctx); typedef void (*ValkeyModuleUserChangedFunc)(uint64_t client_id, void *privdata); +typedef struct ValkeyModuleScriptingEngine { + /* engine specific context */ + void *engine_ctx; + + /* Create function callback, get the engine_ctx, and function code + * engine_ctx - opaque struct that was created on engine initialization + * li - library information that need to be provided and when add functions + * code - the library code + * timeout - timeout for the library creation (0 for no timeout) + * err - description of error (if occurred) + * returns C_ERR on error and set err to be the error message */ + int (*create)(void *engine_ctx, ValkeyModuleScriptingEngineFunctionLibrary *li, const char *code, size_t timeout, char **err); + + /* Invoking a function, func_ctx is an opaque object (from engine POV). + * The func_ctx should be used by the engine to interaction with the server, + * such interaction could be running commands, set resp, or set + * replication mode + */ + void (*call)(ValkeyModuleEngineFunctionCallCtx *func_ctx, void *engine_ctx, void *compiled_function, ValkeyModuleString **keys, size_t nkeys, ValkeyModuleString **args, size_t nargs); + + /* get current used memory by the engine */ + size_t (*get_used_memory)(void *engine_ctx); + + /* Return memory overhead for a given function, + * such memory is not counted as engine memory but as general + * structs memory that hold different information */ + size_t (*get_function_memory_overhead)(void *compiled_function); + + /* Return memory overhead for engine (struct size holding the engine)*/ + size_t (*get_engine_memory_overhead)(void *engine_ctx); + + /* free the given function */ + void (*free_function)(void *engine_ctx, void *compiled_function); +} ValkeyModuleScriptingEngine; + /* ------------------------- End of common defines ------------------------ */ /* ----------- The rest of the defines are only for modules ----------------- */ @@ -1649,6 +1686,19 @@ VALKEYMODULE_API int (*ValkeyModule_RdbSave)(ValkeyModuleCtx *ctx, ValkeyModuleRdbStream *stream, int flags) VALKEYMODULE_ATTR; +VALKEYMODULE_API int (*ValkeyModule_RegisterScriptingEngine)(ValkeyModuleCtx *ctx, + const char *engine_name, + ValkeyModuleScriptingEngine *scripting_engine) VALKEYMODULE_ATTR; + +VALKEYMODULE_API int (*ValkeyModule_RegisterScriptingEngineFunction)(const char *name, + void *function, + ValkeyModuleScriptingEngineFunctionLibrary *li, + const char *desc, + uint64_t f_flags, + char **err) VALKEYMODULE_ATTR; + +VALKEYMODULE_API ValkeyModuleCtx *(*ValkeyModule_GetModuleCtxFromFunctionCallCtx)(ValkeyModuleEngineFunctionCallCtx *func_ctx); + #define ValkeyModule_IsAOFClient(id) ((id) == UINT64_MAX) /* This is included inline inside each Valkey module. */ @@ -2015,6 +2065,9 @@ static int ValkeyModule_Init(ValkeyModuleCtx *ctx, const char *name, int ver, in VALKEYMODULE_GET_API(RdbStreamFree); VALKEYMODULE_GET_API(RdbLoad); VALKEYMODULE_GET_API(RdbSave); + VALKEYMODULE_GET_API(RegisterScriptingEngine); + VALKEYMODULE_GET_API(RegisterScriptingEngineFunction); + VALKEYMODULE_GET_API(GetModuleCtxFromFunctionCallCtx); if (ValkeyModule_IsModuleNameBusy && ValkeyModule_IsModuleNameBusy(name)) return VALKEYMODULE_ERR; ValkeyModule_SetModuleAttribs(ctx, name, ver, apiver);