diff --git a/src/commands.c b/src/commands.c index aee327f5f..63dd3d757 100644 --- a/src/commands.c +++ b/src/commands.c @@ -3144,24 +3144,6 @@ struct redisCommandArg FCALL_RO_Args[] = { {0} }; -/********** FUNCTION CREATE ********************/ - -/* FUNCTION CREATE history */ -#define FUNCTION_CREATE_History NULL - -/* FUNCTION CREATE hints */ -#define FUNCTION_CREATE_Hints NULL - -/* FUNCTION CREATE argument table */ -struct redisCommandArg FUNCTION_CREATE_Args[] = { -{"engine-name",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, -{"function-name",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, -{"replace",ARG_TYPE_PURE_TOKEN,-1,"REPLACE",NULL,NULL,CMD_ARG_OPTIONAL}, -{"function-description",ARG_TYPE_STRING,-1,"DESC",NULL,NULL,CMD_ARG_OPTIONAL}, -{"function-code",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, -{0} -}; - /********** FUNCTION DELETE ********************/ /* FUNCTION DELETE history */ @@ -3213,21 +3195,6 @@ struct redisCommandArg FUNCTION_FLUSH_Args[] = { /* FUNCTION HELP hints */ #define FUNCTION_HELP_Hints NULL -/********** FUNCTION INFO ********************/ - -/* FUNCTION INFO history */ -#define FUNCTION_INFO_History NULL - -/* FUNCTION INFO hints */ -#define FUNCTION_INFO_Hints NULL - -/* FUNCTION INFO argument table */ -struct redisCommandArg FUNCTION_INFO_Args[] = { -{"function-name",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, -{"withcode",ARG_TYPE_PURE_TOKEN,-1,"WITHCODE",NULL,NULL,CMD_ARG_OPTIONAL}, -{0} -}; - /********** FUNCTION KILL ********************/ /* FUNCTION KILL history */ @@ -3244,6 +3211,31 @@ struct redisCommandArg FUNCTION_INFO_Args[] = { /* FUNCTION LIST hints */ #define FUNCTION_LIST_Hints NULL +/* FUNCTION LIST argument table */ +struct redisCommandArg FUNCTION_LIST_Args[] = { +{"library-name-pattern",ARG_TYPE_STRING,-1,"LIBRARYNAME",NULL,NULL,CMD_ARG_OPTIONAL}, +{"withcode",ARG_TYPE_PURE_TOKEN,-1,"WITHCODE",NULL,NULL,CMD_ARG_OPTIONAL}, +{0} +}; + +/********** FUNCTION LOAD ********************/ + +/* FUNCTION LOAD history */ +#define FUNCTION_LOAD_History NULL + +/* FUNCTION LOAD hints */ +#define FUNCTION_LOAD_Hints NULL + +/* FUNCTION LOAD argument table */ +struct redisCommandArg FUNCTION_LOAD_Args[] = { +{"engine-name",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, +{"library-name",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, +{"replace",ARG_TYPE_PURE_TOKEN,-1,"REPLACE",NULL,NULL,CMD_ARG_OPTIONAL}, +{"library-description",ARG_TYPE_STRING,-1,"DESC",NULL,NULL,CMD_ARG_OPTIONAL}, +{"function-code",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE}, +{0} +}; + /********** FUNCTION RESTORE ********************/ /* FUNCTION RESTORE history */ @@ -3277,15 +3269,14 @@ struct redisCommandArg FUNCTION_RESTORE_Args[] = { /* FUNCTION command table */ struct redisCommand FUNCTION_Subcommands[] = { -{"create","Create a function with the given arguments (name, code, description)","O(1) (considering compilation time is redundant)","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_CREATE_History,FUNCTION_CREATE_Hints,functionCreateCommand,-5,CMD_NOSCRIPT|CMD_WRITE,ACL_CATEGORY_SCRIPTING,.args=FUNCTION_CREATE_Args}, {"delete","Delete a function by name","O(1)","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_DELETE_History,FUNCTION_DELETE_Hints,functionDeleteCommand,3,CMD_NOSCRIPT|CMD_WRITE,ACL_CATEGORY_SCRIPTING,.args=FUNCTION_DELETE_Args}, {"dump","Dump all functions into a serialized binary payload","O(N) where N is the number of functions","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_DUMP_History,FUNCTION_DUMP_Hints,functionDumpCommand,2,CMD_NOSCRIPT,ACL_CATEGORY_SCRIPTING}, {"flush","Deleting all functions","O(N) where N is the number of functions deleted","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_FLUSH_History,FUNCTION_FLUSH_Hints,functionFlushCommand,-2,CMD_NOSCRIPT|CMD_WRITE,ACL_CATEGORY_SCRIPTING,.args=FUNCTION_FLUSH_Args}, {"help","Show helpful text about the different subcommands","O(1)","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_HELP_History,FUNCTION_HELP_Hints,functionHelpCommand,2,CMD_LOADING|CMD_STALE,ACL_CATEGORY_SCRIPTING}, -{"info","Return information about a function by function name","O(1)","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_INFO_History,FUNCTION_INFO_Hints,functionInfoCommand,-3,CMD_NOSCRIPT,ACL_CATEGORY_SCRIPTING,.args=FUNCTION_INFO_Args}, {"kill","Kill the function currently in execution.","O(1)","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_KILL_History,FUNCTION_KILL_Hints,functionKillCommand,2,CMD_NOSCRIPT,ACL_CATEGORY_SCRIPTING}, -{"list","List information about all the functions","O(N) where N is the number of functions","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_LIST_History,FUNCTION_LIST_Hints,functionListCommand,2,CMD_NOSCRIPT,ACL_CATEGORY_SCRIPTING}, -{"restore","Restore all the functions on the given payload","O(N) where N is the number of functions on the payload","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_RESTORE_History,FUNCTION_RESTORE_Hints,functionRestoreCommand,-3,CMD_NOSCRIPT|CMD_WRITE,ACL_CATEGORY_SCRIPTING,.args=FUNCTION_RESTORE_Args}, +{"list","List information about all the functions","O(N) where N is the number of functions","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_LIST_History,FUNCTION_LIST_Hints,functionListCommand,-2,CMD_NOSCRIPT,ACL_CATEGORY_SCRIPTING,.args=FUNCTION_LIST_Args}, +{"load","Create a function with the given arguments (name, code, description)","O(1) (considering compilation time is redundant)","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_LOAD_History,FUNCTION_LOAD_Hints,functionLoadCommand,-5,CMD_NOSCRIPT|CMD_WRITE|CMD_DENYOOM,ACL_CATEGORY_SCRIPTING,.args=FUNCTION_LOAD_Args}, +{"restore","Restore all the functions on the given payload","O(N) where N is the number of functions on the payload","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_RESTORE_History,FUNCTION_RESTORE_Hints,functionRestoreCommand,-3,CMD_NOSCRIPT|CMD_WRITE|CMD_DENYOOM,ACL_CATEGORY_SCRIPTING,.args=FUNCTION_RESTORE_Args}, {"stats","Return information about the function currently running (name, description, duration)","O(1)","7.0.0",CMD_DOC_NONE,NULL,NULL,COMMAND_GROUP_SCRIPTING,FUNCTION_STATS_History,FUNCTION_STATS_Hints,functionStatsCommand,2,CMD_NOSCRIPT,ACL_CATEGORY_SCRIPTING}, {0} }; diff --git a/src/commands/function-info.json b/src/commands/function-info.json deleted file mode 100644 index 450b195f2..000000000 --- a/src/commands/function-info.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "INFO": { - "summary": "Return information about a function by function name", - "complexity": "O(1)", - "group": "scripting", - "since": "7.0.0", - "arity": -3, - "container": "FUNCTION", - "function": "functionInfoCommand", - "command_flags": [ - "NOSCRIPT" - ], - "acl_categories": [ - "SCRIPTING" - ], - "arguments": [ - { - "name": "function-name", - "type": "string" - }, - { - "name": "withcode", - "type": "pure-token", - "token": "WITHCODE", - "optional": true - } - ] - } -} diff --git a/src/commands/function-list.json b/src/commands/function-list.json index 66cf3e251..601299345 100644 --- a/src/commands/function-list.json +++ b/src/commands/function-list.json @@ -4,7 +4,7 @@ "complexity": "O(N) where N is the number of functions", "group": "scripting", "since": "7.0.0", - "arity": 2, + "arity": -2, "container": "FUNCTION", "function": "functionListCommand", "command_flags": [ @@ -12,6 +12,20 @@ ], "acl_categories": [ "SCRIPTING" + ], + "arguments": [ + { + "name": "library-name-pattern", + "type": "string", + "token": "LIBRARYNAME", + "optional": true + }, + { + "name": "withcode", + "type": "pure-token", + "token": "WITHCODE", + "optional": true + } ] } } diff --git a/src/commands/function-create.json b/src/commands/function-load.json similarity index 84% rename from src/commands/function-create.json rename to src/commands/function-load.json index 10cf5c1d9..202c70c80 100644 --- a/src/commands/function-create.json +++ b/src/commands/function-load.json @@ -1,15 +1,16 @@ { - "CREATE": { + "LOAD": { "summary": "Create a function with the given arguments (name, code, description)", "complexity": "O(1) (considering compilation time is redundant)", "group": "scripting", "since": "7.0.0", "arity": -5, "container": "FUNCTION", - "function": "functionCreateCommand", + "function": "functionLoadCommand", "command_flags": [ "NOSCRIPT", - "WRITE" + "WRITE", + "DENYOOM" ], "acl_categories": [ "SCRIPTING" @@ -20,7 +21,7 @@ "type": "string" }, { - "name": "function-name", + "name": "library-name", "type": "string" }, { @@ -30,7 +31,7 @@ "optional": true }, { - "name": "function-description", + "name": "library-description", "type": "string", "token": "DESC", "optional": true diff --git a/src/commands/function-restore.json b/src/commands/function-restore.json index bc9b32be4..ebdc36a4a 100644 --- a/src/commands/function-restore.json +++ b/src/commands/function-restore.json @@ -9,7 +9,8 @@ "function": "functionRestoreCommand", "command_flags": [ "NOSCRIPT", - "WRITE" + "WRITE", + "DENYOOM" ], "acl_categories": [ "SCRIPTING" diff --git a/src/db.c b/src/db.c index d2fa8aba9..0d19b00dd 100644 --- a/src/db.c +++ b/src/db.c @@ -461,7 +461,7 @@ long long emptyData(int dbnum, int flags, void(callback)(dict*)) { if (with_functions) { serverAssert(dbnum == -1); - functionsCtxClearCurrent(async); + functionsLibCtxClearCurrent(async); } /* Also fire the end event. Note that this event will fire almost diff --git a/src/function_lua.c b/src/function_lua.c index 864ced809..e6b8a2727 100644 --- a/src/function_lua.c +++ b/src/function_lua.c @@ -48,6 +48,9 @@ #define LUA_ENGINE_NAME "LUA" #define REGISTRY_ENGINE_CTX_NAME "__ENGINE_CTX__" #define REGISTRY_ERROR_HANDLER_NAME "__ERROR_HANDLER__" +#define REGISTRY_LOAD_CTX_NAME "__LIBRARY_CTX__" +#define LIBRARY_API_NAME "__LIBRARY_API__" +#define LOAD_TIMEOUT_MS 500 /* Lua engine ctx */ typedef struct luaEngineCtx { @@ -60,6 +63,27 @@ typedef struct luaFunctionCtx { int lua_function_ref; } luaFunctionCtx; +typedef struct loadCtx { + functionLibInfo *li; + monotime start_time; +} loadCtx; + +/* Hook for FUNCTION LOAD execution. + * Used to cancel the execution in case of a timeout (500ms). + * This execution should be fast and should only register + * functions so 500ms should be more than enough. */ +static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) { + UNUSED(ar); + loadCtx *load_ctx = luaGetFromRegistry(lua, REGISTRY_LOAD_CTX_NAME); + uint64_t duration = elapsedMs(load_ctx->start_time); + if (duration > LOAD_TIMEOUT_MS) { + lua_sethook(lua, luaEngineLoadHook, LUA_MASKLINE, 0); + + lua_pushstring(lua,"FUNCTION LOAD timeout"); + lua_error(lua); + } +} + /* * Compile a given blob and save it on the registry. * Return a function ctx with Lua ref that allows to later retrieve the @@ -67,25 +91,88 @@ typedef struct luaFunctionCtx { * * Return NULL on compilation error and set the error to the err variable */ -static void* luaEngineCreate(void *engine_ctx, sds blob, sds *err) { +static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, sds *err) { luaEngineCtx *lua_engine_ctx = engine_ctx; lua_State *lua = lua_engine_ctx->lua; - if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) { - *err = sdsempty(); - *err = sdscatprintf(*err, "Error compiling function: %s", - lua_tostring(lua, -1)); - lua_pop(lua, 1); - return NULL; - } + /* Each library will have its own global distinct table. + * We will create a new fresh Lua table and use + * lua_setfenv to set the table as the library globals + * (https://www.lua.org/manual/5.1/manual.html#lua_setfenv) + * + * At first, populate this new table with only the 'library' API + * to make sure only 'library' API is available at start. After the + * initial run is finished and all functions are registered, add + * all the default globals to the library global table and delete + * the library API. + * + * There are 2 ways to achieve the last part (add default + * globals to the new table): + * + * 1. Initialize the new table with all the default globals + * 2. Inheritance using metatable (https://www.lua.org/pil/14.3.html) + * + * For now we are choosing the second, we can change it in the future to + * achieve a better isolation between functions. */ + lua_newtable(lua); /* Global table for the library */ + lua_pushstring(lua, REDIS_API_NAME); + lua_pushstring(lua, LIBRARY_API_NAME); + lua_gettable(lua, LUA_REGISTRYINDEX); /* get library function from registry */ + lua_settable(lua, -3); /* push the library table to the new global table */ + + /* Set global protection on the new global table */ + luaSetGlobalProtection(lua_engine_ctx->lua); + + /* compile the code */ + if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) { + *err = sdscatprintf(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1)); + lua_pop(lua, 2); /* pops the error and globals table */ + return C_ERR; + } serverAssert(lua_isfunction(lua, -1)); - int lua_function_ref = luaL_ref(lua, LUA_REGISTRYINDEX); + loadCtx load_ctx = { + .li = li, + .start_time = getMonotonicUs(), + }; + luaSaveOnRegistry(lua, REGISTRY_LOAD_CTX_NAME, &load_ctx); - luaFunctionCtx *f_ctx = zmalloc(sizeof(*f_ctx)); - *f_ctx = (luaFunctionCtx ) { .lua_function_ref = lua_function_ref, }; + /* set the function environment so only 'library' API can be accessed. */ + lua_pushvalue(lua, -2); /* push global table to the front */ + lua_setfenv(lua, -2); - return f_ctx; + lua_sethook(lua,luaEngineLoadHook,LUA_MASKCOUNT,100000); + /* Run the compiled code to allow it to register functions */ + if (lua_pcall(lua,0,0,0)) { + *err = sdscatprintf(sdsempty(), "Error registering functions: %s", lua_tostring(lua, -1)); + lua_pop(lua, 2); /* pops the error and globals table */ + lua_sethook(lua,NULL,0,0); /* Disable hook */ + luaSaveOnRegistry(lua, REGISTRY_LOAD_CTX_NAME, NULL); + return C_ERR; + } + lua_sethook(lua,NULL,0,0); /* Disable hook */ + luaSaveOnRegistry(lua, REGISTRY_LOAD_CTX_NAME, NULL); + + /* stack contains the global table, lets rearrange it to contains the entire API. */ + /* delete 'redis' API */ + lua_pushstring(lua, REDIS_API_NAME); + lua_pushnil(lua); + lua_settable(lua, -3); + + /* create metatable */ + lua_newtable(lua); + lua_pushstring(lua, "__index"); + lua_pushvalue(lua, LUA_GLOBALSINDEX); /* push original globals */ + lua_settable(lua, -3); + lua_pushstring(lua, "__newindex"); + lua_pushvalue(lua, LUA_GLOBALSINDEX); /* push original globals */ + lua_settable(lua, -3); + + lua_setmetatable(lua, -2); + + lua_pop(lua, 1); /* pops the global table */ + + return C_OK; } /* @@ -137,6 +224,64 @@ static void luaEngineFreeFunction(void *engine_ctx, void *compiled_function) { zfree(f_ctx); } +static int luaRegisterFunction(lua_State *lua) { + int argc = lua_gettop(lua); + if (argc < 2 || argc > 3) { + luaPushError(lua, "wrong number of arguments to redis.register_function"); + return luaRaiseError(lua); + } + loadCtx *load_ctx = luaGetFromRegistry(lua, REGISTRY_LOAD_CTX_NAME); + if (!load_ctx) { + luaPushError(lua, "redis.register_function can only be called on FUNCTION LOAD command"); + return luaRaiseError(lua); + } + + if (!lua_isstring(lua, 1)) { + luaPushError(lua, "first argument to redis.register_function must be a string"); + return luaRaiseError(lua); + } + + if (!lua_isfunction(lua, 2)) { + luaPushError(lua, "second argument to redis.register_function must be a function"); + return luaRaiseError(lua); + } + + if (argc == 3 && !lua_isstring(lua, 3)) { + luaPushError(lua, "third argument to redis.register_function must be a string"); + return luaRaiseError(lua); + } + + size_t function_name_len; + const char *function_name = lua_tolstring(lua, 1, &function_name_len); + sds function_name_sds = sdsnewlen(function_name, function_name_len); + + sds desc_sds = NULL; + if (argc == 3){ + size_t desc_len; + const char *desc = lua_tolstring(lua, 3, &desc_len); + desc_sds = sdsnewlen(desc, desc_len); + lua_pop(lua, 1); /* pop out the description */ + } + + int lua_function_ref = luaL_ref(lua, LUA_REGISTRYINDEX); + + luaFunctionCtx *lua_f_ctx = zmalloc(sizeof(*lua_f_ctx)); + *lua_f_ctx = (luaFunctionCtx ) { .lua_function_ref = lua_function_ref, }; + + sds err = NULL; + if (functionLibCreateFunction(function_name_sds, lua_f_ctx, load_ctx->li, desc_sds, &err) != C_OK) { + sdsfree(function_name_sds); + if (desc_sds) sdsfree(desc_sds); + lua_unref(lua, lua_f_ctx->lua_function_ref); + zfree(lua_f_ctx); + luaPushError(lua, err); + sdsfree(err); + return luaRaiseError(lua); + } + + return 0; +} + /* Initialize Lua engine, should be called once on start. */ int luaEngineInitEngine() { luaEngineCtx *lua_engine_ctx = zmalloc(sizeof(*lua_engine_ctx)); @@ -144,6 +289,18 @@ int luaEngineInitEngine() { luaRegisterRedisAPI(lua_engine_ctx->lua); + /* Register the library commands table and fields and store it to registry */ + lua_pushstring(lua_engine_ctx->lua, LIBRARY_API_NAME); + lua_newtable(lua_engine_ctx->lua); + + lua_pushstring(lua_engine_ctx->lua, "register_function"); + lua_pushcfunction(lua_engine_ctx->lua, luaRegisterFunction); + lua_settable(lua_engine_ctx->lua, -3); + + luaRegisterLogFunction(lua_engine_ctx->lua); + + lua_settable(lua_engine_ctx->lua, LUA_REGISTRYINDEX); + /* Save error handler to registry */ lua_pushstring(lua_engine_ctx->lua, REGISTRY_ERROR_HANDLER_NAME); char *errh_func = "local dbg = debug\n" @@ -163,12 +320,17 @@ int luaEngineInitEngine() { lua_pcall(lua_engine_ctx->lua,0,1,0); lua_settable(lua_engine_ctx->lua, LUA_REGISTRYINDEX); + /* Save global protection to registry */ + luaRegisterGlobalProtectionFunction(lua_engine_ctx->lua); + + /* Set global protection on globals */ + lua_pushvalue(lua_engine_ctx->lua, LUA_GLOBALSINDEX); + luaSetGlobalProtection(lua_engine_ctx->lua); + lua_pop(lua_engine_ctx->lua, 1); + /* save the engine_ctx on the registry so we can get it from the Lua interpreter */ luaSaveOnRegistry(lua_engine_ctx->lua, REGISTRY_ENGINE_CTX_NAME, lua_engine_ctx); - luaEnableGlobalsProtection(lua_engine_ctx->lua, 0); - - engine *lua_engine = zmalloc(sizeof(*lua_engine)); *lua_engine = (engine) { .engine_ctx = lua_engine_ctx, diff --git a/src/functions.c b/src/functions.c index a515709ca..5a8cd27c1 100644 --- a/src/functions.c +++ b/src/functions.c @@ -41,8 +41,11 @@ static size_t engine_cache_memory = 0; /* Forward declaration */ static void engineFunctionDispose(dict *d, void *obj); +static void engineLibraryDispose(dict *d, void *obj); +static int functionsVerifyName(sds name); -struct functionsCtx { +struct functionsLibCtx { + dict *libraries; /* Function name -> Function object that can be used to run the function */ dict *functions; /* Function name -> Function object that can be used to run the function */ size_t cache_memory /* Overhead memory (structs, dictionaries, ..) used by all the functions */; }; @@ -63,23 +66,49 @@ dictType functionDictType = { NULL, /* val dup */ dictSdsKeyCaseCompare,/* key compare */ dictSdsDestructor, /* key destructor */ + NULL, /* val destructor */ + NULL /* allow to expand */ +}; + +dictType libraryFunctionDictType = { + dictSdsHash, /* hash function */ + dictSdsDup, /* key dup */ + NULL, /* val dup */ + dictSdsKeyCompare, /* key compare */ + dictSdsDestructor, /* key destructor */ engineFunctionDispose,/* val destructor */ NULL /* allow to expand */ }; +dictType librariesDictType = { + dictSdsHash, /* hash function */ + dictSdsDup, /* key dup */ + NULL, /* val dup */ + dictSdsKeyCompare, /* key compare */ + dictSdsDestructor, /* key destructor */ + engineLibraryDispose, /* val destructor */ + NULL /* allow to expand */ +}; + /* Dictionary of engines */ static dict *engines = NULL; -/* Functions Ctx. - * Contains the dictionary that map a function name to - * function object and the cache memory used by all the functions */ -static functionsCtx *functions_ctx = NULL; +/* Libraries Ctx. + * Contains the dictionary that map a library name to library object, + * Contains the dictionary that map a function name to function object, + * and the cache memory used by all the functions */ +static functionsLibCtx *curr_functions_lib_ctx = NULL; static size_t functionMallocSize(functionInfo *fi) { return zmalloc_size(fi) + sdsZmallocSize(fi->name) + (fi->desc ? sdsZmallocSize(fi->desc) : 0) - + sdsZmallocSize(fi->code) - + fi->ei->engine->get_function_memory_overhead(fi->function); + + fi->li->ei->engine->get_function_memory_overhead(fi->function); +} + +static size_t libraryMallocSize(functionLibInfo *li) { + return zmalloc_size(li) + sdsZmallocSize(li->name) + + (li->desc ? sdsZmallocSize(li->desc) : 0) + + sdsZmallocSize(li->code); } /* Dispose function memory */ @@ -89,99 +118,231 @@ static void engineFunctionDispose(dict *d, void *obj) { return; } functionInfo *fi = obj; - sdsfree(fi->code); sdsfree(fi->name); if (fi->desc) { sdsfree(fi->desc); } - engine *engine = fi->ei->engine; + engine *engine = fi->li->ei->engine; engine->free_function(engine->engine_ctx, fi->function); zfree(fi); } -/* Free function memory and detele it from the functions dictionary */ -static void engineFunctionFree(functionInfo *fi, functionsCtx *functions) { - functions->cache_memory -= functionMallocSize(fi); - - dictDelete(functions->functions, fi->name); +static void engineLibraryFree(functionLibInfo* li) { + if (!li) { + return; + } + dictRelease(li->functions); + sdsfree(li->name); + sdsfree(li->code); + if (li->desc) sdsfree(li->desc); + zfree(li); } -/* Clear all the functions from the given functions ctx */ -void functionsCtxClear(functionsCtx *functions_ctx) { - dictEmpty(functions_ctx->functions, NULL); - functions_ctx->cache_memory = 0; +static void engineLibraryDispose(dict *d, void *obj) { + UNUSED(d); + engineLibraryFree(obj); } -void functionsCtxClearCurrent(int async) { +/* Clear all the functions from the given library ctx */ +void functionsLibCtxClear(functionsLibCtx *lib_ctx) { + dictEmpty(lib_ctx->functions, NULL); + dictEmpty(lib_ctx->libraries, NULL); + curr_functions_lib_ctx->cache_memory = 0; +} + +void functionsLibCtxClearCurrent(int async) { if (async) { - functionsCtx *old_f_ctx = functions_ctx; - functions_ctx = functionsCtxCreate(); - freeFunctionsAsync(old_f_ctx); + functionsLibCtx *old_l_ctx = curr_functions_lib_ctx; + curr_functions_lib_ctx = functionsLibCtxCreate(); + freeFunctionsAsync(old_l_ctx); } else { - functionsCtxClear(functions_ctx); + functionsLibCtxClear(curr_functions_lib_ctx); } } /* Free the given functions ctx */ -void functionsCtxFree(functionsCtx *functions_ctx) { - functionsCtxClear(functions_ctx); - dictRelease(functions_ctx->functions); - zfree(functions_ctx); +void functionsLibCtxFree(functionsLibCtx *functions_lib_ctx) { + functionsLibCtxClear(functions_lib_ctx); + dictRelease(functions_lib_ctx->functions); + dictRelease(functions_lib_ctx->libraries); + zfree(functions_lib_ctx); } /* Swap the current functions ctx with the given one. * Free the old functions ctx. */ -void functionsCtxSwapWithCurrent(functionsCtx *new_functions_ctx) { - functionsCtxFree(functions_ctx); - functions_ctx = new_functions_ctx; +void functionsLibCtxSwapWithCurrent(functionsLibCtx *new_lib_ctx) { + functionsLibCtxFree(curr_functions_lib_ctx); + curr_functions_lib_ctx = new_lib_ctx; } /* return the current functions ctx */ -functionsCtx* functionsCtxGetCurrent() { - return functions_ctx; +functionsLibCtx* functionsLibCtxGetCurrent() { + return curr_functions_lib_ctx; } /* Create a new functions ctx */ -functionsCtx* functionsCtxCreate() { - functionsCtx *ret = zmalloc(sizeof(functionsCtx)); +functionsLibCtx* functionsLibCtxCreate() { + functionsLibCtx *ret = zmalloc(sizeof(functionsLibCtx)); + ret->libraries = dictCreate(&librariesDictType); ret->functions = dictCreate(&functionDictType); ret->cache_memory = 0; return ret; } /* - * Register a function info to functions dictionary - * 1. Set the function client - * 2. Add function to functions dictionary - * 3. update cache memory + * Creating a function inside the given library. + * On success, return C_OK. + * On error, return C_ERR and set err output parameter with a relevant error message. + * + * Note: the code assumes 'name' is NULL terminated but not require it to be binary safe. + * the function will verify that the given name is following the naming format + * and return an error if its not. */ -static void engineFunctionRegister(functionInfo *fi, functionsCtx *functions) { - int res = dictAdd(functions->functions, fi->name, fi); - serverAssert(res == DICT_OK); +int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds desc, sds *err) { + if (functionsVerifyName(name) != C_OK) { + *err = sdsnew("Function names can only contain letters and numbers and must be at least one character long"); + return C_ERR; + } - functions->cache_memory += functionMallocSize(fi); -} - -/* - * Creating a function info object and register it. - * Return the created object - */ -static functionInfo* engineFunctionCreate(sds name, void *function, engineInfo *ei, - sds desc, sds code, functionsCtx *functions) -{ + if (dictFetchValue(li->functions, name)) { + *err = sdsnew("Function already exists in the library"); + return C_ERR; + } functionInfo *fi = zmalloc(sizeof(*fi)); - *fi = (functionInfo ) { - .name = sdsdup(name), + *fi = (functionInfo) { + .name = name, .function = function, + .li = li, + .desc = desc, + }; + + int res = dictAdd(li->functions, fi->name, fi); + serverAssert(res == DICT_OK); + + return C_OK; +} + +static functionLibInfo* engineLibraryCreate(sds name, engineInfo *ei, sds desc, sds code) { + functionLibInfo *li = zmalloc(sizeof(*li)); + *li = (functionLibInfo) { + .name = sdsdup(name), + .functions = dictCreate(&libraryFunctionDictType), .ei = ei, .code = sdsdup(code), .desc = desc ? sdsdup(desc) : NULL, }; + return li; +} - engineFunctionRegister(fi, functions); +static void libraryUnlink(functionsLibCtx *lib_ctx, functionLibInfo* li) { + dictIterator *iter = dictGetIterator(li->functions); + dictEntry *entry = NULL; + while ((entry = dictNext(iter))) { + functionInfo *fi = dictGetVal(entry); + int ret = dictDelete(lib_ctx->functions, fi->name); + serverAssert(ret == DICT_OK); + lib_ctx->cache_memory -= functionMallocSize(fi); + } + dictReleaseIterator(iter); + entry = dictUnlink(lib_ctx->libraries, li->name); + dictSetVal(lib_ctx->libraries, entry, NULL); + dictFreeUnlinkedEntry(lib_ctx->libraries, entry); + lib_ctx->cache_memory += libraryMallocSize(li); +} - return fi; +static void libraryLink(functionsLibCtx *lib_ctx, functionLibInfo* li) { + dictIterator *iter = dictGetIterator(li->functions); + dictEntry *entry = NULL; + while ((entry = dictNext(iter))) { + functionInfo *fi = dictGetVal(entry); + dictAdd(lib_ctx->functions, fi->name, fi); + lib_ctx->cache_memory += functionMallocSize(fi); + } + dictReleaseIterator(iter); + + dictAdd(lib_ctx->libraries, li->name, li); + lib_ctx->cache_memory += libraryMallocSize(li); +} + +/* Takes all libraries from lib_ctx_src and add to lib_ctx_dst. + * On collision, if 'replace' argument is true, replace the existing library with the new one. + * Otherwise abort and leave 'lib_ctx_dst' and 'lib_ctx_src' untouched. + * Return C_OK on success and C_ERR if aborted. If C_ERR is retunred, set a relevant + * error message on the 'err' out parameter. + * */ +static int libraryJoin(functionsLibCtx *functions_lib_ctx_dst, functionsLibCtx *functions_lib_ctx_src, int replace, sds *err) { + int ret = C_ERR; + dictIterator *iter = NULL; + /* Stores the libraries we need to replace in case a revert is required. + * Only initialized when needed */ + list *old_libraries_list = NULL; + dictEntry *entry = NULL; + iter = dictGetIterator(functions_lib_ctx_src->libraries); + while ((entry = dictNext(iter))) { + functionLibInfo *li = dictGetVal(entry); + functionLibInfo *old_li = dictFetchValue(functions_lib_ctx_dst->libraries, li->name); + if (old_li) { + if (!replace) { + /* library already exists, failed the restore. */ + *err = sdscatfmt(sdsempty(), "Library %s already exists", li->name); + goto done; + } else { + if (!old_libraries_list) { + old_libraries_list = listCreate(); + listSetFreeMethod(old_libraries_list, (void (*)(void*))engineLibraryFree); + } + libraryUnlink(functions_lib_ctx_dst, old_li); + listAddNodeTail(old_libraries_list, old_li); + } + } + } + dictReleaseIterator(iter); + iter = NULL; + + /* Make sure no functions collision */ + iter = dictGetIterator(functions_lib_ctx_src->functions); + while ((entry = dictNext(iter))) { + functionInfo *fi = dictGetVal(entry); + if (dictFetchValue(functions_lib_ctx_dst->functions, fi->name)) { + *err = sdscatfmt(sdsempty(), "Function %s already exists", fi->name); + goto done; + } + } + dictReleaseIterator(iter); + iter = NULL; + + /* No collision, it is safe to link all the new libraries. */ + iter = dictGetIterator(functions_lib_ctx_src->libraries); + while ((entry = dictNext(iter))) { + functionLibInfo *li = dictGetVal(entry); + libraryLink(functions_lib_ctx_dst, li); + dictSetVal(functions_lib_ctx_src->libraries, entry, NULL); + } + dictReleaseIterator(iter); + iter = NULL; + + functionsLibCtxClear(functions_lib_ctx_src); + if (old_libraries_list) { + listRelease(old_libraries_list); + old_libraries_list = NULL; + } + ret = C_OK; + +done: + if (iter) dictReleaseIterator(iter); + if (old_libraries_list) { + /* Link back all libraries on tmp_l_ctx */ + while (listLength(old_libraries_list) > 0) { + listNode *head = listFirst(old_libraries_list); + functionLibInfo *li = listNodeValue(head); + listNodeValue(head) = NULL; + libraryLink(functions_lib_ctx_dst, li); + listDelNode(old_libraries_list, head); + } + listRelease(old_libraries_list); + } + return ret; } /* Register an engine, should be called once by the engine on startup and give the following: @@ -250,82 +411,111 @@ void functionStatsCommand(client *c) { } /* - * FUNCTION LIST + * FUNCTION LIST [LIBRARYNAME PATTERN] [WITHCODE] + * + * Return general information about all the libraries: + * * Library name + * * The engine used to run the Library + * * Library description + * * Functions list + * * Library code (if WITHCODE is given) + * + * It is also possible to given library name pattern using + * LIBRARYNAME argument, if given, return only libraries + * that matches the given pattern. */ void functionListCommand(client *c) { - /* general information on all the functions */ - addReplyArrayLen(c, dictSize(functions_ctx->functions)); - dictIterator *iter = dictGetIterator(functions_ctx->functions); + int with_code = 0; + sds library_name = NULL; + for (int i = 2 ; i < c->argc ; ++i) { + robj *next_arg = c->argv[i]; + if (!with_code && !strcasecmp(next_arg->ptr, "withcode")) { + with_code = 1; + continue; + } + if (!library_name && !strcasecmp(next_arg->ptr, "libraryname")) { + if (i >= c->argc - 1) { + addReplyError(c, "library name argument was not given"); + return; + } + library_name = c->argv[++i]->ptr; + continue; + } + addReplyErrorSds(c, sdscatfmt(sdsempty(), "Unknown argument %s", next_arg->ptr)); + return; + } + size_t reply_len = 0; + void *len_ptr = NULL; + if (library_name) { + len_ptr = addReplyDeferredLen(c); + } else { + /* If no pattern is asked we know the reply len and we can just set it */ + addReplyArrayLen(c, dictSize(curr_functions_lib_ctx->libraries)); + } + dictIterator *iter = dictGetIterator(curr_functions_lib_ctx->libraries); dictEntry *entry = NULL; while ((entry = dictNext(iter))) { - functionInfo *fi = dictGetVal(entry); - addReplyMapLen(c, 3); - addReplyBulkCString(c, "name"); - addReplyBulkCBuffer(c, fi->name, sdslen(fi->name)); + functionLibInfo *li = dictGetVal(entry); + if (library_name) { + if (!stringmatchlen(library_name, sdslen(library_name), li->name, sdslen(li->name), 1)) { + continue; + } + } + ++reply_len; + addReplyMapLen(c, with_code? 5 : 4); + addReplyBulkCString(c, "library_name"); + addReplyBulkCBuffer(c, li->name, sdslen(li->name)); addReplyBulkCString(c, "engine"); - addReplyBulkCBuffer(c, fi->ei->name, sdslen(fi->ei->name)); + addReplyBulkCBuffer(c, li->ei->name, sdslen(li->ei->name)); addReplyBulkCString(c, "description"); - if (fi->desc) { - addReplyBulkCBuffer(c, fi->desc, sdslen(fi->desc)); + if (li->desc) { + addReplyBulkCBuffer(c, li->desc, sdslen(li->desc)); } else { addReplyNull(c); } - } - dictReleaseIterator(iter); -} -/* - * FUNCTION INFO [WITHCODE] - */ -void functionInfoCommand(client *c) { - if (c->argc > 4) { - addReplyErrorFormat(c,"wrong number of arguments for '%s' command or subcommand", c->cmd->name); - return; - } - /* dedicated information on specific function */ - robj *function_name = c->argv[2]; - int with_code = 0; - if (c->argc == 4) { - robj *with_code_arg = c->argv[3]; - if (!strcasecmp(with_code_arg->ptr, "withcode")) { - with_code = 1; + addReplyBulkCString(c, "functions"); + addReplyArrayLen(c, dictSize(li->functions)); + dictIterator *functions_iter = dictGetIterator(li->functions); + dictEntry *function_entry = NULL; + while ((function_entry = dictNext(functions_iter))) { + functionInfo *fi = dictGetVal(function_entry); + addReplyMapLen(c, 2); + addReplyBulkCString(c, "name"); + addReplyBulkCBuffer(c, fi->name, sdslen(fi->name)); + addReplyBulkCString(c, "description"); + if (fi->desc) { + addReplyBulkCBuffer(c, fi->desc, sdslen(fi->desc)); + } else { + addReplyNull(c); + } + } + dictReleaseIterator(functions_iter); + + if (with_code) { + addReplyBulkCString(c, "library_code"); + addReplyBulkCBuffer(c, li->code, sdslen(li->code)); } } - - functionInfo *fi = dictFetchValue(functions_ctx->functions, function_name->ptr); - if (!fi) { - addReplyError(c, "Function does not exists"); - return; - } - addReplyMapLen(c, with_code? 4 : 3); - addReplyBulkCString(c, "name"); - addReplyBulkCBuffer(c, fi->name, sdslen(fi->name)); - addReplyBulkCString(c, "engine"); - addReplyBulkCBuffer(c, fi->ei->name, sdslen(fi->ei->name)); - addReplyBulkCString(c, "description"); - if (fi->desc) { - addReplyBulkCBuffer(c, fi->desc, sdslen(fi->desc)); - } else { - addReplyNull(c); - } - if (with_code) { - addReplyBulkCString(c, "code"); - addReplyBulkCBuffer(c, fi->code, sdslen(fi->code)); + dictReleaseIterator(iter); + if (len_ptr) { + setDeferredArrayLen(c, len_ptr, reply_len); } } /* - * FUNCTION DELETE + * FUNCTION DELETE */ void functionDeleteCommand(client *c) { robj *function_name = c->argv[2]; - functionInfo *fi = dictFetchValue(functions_ctx->functions, function_name->ptr); - if (!fi) { - addReplyError(c, "Function not found"); + functionLibInfo *li = dictFetchValue(curr_functions_lib_ctx->libraries, function_name->ptr); + if (!li) { + addReplyError(c, "Library not found"); return; } - engineFunctionFree(fi, functions_ctx); + libraryUnlink(curr_functions_lib_ctx, li); + engineLibraryFree(li); /* Indicate that the command changed the data so it will be replicated and * counted as a data change (for persistence configuration) */ server.dirty++; @@ -338,12 +528,12 @@ void functionKillCommand(client *c) { static void fcallCommandGeneric(client *c, int ro) { robj *function_name = c->argv[1]; - functionInfo *fi = dictFetchValue(functions_ctx->functions, function_name->ptr); + functionInfo *fi = dictFetchValue(curr_functions_lib_ctx->functions, function_name->ptr); if (!fi) { addReplyError(c, "Function not found"); return; } - engine *engine = fi->ei->engine; + engine *engine = fi->li->ei->engine; long long numkeys; /* Get the number of arguments that are keys */ @@ -361,7 +551,7 @@ static void fcallCommandGeneric(client *c, int ro) { scriptRunCtx run_ctx; - scriptPrepareForRun(&run_ctx, fi->ei->c, c, fi->name); + scriptPrepareForRun(&run_ctx, fi->li->ei->c, c, fi->name); if (ro) { run_ctx.flags |= SCRIPT_READ_ONLY; } @@ -387,17 +577,17 @@ void fcallroCommand(client *c) { /* * FUNCTION DUMP * - * Returns a binary payload representing all the functions. + * Returns a binary payload representing all the libraries. * Can be loaded using FUNCTION RESTORE * - * The payload structure is the same as on RDB. Each function + * The payload structure is the same as on RDB. Each library * is saved separately with the following information: - * * Function name + * * Library name * * Engine name - * * Function description - * * Function code - * RDB_OPCODE_FUNCTION is saved before each function to present - * that the payload is a function. + * * Library description + * * Library code + * RDB_OPCODE_FUNCTION is saved before each library to present + * that the payload is a library. * RDB version and crc64 is saved at the end of the payload. * The RDB version is saved for backward compatibility. * crc64 is saved so we can verify the payload content. @@ -427,12 +617,12 @@ void functionDumpCommand(client *c) { /* * FUNCTION RESTORE [FLUSH|APPEND|REPLACE] * - * Restore the functions represented by the give payload. - * Restore policy to can be given to control how to handle existing functions (default APPEND): - * * FLUSH: delete all existing functions. - * * APPEND: appends the restored functions to the existing functions. On collision, abort. - * * REPLACE: appends the restored functions to the existing functions. - * On collision, replace the old function with the new function. + * Restore the libraries represented by the give payload. + * Restore policy to can be given to control how to handle existing libraries (default APPEND): + * * FLUSH: delete all existing libraries. + * * APPEND: appends the restored libraries to the existing libraries. On collision, abort. + * * REPLACE: appends the restored libraries to the existing libraries. + * On collision, replace the old libraries with the new libraries. */ void functionRestoreCommand(client *c) { if (c->argc > 4) { @@ -444,7 +634,6 @@ void functionRestoreCommand(client *c) { sds data = c->argv[2]->ptr; size_t data_len = sdslen(data); rio payload; - dictIterator *iter = NULL; sds err = NULL; if (c->argc == 4) { @@ -467,7 +656,7 @@ void functionRestoreCommand(client *c) { return; } - functionsCtx *f_ctx = functionsCtxCreate(); + functionsLibCtx *functions_lib_ctx = functionsLibCtxCreate(); rioInitWithBuffer(&payload, data); /* Read until reaching last 10 bytes that should contain RDB version and checksum. */ @@ -481,7 +670,7 @@ void functionRestoreCommand(client *c) { err = sdsnew("given type is not a function"); goto load_error; } - if (rdbFunctionLoad(&payload, rdbver, f_ctx, RDBFLAGS_NONE, &err) != C_OK) { + if (rdbFunctionLoad(&payload, rdbver, functions_lib_ctx, RDBFLAGS_NONE, &err) != C_OK) { if (!err) { err = sdsnew("failed loading the given functions payload"); } @@ -490,29 +679,11 @@ void functionRestoreCommand(client *c) { } if (restore_replicy == restorePolicy_Flush) { - functionsCtxSwapWithCurrent(f_ctx); - f_ctx = NULL; /* avoid releasing the f_ctx in the end */ + functionsLibCtxSwapWithCurrent(functions_lib_ctx); + functions_lib_ctx = NULL; /* avoid releasing the f_ctx in the end */ } else { - if (restore_replicy == restorePolicy_Append) { - /* First make sure there is only new functions */ - iter = dictGetIterator(f_ctx->functions); - dictEntry *entry = NULL; - while ((entry = dictNext(iter))) { - functionInfo *fi = dictGetVal(entry); - if (dictFetchValue(functions_ctx->functions, fi->name)) { - /* function already exists, failed the restore. */ - err = sdscatfmt(sdsempty(), "Function %s already exists", fi->name); - goto load_error; - } - } - dictReleaseIterator(iter); - } - iter = dictGetIterator(f_ctx->functions); - dictEntry *entry = NULL; - while ((entry = dictNext(iter))) { - functionInfo *fi = dictGetVal(entry); - dictReplace(functions_ctx->functions, fi->name, fi); - dictSetVal(f_ctx->functions, entry, NULL); /* make sure value will not be disposed */ + if (libraryJoin(curr_functions_lib_ctx, functions_lib_ctx, restore_replicy == restorePolicy_Replace, &err) != C_OK) { + goto load_error; } } @@ -526,11 +697,8 @@ load_error: } else { addReply(c, shared.ok); } - if (iter) { - dictReleaseIterator(iter); - } - if (f_ctx) { - functionsCtxFree(f_ctx); + if (functions_lib_ctx) { + functionsLibCtxFree(functions_lib_ctx); } } @@ -551,7 +719,7 @@ void functionFlushCommand(client *c) { return; } - functionsCtxClearCurrent(async); + functionsLibCtxClearCurrent(async); /* Indicate that the command changed the data so it will be replicated and * counted as a data change (for persistence configuration) */ @@ -561,45 +729,43 @@ void functionFlushCommand(client *c) { void functionHelpCommand(client *c) { const char *help[] = { -"CREATE [REPLACE] [DESC ] ", -" Create a new function with the given function name and code.", -"DELETE ", -" Delete the given function.", -"INFO [WITHCODE]", -" For each function, print the following information about the function:", -" * Function name", -" * The engine used to run the function", -" * Function description", -" * Function code (only if WITHCODE is given)", -"LIST", -" Return general information on all the functions:", -" * Function name", -" * The engine used to run the function", -" * Function description", +"LOAD [REPLACE] [DESC ] ", +" Create a new library with the given library name and code.", +"DELETE ", +" Delete the given library.", +"LIST [LIBRARYNAME PATTERN] [WITHCODE]", +" Return general information on all the libraries:", +" * Library name", +" * The engine used to run the Library", +" * Library description", +" * Functions list", +" * Library code (if WITHCODE is given)", +" It also possible to get only function that matches a pattern using LIBRARYNAME argument.", "STATS", " Return information about the current function running:", " * Function name", " * Command used to run the function", " * Duration in MS that the function is running", -" If not function is running, return nil", +" If no function is running, return nil", " In addition, returns a list of available engines.", "KILL", " Kill the current running function.", "FLUSH [ASYNC|SYNC]", -" Delete all the functions.", +" Delete all the libraries.", " When called without the optional mode argument, the behavior is determined by the", " lazyfree-lazy-user-flush configuration directive. Valid modes are:", -" * ASYNC: Asynchronously flush the functions.", -" * SYNC: Synchronously flush the functions.", +" * ASYNC: Asynchronously flush the libraries.", +" * SYNC: Synchronously flush the libraries.", "DUMP", -" Returns a serialized payload representing the current functions, can be restored using FUNCTION RESTORE command", +" Returns a serialized payload representing the current libraries, can be restored using FUNCTION RESTORE command", "RESTORE [FLUSH|APPEND|REPLACE]", -" Restore the functions represented by the given payload, it is possible to give a restore policy to", -" control how to handle existing functions (default APPEND):", -" * FLUSH: delete all existing functions.", -" * APPEND: appends the restored functions to the existing functions. On collision, abort.", -" * REPLACE: appends the restored functions to the existing functions, On collision, replace the old", -" function with the new function.", +" Restore the libraries represented by the given payload, it is possible to give a restore policy to", +" control how to handle existing libraries (default APPEND):", +" * FLUSH: delete all existing libraries.", +" * APPEND: appends the restored libraries to the existing libraries. On collision, abort.", +" * REPLACE: appends the restored libraries to the existing libraries, On collision, replace the old", +" libraries with the new libraries (notice that even on this option there is a chance of failure", +" in case of functions name collision with another library).", NULL }; addReplyHelp(c, help); } @@ -623,12 +789,14 @@ static int functionsVerifyName(sds name) { return C_OK; } -/* Compile and save the given function, return C_OK on success and C_ERR on failure. +/* Compile and save the given library, return C_OK on success and C_ERR on failure. * In case on failure the err out param is set with relevant error message */ -int functionsCreateWithFunctionCtx(sds function_name,sds engine_name, sds desc, sds code, - int replace, sds* err, functionsCtx *functions) { - if (functionsVerifyName(function_name)) { - *err = sdsnew("Function names can only contain letters and numbers and must be at least one character long"); +int functionsCreateWithLibraryCtx(sds lib_name,sds engine_name, sds desc, sds code, + int replace, sds* err, functionsLibCtx *lib_ctx) { + dictIterator *iter = NULL; + dictEntry *entry = NULL; + if (functionsVerifyName(lib_name)) { + *err = sdsnew("Library names can only contain letters and numbers and must be at least one character long"); return C_ERR; } @@ -639,40 +807,69 @@ int functionsCreateWithFunctionCtx(sds function_name,sds engine_name, sds desc, } engine *engine = ei->engine; - functionInfo *fi = dictFetchValue(functions->functions, function_name); - if (fi && !replace) { - *err = sdsnew("Function already exists"); + functionLibInfo *old_li = dictFetchValue(lib_ctx->libraries, lib_name); + if (old_li && !replace) { + *err = sdsnew("Library already exists"); return C_ERR; } - void *function = engine->create(engine->engine_ctx, code, err); - if (*err) { - return C_ERR; + if (old_li) { + libraryUnlink(lib_ctx, old_li); } - if (fi) { - /* free the already existing function as we are going to replace it */ - engineFunctionFree(fi, functions); + functionLibInfo *new_li = engineLibraryCreate(lib_name, ei, desc, code); + if (engine->create(engine->engine_ctx, new_li, code, err) != C_OK) { + goto error; } - engineFunctionCreate(function_name, function, ei, desc, code, functions); + if (dictSize(new_li->functions) == 0) { + *err = sdsnew("No functions registered"); + goto error; + } + + /* Verify no duplicate functions */ + iter = dictGetIterator(new_li->functions); + while ((entry = dictNext(iter))) { + functionInfo *fi = dictGetVal(entry); + if (dictFetchValue(lib_ctx->functions, fi->name)) { + /* functions name collision, abort. */ + *err = sdscatfmt(sdsempty(), "Function %s already exists", fi->name); + goto error; + } + } + dictReleaseIterator(iter); + iter = NULL; + + libraryLink(lib_ctx, new_li); + + if (old_li) { + engineLibraryFree(old_li); + } return C_OK; + +error: + if (iter) dictReleaseIterator(iter); + engineLibraryFree(new_li); + if (old_li) { + libraryLink(lib_ctx, old_li); + } + return C_ERR; } /* - * FUNCTION CREATE - * [REPLACE] [DESC ] + * FUNCTION LOAD + * [REPLACE] [DESC ] * - * ENGINE NAME - name of the engine to use the run the function - * FUNCTION NAME - name to use to invoke the function - * REPLACE - optional, replace existing function - * DESCRIPTION - optional, function description - * FUNCTION CODE - function code to pass to the engine + * ENGINE NAME - name of the engine to use the run the library + * LIBRARY NAME - name of the library + * REPLACE - optional, replace existing library + * DESCRIPTION - optional, library description + * LIBRARY CODE - library code to pass to the engine */ -void functionCreateCommand(client *c) { +void functionLoadCommand(client *c) { robj *engine_name = c->argv[2]; - robj *function_name = c->argv[3]; + robj *library_name = c->argv[3]; int replace = 0; int argc_pos = 4; @@ -700,8 +897,8 @@ void functionCreateCommand(client *c) { robj *code = c->argv[argc_pos]; sds err = NULL; - if (functionsCreateWithFunctionCtx(function_name->ptr, engine_name->ptr, - desc, code->ptr, replace, &err, functions_ctx) != C_OK) + if (functionsCreateWithLibraryCtx(library_name->ptr, engine_name->ptr, + desc, code->ptr, replace, &err, curr_functions_lib_ctx) != C_OK) { addReplyErrorSds(c, err); return; @@ -731,9 +928,9 @@ unsigned long functionsMemory() { unsigned long functionsMemoryOverhead() { size_t memory_overhead = dictSize(engines) * sizeof(dictEntry) + dictSlots(engines) * sizeof(dictEntry*); - memory_overhead += dictSize(functions_ctx->functions) * sizeof(dictEntry) + - dictSlots(functions_ctx->functions) * sizeof(dictEntry*) + sizeof(functionsCtx); - memory_overhead += functions_ctx->cache_memory; + memory_overhead += dictSize(curr_functions_lib_ctx->functions) * sizeof(dictEntry) + + dictSlots(curr_functions_lib_ctx->functions) * sizeof(dictEntry*) + sizeof(functionsLibCtx); + memory_overhead += curr_functions_lib_ctx->cache_memory; memory_overhead += engine_cache_memory; return memory_overhead; @@ -741,14 +938,18 @@ unsigned long functionsMemoryOverhead() { /* Returns the number of functions */ unsigned long functionsNum() { - return dictSize(functions_ctx->functions); + return dictSize(curr_functions_lib_ctx->functions); } -dict* functionsGet() { - return functions_ctx->functions; +unsigned long functionsLibNum() { + return dictSize(curr_functions_lib_ctx->libraries); } -size_t functionsLen(functionsCtx *functions_ctx) { +dict* functionsLibGet() { + return curr_functions_lib_ctx->libraries; +} + +size_t functionsLibCtxfunctionsLen(functionsLibCtx *functions_ctx) { return dictSize(functions_ctx->functions); } @@ -756,7 +957,7 @@ size_t functionsLen(functionsCtx *functions_ctx) { * Should be called once on server initialization */ int functionsInit() { engines = dictCreate(&engineDictType); - functions_ctx = functionsCtxCreate(); + curr_functions_lib_ctx = functionsLibCtxCreate(); if (luaEngineInitEngine() != C_OK) { return C_ERR; diff --git a/src/functions.h b/src/functions.h index 11cb307e1..f49040193 100644 --- a/src/functions.h +++ b/src/functions.h @@ -47,13 +47,15 @@ #include "script.h" #include "redismodule.h" +typedef struct functionLibInfo functionLibInfo; + typedef struct engine { /* engine specific context */ void *engine_ctx; /* Create function callback, get the engine_ctx, and function code. * returns NULL on error and set sds to be the error message */ - void* (*create)(void *engine_ctx, sds code, sds *err); + int (*create)(void *engine_ctx, functionLibInfo *li, sds code, sds *err); /* Invoking a function, r_ctx is an opaque object (from engine POV). * The r_ctx should be used by the engine to interaction with Redis, @@ -89,29 +91,40 @@ typedef struct engineInfo { /* Hold information about the specific function. * Used on rdb.c so it must be declared here. */ typedef struct functionInfo { - sds name; /* Function name */ - void *function; /* Opaque object that set by the function's engine and allow it - to run the function, usually it's the function compiled code. */ - engineInfo *ei; /* Pointer to the function engine */ - sds code; /* Function code */ - sds desc; /* Function description */ + sds name; /* Function name */ + void *function; /* Opaque object that set by the function's engine and allow it + to run the function, usually it's the function compiled code. */ + functionLibInfo* li; /* Pointer to the library created the function */ + sds desc; /* Function description */ } functionInfo; +/* Hold information about the specific library. + * Used on rdb.c so it must be declared here. */ +struct functionLibInfo { + sds name; /* Library name */ + dict *functions; /* Functions dictionary */ + engineInfo *ei; /* Pointer to the function engine */ + sds code; /* Library code */ + sds desc; /* Library description */ +}; + int functionsRegisterEngine(const char *engine_name, engine *engine_ctx); -int functionsCreateWithFunctionCtx(sds function_name, sds engine_name, sds desc, sds code, - int replace, sds* err, functionsCtx *functions); +int functionsCreateWithLibraryCtx(sds lib_name, sds engine_name, sds desc, sds code, + int replace, sds* err, functionsLibCtx *lib_ctx); unsigned long functionsMemory(); unsigned long functionsMemoryOverhead(); -int functionsLoad(rio *rdb, int ver); unsigned long functionsNum(); -dict* functionsGet(); -size_t functionsLen(functionsCtx *functions_ctx); -functionsCtx* functionsCtxGetCurrent(); -functionsCtx* functionsCtxCreate(); -void functionsCtxClearCurrent(int async); -void functionsCtxFree(functionsCtx *functions_ctx); -void functionsCtxClear(functionsCtx *functions_ctx); -void functionsCtxSwapWithCurrent(functionsCtx *functions_ctx); +unsigned long functionsLibNum(); +dict* functionsLibGet(); +size_t functionsLibCtxfunctionsLen(functionsLibCtx *functions_ctx); +functionsLibCtx* functionsLibCtxGetCurrent(); +functionsLibCtx* functionsLibCtxCreate(); +void functionsLibCtxClearCurrent(int async); +void functionsLibCtxFree(functionsLibCtx *lib_ctx); +void functionsLibCtxClear(functionsLibCtx *lib_ctx); +void functionsLibCtxSwapWithCurrent(functionsLibCtx *lib_ctx); + +int functionLibCreateFunction(sds name, void *function, functionLibInfo *li, sds desc, sds *err); int luaEngineInitEngine(); int functionsInit(); diff --git a/src/lazyfree.c b/src/lazyfree.c index 1cc521cbd..4e336255e 100644 --- a/src/lazyfree.c +++ b/src/lazyfree.c @@ -49,9 +49,9 @@ void lazyFreeLuaScripts(void *args[]) { /* Release the functions ctx. */ void lazyFreeFunctionsCtx(void *args[]) { - functionsCtx *f_ctx = args[0]; - size_t len = functionsLen(f_ctx); - functionsCtxFree(f_ctx); + functionsLibCtx *functions_lib_ctx = args[0]; + size_t len = functionsLibCtxfunctionsLen(functions_lib_ctx); + functionsLibCtxFree(functions_lib_ctx); atomicDecr(lazyfree_objects,len); atomicIncr(lazyfreed_objects,len); } @@ -204,12 +204,12 @@ void freeLuaScriptsAsync(dict *lua_scripts) { } /* Free functions ctx, if the functions ctx contains enough functions, free it in async way. */ -void freeFunctionsAsync(functionsCtx *f_ctx) { - if (functionsLen(f_ctx) > LAZYFREE_THRESHOLD) { - atomicIncr(lazyfree_objects,functionsLen(f_ctx)); - bioCreateLazyFreeJob(lazyFreeFunctionsCtx,1,f_ctx); +void freeFunctionsAsync(functionsLibCtx *functions_lib_ctx) { + if (functionsLibCtxfunctionsLen(functions_lib_ctx) > LAZYFREE_THRESHOLD) { + atomicIncr(lazyfree_objects,functionsLibCtxfunctionsLen(functions_lib_ctx)); + bioCreateLazyFreeJob(lazyFreeFunctionsCtx,1,functions_lib_ctx); } else { - functionsCtxFree(f_ctx); + functionsLibCtxFree(functions_lib_ctx); } } diff --git a/src/rdb.c b/src/rdb.c index a8172a865..91645d981 100644 --- a/src/rdb.c +++ b/src/rdb.c @@ -1215,7 +1215,7 @@ ssize_t rdbSaveSingleModuleAux(rio *rdb, int when, moduleType *mt) { } ssize_t rdbSaveFunctions(rio *rdb) { - dict *functions = functionsGet(); + dict *functions = functionsLibGet(); dictIterator *iter = dictGetIterator(functions); dictEntry *entry = NULL; ssize_t written = 0; @@ -1223,23 +1223,23 @@ ssize_t rdbSaveFunctions(rio *rdb) { while ((entry = dictNext(iter))) { if ((ret = rdbSaveType(rdb, RDB_OPCODE_FUNCTION)) < 0) goto werr; written += ret; - functionInfo *fi = dictGetVal(entry); - if ((ret = rdbSaveRawString(rdb, (unsigned char *) fi->name, sdslen(fi->name))) < 0) goto werr; + functionLibInfo *li = dictGetVal(entry); + if ((ret = rdbSaveRawString(rdb, (unsigned char *) li->name, sdslen(li->name))) < 0) goto werr; written += ret; - if ((ret = rdbSaveRawString(rdb, (unsigned char *) fi->ei->name, sdslen(fi->ei->name))) < 0) goto werr; + if ((ret = rdbSaveRawString(rdb, (unsigned char *) li->ei->name, sdslen(li->ei->name))) < 0) goto werr; written += ret; - if (fi->desc) { + if (li->desc) { /* desc exists */ if ((ret = rdbSaveLen(rdb, 1)) < 0) goto werr; written += ret; - if ((ret = rdbSaveRawString(rdb, (unsigned char *) fi->desc, sdslen(fi->desc))) < 0) goto werr; + if ((ret = rdbSaveRawString(rdb, (unsigned char *) li->desc, sdslen(li->desc))) < 0) goto werr; written += ret; } else { /* desc not exists */ if ((ret = rdbSaveLen(rdb, 0)) < 0) goto werr; written += ret; } - if ((ret = rdbSaveRawString(rdb, (unsigned char *) fi->code, sdslen(fi->code))) < 0) goto werr; + if ((ret = rdbSaveRawString(rdb, (unsigned char *) li->code, sdslen(li->code))) < 0) goto werr; written += ret; } dictReleaseIterator(iter); @@ -2746,7 +2746,7 @@ void rdbLoadProgressCallback(rio *r, const void *buf, size_t len) { * The err output parameter is optional and will be set with relevant error * message on failure, it is the caller responsibility to free the error * message on failure. */ -int rdbFunctionLoad(rio *rdb, int ver, functionsCtx* functions_ctx, int rdbflags, sds *err) { +int rdbFunctionLoad(rio *rdb, int ver, functionsLibCtx* lib_ctx, int rdbflags, sds *err) { UNUSED(ver); sds name = NULL; sds engine_name = NULL; @@ -2756,7 +2756,7 @@ int rdbFunctionLoad(rio *rdb, int ver, functionsCtx* functions_ctx, int rdbflags sds error = NULL; int res = C_ERR; if (!(name = rdbGenericLoadStringObject(rdb, RDB_LOAD_SDS, NULL))) { - error = sdsnew("Failed loading function name"); + error = sdsnew("Failed loading library name"); goto error; } @@ -2766,23 +2766,23 @@ int rdbFunctionLoad(rio *rdb, int ver, functionsCtx* functions_ctx, int rdbflags } if ((has_desc = rdbLoadLen(rdb, NULL)) == RDB_LENERR) { - error = sdsnew("Failed loading function description indicator"); + error = sdsnew("Failed loading library description indicator"); goto error; } if (has_desc && !(desc = rdbGenericLoadStringObject(rdb, RDB_LOAD_SDS, NULL))) { - error = sdsnew("Failed loading function description"); + error = sdsnew("Failed loading library description"); goto error; } if (!(blob = rdbGenericLoadStringObject(rdb, RDB_LOAD_SDS, NULL))) { - error = sdsnew("Failed loading function blob"); + error = sdsnew("Failed loading library blob"); goto error; } - if (functionsCreateWithFunctionCtx(name, engine_name, desc, blob, rdbflags & RDBFLAGS_ALLOW_DUP, &error, functions_ctx) != C_OK) { + if (functionsCreateWithLibraryCtx(name, engine_name, desc, blob, rdbflags & RDBFLAGS_ALLOW_DUP, &error, lib_ctx) != C_OK) { if (!error) { - error = sdsnew("Failed creating the function"); + error = sdsnew("Failed creating the library"); } goto error; } @@ -2808,8 +2808,8 @@ error: /* Load an RDB file from the rio stream 'rdb'. On success C_OK is returned, * otherwise C_ERR is returned and 'errno' is set accordingly. */ int rdbLoadRio(rio *rdb, int rdbflags, rdbSaveInfo *rsi) { - functionsCtx* functions_ctx = functionsCtxGetCurrent(); - rdbLoadingCtx loading_ctx = { .dbarray = server.db, .functions_ctx = functions_ctx }; + functionsLibCtx* functions_lib_ctx = functionsLibCtxGetCurrent(); + rdbLoadingCtx loading_ctx = { .dbarray = server.db, .functions_lib_ctx = functions_lib_ctx }; int retval = rdbLoadRioWithLoadingCtx(rdb,rdbflags,rsi,&loading_ctx); return retval; } @@ -2818,7 +2818,7 @@ int rdbLoadRio(rio *rdb, int rdbflags, rdbSaveInfo *rsi) { /* Load an RDB file from the rio stream 'rdb'. On success C_OK is returned, * otherwise C_ERR is returned and 'errno' is set accordingly. * The rdb_loading_ctx argument holds objects to which the rdb will be loaded to, - * currently it only allow to set db object and functionsCtx to which the data + * currently it only allow to set db object and functionLibCtx to which the data * will be loaded (in the future it might contains more such objects). */ int rdbLoadRioWithLoadingCtx(rio *rdb, int rdbflags, rdbSaveInfo *rsi, rdbLoadingCtx *rdb_loading_ctx) { uint64_t dbid = 0; @@ -3023,8 +3023,8 @@ int rdbLoadRioWithLoadingCtx(rio *rdb, int rdbflags, rdbSaveInfo *rsi, rdbLoadin } } else if (type == RDB_OPCODE_FUNCTION) { sds err = NULL; - if (rdbFunctionLoad(rdb, rdbver, rdb_loading_ctx->functions_ctx, rdbflags, &err) != C_OK) { - serverLog(LL_WARNING,"Failed loading function, %s", err); + if (rdbFunctionLoad(rdb, rdbver, rdb_loading_ctx->functions_lib_ctx, rdbflags, &err) != C_OK) { + serverLog(LL_WARNING,"Failed loading library, %s", err); sdsfree(err); goto eoferr; } diff --git a/src/rdb.h b/src/rdb.h index f2a5a28fe..3c7b5ffcc 100644 --- a/src/rdb.h +++ b/src/rdb.h @@ -169,7 +169,7 @@ int rdbSaveBinaryFloatValue(rio *rdb, float val); int rdbLoadBinaryFloatValue(rio *rdb, float *val); int rdbLoadRio(rio *rdb, int rdbflags, rdbSaveInfo *rsi); int rdbLoadRioWithLoadingCtx(rio *rdb, int rdbflags, rdbSaveInfo *rsi, rdbLoadingCtx *rdb_loading_ctx); -int rdbFunctionLoad(rio *rdb, int ver, functionsCtx* functions_ctx, int rdbflags, sds *err); +int rdbFunctionLoad(rio *rdb, int ver, functionsLibCtx* lib_ctx, int rdbflags, sds *err); int rdbSaveRio(int req, rio *rdb, int *error, int rdbflags, rdbSaveInfo *rsi); ssize_t rdbSaveFunctions(rio *rdb); rdbSaveInfo *rdbPopulateSaveInfo(rdbSaveInfo *rsi); diff --git a/src/replication.c b/src/replication.c index a0a19f36d..75472bf9e 100644 --- a/src/replication.c +++ b/src/replication.c @@ -1792,7 +1792,7 @@ void readSyncBulkPayload(connection *conn) { ssize_t nread, readlen, nwritten; int use_diskless_load = useDisklessLoad(); redisDb *diskless_load_tempDb = NULL; - functionsCtx* temp_functions_ctx = NULL; + functionsLibCtx* temp_functions_lib_ctx = NULL; int empty_db_flags = server.repl_slave_lazy_flush ? EMPTYDB_ASYNC : EMPTYDB_NO_FLAGS; off_t left; @@ -1968,7 +1968,7 @@ void readSyncBulkPayload(connection *conn) { if (use_diskless_load && server.repl_diskless_load == REPL_DISKLESS_LOAD_SWAPDB) { /* Initialize empty tempDb dictionaries. */ diskless_load_tempDb = disklessLoadInitTempDb(); - temp_functions_ctx = functionsCtxCreate(); + temp_functions_lib_ctx = functionsLibCtxCreate(); moduleFireServerEvent(REDISMODULE_EVENT_REPL_ASYNC_LOAD, REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_STARTED, @@ -1991,7 +1991,7 @@ void readSyncBulkPayload(connection *conn) { if (use_diskless_load) { rio rdb; redisDb *dbarray; - functionsCtx* functions_ctx; + functionsLibCtx* functions_lib_ctx; int asyncLoading = 0; if (server.repl_diskless_load == REPL_DISKLESS_LOAD_SWAPDB) { @@ -2004,11 +2004,11 @@ void readSyncBulkPayload(connection *conn) { asyncLoading = 1; } dbarray = diskless_load_tempDb; - functions_ctx = temp_functions_ctx; + functions_lib_ctx = temp_functions_lib_ctx; } else { dbarray = server.db; - functions_ctx = functionsCtxGetCurrent(); - functionsCtxClear(functions_ctx); + functions_lib_ctx = functionsLibCtxGetCurrent(); + functionsLibCtxClear(functions_lib_ctx); } rioInitWithConn(&rdb,conn,server.repl_transfer_size); @@ -2020,7 +2020,7 @@ void readSyncBulkPayload(connection *conn) { startLoading(server.repl_transfer_size, RDBFLAGS_REPLICATION, asyncLoading); int loadingFailed = 0; - rdbLoadingCtx loadingCtx = { .dbarray = dbarray, .functions_ctx = functions_ctx }; + rdbLoadingCtx loadingCtx = { .dbarray = dbarray, .functions_lib_ctx = functions_lib_ctx }; if (rdbLoadRioWithLoadingCtx(&rdb,RDBFLAGS_REPLICATION,&rsi,&loadingCtx) != C_OK) { /* RDB loading failed. */ serverLog(LL_WARNING, @@ -2049,7 +2049,7 @@ void readSyncBulkPayload(connection *conn) { NULL); disklessLoadDiscardTempDb(diskless_load_tempDb); - functionsCtxFree(temp_functions_ctx); + functionsLibCtxFree(temp_functions_lib_ctx); serverLog(LL_NOTICE, "MASTER <-> REPLICA sync: Discarding temporary DB in background"); } else { /* Remove the half-loaded data in case we started with an empty replica. */ @@ -2073,7 +2073,7 @@ void readSyncBulkPayload(connection *conn) { swapMainDbWithTempDb(diskless_load_tempDb); /* swap existing functions ctx with the temporary one */ - functionsCtxSwapWithCurrent(temp_functions_ctx); + functionsLibCtxSwapWithCurrent(temp_functions_lib_ctx); moduleFireServerEvent(REDISMODULE_EVENT_REPL_ASYNC_LOAD, REDISMODULE_SUBEVENT_REPL_ASYNC_LOAD_COMPLETED, diff --git a/src/script_lua.c b/src/script_lua.c index 258c6c385..fc9fc812a 100644 --- a/src/script_lua.c +++ b/src/script_lua.c @@ -80,6 +80,9 @@ void* luaGetFromRegistry(lua_State* lua, const char* name) { lua_pushstring(lua, name); lua_gettable(lua, LUA_REGISTRYINDEX); + if (lua_isnil(lua, -1)) { + return NULL; + } /* must be light user data */ serverAssert(lua_islightuserdata(lua, -1)); @@ -427,7 +430,7 @@ static void redisProtocolToLuaType_Double(void *ctx, double d, const char *proto * with a single "err" field set to the error string. Note that this * table is never a valid reply by proper commands, since the returned * tables are otherwise always indexed by integers, never by strings. */ -static void luaPushError(lua_State *lua, char *error) { +void luaPushError(lua_State *lua, char *error) { lua_Debug dbg; /* If debugging is active and in step mode, log errors resulting from @@ -455,7 +458,7 @@ static void luaPushError(lua_State *lua, char *error) { * by the non-error-trapping version of redis.pcall(), which is redis.call(), * this function will raise the Lua error so that the execution of the * script will be halted. */ -static int luaRaiseError(lua_State *lua) { +int luaRaiseError(lua_State *lua) { lua_pushstring(lua,"err"); lua_gettable(lua,-2); return lua_error(lua); @@ -656,6 +659,10 @@ static void luaReplyToRedisReply(client *c, client* script_client, lua_State *lu static int luaRedisGenericCommand(lua_State *lua, int raise_error) { int j, argc = lua_gettop(lua); scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + if (!rctx) { + luaPushError(lua, "redis.call/pcall can only be called inside a script invocation"); + return luaRaiseError(lua); + } sds err = NULL; client* c = rctx->c; sds reply; @@ -911,6 +918,10 @@ static int luaRedisSetReplCommand(lua_State *lua) { int flags, argc = lua_gettop(lua); scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + if (!rctx) { + lua_pushstring(lua, "redis.set_repl can only be called inside a script invocation"); + return lua_error(lua); + } if (argc != 1) { lua_pushstring(lua, "redis.set_repl() requires two arguments."); @@ -966,6 +977,11 @@ static int luaLogCommand(lua_State *lua) { /* redis.setresp() */ static int luaSetResp(lua_State *lua) { + scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + if (!rctx) { + lua_pushstring(lua, "redis.setresp can only be called inside a script invocation"); + return lua_error(lua); + } int argc = lua_gettop(lua); if (argc != 1) { @@ -978,7 +994,6 @@ static int luaSetResp(lua_State *lua) { lua_pushstring(lua, "RESP version must be 2 or 3."); return lua_error(lua); } - scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); scriptSetResp(rctx, resp); return 0; } @@ -1031,8 +1046,8 @@ static void luaRemoveUnsupportedFunctions(lua_State *lua) { * sequence, because it may interact with creation of globals. * * On Legacy Lua (eval) we need to check 'w ~= \"main\"' otherwise we will not be able - * to create the global 'function ()' variable. On Lua engine we do not use this trick - * so its not needed. */ + * to create the global 'function ()' variable. On Functions Lua engine we do not use + * this trick so it's not needed. */ void luaEnableGlobalsProtection(lua_State *lua, int is_eval) { char *s[32]; sds code = sdsempty(); @@ -1067,6 +1082,89 @@ void luaEnableGlobalsProtection(lua_State *lua, int is_eval) { sdsfree(code); } +/* Create a global protection function and put it to registry. + * This need to be called once in the lua_State lifetime. + * After called it is possible to use luaSetGlobalProtection + * to set global protection on a give table. + * + * The function assumes the Lua stack have a least enough + * space to push 2 element, its up to the caller to verify + * this before calling this function. + * + * Notice, the difference between this and luaEnableGlobalsProtection + * is that luaEnableGlobalsProtection is enabling global protection + * on the current Lua globals. This registering a global protection + * function that later can be applied on any table. */ +void luaRegisterGlobalProtectionFunction(lua_State *lua) { + lua_pushstring(lua, REGISTRY_SET_GLOBALS_PROTECTION_NAME); + char *global_protection_func = "local dbg = debug\n" + "local globals_protection = function (t)\n" + " local mt = {}\n" + " setmetatable(t, mt)\n" + " mt.__newindex = function (t, n, v)\n" + " if dbg.getinfo(2) then\n" + " local w = dbg.getinfo(2, \"S\").what\n" + " if w ~= \"C\" then\n" + " error(\"Script attempted to create global variable '\"..tostring(n)..\"'\", 2)\n" + " end" + " end" + " rawset(t, n, v)\n" + " end\n" + " mt.__index = function (t, n)\n" + " if dbg.getinfo(2) and dbg.getinfo(2, \"S\").what ~= \"C\" then\n" + " error(\"Script attempted to access nonexistent global variable '\"..tostring(n)..\"'\", 2)\n" + " end\n" + " return rawget(t, n)\n" + " end\n" + "end\n" + "return globals_protection"; + int res = luaL_loadbuffer(lua, global_protection_func, strlen(global_protection_func), "@global_protection_def"); + serverAssert(res == 0); + res = lua_pcall(lua,0,1,0); + serverAssert(res == 0); + lua_settable(lua, LUA_REGISTRYINDEX); +} + +/* Set global protection on a given table. + * The table need to be located on the top of the lua stack. + * After called, it will no longer be possible to set + * new items on the table. The function is not removing + * the table from the top of the stack! + * + * The function assumes the Lua stack have a least enough + * space to push 2 element, its up to the caller to verify + * this before calling this function. */ +void luaSetGlobalProtection(lua_State *lua) { + lua_pushstring(lua, REGISTRY_SET_GLOBALS_PROTECTION_NAME); + lua_gettable(lua, LUA_REGISTRYINDEX); + lua_pushvalue(lua, -2); + int res = lua_pcall(lua, 1, 0, 0); + serverAssert(res == 0); +} + +void luaRegisterLogFunction(lua_State* lua) { + /* redis.log and log levels. */ + lua_pushstring(lua,"log"); + lua_pushcfunction(lua,luaLogCommand); + lua_settable(lua,-3); + + lua_pushstring(lua,"LOG_DEBUG"); + lua_pushnumber(lua,LL_DEBUG); + lua_settable(lua,-3); + + lua_pushstring(lua,"LOG_VERBOSE"); + lua_pushnumber(lua,LL_VERBOSE); + lua_settable(lua,-3); + + lua_pushstring(lua,"LOG_NOTICE"); + lua_pushnumber(lua,LL_NOTICE); + lua_settable(lua,-3); + + lua_pushstring(lua,"LOG_WARNING"); + lua_pushnumber(lua,LL_WARNING); + lua_settable(lua,-3); +} + void luaRegisterRedisAPI(lua_State* lua) { luaLoadLibraries(lua); luaRemoveUnsupportedFunctions(lua); @@ -1084,32 +1182,13 @@ void luaRegisterRedisAPI(lua_State* lua) { lua_pushcfunction(lua,luaRedisPCallCommand); lua_settable(lua,-3); - /* redis.log and log levels. */ - lua_pushstring(lua,"log"); - lua_pushcfunction(lua,luaLogCommand); - lua_settable(lua,-3); + luaRegisterLogFunction(lua); /* redis.setresp */ lua_pushstring(lua,"setresp"); lua_pushcfunction(lua,luaSetResp); lua_settable(lua,-3); - lua_pushstring(lua,"LOG_DEBUG"); - lua_pushnumber(lua,LL_DEBUG); - lua_settable(lua,-3); - - lua_pushstring(lua,"LOG_VERBOSE"); - lua_pushnumber(lua,LL_VERBOSE); - lua_settable(lua,-3); - - lua_pushstring(lua,"LOG_NOTICE"); - lua_pushnumber(lua,LL_NOTICE); - lua_settable(lua,-3); - - lua_pushstring(lua,"LOG_WARNING"); - lua_pushnumber(lua,LL_WARNING); - lua_settable(lua,-3); - /* redis.sha1hex */ lua_pushstring(lua, "sha1hex"); lua_pushcfunction(lua, luaRedisSha1hexCommand); @@ -1149,7 +1228,7 @@ void luaRegisterRedisAPI(lua_State* lua) { lua_settable(lua,-3); /* Finally set the table as 'redis' global var. */ - lua_setglobal(lua,"redis"); + lua_setglobal(lua,REDIS_API_NAME); /* Replace math.random and math.randomseed with our implementations. */ lua_getglobal(lua,"math"); @@ -1167,7 +1246,7 @@ void luaRegisterRedisAPI(lua_State* lua) { /* Set an array of Redis String Objects as a Lua array (table) stored into a * global variable. */ -static void luaSetGlobalArray(lua_State *lua, char *var, robj **elev, int elec) { +static void luaCreateArray(lua_State *lua, robj **elev, int elec) { int j; lua_newtable(lua); @@ -1175,7 +1254,6 @@ static void luaSetGlobalArray(lua_State *lua, char *var, robj **elev, int elec) lua_pushlstring(lua,(char*)elev[j]->ptr,sdslen(elev[j]->ptr)); lua_rawseti(lua,-2,j+1); } - lua_setglobal(lua,var); } /* --------------------------------------------------------------------------- @@ -1189,6 +1267,11 @@ static void luaSetGlobalArray(lua_State *lua, char *var, robj **elev, int elec) /* The following implementation is the one shipped with Lua itself but with * rand() replaced by redisLrand48(). */ static int redis_math_random (lua_State *L) { + scriptRunCtx* rctx = luaGetFromRegistry(L, REGISTRY_RUN_CTX_NAME); + if (!rctx) { + return luaL_error(L, "math.random can only be called inside a script invocation"); + } + /* the `%' avoids the (rare) case of r==1, and is needed also because on some systems (SunOS!) `rand()' may return a value larger than RAND_MAX */ lua_Number r = (lua_Number)(redisLrand48()%REDIS_LRAND48_MAX) / @@ -1217,6 +1300,10 @@ static int redis_math_random (lua_State *L) { } static int redis_math_randomseed (lua_State *L) { + scriptRunCtx* rctx = luaGetFromRegistry(L, REGISTRY_RUN_CTX_NAME); + if (!rctx) { + return luaL_error(L, "math.randomseed can only be called inside a script invocation"); + } redisSrand48(luaL_checkint(L, 1)); return 0; } @@ -1260,13 +1347,24 @@ void luaCallFunction(scriptRunCtx* run_ctx, lua_State *lua, robj** keys, size_t /* Populate the argv and keys table accordingly to the arguments that * EVAL received. */ - luaSetGlobalArray(lua,"KEYS",keys,nkeys); - luaSetGlobalArray(lua,"ARGV",args,nargs); + luaCreateArray(lua,keys,nkeys); + /* On eval, keys and arguments are globals. */ + if (run_ctx->flags & SCRIPT_EVAL_MODE) lua_setglobal(lua,"KEYS"); + luaCreateArray(lua,args,nargs); + if (run_ctx->flags & SCRIPT_EVAL_MODE) lua_setglobal(lua,"ARGV"); /* At this point whether this script was never seen before or if it was - * already defined, we can call it. We have zero arguments and expect - * a single return value. */ - int err = lua_pcall(lua,0,1,-2); + * already defined, we can call it. + * On eval mode, we have zero arguments and expect a single return value. + * In addition the error handler is located on position -2 on the Lua stack. + * On function mode, we pass 2 arguments (the keys and args tables), + * and the error handler is located on position -4 (stack: error_handler, callback, keys, args) */ + int err; + if (run_ctx->flags & SCRIPT_EVAL_MODE) { + err = lua_pcall(lua,0,1,-2); + } else { + err = lua_pcall(lua,2,1,-4); + } /* Call the Lua garbage collector from time to time to avoid a * full cycle performed by Lua, which adds too latency. diff --git a/src/script_lua.h b/src/script_lua.h index 40cdd00b5..b7862bee1 100644 --- a/src/script_lua.h +++ b/src/script_lua.h @@ -55,9 +55,16 @@ #include #define REGISTRY_RUN_CTX_NAME "__RUN_CTX__" +#define REGISTRY_SET_GLOBALS_PROTECTION_NAME "__GLOBAL_PROTECTION__" +#define REDIS_API_NAME "redis" void luaRegisterRedisAPI(lua_State* lua); void luaEnableGlobalsProtection(lua_State *lua, int is_eval); +void luaRegisterGlobalProtectionFunction(lua_State *lua); +void luaSetGlobalProtection(lua_State *lua); +void luaRegisterLogFunction(lua_State* lua); +void luaPushError(lua_State *lua, char *error); +int luaRaiseError(lua_State *lua); void luaSaveOnRegistry(lua_State* lua, const char* name, void* ptr); void* luaGetFromRegistry(lua_State* lua, const char* name); void luaCallFunction(scriptRunCtx* r_ctx, lua_State *lua, robj** keys, size_t nkeys, robj** args, size_t nargs, int debug_enabled); diff --git a/src/server.c b/src/server.c index 911441d56..a1d0f1cc5 100644 --- a/src/server.c +++ b/src/server.c @@ -4886,6 +4886,7 @@ sds genRedisInfoString(const char *section) { "used_memory_scripts_eval:%lld\r\n" "number_of_cached_scripts:%lu\r\n" "number_of_functions:%lu\r\n" + "number_of_libraries:%lu\r\n" "used_memory_vm_functions:%lld\r\n" "used_memory_vm_total:%lld\r\n" "used_memory_vm_total_human:%s\r\n" @@ -4936,6 +4937,7 @@ sds genRedisInfoString(const char *section) { (long long) mh->lua_caches, dictSize(evalScriptsDict()), functionsNum(), + functionsLibNum(), memory_functions, memory_functions + memory_lua, used_memory_vm_total_hmem, diff --git a/src/server.h b/src/server.h index fc16cdcd6..a591cc73f 100644 --- a/src/server.h +++ b/src/server.h @@ -873,7 +873,7 @@ typedef struct redisDb { } redisDb; /* forward declaration for functions ctx */ -typedef struct functionsCtx functionsCtx; +typedef struct functionsLibCtx functionsLibCtx; /* Holding object that need to be populated during * rdb loading. On loading end it is possible to decide @@ -882,7 +882,7 @@ typedef struct functionsCtx functionsCtx; * successful loading and dropped on failure. */ typedef struct rdbLoadingCtx { redisDb* dbarray; - functionsCtx* functions_ctx; + functionsLibCtx* functions_lib_ctx; }rdbLoadingCtx; /* Client MULTI/EXEC state */ @@ -3017,7 +3017,7 @@ int ldbPendingChildren(void); sds luaCreateFunction(client *c, robj *body); void luaLdbLineHook(lua_State *lua, lua_Debug *ar); void freeLuaScriptsAsync(dict *lua_scripts); -void freeFunctionsAsync(functionsCtx *f_ctx); +void freeFunctionsAsync(functionsLibCtx *lib_ctx); int ldbIsEnabled(); void ldbLog(sds entry); void ldbLogRedisReply(char *reply); @@ -3279,11 +3279,10 @@ void evalShaRoCommand(client *c); void scriptCommand(client *c); void fcallCommand(client *c); void fcallroCommand(client *c); -void functionCreateCommand(client *c); +void functionLoadCommand(client *c); void functionDeleteCommand(client *c); void functionKillCommand(client *c); void functionStatsCommand(client *c); -void functionInfoCommand(client *c); void functionListCommand(client *c); void functionHelpCommand(client *c); void functionFlushCommand(client *c); diff --git a/tests/integration/redis-cli.tcl b/tests/integration/redis-cli.tcl index de38f5ebe..9cbec3f96 100644 --- a/tests/integration/redis-cli.tcl +++ b/tests/integration/redis-cli.tcl @@ -322,7 +322,7 @@ if {!$::tls} { ;# fake_redis_node doesn't support TLS set dir [lindex [r config get dir] 1] assert_equal "OK" [r debug populate 100000 key 1000] - assert_equal "OK" [r function create lua func1 "return 123"] + assert_equal "OK" [r function load lua lib1 "redis.register_function('func1', function() return 123 end)"] if {$functions_only} { set args "--functions-rdb $dir/cli.rdb" } else { @@ -335,11 +335,10 @@ if {!$::tls} { ;# fake_redis_node doesn't support TLS file rename "$dir/cli.rdb" "$dir/dump.rdb" assert_equal "OK" [r set should-not-exist 1] - assert_equal "OK" [r function create lua should_not_exist_func "return 456"] + assert_equal "OK" [r function load lua should_not_exist_func "redis.register_function('should_not_exist_func', function() return 456 end)"] assert_equal "OK" [r debug reload nosave] assert_equal {} [r get should-not-exist] - assert_error "ERR Function does not exists" {r function info should_not_exist_func} - assert_equal "func1" [dict get [r function info func1] name] + assert_equal {{library_name lib1 engine LUA description {} functions {{name func1 description {}}}}} [r function list] if {$functions_only} { assert_equal 0 [r dbsize] } else { diff --git a/tests/integration/replication.tcl b/tests/integration/replication.tcl index 5755d3a3e..f2f1fe7e9 100644 --- a/tests/integration/replication.tcl +++ b/tests/integration/replication.tcl @@ -522,10 +522,10 @@ foreach testType {Successful Aborted} { $replica set mykey myvalue # Set a function value on replica to check status during loading, on failure and after swapping db - $replica function create LUA test {return 'hello1'} + $replica function load LUA test {redis.register_function('test', function() return 'hello1' end)} # Set a function value on master to check it reaches the replica when replication ends - $master function create LUA test {return 'hello2'} + $master function load LUA test {redis.register_function('test', function() return 'hello2' end)} # Force the replica to try another full sync (this time it will have matching master replid) $master multi @@ -658,7 +658,7 @@ test {diskless loading short read} { set start [clock clicks -milliseconds] # Set a function value to check short read handling on functions - r function create LUA test {return 'hello1'} + r function load LUA test {redis.register_function('test', function() return 'hello1' end)} for {set k 0} {$k < 3} {incr k} { for {set i 0} {$i < 10} {incr i} { diff --git a/tests/unit/cluster.tcl b/tests/unit/cluster.tcl index 2507d5208..d1c02a7c5 100644 --- a/tests/unit/cluster.tcl +++ b/tests/unit/cluster.tcl @@ -182,7 +182,7 @@ start_server [list overrides $base_conf] { # upload a function to all the cluster exec src/redis-cli --cluster-yes --cluster call 127.0.0.1:[srv 0 port] \ - FUNCTION CREATE LUA TEST {return 'hello'} + FUNCTION LOAD LUA TEST {redis.register_function('test', function() return 'hello' end)} # adding node to the cluster exec src/redis-cli --cluster-yes --cluster add-node \ @@ -199,13 +199,13 @@ start_server [list overrides $base_conf] { } # make sure 'test' function was added to the new node - assert_equal {{name TEST engine LUA description {}}} [$node4_rd FUNCTION LIST] + assert_equal {{library_name TEST engine LUA description {} functions {{name test description {}}}}} [$node4_rd FUNCTION LIST] # add function to node 5 - assert_equal {OK} [$node5_rd FUNCTION CREATE LUA TEST {return 'hello1'}] + assert_equal {OK} [$node5_rd FUNCTION LOAD LUA TEST {redis.register_function('test', function() return 'hello' end)}] # make sure functions was added to node 5 - assert_equal {{name TEST engine LUA description {}}} [$node5_rd FUNCTION LIST] + assert_equal {{library_name TEST engine LUA description {} functions {{name test description {}}}}} [$node5_rd FUNCTION LIST] # adding node 5 to the cluster should failed because it already contains the 'test' function catch { diff --git a/tests/unit/functions.tcl b/tests/unit/functions.tcl index 8b38688b4..84f9154e5 100644 --- a/tests/unit/functions.tcl +++ b/tests/unit/functions.tcl @@ -1,46 +1,50 @@ +proc get_function_code {args} { + return [format "redis.register_function('%s', function(KEYS, ARGV)\n %s \nend)" [lindex $args 0] [lindex $args 1]] +} + start_server {tags {"scripting"}} { test {FUNCTION - Basic usage} { - r function create LUA test {return 'hello'} + r function load LUA test [get_function_code test {return 'hello'}] r fcall test 0 } {hello} - test {FUNCTION - Create an already exiting function raise error} { + test {FUNCTION - Create an already exiting library raise error} { catch { - r function create LUA test {return 'hello1'} + r function load LUA test [get_function_code test {return 'hello1'}] } e set _ $e - } {*Function already exists*} + } {*already exists*} - test {FUNCTION - Create an already exiting function raise error (case insensitive)} { + test {FUNCTION - Create an already exiting library raise error (case insensitive)} { catch { - r function create LUA TEST {return 'hello1'} + r function load LUA TEST [get_function_code test {return 'hello1'}] } e set _ $e - } {*Function already exists*} + } {*already exists*} - test {FUNCTION - Create a function with wrong name format} { + test {FUNCTION - Create a library with wrong name format} { catch { - r function create LUA {bad\0foramat} {return 'hello1'} + r function load LUA {bad\0foramat} [get_function_code test {return 'hello1'}] } e set _ $e - } {*Function names can only contain letters and numbers*} + } {*Library names can only contain letters and numbers*} - test {FUNCTION - Create function with unexisting engine} { + test {FUNCTION - Create library with unexisting engine} { catch { - r function create bad_engine test {return 'hello1'} + r function load bad_engine test [get_function_code test {return 'hello1'}] } e set _ $e } {*Engine not found*} test {FUNCTION - Test uncompiled script} { catch { - r function create LUA test1 {bad script} + r function load LUA test1 {bad script} } e set _ $e } {*Error compiling function*} test {FUNCTION - test replace argument} { - r function create LUA test REPLACE {return 'hello1'} + r function load LUA test REPLACE [get_function_code test {return 'hello1'}] r fcall test 0 } {hello1} @@ -48,7 +52,7 @@ start_server {tags {"scripting"}} { r fcall TEST 0 } {hello1} - test {FUNCTION - test replace argument with function creation failure keeps old function} { + test {FUNCTION - test replace argument with failure keeps old libraries} { catch {r function create LUA test REPLACE {error}} r fcall test 0 } {hello1} @@ -62,31 +66,9 @@ start_server {tags {"scripting"}} { } {*Function not found*} test {FUNCTION - test description argument} { - r function create LUA test DESCRIPTION {some description} {return 'hello'} + r function load LUA test DESCRIPTION {some description} [get_function_code test {return 'hello'}] r function list - } {{name test engine LUA description {some description}}} - - test {FUNCTION - test info specific function} { - r function info test WITHCODE - } {name test engine LUA description {some description} code {return 'hello'}} - - test {FUNCTION - test info without code} { - r function info test - } {name test engine LUA description {some description}} - - test {FUNCTION - test info on function that does not exists} { - catch { - r function info bad_function_name - } e - set _ $e - } {*Function does not exists*} - - test {FUNCTION - test info with bad number of arguments} { - catch { - r function info test WITHCODE bad_arg - } e - set _ $e - } {*wrong number of arguments*} + } {{library_name test engine LUA description {some description} functions {{name test description {}}}}} test {FUNCTION - test fcall bad arguments} { catch { @@ -109,12 +91,12 @@ start_server {tags {"scripting"}} { set _ $e } {*Number of keys can't be negative*} - test {FUNCTION - test function delete on not exiting function} { + test {FUNCTION - test delete on not exiting library} { catch { r function delete test1 } e set _ $e - } {*Function not found*} + } {*Library not found*} test {FUNCTION - test function kill when function is not running} { catch { @@ -140,14 +122,14 @@ start_server {tags {"scripting"}} { assert_match "*Error trying to load the RDB*" $e r debug reload noflush merge r function list - } {{name test engine LUA description {some description}}} {needs:debug} + } {{library_name test engine LUA description {some description} functions {{name test description {}}}}} {needs:debug} test {FUNCTION - test debug reload with nosave and noflush} { r function delete test r set x 1 - r function create LUA test1 DESCRIPTION {some description} {return 'hello'} + r function load LUA test1 DESCRIPTION {some description} [get_function_code test1 {return 'hello'}] r debug reload - r function create LUA test2 DESCRIPTION {some description} {return 'hello'} + r function load LUA test2 DESCRIPTION {some description} [get_function_code test2 {return 'hello'}] r debug reload nosave noflush merge assert_equal [r fcall test1 0] {hello} assert_equal [r fcall test2 0] {hello} @@ -155,21 +137,21 @@ start_server {tags {"scripting"}} { test {FUNCTION - test flushall and flushdb do not clean functions} { r function flush - r function create lua test REPLACE {return redis.call('set', 'x', '1')} + r function load lua test REPLACE [get_function_code test {return redis.call('set', 'x', '1')}] r flushall r flushdb r function list - } {{name test engine LUA description {}}} + } {{library_name test engine LUA description {} functions {{name test description {}}}}} test {FUNCTION - test function dump and restore} { r function flush - r function create lua test description {some description} {return 'hello'} + r function load lua test description {some description} [get_function_code test {return 'hello'}] set e [r function dump] r function delete test assert_match {} [r function list] r function restore $e r function list - } {{name test engine LUA description {some description}}} + } {{library_name test engine LUA description {some description} functions {{name test description {}}}}} test {FUNCTION - test function dump and restore with flush argument} { set e [r function dump] @@ -177,17 +159,17 @@ start_server {tags {"scripting"}} { assert_match {} [r function list] r function restore $e FLUSH r function list - } {{name test engine LUA description {some description}}} + } {{library_name test engine LUA description {some description} functions {{name test description {}}}}} test {FUNCTION - test function dump and restore with append argument} { set e [r function dump] r function flush assert_match {} [r function list] - r function create lua test {return 'hello1'} + r function load lua test [get_function_code test {return 'hello1'}] catch {r function restore $e APPEND} err assert_match {*already exists*} $err r function flush - r function create lua test1 {return 'hello1'} + r function load lua test1 [get_function_code test1 {return 'hello1'}] r function restore $e APPEND assert_match {hello} [r fcall test 0] assert_match {hello1} [r fcall test1 0] @@ -195,11 +177,11 @@ start_server {tags {"scripting"}} { test {FUNCTION - test function dump and restore with replace argument} { r function flush - r function create LUA test DESCRIPTION {some description} {return 'hello'} + r function load LUA test DESCRIPTION {some description} [get_function_code test {return 'hello'}] set e [r function dump] r function flush assert_match {} [r function list] - r function create lua test {return 'hello1'} + r function load lua test [get_function_code test {return 'hello1'}] assert_match {hello1} [r fcall test 0] r function restore $e REPLACE assert_match {hello} [r fcall test 0] @@ -207,11 +189,11 @@ start_server {tags {"scripting"}} { test {FUNCTION - test function restore with bad payload do not drop existing functions} { r function flush - r function create LUA test DESCRIPTION {some description} {return 'hello'} + r function load LUA test DESCRIPTION {some description} [get_function_code test {return 'hello'}] catch {r function restore bad_payload} e assert_match {*payload version or checksum are wrong*} $e r function list - } {{name test engine LUA description {some description}}} + } {{library_name test engine LUA description {some description} functions {{name test description {}}}}} test {FUNCTION - test function restore with wrong number of arguments} { catch {r function restore arg1 args2 arg3} e @@ -219,19 +201,19 @@ start_server {tags {"scripting"}} { } {*wrong number of arguments*} test {FUNCTION - test fcall_ro with write command} { - r function create lua test REPLACE {return redis.call('set', 'x', '1')} + r function load lua test REPLACE [get_function_code test {return redis.call('set', 'x', '1')}] catch { r fcall_ro test 0 } e set _ $e } {*Write commands are not allowed from read-only scripts*} test {FUNCTION - test fcall_ro with read only commands} { - r function create lua test REPLACE {return redis.call('get', 'x')} + r function load lua test REPLACE [get_function_code test {return redis.call('get', 'x')}] r set x 1 r fcall_ro test 0 } {1} test {FUNCTION - test keys and argv} { - r function create lua test REPLACE {return redis.call('set', KEYS[1], ARGV[1])} + r function load lua test REPLACE [get_function_code test {return redis.call('set', KEYS[1], ARGV[1])}] r fcall test 1 x foo r get x } {foo} @@ -247,7 +229,7 @@ start_server {tags {"scripting"}} { test {FUNCTION - test function kill} { set rd [redis_deferring_client] r config set script-time-limit 10 - r function create lua test REPLACE {local a = 1 while true do a = a + 1 end} + r function load lua test REPLACE [get_function_code test {local a = 1 while true do a = a + 1 end}] $rd fcall test 0 after 200 catch {r ping} e @@ -261,7 +243,7 @@ start_server {tags {"scripting"}} { test {FUNCTION - test script kill not working on function} { set rd [redis_deferring_client] r config set script-time-limit 10 - r function create lua test REPLACE {local a = 1 while true do a = a + 1 end} + r function load lua test REPLACE [get_function_code test {local a = 1 while true do a = a + 1 end}] $rd fcall test 0 after 200 catch {r ping} e @@ -288,18 +270,18 @@ start_server {tags {"scripting"}} { } test {FUNCTION - test function flush} { - r function create lua test REPLACE {local a = 1 while true do a = a + 1 end} - assert_match {{name test engine LUA description {}}} [r function list] + r function load lua test REPLACE [get_function_code test {local a = 1 while true do a = a + 1 end}] + assert_match {{library_name test engine LUA description {} functions {{name test description {}}}}} [r function list] r function flush assert_match {} [r function list] - r function create lua test REPLACE {local a = 1 while true do a = a + 1 end} - assert_match {{name test engine LUA description {}}} [r function list] + r function load lua test REPLACE [get_function_code test {local a = 1 while true do a = a + 1 end}] + assert_match {{library_name test engine LUA description {} functions {{name test description {}}}}} [r function list] r function flush async assert_match {} [r function list] - r function create lua test REPLACE {local a = 1 while true do a = a + 1 end} - assert_match {{name test engine LUA description {}}} [r function list] + r function load lua test REPLACE [get_function_code test {local a = 1 while true do a = a + 1 end}] + assert_match {{library_name test engine LUA description {} functions {{name test description {}}}}} [r function list] r function flush sync assert_match {} [r function list] } @@ -326,9 +308,9 @@ start_server {tags {"scripting repl external:skip"}} { } test {FUNCTION - creation is replicated to replica} { - r function create LUA test DESCRIPTION {some description} {return 'hello'} + r function load LUA test DESCRIPTION {some description} [get_function_code test {return 'hello'}] wait_for_condition 50 100 { - [r -1 function list] eq {{name test engine LUA description {some description}}} + [r -1 function list] eq {{library_name test engine LUA description {some description} functions {{name test description {}}}}} } else { fail "Failed waiting for function to replicate to replica" } @@ -348,10 +330,10 @@ start_server {tags {"scripting repl external:skip"}} { fail "Failed waiting for function to replicate to replica" } - r function restore $e + assert_equal [r function restore $e] {OK} wait_for_condition 50 100 { - [r -1 function list] eq {{name test engine LUA description {some description}}} + [r -1 function list] eq {{library_name test engine LUA description {some description} functions {{name test description {}}}}} } else { fail "Failed waiting for function to replicate to replica" } @@ -367,9 +349,9 @@ start_server {tags {"scripting repl external:skip"}} { } test {FUNCTION - flush is replicated to replica} { - r function create LUA test DESCRIPTION {some description} {return 'hello'} + r function load LUA test DESCRIPTION {some description} [get_function_code test {return 'hello'}] wait_for_condition 50 100 { - [r -1 function list] eq {{name test engine LUA description {some description}}} + [r -1 function list] eq {{library_name test engine LUA description {some description} functions {{name test description {}}}}} } else { fail "Failed waiting for function to replicate to replica" } @@ -385,7 +367,7 @@ start_server {tags {"scripting repl external:skip"}} { r -1 slaveof no one # creating a function after disconnect to make sure function # is replicated on rdb phase - r function create LUA test DESCRIPTION {some description} {return 'hello'} + r function load LUA test DESCRIPTION {some description} [get_function_code test {return 'hello'}] # reconnect the replica r -1 slaveof [srv 0 host] [srv 0 port] @@ -402,12 +384,12 @@ start_server {tags {"scripting repl external:skip"}} { } {hello} test "FUNCTION - test replication to replica on rdb phase info command" { - r -1 function info test WITHCODE - } {name test engine LUA description {some description} code {return 'hello'}} + r -1 function list + } {{library_name test engine LUA description {some description} functions {{name test description {}}}}} test "FUNCTION - create on read only replica" { catch { - r -1 function create LUA test DESCRIPTION {some description} {return 'hello'} + r -1 function load LUA test DESCRIPTION {some description} [get_function_code test {return 'hello'}] } e set _ $e } {*can't write against a read only replica*} @@ -420,7 +402,7 @@ start_server {tags {"scripting repl external:skip"}} { } {*can't write against a read only replica*} test "FUNCTION - function effect is replicated to replica" { - r function create LUA test REPLACE {return redis.call('set', 'x', '1')} + r function load LUA test REPLACE [get_function_code test {return redis.call('set', 'x', '1')}] r fcall test 0 assert {[r get x] eq {1}} wait_for_condition 50 100 { @@ -443,12 +425,12 @@ test {FUNCTION can processes create, delete and flush commands in AOF when doing start_server {} { r config set appendonly yes waitForBgrewriteaof r - r FUNCTION CREATE lua test "return 'hello'" + r FUNCTION LOAD lua test "redis.register_function('test', function() return 'hello' end)" r config set slave-read-only yes r slaveof 127.0.0.1 0 r debug loadaof r slaveof no one - assert_equal [r function list] {{name test engine LUA description {}}} + assert_equal [r function list] {{library_name test engine LUA description {} functions {{name test description {}}}}} r FUNCTION DELETE test @@ -457,7 +439,7 @@ test {FUNCTION can processes create, delete and flush commands in AOF when doing r slaveof no one assert_equal [r function list] {} - r FUNCTION CREATE lua test "return 'hello'" + r FUNCTION LOAD lua test "redis.register_function('test', function() return 'hello' end)" r FUNCTION FLUSH r slaveof 127.0.0.1 0 @@ -466,3 +448,420 @@ test {FUNCTION can processes create, delete and flush commands in AOF when doing assert_equal [r function list] {} } } {} {needs:debug external:skip} + +start_server {tags {"scripting"}} { + test {LIBRARIES - test shared function can access default globals} { + r function load LUA lib1 { + local function ping() + return redis.call('ping') + end + redis.register_function( + 'f1', + function(keys, args) + return ping() + end + ) + } + r fcall f1 0 + } {PONG} + + test {LIBRARIES - usage and code sharing} { + r function load LUA lib1 REPLACE { + local function add1(a) + return a + 1 + end + redis.register_function( + 'f1', + function(keys, args) + return add1(1) + end, + 'f1 description' + ) + redis.register_function( + 'f2', + function(keys, args) + return add1(2) + end, + 'f2 description' + ) + } + assert_equal [r fcall f1 0] {2} + assert_equal [r fcall f2 0] {3} + r function list + } {{library_name lib1 engine LUA description {} functions {*}}} + + test {LIBRARIES - test registration failure revert the entire load} { + catch { + r function load LUA lib1 replace { + local function add1(a) + return a + 2 + end + redis.register_function( + 'f1', + function(keys, args) + return add1(1) + end + ) + redis.register_function( + 'f2', + 'not a function' + ) + } + } e + assert_match {*second argument to redis.register_function must be a function*} $e + assert_equal [r fcall f1 0] {2} + assert_equal [r fcall f2 0] {3} + } + + test {LIBRARIES - test registration function name collision} { + catch { + r function load LUA lib2 replace { + redis.register_function( + 'f1', + function(keys, args) + return 1 + end + ) + } + } e + assert_match {*Function f1 already exists*} $e + assert_equal [r fcall f1 0] {2} + assert_equal [r fcall f2 0] {3} + } + + test {LIBRARIES - test registration function name collision on same library} { + catch { + r function load LUA lib2 replace { + redis.register_function( + 'f1', + function(keys, args) + return 1 + end + ) + redis.register_function( + 'f1', + function(keys, args) + return 1 + end + ) + } + } e + set _ $e + } {*Function already exists in the library*} + + test {LIBRARIES - test registration with no argument} { + catch { + r function load LUA lib2 replace { + redis.register_function() + } + } e + set _ $e + } {*wrong number of arguments to redis.register_function*} + + test {LIBRARIES - test registration with only name} { + catch { + r function load LUA lib2 replace { + redis.register_function('f1') + } + } e + set _ $e + } {*wrong number of arguments to redis.register_function*} + + test {LIBRARIES - test registration with to many arguments} { + catch { + r function load LUA lib2 replace { + redis.register_function('f1', function() return 1 end, 'description', 'extra arg') + } + } e + set _ $e + } {*wrong number of arguments to redis.register_function*} + + test {LIBRARIES - test registration with no string name} { + catch { + r function load LUA lib2 replace { + redis.register_function(nil, function() return 1 end) + } + } e + set _ $e + } {*first argument to redis.register_function must be a string*} + + test {LIBRARIES - test registration with wrong name format} { + catch { + r function load LUA lib2 replace { + redis.register_function('test\0test', function() return 1 end) + } + } e + set _ $e + } {*Function names can only contain letters and numbers and must be at least one character long*} + + test {LIBRARIES - test registration with empty name} { + catch { + r function load LUA lib2 replace { + redis.register_function('', function() return 1 end) + } + } e + set _ $e + } {*Function names can only contain letters and numbers and must be at least one character long*} + + test {LIBRARIES - math.random from function load} { + catch { + r function load LUA lib2 replace { + return math.random() + } + } e + set _ $e + } {*attempted to access nonexistent global variable 'math'*} + + test {LIBRARIES - redis.call from function load} { + catch { + r function load LUA lib2 replace { + return redis.call('ping') + } + } e + set _ $e + } {*attempt to call field 'call' (a nil value)*} + + test {LIBRARIES - redis.call from function load} { + catch { + r function load LUA lib2 replace { + return redis.setresp(3) + } + } e + set _ $e + } {*attempt to call field 'setresp' (a nil value)*} + + test {LIBRARIES - redis.set_repl from function load} { + catch { + r function load LUA lib2 replace { + return redis.set_repl(redis.REPL_NONE) + } + } e + set _ $e + } {*attempt to call field 'set_repl' (a nil value)*} + + test {LIBRARIES - malicious access test} { + # the 'library' API is not exposed inside a + # function context and the 'redis' API is not + # expose on the library registration context. + # But a malicious user might find a way to hack it + # (as demonstrated in this test). This is why we + # have another level of protection on the C + # code itself and we want to test it and verify + # that it works properly. + r function load LUA lib1 replace { + local lib = redis + lib.register_function('f1', function () + lib.redis = redis + lib.math = math + return {ok='OK'} + end) + + lib.register_function('f2', function () + lib.register_function('f1', function () + lib.redis = redis + lib.math = math + return {ok='OK'} + end) + end) + } + assert_equal {OK} [r fcall f1 0] + + catch {[r function load LUA lib2 {redis.math.random()}]} e + assert_match {*can only be called inside a script invocation*} $e + + catch {[r function load LUA lib2 {redis.math.randomseed()}]} e + assert_match {*can only be called inside a script invocation*} $e + + catch {[r function load LUA lib2 {redis.redis.call('ping')}]} e + assert_match {*can only be called inside a script invocation*} $e + + catch {[r function load LUA lib2 {redis.redis.pcall('ping')}]} e + assert_match {*can only be called inside a script invocation*} $e + + catch {[r function load LUA lib2 {redis.redis.setresp(3)}]} e + assert_match {*can only be called inside a script invocation*} $e + + catch {[r function load LUA lib2 {redis.redis.set_repl(redis.redis.REPL_NONE)}]} e + assert_match {*can only be called inside a script invocation*} $e + + catch {[r fcall f2 0]} e + assert_match {*can only be called on FUNCTION LOAD command*} $e + } + + test {LIBRARIES - delete removed all functions on library} { + r function delete lib1 + r function list + } {} + + test {LIBRARIES - register function inside a function} { + r function load LUA lib { + redis.register_function( + 'f1', + function(keys, args) + redis.register_function( + 'f2', + function(key, args) + return 2 + end + ) + return 1 + end + ) + } + catch {r fcall f1 0} e + set _ $e + } {*attempt to call field 'register_function' (a nil value)*} + + test {LIBRARIES - register library with no functions} { + r function flush + catch { + r function load LUA lib { + return 1 + } + } e + set _ $e + } {*No functions registered*} + + test {LIBRARIES - load timeout} { + catch { + r function load LUA lib { + local a = 1 + while 1 do a = a + 1 end + } + } e + set _ $e + } {*FUNCTION LOAD timeout*} + + test {LIBRARIES - verify global protection on the load run} { + catch { + r function load LUA lib { + a = 1 + } + } e + set _ $e + } {*attempted to create global variable 'a'*} + + test {FUNCTION - test function restore with function name collision} { + r function flush + r function load lua lib1 { + local function add1(a) + return a + 1 + end + redis.register_function( + 'f1', + function(keys, args) + return add1(1) + end + ) + redis.register_function( + 'f2', + function(keys, args) + return add1(2) + end + ) + redis.register_function( + 'f3', + function(keys, args) + return add1(3) + end + ) + } + set e [r function dump] + r function flush + + # load a library with different name but with the same function name + r function load lua lib1 { + redis.register_function( + 'f6', + function(keys, args) + return 7 + end + ) + } + r function load lua lib2 { + local function add1(a) + return a + 1 + end + redis.register_function( + 'f4', + function(keys, args) + return add1(4) + end + ) + redis.register_function( + 'f5', + function(keys, args) + return add1(5) + end + ) + redis.register_function( + 'f3', + function(keys, args) + return add1(3) + end + ) + } + + catch {r function restore $e} error + assert_match {*Library lib1 already exists*} $error + assert_equal [r fcall f3 0] {4} + assert_equal [r fcall f4 0] {5} + assert_equal [r fcall f5 0] {6} + assert_equal [r fcall f6 0] {7} + + catch {r function restore $e replace} error + assert_match {*Function f3 already exists*} $error + assert_equal [r fcall f3 0] {4} + assert_equal [r fcall f4 0] {5} + assert_equal [r fcall f5 0] {6} + assert_equal [r fcall f6 0] {7} + } + + test {FUNCTION - test function list with code} { + r function flush + r function load lua library1 {redis.register_function('f6', function(keys, args) return 7 end)} + r function list withcode + } {{library_name library1 engine LUA description {} functions {{name f6 description {}}} library_code {redis.register_function('f6', function(keys, args) return 7 end)}}} + + test {FUNCTION - test function list with pattern} { + r function load lua lib1 {redis.register_function('f7', function(keys, args) return 7 end)} + r function list libraryname library* + } {{library_name library1 engine LUA description {} functions {{name f6 description {}}}}} + + test {FUNCTION - test function list wrong argument} { + catch {r function list bad_argument} e + set _ $e + } {*Unknown argument bad_argument*} + + test {FUNCTION - test function list with bad argument to library name} { + catch {r function list libraryname} e + set _ $e + } {*library name argument was not given*} + + test {FUNCTION - test function list withcode multiple times} { + catch {r function list withcode withcode} e + set _ $e + } {*Unknown argument withcode*} + + test {FUNCTION - test function list libraryname multiple times} { + catch {r function list withcode libraryname foo libraryname foo} e + set _ $e + } {*Unknown argument libraryname*} + + test {FUNCTION - verify OOM on function load and function restore} { + r function flush + r function load lua test replace {redis.register_function('f1', function() return 1 end)} + set payload [r function dump] + r config set maxmemory 1 + + r function flush + catch {r function load lua test replace {redis.register_function('f1', function() return 1 end)}} e + assert_match {*command not allowed when used memory*} $e + + r function flush + catch {r function restore $payload} e + assert_match {*command not allowed when used memory*} $e + + r config set maxmemory 0 + } +} diff --git a/tests/unit/scripting.tcl b/tests/unit/scripting.tcl index 970cab992..f342a92b6 100644 --- a/tests/unit/scripting.tcl +++ b/tests/unit/scripting.tcl @@ -15,16 +15,16 @@ if {$is_eval == 1} { } } else { proc run_script {args} { - r function create LUA test replace [lindex $args 0] + r function load LUA test replace [format "redis.register_function('test', function(KEYS, ARGV)\n %s \nend)" [lindex $args 0]] r fcall test {*}[lrange $args 1 end] } proc run_script_ro {args} { - r function create LUA test replace [lindex $args 0] + r function load LUA test replace [format "redis.register_function('test', function(KEYS, ARGV)\n %s \nend)" [lindex $args 0]] r fcall_ro test {*}[lrange $args 1 end] } proc run_script_on_connection {args} { set rd [lindex $args 0] - $rd function create LUA test replace [lindex $args 1] + $rd function load LUA test replace [format "redis.register_function('test', function(KEYS, ARGV)\n %s \nend)" [lindex $args 1]] # read the ok reply of function create $rd read $rd fcall test {*}[lrange $args 2 end] @@ -37,7 +37,7 @@ if {$is_eval == 1} { start_server {tags {"scripting"}} { test {Script - disallow write on OOM} { - r FUNCTION create lua f1 replace { return redis.call('set', 'x', '1') } + r FUNCTION load lua f1 replace { redis.register_function('f1', function() return redis.call('set', 'x', '1') end) } r config set maxmemory 1 @@ -737,7 +737,7 @@ start_server {tags {"scripting"}} { set buf "*3\r\n\$4\r\neval\r\n\$33\r\nwhile 1 do redis.call('ping') end\r\n\$1\r\n0\r\n" append buf "*1\r\n\$4\r\nping\r\n" } else { - set buf "*6\r\n\$8\r\nfunction\r\n\$6\r\ncreate\r\n\$3\r\nlua\r\n\$4\r\ntest\r\n\$7\r\nreplace\r\n\$33\r\nwhile 1 do redis.call('ping') end\r\n" + set buf "*6\r\n\$8\r\nfunction\r\n\$4\r\nload\r\n\$3\r\nlua\r\n\$4\r\ntest\r\n\$7\r\nreplace\r\n\$81\r\nredis.register_function('test', function() while 1 do redis.call('ping') end end)\r\n" append buf "*3\r\n\$5\r\nfcall\r\n\$4\r\ntest\r\n\$1\r\n0\r\n" append buf "*1\r\n\$4\r\nping\r\n" }