diff --git a/src/AsyncWorkQueue.cpp b/src/AsyncWorkQueue.cpp index 91fc2d2d2..be85f5ac2 100644 --- a/src/AsyncWorkQueue.cpp +++ b/src/AsyncWorkQueue.cpp @@ -73,6 +73,18 @@ bool AsyncWorkQueue::removeClientAsyncWrites(client *c) return fFound; } +void AsyncWorkQueue::shutdown() +{ + std::unique_lock lock(m_mutex); + serverAssert(!GlobalLocksAcquired()); + m_fQuitting = true; + m_cvWakeup.notify_all(); + lock.unlock(); + + for (auto &thread : m_vecthreads) + thread.join(); +} + void AsyncWorkQueue::abandonThreads() { std::unique_lock lock(m_mutex); diff --git a/src/AsyncWorkQueue.h b/src/AsyncWorkQueue.h index 2a413404d..1f019324a 100644 --- a/src/AsyncWorkQueue.h +++ b/src/AsyncWorkQueue.h @@ -34,5 +34,7 @@ public: void AddWorkFunction(std::function &&fnAsync, bool fHiPri = false); bool removeClientAsyncWrites(struct client *c); + void shutdown(); + void abandonThreads(); }; \ No newline at end of file diff --git a/src/db.cpp b/src/db.cpp index 62d52a7ca..7d1b37672 100644 --- a/src/db.cpp +++ b/src/db.cpp @@ -950,6 +950,24 @@ int parseScanCursorOrReply(client *c, robj *o, unsigned long *cursor) { return C_OK; } + +static bool filterKey(robj_roptr kobj, sds pat, int patlen) +{ + bool filter = false; + if (sdsEncodedObject(kobj)) { + if (!stringmatchlen(pat, patlen, szFromObj(kobj), sdslen(szFromObj(kobj)), 0)) + filter = true; + } else { + char buf[LONG_STR_SIZE]; + int len; + + serverAssert(kobj->encoding == OBJ_ENCODING_INT); + len = ll2string(buf,sizeof(buf),(long)ptrFromObj(kobj)); + if (!stringmatchlen(pat, patlen, buf, len, 0)) filter = true; + } + return filter; +} + /* This command implements SCAN, HSCAN and SSCAN commands. * If object 'o' is passed, then it must be a Hash, Set or Zset object, otherwise * if 'o' is NULL the command will operate on the dictionary associated with @@ -961,10 +979,10 @@ int parseScanCursorOrReply(client *c, robj *o, unsigned long *cursor) { * * In the case of a Hash object the function returns both the field and value * of every element on the Hash. */ +void scanFilterAndReply(client *c, list *keys, sds pat, sds type, int use_pattern, robj_roptr o, unsigned long cursor); void scanGenericCommand(client *c, robj_roptr o, unsigned long cursor) { int i, j; list *keys = listCreate(); - listNode *node, *nextnode; long count = 10; sds pat = NULL; sds type = NULL; @@ -1014,6 +1032,59 @@ void scanGenericCommand(client *c, robj_roptr o, unsigned long cursor) { } } + if (o == nullptr && count >= 100) + { + // Do an async version + const redisDbPersistentDataSnapshot *snapshot = nullptr; + if (!(c->flags & (CLIENT_MULTI | CLIENT_BLOCKED))) + snapshot = c->db->createSnapshot(c->mvccCheckpoint, true /* fOptional */); + if (snapshot != nullptr) + { + aeEventLoop *el = serverTL->el; + blockClient(c, BLOCKED_ASYNC); + redisDb *db = c->db; + sds patCopy = pat ? sdsdup(pat) : nullptr; + sds typeCopy = type ? sdsdup(type) : nullptr; + g_pserver->asyncworkqueue->AddWorkFunction([c, snapshot, cursor, count, keys, el, db, patCopy, typeCopy, use_pattern]{ + auto cursorResult = snapshot->scan_threadsafe(cursor, count, typeCopy, keys); + if (use_pattern) { + listNode *ln = listFirst(keys); + int patlen = sdslen(patCopy); + while (ln != nullptr) + { + listNode *next = ln->next; + if (filterKey((robj*)listNodeValue(ln), patCopy, patlen)) + { + listDelNode(keys, ln); + } + ln = next; + } + } + if (patCopy != nullptr) + sdsfree(patCopy); + if (typeCopy != nullptr) + sdsfree(typeCopy); + + aePostFunction(el, [c, snapshot, keys, db, cursorResult, use_pattern]{ + aeReleaseLock(); // we need to lock with coordination of the client + + std::unique_locklock)> lock(c->lock); + AeLocker locker; + locker.arm(c); + + unblockClient(c); + scanFilterAndReply(c, keys, nullptr, nullptr, false, nullptr, cursorResult); + + db->endSnapshot(snapshot); + listSetFreeMethod(keys,decrRefCountVoid); + listRelease(keys); + aeAcquireLock(); + }); + }); + return; + } + } + /* Step 2: Iterate the collection. * * Note that if the object is encoded with a ziplist, intset, or any other @@ -1038,23 +1109,30 @@ void scanGenericCommand(client *c, robj_roptr o, unsigned long cursor) { } if (ht) { - void *privdata[2]; - /* We set the max number of iterations to ten times the specified - * COUNT, so if the hash table is in a pathological state (very - * sparsely populated) we avoid to block too much time at the cost - * of returning no or very few elements. */ - long maxiterations = count*10; + if (ht == c->db->dictUnsafeKeyOnly()) + { + cursor = c->db->scan_threadsafe(cursor, count, nullptr, keys); + } + else + { + void *privdata[2]; + /* We set the max number of iterations to ten times the specified + * COUNT, so if the hash table is in a pathological state (very + * sparsely populated) we avoid to block too much time at the cost + * of returning no or very few elements. */ + long maxiterations = count*10; - /* We pass two pointers to the callback: the list to which it will - * add new elements, and the object containing the dictionary so that - * it is possible to fetch more data in a type-dependent way. */ - privdata[0] = keys; - privdata[1] = o.unsafe_robjcast(); - do { - cursor = dictScan(ht, cursor, scanCallback, NULL, privdata); - } while (cursor && - maxiterations-- && - listLength(keys) < (unsigned long)count); + /* We pass two pointers to the callback: the list to which it will + * add new elements, and the object containing the dictionary so that + * it is possible to fetch more data in a type-dependent way. */ + privdata[0] = keys; + privdata[1] = o.unsafe_robjcast(); + do { + cursor = dictScan(ht, cursor, scanCallback, NULL, privdata); + } while (cursor && + maxiterations-- && + listLength(keys) < (unsigned long)count); + } } else if (o->type == OBJ_SET) { int pos = 0; int64_t ll; @@ -1080,6 +1158,18 @@ void scanGenericCommand(client *c, robj_roptr o, unsigned long cursor) { serverPanic("Not handled encoding in SCAN."); } + scanFilterAndReply(c, keys, pat, type, use_pattern, o, cursor); + +cleanup: + listSetFreeMethod(keys,decrRefCountVoid); + listRelease(keys); +} + +void scanFilterAndReply(client *c, list *keys, sds pat, sds type, int use_pattern, robj_roptr o, unsigned long cursor) +{ + listNode *node, *nextnode; + int patlen = (pat != nullptr) ? sdslen(pat) : 0; + /* Step 3: Filter elements. */ node = listFirst(keys); while (node) { @@ -1089,17 +1179,8 @@ void scanGenericCommand(client *c, robj_roptr o, unsigned long cursor) { /* Filter element if it does not match the pattern. */ if (!filter && use_pattern) { - if (sdsEncodedObject(kobj)) { - if (!stringmatchlen(pat, patlen, szFromObj(kobj), sdslen(szFromObj(kobj)), 0)) - filter = 1; - } else { - char buf[LONG_STR_SIZE]; - int len; - - serverAssert(kobj->encoding == OBJ_ENCODING_INT); - len = ll2string(buf,sizeof(buf),(long)ptrFromObj(kobj)); - if (!stringmatchlen(pat, patlen, buf, len, 0)) filter = 1; - } + if (filterKey(kobj, pat, patlen)) + filter = 1; } /* Filter an element if it isn't the type we want. */ @@ -1144,10 +1225,6 @@ void scanGenericCommand(client *c, robj_roptr o, unsigned long cursor) { decrRefCount(kobj); listDelNode(keys, node); } - -cleanup: - listSetFreeMethod(keys,decrRefCountVoid); - listRelease(keys); } /* The SCAN command completely relies on scanGenericCommand. */ diff --git a/src/dict.cpp b/src/dict.cpp index 0e8827165..ef7365fdb 100644 --- a/src/dict.cpp +++ b/src/dict.cpp @@ -952,7 +952,7 @@ unsigned long dictScan(dict *d, /* Having a safe iterator means no rehashing can happen, see _dictRehashStep. * This is needed in case the scan callback tries to do dictFind or alike. */ - d->iterators++; + __atomic_fetch_add(&d->iterators, 1, __ATOMIC_SEQ_CST); if (!dictIsRehashing(d)) { t0 = &(d->ht[0]); @@ -1021,7 +1021,7 @@ unsigned long dictScan(dict *d, } /* undo the ++ at the top */ - d->iterators--; + __atomic_fetch_sub(&d->iterators, 1, __ATOMIC_SEQ_CST); return v; } diff --git a/src/server.cpp b/src/server.cpp index ce817947d..490490317 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -4097,11 +4097,13 @@ int processCommand(client *c, int callFlags, AeLocker &locker) { queueMultiCommand(c); addReply(c,shared.queued); } else { +#if 0 if (cserver.cthreads >= 2 && !g_fTestMode && g_pserver->m_pstorageFactory == nullptr && listLength(g_pserver->monitors) == 0 && c->cmd->proc == getCommand) { if (getCommandAsync(c)) return C_OK; } +#endif locker.arm(c); incrementMvccTstamp(); call(c,callFlags); @@ -4226,6 +4228,13 @@ int prepareForShutdown(int flags) { /* Close the listening sockets. Apparently this allows faster restarts. */ closeListeningSockets(1); + if (g_pserver->asyncworkqueue) + { + aeReleaseLock(); + g_pserver->asyncworkqueue->shutdown(); + aeAcquireLock(); + } + for (int iel = 0; iel < cserver.cthreads; ++iel) { aePostFunction(g_pserver->rgthreadvar[iel].el, [iel]{ diff --git a/src/server.h b/src/server.h index 008b47b6e..0632af792 100644 --- a/src/server.h +++ b/src/server.h @@ -1396,6 +1396,8 @@ public: bool FWillFreeChildDebug() const { return m_spdbSnapshotHOLDER != nullptr; } bool iterate_threadsafe(std::function fn, bool fKeyOnly = false, bool fCacheOnly = false) const; + unsigned long scan_threadsafe(unsigned long iterator, long count, sds type, list *keys) const; + using redisDbPersistentData::createSnapshot; using redisDbPersistentData::endSnapshot; using redisDbPersistentData::endSnapshotAsync; diff --git a/src/snapshot.cpp b/src/snapshot.cpp index 0f5bb079e..38ab9b7b2 100644 --- a/src/snapshot.cpp +++ b/src/snapshot.cpp @@ -61,8 +61,11 @@ const redisDbPersistentDataSnapshot *redisDbPersistentData::createSnapshot(uint6 } // See if we have too many levels and can bail out of this to reduce load - if (fOptional && (levels >= 4)) + if (fOptional && (levels >= 6)) + { + serverLog(LL_DEBUG, "Snapshot nesting too deep, abondoning"); return nullptr; + } auto spdb = std::unique_ptr(new (MALLOC_LOCAL) redisDbPersistentDataSnapshot()); @@ -394,6 +397,62 @@ dict_iter redisDbPersistentData::find_cached_threadsafe(const char *key) const return dict_iter(de); } +struct scan_callback_data +{ + dict *dictTombstone; + sds type; + list *keys; +}; +void snapshot_scan_callback(void *privdata, const dictEntry *de) +{ + scan_callback_data *data = (scan_callback_data*)privdata; + if (data->dictTombstone != nullptr && dictFind(data->dictTombstone, dictGetKey(de)) != nullptr) + return; + + sds sdskey = (sds)dictGetKey(de); + if (data->type != nullptr) + { + if (strcasecmp(data->type, getObjectTypeName((robj*)dictGetVal(de))) != 0) + return; + } + listAddNodeHead(data->keys, createStringObject(sdskey, sdslen(sdskey))); +} +unsigned long redisDbPersistentDataSnapshot::scan_threadsafe(unsigned long iterator, long count, sds type, list *keys) const +{ + unsigned long iteratorReturn = 0; + + scan_callback_data data; + data.dictTombstone = m_pdictTombstone; + data.keys = keys; + data.type = type; + + const redisDbPersistentDataSnapshot *psnapshot; + __atomic_load(&m_pdbSnapshot, &psnapshot, __ATOMIC_ACQUIRE); + if (psnapshot != nullptr) + { + // Always process the snapshot first as we assume its bigger than we are + iteratorReturn = psnapshot->scan_threadsafe(iterator, count, type, keys); + + // Just catch up with our snapshot + do + { + iterator = dictScan(m_pdict, iterator, snapshot_scan_callback, nullptr, &data); + } while (iterator != 0 && (iterator < iteratorReturn || iteratorReturn == 0)); + } + else + { + long maxiterations = count * 10; // allow more iterations than keys for sparse tables + iteratorReturn = iterator; + do { + iteratorReturn = dictScan(m_pdict, iteratorReturn, snapshot_scan_callback, NULL, &data); + } while (iteratorReturn && + maxiterations-- && + listLength(keys) < (unsigned long)count); + } + + return iteratorReturn; +} + bool redisDbPersistentDataSnapshot::iterate_threadsafe(std::function fn, bool fKeyOnly, bool fCacheOnly) const { // Take the size so we can ensure we visited every element exactly once