From 0ae02ce95b8a27ce9d19340e53dcdc9b5a060101 Mon Sep 17 00:00:00 2001 From: zhenwei pi Date: Wed, 27 Jul 2022 11:47:50 +0800 Subject: [PATCH] Abstract accept handler Abstract accept handler for socket&TLS, and add helper function 'connAcceptHandler' to get accept handler by specified type. Also move acceptTcpHandler into socket.c, and move acceptTLSHandler into tls.c. Signed-off-by: zhenwei pi --- src/config.c | 4 ++-- src/connection.h | 12 ++++++++++++ src/networking.c | 43 +------------------------------------------ src/server.c | 8 ++++---- src/server.h | 3 +-- src/socket.c | 21 +++++++++++++++++++++ src/tls.c | 21 +++++++++++++++++++++ 7 files changed, 62 insertions(+), 50 deletions(-) diff --git a/src/config.c b/src/config.c index 7b8c4b75c..2540124cf 100644 --- a/src/config.c +++ b/src/config.c @@ -2430,7 +2430,7 @@ static int updateHZ(const char **err) { } static int updatePort(const char **err) { - if (changeListenPort(server.port, &server.ipfd, acceptTcpHandler) == C_ERR) { + if (changeListenPort(server.port, &server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) == C_ERR) { *err = "Unable to listen on this port. Check server logs."; return 0; } @@ -2591,7 +2591,7 @@ static int applyTLSPort(const char **err) { return 0; } - if (changeListenPort(server.tls_port, &server.tlsfd, acceptTLSHandler) == C_ERR) { + if (changeListenPort(server.tls_port, &server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) == C_ERR) { *err = "Unable to listen on this port. Check server logs."; return 0; } diff --git a/src/connection.h b/src/connection.h index 4fca50fd1..feb084629 100644 --- a/src/connection.h +++ b/src/connection.h @@ -36,8 +36,11 @@ #include #include +#include "ae.h" + #define CONN_INFO_LEN 32 #define CONN_ADDR_STR_LEN 128 /* Similar to INET6_ADDRSTRLEN, hoping to handle other protocols. */ +#define MAX_ACCEPTS_PER_CALL 1000 struct aeEventLoop; typedef struct connection connection; @@ -71,6 +74,7 @@ typedef struct ConnectionType { /* ae & accept & listen & error & address handler */ void (*ae_handler)(struct aeEventLoop *el, int fd, void *clientData, int mask); + aeFileProc *accept_handler; int (*addr)(connection *conn, char *ip, size_t ip_len, int *port, int remote); /* create/close connection */ @@ -381,6 +385,14 @@ int connTypeHasPendingData(void); /* walk all the connection types and process pending data for each connection type */ int connTypeProcessPendingData(void); +/* Get accept_handler of a connection type */ +static inline aeFileProc *connAcceptHandler(int type) { + ConnectionType *ct = connectionByType(type); + if (ct) + return ct->accept_handler; + return NULL; +} + int RedisRegisterConnectionTypeSocket(); int RedisRegisterConnectionTypeTLS(); diff --git a/src/networking.c b/src/networking.c index 45ff585e2..d71e30b71 100644 --- a/src/networking.c +++ b/src/networking.c @@ -1255,8 +1255,7 @@ void clientAcceptHandler(connection *conn) { c); } -#define MAX_ACCEPTS_PER_CALL 1000 -static void acceptCommonHandler(connection *conn, int flags, char *ip) { +void acceptCommonHandler(connection *conn, int flags, char *ip) { client *c; char conninfo[100]; UNUSED(ip); @@ -1328,46 +1327,6 @@ static void acceptCommonHandler(connection *conn, int flags, char *ip) { } } -void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask) { - int cport, cfd, max = MAX_ACCEPTS_PER_CALL; - char cip[NET_IP_STR_LEN]; - UNUSED(el); - UNUSED(mask); - UNUSED(privdata); - - while(max--) { - cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport); - if (cfd == ANET_ERR) { - if (errno != EWOULDBLOCK) - serverLog(LL_WARNING, - "Accepting client connection: %s", server.neterr); - return; - } - serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); - acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),0,cip); - } -} - -void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask) { - int cport, cfd, max = MAX_ACCEPTS_PER_CALL; - char cip[NET_IP_STR_LEN]; - UNUSED(el); - UNUSED(mask); - UNUSED(privdata); - - while(max--) { - cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport); - if (cfd == ANET_ERR) { - if (errno != EWOULDBLOCK) - serverLog(LL_WARNING, - "Accepting client connection: %s", server.neterr); - return; - } - serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); - acceptCommonHandler(connCreateAccepted(CONN_TYPE_TLS, cfd, &server.tls_auth_clients),0,cip); - } -} - void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask) { int cfd, max = MAX_ACCEPTS_PER_CALL; UNUSED(el); diff --git a/src/server.c b/src/server.c index 5c7c7ce6f..795010473 100644 --- a/src/server.c +++ b/src/server.c @@ -2586,10 +2586,10 @@ void initServer(void) { /* Create an event handler for accepting new connections in TCP and Unix * domain sockets. */ - if (createSocketAcceptHandler(&server.ipfd, acceptTcpHandler) != C_OK) { + if (createSocketAcceptHandler(&server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) != C_OK) { serverPanic("Unrecoverable error creating TCP socket accept handler."); } - if (createSocketAcceptHandler(&server.tlsfd, acceptTLSHandler) != C_OK) { + if (createSocketAcceptHandler(&server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) != C_OK) { serverPanic("Unrecoverable error creating TLS socket accept handler."); } if (createSocketAcceptHandler(&server.sofd, acceptUnixHandler) != C_OK) { @@ -6282,10 +6282,10 @@ int changeBindAddr(void) { } /* Create TCP and TLS event handlers */ - if (createSocketAcceptHandler(&server.ipfd, acceptTcpHandler) != C_OK) { + if (createSocketAcceptHandler(&server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) != C_OK) { serverPanic("Unrecoverable error creating TCP socket accept handler."); } - if (createSocketAcceptHandler(&server.tlsfd, acceptTLSHandler) != C_OK) { + if (createSocketAcceptHandler(&server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) != C_OK) { serverPanic("Unrecoverable error creating TLS socket accept handler."); } diff --git a/src/server.h b/src/server.h index 169ec4ec3..ad91d5218 100644 --- a/src/server.h +++ b/src/server.h @@ -2460,8 +2460,7 @@ void setDeferredSetLen(client *c, void *node, long length); void setDeferredAttributeLen(client *c, void *node, long length); void setDeferredPushLen(client *c, void *node, long length); int processInputBuffer(client *c); -void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask); -void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask); +void acceptCommonHandler(connection *conn, int flags, char *ip); void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask); void readQueryFromClient(connection *conn); int prepareClientToWrite(client *c); diff --git a/src/socket.c b/src/socket.c index 5aea1954e..962056c74 100644 --- a/src/socket.c +++ b/src/socket.c @@ -301,6 +301,26 @@ static void connSocketEventHandler(struct aeEventLoop *el, int fd, void *clientD } } +static void connSocketAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) { + int cport, cfd, max = MAX_ACCEPTS_PER_CALL; + char cip[NET_IP_STR_LEN]; + UNUSED(el); + UNUSED(mask); + UNUSED(privdata); + + while(max--) { + cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport); + if (cfd == ANET_ERR) { + if (errno != EWOULDBLOCK) + serverLog(LL_WARNING, + "Accepting client connection: %s", server.neterr); + return; + } + serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); + acceptCommonHandler(connCreateAcceptedSocket(cfd, NULL),0,cip); + } +} + static int connSocketAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) { if (anetFdToString(conn->fd, ip, ip_len, port, remote) == 0) return C_OK; @@ -360,6 +380,7 @@ static ConnectionType CT_Socket = { /* ae & accept & listen & error & address handler */ .ae_handler = connSocketEventHandler, + .accept_handler = connSocketAcceptHandler, .addr = connSocketAddr, /* create/close connection */ diff --git a/src/tls.c b/src/tls.c index 39108afed..c459ba8f2 100644 --- a/src/tls.c +++ b/src/tls.c @@ -719,6 +719,26 @@ static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, in tlsHandleEvent(conn, mask); } +static void tlsAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) { + int cport, cfd, max = MAX_ACCEPTS_PER_CALL; + char cip[NET_IP_STR_LEN]; + UNUSED(el); + UNUSED(mask); + UNUSED(privdata); + + while(max--) { + cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport); + if (cfd == ANET_ERR) { + if (errno != EWOULDBLOCK) + serverLog(LL_WARNING, + "Accepting client connection: %s", server.neterr); + return; + } + serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); + acceptCommonHandler(connCreateAcceptedTLS(cfd, &server.tls_auth_clients),0,cip); + } +} + static int connTLSAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) { return anetFdToString(conn->fd, ip, ip_len, port, remote); } @@ -1082,6 +1102,7 @@ static ConnectionType CT_TLS = { /* ae & accept & listen & error & address handler */ .ae_handler = tlsEventHandler, + .accept_handler = tlsAcceptHandler, .addr = connTLSAddr, /* create/close connection */