From 978e1fdda5648eff82489729866ff0f952625fc0 Mon Sep 17 00:00:00 2001 From: John Sully Date: Wed, 29 Jan 2020 19:03:18 -0500 Subject: [PATCH 1/2] fix leak in cron Former-commit-id: c1f4e344bdaf21bc74fae6e1b0cb7fc1ce687e62 --- src/ae.cpp | 2 ++ src/cron.cpp | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ae.cpp b/src/ae.cpp index a08b23454..d84a5b0d6 100644 --- a/src/ae.cpp +++ b/src/ae.cpp @@ -130,7 +130,9 @@ struct aeCommand void *clientData; aeCommandControl *pctl; }; +#ifdef PIPE_BUF static_assert(sizeof(aeCommand) <= PIPE_BUF); +#endif void aeProcessCmd(aeEventLoop *eventLoop, int fd, void *, int ) { diff --git a/src/cron.cpp b/src/cron.cpp index 2c20a5c0e..11d1f6d8c 100644 --- a/src/cron.cpp +++ b/src/cron.cpp @@ -67,6 +67,7 @@ void cronCommand(client *c) robj *o = createObject(OBJ_CRON, spjob.release()); setKey(c->db, c->argv[ARG_NAME], o); + decrRefCount(o); // use an expire to trigger execution. Note: We use a subkey expire here so legacy clients don't delete it. setExpire(c, c->db, c->argv[ARG_NAME], c->argv[ARG_NAME], base + interval); addReply(c, shared.ok); @@ -102,7 +103,7 @@ void executeCronJobExpireHook(const char *key, robj *o) int dbId = job->dbNum; if (job->fSingleShot) { - dbSyncDelete(cFake->db, keyobj); + serverAssert(dbSyncDelete(cFake->db, keyobj)); } else { From f5b08185b15fe376bc0e8cbaaa54256b318646e9 Mon Sep 17 00:00:00 2001 From: John Sully Date: Wed, 29 Jan 2020 21:21:47 -0500 Subject: [PATCH 2/2] TLS Thread Safety fixes Former-commit-id: e98a5fc108c5448307a8cc38182c79263f01102a --- src/connection.cpp | 19 +++++++- src/connhelpers.h | 11 +---- src/networking.cpp | 1 + src/server.cpp | 44 ++++------------- src/server.h | 1 + src/tls.cpp | 114 +++++++++++++++++++++++++++++++++++++++------ 6 files changed, 130 insertions(+), 60 deletions(-) diff --git a/src/connection.cpp b/src/connection.cpp index 72e3f9f4b..09ecabf21 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -109,7 +109,7 @@ static int connSocketConnect(connection *conn, const char *addr, int port, const conn->state = CONN_STATE_CONNECTING; conn->conn_handler = connect_handler; - aeCreateFileEvent(serverTL->el, conn->fd, AE_WRITABLE, + aeCreateFileEvent(serverTL->el, conn->fd, AE_WRITABLE|AE_WRITE_THREADSAFE, conn->type->ae_handler, conn); return C_OK; @@ -244,7 +244,7 @@ static const char *connSocketGetLastError(connection *conn) { return strerror(conn->last_errno); } -static void connSocketEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) +void connSocketEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) { UNUSED(el); UNUSED(fd); @@ -262,7 +262,11 @@ static void connSocketEventHandler(struct aeEventLoop *el, int fd, void *clientD if (!conn->write_handler) aeDeleteFileEvent(serverTL->el,conn->fd,AE_WRITABLE); + { + AeLocker locker; + locker.arm(nullptr); if (!callHandler(conn, conn->conn_handler)) return; + } conn->conn_handler = NULL; } @@ -438,3 +442,14 @@ const char *connGetInfo(connection *conn, char *buf, size_t buf_len) { return buf; } + +int callHandler(connection *conn, ConnectionCallbackFunc handler) { + conn->flags |= CONN_FLAG_IN_HANDLER; + if (handler) handler(conn); + conn->flags &= ~CONN_FLAG_IN_HANDLER; + if (conn->flags & CONN_FLAG_CLOSE_SCHEDULED) { + connClose(conn); + return 0; + } + return 1; +} \ No newline at end of file diff --git a/src/connhelpers.h b/src/connhelpers.h index f237c9b1d..1f3e2e914 100644 --- a/src/connhelpers.h +++ b/src/connhelpers.h @@ -71,15 +71,6 @@ static inline int exitHandler(connection *conn) { * 3. Mark the handler as NOT in use and perform deferred close if was * requested by the handler at any time. */ -static inline int callHandler(connection *conn, ConnectionCallbackFunc handler) { - conn->flags |= CONN_FLAG_IN_HANDLER; - if (handler) handler(conn); - conn->flags &= ~CONN_FLAG_IN_HANDLER; - if (conn->flags & CONN_FLAG_CLOSE_SCHEDULED) { - connClose(conn); - return 0; - } - return 1; -} +int callHandler(connection *conn, ConnectionCallbackFunc handler); #endif /* __REDIS_CONNHELPERS_H */ diff --git a/src/networking.cpp b/src/networking.cpp index 8624c4e2c..3b9befa6c 100644 --- a/src/networking.cpp +++ b/src/networking.cpp @@ -2215,6 +2215,7 @@ void readQueryFromClient(connection *conn) { size_t qblen; serverAssert(FCorrectThread(c)); + serverAssert(!GlobalLocksAcquired()); AeLocker aelock; AssertCorrectThread(c); diff --git a/src/server.cpp b/src/server.cpp index a2509c654..af69e1ffd 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -2152,12 +2152,15 @@ int serverCronLite(struct aeEventLoop *eventLoop, long long id, void *clientData * for ready file descriptors. */ void beforeSleep(struct aeEventLoop *eventLoop) { UNUSED(eventLoop); + int iel = ielFromEventLoop(eventLoop); /* Handle TLS pending data. (must be done before flushAppendOnlyFile) */ tlsProcessPendingData(); /* If tls still has pending unread data don't sleep at all. */ aeSetDontWait(eventLoop, tlsHasPendingData()); + aeAcquireLock(); + /* Call the Redis Cluster before sleep function. Note that this function * may change the state of Redis Cluster (from ok to fail or vice versa), * so it's a good idea to call it before serving the unblocked clients @@ -2194,9 +2197,9 @@ void beforeSleep(struct aeEventLoop *eventLoop) { if (moduleCount()) moduleHandleBlockedClients(ielFromEventLoop(eventLoop)); /* Try to process pending commands for clients that were just unblocked. */ - if (listLength(g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].unblocked_clients)) + if (listLength(g_pserver->rgthreadvar[iel].unblocked_clients)) { - processUnblockedClients(IDX_EVENT_LOOP_MAIN); + processUnblockedClients(iel); } /* Write the AOF buffer on disk */ @@ -2204,47 +2207,20 @@ void beforeSleep(struct aeEventLoop *eventLoop) { /* Handle writes with pending output buffers. */ aeReleaseLock(); - handleClientsWithPendingWrites(IDX_EVENT_LOOP_MAIN); - aeAcquireLock(); - - /* Close clients that need to be closed asynchronous */ - freeClientsInAsyncFreeQueue(IDX_EVENT_LOOP_MAIN); - - /* Before we are going to sleep, let the threads access the dataset by - * releasing the GIL. Redis main thread will not touch anything at this - * time. */ - if (moduleCount()) moduleReleaseGIL(TRUE /*fServerThread*/); -} - -void beforeSleepLite(struct aeEventLoop *eventLoop) -{ - int iel = ielFromEventLoop(eventLoop); - - /* Try to process pending commands for clients that were just unblocked. */ - aeAcquireLock(); - if (listLength(g_pserver->rgthreadvar[iel].unblocked_clients)) { - processUnblockedClients(iel); - } - - /* Check if there are clients unblocked by modules that implement - * blocking commands. */ - if (moduleCount()) moduleHandleBlockedClients(ielFromEventLoop(eventLoop)); - aeReleaseLock(); - - /* Handle writes with pending output buffers. */ handleClientsWithPendingWrites(iel); - aeAcquireLock(); + /* Close clients that need to be closed asynchronous */ freeClientsInAsyncFreeQueue(iel); - aeReleaseLock(); /* Before we are going to sleep, let the threads access the dataset by * releasing the GIL. Redis main thread will not touch anything at this * time. */ if (moduleCount()) moduleReleaseGIL(TRUE /*fServerThread*/); + aeReleaseLock(); } + /* This function is called immadiately after the event loop multiplexing * API returned, and the control is going to soon return to Redis by invoking * the different events callbacks. */ @@ -5149,11 +5125,11 @@ void *workerThreadMain(void *parg) int iel = (int)((int64_t)parg); serverLog(LOG_INFO, "Thread %d alive.", iel); serverTL = g_pserver->rgthreadvar+iel; // set the TLS threadsafe global + tlsInitThread(); moduleAcquireGIL(true); // Normally afterSleep acquires this, but that won't be called on the first run - int isMainThread = (iel == IDX_EVENT_LOOP_MAIN); aeEventLoop *el = g_pserver->rgthreadvar[iel].el; - aeSetBeforeSleepProc(el, isMainThread ? beforeSleep : beforeSleepLite, isMainThread ? 0 : AE_SLEEP_THREADSAFE); + aeSetBeforeSleepProc(el, beforeSleep, AE_SLEEP_THREADSAFE); aeSetAfterSleepProc(el, afterSleep, AE_SLEEP_THREADSAFE); aeMain(el); aeDeleteEventLoop(el); diff --git a/src/server.h b/src/server.h index aa0cb4cf0..aef9a4df9 100644 --- a/src/server.h +++ b/src/server.h @@ -2982,6 +2982,7 @@ inline int FCorrectThread(client *c) /* TLS stuff */ void tlsInit(void); +void tlsInitThread(); int tlsConfigure(redisTLSContextConfig *ctx_config); #define redisDebug(fmt, ...) \ diff --git a/src/tls.cpp b/src/tls.cpp index c429e8d69..3ec256c7f 100644 --- a/src/tls.cpp +++ b/src/tls.cpp @@ -31,6 +31,8 @@ #include "server.h" #include "connhelpers.h" #include "adlist.h" +#include "aelocker.h" +#include #ifdef USE_OPENSSL @@ -53,6 +55,7 @@ extern ConnectionType CT_Socket; SSL_CTX *redis_tls_ctx; +fastlock g_ctxtlock("SSL CTX"); static int parseProtocolsConfig(const char *str) { int i, count = 0; @@ -91,7 +94,7 @@ static int parseProtocolsConfig(const char *str) { /* list of connections with pending data already read from the socket, but not * served to the reader yet. */ -static list *pending_list = NULL; +static thread_local list *pending_list = NULL; void tlsInit(void) { ERR_load_crypto_strings(); @@ -102,10 +105,15 @@ void tlsInit(void) { serverLog(LL_WARNING, "OpenSSL: Failed to seed random number generator."); } - pending_list = listCreate(); - /* Server configuration */ g_pserver->tls_auth_clients = 1; /* Secure by default */ + tlsInitThread(); +} + +void tlsInitThread(void) +{ + serverAssert(pending_list == nullptr); + pending_list = listCreate(); } /* Attempt to configure/reconfigure TLS. This operation is atomic and will @@ -226,8 +234,11 @@ int tlsConfigure(redisTLSContextConfig *ctx_config) { } #endif + { + std::unique_lock ul(g_ctxtlock); SSL_CTX_free(redis_tls_ctx); redis_tls_ctx = ctx; + } return C_OK; @@ -276,13 +287,16 @@ typedef struct tls_connection { SSL *ssl; char *ssl_error; listNode *pending_list_node; + aeEventLoop *el; } tls_connection; connection *connCreateTLS(void) { tls_connection *conn = (tls_connection*)zcalloc(sizeof(tls_connection), MALLOC_LOCAL); conn->c.type = &CT_TLS; conn->c.fd = -1; + std::unique_lock ul(g_ctxtlock); conn->ssl = SSL_new(redis_tls_ctx); + conn->el = serverTL->el; return (connection *) conn; } @@ -343,15 +357,17 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType * void registerSSLEvent(tls_connection *conn, WantIOType want) { int mask = aeGetFileEvents(serverTL->el, conn->c.fd); + serverAssert(conn->el == serverTL->el); + switch (want) { case WANT_READ: - if (mask & AE_WRITABLE) aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE); - if (!(mask & AE_READABLE)) aeCreateFileEvent(serverTL->el, conn->c.fd, AE_READABLE, + if (mask & AE_WRITABLE) aeDeleteFileEvent(conn->el, conn->c.fd, AE_WRITABLE); + if (!(mask & AE_READABLE)) aeCreateFileEvent(conn->el, conn->c.fd, AE_READABLE|AE_READ_THREADSAFE, tlsEventHandler, conn); break; case WANT_WRITE: - if (mask & AE_READABLE) aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_READABLE); - if (!(mask & AE_WRITABLE)) aeCreateFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE, + if (mask & AE_READABLE) aeDeleteFileEvent(conn->el, conn->c.fd, AE_READABLE); + if (!(mask & AE_WRITABLE)) aeCreateFileEvent(conn->el, conn->c.fd, AE_WRITABLE|AE_WRITE_THREADSAFE, tlsEventHandler, conn); break; default: @@ -360,24 +376,38 @@ void registerSSLEvent(tls_connection *conn, WantIOType want) { } } -void updateSSLEvent(tls_connection *conn) { +void updateSSLEventCore(tls_connection *conn) { int mask = aeGetFileEvents(serverTL->el, conn->c.fd); int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ); int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE); + serverAssert(conn->el == serverTL->el); + if (need_read && !(mask & AE_READABLE)) - aeCreateFileEvent(serverTL->el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); + aeCreateFileEvent(serverTL->el, conn->c.fd, AE_READABLE|AE_READ_THREADSAFE, tlsEventHandler, conn); if (!need_read && (mask & AE_READABLE)) aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_READABLE); if (need_write && !(mask & AE_WRITABLE)) - aeCreateFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); + aeCreateFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE|AE_WRITE_THREADSAFE, tlsEventHandler, conn); if (!need_write && (mask & AE_WRITABLE)) aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE); } -static void tlsHandleEvent(tls_connection *conn, int mask) { +void updateSSLEvent(tls_connection *conn) { + if (conn->el != serverTL->el) { + aePostFunction(conn->el, [conn]{ + updateSSLEventCore(conn); + }); + } else { + updateSSLEventCore(conn); + } +} + +void tlsHandleEvent(tls_connection *conn, int mask) { int ret; + serverAssert(!GlobalLocksAcquired()); + serverAssert(conn->el == serverTL->el); TLSCONN_DEBUG("tlsEventHandler(): fd=%d, state=%d, mask=%d, r=%d, w=%d, flags=%d", fd, conn->c.state, mask, conn->c.read_handler != NULL, conn->c.write_handler != NULL, @@ -415,7 +445,11 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { } } + { + AeLocker locker; + locker.arm(nullptr); if (!callHandler((connection *) conn, conn->c.conn_handler)) return; + } conn->c.conn_handler = NULL; break; case CONN_STATE_ACCEPTING: @@ -437,7 +471,11 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { conn->c.state = CONN_STATE_CONNECTED; } + { + AeLocker locker; + locker.arm(nullptr); if (!callHandler((connection *) conn, conn->c.conn_handler)) return; + } conn->c.conn_handler = NULL; break; case CONN_STATE_CONNECTED: @@ -461,12 +499,18 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER; if (!invert && call_read) { + AeLocker lock; + if (!(conn->c.flags & CONN_FLAG_READ_THREADSAFE)) + lock.arm(nullptr); conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *) conn, conn->c.read_handler)) return; } /* Fire the writable event. */ if (call_write) { + AeLocker lock; + if (!(conn->c.flags & CONN_FLAG_WRITE_THREADSAFE)) + lock.arm(nullptr); conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ; if (!callHandler((connection *) conn, conn->c.write_handler)) return; } @@ -474,6 +518,9 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { /* If we have to invert the call, fire the readable event now * after the writable one. */ if (invert && call_read) { + AeLocker lock; + if (!(conn->c.flags & CONN_FLAG_READ_THREADSAFE)) + lock.arm(nullptr); conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *) conn, conn->c.read_handler)) return; } @@ -509,8 +556,8 @@ static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, in tlsHandleEvent(conn, mask); } -static void connTLSClose(connection *conn_) { - tls_connection *conn = (tls_connection *) conn_; +static void connTLSCloseCore(tls_connection *conn) { + serverAssert(conn->el == serverTL->el); if (conn->ssl) { SSL_free(conn->ssl); @@ -527,13 +574,26 @@ static void connTLSClose(connection *conn_) { conn->pending_list_node = NULL; } - CT_Socket.close(conn_); + CT_Socket.close(&conn->c); +} + +static void connTLSClose(connection *conn_) { + tls_connection *conn = (tls_connection *) conn_; + if (conn->el != serverTL->el) { + aePostFunction(conn->el, [conn]{ + connTLSCloseCore(conn); + }); + } else { + connTLSCloseCore(conn); + } } static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) { tls_connection *conn = (tls_connection *) _conn; int ret; + serverAssert(conn->el == serverTL->el); + if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR; ERR_clear_error(); @@ -562,6 +622,8 @@ static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handle static int connTLSConnect(connection *conn_, const char *addr, int port, const char *src_addr, ConnectionCallbackFunc connect_handler) { tls_connection *conn = (tls_connection *) conn_; + serverAssert(conn->el == serverTL->el); + if (conn->c.state != CONN_STATE_NONE) return C_ERR; ERR_clear_error(); @@ -609,6 +671,8 @@ static int connTLSRead(connection *conn_, void *buf, size_t buf_len) { int ret; int ssl_err; + serverAssert(conn->el == serverTL->el); + if (conn->c.state != CONN_STATE_CONNECTED) return -1; ERR_clear_error(); ret = SSL_read(conn->ssl, buf, buf_len); @@ -643,17 +707,31 @@ static const char *connTLSGetLastError(connection *conn_) { } int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier, bool fThreadSafe) { + serverAssert(((tls_connection*)conn)->el == serverTL->el); conn->write_handler = func; if (barrier) conn->flags |= CONN_FLAG_WRITE_BARRIER; else conn->flags &= ~CONN_FLAG_WRITE_BARRIER; + + if (fThreadSafe) + conn->flags |= CONN_FLAG_WRITE_THREADSAFE; + else + conn->flags &= ~CONN_FLAG_WRITE_THREADSAFE; + updateSSLEvent((tls_connection *) conn); return C_OK; } int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func, bool fThreadSafe) { + serverAssert(((tls_connection*)conn)->el == serverTL->el); conn->read_handler = func; + + if (fThreadSafe) + conn->flags |= CONN_FLAG_READ_THREADSAFE; + else + conn->flags &= ~CONN_FLAG_READ_THREADSAFE; + updateSSLEvent((tls_connection *) conn); return C_OK; } @@ -674,6 +752,8 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port, tls_connection *conn = (tls_connection *) conn_; int ret; + serverAssert(conn->el == serverTL->el); + if (conn->c.state != CONN_STATE_NONE) return C_ERR; /* Initiate socket blocking connect first */ @@ -697,6 +777,8 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port, static ssize_t connTLSSyncWrite(connection *conn_, const char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *) conn_; + serverAssert(conn->el == serverTL->el); + setBlockingTimeout(conn, timeout); SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE); int ret = SSL_write(conn->ssl, ptr, size); @@ -709,6 +791,8 @@ static ssize_t connTLSSyncWrite(connection *conn_, const char *ptr, ssize_t size static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *) conn_; + serverAssert(conn->el == serverTL->el); + setBlockingTimeout(conn, timeout); int ret = SSL_read(conn->ssl, ptr, size); unsetBlockingTimeout(conn); @@ -720,6 +804,8 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l tls_connection *conn = (tls_connection *) conn_; ssize_t nread = 0; + serverAssert(conn->el == serverTL->el); + setBlockingTimeout(conn, timeout); size--;