diff --git a/src/connection.c b/src/connection.c index 09fa12f2a..23b44a314 100644 --- a/src/connection.c +++ b/src/connection.c @@ -329,6 +329,11 @@ static ssize_t connSocketSyncReadLine(connection *conn, char *ptr, ssize_t size, return syncReadLine(conn->fd, ptr, size, timeout); } +static int connSocketGetType(connection *conn) { + (void) conn; + + return CONN_TYPE_SOCKET; +} ConnectionType CT_Socket = { .ae_handler = connSocketEventHandler, @@ -343,7 +348,8 @@ ConnectionType CT_Socket = { .blocking_connect = connSocketBlockingConnect, .sync_write = connSocketSyncWrite, .sync_read = connSocketSyncRead, - .sync_readline = connSocketSyncReadLine + .sync_readline = connSocketSyncReadLine, + .get_type = connSocketGetType }; diff --git a/src/connection.h b/src/connection.h index 0fd6c5f24..85585a3d0 100644 --- a/src/connection.h +++ b/src/connection.h @@ -48,6 +48,9 @@ typedef enum { #define CONN_FLAG_CLOSE_SCHEDULED (1<<0) /* Closed scheduled by a handler */ #define CONN_FLAG_WRITE_BARRIER (1<<1) /* Write barrier requested */ +#define CONN_TYPE_SOCKET 1 +#define CONN_TYPE_TLS 2 + typedef void (*ConnectionCallbackFunc)(struct connection *conn); typedef struct ConnectionType { @@ -64,6 +67,7 @@ typedef struct ConnectionType { ssize_t (*sync_write)(struct connection *conn, char *ptr, ssize_t size, long long timeout); ssize_t (*sync_read)(struct connection *conn, char *ptr, ssize_t size, long long timeout); ssize_t (*sync_readline)(struct connection *conn, char *ptr, ssize_t size, long long timeout); + int (*get_type)(struct connection *conn); } ConnectionType; struct connection { @@ -194,6 +198,11 @@ static inline ssize_t connSyncReadLine(connection *conn, char *ptr, ssize_t size return conn->type->sync_readline(conn, ptr, size, timeout); } +/* Return CONN_TYPE_* for the specified connection */ +static inline int connGetType(connection *conn) { + return conn->type->get_type(conn); +} + connection *connCreateSocket(); connection *connCreateAcceptedSocket(int fd); diff --git a/src/module.c b/src/module.c index 0714458f7..94a7344e1 100644 --- a/src/module.c +++ b/src/module.c @@ -1753,6 +1753,8 @@ int modulePopulateClientInfoStructure(void *ci, client *client, int structver) { ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_TRACKING; if (client->flags & CLIENT_BLOCKED) ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_BLOCKED; + if (connGetType(client->conn) == CONN_TYPE_TLS) + ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_SSL; int port; connPeerToString(client->conn,ci1->addr,sizeof(ci1->addr),&port); diff --git a/src/tls.c b/src/tls.c index 4f0ea4d65..52887cd23 100644 --- a/src/tls.c +++ b/src/tls.c @@ -823,6 +823,12 @@ exit: return nread; } +static int connTLSGetType(connection *conn_) { + (void) conn_; + + return CONN_TYPE_TLS; +} + ConnectionType CT_TLS = { .ae_handler = tlsEventHandler, .accept = connTLSAccept, @@ -837,6 +843,7 @@ ConnectionType CT_TLS = { .sync_write = connTLSSyncWrite, .sync_read = connTLSSyncRead, .sync_readline = connTLSSyncReadLine, + .get_type = connTLSGetType }; int tlsHasPendingData() { diff --git a/tests/modules/misc.c b/tests/modules/misc.c index 1048d5065..1f9cb1932 100644 --- a/tests/modules/misc.c +++ b/tests/modules/misc.c @@ -195,6 +195,42 @@ int test_setlfu(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_OK; } +int test_clientinfo(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) +{ + (void) argv; + (void) argc; + + RedisModuleClientInfo ci = { .version = REDISMODULE_CLIENTINFO_VERSION }; + + if (RedisModule_GetClientInfoById(&ci, RedisModule_GetClientId(ctx)) == REDISMODULE_ERR) { + RedisModule_ReplyWithError(ctx, "failed to get client info"); + return REDISMODULE_OK; + } + + RedisModule_ReplyWithArray(ctx, 10); + char flags[512]; + snprintf(flags, sizeof(flags) - 1, "%s:%s:%s:%s:%s:%s", + ci.flags & REDISMODULE_CLIENTINFO_FLAG_SSL ? "ssl" : "", + ci.flags & REDISMODULE_CLIENTINFO_FLAG_PUBSUB ? "pubsub" : "", + ci.flags & REDISMODULE_CLIENTINFO_FLAG_BLOCKED ? "blocked" : "", + ci.flags & REDISMODULE_CLIENTINFO_FLAG_TRACKING ? "tracking" : "", + ci.flags & REDISMODULE_CLIENTINFO_FLAG_UNIXSOCKET ? "unixsocket" : "", + ci.flags & REDISMODULE_CLIENTINFO_FLAG_MULTI ? "multi" : ""); + + RedisModule_ReplyWithCString(ctx, "flags"); + RedisModule_ReplyWithCString(ctx, flags); + RedisModule_ReplyWithCString(ctx, "id"); + RedisModule_ReplyWithLongLong(ctx, ci.id); + RedisModule_ReplyWithCString(ctx, "addr"); + RedisModule_ReplyWithCString(ctx, ci.addr); + RedisModule_ReplyWithCString(ctx, "port"); + RedisModule_ReplyWithLongLong(ctx, ci.port); + RedisModule_ReplyWithCString(ctx, "db"); + RedisModule_ReplyWithLongLong(ctx, ci.db); + + return REDISMODULE_OK; +} + int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { REDISMODULE_NOT_USED(argv); REDISMODULE_NOT_USED(argc); @@ -221,6 +257,8 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; if (RedisModule_CreateCommand(ctx,"test.getlfu", test_getlfu,"",0,0,0) == REDISMODULE_ERR) return REDISMODULE_ERR; + if (RedisModule_CreateCommand(ctx,"test.clientinfo", test_clientinfo,"",0,0,0) == REDISMODULE_ERR) + return REDISMODULE_ERR; return REDISMODULE_OK; } diff --git a/tests/unit/moduleapi/misc.tcl b/tests/unit/moduleapi/misc.tcl index 748016f1a..b57a94f6a 100644 --- a/tests/unit/moduleapi/misc.tcl +++ b/tests/unit/moduleapi/misc.tcl @@ -67,4 +67,23 @@ start_server {tags {"modules"}} { assert { $was_set == 0 } } + test {test module clientinfo api} { + # Test basic sanity and SSL flag + set info [r test.clientinfo] + set ssl_flag [expr $::tls ? {"ssl:"} : {":"}] + + assert { [dict get $info db] == 9 } + assert { [dict get $info flags] == "${ssl_flag}::::" } + + # Test MULTI flag + r multi + r test.clientinfo + set info [lindex [r exec] 0] + assert { [dict get $info flags] == "${ssl_flag}::::multi" } + + # Test TRACKING flag + r client tracking on + set info [r test.clientinfo] + assert { [dict get $info flags] == "${ssl_flag}::tracking::" } + } }