diff --git a/src/Makefile b/src/Makefile index da848d474..0eecd5500 100644 --- a/src/Makefile +++ b/src/Makefile @@ -236,7 +236,8 @@ ifeq ($(MALLOC),memkind) endif ifeq ($(BUILD_TLS),yes) - FINAL_CFLAGS+=-DUSE_OPENSSL $(OPENSSL_CFLAGS) + FINAL_CFLAGS+=-DUSE_OPENSSL $(OPENSSL_CXXFLAGS) + FINAL_CXXFLAGS+=-DUSE_OPENSSL $(OPENSSL_CXXFLAGS) FINAL_LDFLAGS+=$(OPENSSL_LDFLAGS) FINAL_LIBS += ../deps/hiredis/libhiredis_ssl.a -lssl -lcrypto endif diff --git a/src/connection.h b/src/connection.h index f5624b596..515229d6a 100644 --- a/src/connection.h +++ b/src/connection.h @@ -68,6 +68,7 @@ typedef struct ConnectionType { ssize_t (*sync_write)(struct connection *conn, const char *ptr, ssize_t size, long long timeout); ssize_t (*sync_read)(struct connection *conn, char *ptr, ssize_t size, long long timeout); ssize_t (*sync_readline)(struct connection *conn, char *ptr, ssize_t size, long long timeout); + void (*marshal_thread)(struct connection *conn); } ConnectionType; struct connection { @@ -198,6 +199,11 @@ static inline ssize_t connSyncReadLine(connection *conn, char *ptr, ssize_t size return conn->type->sync_readline(conn, ptr, size, timeout); } +static inline void connMarshalThread(connection *conn) { + if (conn->type->marshal_thread != nullptr) + conn->type->marshal_thread(conn); +} + connection *connCreateSocket(); connection *connCreateAcceptedSocket(int fd); diff --git a/src/networking.cpp b/src/networking.cpp index 58ca4f86b..8b1e889f6 100644 --- a/src/networking.cpp +++ b/src/networking.cpp @@ -1313,7 +1313,8 @@ void acceptOnThread(connection *conn, int flags, char *cip) szT = (char*)zmalloc(NET_IP_STR_LEN, MALLOC_LOCAL); memcpy(szT, cip, NET_IP_STR_LEN); } - int res = aePostFunction(g_pserver->rgthreadvar[ielTarget].el, [conn, flags, ielTarget, szT]{ + int res = aePostFunction(g_pserver->rgthreadvar[ielTarget].el, [conn, flags, ielTarget, szT] { + connMarshalThread(conn); acceptCommonHandler(conn,flags,szT,ielTarget); if (!g_fTestMode && !g_pserver->loading) rgacceptsInFlight[ielTarget].fetch_sub(1, std::memory_order_relaxed); diff --git a/src/sentinel.cpp b/src/sentinel.cpp index 16d408ad4..2d116982e 100644 --- a/src/sentinel.cpp +++ b/src/sentinel.cpp @@ -32,7 +32,9 @@ #include "hiredis.h" #ifdef USE_OPENSSL #include "openssl/ssl.h" +extern "C" { #include "hiredis_ssl.h" +} #endif #include "async.h" diff --git a/src/server.cpp b/src/server.cpp index 8566dd7f0..654044699 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -2940,7 +2940,7 @@ static void initNetworking(int fReusePort) } /* Abort if there are no listening sockets at all. */ - if (g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].ipfd_count == 0 && g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].tlsfd_count && g_pserver->sofd < 0) { + if (g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].ipfd_count == 0 && g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].tlsfd_count == 0 && g_pserver->sofd < 0) { serverLog(LL_WARNING, "Configured to not listen anywhere, exiting."); exit(1); } @@ -5305,6 +5305,7 @@ void *workerThreadMain(void *parg) int iel = (int)((int64_t)parg); serverLog(LOG_INFO, "Thread %d alive.", iel); serverTL = g_pserver->rgthreadvar+iel; // set the TLS threadsafe global + tlsInitThread(); if (iel != IDX_EVENT_LOOP_MAIN) { diff --git a/src/server.h b/src/server.h index 628e1d3b9..f42968994 100644 --- a/src/server.h +++ b/src/server.h @@ -3061,6 +3061,7 @@ inline int FCorrectThread(client *c) /* TLS stuff */ void tlsInit(void); +void tlsInitThread(); int tlsConfigure(redisTLSContextConfig *ctx_config); diff --git a/src/tls.cpp b/src/tls.cpp index 28a74df9a..d297695cc 100644 --- a/src/tls.cpp +++ b/src/tls.cpp @@ -31,6 +31,8 @@ #include "server.h" #include "connhelpers.h" #include "adlist.h" +#include "aelocker.h" +#include #ifdef USE_OPENSSL @@ -53,6 +55,7 @@ extern ConnectionType CT_Socket; SSL_CTX *redis_tls_ctx; +fastlock g_ctxtlock("SSL CTX"); static int parseProtocolsConfig(const char *str) { int i, count = 0; @@ -91,7 +94,7 @@ static int parseProtocolsConfig(const char *str) { /* list of connections with pending data already read from the socket, but not * served to the reader yet. */ -static list *pending_list = NULL; +static thread_local list *pending_list = NULL; /** * OpenSSL global initialization and locking handling callbacks. @@ -147,10 +150,15 @@ void tlsInit(void) { serverLog(LL_WARNING, "OpenSSL: Failed to seed random number generator."); } - pending_list = listCreate(); - /* Server configuration */ - server.tls_auth_clients = 1; /* Secure by default */ + g_pserver->tls_auth_clients = 1; /* Secure by default */ + tlsInitThread(); +} + +void tlsInitThread(void) +{ + serverAssert(pending_list == nullptr); + pending_list = listCreate(); } /* Attempt to configure/reconfigure TLS. This operation is atomic and will @@ -159,6 +167,7 @@ void tlsInit(void) { int tlsConfigure(redisTLSContextConfig *ctx_config) { char errbuf[256]; SSL_CTX *ctx = NULL; + int protocols; if (!ctx_config->cert_file) { serverLog(LL_WARNING, "No tls-cert-file configured!"); @@ -184,7 +193,7 @@ int tlsConfigure(redisTLSContextConfig *ctx_config) { SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); #endif - int protocols = parseProtocolsConfig(ctx_config->protocols); + protocols = parseProtocolsConfig(ctx_config->protocols); if (protocols == -1) goto error; if (!(protocols & REDIS_TLS_PROTO_TLSv1)) @@ -272,8 +281,11 @@ int tlsConfigure(redisTLSContextConfig *ctx_config) { } #endif + { + std::unique_lock ul(g_ctxtlock); SSL_CTX_free(redis_tls_ctx); redis_tls_ctx = ctx; + } return C_OK; @@ -289,7 +301,7 @@ error: #define TLSCONN_DEBUG(fmt, ...) #endif -ConnectionType CT_TLS; +extern ConnectionType CT_TLS; /* Normal socket connections have a simple events/handler correlation. * @@ -307,6 +319,7 @@ ConnectionType CT_TLS; */ typedef enum { + WANT_INVALID = 0, WANT_READ = 1, WANT_WRITE } WantIOType; @@ -321,16 +334,24 @@ typedef struct tls_connection { SSL *ssl; char *ssl_error; listNode *pending_list_node; + aeEventLoop *el; } tls_connection; connection *connCreateTLS(void) { - tls_connection *conn = zcalloc(sizeof(tls_connection)); + tls_connection *conn = (tls_connection*)zcalloc(sizeof(tls_connection), MALLOC_LOCAL); conn->c.type = &CT_TLS; conn->c.fd = -1; + std::unique_lock ul(g_ctxtlock); conn->ssl = SSL_new(redis_tls_ctx); + conn->el = serverTL->el; return (connection *) conn; } +void connTLSMarshalThread(connection *c) { + tls_connection *conn = (tls_connection*)c; + conn->el = serverTL->el; +} + connection *connCreateAcceptedTLS(int fd, int require_auth) { tls_connection *conn = (tls_connection *) connCreateTLS(); conn->c.fd = fd; @@ -374,7 +395,7 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType * /* Error! */ conn->c.last_errno = 0; if (conn->ssl_error) zfree(conn->ssl_error); - conn->ssl_error = zmalloc(512); + conn->ssl_error = (char*)zmalloc(512); ERR_error_string_n(ERR_get_error(), conn->ssl_error, 512); break; } @@ -386,17 +407,19 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType * } void registerSSLEvent(tls_connection *conn, WantIOType want) { - int mask = aeGetFileEvents(server.el, conn->c.fd); + int mask = aeGetFileEvents(serverTL->el, conn->c.fd); + + serverAssert(conn->el == serverTL->el); switch (want) { case WANT_READ: - if (mask & AE_WRITABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); - if (!(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, + if (mask & AE_WRITABLE) aeDeleteFileEvent(conn->el, conn->c.fd, AE_WRITABLE); + if (!(mask & AE_READABLE)) aeCreateFileEvent(conn->el, conn->c.fd, AE_READABLE|AE_READ_THREADSAFE, tlsEventHandler, conn); break; case WANT_WRITE: - if (mask & AE_READABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); - if (!(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, + if (mask & AE_READABLE) aeDeleteFileEvent(conn->el, conn->c.fd, AE_READABLE); + if (!(mask & AE_WRITABLE)) aeCreateFileEvent(conn->el, conn->c.fd, AE_WRITABLE|AE_WRITE_THREADSAFE, tlsEventHandler, conn); break; default: @@ -405,24 +428,38 @@ void registerSSLEvent(tls_connection *conn, WantIOType want) { } } -void updateSSLEvent(tls_connection *conn) { - int mask = aeGetFileEvents(server.el, conn->c.fd); +void updateSSLEventCore(tls_connection *conn) { + int mask = aeGetFileEvents(serverTL->el, conn->c.fd); int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ); int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE); + serverAssert(conn->el == serverTL->el); + if (need_read && !(mask & AE_READABLE)) - aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); + aeCreateFileEvent(serverTL->el, conn->c.fd, AE_READABLE|AE_READ_THREADSAFE, tlsEventHandler, conn); if (!need_read && (mask & AE_READABLE)) - aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); + aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_READABLE); if (need_write && !(mask & AE_WRITABLE)) - aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); + aeCreateFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE|AE_WRITE_THREADSAFE, tlsEventHandler, conn); if (!need_write && (mask & AE_WRITABLE)) - aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); + aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE); } -static void tlsHandleEvent(tls_connection *conn, int mask) { +void updateSSLEvent(tls_connection *conn) { + if (conn->el != serverTL->el) { + aePostFunction(conn->el, [conn]{ + updateSSLEventCore(conn); + }); + } else { + updateSSLEventCore(conn); + } +} + +void tlsHandleEvent(tls_connection *conn, int mask) { int ret; + serverAssert(!GlobalLocksAcquired()); + serverAssert(conn->el == serverTL->el); TLSCONN_DEBUG("tlsEventHandler(): fd=%d, state=%d, mask=%d, r=%d, w=%d, flags=%d", fd, conn->c.state, mask, conn->c.read_handler != NULL, conn->c.write_handler != NULL, @@ -442,7 +479,7 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { } ret = SSL_connect(conn->ssl); if (ret <= 0) { - WantIOType want = 0; + WantIOType want = WANT_INVALID; if (!handleSSLReturnCode(conn, ret, &want)) { registerSSLEvent(conn, want); @@ -460,13 +497,17 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { } } + { + AeLocker locker; + locker.arm(nullptr); if (!callHandler((connection *) conn, conn->c.conn_handler)) return; + } conn->c.conn_handler = NULL; break; case CONN_STATE_ACCEPTING: ret = SSL_accept(conn->ssl); if (ret <= 0) { - WantIOType want = 0; + WantIOType want = WANT_INVALID; if (!handleSSLReturnCode(conn, ret, &want)) { /* Avoid hitting UpdateSSLEvent, which knows nothing * of what SSL_connect() wants and instead looks at our @@ -482,7 +523,11 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { conn->c.state = CONN_STATE_CONNECTED; } + { + AeLocker locker; + locker.arm(nullptr); if (!callHandler((connection *) conn, conn->c.conn_handler)) return; + } conn->c.conn_handler = NULL; break; case CONN_STATE_CONNECTED: @@ -506,12 +551,18 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER; if (!invert && call_read) { + AeLocker lock; + if (!(conn->c.flags & CONN_FLAG_READ_THREADSAFE)) + lock.arm(nullptr); conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *) conn, conn->c.read_handler)) return; } /* Fire the writable event. */ if (call_write) { + AeLocker lock; + if (!(conn->c.flags & CONN_FLAG_WRITE_THREADSAFE)) + lock.arm(nullptr); conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ; if (!callHandler((connection *) conn, conn->c.write_handler)) return; } @@ -519,6 +570,9 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { /* If we have to invert the call, fire the readable event now * after the writable one. */ if (invert && call_read) { + AeLocker lock; + if (!(conn->c.flags & CONN_FLAG_READ_THREADSAFE)) + lock.arm(nullptr); conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *) conn, conn->c.read_handler)) return; } @@ -550,12 +604,12 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) { UNUSED(el); UNUSED(fd); - tls_connection *conn = clientData; + tls_connection *conn = (tls_connection*)clientData; tlsHandleEvent(conn, mask); } -static void connTLSClose(connection *conn_) { - tls_connection *conn = (tls_connection *) conn_; +static void connTLSCloseCore(tls_connection *conn) { + serverAssert(conn->el == serverTL->el); if (conn->ssl) { SSL_free(conn->ssl); @@ -572,13 +626,26 @@ static void connTLSClose(connection *conn_) { conn->pending_list_node = NULL; } - CT_Socket.close(conn_); + CT_Socket.close(&conn->c); +} + +static void connTLSClose(connection *conn_) { + tls_connection *conn = (tls_connection *) conn_; + if (conn->el != serverTL->el) { + aePostFunction(conn->el, [conn]{ + connTLSCloseCore(conn); + }); + } else { + connTLSCloseCore(conn); + } } static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) { tls_connection *conn = (tls_connection *) _conn; int ret; + serverAssert(conn->el == serverTL->el); + if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR; ERR_clear_error(); @@ -587,7 +654,7 @@ static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handle ret = SSL_accept(conn->ssl); if (ret <= 0) { - WantIOType want = 0; + WantIOType want = WANT_INVALID; if (!handleSSLReturnCode(conn, ret, &want)) { registerSSLEvent(conn, want); /* We'll fire back */ return C_OK; @@ -607,6 +674,8 @@ static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handle static int connTLSConnect(connection *conn_, const char *addr, int port, const char *src_addr, ConnectionCallbackFunc connect_handler) { tls_connection *conn = (tls_connection *) conn_; + serverAssert(conn->el == serverTL->el); + if (conn->c.state != CONN_STATE_NONE) return C_ERR; ERR_clear_error(); @@ -628,7 +697,7 @@ static int connTLSWrite(connection *conn_, const void *data, size_t data_len) { ret = SSL_write(conn->ssl, data, data_len); if (ret <= 0) { - WantIOType want = 0; + WantIOType want = WANT_INVALID; if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) { if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; updateSSLEvent(conn); @@ -654,11 +723,13 @@ static int connTLSRead(connection *conn_, void *buf, size_t buf_len) { int ret; int ssl_err; + serverAssert(conn->el == serverTL->el); + if (conn->c.state != CONN_STATE_CONNECTED) return -1; ERR_clear_error(); ret = SSL_read(conn->ssl, buf, buf_len); if (ret <= 0) { - WantIOType want = 0; + WantIOType want = WANT_INVALID; if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) { if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; updateSSLEvent(conn); @@ -687,18 +758,32 @@ static const char *connTLSGetLastError(connection *conn_) { return NULL; } -int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) { +int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier, bool fThreadSafe) { + serverAssert(((tls_connection*)conn)->el == serverTL->el); conn->write_handler = func; if (barrier) conn->flags |= CONN_FLAG_WRITE_BARRIER; else conn->flags &= ~CONN_FLAG_WRITE_BARRIER; + + if (fThreadSafe) + conn->flags |= CONN_FLAG_WRITE_THREADSAFE; + else + conn->flags &= ~CONN_FLAG_WRITE_THREADSAFE; + updateSSLEvent((tls_connection *) conn); return C_OK; } -int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) { +int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func, bool fThreadSafe) { + serverAssert(((tls_connection*)conn)->el == serverTL->el); conn->read_handler = func; + + if (fThreadSafe) + conn->flags |= CONN_FLAG_READ_THREADSAFE; + else + conn->flags &= ~CONN_FLAG_READ_THREADSAFE; + updateSSLEvent((tls_connection *) conn); return C_OK; } @@ -719,6 +804,8 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port, tls_connection *conn = (tls_connection *) conn_; int ret; + serverAssert(conn->el == serverTL->el); + if (conn->c.state != CONN_STATE_NONE) return C_ERR; /* Initiate socket blocking connect first */ @@ -739,9 +826,11 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port, return C_OK; } -static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long long timeout) { +static ssize_t connTLSSyncWrite(connection *conn_, const char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *) conn_; + serverAssert(conn->el == serverTL->el); + setBlockingTimeout(conn, timeout); SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE); int ret = SSL_write(conn->ssl, ptr, size); @@ -754,6 +843,8 @@ static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *) conn_; + serverAssert(conn->el == serverTL->el); + setBlockingTimeout(conn, timeout); int ret = SSL_read(conn->ssl, ptr, size); unsetBlockingTimeout(conn); @@ -765,6 +856,8 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l tls_connection *conn = (tls_connection *) conn_; ssize_t nread = 0; + serverAssert(conn->el == serverTL->el); + setBlockingTimeout(conn, timeout); size--; @@ -792,19 +885,20 @@ exit: } ConnectionType CT_TLS = { - .ae_handler = tlsEventHandler, - .accept = connTLSAccept, - .connect = connTLSConnect, - .blocking_connect = connTLSBlockingConnect, - .read = connTLSRead, - .write = connTLSWrite, - .close = connTLSClose, - .set_write_handler = connTLSSetWriteHandler, - .set_read_handler = connTLSSetReadHandler, - .get_last_error = connTLSGetLastError, - .sync_write = connTLSSyncWrite, - .sync_read = connTLSSyncRead, - .sync_readline = connTLSSyncReadLine, + tlsEventHandler, + connTLSConnect, + connTLSWrite, + connTLSRead, + connTLSClose, + connTLSAccept, + connTLSSetWriteHandler, + connTLSSetReadHandler, + connTLSGetLastError, + connTLSBlockingConnect, + connTLSSyncWrite, + connTLSSyncRead, + connTLSSyncReadLine, + connTLSMarshalThread, }; int tlsHasPendingData() { @@ -820,7 +914,7 @@ int tlsProcessPendingData() { int processed = listLength(pending_list); listRewind(pending_list,&li); while((ln = listNext(&li))) { - tls_connection *conn = listNodeValue(ln); + tls_connection *conn = (tls_connection*)listNodeValue(ln); tlsHandleEvent(conn, AE_READABLE); } return processed; @@ -855,4 +949,6 @@ int tlsProcessPendingData() { return 0; } +void tlsInitThread() {} + #endif