From 7b070423b8e20cf14a9ab80dfa7bea5ba88c54a3 Mon Sep 17 00:00:00 2001 From: Binbin Date: Mon, 18 Mar 2024 23:41:54 +0800 Subject: [PATCH] Fix dictionary use-after-free in active expire and make kvstore iter to respect EMPTY flag (#13135) After #13072, there is an use-after-free error. In expireScanCallback, we will delete the dict, and then in dictScan we will continue to use the dict, like we will doing `dictResumeRehashing(d)` in the end, this casued an error. In this PR, in freeDictIfNeeded, if the dict's pauserehash is set, don't delete the dict yet, and then when scan returns try to delete it again. At the same time, we noticed that there will be similar problems in iterator. We may also delete elements during the iteration process, causing the dict to be deleted, so the part related to iter in the PR has also been modified. dictResetIterator was also missing from the previous kvstoreIteratorNextDict, we currently have no scenario that elements will be deleted in kvstoreIterator process, deal with it together to avoid future problems. Added some simple tests to verify the changes. In addition, the modification in #13072 omitted initTempDb and emptyDbAsync, and they were also added. This PR also remove the slow flag from the expire test (consumes 1.3s) so that problems can be found in CI in the future. --- src/db.c | 11 ++- src/dict.h | 1 + src/kvstore.c | 185 +++++++++++++++++++++++++++++++++++++++++- src/kvstore.h | 4 + src/lazyfree.c | 11 ++- src/server.c | 3 +- tests/unit/expire.tcl | 2 +- 7 files changed, 205 insertions(+), 12 deletions(-) diff --git a/src/db.c b/src/db.c index f1eb6c027..82e61b496 100644 --- a/src/db.c +++ b/src/db.c @@ -563,12 +563,17 @@ long long emptyData(int dbnum, int flags, void(callback)(dict*)) { /* Initialize temporary db on replica for use during diskless replication. */ redisDb *initTempDb(void) { + int slot_count_bits = 0; + int flags = KVSTORE_ALLOCATE_DICTS_ON_DEMAND; + if (server.cluster_enabled) { + slot_count_bits = CLUSTER_SLOT_MASK_BITS; + flags |= KVSTORE_FREE_EMPTY_DICTS; + } redisDb *tempDb = zcalloc(sizeof(redisDb)*server.dbnum); for (int i=0; irehashidx != -1) #define dictPauseRehashing(d) ((d)->pauserehash++) #define dictResumeRehashing(d) ((d)->pauserehash--) +#define dictIsRehashingPaused(d) ((d)->pauserehash > 0) #define dictPauseAutoResize(d) ((d)->pauseAutoResize++) #define dictResumeAutoResize(d) ((d)->pauseAutoResize--) diff --git a/src/kvstore.c b/src/kvstore.c index 5517594c6..505f92957 100644 --- a/src/kvstore.c +++ b/src/kvstore.c @@ -98,6 +98,12 @@ static dict **kvstoreGetDictRef(kvstore *kvs, int didx) { return &kvs->dicts[didx]; } +static int kvstoreDictIsRehashingPaused(kvstore *kvs, int didx) +{ + dict *d = kvstoreGetDict(kvs, didx); + return d ? dictIsRehashingPaused(d) : 0; +} + /* Returns total (cumulative) number of keys up until given dict-index (inclusive). * Time complexity is O(log(kvs->num_dicts)). */ static unsigned long long cumulativeKeyCountRead(kvstore *kvs, int didx) { @@ -167,10 +173,18 @@ static dict *createDictIfNeeded(kvstore *kvs, int didx) { return kvs->dicts[didx]; } +/* Called when the dict will delete entries, the function will check + * KVSTORE_FREE_EMPTY_DICTS to determine whether the empty dict needs + * to be freed. + * + * Note that for rehashing dicts, that is, in the case of safe iterators + * and Scan, we won't delete the dict. We will check whether it needs + * to be deleted when we're releasing the iterator. */ static void freeDictIfNeeded(kvstore *kvs, int didx) { if (!(kvs->flags & KVSTORE_FREE_EMPTY_DICTS) || !kvstoreGetDict(kvs, didx) || - kvstoreDictSize(kvs, didx) != 0) + kvstoreDictSize(kvs, didx) != 0 || + kvstoreDictIsRehashingPaused(kvs, didx)) return; dictRelease(kvs->dicts[didx]); kvs->dicts[didx] = NULL; @@ -391,6 +405,8 @@ unsigned long long kvstoreScan(kvstore *kvs, unsigned long long cursor, int skip = !d || (skip_cb && skip_cb(d)); if (!skip) { _cursor = dictScan(d, cursor, scan_cb, privdata); + /* In dictScan, scan_cb may delete entries (e.g., in active expire case). */ + freeDictIfNeeded(kvs, didx); } /* scanning done for the current dictionary or if the scanning wasn't possible, move to the next dict index. */ if (_cursor == 0 || skip) { @@ -568,7 +584,8 @@ kvstoreIterator *kvstoreIteratorInit(kvstore *kvs) { void kvstoreIteratorRelease(kvstoreIterator *kvs_it) { dictIterator *iter = &kvs_it->di; dictResetIterator(iter); - + /* In the safe iterator context, we may delete entries. */ + freeDictIfNeeded(kvs_it->kvs, kvs_it->didx); zfree(kvs_it); } @@ -576,6 +593,16 @@ void kvstoreIteratorRelease(kvstoreIterator *kvs_it) { dict *kvstoreIteratorNextDict(kvstoreIterator *kvs_it) { if (kvs_it->next_didx == -1) return NULL; + + /* The dict may be deleted during the iteration process, so here need to check for NULL. */ + if (kvs_it->didx != -1 && kvstoreGetDict(kvs_it->kvs, kvs_it->didx)) { + /* Before we move to the next dict, reset the iter of the previous dict. */ + dictIterator *iter = &kvs_it->di; + dictResetIterator(iter); + /* In the safe iterator context, we may delete entries. */ + freeDictIfNeeded(kvs_it->kvs, kvs_it->didx); + } + kvs_it->didx = kvs_it->next_didx; kvs_it->next_didx = kvstoreGetNextNonEmptyDictIndex(kvs_it->kvs, kvs_it->didx); return kvs_it->kvs->dicts[kvs_it->didx]; @@ -597,6 +624,8 @@ dictEntry *kvstoreIteratorNext(kvstoreIterator *kvs_it) { /* Before we move to the next dict, reset the iter of the previous dict. */ dictIterator *iter = &kvs_it->di; dictResetIterator(iter); + /* In the safe iterator context, we may delete entries. */ + freeDictIfNeeded(kvs_it->kvs, kvs_it->didx); } dictInitSafeIterator(&kvs_it->di, d); de = dictNext(&kvs_it->di); @@ -690,7 +719,11 @@ kvstoreDictIterator *kvstoreGetDictSafeIterator(kvstore *kvs, int didx) void kvstoreReleaseDictIterator(kvstoreDictIterator *kvs_di) { /* The dict may be deleted during the iteration process, so here need to check for NULL. */ - if (kvstoreGetDict(kvs_di->kvs, kvs_di->didx)) dictResetIterator(&kvs_di->di); + if (kvstoreGetDict(kvs_di->kvs, kvs_di->didx)) { + dictResetIterator(&kvs_di->di); + /* In the safe iterator context, we may delete entries. */ + freeDictIfNeeded(kvs_di->kvs, kvs_di->didx); + } zfree(kvs_di); } @@ -825,10 +858,154 @@ int kvstoreDictDelete(kvstore *kvs, int didx, const void *key) { dict *d = kvstoreGetDict(kvs, didx); if (!d) return DICT_ERR; - int ret = dictDelete(kvstoreGetDict(kvs, didx), key); + int ret = dictDelete(d, key); if (ret == DICT_OK) { cumulativeKeyCountAdd(kvs, didx, -1); freeDictIfNeeded(kvs, didx); } return ret; } + +#ifdef REDIS_TEST +#include +#include "testhelp.h" + +#define TEST(name) printf("test — %s\n", name); + +uint64_t hashTestCallback(const void *key) { + return dictGenHashFunction((unsigned char*)key, strlen((char*)key)); +} + +void freeTestCallback(dict *d, void *val) { + UNUSED(d); + zfree(val); +} + +dictType KvstoreDictTestType = { + hashTestCallback, + NULL, + NULL, + NULL, + freeTestCallback, + NULL, + NULL +}; + +char *stringFromInt(int value) { + char buf[32]; + int len; + char *s; + + len = snprintf(buf, sizeof(buf), "%d",value); + s = zmalloc(len+1); + memcpy(s, buf, len); + s[len] = '\0'; + return s; +} + +/* ./redis-server test kvstore */ +int kvstoreTest(int argc, char **argv, int flags) { + UNUSED(argc); + UNUSED(argv); + UNUSED(flags); + + int i; + void *key; + dictEntry *de; + kvstoreIterator *kvs_it; + kvstoreDictIterator *kvs_di; + + int didx = 0; + int curr_slot = 0; + kvstore *kvs1 = kvstoreCreate(&KvstoreDictTestType, 0, KVSTORE_ALLOCATE_DICTS_ON_DEMAND); + kvstore *kvs2 = kvstoreCreate(&KvstoreDictTestType, 0, KVSTORE_ALLOCATE_DICTS_ON_DEMAND | KVSTORE_FREE_EMPTY_DICTS); + + TEST("Add 16 keys") { + for (i = 0; i < 16; i++) { + de = kvstoreDictAddRaw(kvs1, didx, stringFromInt(i), NULL); + assert(de != NULL); + de = kvstoreDictAddRaw(kvs2, didx, stringFromInt(i), NULL); + assert(de != NULL); + } + assert(kvstoreDictSize(kvs1, didx) == 16); + assert(kvstoreSize(kvs1) == 16); + assert(kvstoreDictSize(kvs2, didx) == 16); + assert(kvstoreSize(kvs2) == 16); + } + + TEST("kvstoreIterator case 1: removing all keys does not delete the empty dict") { + kvs_it = kvstoreIteratorInit(kvs1); + while((de = kvstoreIteratorNext(kvs_it)) != NULL) { + curr_slot = kvstoreIteratorGetCurrentDictIndex(kvs_it); + key = dictGetKey(de); + assert(kvstoreDictDelete(kvs1, curr_slot, key) == DICT_OK); + } + kvstoreIteratorRelease(kvs_it); + + dict *d = kvstoreGetDict(kvs1, didx); + assert(d != NULL); + assert(kvstoreDictSize(kvs1, didx) == 0); + assert(kvstoreSize(kvs1) == 0); + } + + TEST("kvstoreIterator case 2: removing all keys will delete the empty dict") { + kvs_it = kvstoreIteratorInit(kvs2); + while((de = kvstoreIteratorNext(kvs_it)) != NULL) { + curr_slot = kvstoreIteratorGetCurrentDictIndex(kvs_it); + key = dictGetKey(de); + assert(kvstoreDictDelete(kvs2, curr_slot, key) == DICT_OK); + } + kvstoreIteratorRelease(kvs_it); + + dict *d = kvstoreGetDict(kvs2, didx); + assert(d == NULL); + assert(kvstoreDictSize(kvs2, didx) == 0); + assert(kvstoreSize(kvs2) == 0); + } + + TEST("Add 16 keys again") { + for (i = 0; i < 16; i++) { + de = kvstoreDictAddRaw(kvs1, didx, stringFromInt(i), NULL); + assert(de != NULL); + de = kvstoreDictAddRaw(kvs2, didx, stringFromInt(i), NULL); + assert(de != NULL); + } + assert(kvstoreDictSize(kvs1, didx) == 16); + assert(kvstoreSize(kvs1) == 16); + assert(kvstoreDictSize(kvs2, didx) == 16); + assert(kvstoreSize(kvs2) == 16); + } + + TEST("kvstoreDictIterator case 1: removing all keys does not delete the empty dict") { + kvs_di = kvstoreGetDictSafeIterator(kvs1, didx); + while((de = kvstoreDictIteratorNext(kvs_di)) != NULL) { + key = dictGetKey(de); + assert(kvstoreDictDelete(kvs1, didx, key) == DICT_OK); + } + kvstoreReleaseDictIterator(kvs_di); + + dict *d = kvstoreGetDict(kvs1, didx); + assert(d != NULL); + assert(kvstoreDictSize(kvs1, didx) == 0); + assert(kvstoreSize(kvs1) == 0); + } + + TEST("kvstoreDictIterator case 2: removing all keys will delete the empty dict") { + kvs_di = kvstoreGetDictSafeIterator(kvs2, didx); + while((de = kvstoreDictIteratorNext(kvs_di)) != NULL) { + key = dictGetKey(de); + assert(kvstoreDictDelete(kvs2, didx, key) == DICT_OK); + } + kvstoreReleaseDictIterator(kvs_di); + + dict *d = kvstoreGetDict(kvs2, didx); + assert(d == NULL); + assert(kvstoreDictSize(kvs2, didx) == 0); + assert(kvstoreSize(kvs2) == 0); + } + + kvstoreRelease(kvs1); + kvstoreRelease(kvs2); + return 0; +} +#endif diff --git a/src/kvstore.h b/src/kvstore.h index 56a486199..bce45fe4c 100644 --- a/src/kvstore.h +++ b/src/kvstore.h @@ -72,4 +72,8 @@ dictEntry *kvstoreDictTwoPhaseUnlinkFind(kvstore *kvs, int didx, const void *key void kvstoreDictTwoPhaseUnlinkFree(kvstore *kvs, int didx, dictEntry *he, dictEntry **plink, int table_index); int kvstoreDictDelete(kvstore *kvs, int didx, const void *key); +#ifdef REDIS_TEST +int kvstoreTest(int argc, char *argv[], int flags); +#endif + #endif /* DICTARRAY_H_ */ diff --git a/src/lazyfree.c b/src/lazyfree.c index aa084464b..645da2b34 100644 --- a/src/lazyfree.c +++ b/src/lazyfree.c @@ -177,10 +177,15 @@ void freeObjAsync(robj *key, robj *obj, int dbid) { * create a new empty set of hash tables and scheduling the old ones for * lazy freeing. */ void emptyDbAsync(redisDb *db) { - int slotCountBits = server.cluster_enabled? CLUSTER_SLOT_MASK_BITS : 0; + int slot_count_bits = 0; + int flags = KVSTORE_ALLOCATE_DICTS_ON_DEMAND; + if (server.cluster_enabled) { + slot_count_bits = CLUSTER_SLOT_MASK_BITS; + flags |= KVSTORE_FREE_EMPTY_DICTS; + } kvstore *oldkeys = db->keys, *oldexpires = db->expires; - db->keys = kvstoreCreate(&dbDictType, slotCountBits, KVSTORE_ALLOCATE_DICTS_ON_DEMAND); - db->expires = kvstoreCreate(&dbExpiresDictType, slotCountBits, KVSTORE_ALLOCATE_DICTS_ON_DEMAND); + db->keys = kvstoreCreate(&dbDictType, slot_count_bits, flags); + db->expires = kvstoreCreate(&dbExpiresDictType, slot_count_bits, flags); atomicIncr(lazyfree_objects, kvstoreSize(oldkeys)); bioCreateLazyFreeJob(lazyfreeFreeDatabase, 2, oldkeys, oldexpires); } diff --git a/src/server.c b/src/server.c index cff9db24c..29637dae5 100644 --- a/src/server.c +++ b/src/server.c @@ -6840,7 +6840,8 @@ struct redisTest { {"zmalloc", zmalloc_test}, {"sds", sdsTest}, {"dict", dictTest}, - {"listpack", listpackTest} + {"listpack", listpackTest}, + {"kvstore", kvstoreTest}, }; redisTestProc *getTestProcByName(const char *name) { int numtests = sizeof(redisTests)/sizeof(struct redisTest); diff --git a/tests/unit/expire.tcl b/tests/unit/expire.tcl index 3e58bd4f7..08fa88a10 100644 --- a/tests/unit/expire.tcl +++ b/tests/unit/expire.tcl @@ -834,7 +834,7 @@ start_server {tags {"expire"}} { } {} {needs:debug} } -start_cluster 1 0 {tags {"expire external:skip cluster slow"}} { +start_cluster 1 0 {tags {"expire external:skip cluster"}} { test "expire scan should skip dictionaries with lot's of empty buckets" { r debug set-active-expire 0