diff --git a/src/db.cpp b/src/db.cpp index 12c23ca21..e2a50164c 100644 --- a/src/db.cpp +++ b/src/db.cpp @@ -338,6 +338,7 @@ int dbMerge(redisDb *db, robj *key, robj *val, int fReplace) * The client 'c' argument may be set to NULL if the operation is performed * in a context where there is no clear client performing the operation. */ void genericSetKey(client *c, redisDb *db, robj *key, robj *val, int keepttl, int signal) { + db->prepOverwriteForSnapshot(szFromObj(key)); if (!dbAddCore(db, key, val)) { dbOverwrite(db, key, val, !keepttl); } @@ -2358,6 +2359,24 @@ bool redisDbPersistentData::insert(char *key, robj *o, bool fAssumeNew) return (res == DICT_OK); } +// This is a performance tool to prevent us copying over an object we're going to overwrite anyways +void redisDbPersistentData::prepOverwriteForSnapshot(char *key) +{ + if (g_pserver->maxmemory_policy & MAXMEMORY_FLAG_LFU) + return; + + if (m_pdbSnapshot != nullptr) + { + auto itr = m_pdbSnapshot->find_cached_threadsafe(key); + if (itr.key() != nullptr) + { + sds keyNew = sdsdupshared(itr.key()); + if (dictAdd(m_pdictTombstone, keyNew, (void*)dictHashKey(m_pdict, key)) != DICT_OK) + sdsfree(keyNew); + } + } +} + void redisDbPersistentData::tryResize() { if (htNeedsResize(m_pdict)) diff --git a/src/server.h b/src/server.h index 597ac8310..4f9375834 100644 --- a/src/server.h +++ b/src/server.h @@ -1326,6 +1326,9 @@ public: void setExpire(robj *key, robj *subkey, long long when); void setExpire(expireEntry &&e); void initialize(); + void prepOverwriteForSnapshot(char *key); + + bool FRehashing() const { return dictIsRehashing(m_pdict) || dictIsRehashing(m_pdictTombstone); } void setStorageProvider(StorageCache *pstorage); @@ -1527,6 +1530,8 @@ struct redisDb : public redisDbPersistentDataSnapshot using redisDbPersistentData::dictUnsafeKeyOnly; using redisDbPersistentData::resortExpire; using redisDbPersistentData::prefetchKeysAsync; + using redisDbPersistentData::prepOverwriteForSnapshot; + using redisDbPersistentData::FRehashing; public: expireset::setiter expireitr;