Enable TLS connections
Former-commit-id: d05da0fabdfeb7eadce6546c7c1d85739b2794d7
This commit is contained in:
parent
d176ae50ec
commit
845027c291
@ -236,7 +236,8 @@ ifeq ($(MALLOC),memkind)
|
|||||||
endif
|
endif
|
||||||
|
|
||||||
ifeq ($(BUILD_TLS),yes)
|
ifeq ($(BUILD_TLS),yes)
|
||||||
FINAL_CFLAGS+=-DUSE_OPENSSL $(OPENSSL_CFLAGS)
|
FINAL_CFLAGS+=-DUSE_OPENSSL $(OPENSSL_CXXFLAGS)
|
||||||
|
FINAL_CXXFLAGS+=-DUSE_OPENSSL $(OPENSSL_CXXFLAGS)
|
||||||
FINAL_LDFLAGS+=$(OPENSSL_LDFLAGS)
|
FINAL_LDFLAGS+=$(OPENSSL_LDFLAGS)
|
||||||
FINAL_LIBS += ../deps/hiredis/libhiredis_ssl.a -lssl -lcrypto
|
FINAL_LIBS += ../deps/hiredis/libhiredis_ssl.a -lssl -lcrypto
|
||||||
endif
|
endif
|
||||||
|
@ -68,6 +68,7 @@ typedef struct ConnectionType {
|
|||||||
ssize_t (*sync_write)(struct connection *conn, const char *ptr, ssize_t size, long long timeout);
|
ssize_t (*sync_write)(struct connection *conn, const char *ptr, ssize_t size, long long timeout);
|
||||||
ssize_t (*sync_read)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
|
ssize_t (*sync_read)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
|
||||||
ssize_t (*sync_readline)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
|
ssize_t (*sync_readline)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
|
||||||
|
void (*marshal_thread)(struct connection *conn);
|
||||||
} ConnectionType;
|
} ConnectionType;
|
||||||
|
|
||||||
struct connection {
|
struct connection {
|
||||||
@ -198,6 +199,11 @@ static inline ssize_t connSyncReadLine(connection *conn, char *ptr, ssize_t size
|
|||||||
return conn->type->sync_readline(conn, ptr, size, timeout);
|
return conn->type->sync_readline(conn, ptr, size, timeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline void connMarshalThread(connection *conn) {
|
||||||
|
if (conn->type->marshal_thread != nullptr)
|
||||||
|
conn->type->marshal_thread(conn);
|
||||||
|
}
|
||||||
|
|
||||||
connection *connCreateSocket();
|
connection *connCreateSocket();
|
||||||
connection *connCreateAcceptedSocket(int fd);
|
connection *connCreateAcceptedSocket(int fd);
|
||||||
|
|
||||||
|
@ -1313,7 +1313,8 @@ void acceptOnThread(connection *conn, int flags, char *cip)
|
|||||||
szT = (char*)zmalloc(NET_IP_STR_LEN, MALLOC_LOCAL);
|
szT = (char*)zmalloc(NET_IP_STR_LEN, MALLOC_LOCAL);
|
||||||
memcpy(szT, cip, NET_IP_STR_LEN);
|
memcpy(szT, cip, NET_IP_STR_LEN);
|
||||||
}
|
}
|
||||||
int res = aePostFunction(g_pserver->rgthreadvar[ielTarget].el, [conn, flags, ielTarget, szT]{
|
int res = aePostFunction(g_pserver->rgthreadvar[ielTarget].el, [conn, flags, ielTarget, szT] {
|
||||||
|
connMarshalThread(conn);
|
||||||
acceptCommonHandler(conn,flags,szT,ielTarget);
|
acceptCommonHandler(conn,flags,szT,ielTarget);
|
||||||
if (!g_fTestMode && !g_pserver->loading)
|
if (!g_fTestMode && !g_pserver->loading)
|
||||||
rgacceptsInFlight[ielTarget].fetch_sub(1, std::memory_order_relaxed);
|
rgacceptsInFlight[ielTarget].fetch_sub(1, std::memory_order_relaxed);
|
||||||
|
@ -32,7 +32,9 @@
|
|||||||
#include "hiredis.h"
|
#include "hiredis.h"
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
#include "openssl/ssl.h"
|
#include "openssl/ssl.h"
|
||||||
|
extern "C" {
|
||||||
#include "hiredis_ssl.h"
|
#include "hiredis_ssl.h"
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
#include "async.h"
|
#include "async.h"
|
||||||
|
|
||||||
|
@ -2940,7 +2940,7 @@ static void initNetworking(int fReusePort)
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Abort if there are no listening sockets at all. */
|
/* Abort if there are no listening sockets at all. */
|
||||||
if (g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].ipfd_count == 0 && g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].tlsfd_count && g_pserver->sofd < 0) {
|
if (g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].ipfd_count == 0 && g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].tlsfd_count == 0 && g_pserver->sofd < 0) {
|
||||||
serverLog(LL_WARNING, "Configured to not listen anywhere, exiting.");
|
serverLog(LL_WARNING, "Configured to not listen anywhere, exiting.");
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
@ -5305,6 +5305,7 @@ void *workerThreadMain(void *parg)
|
|||||||
int iel = (int)((int64_t)parg);
|
int iel = (int)((int64_t)parg);
|
||||||
serverLog(LOG_INFO, "Thread %d alive.", iel);
|
serverLog(LOG_INFO, "Thread %d alive.", iel);
|
||||||
serverTL = g_pserver->rgthreadvar+iel; // set the TLS threadsafe global
|
serverTL = g_pserver->rgthreadvar+iel; // set the TLS threadsafe global
|
||||||
|
tlsInitThread();
|
||||||
|
|
||||||
if (iel != IDX_EVENT_LOOP_MAIN)
|
if (iel != IDX_EVENT_LOOP_MAIN)
|
||||||
{
|
{
|
||||||
|
@ -3061,6 +3061,7 @@ inline int FCorrectThread(client *c)
|
|||||||
|
|
||||||
/* TLS stuff */
|
/* TLS stuff */
|
||||||
void tlsInit(void);
|
void tlsInit(void);
|
||||||
|
void tlsInitThread();
|
||||||
int tlsConfigure(redisTLSContextConfig *ctx_config);
|
int tlsConfigure(redisTLSContextConfig *ctx_config);
|
||||||
|
|
||||||
|
|
||||||
|
188
src/tls.cpp
188
src/tls.cpp
@ -31,6 +31,8 @@
|
|||||||
#include "server.h"
|
#include "server.h"
|
||||||
#include "connhelpers.h"
|
#include "connhelpers.h"
|
||||||
#include "adlist.h"
|
#include "adlist.h"
|
||||||
|
#include "aelocker.h"
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
#ifdef USE_OPENSSL
|
#ifdef USE_OPENSSL
|
||||||
|
|
||||||
@ -53,6 +55,7 @@
|
|||||||
extern ConnectionType CT_Socket;
|
extern ConnectionType CT_Socket;
|
||||||
|
|
||||||
SSL_CTX *redis_tls_ctx;
|
SSL_CTX *redis_tls_ctx;
|
||||||
|
fastlock g_ctxtlock("SSL CTX");
|
||||||
|
|
||||||
static int parseProtocolsConfig(const char *str) {
|
static int parseProtocolsConfig(const char *str) {
|
||||||
int i, count = 0;
|
int i, count = 0;
|
||||||
@ -91,7 +94,7 @@ static int parseProtocolsConfig(const char *str) {
|
|||||||
|
|
||||||
/* list of connections with pending data already read from the socket, but not
|
/* list of connections with pending data already read from the socket, but not
|
||||||
* served to the reader yet. */
|
* served to the reader yet. */
|
||||||
static list *pending_list = NULL;
|
static thread_local list *pending_list = NULL;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* OpenSSL global initialization and locking handling callbacks.
|
* OpenSSL global initialization and locking handling callbacks.
|
||||||
@ -147,10 +150,15 @@ void tlsInit(void) {
|
|||||||
serverLog(LL_WARNING, "OpenSSL: Failed to seed random number generator.");
|
serverLog(LL_WARNING, "OpenSSL: Failed to seed random number generator.");
|
||||||
}
|
}
|
||||||
|
|
||||||
pending_list = listCreate();
|
|
||||||
|
|
||||||
/* Server configuration */
|
/* Server configuration */
|
||||||
server.tls_auth_clients = 1; /* Secure by default */
|
g_pserver->tls_auth_clients = 1; /* Secure by default */
|
||||||
|
tlsInitThread();
|
||||||
|
}
|
||||||
|
|
||||||
|
void tlsInitThread(void)
|
||||||
|
{
|
||||||
|
serverAssert(pending_list == nullptr);
|
||||||
|
pending_list = listCreate();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Attempt to configure/reconfigure TLS. This operation is atomic and will
|
/* Attempt to configure/reconfigure TLS. This operation is atomic and will
|
||||||
@ -159,6 +167,7 @@ void tlsInit(void) {
|
|||||||
int tlsConfigure(redisTLSContextConfig *ctx_config) {
|
int tlsConfigure(redisTLSContextConfig *ctx_config) {
|
||||||
char errbuf[256];
|
char errbuf[256];
|
||||||
SSL_CTX *ctx = NULL;
|
SSL_CTX *ctx = NULL;
|
||||||
|
int protocols;
|
||||||
|
|
||||||
if (!ctx_config->cert_file) {
|
if (!ctx_config->cert_file) {
|
||||||
serverLog(LL_WARNING, "No tls-cert-file configured!");
|
serverLog(LL_WARNING, "No tls-cert-file configured!");
|
||||||
@ -184,7 +193,7 @@ int tlsConfigure(redisTLSContextConfig *ctx_config) {
|
|||||||
SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
|
SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int protocols = parseProtocolsConfig(ctx_config->protocols);
|
protocols = parseProtocolsConfig(ctx_config->protocols);
|
||||||
if (protocols == -1) goto error;
|
if (protocols == -1) goto error;
|
||||||
|
|
||||||
if (!(protocols & REDIS_TLS_PROTO_TLSv1))
|
if (!(protocols & REDIS_TLS_PROTO_TLSv1))
|
||||||
@ -272,8 +281,11 @@ int tlsConfigure(redisTLSContextConfig *ctx_config) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<fastlock> ul(g_ctxtlock);
|
||||||
SSL_CTX_free(redis_tls_ctx);
|
SSL_CTX_free(redis_tls_ctx);
|
||||||
redis_tls_ctx = ctx;
|
redis_tls_ctx = ctx;
|
||||||
|
}
|
||||||
|
|
||||||
return C_OK;
|
return C_OK;
|
||||||
|
|
||||||
@ -289,7 +301,7 @@ error:
|
|||||||
#define TLSCONN_DEBUG(fmt, ...)
|
#define TLSCONN_DEBUG(fmt, ...)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ConnectionType CT_TLS;
|
extern ConnectionType CT_TLS;
|
||||||
|
|
||||||
/* Normal socket connections have a simple events/handler correlation.
|
/* Normal socket connections have a simple events/handler correlation.
|
||||||
*
|
*
|
||||||
@ -307,6 +319,7 @@ ConnectionType CT_TLS;
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
|
WANT_INVALID = 0,
|
||||||
WANT_READ = 1,
|
WANT_READ = 1,
|
||||||
WANT_WRITE
|
WANT_WRITE
|
||||||
} WantIOType;
|
} WantIOType;
|
||||||
@ -321,16 +334,24 @@ typedef struct tls_connection {
|
|||||||
SSL *ssl;
|
SSL *ssl;
|
||||||
char *ssl_error;
|
char *ssl_error;
|
||||||
listNode *pending_list_node;
|
listNode *pending_list_node;
|
||||||
|
aeEventLoop *el;
|
||||||
} tls_connection;
|
} tls_connection;
|
||||||
|
|
||||||
connection *connCreateTLS(void) {
|
connection *connCreateTLS(void) {
|
||||||
tls_connection *conn = zcalloc(sizeof(tls_connection));
|
tls_connection *conn = (tls_connection*)zcalloc(sizeof(tls_connection), MALLOC_LOCAL);
|
||||||
conn->c.type = &CT_TLS;
|
conn->c.type = &CT_TLS;
|
||||||
conn->c.fd = -1;
|
conn->c.fd = -1;
|
||||||
|
std::unique_lock<fastlock> ul(g_ctxtlock);
|
||||||
conn->ssl = SSL_new(redis_tls_ctx);
|
conn->ssl = SSL_new(redis_tls_ctx);
|
||||||
|
conn->el = serverTL->el;
|
||||||
return (connection *) conn;
|
return (connection *) conn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void connTLSMarshalThread(connection *c) {
|
||||||
|
tls_connection *conn = (tls_connection*)c;
|
||||||
|
conn->el = serverTL->el;
|
||||||
|
}
|
||||||
|
|
||||||
connection *connCreateAcceptedTLS(int fd, int require_auth) {
|
connection *connCreateAcceptedTLS(int fd, int require_auth) {
|
||||||
tls_connection *conn = (tls_connection *) connCreateTLS();
|
tls_connection *conn = (tls_connection *) connCreateTLS();
|
||||||
conn->c.fd = fd;
|
conn->c.fd = fd;
|
||||||
@ -374,7 +395,7 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *
|
|||||||
/* Error! */
|
/* Error! */
|
||||||
conn->c.last_errno = 0;
|
conn->c.last_errno = 0;
|
||||||
if (conn->ssl_error) zfree(conn->ssl_error);
|
if (conn->ssl_error) zfree(conn->ssl_error);
|
||||||
conn->ssl_error = zmalloc(512);
|
conn->ssl_error = (char*)zmalloc(512);
|
||||||
ERR_error_string_n(ERR_get_error(), conn->ssl_error, 512);
|
ERR_error_string_n(ERR_get_error(), conn->ssl_error, 512);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -386,17 +407,19 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *
|
|||||||
}
|
}
|
||||||
|
|
||||||
void registerSSLEvent(tls_connection *conn, WantIOType want) {
|
void registerSSLEvent(tls_connection *conn, WantIOType want) {
|
||||||
int mask = aeGetFileEvents(server.el, conn->c.fd);
|
int mask = aeGetFileEvents(serverTL->el, conn->c.fd);
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
switch (want) {
|
switch (want) {
|
||||||
case WANT_READ:
|
case WANT_READ:
|
||||||
if (mask & AE_WRITABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE);
|
if (mask & AE_WRITABLE) aeDeleteFileEvent(conn->el, conn->c.fd, AE_WRITABLE);
|
||||||
if (!(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE,
|
if (!(mask & AE_READABLE)) aeCreateFileEvent(conn->el, conn->c.fd, AE_READABLE|AE_READ_THREADSAFE,
|
||||||
tlsEventHandler, conn);
|
tlsEventHandler, conn);
|
||||||
break;
|
break;
|
||||||
case WANT_WRITE:
|
case WANT_WRITE:
|
||||||
if (mask & AE_READABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE);
|
if (mask & AE_READABLE) aeDeleteFileEvent(conn->el, conn->c.fd, AE_READABLE);
|
||||||
if (!(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE,
|
if (!(mask & AE_WRITABLE)) aeCreateFileEvent(conn->el, conn->c.fd, AE_WRITABLE|AE_WRITE_THREADSAFE,
|
||||||
tlsEventHandler, conn);
|
tlsEventHandler, conn);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
@ -405,24 +428,38 @@ void registerSSLEvent(tls_connection *conn, WantIOType want) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateSSLEvent(tls_connection *conn) {
|
void updateSSLEventCore(tls_connection *conn) {
|
||||||
int mask = aeGetFileEvents(server.el, conn->c.fd);
|
int mask = aeGetFileEvents(serverTL->el, conn->c.fd);
|
||||||
int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ);
|
int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ);
|
||||||
int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE);
|
int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE);
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
if (need_read && !(mask & AE_READABLE))
|
if (need_read && !(mask & AE_READABLE))
|
||||||
aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn);
|
aeCreateFileEvent(serverTL->el, conn->c.fd, AE_READABLE|AE_READ_THREADSAFE, tlsEventHandler, conn);
|
||||||
if (!need_read && (mask & AE_READABLE))
|
if (!need_read && (mask & AE_READABLE))
|
||||||
aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE);
|
aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_READABLE);
|
||||||
|
|
||||||
if (need_write && !(mask & AE_WRITABLE))
|
if (need_write && !(mask & AE_WRITABLE))
|
||||||
aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn);
|
aeCreateFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE|AE_WRITE_THREADSAFE, tlsEventHandler, conn);
|
||||||
if (!need_write && (mask & AE_WRITABLE))
|
if (!need_write && (mask & AE_WRITABLE))
|
||||||
aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE);
|
aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void tlsHandleEvent(tls_connection *conn, int mask) {
|
void updateSSLEvent(tls_connection *conn) {
|
||||||
|
if (conn->el != serverTL->el) {
|
||||||
|
aePostFunction(conn->el, [conn]{
|
||||||
|
updateSSLEventCore(conn);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
updateSSLEventCore(conn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void tlsHandleEvent(tls_connection *conn, int mask) {
|
||||||
int ret;
|
int ret;
|
||||||
|
serverAssert(!GlobalLocksAcquired());
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
TLSCONN_DEBUG("tlsEventHandler(): fd=%d, state=%d, mask=%d, r=%d, w=%d, flags=%d",
|
TLSCONN_DEBUG("tlsEventHandler(): fd=%d, state=%d, mask=%d, r=%d, w=%d, flags=%d",
|
||||||
fd, conn->c.state, mask, conn->c.read_handler != NULL, conn->c.write_handler != NULL,
|
fd, conn->c.state, mask, conn->c.read_handler != NULL, conn->c.write_handler != NULL,
|
||||||
@ -442,7 +479,7 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
|
|||||||
}
|
}
|
||||||
ret = SSL_connect(conn->ssl);
|
ret = SSL_connect(conn->ssl);
|
||||||
if (ret <= 0) {
|
if (ret <= 0) {
|
||||||
WantIOType want = 0;
|
WantIOType want = WANT_INVALID;
|
||||||
if (!handleSSLReturnCode(conn, ret, &want)) {
|
if (!handleSSLReturnCode(conn, ret, &want)) {
|
||||||
registerSSLEvent(conn, want);
|
registerSSLEvent(conn, want);
|
||||||
|
|
||||||
@ -460,13 +497,17 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
AeLocker locker;
|
||||||
|
locker.arm(nullptr);
|
||||||
if (!callHandler((connection *) conn, conn->c.conn_handler)) return;
|
if (!callHandler((connection *) conn, conn->c.conn_handler)) return;
|
||||||
|
}
|
||||||
conn->c.conn_handler = NULL;
|
conn->c.conn_handler = NULL;
|
||||||
break;
|
break;
|
||||||
case CONN_STATE_ACCEPTING:
|
case CONN_STATE_ACCEPTING:
|
||||||
ret = SSL_accept(conn->ssl);
|
ret = SSL_accept(conn->ssl);
|
||||||
if (ret <= 0) {
|
if (ret <= 0) {
|
||||||
WantIOType want = 0;
|
WantIOType want = WANT_INVALID;
|
||||||
if (!handleSSLReturnCode(conn, ret, &want)) {
|
if (!handleSSLReturnCode(conn, ret, &want)) {
|
||||||
/* Avoid hitting UpdateSSLEvent, which knows nothing
|
/* Avoid hitting UpdateSSLEvent, which knows nothing
|
||||||
* of what SSL_connect() wants and instead looks at our
|
* of what SSL_connect() wants and instead looks at our
|
||||||
@ -482,7 +523,11 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
|
|||||||
conn->c.state = CONN_STATE_CONNECTED;
|
conn->c.state = CONN_STATE_CONNECTED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
AeLocker locker;
|
||||||
|
locker.arm(nullptr);
|
||||||
if (!callHandler((connection *) conn, conn->c.conn_handler)) return;
|
if (!callHandler((connection *) conn, conn->c.conn_handler)) return;
|
||||||
|
}
|
||||||
conn->c.conn_handler = NULL;
|
conn->c.conn_handler = NULL;
|
||||||
break;
|
break;
|
||||||
case CONN_STATE_CONNECTED:
|
case CONN_STATE_CONNECTED:
|
||||||
@ -506,12 +551,18 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
|
|||||||
int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER;
|
int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER;
|
||||||
|
|
||||||
if (!invert && call_read) {
|
if (!invert && call_read) {
|
||||||
|
AeLocker lock;
|
||||||
|
if (!(conn->c.flags & CONN_FLAG_READ_THREADSAFE))
|
||||||
|
lock.arm(nullptr);
|
||||||
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
|
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
|
||||||
if (!callHandler((connection *) conn, conn->c.read_handler)) return;
|
if (!callHandler((connection *) conn, conn->c.read_handler)) return;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Fire the writable event. */
|
/* Fire the writable event. */
|
||||||
if (call_write) {
|
if (call_write) {
|
||||||
|
AeLocker lock;
|
||||||
|
if (!(conn->c.flags & CONN_FLAG_WRITE_THREADSAFE))
|
||||||
|
lock.arm(nullptr);
|
||||||
conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ;
|
conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ;
|
||||||
if (!callHandler((connection *) conn, conn->c.write_handler)) return;
|
if (!callHandler((connection *) conn, conn->c.write_handler)) return;
|
||||||
}
|
}
|
||||||
@ -519,6 +570,9 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
|
|||||||
/* If we have to invert the call, fire the readable event now
|
/* If we have to invert the call, fire the readable event now
|
||||||
* after the writable one. */
|
* after the writable one. */
|
||||||
if (invert && call_read) {
|
if (invert && call_read) {
|
||||||
|
AeLocker lock;
|
||||||
|
if (!(conn->c.flags & CONN_FLAG_READ_THREADSAFE))
|
||||||
|
lock.arm(nullptr);
|
||||||
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
|
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
|
||||||
if (!callHandler((connection *) conn, conn->c.read_handler)) return;
|
if (!callHandler((connection *) conn, conn->c.read_handler)) return;
|
||||||
}
|
}
|
||||||
@ -550,12 +604,12 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
|
|||||||
static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) {
|
static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) {
|
||||||
UNUSED(el);
|
UNUSED(el);
|
||||||
UNUSED(fd);
|
UNUSED(fd);
|
||||||
tls_connection *conn = clientData;
|
tls_connection *conn = (tls_connection*)clientData;
|
||||||
tlsHandleEvent(conn, mask);
|
tlsHandleEvent(conn, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void connTLSClose(connection *conn_) {
|
static void connTLSCloseCore(tls_connection *conn) {
|
||||||
tls_connection *conn = (tls_connection *) conn_;
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
if (conn->ssl) {
|
if (conn->ssl) {
|
||||||
SSL_free(conn->ssl);
|
SSL_free(conn->ssl);
|
||||||
@ -572,13 +626,26 @@ static void connTLSClose(connection *conn_) {
|
|||||||
conn->pending_list_node = NULL;
|
conn->pending_list_node = NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
CT_Socket.close(conn_);
|
CT_Socket.close(&conn->c);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void connTLSClose(connection *conn_) {
|
||||||
|
tls_connection *conn = (tls_connection *) conn_;
|
||||||
|
if (conn->el != serverTL->el) {
|
||||||
|
aePostFunction(conn->el, [conn]{
|
||||||
|
connTLSCloseCore(conn);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
connTLSCloseCore(conn);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) {
|
static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) {
|
||||||
tls_connection *conn = (tls_connection *) _conn;
|
tls_connection *conn = (tls_connection *) _conn;
|
||||||
int ret;
|
int ret;
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR;
|
if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR;
|
||||||
ERR_clear_error();
|
ERR_clear_error();
|
||||||
|
|
||||||
@ -587,7 +654,7 @@ static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handle
|
|||||||
ret = SSL_accept(conn->ssl);
|
ret = SSL_accept(conn->ssl);
|
||||||
|
|
||||||
if (ret <= 0) {
|
if (ret <= 0) {
|
||||||
WantIOType want = 0;
|
WantIOType want = WANT_INVALID;
|
||||||
if (!handleSSLReturnCode(conn, ret, &want)) {
|
if (!handleSSLReturnCode(conn, ret, &want)) {
|
||||||
registerSSLEvent(conn, want); /* We'll fire back */
|
registerSSLEvent(conn, want); /* We'll fire back */
|
||||||
return C_OK;
|
return C_OK;
|
||||||
@ -607,6 +674,8 @@ static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handle
|
|||||||
static int connTLSConnect(connection *conn_, const char *addr, int port, const char *src_addr, ConnectionCallbackFunc connect_handler) {
|
static int connTLSConnect(connection *conn_, const char *addr, int port, const char *src_addr, ConnectionCallbackFunc connect_handler) {
|
||||||
tls_connection *conn = (tls_connection *) conn_;
|
tls_connection *conn = (tls_connection *) conn_;
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
if (conn->c.state != CONN_STATE_NONE) return C_ERR;
|
if (conn->c.state != CONN_STATE_NONE) return C_ERR;
|
||||||
ERR_clear_error();
|
ERR_clear_error();
|
||||||
|
|
||||||
@ -628,7 +697,7 @@ static int connTLSWrite(connection *conn_, const void *data, size_t data_len) {
|
|||||||
ret = SSL_write(conn->ssl, data, data_len);
|
ret = SSL_write(conn->ssl, data, data_len);
|
||||||
|
|
||||||
if (ret <= 0) {
|
if (ret <= 0) {
|
||||||
WantIOType want = 0;
|
WantIOType want = WANT_INVALID;
|
||||||
if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) {
|
if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) {
|
||||||
if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ;
|
if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ;
|
||||||
updateSSLEvent(conn);
|
updateSSLEvent(conn);
|
||||||
@ -654,11 +723,13 @@ static int connTLSRead(connection *conn_, void *buf, size_t buf_len) {
|
|||||||
int ret;
|
int ret;
|
||||||
int ssl_err;
|
int ssl_err;
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
if (conn->c.state != CONN_STATE_CONNECTED) return -1;
|
if (conn->c.state != CONN_STATE_CONNECTED) return -1;
|
||||||
ERR_clear_error();
|
ERR_clear_error();
|
||||||
ret = SSL_read(conn->ssl, buf, buf_len);
|
ret = SSL_read(conn->ssl, buf, buf_len);
|
||||||
if (ret <= 0) {
|
if (ret <= 0) {
|
||||||
WantIOType want = 0;
|
WantIOType want = WANT_INVALID;
|
||||||
if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) {
|
if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) {
|
||||||
if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE;
|
if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE;
|
||||||
updateSSLEvent(conn);
|
updateSSLEvent(conn);
|
||||||
@ -687,18 +758,32 @@ static const char *connTLSGetLastError(connection *conn_) {
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) {
|
int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier, bool fThreadSafe) {
|
||||||
|
serverAssert(((tls_connection*)conn)->el == serverTL->el);
|
||||||
conn->write_handler = func;
|
conn->write_handler = func;
|
||||||
if (barrier)
|
if (barrier)
|
||||||
conn->flags |= CONN_FLAG_WRITE_BARRIER;
|
conn->flags |= CONN_FLAG_WRITE_BARRIER;
|
||||||
else
|
else
|
||||||
conn->flags &= ~CONN_FLAG_WRITE_BARRIER;
|
conn->flags &= ~CONN_FLAG_WRITE_BARRIER;
|
||||||
|
|
||||||
|
if (fThreadSafe)
|
||||||
|
conn->flags |= CONN_FLAG_WRITE_THREADSAFE;
|
||||||
|
else
|
||||||
|
conn->flags &= ~CONN_FLAG_WRITE_THREADSAFE;
|
||||||
|
|
||||||
updateSSLEvent((tls_connection *) conn);
|
updateSSLEvent((tls_connection *) conn);
|
||||||
return C_OK;
|
return C_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) {
|
int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func, bool fThreadSafe) {
|
||||||
|
serverAssert(((tls_connection*)conn)->el == serverTL->el);
|
||||||
conn->read_handler = func;
|
conn->read_handler = func;
|
||||||
|
|
||||||
|
if (fThreadSafe)
|
||||||
|
conn->flags |= CONN_FLAG_READ_THREADSAFE;
|
||||||
|
else
|
||||||
|
conn->flags &= ~CONN_FLAG_READ_THREADSAFE;
|
||||||
|
|
||||||
updateSSLEvent((tls_connection *) conn);
|
updateSSLEvent((tls_connection *) conn);
|
||||||
return C_OK;
|
return C_OK;
|
||||||
}
|
}
|
||||||
@ -719,6 +804,8 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,
|
|||||||
tls_connection *conn = (tls_connection *) conn_;
|
tls_connection *conn = (tls_connection *) conn_;
|
||||||
int ret;
|
int ret;
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
if (conn->c.state != CONN_STATE_NONE) return C_ERR;
|
if (conn->c.state != CONN_STATE_NONE) return C_ERR;
|
||||||
|
|
||||||
/* Initiate socket blocking connect first */
|
/* Initiate socket blocking connect first */
|
||||||
@ -739,9 +826,11 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,
|
|||||||
return C_OK;
|
return C_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long long timeout) {
|
static ssize_t connTLSSyncWrite(connection *conn_, const char *ptr, ssize_t size, long long timeout) {
|
||||||
tls_connection *conn = (tls_connection *) conn_;
|
tls_connection *conn = (tls_connection *) conn_;
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
setBlockingTimeout(conn, timeout);
|
setBlockingTimeout(conn, timeout);
|
||||||
SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||||
int ret = SSL_write(conn->ssl, ptr, size);
|
int ret = SSL_write(conn->ssl, ptr, size);
|
||||||
@ -754,6 +843,8 @@ static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long
|
|||||||
static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) {
|
static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) {
|
||||||
tls_connection *conn = (tls_connection *) conn_;
|
tls_connection *conn = (tls_connection *) conn_;
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
setBlockingTimeout(conn, timeout);
|
setBlockingTimeout(conn, timeout);
|
||||||
int ret = SSL_read(conn->ssl, ptr, size);
|
int ret = SSL_read(conn->ssl, ptr, size);
|
||||||
unsetBlockingTimeout(conn);
|
unsetBlockingTimeout(conn);
|
||||||
@ -765,6 +856,8 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
|
|||||||
tls_connection *conn = (tls_connection *) conn_;
|
tls_connection *conn = (tls_connection *) conn_;
|
||||||
ssize_t nread = 0;
|
ssize_t nread = 0;
|
||||||
|
|
||||||
|
serverAssert(conn->el == serverTL->el);
|
||||||
|
|
||||||
setBlockingTimeout(conn, timeout);
|
setBlockingTimeout(conn, timeout);
|
||||||
|
|
||||||
size--;
|
size--;
|
||||||
@ -792,19 +885,20 @@ exit:
|
|||||||
}
|
}
|
||||||
|
|
||||||
ConnectionType CT_TLS = {
|
ConnectionType CT_TLS = {
|
||||||
.ae_handler = tlsEventHandler,
|
tlsEventHandler,
|
||||||
.accept = connTLSAccept,
|
connTLSConnect,
|
||||||
.connect = connTLSConnect,
|
connTLSWrite,
|
||||||
.blocking_connect = connTLSBlockingConnect,
|
connTLSRead,
|
||||||
.read = connTLSRead,
|
connTLSClose,
|
||||||
.write = connTLSWrite,
|
connTLSAccept,
|
||||||
.close = connTLSClose,
|
connTLSSetWriteHandler,
|
||||||
.set_write_handler = connTLSSetWriteHandler,
|
connTLSSetReadHandler,
|
||||||
.set_read_handler = connTLSSetReadHandler,
|
connTLSGetLastError,
|
||||||
.get_last_error = connTLSGetLastError,
|
connTLSBlockingConnect,
|
||||||
.sync_write = connTLSSyncWrite,
|
connTLSSyncWrite,
|
||||||
.sync_read = connTLSSyncRead,
|
connTLSSyncRead,
|
||||||
.sync_readline = connTLSSyncReadLine,
|
connTLSSyncReadLine,
|
||||||
|
connTLSMarshalThread,
|
||||||
};
|
};
|
||||||
|
|
||||||
int tlsHasPendingData() {
|
int tlsHasPendingData() {
|
||||||
@ -820,7 +914,7 @@ int tlsProcessPendingData() {
|
|||||||
int processed = listLength(pending_list);
|
int processed = listLength(pending_list);
|
||||||
listRewind(pending_list,&li);
|
listRewind(pending_list,&li);
|
||||||
while((ln = listNext(&li))) {
|
while((ln = listNext(&li))) {
|
||||||
tls_connection *conn = listNodeValue(ln);
|
tls_connection *conn = (tls_connection*)listNodeValue(ln);
|
||||||
tlsHandleEvent(conn, AE_READABLE);
|
tlsHandleEvent(conn, AE_READABLE);
|
||||||
}
|
}
|
||||||
return processed;
|
return processed;
|
||||||
@ -855,4 +949,6 @@ int tlsProcessPendingData() {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void tlsInitThread() {}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
x
Reference in New Issue
Block a user