Preserve original fd blocking state in TLS I/O operations (#1298)
This change prevents unintended side effects on connection state and improves consistency with non-TLS sync operations. For example, when invoking `connTLSSyncRead` with a blocking file descriptor, the mode is switched to non-blocking upon `connTLSSyncRead` exit. If the code assumes the file descriptor remains blocking and calls the normal `read` expecting it to block, it may result in a short read. This caused a crash in dual-channel, which was fixed in this PR by relocating `connBlock()`: https://github.com/valkey-io/valkey/pull/837 Signed-off-by: xbasel <103044017+xbasel@users.noreply.github.com>
This commit is contained in:
parent
6038eda010
commit
b486a41500
30
src/anet.c
30
src/anet.c
@ -70,17 +70,24 @@ int anetGetError(int fd) {
|
||||
return sockerr;
|
||||
}
|
||||
|
||||
int anetSetBlock(char *err, int fd, int non_block) {
|
||||
static int anetGetSocketFlags(char *err, int fd) {
|
||||
int flags;
|
||||
|
||||
/* Set the socket blocking (if non_block is zero) or non-blocking.
|
||||
* Note that fcntl(2) for F_GETFL and F_SETFL can't be
|
||||
* interrupted by a signal. */
|
||||
if ((flags = fcntl(fd, F_GETFL)) == -1) {
|
||||
anetSetError(err, "fcntl(F_GETFL): %s", strerror(errno));
|
||||
return ANET_ERR;
|
||||
}
|
||||
|
||||
return flags;
|
||||
}
|
||||
|
||||
int anetSetBlock(char *err, int fd, int non_block) {
|
||||
int flags = anetGetSocketFlags(err, fd);
|
||||
|
||||
if (flags == ANET_ERR) {
|
||||
return ANET_ERR;
|
||||
}
|
||||
|
||||
/* Check if this flag has been set or unset, if so,
|
||||
* then there is no need to call fcntl to set/unset it again. */
|
||||
if (!!(flags & O_NONBLOCK) == !!non_block) return ANET_OK;
|
||||
@ -105,6 +112,21 @@ int anetBlock(char *err, int fd) {
|
||||
return anetSetBlock(err, fd, 0);
|
||||
}
|
||||
|
||||
int anetIsBlock(char *err, int fd) {
|
||||
int flags = anetGetSocketFlags(err, fd);
|
||||
|
||||
if (flags == ANET_ERR) {
|
||||
return ANET_ERR;
|
||||
}
|
||||
|
||||
/* Check if the O_NONBLOCK flag is set */
|
||||
if (flags & O_NONBLOCK) {
|
||||
return 0; /* Socket is non-blocking */
|
||||
} else {
|
||||
return 1; /* Socket is blocking */
|
||||
}
|
||||
}
|
||||
|
||||
/* Enable the FD_CLOEXEC on the given fd to avoid fd leaks.
|
||||
* This function should be invoked for fd's on specific places
|
||||
* where fork + execve system calls are called. */
|
||||
|
@ -61,6 +61,7 @@ int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port)
|
||||
int anetUnixAccept(char *err, int serversock);
|
||||
int anetNonBlock(char *err, int fd);
|
||||
int anetBlock(char *err, int fd);
|
||||
int anetIsBlock(char *err, int fd);
|
||||
int anetCloexec(int fd);
|
||||
int anetEnableTcpNoDelay(char *err, int fd);
|
||||
int anetDisableTcpNoDelay(char *err, int fd);
|
||||
|
21
src/tls.c
21
src/tls.c
@ -974,6 +974,10 @@ static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func)
|
||||
return C_OK;
|
||||
}
|
||||
|
||||
static int isBlocking(tls_connection *conn) {
|
||||
return anetIsBlock(NULL, conn->c.fd);
|
||||
}
|
||||
|
||||
static void setBlockingTimeout(tls_connection *conn, long long timeout) {
|
||||
anetBlock(NULL, conn->c.fd);
|
||||
anetSendTimeout(NULL, conn->c.fd, timeout);
|
||||
@ -1012,27 +1016,31 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,
|
||||
|
||||
static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long long timeout) {
|
||||
tls_connection *conn = (tls_connection *)conn_;
|
||||
|
||||
int blocking = isBlocking(conn);
|
||||
setBlockingTimeout(conn, timeout);
|
||||
SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||
ERR_clear_error();
|
||||
int ret = SSL_write(conn->ssl, ptr, size);
|
||||
ret = updateStateAfterSSLIO(conn, ret, 0);
|
||||
SSL_set_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
|
||||
unsetBlockingTimeout(conn);
|
||||
if (!blocking) {
|
||||
unsetBlockingTimeout(conn);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) {
|
||||
tls_connection *conn = (tls_connection *)conn_;
|
||||
|
||||
int blocking = isBlocking(conn);
|
||||
setBlockingTimeout(conn, timeout);
|
||||
ERR_clear_error();
|
||||
int ret = SSL_read(conn->ssl, ptr, size);
|
||||
updateSSLPendingFlag(conn);
|
||||
ret = updateStateAfterSSLIO(conn, ret, 0);
|
||||
unsetBlockingTimeout(conn);
|
||||
if (!blocking) {
|
||||
unsetBlockingTimeout(conn);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
@ -1041,6 +1049,7 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
|
||||
tls_connection *conn = (tls_connection *)conn_;
|
||||
ssize_t nread = 0;
|
||||
|
||||
int blocking = isBlocking(conn);
|
||||
setBlockingTimeout(conn, timeout);
|
||||
|
||||
size--;
|
||||
@ -1067,7 +1076,9 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
|
||||
size--;
|
||||
}
|
||||
exit:
|
||||
unsetBlockingTimeout(conn);
|
||||
if (!blocking) {
|
||||
unsetBlockingTimeout(conn);
|
||||
}
|
||||
return nread;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user