diff --git a/src/eval.c b/src/eval.c index dd488a984..6eb6ed1d4 100644 --- a/src/eval.c +++ b/src/eval.c @@ -385,17 +385,17 @@ uint64_t evalGetCommandFlags(client *c, uint64_t cmd_flags) { int evalsha = c->cmd->proc == evalShaCommand || c->cmd->proc == evalShaRoCommand; if (evalsha && sdslen(c->argv[1]->ptr) != 40) return cmd_flags; + uint64_t script_flags; evalCalcFunctionName(evalsha, c->argv[1]->ptr, funcname); char *lua_cur_script = funcname + 2; - dictEntry *de = dictFind(lctx.lua_scripts, lua_cur_script); - uint64_t script_flags; - if (!de) { + c->cur_script = dictFind(lctx.lua_scripts, lua_cur_script); + if (!c->cur_script) { if (evalsha) return cmd_flags; if (evalExtractShebangFlags(c->argv[1]->ptr, &script_flags, NULL, NULL) == C_ERR) return cmd_flags; } else { - luaScript *l = dictGetVal(de); + luaScript *l = dictGetVal(c->cur_script); script_flags = l->flags; } if (script_flags & SCRIPT_FLAG_EVAL_COMPAT_MODE) @@ -502,7 +502,12 @@ void evalGenericCommand(client *c, int evalsha) { return; } - evalCalcFunctionName(evalsha, c->argv[1]->ptr, funcname); + if (c->cur_script) { + funcname[0] = 'f', funcname[1] = '_'; + memcpy(funcname+2, dictGetKey(c->cur_script), 40); + funcname[42] = '\0'; + } else + evalCalcFunctionName(evalsha, c->argv[1]->ptr, funcname); /* Push the pcall error handler function on the stack. */ lua_getglobal(lua, "__redis__err__handler"); @@ -531,7 +536,9 @@ void evalGenericCommand(client *c, int evalsha) { } char *lua_cur_script = funcname + 2; - dictEntry *de = dictFind(lctx.lua_scripts, lua_cur_script); + dictEntry *de = c->cur_script; + if (!de) + de = dictFind(lctx.lua_scripts, lua_cur_script); luaScript *l = dictGetVal(de); int ro = c->cmd->proc == evalRoCommand || c->cmd->proc == evalShaRoCommand; diff --git a/src/functions.c b/src/functions.c index a91c60ddb..3f7b3a7ef 100644 --- a/src/functions.c +++ b/src/functions.c @@ -608,9 +608,10 @@ void functionKillCommand(client *c) { * Note that it does not guarantee the command arguments are right. */ uint64_t fcallGetCommandFlags(client *c, uint64_t cmd_flags) { robj *function_name = c->argv[1]; - functionInfo *fi = dictFetchValue(curr_functions_lib_ctx->functions, function_name->ptr); - if (!fi) + c->cur_script = dictFind(curr_functions_lib_ctx->functions, function_name->ptr); + if (!c->cur_script) return cmd_flags; + functionInfo *fi = dictGetVal(c->cur_script); uint64_t script_flags = fi->f_flags; return scriptFlagsToCmdFlags(cmd_flags, script_flags); } @@ -620,11 +621,14 @@ static void fcallCommandGeneric(client *c, int ro) { replicationFeedMonitors(c,server.monitors,c->db->id,c->argv,c->argc); robj *function_name = c->argv[1]; - functionInfo *fi = dictFetchValue(curr_functions_lib_ctx->functions, function_name->ptr); - if (!fi) { + dictEntry *de = c->cur_script; + if (!de) + de = dictFind(curr_functions_lib_ctx->functions, function_name->ptr); + if (!de) { addReplyError(c, "Function not found"); return; } + functionInfo *fi = dictGetVal(de); engine *engine = fi->li->ei->engine; long long numkeys; diff --git a/src/networking.c b/src/networking.c index 757102b19..76566ddd7 100644 --- a/src/networking.c +++ b/src/networking.c @@ -158,6 +158,7 @@ client *createClient(connection *conn) { c->original_argc = 0; c->original_argv = NULL; c->cmd = c->lastcmd = c->realcmd = NULL; + c->cur_script = NULL; c->multibulklen = 0; c->bulklen = -1; c->sentlen = 0; @@ -2018,6 +2019,7 @@ void resetClient(client *c) { redisCommandProc *prevcmd = c->cmd ? c->cmd->proc : NULL; freeClientArgv(c); + c->cur_script = NULL; c->reqtype = 0; c->multibulklen = 0; c->bulklen = -1; diff --git a/src/server.h b/src/server.h index ce47d232f..190664ded 100644 --- a/src/server.h +++ b/src/server.h @@ -1136,6 +1136,7 @@ typedef struct client { time_t ctime; /* Client creation time. */ long duration; /* Current command duration. Used for measuring latency of blocking/non-blocking cmds */ int slot; /* The slot the client is executing against. Set to -1 if no slot is being used */ + dictEntry *cur_script; /* Cached pointer to the dictEntry of the script being executed. */ time_t lastinteraction; /* Time of the last interaction, used for timeout */ time_t obuf_soft_limit_reached_time; int authenticated; /* Needed when the default user requires auth. */