Preserve original fd blocking state in TLS I/O operations (#1298)

This change prevents unintended side effects on connection state and
improves consistency with non-TLS sync operations.

For example, when invoking `connTLSSyncRead` with a blocking file
descriptor, the mode is switched to non-blocking upon `connTLSSyncRead`
exit. If the code assumes the file descriptor remains blocking and calls
the normal `read` expecting it to block, it may result in a short read.

This caused a crash in dual-channel, which was fixed in this PR by
relocating `connBlock()`:
https://github.com/valkey-io/valkey/pull/837

Signed-off-by: xbasel <103044017+xbasel@users.noreply.github.com>
This commit is contained in:
xbasel 2024-11-21 18:22:16 +02:00 committed by GitHub
parent 6038eda010
commit b486a41500
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 9 deletions

View File

@ -70,17 +70,24 @@ int anetGetError(int fd) {
return sockerr;
}
int anetSetBlock(char *err, int fd, int non_block) {
static int anetGetSocketFlags(char *err, int fd) {
int flags;
/* Set the socket blocking (if non_block is zero) or non-blocking.
* Note that fcntl(2) for F_GETFL and F_SETFL can't be
* interrupted by a signal. */
if ((flags = fcntl(fd, F_GETFL)) == -1) {
anetSetError(err, "fcntl(F_GETFL): %s", strerror(errno));
return ANET_ERR;
}
return flags;
}
int anetSetBlock(char *err, int fd, int non_block) {
int flags = anetGetSocketFlags(err, fd);
if (flags == ANET_ERR) {
return ANET_ERR;
}
/* Check if this flag has been set or unset, if so,
* then there is no need to call fcntl to set/unset it again. */
if (!!(flags & O_NONBLOCK) == !!non_block) return ANET_OK;
@ -105,6 +112,21 @@ int anetBlock(char *err, int fd) {
return anetSetBlock(err, fd, 0);
}
int anetIsBlock(char *err, int fd) {
int flags = anetGetSocketFlags(err, fd);
if (flags == ANET_ERR) {
return ANET_ERR;
}
/* Check if the O_NONBLOCK flag is set */
if (flags & O_NONBLOCK) {
return 0; /* Socket is non-blocking */
} else {
return 1; /* Socket is blocking */
}
}
/* Enable the FD_CLOEXEC on the given fd to avoid fd leaks.
* This function should be invoked for fd's on specific places
* where fork + execve system calls are called. */

View File

@ -61,6 +61,7 @@ int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port)
int anetUnixAccept(char *err, int serversock);
int anetNonBlock(char *err, int fd);
int anetBlock(char *err, int fd);
int anetIsBlock(char *err, int fd);
int anetCloexec(int fd);
int anetEnableTcpNoDelay(char *err, int fd);
int anetDisableTcpNoDelay(char *err, int fd);

View File

@ -974,6 +974,10 @@ static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func)
return C_OK;
}
static int isBlocking(tls_connection *conn) {
return anetIsBlock(NULL, conn->c.fd);
}
static void setBlockingTimeout(tls_connection *conn, long long timeout) {
anetBlock(NULL, conn->c.fd);
anetSendTimeout(NULL, conn->c.fd, timeout);
@ -1012,27 +1016,31 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,
static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long long timeout) {
tls_connection *conn = (tls_connection *)conn_;
int blocking = isBlocking(conn);
setBlockingTimeout(conn, timeout);
SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
ERR_clear_error();
int ret = SSL_write(conn->ssl, ptr, size);
ret = updateStateAfterSSLIO(conn, ret, 0);
SSL_set_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
unsetBlockingTimeout(conn);
if (!blocking) {
unsetBlockingTimeout(conn);
}
return ret;
}
static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) {
tls_connection *conn = (tls_connection *)conn_;
int blocking = isBlocking(conn);
setBlockingTimeout(conn, timeout);
ERR_clear_error();
int ret = SSL_read(conn->ssl, ptr, size);
updateSSLPendingFlag(conn);
ret = updateStateAfterSSLIO(conn, ret, 0);
unsetBlockingTimeout(conn);
if (!blocking) {
unsetBlockingTimeout(conn);
}
return ret;
}
@ -1041,6 +1049,7 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
tls_connection *conn = (tls_connection *)conn_;
ssize_t nread = 0;
int blocking = isBlocking(conn);
setBlockingTimeout(conn, timeout);
size--;
@ -1067,7 +1076,9 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l
size--;
}
exit:
unsetBlockingTimeout(conn);
if (!blocking) {
unsetBlockingTimeout(conn);
}
return nread;
}