diff --git a/src/db.c b/src/db.c index fa14b54ff..79d754d3b 100644 --- a/src/db.c +++ b/src/db.c @@ -810,29 +810,70 @@ void keysCommand(client *c) { setDeferredArrayLen(c,replylen,numkeys); } +/* Data used by the dict scan callback. */ +typedef struct { + list *keys; /* elements that collect from dict */ + robj *o; /* o must be a hash/set/zset object, NULL means current db */ + long long type; /* the particular type when scan the db */ + sds pattern; /* pattern string, NULL means no pattern */ + long sampled; /* cumulative number of keys sampled */ +} scanData; + +/* Helper function to compare key type in scan commands */ +int objectTypeCompare(robj *o, long long target) { + if (o->type != OBJ_MODULE) { + if (o->type != target) + return 0; + else + return 1; + } + /* module type compare */ + long long mt = (long long)REDISMODULE_TYPE_SIGN(((moduleValue *)o->ptr)->type->id); + if (target != -mt) + return 0; + else + return 1; +} /* This callback is used by scanGenericCommand in order to collect elements * returned by the dictionary iterator into a list. */ void scanCallback(void *privdata, const dictEntry *de) { - void **pd = (void**) privdata; - list *keys = pd[0]; - robj *o = pd[1]; - robj *key, *val = NULL; + scanData *data = (scanData *)privdata; + list *keys = data->keys; + robj *o = data->o; + sds val = NULL; + sds key = NULL; + data->sampled++; + + /* o and typename can not have values at the same time. */ + serverAssert(!((data->type != LLONG_MAX) && o)); + + /* Filter an element if it isn't the type we want. */ + /* TODO: uncomment in redis 8.0 + if (!o && data->type != LLONG_MAX) { + robj *rval = dictGetVal(de); + if (!objectTypeCompare(rval, data->type)) return; + }*/ + + /* Filter element if it does not match the pattern. */ + sds keysds = dictGetKey(de); + if (data->pattern) { + if (!stringmatchlen(data->pattern, sdslen(data->pattern), keysds, sdslen(keysds), 0)) { + return; + } + } if (o == NULL) { - sds sdskey = dictGetKey(de); - key = createStringObject(sdskey, sdslen(sdskey)); + key = keysds; } else if (o->type == OBJ_SET) { - sds keysds = dictGetKey(de); - key = createStringObject(keysds,sdslen(keysds)); + key = keysds; } else if (o->type == OBJ_HASH) { - sds sdskey = dictGetKey(de); - sds sdsval = dictGetVal(de); - key = createStringObject(sdskey,sdslen(sdskey)); - val = createStringObject(sdsval,sdslen(sdsval)); + key = keysds; + val = dictGetVal(de); } else if (o->type == OBJ_ZSET) { - sds sdskey = dictGetKey(de); - key = createStringObject(sdskey,sdslen(sdskey)); - val = createStringObjectFromLongDouble(*(double*)dictGetVal(de),0); + char buf[MAX_LONG_DOUBLE_CHARS]; + int len = ld2string(buf, sizeof(buf), *(double *)dictGetVal(de), LD_STR_AUTO); + key = sdsdup(keysds); + val = sdsnewlen(buf, len); } else { serverPanic("Type not handled in SCAN callback."); } @@ -860,6 +901,46 @@ int parseScanCursorOrReply(client *c, robj *o, unsigned long *cursor) { return C_OK; } +char *obj_type_name[OBJ_TYPE_MAX] = { + "string", + "list", + "set", + "zset", + "hash", + NULL, /* module type is special */ + "stream" +}; + +/* Helper function to get type from a string in scan commands */ +long long getObjectTypeByName(char *name) { + + for (long long i = 0; i < OBJ_TYPE_MAX; i++) { + if (obj_type_name[i] && !strcasecmp(name, obj_type_name[i])) { + return i; + } + } + + moduleType *mt = moduleTypeLookupModuleByNameIgnoreCase(name); + if (mt != NULL) return -(REDISMODULE_TYPE_SIGN(mt->id)); + + return LLONG_MAX; +} + +char *getObjectTypeName(robj *o) { + if (o == NULL) { + return "none"; + } + + serverAssert(o->type >= 0 && o->type < OBJ_TYPE_MAX); + + if (o->type == OBJ_MODULE) { + moduleValue *mv = o->ptr; + return mv->type->name; + } else { + return obj_type_name[o->type]; + } +} + /* This command implements SCAN, HSCAN and SSCAN commands. * If object 'o' is passed, then it must be a Hash, Set or Zset object, otherwise * if 'o' is NULL the command will operate on the dictionary associated with @@ -873,11 +954,11 @@ int parseScanCursorOrReply(client *c, robj *o, unsigned long *cursor) { * of every element on the Hash. */ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { int i, j; - list *keys = listCreate(); - listNode *node, *nextnode; + listNode *node; long count = 10; sds pat = NULL; sds typename = NULL; + long long type = LLONG_MAX; int patlen = 0, use_pattern = 0; dict *ht; @@ -896,12 +977,12 @@ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { if (getLongFromObjectOrReply(c, c->argv[i+1], &count, NULL) != C_OK) { - goto cleanup; + return; } if (count < 1) { addReplyErrorObject(c,shared.syntaxerr); - goto cleanup; + return; } i += 2; @@ -917,10 +998,16 @@ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { } else if (!strcasecmp(c->argv[i]->ptr, "type") && o == NULL && j >= 2) { /* SCAN for a particular type only applies to the db dict */ typename = c->argv[i+1]->ptr; + type = getObjectTypeByName(typename); + if (type == LLONG_MAX) { + /* TODO: uncomment in redis 8.0 + addReplyErrorFormat(c, "unknown type name '%s'", typename); + return; */ + } i+= 2; } else { addReplyErrorObject(c,shared.syntaxerr); - goto cleanup; + return; } } @@ -940,42 +1027,67 @@ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { ht = o->ptr; } else if (o->type == OBJ_HASH && o->encoding == OBJ_ENCODING_HT) { ht = o->ptr; - count *= 2; /* We return key / value for this type. */ } else if (o->type == OBJ_ZSET && o->encoding == OBJ_ENCODING_SKIPLIST) { zset *zs = o->ptr; ht = zs->dict; - count *= 2; /* We return key / value for this type. */ + } + + list *keys = listCreate(); + /* Set a free callback for the contents of the collected keys list. + * For the main keyspace dict, and when we scan a key that's dict encoded + * (we have 'ht'), we don't need to define free method because the strings + * in the list are just a shallow copy from the pointer in the dictEntry. + * When scanning a key with other encodings (e.g. listpack), we need to + * free the temporary strings we add to that list. + * The exception to the above is ZSET, where we do allocate temporary + * strings even when scanning a dict. */ + if (o && (!ht || o->type == OBJ_ZSET)) { + listSetFreeMethod(keys, (void (*)(void*))sdsfree); } if (ht) { - void *privdata[2]; /* We set the max number of iterations to ten times the specified * COUNT, so if the hash table is in a pathological state (very * sparsely populated) we avoid to block too much time at the cost * of returning no or very few elements. */ long maxiterations = count*10; - /* We pass two pointers to the callback: the list to which it will - * add new elements, and the object containing the dictionary so that - * it is possible to fetch more data in a type-dependent way. */ - privdata[0] = keys; - privdata[1] = o; + /* We pass scanData which have three pointers to the callback: + * 1. data.keys: the list to which it will add new elements; + * 2. data.o: the object containing the dictionary so that + * it is possible to fetch more data in a type-dependent way; + * 3. data.type: the specified type scan in the db, LLONG_MAX means + * type matching is no needed; + * 4. data.pattern: the pattern string + * 5. data.sampled: the maxiteration limit is there in case we're + * working on an empty dict, one with a lot of empty buckets, and + * for the buckets are not empty, we need to limit the spampled number + * to prevent a long hang time caused by filtering too many keys*/ + scanData data = { + .keys = keys, + .o = o, + .type = type, + .pattern = use_pattern ? pat : NULL, + .sampled = 0, + }; do { - cursor = dictScan(ht, cursor, scanCallback, privdata); - } while (cursor && - maxiterations-- && - listLength(keys) < (unsigned long)count); + cursor = dictScan(ht, cursor, scanCallback, &data); + } while (cursor && maxiterations-- && data.sampled < count); } else if (o->type == OBJ_SET) { char *str; + char buf[LONG_STR_SIZE]; size_t len; int64_t llele; setTypeIterator *si = setTypeInitIterator(o); while (setTypeNext(si, &str, &len, &llele) != -1) { if (str == NULL) { - listAddNodeTail(keys, createStringObjectFromLongLong(llele)); - } else { - listAddNodeTail(keys, createStringObject(str, len)); + len = ll2string(buf, sizeof(buf), llele); } + char *key = str ? str : buf; + if (use_pattern && !stringmatchlen(pat, sdslen(pat), key, len, 0)) { + continue; + } + listAddNodeTail(keys, sdsnewlen(key, len)); } setTypeReleaseIterator(si); cursor = 0; @@ -983,72 +1095,53 @@ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { o->encoding == OBJ_ENCODING_LISTPACK) { unsigned char *p = lpFirst(o->ptr); - unsigned char *vstr; - int64_t vlen; + unsigned char *str; + int64_t len; unsigned char intbuf[LP_INTBUF_SIZE]; while(p) { - vstr = lpGet(p,&vlen,intbuf); - listAddNodeTail(keys, createStringObject((char*)vstr,vlen)); - p = lpNext(o->ptr,p); + str = lpGet(p, &len, intbuf); + /* point to the value */ + p = lpNext(o->ptr, p); + if (use_pattern && !stringmatchlen(pat, sdslen(pat), (char *)str, len, 0)) { + /* jump to the next key/val pair */ + p = lpNext(o->ptr, p); + continue; + } + /* add key object */ + listAddNodeTail(keys, sdsnewlen(str, len)); + /* add value object */ + str = lpGet(p, &len, intbuf); + listAddNodeTail(keys, sdsnewlen(str, len)); + p = lpNext(o->ptr, p); } cursor = 0; } else { serverPanic("Not handled encoding in SCAN."); } - /* Step 3: Filter elements. */ - node = listFirst(keys); - while (node) { - robj *kobj = listNodeValue(node); - nextnode = listNextNode(node); - int filter = 0; - - /* Filter element if it does not match the pattern. */ - if (use_pattern) { - if (sdsEncodedObject(kobj)) { - if (!stringmatchlen(pat, patlen, kobj->ptr, sdslen(kobj->ptr), 0)) - filter = 1; - } else { - char buf[LONG_STR_SIZE]; - int len; - - serverAssert(kobj->encoding == OBJ_ENCODING_INT); - len = ll2string(buf,sizeof(buf),(long)kobj->ptr); - if (!stringmatchlen(pat, patlen, buf, len, 0)) filter = 1; + /* Step 3: Filter the expired keys */ + if (o == NULL && listLength(keys)) { + robj kobj; + listIter li; + listNode *ln; + listRewind(keys, &li); + while ((ln = listNext(&li))) { + sds key = listNodeValue(ln); + initStaticStringObject(kobj, key); + /* Filter an element if it isn't the type we want. */ + /* TODO: remove this in redis 8.0 */ + if (typename) { + robj* typecheck = lookupKeyReadWithFlags(c->db, &kobj, LOOKUP_NOTOUCH|LOOKUP_NONOTIFY); + if (!typecheck || !objectTypeCompare(typecheck, type)) { + listDelNode(keys, ln); + } + continue; + } + if (expireIfNeeded(c->db, &kobj, 0)) { + listDelNode(keys, ln); } } - - /* Filter an element if it isn't the type we want. */ - if (!filter && o == NULL && typename){ - robj* typecheck = lookupKeyReadWithFlags(c->db, kobj, LOOKUP_NOTOUCH); - char* type = getObjectTypeName(typecheck); - if (strcasecmp((char*) typename, type)) filter = 1; - } - - /* Filter element if it is an expired key. */ - if (!filter && o == NULL && expireIfNeeded(c->db, kobj, 0)) filter = 1; - - /* Remove the element and its associated value if needed. */ - if (filter) { - decrRefCount(kobj); - listDelNode(keys, node); - } - - /* If this is a hash or a sorted set, we have a flat list of - * key-value elements, so if this element was filtered, remove the - * value, or skip it if it was not filtered: we only match keys. */ - if (o && (o->type == OBJ_ZSET || o->type == OBJ_HASH)) { - node = nextnode; - serverAssert(node); /* assertion for valgrind (avoid NPD) */ - nextnode = listNextNode(node); - if (filter) { - kobj = listNodeValue(node); - decrRefCount(kobj); - listDelNode(keys, node); - } - } - node = nextnode; } /* Step 4: Reply to the client. */ @@ -1057,14 +1150,11 @@ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { addReplyArrayLen(c, listLength(keys)); while ((node = listFirst(keys)) != NULL) { - robj *kobj = listNodeValue(node); - addReplyBulk(c, kobj); - decrRefCount(kobj); + sds key = listNodeValue(node); + addReplyBulkCBuffer(c, key, sdslen(key)); listDelNode(keys, node); } -cleanup: - listSetFreeMethod(keys,decrRefCountVoid); listRelease(keys); } @@ -1083,28 +1173,6 @@ void lastsaveCommand(client *c) { addReplyLongLong(c,server.lastsave); } -char* getObjectTypeName(robj *o) { - char* type; - if (o == NULL) { - type = "none"; - } else { - switch(o->type) { - case OBJ_STRING: type = "string"; break; - case OBJ_LIST: type = "list"; break; - case OBJ_SET: type = "set"; break; - case OBJ_ZSET: type = "zset"; break; - case OBJ_HASH: type = "hash"; break; - case OBJ_STREAM: type = "stream"; break; - case OBJ_MODULE: { - moduleValue *mv = o->ptr; - type = mv->type->name; - }; break; - default: type = "unknown"; break; - } - } - return type; -} - void typeCommand(client *c) { robj *o; o = lookupKeyReadWithFlags(c->db,c->argv[1],LOOKUP_NOTOUCH); @@ -1736,8 +1804,16 @@ int expireIfNeeded(redisDb *db, robj *key, int flags) { * will have failed over and the new primary will send us the expire. */ if (isPausedActionsWithUpdate(PAUSE_ACTION_EXPIRE)) return 1; + /* The key needs to be converted from static to heap before deleted */ + int static_key = key->refcount == OBJ_STATIC_REFCOUNT; + if (static_key) { + key = createStringObject(key->ptr, sdslen(key->ptr)); + } /* Delete the key */ deleteExpiredKeyAndPropagate(db,key); + if (static_key) { + decrRefCount(key); + } return 1; } diff --git a/src/module.c b/src/module.c index aa493c971..0addeecde 100644 --- a/src/module.c +++ b/src/module.c @@ -6573,7 +6573,7 @@ uint64_t moduleTypeEncodeId(const char *name, int encver) { /* Search, in the list of exported data types of all the modules registered, * a type with the same name as the one given. Returns the moduleType * structure pointer if such a module is found, or NULL otherwise. */ -moduleType *moduleTypeLookupModuleByName(const char *name) { +moduleType *moduleTypeLookupModuleByNameInternal(const char *name, int ignore_case) { dictIterator *di = dictGetIterator(modules); dictEntry *de; @@ -6585,7 +6585,9 @@ moduleType *moduleTypeLookupModuleByName(const char *name) { listRewind(module->types,&li); while((ln = listNext(&li))) { moduleType *mt = ln->value; - if (memcmp(name,mt->name,sizeof(mt->name)) == 0) { + if ((!ignore_case && memcmp(name,mt->name,sizeof(mt->name)) == 0) + || (ignore_case && !strcasecmp(name, mt->name))) + { dictReleaseIterator(di); return mt; } @@ -6594,6 +6596,15 @@ moduleType *moduleTypeLookupModuleByName(const char *name) { dictReleaseIterator(di); return NULL; } +/* Search all registered modules by name, and name is case sensitive */ +moduleType *moduleTypeLookupModuleByName(const char *name) { + return moduleTypeLookupModuleByNameInternal(name, 0); +} + +/* Search all registered modules by name, but case insensitive */ +moduleType *moduleTypeLookupModuleByNameIgnoreCase(const char *name) { + return moduleTypeLookupModuleByNameInternal(name, 1); +} /* Lookup a module by ID, with caching. This function is used during RDB * loading. Modules exporting data types should never be able to unload, so diff --git a/src/server.h b/src/server.h index f301a315c..cb555031e 100644 --- a/src/server.h +++ b/src/server.h @@ -710,6 +710,7 @@ typedef enum { * encoding version. */ #define OBJ_MODULE 5 /* Module object. */ #define OBJ_STREAM 6 /* Stream object. */ +#define OBJ_TYPE_MAX 7 /* Maximum number of object types */ /* Extract encver / signature from a module type ID. */ #define REDISMODULE_TYPE_ENCVER_BITS 10 @@ -2471,6 +2472,8 @@ void moduleLoadFromQueue(void); int moduleGetCommandKeysViaAPI(struct redisCommand *cmd, robj **argv, int argc, getKeysResult *result); int moduleGetCommandChannelsViaAPI(struct redisCommand *cmd, robj **argv, int argc, getKeysResult *result); moduleType *moduleTypeLookupModuleByID(uint64_t id); +moduleType *moduleTypeLookupModuleByName(const char *name); +moduleType *moduleTypeLookupModuleByNameIgnoreCase(const char *name); void moduleTypeNameByID(char *name, uint64_t moduleid); const char *moduleTypeModuleName(moduleType *mt); const char *moduleNameFromCommand(struct redisCommand *cmd); diff --git a/tests/unit/moduleapi/datatype.tcl b/tests/unit/moduleapi/datatype.tcl index 0c87e9597..951c060e7 100644 --- a/tests/unit/moduleapi/datatype.tcl +++ b/tests/unit/moduleapi/datatype.tcl @@ -89,4 +89,46 @@ start_server {tags {"modules"}} { $rd read $rd close } + + test {DataType: check the type name} { + r flushdb + r datatype.set foo 111 bar + assert_type test___dt foo + } + + test {SCAN module datatype} { + r flushdb + populate 1000 + r datatype.set foo 111 bar + set type [r type foo] + set cur 0 + set keys {} + while 1 { + set res [r scan $cur type $type] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys {*}$k + if {$cur == 0} break + } + + assert_equal 1 [llength $keys] + } + + test {SCAN module datatype with case sensitive} { + r flushdb + populate 1000 + r datatype.set foo 111 bar + set type "tEsT___dT" + set cur 0 + set keys {} + while 1 { + set res [r scan $cur type $type] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys {*}$k + if {$cur == 0} break + } + + assert_equal 1 [llength $keys] + } } diff --git a/tests/unit/scan.tcl b/tests/unit/scan.tcl index 45397d7a3..d688d7cda 100644 --- a/tests/unit/scan.tcl +++ b/tests/unit/scan.tcl @@ -98,6 +98,108 @@ start_server {tags {"scan network"}} { assert_equal 1000 [llength $keys] } + test "SCAN unknown type" { + r flushdb + # make sure that passive expiration is triggered by the scan + r debug set-active-expire 0 + + populate 1000 + r hset hash f v + r pexpire hash 1 + + after 2 + + # TODO: remove this in redis 8.0 + set cur 0 + set keys {} + while 1 { + set res [r scan $cur type "string1"] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys {*}$k + if {$cur == 0} break + } + + assert_equal 0 [llength $keys] + # make sure that expired key have been removed by scan command + assert_equal 1000 [scan [regexp -inline {keys\=([\d]*)} [r info keyspace]] keys=%d] + + # TODO: uncomment in redis 8.0 + #assert_error "*unknown type name*" {r scan 0 type "string1"} + # expired key will be no touched by scan command + #assert_equal 1001 [scan [regexp -inline {keys\=([\d]*)} [r info keyspace]] keys=%d] + r debug set-active-expire 1 + } {OK} {needs:debug} + + test "SCAN with expired keys" { + r flushdb + # make sure that passive expiration is triggered by the scan + r debug set-active-expire 0 + + populate 1000 + r set foo bar + r pexpire foo 1 + + # add a hash type key + r hset hash f v + r pexpire hash 1 + + after 2 + + set cur 0 + set keys {} + while 1 { + set res [r scan $cur count 10] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys {*}$k + if {$cur == 0} break + } + + assert_equal 1000 [llength $keys] + + # make sure that expired key have been removed by scan command + assert_equal 1000 [scan [regexp -inline {keys\=([\d]*)} [r info keyspace]] keys=%d] + + r debug set-active-expire 1 + } {OK} {needs:debug} + + test "SCAN with expired keys with TYPE filter" { + r flushdb + # make sure that passive expiration is triggered by the scan + r debug set-active-expire 0 + + populate 1000 + r set foo bar + r pexpire foo 1 + + # add a hash type key + r hset hash f v + r pexpire hash 1 + + after 2 + + set cur 0 + set keys {} + while 1 { + set res [r scan $cur type "string" count 10] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys {*}$k + if {$cur == 0} break + } + + assert_equal 1000 [llength $keys] + + # make sure that expired key have been removed by scan command + assert_equal 1000 [scan [regexp -inline {keys\=([\d]*)} [r info keyspace]] keys=%d] + # TODO: uncomment in redis 8.0 + # make sure that only the expired key in the type match will been removed by scan command + #assert_equal 1001 [scan [regexp -inline {keys\=([\d]*)} [r info keyspace]] keys=%d] + + r debug set-active-expire 1 + } {OK} {needs:debug} + foreach enc {intset listpack hashtable} { test "SSCAN with encoding $enc" { # Create the Set