Enable TLS connections

Former-commit-id: d05da0fabdfeb7eadce6546c7c1d85739b2794d7
This commit is contained in:
John Sully 2020-05-24 19:57:16 -04:00
parent d176ae50ec
commit 845027c291
7 changed files with 157 additions and 49 deletions

View File

@ -236,7 +236,8 @@ ifeq ($(MALLOC),memkind)
endif endif
ifeq ($(BUILD_TLS),yes) ifeq ($(BUILD_TLS),yes)
FINAL_CFLAGS+=-DUSE_OPENSSL $(OPENSSL_CFLAGS) FINAL_CFLAGS+=-DUSE_OPENSSL $(OPENSSL_CXXFLAGS)
FINAL_CXXFLAGS+=-DUSE_OPENSSL $(OPENSSL_CXXFLAGS)
FINAL_LDFLAGS+=$(OPENSSL_LDFLAGS) FINAL_LDFLAGS+=$(OPENSSL_LDFLAGS)
FINAL_LIBS += ../deps/hiredis/libhiredis_ssl.a -lssl -lcrypto FINAL_LIBS += ../deps/hiredis/libhiredis_ssl.a -lssl -lcrypto
endif endif

View File

@ -68,6 +68,7 @@ typedef struct ConnectionType {
ssize_t (*sync_write)(struct connection *conn, const char *ptr, ssize_t size, long long timeout); ssize_t (*sync_write)(struct connection *conn, const char *ptr, ssize_t size, long long timeout);
ssize_t (*sync_read)(struct connection *conn, char *ptr, ssize_t size, long long timeout); ssize_t (*sync_read)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
ssize_t (*sync_readline)(struct connection *conn, char *ptr, ssize_t size, long long timeout); ssize_t (*sync_readline)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
void (*marshal_thread)(struct connection *conn);
} ConnectionType; } ConnectionType;
struct connection { struct connection {
@ -198,6 +199,11 @@ static inline ssize_t connSyncReadLine(connection *conn, char *ptr, ssize_t size
return conn->type->sync_readline(conn, ptr, size, timeout); return conn->type->sync_readline(conn, ptr, size, timeout);
} }
static inline void connMarshalThread(connection *conn) {
if (conn->type->marshal_thread != nullptr)
conn->type->marshal_thread(conn);
}
connection *connCreateSocket(); connection *connCreateSocket();
connection *connCreateAcceptedSocket(int fd); connection *connCreateAcceptedSocket(int fd);

View File

@ -1313,7 +1313,8 @@ void acceptOnThread(connection *conn, int flags, char *cip)
szT = (char*)zmalloc(NET_IP_STR_LEN, MALLOC_LOCAL); szT = (char*)zmalloc(NET_IP_STR_LEN, MALLOC_LOCAL);
memcpy(szT, cip, NET_IP_STR_LEN); memcpy(szT, cip, NET_IP_STR_LEN);
} }
int res = aePostFunction(g_pserver->rgthreadvar[ielTarget].el, [conn, flags, ielTarget, szT]{ int res = aePostFunction(g_pserver->rgthreadvar[ielTarget].el, [conn, flags, ielTarget, szT] {
connMarshalThread(conn);
acceptCommonHandler(conn,flags,szT,ielTarget); acceptCommonHandler(conn,flags,szT,ielTarget);
if (!g_fTestMode && !g_pserver->loading) if (!g_fTestMode && !g_pserver->loading)
rgacceptsInFlight[ielTarget].fetch_sub(1, std::memory_order_relaxed); rgacceptsInFlight[ielTarget].fetch_sub(1, std::memory_order_relaxed);

View File

@ -32,7 +32,9 @@
#include "hiredis.h" #include "hiredis.h"
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
#include "openssl/ssl.h" #include "openssl/ssl.h"
extern "C" {
#include "hiredis_ssl.h" #include "hiredis_ssl.h"
}
#endif #endif
#include "async.h" #include "async.h"

View File

@ -2940,7 +2940,7 @@ static void initNetworking(int fReusePort)
} }
/* Abort if there are no listening sockets at all. */ /* Abort if there are no listening sockets at all. */
if (g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].ipfd_count == 0 && g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].tlsfd_count && g_pserver->sofd < 0) { if (g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].ipfd_count == 0 && g_pserver->rgthreadvar[IDX_EVENT_LOOP_MAIN].tlsfd_count == 0 && g_pserver->sofd < 0) {
serverLog(LL_WARNING, "Configured to not listen anywhere, exiting."); serverLog(LL_WARNING, "Configured to not listen anywhere, exiting.");
exit(1); exit(1);
} }
@ -5305,6 +5305,7 @@ void *workerThreadMain(void *parg)
int iel = (int)((int64_t)parg); int iel = (int)((int64_t)parg);
serverLog(LOG_INFO, "Thread %d alive.", iel); serverLog(LOG_INFO, "Thread %d alive.", iel);
serverTL = g_pserver->rgthreadvar+iel; // set the TLS threadsafe global serverTL = g_pserver->rgthreadvar+iel; // set the TLS threadsafe global
tlsInitThread();
if (iel != IDX_EVENT_LOOP_MAIN) if (iel != IDX_EVENT_LOOP_MAIN)
{ {

View File

@ -3061,6 +3061,7 @@ inline int FCorrectThread(client *c)
/* TLS stuff */ /* TLS stuff */
void tlsInit(void); void tlsInit(void);
void tlsInitThread();
int tlsConfigure(redisTLSContextConfig *ctx_config); int tlsConfigure(redisTLSContextConfig *ctx_config);

View File

@ -31,6 +31,8 @@
#include "server.h" #include "server.h"
#include "connhelpers.h" #include "connhelpers.h"
#include "adlist.h" #include "adlist.h"
#include "aelocker.h"
#include <mutex>
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
@ -53,6 +55,7 @@
extern ConnectionType CT_Socket; extern ConnectionType CT_Socket;
SSL_CTX *redis_tls_ctx; SSL_CTX *redis_tls_ctx;
fastlock g_ctxtlock("SSL CTX");
static int parseProtocolsConfig(const char *str) { static int parseProtocolsConfig(const char *str) {
int i, count = 0; int i, count = 0;
@ -91,7 +94,7 @@ static int parseProtocolsConfig(const char *str) {
/* list of connections with pending data already read from the socket, but not /* list of connections with pending data already read from the socket, but not
* served to the reader yet. */ * served to the reader yet. */
static list *pending_list = NULL; static thread_local list *pending_list = NULL;
/** /**
* OpenSSL global initialization and locking handling callbacks. * OpenSSL global initialization and locking handling callbacks.
@ -147,10 +150,15 @@ void tlsInit(void) {
serverLog(LL_WARNING, "OpenSSL: Failed to seed random number generator."); serverLog(LL_WARNING, "OpenSSL: Failed to seed random number generator.");
} }
pending_list = listCreate();
/* Server configuration */ /* Server configuration */
server.tls_auth_clients = 1; /* Secure by default */ g_pserver->tls_auth_clients = 1; /* Secure by default */
tlsInitThread();
}
void tlsInitThread(void)
{
serverAssert(pending_list == nullptr);
pending_list = listCreate();
} }
/* Attempt to configure/reconfigure TLS. This operation is atomic and will /* Attempt to configure/reconfigure TLS. This operation is atomic and will
@ -159,6 +167,7 @@ void tlsInit(void) {
int tlsConfigure(redisTLSContextConfig *ctx_config) { int tlsConfigure(redisTLSContextConfig *ctx_config) {
char errbuf[256]; char errbuf[256];
SSL_CTX *ctx = NULL; SSL_CTX *ctx = NULL;
int protocols;
if (!ctx_config->cert_file) { if (!ctx_config->cert_file) {
serverLog(LL_WARNING, "No tls-cert-file configured!"); serverLog(LL_WARNING, "No tls-cert-file configured!");
@ -184,7 +193,7 @@ int tlsConfigure(redisTLSContextConfig *ctx_config) {
SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
#endif #endif
int protocols = parseProtocolsConfig(ctx_config->protocols); protocols = parseProtocolsConfig(ctx_config->protocols);
if (protocols == -1) goto error; if (protocols == -1) goto error;
if (!(protocols & REDIS_TLS_PROTO_TLSv1)) if (!(protocols & REDIS_TLS_PROTO_TLSv1))
@ -272,8 +281,11 @@ int tlsConfigure(redisTLSContextConfig *ctx_config) {
} }
#endif #endif
{
std::unique_lock<fastlock> ul(g_ctxtlock);
SSL_CTX_free(redis_tls_ctx); SSL_CTX_free(redis_tls_ctx);
redis_tls_ctx = ctx; redis_tls_ctx = ctx;
}
return C_OK; return C_OK;
@ -289,7 +301,7 @@ error:
#define TLSCONN_DEBUG(fmt, ...) #define TLSCONN_DEBUG(fmt, ...)
#endif #endif
ConnectionType CT_TLS; extern ConnectionType CT_TLS;
/* Normal socket connections have a simple events/handler correlation. /* Normal socket connections have a simple events/handler correlation.
* *
@ -307,6 +319,7 @@ ConnectionType CT_TLS;
*/ */
typedef enum { typedef enum {
WANT_INVALID = 0,
WANT_READ = 1, WANT_READ = 1,
WANT_WRITE WANT_WRITE
} WantIOType; } WantIOType;
@ -321,16 +334,24 @@ typedef struct tls_connection {
SSL *ssl; SSL *ssl;
char *ssl_error; char *ssl_error;
listNode *pending_list_node; listNode *pending_list_node;
aeEventLoop *el;
} tls_connection; } tls_connection;
connection *connCreateTLS(void) { connection *connCreateTLS(void) {
tls_connection *conn = zcalloc(sizeof(tls_connection)); tls_connection *conn = (tls_connection*)zcalloc(sizeof(tls_connection), MALLOC_LOCAL);
conn->c.type = &CT_TLS; conn->c.type = &CT_TLS;
conn->c.fd = -1; conn->c.fd = -1;
std::unique_lock<fastlock> ul(g_ctxtlock);
conn->ssl = SSL_new(redis_tls_ctx); conn->ssl = SSL_new(redis_tls_ctx);
conn->el = serverTL->el;
return (connection *) conn; return (connection *) conn;
} }
void connTLSMarshalThread(connection *c) {
tls_connection *conn = (tls_connection*)c;
conn->el = serverTL->el;
}
connection *connCreateAcceptedTLS(int fd, int require_auth) { connection *connCreateAcceptedTLS(int fd, int require_auth) {
tls_connection *conn = (tls_connection *) connCreateTLS(); tls_connection *conn = (tls_connection *) connCreateTLS();
conn->c.fd = fd; conn->c.fd = fd;
@ -374,7 +395,7 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *
/* Error! */ /* Error! */
conn->c.last_errno = 0; conn->c.last_errno = 0;
if (conn->ssl_error) zfree(conn->ssl_error); if (conn->ssl_error) zfree(conn->ssl_error);
conn->ssl_error = zmalloc(512); conn->ssl_error = (char*)zmalloc(512);
ERR_error_string_n(ERR_get_error(), conn->ssl_error, 512); ERR_error_string_n(ERR_get_error(), conn->ssl_error, 512);
break; break;
} }
@ -386,17 +407,19 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *
} }
void registerSSLEvent(tls_connection *conn, WantIOType want) { void registerSSLEvent(tls_connection *conn, WantIOType want) {
int mask = aeGetFileEvents(server.el, conn->c.fd); int mask = aeGetFileEvents(serverTL->el, conn->c.fd);
serverAssert(conn->el == serverTL->el);
switch (want) { switch (want) {
case WANT_READ: case WANT_READ:
if (mask & AE_WRITABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); if (mask & AE_WRITABLE) aeDeleteFileEvent(conn->el, conn->c.fd, AE_WRITABLE);
if (!(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, if (!(mask & AE_READABLE)) aeCreateFileEvent(conn->el, conn->c.fd, AE_READABLE|AE_READ_THREADSAFE,
tlsEventHandler, conn); tlsEventHandler, conn);
break; break;
case WANT_WRITE: case WANT_WRITE:
if (mask & AE_READABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); if (mask & AE_READABLE) aeDeleteFileEvent(conn->el, conn->c.fd, AE_READABLE);
if (!(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, if (!(mask & AE_WRITABLE)) aeCreateFileEvent(conn->el, conn->c.fd, AE_WRITABLE|AE_WRITE_THREADSAFE,
tlsEventHandler, conn); tlsEventHandler, conn);
break; break;
default: default:
@ -405,24 +428,38 @@ void registerSSLEvent(tls_connection *conn, WantIOType want) {
} }
} }
void updateSSLEvent(tls_connection *conn) { void updateSSLEventCore(tls_connection *conn) {
int mask = aeGetFileEvents(server.el, conn->c.fd); int mask = aeGetFileEvents(serverTL->el, conn->c.fd);
int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ); int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ);
int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE); int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE);
serverAssert(conn->el == serverTL->el);
if (need_read && !(mask & AE_READABLE)) if (need_read && !(mask & AE_READABLE))
aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); aeCreateFileEvent(serverTL->el, conn->c.fd, AE_READABLE|AE_READ_THREADSAFE, tlsEventHandler, conn);
if (!need_read && (mask & AE_READABLE)) if (!need_read && (mask & AE_READABLE))
aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_READABLE);
if (need_write && !(mask & AE_WRITABLE)) if (need_write && !(mask & AE_WRITABLE))
aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); aeCreateFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE|AE_WRITE_THREADSAFE, tlsEventHandler, conn);
if (!need_write && (mask & AE_WRITABLE)) if (!need_write && (mask & AE_WRITABLE))
aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); aeDeleteFileEvent(serverTL->el, conn->c.fd, AE_WRITABLE);
} }
static void tlsHandleEvent(tls_connection *conn, int mask) { void updateSSLEvent(tls_connection *conn) {
if (conn->el != serverTL->el) {
aePostFunction(conn->el, [conn]{
updateSSLEventCore(conn);
});
} else {
updateSSLEventCore(conn);
}
}
void tlsHandleEvent(tls_connection *conn, int mask) {
int ret; int ret;
serverAssert(!GlobalLocksAcquired());
serverAssert(conn->el == serverTL->el);
TLSCONN_DEBUG("tlsEventHandler(): fd=%d, state=%d, mask=%d, r=%d, w=%d, flags=%d", TLSCONN_DEBUG("tlsEventHandler(): fd=%d, state=%d, mask=%d, r=%d, w=%d, flags=%d",
fd, conn->c.state, mask, conn->c.read_handler != NULL, conn->c.write_handler != NULL, fd, conn->c.state, mask, conn->c.read_handler != NULL, conn->c.write_handler != NULL,
@ -442,7 +479,7 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
} }
ret = SSL_connect(conn->ssl); ret = SSL_connect(conn->ssl);
if (ret <= 0) { if (ret <= 0) {
WantIOType want = 0; WantIOType want = WANT_INVALID;
if (!handleSSLReturnCode(conn, ret, &want)) { if (!handleSSLReturnCode(conn, ret, &want)) {
registerSSLEvent(conn, want); registerSSLEvent(conn, want);
@ -460,13 +497,17 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
} }
} }
{
AeLocker locker;
locker.arm(nullptr);
if (!callHandler((connection *) conn, conn->c.conn_handler)) return; if (!callHandler((connection *) conn, conn->c.conn_handler)) return;
}
conn->c.conn_handler = NULL; conn->c.conn_handler = NULL;
break; break;
case CONN_STATE_ACCEPTING: case CONN_STATE_ACCEPTING:
ret = SSL_accept(conn->ssl); ret = SSL_accept(conn->ssl);
if (ret <= 0) { if (ret <= 0) {
WantIOType want = 0; WantIOType want = WANT_INVALID;
if (!handleSSLReturnCode(conn, ret, &want)) { if (!handleSSLReturnCode(conn, ret, &want)) {
/* Avoid hitting UpdateSSLEvent, which knows nothing /* Avoid hitting UpdateSSLEvent, which knows nothing
* of what SSL_connect() wants and instead looks at our * of what SSL_connect() wants and instead looks at our
@ -482,7 +523,11 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
conn->c.state = CONN_STATE_CONNECTED; conn->c.state = CONN_STATE_CONNECTED;
} }
{
AeLocker locker;
locker.arm(nullptr);
if (!callHandler((connection *) conn, conn->c.conn_handler)) return; if (!callHandler((connection *) conn, conn->c.conn_handler)) return;
}
conn->c.conn_handler = NULL; conn->c.conn_handler = NULL;
break; break;
case CONN_STATE_CONNECTED: case CONN_STATE_CONNECTED:
@ -506,12 +551,18 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER; int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER;
if (!invert && call_read) { if (!invert && call_read) {
AeLocker lock;
if (!(conn->c.flags & CONN_FLAG_READ_THREADSAFE))
lock.arm(nullptr);
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
if (!callHandler((connection *) conn, conn->c.read_handler)) return; if (!callHandler((connection *) conn, conn->c.read_handler)) return;
} }
/* Fire the writable event. */ /* Fire the writable event. */
if (call_write) { if (call_write) {
AeLocker lock;
if (!(conn->c.flags & CONN_FLAG_WRITE_THREADSAFE))
lock.arm(nullptr);
conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ; conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ;
if (!callHandler((connection *) conn, conn->c.write_handler)) return; if (!callHandler((connection *) conn, conn->c.write_handler)) return;
} }
@ -519,6 +570,9 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
/* If we have to invert the call, fire the readable event now /* If we have to invert the call, fire the readable event now
* after the writable one. */ * after the writable one. */
if (invert && call_read) { if (invert && call_read) {
AeLocker lock;
if (!(conn->c.flags & CONN_FLAG_READ_THREADSAFE))
lock.arm(nullptr);
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
if (!callHandler((connection *) conn, conn->c.read_handler)) return; if (!callHandler((connection *) conn, conn->c.read_handler)) return;
} }
@ -550,12 +604,12 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) { static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask) {
UNUSED(el); UNUSED(el);
UNUSED(fd); UNUSED(fd);
tls_connection *conn = clientData; tls_connection *conn = (tls_connection*)clientData;
tlsHandleEvent(conn, mask); tlsHandleEvent(conn, mask);
} }
static void connTLSClose(connection *conn_) { static void connTLSCloseCore(tls_connection *conn) {
tls_connection *conn = (tls_connection *) conn_; serverAssert(conn->el == serverTL->el);
if (conn->ssl) { if (conn->ssl) {
SSL_free(conn->ssl); SSL_free(conn->ssl);
@ -572,13 +626,26 @@ static void connTLSClose(connection *conn_) {
conn->pending_list_node = NULL; conn->pending_list_node = NULL;
} }
CT_Socket.close(conn_); CT_Socket.close(&conn->c);
}
static void connTLSClose(connection *conn_) {
tls_connection *conn = (tls_connection *) conn_;
if (conn->el != serverTL->el) {
aePostFunction(conn->el, [conn]{
connTLSCloseCore(conn);
});
} else {
connTLSCloseCore(conn);
}
} }
static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) { static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) {
tls_connection *conn = (tls_connection *) _conn; tls_connection *conn = (tls_connection *) _conn;
int ret; int ret;
serverAssert(conn->el == serverTL->el);
if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR; if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR;
ERR_clear_error(); ERR_clear_error();
@ -587,7 +654,7 @@ static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handle
ret = SSL_accept(conn->ssl); ret = SSL_accept(conn->ssl);
if (ret <= 0) { if (ret <= 0) {
WantIOType want = 0; WantIOType want = WANT_INVALID;
if (!handleSSLReturnCode(conn, ret, &want)) { if (!handleSSLReturnCode(conn, ret, &want)) {
registerSSLEvent(conn, want); /* We'll fire back */ registerSSLEvent(conn, want); /* We'll fire back */
return C_OK; return C_OK;
@ -607,6 +674,8 @@ static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handle
static int connTLSConnect(connection *conn_, const char *addr, int port, const char *src_addr, ConnectionCallbackFunc connect_handler) { static int connTLSConnect(connection *conn_, const char *addr, int port, const char *src_addr, ConnectionCallbackFunc connect_handler) {
tls_connection *conn = (tls_connection *) conn_; tls_connection *conn = (tls_connection *) conn_;
serverAssert(conn->el == serverTL->el);
if (conn->c.state != CONN_STATE_NONE) return C_ERR; if (conn->c.state != CONN_STATE_NONE) return C_ERR;
ERR_clear_error(); ERR_clear_error();
@ -628,7 +697,7 @@ static int connTLSWrite(connection *conn_, const void *data, size_t data_len) {
ret = SSL_write(conn->ssl, data, data_len); ret = SSL_write(conn->ssl, data, data_len);
if (ret <= 0) { if (ret <= 0) {
WantIOType want = 0; WantIOType want = WANT_INVALID;
if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) { if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) {
if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ;
updateSSLEvent(conn); updateSSLEvent(conn);
@ -654,11 +723,13 @@ static int connTLSRead(connection *conn_, void *buf, size_t buf_len) {
int ret; int ret;
int ssl_err; int ssl_err;
serverAssert(conn->el == serverTL->el);
if (conn->c.state != CONN_STATE_CONNECTED) return -1; if (conn->c.state != CONN_STATE_CONNECTED) return -1;
ERR_clear_error(); ERR_clear_error();
ret = SSL_read(conn->ssl, buf, buf_len); ret = SSL_read(conn->ssl, buf, buf_len);
if (ret <= 0) { if (ret <= 0) {
WantIOType want = 0; WantIOType want = WANT_INVALID;
if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) { if (!(ssl_err = handleSSLReturnCode(conn, ret, &want))) {
if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE;
updateSSLEvent(conn); updateSSLEvent(conn);
@ -687,18 +758,32 @@ static const char *connTLSGetLastError(connection *conn_) {
return NULL; return NULL;
} }
int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) { int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier, bool fThreadSafe) {
serverAssert(((tls_connection*)conn)->el == serverTL->el);
conn->write_handler = func; conn->write_handler = func;
if (barrier) if (barrier)
conn->flags |= CONN_FLAG_WRITE_BARRIER; conn->flags |= CONN_FLAG_WRITE_BARRIER;
else else
conn->flags &= ~CONN_FLAG_WRITE_BARRIER; conn->flags &= ~CONN_FLAG_WRITE_BARRIER;
if (fThreadSafe)
conn->flags |= CONN_FLAG_WRITE_THREADSAFE;
else
conn->flags &= ~CONN_FLAG_WRITE_THREADSAFE;
updateSSLEvent((tls_connection *) conn); updateSSLEvent((tls_connection *) conn);
return C_OK; return C_OK;
} }
int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) { int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func, bool fThreadSafe) {
serverAssert(((tls_connection*)conn)->el == serverTL->el);
conn->read_handler = func; conn->read_handler = func;
if (fThreadSafe)
conn->flags |= CONN_FLAG_READ_THREADSAFE;
else
conn->flags &= ~CONN_FLAG_READ_THREADSAFE;
updateSSLEvent((tls_connection *) conn); updateSSLEvent((tls_connection *) conn);
return C_OK; return C_OK;
} }
@ -719,6 +804,8 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,
tls_connection *conn = (tls_connection *) conn_; tls_connection *conn = (tls_connection *) conn_;
int ret; int ret;
serverAssert(conn->el == serverTL->el);
if (conn->c.state != CONN_STATE_NONE) return C_ERR; if (conn->c.state != CONN_STATE_NONE) return C_ERR;
/* Initiate socket blocking connect first */ /* Initiate socket blocking connect first */
@ -739,9 +826,11 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,
return C_OK; return C_OK;
} }
static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long long timeout) { static ssize_t connTLSSyncWrite(connection *conn_, const char *ptr, ssize_t size, long long timeout) {
tls_connection *conn = (tls_connection *) conn_; tls_connection *conn = (tls_connection *) conn_;
serverAssert(conn->el == serverTL->el);
setBlockingTimeout(conn, timeout); setBlockingTimeout(conn, timeout);
SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE); SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
int ret = SSL_write(conn->ssl, ptr, size); int ret = SSL_write(conn->ssl, ptr, size);
@ -754,6 +843,8 @@ static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long
static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) { static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) {
tls_connection *conn = (tls_connection *) conn_; tls_connection *conn = (tls_connection *) conn_;
serverAssert(conn->el == serverTL->el);
setBlockingTimeout(conn, timeout); setBlockingTimeout(conn, timeout);
int ret = SSL_read(conn->ssl, ptr, size); int ret = SSL_read(conn->ssl, ptr, size);
unsetBlockingTimeout(conn); unsetBlockingTimeout(conn);
@ -765,6 +856,8 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
tls_connection *conn = (tls_connection *) conn_; tls_connection *conn = (tls_connection *) conn_;
ssize_t nread = 0; ssize_t nread = 0;
serverAssert(conn->el == serverTL->el);
setBlockingTimeout(conn, timeout); setBlockingTimeout(conn, timeout);
size--; size--;
@ -792,19 +885,20 @@ exit:
} }
ConnectionType CT_TLS = { ConnectionType CT_TLS = {
.ae_handler = tlsEventHandler, tlsEventHandler,
.accept = connTLSAccept, connTLSConnect,
.connect = connTLSConnect, connTLSWrite,
.blocking_connect = connTLSBlockingConnect, connTLSRead,
.read = connTLSRead, connTLSClose,
.write = connTLSWrite, connTLSAccept,
.close = connTLSClose, connTLSSetWriteHandler,
.set_write_handler = connTLSSetWriteHandler, connTLSSetReadHandler,
.set_read_handler = connTLSSetReadHandler, connTLSGetLastError,
.get_last_error = connTLSGetLastError, connTLSBlockingConnect,
.sync_write = connTLSSyncWrite, connTLSSyncWrite,
.sync_read = connTLSSyncRead, connTLSSyncRead,
.sync_readline = connTLSSyncReadLine, connTLSSyncReadLine,
connTLSMarshalThread,
}; };
int tlsHasPendingData() { int tlsHasPendingData() {
@ -820,7 +914,7 @@ int tlsProcessPendingData() {
int processed = listLength(pending_list); int processed = listLength(pending_list);
listRewind(pending_list,&li); listRewind(pending_list,&li);
while((ln = listNext(&li))) { while((ln = listNext(&li))) {
tls_connection *conn = listNodeValue(ln); tls_connection *conn = (tls_connection*)listNodeValue(ln);
tlsHandleEvent(conn, AE_READABLE); tlsHandleEvent(conn, AE_READABLE);
} }
return processed; return processed;
@ -855,4 +949,6 @@ int tlsProcessPendingData() {
return 0; return 0;
} }
void tlsInitThread() {}
#endif #endif