diff --git a/src/db.cpp b/src/db.cpp index fc786f095..589181e58 100644 --- a/src/db.cpp +++ b/src/db.cpp @@ -1183,17 +1183,10 @@ void scanGenericCommand(client *c, robj_roptr o, unsigned long cursor) { if (o == nullptr && count >= 100) { // Do an async version - const redisDbPersistentDataSnapshot *snapshot = nullptr; - if (!(c->flags & (CLIENT_MULTI | CLIENT_BLOCKED))) - snapshot = c->db->createSnapshot(c->mvccCheckpoint, false /* fOptional */); - if (snapshot != nullptr) - { - aeEventLoop *el = serverTL->el; - blockClient(c, BLOCKED_ASYNC); - redisDb *db = c->db; - sds patCopy = pat ? sdsdup(pat) : nullptr; - sds typeCopy = type ? sdsdup(type) : nullptr; - g_pserver->asyncworkqueue->AddWorkFunction([c, snapshot, cursor, count, keys, el, db, patCopy, typeCopy, use_pattern]{ + if (c->asyncCommand( + [c, keys, pat, type, cursor, count, use_pattern] (const redisDbPersistentDataSnapshot *snapshot, const std::vector &) { + sds patCopy = pat ? sdsdup(pat) : nullptr; + sds typeCopy = type ? sdsdup(type) : nullptr; auto cursorResult = snapshot->scan_threadsafe(cursor, count, typeCopy, keys); if (use_pattern) { listNode *ln = listFirst(keys); @@ -1214,30 +1207,17 @@ void scanGenericCommand(client *c, robj_roptr o, unsigned long cursor) { sdsfree(patCopy); if (typeCopy != nullptr) sdsfree(typeCopy); - - aePostFunction(el, [c, snapshot, keys, db, cursorResult, use_pattern]{ - aeReleaseLock(); // we need to lock with coordination of the client - - std::unique_locklock)> lock(c->lock); - AeLocker locker; - locker.arm(c); - - unblockClient(c); - mstime_t timeScanFilter; - latencyStartMonitor(timeScanFilter); - scanFilterAndReply(c, keys, nullptr, nullptr, false, nullptr, cursorResult); - latencyEndMonitor(timeScanFilter); - latencyAddSampleIfNeeded("scan-async-filter", timeScanFilter); - - locker.disarm(); - lock.unlock(); - - db->endSnapshotAsync(snapshot); - listSetFreeMethod(keys,decrRefCountVoid); - listRelease(keys); - aeAcquireLock(); - }); - }); + mstime_t timeScanFilter; + latencyStartMonitor(timeScanFilter); + scanFilterAndReply(c, keys, nullptr, nullptr, false, nullptr, cursorResult); + latencyEndMonitor(timeScanFilter); + latencyAddSampleIfNeeded("scan-async-filter", timeScanFilter); + }, + [keys] (const redisDbPersistentDataSnapshot *) { + listSetFreeMethod(keys,decrRefCountVoid); + listRelease(keys); + } + )) { return; } } diff --git a/src/server.cpp b/src/server.cpp index 0b2c5ed68..4f330ddd7 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -4952,6 +4952,46 @@ bool client::postFunction(std::function fn, bool fLock) { }, fLock) == AE_OK; } +std::vector clientArgs(client *c) { + std::vector args; + for (int j = 1; j < c->argc; j++) { + args.push_back(robj_sharedptr(c->argv[j])); + } + return args; +} + +bool client::asyncCommand(std::function &)> &&mainFn, + std::function &&postFn) +{ + serverAssert(FCorrectThread(this)); + const redisDbPersistentDataSnapshot *snapshot = nullptr; + if (!(this->flags & (CLIENT_MULTI | CLIENT_BLOCKED))) + snapshot = this->db->createSnapshot(this->mvccCheckpoint, false /* fOptional */); + if (snapshot == nullptr) { + return false; + } + aeEventLoop *el = serverTL->el; + blockClient(this, BLOCKED_ASYNC); + g_pserver->asyncworkqueue->AddWorkFunction([el, this, mainFn, postFn, snapshot] { + std::vector args = clientArgs(this); + aePostFunction(el, [this, mainFn, postFn, snapshot, args] { + aeReleaseLock(); + std::unique_locklock)> lock(this->lock); + AeLocker locker; + locker.arm(this); + unblockClient(this); + mainFn(snapshot, args); + locker.disarm(); + lock.unlock(); + if (postFn) + postFn(snapshot); + this->db->endSnapshotAsync(snapshot); + aeAcquireLock(); + }); + }); + return true; +} + /* ====================== Error lookup and execution ===================== */ void incrementErrorCount(const char *fullerr, size_t namelen) { diff --git a/src/server.h b/src/server.h index adf16dc41..70709efc4 100644 --- a/src/server.h +++ b/src/server.h @@ -1661,6 +1661,8 @@ struct client { // post a function from a non-client thread to run on its client thread bool postFunction(std::function fn, bool fLock = true); size_t argv_len_sum() const; + bool asyncCommand(std::function &)> &&mainFn, + std::function &&postFn = nullptr); }; struct saveparam { diff --git a/src/t_string.cpp b/src/t_string.cpp index c1f2e134a..a09cfd919 100644 --- a/src/t_string.cpp +++ b/src/t_string.cpp @@ -29,6 +29,7 @@ #include "server.h" #include /* isnan(), isinf() */ +#include "aelocker.h" /* Forward declarations */ int getGenericCommand(client *c); @@ -523,24 +524,37 @@ void getrangeCommand(client *c) { } } -void mgetCommand(client *c) { - int j; - - addReplyArrayLen(c,c->argc-1); - for (j = 1; j < c->argc; j++) { - robj_roptr o = lookupKeyRead(c->db,c->argv[j]); - if (o == nullptr) { +void mgetCore(client *c, robj **keys, int count, const redisDbPersistentDataSnapshot *snapshot = nullptr) { + addReplyArrayLen(c,count); + for (int i = 0; i < count; i++) { + robj_roptr o; + if (snapshot) + o = snapshot->find_cached_threadsafe(szFromObj(keys[i])).val(); + else + o = lookupKeyRead(c->db,keys[i]); + if (o == nullptr || o->type != OBJ_STRING) { addReplyNull(c); } else { - if (o->type != OBJ_STRING) { - addReplyNull(c); - } else { - addReplyBulk(c,o); - } + addReplyBulk(c,o); } } } +void mgetCommand(client *c) { + // Do async version for large number of arguments + if (c->argc > 100) { + if (c->asyncCommand( + [c] (const redisDbPersistentDataSnapshot *snapshot, const std::vector &keys) { + mgetCore(c, (robj **)keys.data(), keys.size(), snapshot); + } + )) { + return; + } + } + + mgetCore(c, c->argv + 1, c->argc - 1); +} + void msetGenericCommand(client *c, int nx) { int j;