diff --git a/src/module.c b/src/module.c index 084ee5fd5..f81b2c042 100644 --- a/src/module.c +++ b/src/module.c @@ -3373,6 +3373,40 @@ int RM_GetClientInfoById(void *ci, uint64_t id) { return modulePopulateClientInfoStructure(ci,client,structver); } +/* Returns the name of the client connection with the given ID. + * + * If the client ID does not exist or if the client has no name associated with + * it, NULL is returned. */ +RedisModuleString *RM_GetClientNameById(RedisModuleCtx *ctx, uint64_t id) { + client *client = lookupClientByID(id); + if (client == NULL || client->name == NULL) return NULL; + robj *name = client->name; + incrRefCount(name); + autoMemoryAdd(ctx, REDISMODULE_AM_STRING, name); + return name; +} + +/* Sets the name of the client with the given ID. This is equivalent to the client calling + * `CLIENT SETNAME name`. + * + * Returns REDISMODULE_OK on success. On failure, REDISMODULE_ERR is returned + * and errno is set as follows: + * + * - ENOENT if the client does not exist + * - EINVAL if the name contains invalid characters */ +int RM_SetClientNameById(uint64_t id, RedisModuleString *name) { + client *client = lookupClientByID(id); + if (client == NULL) { + errno = ENOENT; + return REDISMODULE_ERR; + } + if (clientSetName(client, name) == C_ERR) { + errno = EINVAL; + return REDISMODULE_ERR; + } + return REDISMODULE_OK; +} + /* Publish a message to subscribers (see PUBLISH command). */ int RM_PublishMessage(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) { UNUSED(ctx); @@ -12601,6 +12635,8 @@ void moduleRegisterCoreAPI(void) { REGISTER_API(ServerInfoGetFieldUnsigned); REGISTER_API(ServerInfoGetFieldDouble); REGISTER_API(GetClientInfoById); + REGISTER_API(GetClientNameById); + REGISTER_API(SetClientNameById); REGISTER_API(PublishMessage); REGISTER_API(PublishMessageShard); REGISTER_API(SubscribeToServerEvent); diff --git a/src/networking.c b/src/networking.c index 44aa50642..8cace10d4 100644 --- a/src/networking.c +++ b/src/networking.c @@ -2832,18 +2832,9 @@ sds getAllClientsInfoString(int type) { return o; } -/* This function implements CLIENT SETNAME, including replying to the - * user with an error if the charset is wrong (in that case C_ERR is - * returned). If the function succeeded C_OK is returned, and it's up - * to the caller to send a reply if needed. - * - * Setting an empty string as name has the effect of unsetting the - * currently set name: the client will remain unnamed. - * - * This function is also used to implement the HELLO SETNAME option. */ -int clientSetNameOrReply(client *c, robj *name) { - int len = sdslen(name->ptr); - char *p = name->ptr; +/* Returns C_OK if the name has been set or C_ERR if the name is invalid. */ +int clientSetName(client *c, robj *name) { + int len = (name != NULL) ? sdslen(name->ptr) : 0; /* Setting the client name to an empty string actually removes * the current name. */ @@ -2856,11 +2847,9 @@ int clientSetNameOrReply(client *c, robj *name) { /* Otherwise check if the charset is ok. We need to do this otherwise * CLIENT LIST format will break. You should always be able to * split by space to get the different fields. */ + char *p = name->ptr; for (int j = 0; j < len; j++) { if (p[j] < '!' || p[j] > '~') { /* ASCII is assumed. */ - addReplyError(c, - "Client names cannot contain spaces, " - "newlines or special characters."); return C_ERR; } } @@ -2870,6 +2859,25 @@ int clientSetNameOrReply(client *c, robj *name) { return C_OK; } +/* This function implements CLIENT SETNAME, including replying to the + * user with an error if the charset is wrong (in that case C_ERR is + * returned). If the function succeeded C_OK is returned, and it's up + * to the caller to send a reply if needed. + * + * Setting an empty string as name has the effect of unsetting the + * currently set name: the client will remain unnamed. + * + * This function is also used to implement the HELLO SETNAME option. */ +int clientSetNameOrReply(client *c, robj *name) { + int result = clientSetName(c, name); + if (result == C_ERR) { + addReplyError(c, + "Client names cannot contain spaces, " + "newlines or special characters."); + } + return result; +} + /* Reset the client state to resemble a newly connected client. */ void resetCommand(client *c) { diff --git a/src/redismodule.h b/src/redismodule.h index 899bb519d..e36d5f3c0 100644 --- a/src/redismodule.h +++ b/src/redismodule.h @@ -1000,6 +1000,8 @@ REDISMODULE_API void (*RedisModule_ChannelAtPosWithFlags)(RedisModuleCtx *ctx, i REDISMODULE_API unsigned long long (*RedisModule_GetClientId)(RedisModuleCtx *ctx) REDISMODULE_ATTR; REDISMODULE_API RedisModuleString * (*RedisModule_GetClientUserNameById)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_GetClientInfoById)(void *ci, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleString * (*RedisModule_GetClientNameById)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_SetClientNameById)(uint64_t id, RedisModuleString *name) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_PublishMessage)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_PublishMessageShard)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_GetContextFlags)(RedisModuleCtx *ctx) REDISMODULE_ATTR; @@ -1426,6 +1428,8 @@ static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int REDISMODULE_GET_API(ServerInfoGetFieldUnsigned); REDISMODULE_GET_API(ServerInfoGetFieldDouble); REDISMODULE_GET_API(GetClientInfoById); + REDISMODULE_GET_API(GetClientNameById); + REDISMODULE_GET_API(SetClientNameById); REDISMODULE_GET_API(PublishMessage); REDISMODULE_GET_API(PublishMessageShard); REDISMODULE_GET_API(SubscribeToServerEvent); diff --git a/src/server.h b/src/server.h index abaa5f046..6beae765b 100644 --- a/src/server.h +++ b/src/server.h @@ -2498,6 +2498,7 @@ char *getClientPeerId(client *client); char *getClientSockName(client *client); sds catClientInfoString(sds s, client *client); sds getAllClientsInfoString(int type); +int clientSetName(client *c, robj *name); void rewriteClientCommandVector(client *c, int argc, ...); void rewriteClientCommandArgument(client *c, int i, robj *newval); void replaceClientCommandVector(client *c, int argc, robj **argv); diff --git a/tests/modules/misc.c b/tests/modules/misc.c index da6ee9f9e..dce78ca2b 100644 --- a/tests/modules/misc.c +++ b/tests/modules/misc.c @@ -270,6 +270,27 @@ int test_clientinfo(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_OK; } +int test_getname(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + (void)argv; + if (argc != 1) return RedisModule_WrongArity(ctx); + unsigned long long id = RedisModule_GetClientId(ctx); + RedisModuleString *name = RedisModule_GetClientNameById(ctx, id); + if (name == NULL) + return RedisModule_ReplyWithError(ctx, "-ERR No name"); + RedisModule_ReplyWithString(ctx, name); + RedisModule_FreeString(ctx, name); + return REDISMODULE_OK; +} + +int test_setname(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + if (argc != 2) return RedisModule_WrongArity(ctx); + unsigned long long id = RedisModule_GetClientId(ctx); + if (RedisModule_SetClientNameById(id, argv[1]) == REDISMODULE_OK) + return RedisModule_ReplyWithSimpleString(ctx, "OK"); + else + return RedisModule_ReplyWithError(ctx, strerror(errno)); +} + int test_log_tsctx(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { RedisModuleCtx *tsctx = RedisModule_GetDetachedThreadSafeContext(ctx); @@ -384,6 +405,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; if (RedisModule_CreateCommand(ctx,"test.clientinfo", test_clientinfo,"",0,0,0) == REDISMODULE_ERR) return REDISMODULE_ERR; + if (RedisModule_CreateCommand(ctx,"test.getname", test_getname,"",0,0,0) == REDISMODULE_ERR) + return REDISMODULE_ERR; + if (RedisModule_CreateCommand(ctx,"test.setname", test_setname,"",0,0,0) == REDISMODULE_ERR) + return REDISMODULE_ERR; if (RedisModule_CreateCommand(ctx,"test.redisversion", test_redisversion,"",0,0,0) == REDISMODULE_ERR) return REDISMODULE_ERR; if (RedisModule_CreateCommand(ctx,"test.getclientcert", test_getclientcert,"",0,0,0) == REDISMODULE_ERR) diff --git a/tests/unit/moduleapi/misc.tcl b/tests/unit/moduleapi/misc.tcl index c8a5107f9..3fca1cf7f 100644 --- a/tests/unit/moduleapi/misc.tcl +++ b/tests/unit/moduleapi/misc.tcl @@ -103,6 +103,18 @@ start_server {tags {"modules"}} { assert { [dict get $info flags] == "${ssl_flag}::tracking::" } } + test {test module get/set client name by id api} { + catch { r test.getname } e + assert_equal "-ERR No name" $e + r client setname nobody + catch { r test.setname "name with spaces" } e + assert_match "*Invalid argument*" $e + assert_equal nobody [r client getname] + assert_equal nobody [r test.getname] + r test.setname somebody + assert_equal somebody [r client getname] + } + test {test module getclientcert api} { set cert [r test.getclientcert]