Introduce TLS specified APIs

Introduce .get_peer_cert, .get_ctx and .get_client_ctx for TLS, also
hide redis_tls_ctx & redis_tls_client_ctx.

Then outside could access the variables by connection API only:
- redis_tls_ctx -> connTypeGetCtx(CONN_TYPE_TLS)
- redis_tls_client_ctx -> connTypeGetClientCtx(CONN_TYPE_TLS)

Also remove connTLSGetPeerCert(), use connGetPeerCert() instead.

Signed-off-by: zhenwei pi <pizhenwei@bytedance.com>
This commit is contained in:
zhenwei pi 2022-06-14 19:17:28 +08:00
parent 709b55b09d
commit c4c02f8036
5 changed files with 56 additions and 15 deletions

View File

@ -132,3 +132,23 @@ int connTypeProcessPendingData(void) {
return ret;
}
void *connTypeGetCtx(int type) {
ConnectionType *ct = connectionByType(type);
if (ct && ct->get_ctx) {
return ct->get_ctx();
}
return NULL;
}
void *connTypeGetClientCtx(int type) {
ConnectionType *ct = connectionByType(type);
if (ct && ct->get_client_ctx) {
return ct->get_client_ctx();
}
return NULL;
}

View File

@ -95,6 +95,11 @@ typedef struct ConnectionType {
/* pending data */
int (*has_pending_data)(void);
int (*process_pending_data)(void);
/* TLS specified methods */
sds (*get_peer_cert)(struct connection *conn);
void* (*get_ctx)(void);
void* (*get_client_ctx)(void);
} ConnectionType;
struct connection {
@ -335,7 +340,17 @@ int connSendTimeout(connection *conn, long long ms);
int connRecvTimeout(connection *conn, long long ms);
/* Helpers for tls special considerations */
sds connTLSGetPeerCert(connection *conn);
void *connTypeGetCtx(int type);
void *connTypeGetClientCtx(int type);
/* Get cert for the secure connection */
static inline sds connGetPeerCert(connection *conn) {
if (conn->type->get_peer_cert) {
return conn->type->get_peer_cert(conn);
}
return NULL;
}
/* Initialize the redis connection framework */
int connTypeInitialize();

View File

@ -8946,7 +8946,7 @@ RedisModuleString *RM_GetClientCertificate(RedisModuleCtx *ctx, uint64_t client_
client *c = lookupClientByID(client_id);
if (c == NULL) return NULL;
sds cert = connTLSGetPeerCert(c->conn);
sds cert = connGetPeerCert(c->conn);
if (!cert) return NULL;
RedisModuleString *s = createObject(OBJ_STRING, cert);

View File

@ -44,11 +44,6 @@
extern char **environ;
#ifdef USE_OPENSSL
extern SSL_CTX *redis_tls_ctx;
extern SSL_CTX *redis_tls_client_ctx;
#endif
#define REDIS_SENTINEL_PORT 26379
/* ======================== Sentinel global state =========================== */
@ -2381,6 +2376,9 @@ static int instanceLinkNegotiateTLS(redisAsyncContext *context) {
#ifndef USE_OPENSSL
(void) context;
#else
SSL_CTX *redis_tls_ctx = connTypeGetCtx(CONN_TYPE_TLS);
SSL_CTX *redis_tls_client_ctx = connTypeGetClientCtx(CONN_TYPE_TLS);
if (!redis_tls_ctx) return C_ERR;
SSL *ssl = SSL_new(redis_tls_client_ctx ? redis_tls_client_ctx : redis_tls_ctx);
if (!ssl) return C_ERR;

View File

@ -58,8 +58,8 @@
extern ConnectionType CT_Socket;
SSL_CTX *redis_tls_ctx = NULL;
SSL_CTX *redis_tls_client_ctx = NULL;
static SSL_CTX *redis_tls_ctx = NULL;
static SSL_CTX *redis_tls_client_ctx = NULL;
static int parseProtocolsConfig(const char *str) {
int i, count = 0;
@ -1043,7 +1043,7 @@ static int tlsProcessPendingData() {
/* Fetch the peer certificate used for authentication on the specified
* connection and return it as a PEM-encoded sds.
*/
sds connTLSGetPeerCert(connection *conn_) {
static sds connTLSGetPeerCert(connection *conn_) {
tls_connection *conn = (tls_connection *) conn_;
if (conn_->type->get_type(conn_) != CONN_TYPE_TLS || !conn->ssl) return NULL;
@ -1064,6 +1064,14 @@ sds connTLSGetPeerCert(connection *conn_) {
return cert_pem;
}
static void *tlsGetCtx(void) {
return redis_tls_ctx;
}
static void *tlsGetClientCtx(void) {
return redis_tls_client_ctx;
}
ConnectionType CT_TLS = {
/* connection type */
.get_type = connTLSGetType,
@ -1099,6 +1107,11 @@ ConnectionType CT_TLS = {
/* pending data */
.has_pending_data = tlsHasPendingData,
.process_pending_data = tlsProcessPendingData,
/* TLS specified methods */
.get_peer_cert = connTLSGetPeerCert,
.get_ctx = tlsGetCtx,
.get_client_ctx = tlsGetClientCtx
};
int RedisRegisterConnectionTypeTLS()
@ -1124,9 +1137,4 @@ connection *connCreateAcceptedTLS(int fd, int require_auth) {
return NULL;
}
sds connTLSGetPeerCert(connection *conn_) {
(void) conn_;
return NULL;
}
#endif