diff --git a/src/db.cpp b/src/db.cpp index 2f053d975..586f87e01 100644 --- a/src/db.cpp +++ b/src/db.cpp @@ -647,7 +647,7 @@ bool redisDbPersistentData::iterate(std::function fn) } dictReleaseIterator(di); - if (m_pdbSnapshot != nullptr) + if (fResult && m_pdbSnapshot != nullptr) { fResult = m_pdbSnapshot->iterate([&](const char *key){ // Before passing off to the user we need to make sure it's not already in the @@ -671,7 +671,7 @@ bool redisDbPersistentData::iterate(std::function fn) bool redisDbPersistentData::iterate_threadsafe(std::function fn) const { - dictIterator *di = dictGetSafeIterator(m_pdict); + dictIterator *di = dictGetIterator(m_pdict); dictEntry *de = nullptr; bool fResult = true; @@ -685,7 +685,7 @@ bool redisDbPersistentData::iterate_threadsafe(std::functioniterate_threadsafe([&](const char *key, robj_roptr o){ // Before passing off to the user we need to make sure it's not already in the @@ -707,7 +707,7 @@ bool redisDbPersistentData::iterate_threadsafe(std::function fn) const { - dictIterator *di = dictGetSafeIterator(m_pdict); + dictIterator *di = dictGetIterator(m_pdict); dictEntry *de = nullptr; bool fResult = true; while((de = dictNext(di)) != nullptr) @@ -720,7 +720,7 @@ bool redisDbPersistentData::iterate(std::function fn) const } dictReleaseIterator(di); - if (m_pdbSnapshot != nullptr) + if (fResult && m_pdbSnapshot != nullptr) { fResult = m_pdbSnapshot->iterate([&](const char *key){ // Before passing off to the user we need to make sure it's not already in the @@ -740,14 +740,17 @@ bool redisDbPersistentData::iterate(std::function fn) const return fResult; } -void keysCommandCore(client *c, const redisDbPersistentData *db, sds pattern) +client *createFakeClient(void); +void freeFakeClient(client *); +void keysCommandCore(client *cIn, const redisDbPersistentData *db, sds pattern) { int plen = sdslen(pattern), allkeys; unsigned long numkeys = 0; - aeAcquireLock(); - void *replylen = addReplyDeferredLenAsync(c); - aeReleaseLock(); + client *c = createFakeClient(); + c->flags |= CLIENT_FORCE_REPLY; + + void *replylen = addReplyDeferredLen(c); allkeys = (pattern[0] == '*' && pattern[1] == '\0'); db->iterate([&](const char *key)->bool { @@ -756,17 +759,28 @@ void keysCommandCore(client *c, const redisDbPersistentData *db, sds pattern) if (allkeys || stringmatchlen(pattern,plen,key,sdslen(key),0)) { keyobj = createStringObject(key,sdslen(key)); if (!keyIsExpired(c->db,keyobj)) { - aeAcquireLock(); - addReplyBulkAsync(c,keyobj); - aeReleaseLock(); + addReplyBulk(c,keyobj); numkeys++; } decrRefCount(keyobj); } - return true; + return !(cIn->flags.load(std::memory_order_relaxed) & CLIENT_CLOSE_ASAP); }); - setDeferredArrayLenAsync(c,replylen,numkeys); + setDeferredArrayLen(c,replylen,numkeys); + + aeAcquireLock(); + addReplyProtoAsync(cIn, c->buf, c->bufpos); + listIter li; + listNode *ln; + listRewind(c->reply, &li); + while ((ln = listNext(&li)) != nullptr) + { + clientReplyBlock *block = (clientReplyBlock*)listNodeValue(ln); + addReplyProtoAsync(cIn, block->buf(), block->used); + } + aeReleaseLock(); + freeFakeClient(c); } int prepareClientToWrite(client *c, bool fAsync); @@ -2076,7 +2090,7 @@ void redisDbPersistentData::ensure(const char *sdsKey, dictEntry **pde) dictEntry *deTombstone = dictFind(m_pdictTombstone, sdsKey); if (deTombstone == nullptr) { - auto itr = m_pdbSnapshot->find(sdsKey); + auto itr = m_pdbSnapshot->find_threadsafe(sdsKey); if (itr == m_pdbSnapshot->end()) return; // not found if (itr.val()->getrefcount(std::memory_order_relaxed) == OBJ_SHARED_REFCOUNT) @@ -2165,6 +2179,8 @@ void redisDbPersistentData::processChanges() const redisDbPersistentData *redisDbPersistentData::createSnapshot(uint64_t mvccCheckpoint) { serverAssert(GlobalLocksAcquired()); + serverAssert(m_refCount == 0); // do not call this on a snapshot + bool fNested = false; if (m_spdbSnapshotHOLDER != nullptr) { if (mvccCheckpoint <= m_spdbSnapshotHOLDER->mvccCheckpoint) @@ -2173,14 +2189,15 @@ const redisDbPersistentData *redisDbPersistentData::createSnapshot(uint64_t mvcc return m_spdbSnapshotHOLDER.get(); } serverLog(LL_WARNING, "Nested snapshot created"); + fNested = true; } auto spdb = std::unique_ptr(new (MALLOC_LOCAL) redisDbPersistentData()); spdb->m_fAllChanged = false; spdb->m_fTrackingChanges = 0; spdb->m_pdict = m_pdict; + spdb->m_pdict->iterators++; spdb->m_pdictTombstone = m_pdictTombstone; - spdb->m_pdict->iterators++; // fake an iterator so it doesn't rehash spdb->m_spdbSnapshotHOLDER = std::move(m_spdbSnapshotHOLDER); spdb->m_pdbSnapshot = m_pdbSnapshot; spdb->m_refCount = 1; @@ -2191,6 +2208,8 @@ const redisDbPersistentData *redisDbPersistentData::createSnapshot(uint64_t mvcc m_pdictTombstone = dictCreate(&dbDictType, this); m_setexpire = new (MALLOC_LOCAL) expireset(); + serverAssert(spdb->m_pdict->iterators == 1); + m_spdbSnapshotHOLDER = std::move(spdb); m_pdbSnapshot = m_spdbSnapshotHOLDER.get(); @@ -2234,15 +2253,19 @@ void redisDbPersistentData::endSnapshot(const redisDbPersistentData *psnapshot) return; } + // Alright we're ready to be free'd, but first dump all the refs on our child snapshots + if (m_spdbSnapshotHOLDER->m_refCount == 1) + recursiveFreeSnapshots(m_spdbSnapshotHOLDER.get()); + m_spdbSnapshotHOLDER->m_refCount--; if (m_spdbSnapshotHOLDER->m_refCount > 0) return; serverAssert(m_spdbSnapshotHOLDER->m_refCount == 0); + serverAssert((m_refCount == 0 && m_pdict->iterators == 0) || (m_refCount != 0 && m_pdict->iterators == 1)); - // Alright we're ready to be free'd, but first dump all the refs on our child snapshots - recursiveFreeSnapshots(m_spdbSnapshotHOLDER.get()); - - m_spdbSnapshotHOLDER->m_pdict->iterators--; + serverAssert(m_spdbSnapshotHOLDER->m_pdict->iterators == 1); // All iterators should have been free'd except the fake one from createSnapshot + if (m_refCount == 0) + m_spdbSnapshotHOLDER->m_pdict->iterators--; if (m_pdbSnapshot == nullptr) { @@ -2311,9 +2334,15 @@ void redisDbPersistentData::endSnapshot(const redisDbPersistentData *psnapshot) { m_pdbSnapshot = nullptr; } + + // Fixup the about to free'd snapshots iterator count so the dtor doesn't complain + if (m_refCount) + m_spdbSnapshotHOLDER->m_pdict->iterators--; + m_spdbSnapshotHOLDER = std::move(m_spdbSnapshotHOLDER->m_spdbSnapshotHOLDER); serverAssert(m_spdbSnapshotHOLDER != nullptr || m_pdbSnapshot == nullptr); serverAssert(m_pdbSnapshot == m_spdbSnapshotHOLDER.get() || m_pdbSnapshot == nullptr); + serverAssert((m_refCount == 0 && m_pdict->iterators == 0) || (m_refCount != 0 && m_pdict->iterators == 1)); } redisDbPersistentData::~redisDbPersistentData() @@ -2321,8 +2350,51 @@ redisDbPersistentData::~redisDbPersistentData() serverAssert(m_spdbSnapshotHOLDER == nullptr); serverAssert(m_pdbSnapshot == nullptr); serverAssert(m_refCount == 0); + serverAssert(m_pdict->iterators == 0); dictRelease(m_pdict); if (m_pdictTombstone) dictRelease(m_pdictTombstone); delete m_setexpire; +} + +dict_iter redisDbPersistentData::random() +{ + if (size() == 0) + return dict_iter(nullptr); + if (m_pdbSnapshot != nullptr && m_pdbSnapshot->size() > 0) + { + dict_iter iter(nullptr); + double pctInSnapshot = (double)m_pdbSnapshot->size() / (size() + m_pdbSnapshot->size()); + double randval = (double)rand()/RAND_MAX; + if (randval <= pctInSnapshot) + { + iter = m_pdbSnapshot->random_threadsafe(); + ensure(iter.key()); + dictEntry *de = dictFind(m_pdict, iter.key()); + return dict_iter(de); + } + } + dictEntry *de = dictGetRandomKey(m_pdict); + if (de != nullptr) + ensure((const char*)dictGetKey(de), &de); + return dict_iter(de); +} + +dict_iter redisDbPersistentData::random_threadsafe() const +{ + if (size() == 0) + return dict_iter(nullptr); + if (m_pdbSnapshot != nullptr && m_pdbSnapshot->size() > 0) + { + dict_iter iter(nullptr); + double pctInSnapshot = (double)m_pdbSnapshot->size() / (size() + m_pdbSnapshot->size()); + double randval = (double)rand()/RAND_MAX; + if (randval <= pctInSnapshot) + { + return m_pdbSnapshot->random_threadsafe(); + } + } + serverAssert(dictSize(m_pdict) > 0); + dictEntry *de = dictGetRandomKey(m_pdict); + return dict_iter(de); } \ No newline at end of file diff --git a/src/server.h b/src/server.h index d8c70c5b9..46348969e 100644 --- a/src/server.h +++ b/src/server.h @@ -1229,28 +1229,7 @@ public: return find(szFromObj(key)); } - dict_iter random() - { - if (size() == 0) - return dict_iter(nullptr); - if (m_pdbSnapshot != nullptr && m_pdbSnapshot->size() > 0) - { - dict_iter iter(nullptr); - double pctInSnapshot = (double)m_pdbSnapshot->size() / (size() + m_pdbSnapshot->size()); - double randval = (double)rand()/RAND_MAX; - if (randval <= pctInSnapshot) - { - iter = m_pdbSnapshot->random(); - ensure(iter.key()); - dictEntry *de = dictFind(m_pdict, iter.key()); - return dict_iter(de); - } - } - dictEntry *de = dictGetRandomKey(m_pdict); - if (de != nullptr) - ensure((const char*)dictGetKey(de), &de); - return dict_iter(de); - } + dict_iter random(); const expireEntry &random_expire() { @@ -1308,6 +1287,14 @@ private: void storeDatabase(); void storeKey(const char *key, size_t cchKey, robj *o); void recursiveFreeSnapshots(redisDbPersistentData *psnapshot); + + // These do not call ENSURE and so may have a NULL object + dict_iter random_threadsafe() const; + dict_iter find_threadsafe(const char *key) const + { + dictEntry *de = dictFind(m_pdict, key); + return dict_iter(de); + } // Keyspace dict *m_pdict = nullptr; /* The keyspace for this DB */ @@ -1324,7 +1311,7 @@ private: // These two pointers are the same, UNLESS the database has been cleared. // in which case m_pdbSnapshot is NULL and we continue as though we weren' // in a snapshot - redisDbPersistentData *m_pdbSnapshot = nullptr; + const redisDbPersistentData *m_pdbSnapshot = nullptr; std::unique_ptr m_spdbSnapshotHOLDER; int m_refCount = 0; };