Fix race condition in PUB/SUB and other async reply commands where the client can be freed before our handler is executed on the client thread. When this occurs the client pointer is dangling

Former-commit-id: fad9483fc920e5b1fa67e56d4b8483138b565bd3
This commit is contained in:
John Sully 2019-08-26 20:18:52 -04:00
parent 8d36bab0e1
commit cb6abcf3fa
6 changed files with 113 additions and 74 deletions

56
.vscode/settings.json vendored
View File

@ -1,56 +0,0 @@
{
"files.associations": {
"zmalloc.h": "c",
"stat.h": "c",
"array": "cpp",
"atomic": "cpp",
"*.tcc": "cpp",
"cctype": "cpp",
"chrono": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"condition_variable": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"list": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"exception": "cpp",
"fstream": "cpp",
"functional": "cpp",
"future": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"memory": "cpp",
"mutex": "cpp",
"new": "cpp",
"numeric": "cpp",
"optional": "cpp",
"ostream": "cpp",
"ratio": "cpp",
"scoped_allocator": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"thread": "cpp",
"cinttypes": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"typeinfo": "cpp",
"utility": "cpp"
}
}

View File

@ -191,6 +191,36 @@ void aeProcessCmd(aeEventLoop *eventLoop, int fd, void *, int )
} }
} }
// Unlike write() this is an all or nothing thing. We will block if a partial write is hit
ssize_t safe_write(int fd, const void *pv, size_t cb)
{
const char *pcb = (const char*)pv;
ssize_t written = 0;
do
{
ssize_t rval = write(fd, pcb, cb);
if (rval > 0)
{
pcb += rval;
cb -= rval;
written += rval;
}
else if (errno == EAGAIN)
{
if (written == 0)
break;
// if we've already written something then we're committed so keep trying
}
else
{
if (rval == 0)
return written;
return rval;
}
} while (cb);
return written;
}
int aeCreateRemoteFileEvent(aeEventLoop *eventLoop, int fd, int mask, int aeCreateRemoteFileEvent(aeEventLoop *eventLoop, int fd, int mask,
aeFileProc *proc, void *clientData, int fSynchronous) aeFileProc *proc, void *clientData, int fSynchronous)
{ {
@ -212,9 +242,10 @@ int aeCreateRemoteFileEvent(aeEventLoop *eventLoop, int fd, int mask,
std::unique_lock<std::mutex> ulock(cmd.pctl->mutexcv, std::defer_lock); std::unique_lock<std::mutex> ulock(cmd.pctl->mutexcv, std::defer_lock);
if (fSynchronous) if (fSynchronous)
cmd.pctl->mutexcv.lock(); cmd.pctl->mutexcv.lock();
auto size = write(eventLoop->fdCmdWrite, &cmd, sizeof(cmd)); auto size = safe_write(eventLoop->fdCmdWrite, &cmd, sizeof(cmd));
if (size != sizeof(cmd)) if (size != sizeof(cmd))
{ {
AE_ASSERT(size == sizeof(cmd) || size <= 0);
AE_ASSERT(errno == EAGAIN); AE_ASSERT(errno == EAGAIN);
ret = AE_ERR; ret = AE_ERR;
} }

View File

@ -55,6 +55,8 @@ typedef ucontext_t sigcontext_t;
#endif #endif
#endif #endif
bool g_fInCrash = false;
/* ================================= Debugging ============================== */ /* ================================= Debugging ============================== */
/* Compute the sha1 of string at 's' with 'len' bytes long. /* Compute the sha1 of string at 's' with 'len' bytes long.
@ -1356,6 +1358,7 @@ void dumpX86Calls(void *addr, size_t len) {
void sigsegvHandler(int sig, siginfo_t *info, void *secret) { void sigsegvHandler(int sig, siginfo_t *info, void *secret) {
ucontext_t *uc = (ucontext_t*) secret; ucontext_t *uc = (ucontext_t*) secret;
g_fInCrash = true;
void *eip = getMcontextEip(uc); void *eip = getMcontextEip(uc);
sds infostring, clients; sds infostring, clients;
struct sigaction act; struct sigaction act;

View File

@ -174,6 +174,7 @@ client *createClient(int fd, int iel) {
c->bufAsync = NULL; c->bufAsync = NULL;
c->buflenAsync = 0; c->buflenAsync = 0;
c->bufposAsync = 0; c->bufposAsync = 0;
c->casyncOpsPending = 0;
memset(c->uuid, 0, UUID_BINARY_LEN); memset(c->uuid, 0, UUID_BINARY_LEN);
listSetFreeMethod(c->pubsub_patterns,decrRefCountVoid); listSetFreeMethod(c->pubsub_patterns,decrRefCountVoid);
@ -1003,7 +1004,6 @@ static void acceptCommonHandler(int fd, int flags, char *ip, int iel) {
serverLog(LL_WARNING, serverLog(LL_WARNING,
"Error registering fd event for the new client: %s (fd=%d)", "Error registering fd event for the new client: %s (fd=%d)",
strerror(errno),fd); strerror(errno),fd);
close(fd); /* May be already closed, just ignore errors */
return; return;
} }
@ -1266,7 +1266,7 @@ void unlinkClient(client *c) {
} }
} }
void freeClient(client *c) { bool freeClient(client *c) {
listNode *ln; listNode *ln;
serverAssert(c->fd == -1 || GlobalLocksAcquired()); serverAssert(c->fd == -1 || GlobalLocksAcquired());
AssertCorrectThread(c); AssertCorrectThread(c);
@ -1274,9 +1274,9 @@ void freeClient(client *c) {
/* If a client is protected, yet we need to free it right now, make sure /* If a client is protected, yet we need to free it right now, make sure
* to at least use asynchronous freeing. */ * to at least use asynchronous freeing. */
if (c->flags & CLIENT_PROTECTED) { if (c->flags & CLIENT_PROTECTED || c->casyncOpsPending) {
freeClientAsync(c); freeClientAsync(c);
return; return false;
} }
/* If it is our master that's beging disconnected we should make sure /* If it is our master that's beging disconnected we should make sure
@ -1291,7 +1291,7 @@ void freeClient(client *c) {
CLIENT_BLOCKED))) CLIENT_BLOCKED)))
{ {
replicationCacheMaster(MasterInfoFromClient(c), c); replicationCacheMaster(MasterInfoFromClient(c), c);
return; return false;
} }
} }
@ -1370,6 +1370,7 @@ void freeClient(client *c) {
ulock.unlock(); ulock.unlock();
fastlock_free(&c->lock); fastlock_free(&c->lock);
zfree(c); zfree(c);
return true;
} }
/* Schedule a client to free it at a safe time in the serverCron() function. /* Schedule a client to free it at a safe time in the serverCron() function.
@ -1382,28 +1383,37 @@ void freeClientAsync(client *c) {
* may access the list while Redis uses I/O threads. All the other accesses * may access the list while Redis uses I/O threads. All the other accesses
* are in the context of the main thread while the other threads are * are in the context of the main thread while the other threads are
* idle. */ * idle. */
if (c->flags & CLIENT_CLOSE_ASAP || c->flags & CLIENT_LUA) return; if (c->flags & CLIENT_CLOSE_ASAP || c->flags & CLIENT_LUA) return; // check without the lock first
std::lock_guard<decltype(c->lock)> clientlock(c->lock); std::lock_guard<decltype(c->lock)> clientlock(c->lock);
AeLocker lock; AeLocker lock;
lock.arm(c); lock.arm(c);
if (c->flags & CLIENT_CLOSE_ASAP || c->flags & CLIENT_LUA) return; // race condition after we acquire the lock
c->flags |= CLIENT_CLOSE_ASAP; c->flags |= CLIENT_CLOSE_ASAP;
listAddNodeTail(g_pserver->clients_to_close,c); listAddNodeTail(g_pserver->clients_to_close,c);
} }
void freeClientsInAsyncFreeQueue(int iel) { void freeClientsInAsyncFreeQueue(int iel) {
serverAssert(GlobalLocksAcquired());
listIter li; listIter li;
listNode *ln; listNode *ln;
listRewind(g_pserver->clients_to_close,&li); listRewind(g_pserver->clients_to_close,&li);
while((ln = listNext(&li))) { // Store the clients in a temp vector since freeClient will modify this list
std::vector<client*> vecclientsFree;
while((ln = listNext(&li)))
{
client *c = (client*)listNodeValue(ln); client *c = (client*)listNodeValue(ln);
if (c->iel != iel) if (c->iel == iel)
continue; // wrong thread {
vecclientsFree.push_back(c);
listDelNode(g_pserver->clients_to_close, ln);
}
}
for (client *c : vecclientsFree)
{
c->flags &= ~CLIENT_CLOSE_ASAP; c->flags &= ~CLIENT_CLOSE_ASAP;
freeClient(c); freeClient(c);
listDelNode(g_pserver->clients_to_close,ln);
listRewind(g_pserver->clients_to_close,&li);
} }
} }
@ -1551,6 +1561,15 @@ void ProcessPendingAsyncWrites()
std::lock_guard<decltype(c->lock)> lock(c->lock); std::lock_guard<decltype(c->lock)> lock(c->lock);
serverAssert(c->fPendingAsyncWrite); serverAssert(c->fPendingAsyncWrite);
if (c->flags & (CLIENT_CLOSE_ASAP | CLIENT_CLOSE_AFTER_REPLY))
{
c->bufposAsync = 0;
c->buflenAsync = 0;
zfree(c->bufAsync);
c->bufAsync = nullptr;
c->fPendingAsyncWrite = FALSE;
continue;
}
// TODO: Append to end of reply block? // TODO: Append to end of reply block?
@ -1587,8 +1606,36 @@ void ProcessPendingAsyncWrites()
continue; continue;
asyncCloseClientOnOutputBufferLimitReached(c); asyncCloseClientOnOutputBufferLimitReached(c);
if (aeCreateRemoteFileEvent(g_pserver->rgthreadvar[c->iel].el, c->fd, ae_flags, sendReplyToClient, c, FALSE) == AE_ERR) if (c->flags & CLIENT_CLOSE_ASAP)
continue; // We can retry later in the cron continue; // we will never write this so don't post an op
std::atomic_thread_fence(std::memory_order_seq_cst);
if (c->casyncOpsPending == 0)
{
if (FCorrectThread(c))
{
prepareClientToWrite(c, false); // queue an event
}
else
{
// We need to start the write on the client's thread
if (aePostFunction(g_pserver->rgthreadvar[c->iel].el, [c]{
// Install a write handler. Don't do the actual write here since we don't want
// to duplicate the throttling and safety mechanisms of the normal write code
std::lock_guard<decltype(c->lock)> lock(c->lock);
serverAssert(c->casyncOpsPending > 0);
c->casyncOpsPending--;
aeCreateFileEvent(g_pserver->rgthreadvar[c->iel].el, c->fd, AE_WRITABLE|AE_WRITE_THREADSAFE, sendReplyToClient, c);
}, false) == AE_ERR
)
{
// Posting the function failed
continue; // We can retry later in the cron
}
++c->casyncOpsPending; // race is handled by the client lock in the lambda
}
}
} }
} }
@ -1628,13 +1675,15 @@ int handleClientsWithPendingWrites(int iel) {
std::unique_lock<decltype(c->lock)> lock(c->lock); std::unique_lock<decltype(c->lock)> lock(c->lock);
/* Try to write buffers to the client socket. */ /* Try to write buffers to the client socket. */
if (writeToClient(c->fd,c,0) == C_ERR) { if (writeToClient(c->fd,c,0) == C_ERR)
{
if (c->flags & CLIENT_CLOSE_ASAP) if (c->flags & CLIENT_CLOSE_ASAP)
{ {
lock.release(); // still locked lock.release(); // still locked
AeLocker ae; AeLocker ae;
ae.arm(c); ae.arm(c);
freeClient(c); // writeToClient will only async close, but there's no need to wait if (!freeClient(c)) // writeToClient will only async close, but there's no need to wait
c->lock.unlock(); // if we just got put on the async close list, then we need to remove the lock
} }
continue; continue;
} }

View File

@ -143,6 +143,8 @@ int clientSubscriptionsCount(client *c) {
/* Subscribe a client to a channel. Returns 1 if the operation succeeded, or /* Subscribe a client to a channel. Returns 1 if the operation succeeded, or
* 0 if the client was already subscribed to that channel. */ * 0 if the client was already subscribed to that channel. */
int pubsubSubscribeChannel(client *c, robj *channel) { int pubsubSubscribeChannel(client *c, robj *channel) {
serverAssert(GlobalLocksAcquired());
serverAssert(c->lock.fOwnLock());
dictEntry *de; dictEntry *de;
list *clients = NULL; list *clients = NULL;
int retval = 0; int retval = 0;
@ -202,6 +204,7 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify) {
/* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */
int pubsubSubscribePattern(client *c, robj *pattern) { int pubsubSubscribePattern(client *c, robj *pattern) {
serverAssert(GlobalLocksAcquired());
int retval = 0; int retval = 0;
if (listSearchKey(c->pubsub_patterns,pattern) == NULL) { if (listSearchKey(c->pubsub_patterns,pattern) == NULL) {
@ -244,6 +247,7 @@ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) {
/* Unsubscribe from all the channels. Return the number of channels the /* Unsubscribe from all the channels. Return the number of channels the
* client was subscribed to. */ * client was subscribed to. */
int pubsubUnsubscribeAllChannels(client *c, int notify) { int pubsubUnsubscribeAllChannels(client *c, int notify) {
serverAssert(GlobalLocksAcquired());
dictIterator *di = dictGetSafeIterator(c->pubsub_channels); dictIterator *di = dictGetSafeIterator(c->pubsub_channels);
dictEntry *de; dictEntry *de;
int count = 0; int count = 0;
@ -262,6 +266,7 @@ int pubsubUnsubscribeAllChannels(client *c, int notify) {
/* Unsubscribe from all the patterns. Return the number of patterns the /* Unsubscribe from all the patterns. Return the number of patterns the
* client was subscribed from. */ * client was subscribed from. */
int pubsubUnsubscribeAllPatterns(client *c, int notify) { int pubsubUnsubscribeAllPatterns(client *c, int notify) {
serverAssert(GlobalLocksAcquired());
listNode *ln; listNode *ln;
listIter li; listIter li;
int count = 0; int count = 0;
@ -278,6 +283,7 @@ int pubsubUnsubscribeAllPatterns(client *c, int notify) {
/* Publish a message */ /* Publish a message */
int pubsubPublishMessage(robj *channel, robj *message) { int pubsubPublishMessage(robj *channel, robj *message) {
serverAssert(GlobalLocksAcquired());
int receivers = 0; int receivers = 0;
dictEntry *de; dictEntry *de;
listNode *ln; listNode *ln;
@ -293,6 +299,8 @@ int pubsubPublishMessage(robj *channel, robj *message) {
listRewind(list,&li); listRewind(list,&li);
while ((ln = listNext(&li)) != NULL) { while ((ln = listNext(&li)) != NULL) {
client *c = reinterpret_cast<client*>(ln->value); client *c = reinterpret_cast<client*>(ln->value);
if (c->flags & CLIENT_CLOSE_ASAP) // avoid blocking if the write will be ignored
continue;
fastlock_lock(&c->lock); fastlock_lock(&c->lock);
addReplyPubsubMessage(c,channel,message); addReplyPubsubMessage(c,channel,message);
fastlock_unlock(&c->lock); fastlock_unlock(&c->lock);
@ -311,6 +319,8 @@ int pubsubPublishMessage(robj *channel, robj *message) {
(char*)ptrFromObj(channel), (char*)ptrFromObj(channel),
sdslen(szFromObj(channel)),0)) sdslen(szFromObj(channel)),0))
{ {
if (pat->pclient->flags & CLIENT_CLOSE_ASAP)
continue;
fastlock_lock(&pat->pclient->lock); fastlock_lock(&pat->pclient->lock);
addReplyPubsubPatMessage(pat->pclient, addReplyPubsubPatMessage(pat->pclient,
pat->pattern,channel,message); pat->pattern,channel,message);

View File

@ -925,6 +925,7 @@ typedef struct client {
time_t lastinteraction; /* Time of the last interaction, used for timeout */ time_t lastinteraction; /* Time of the last interaction, used for timeout */
time_t obuf_soft_limit_reached_time; time_t obuf_soft_limit_reached_time;
std::atomic<int> flags; /* Client flags: CLIENT_* macros. */ std::atomic<int> flags; /* Client flags: CLIENT_* macros. */
int casyncOpsPending;
int fPendingAsyncWrite; /* NOTE: Not a flag because it is written to outside of the client lock (locked by the global lock instead) */ int fPendingAsyncWrite; /* NOTE: Not a flag because it is written to outside of the client lock (locked by the global lock instead) */
int authenticated; /* Needed when the default user requires auth. */ int authenticated; /* Needed when the default user requires auth. */
int replstate; /* Replication state if this is a slave. */ int replstate; /* Replication state if this is a slave. */
@ -1694,7 +1695,7 @@ void redisSetProcTitle(const char *title);
/* networking.c -- Networking and Client related operations */ /* networking.c -- Networking and Client related operations */
client *createClient(int fd, int iel); client *createClient(int fd, int iel);
void closeTimedoutClients(void); void closeTimedoutClients(void);
void freeClient(client *c); bool freeClient(client *c);
void freeClientAsync(client *c); void freeClientAsync(client *c);
void resetClient(client *c); void resetClient(client *c);
void sendReplyToClient(aeEventLoop *el, int fd, void *privdata, int mask); void sendReplyToClient(aeEventLoop *el, int fd, void *privdata, int mask);
@ -2503,9 +2504,10 @@ void xorDigest(unsigned char *digest, const void *ptr, size_t len);
int populateCommandTableParseFlags(struct redisCommand *c, const char *strflags); int populateCommandTableParseFlags(struct redisCommand *c, const char *strflags);
int moduleGILAcquiredByModule(void); int moduleGILAcquiredByModule(void);
extern bool g_fInCrash;
static inline int GlobalLocksAcquired(void) // Used in asserts to verify all global locks are correctly acquired for a server-thread to operate static inline int GlobalLocksAcquired(void) // Used in asserts to verify all global locks are correctly acquired for a server-thread to operate
{ {
return aeThreadOwnsLock() || moduleGILAcquiredByModule(); return aeThreadOwnsLock() || moduleGILAcquiredByModule() || g_fInCrash;
} }
inline int ielFromEventLoop(const aeEventLoop *eventLoop) inline int ielFromEventLoop(const aeEventLoop *eventLoop)