diff --git a/src/db.cpp b/src/db.cpp index 8253e531f..0b294b4ce 100644 --- a/src/db.cpp +++ b/src/db.cpp @@ -208,7 +208,7 @@ robj *lookupKeyWriteOrReply(client *c, robj *key, robj *reply) { bool dbAddCore(redisDb *db, robj *key, robj *val) { serverAssert(!val->FExpires()); - sds copy = sdsdup(szFromObj(key)); + sds copy = sdsdupshared(szFromObj(key)); bool fInserted = db->insert(copy, val); if (g_pserver->fActiveReplica) val->mvcc_tstamp = key->mvcc_tstamp = getMvccTstamp(); @@ -1414,12 +1414,12 @@ void setExpire(client *c, redisDb *db, robj *key, expireEntry &&e) /* Return the expire time of the specified key, or null if no expire * is associated with this key (i.e. the key is non volatile) */ -expireEntry *redisDbPersistentDataSnapshot::getExpire(robj_roptr key) { +expireEntry *redisDbPersistentDataSnapshot::getExpire(const char *key) { /* No expire? return ASAP */ if (expireSize() == 0) return nullptr; - auto itr = find_threadsafe(szFromObj(key)); + auto itr = find_threadsafe(key); if (itr == nullptr) return nullptr; if (!itr.val()->FExpires()) @@ -1429,7 +1429,7 @@ expireEntry *redisDbPersistentDataSnapshot::getExpire(robj_roptr key) { return itrExpire.operator->(); } -const expireEntry *redisDbPersistentDataSnapshot::getExpire(robj_roptr key) const +const expireEntry *redisDbPersistentDataSnapshot::getExpire(const char *key) const { return const_cast(this)->getExpire(key); } @@ -2041,28 +2041,29 @@ void redisDbPersistentData::ensure(const char *sdsKey, dictEntry **pde) if (itr == m_pdbSnapshot->end()) return; // not found + sds keyNew = sdsdupshared(itr.key()); // note: we use the iterator's key because the sdsKey may not be a shared string if (itr.val() != nullptr) { if (itr.val()->getrefcount(std::memory_order_relaxed) == OBJ_SHARED_REFCOUNT) { - dictAdd(m_pdict, sdsdup(sdsKey), itr.val()); + dictAdd(m_pdict, keyNew, itr.val()); } else { sds strT = serializeStoredObject(itr.val()); robj *objNew = deserializeStoredObject(this, sdsKey, strT, sdslen(strT)); sdsfree(strT); - dictAdd(m_pdict, sdsdup(sdsKey), objNew); + dictAdd(m_pdict, keyNew, objNew); serverAssert(objNew->getrefcount(std::memory_order_relaxed) == 1); serverAssert(objNew->mvcc_tstamp == itr.val()->mvcc_tstamp); } } else { - dictAdd(m_pdict, sdsdup(sdsKey), nullptr); + dictAdd(m_pdict, keyNew, nullptr); } *pde = dictFind(m_pdict, sdsKey); - dictAdd(m_pdictTombstone, sdsdup(sdsKey), nullptr); + dictAdd(m_pdictTombstone, sdsdupshared(itr.key()), nullptr); } } diff --git a/src/sds.h b/src/sds.h index e9817c6a2..88fad2cd5 100644 --- a/src/sds.h +++ b/src/sds.h @@ -335,15 +335,17 @@ int sdsTest(int argc, char *argv[]); class sdsview { - const char *m_str; +protected: + sds m_str = nullptr; + sdsview() = default; // Not allowed to create a sdsview directly with a nullptr public: sdsview(sds str) - : m_str((const char*) str) + : m_str(str) {} sdsview(const char *str) - : m_str(str) + : m_str((sds)str) {} bool operator<(const sdsview &other) const @@ -374,6 +376,60 @@ public: explicit operator const char*() const { return m_str; } }; +class sdsstring : public sdsview +{ +public: + sdsstring() = default; + explicit sdsstring(sds str) + : sdsview(str) + {} + + sdsstring(const sdsstring &other) + : sdsview(sdsdup(other.m_str)) + {} + + sdsstring(sdsstring &&other) + : sdsview(other.m_str) + { + other.m_str = nullptr; + } + + ~sdsstring() + { + sdsfree(m_str); + } +}; + +class sdsimmutablestring : public sdsstring +{ +public: + sdsimmutablestring() = default; + explicit sdsimmutablestring(sds str) + : sdsstring(str) + {} + + explicit sdsimmutablestring(const char *str) + : sdsstring((sds)str) + {} + + sdsimmutablestring(const sdsimmutablestring &other) + : sdsstring(sdsdupshared(other.m_str)) + {} + + sdsimmutablestring(sdsimmutablestring &&other) + : sdsstring(other.m_str) + { + other.m_str = nullptr; + } + + auto &operator=(const sdsimmutablestring &other) + { + sdsfree(m_str); + m_str = sdsdupshared(other.m_str); + return *this; + } +}; + #endif #endif diff --git a/src/server.h b/src/server.h index eb6cd01d2..c4666d650 100644 --- a/src/server.h +++ b/src/server.h @@ -941,11 +941,11 @@ public: }; private: - sds m_keyPrimary; + sdsimmutablestring m_keyPrimary; std::vector m_vecexpireEntries; // Note a NULL for the sds portion means the expire is for the primary key public: - expireEntryFat(sds keyPrimary) + expireEntryFat(const sdsimmutablestring &keyPrimary) : m_keyPrimary(keyPrimary) {} @@ -953,7 +953,7 @@ public: expireEntryFat(expireEntryFat &&e) = default; long long when() const noexcept { return m_vecexpireEntries.front().when; } - const char *key() const noexcept { return m_keyPrimary; } + const char *key() const noexcept { return static_cast(m_keyPrimary); } bool operator<(long long when) const noexcept { return this->when() < when; } @@ -990,9 +990,9 @@ public: }; class expireEntry { - union + struct { - sds m_key; + sdsimmutablestring m_key; expireEntryFat *m_pfatentry; } u; long long m_when; // LLONG_MIN means this is a fat entry and we should use the pointer @@ -1037,12 +1037,12 @@ public: if (subkey != nullptr) { m_when = LLONG_MIN; - u.m_pfatentry = new (MALLOC_LOCAL) expireEntryFat(key); + u.m_pfatentry = new (MALLOC_LOCAL) expireEntryFat(sdsimmutablestring(sdsdupshared(key))); u.m_pfatentry->expireSubKey(subkey, when); } else { - u.m_key = key; + u.m_key = sdsimmutablestring(sdsdupshared(key)); m_when = when; } } @@ -1063,9 +1063,8 @@ public: expireEntry(expireEntry &&e) { - u.m_key = e.u.m_key; + u.m_key = std::move(e.u.m_key); m_when = e.m_when; - e.u.m_key = (char*)key(); // we do this so it can still be found in the set e.m_when = 0; } @@ -1078,9 +1077,9 @@ public: void setKeyUnsafe(sds key) { if (FFat()) - u.m_pfatentry->m_keyPrimary = key; + u.m_pfatentry->m_keyPrimary = sdsimmutablestring(sdsdupshared(key)); else - u.m_key = key; + u.m_key = sdsimmutablestring(sdsdupshared(key)); } inline bool FFat() const noexcept { return m_when == LLONG_MIN; } @@ -1106,7 +1105,7 @@ public: { if (FFat()) return u.m_pfatentry->key(); - return u.m_key; + return static_cast(u.m_key); } long long when() const noexcept { @@ -1128,7 +1127,7 @@ public: { // we have to upgrade to a fat entry long long whenT = m_when; - sds keyPrimary = u.m_key; + sdsimmutablestring keyPrimary = u.m_key; m_when = LLONG_MIN; u.m_pfatentry = new (MALLOC_LOCAL) expireEntryFat(keyPrimary); u.m_pfatentry->expireSubKey(nullptr, whenT); @@ -1387,8 +1386,10 @@ public: dict_iter random_threadsafe() const; dict_iter find_threadsafe(const char *key) const; - expireEntry *getExpire(robj_roptr key); - const expireEntry *getExpire(robj_roptr key) const; + expireEntry *getExpire(robj_roptr key) { return getExpire(szFromObj(key)); } + expireEntry *getExpire(const char *key); + const expireEntry *getExpire(const char *key) const; + const expireEntry *getExpire(robj_roptr key) const { return getExpire(szFromObj(key)); } // These need to be fixed using redisDbPersistentData::size; @@ -3273,9 +3274,9 @@ inline int ielFromEventLoop(const aeEventLoop *eventLoop) inline int FCorrectThread(client *c) { - return (serverTL != NULL && (g_pserver->rgthreadvar[c->iel].el == serverTL->el)) + return (c->fd == -1) || (c->iel == IDX_EVENT_LOOP_MAIN && moduleGILAcquiredByModule()) - || (c->fd == -1); + || (serverTL != NULL && (g_pserver->rgthreadvar[c->iel].el == serverTL->el)); } #define AssertCorrectThread(c) serverAssert(FCorrectThread(c)) diff --git a/src/snapshot.cpp b/src/snapshot.cpp index 5e593ffcf..7972dcf9a 100644 --- a/src/snapshot.cpp +++ b/src/snapshot.cpp @@ -160,7 +160,7 @@ void redisDbPersistentData::endSnapshot(const redisDbPersistentDataSnapshot *psn { // The tombstone is for a grand child, propogate it serverAssert(m_spdbSnapshotHOLDER->m_pdbSnapshot->find_threadsafe((const char*)dictGetKey(de)) != nullptr); - dictAdd(m_spdbSnapshotHOLDER->m_pdictTombstone, sdsdup((sds)dictGetKey(de)), nullptr); + dictAdd(m_spdbSnapshotHOLDER->m_pdictTombstone, sdsdupshared((sds)dictGetKey(de)), nullptr); continue; } @@ -193,7 +193,7 @@ void redisDbPersistentData::endSnapshot(const redisDbPersistentDataSnapshot *psn } else { - dictAdd(m_spdbSnapshotHOLDER->m_pdict, sdsdup((sds)dictGetKey(de)), o); + dictAdd(m_spdbSnapshotHOLDER->m_pdict, sdsdupshared((sds)dictGetKey(de)), o); } if (dictGetVal(de) != nullptr) incrRefCount((robj*)dictGetVal(de)); @@ -368,7 +368,7 @@ void redisDbPersistentDataSnapshot::consolidate_children(redisDbPersistentData * m_pdbSnapshot->iterate_threadsafe([&](const char *key, robj_roptr o){ if (o != nullptr) incrRefCount(o); - dictAdd(spdb->m_pdict, sdsdup(key), o.unsafe_robjcast()); + dictAdd(spdb->m_pdict, sdsdupshared(key), o.unsafe_robjcast()); return true; }, true /*fKeyOnly*/); spdb->m_spstorage = m_pdbSnapshot->m_spstorage;