diff --git a/src/db.c b/src/db.c index b537a29a4..6623f7f2f 100644 --- a/src/db.c +++ b/src/db.c @@ -613,7 +613,7 @@ int parseScanCursorOrReply(client *c, robj *o, unsigned long *cursor) { } /* This command implements SCAN, HSCAN and SSCAN commands. - * If object 'o' is passed, then it must be a Hash or Set object, otherwise + * If object 'o' is passed, then it must be a Hash, Set or Zset object, otherwise * if 'o' is NULL the command will operate on the dictionary associated with * the current database. * @@ -629,6 +629,7 @@ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { listNode *node, *nextnode; long count = 10; sds pat = NULL; + sds typename = NULL; int patlen = 0, use_pattern = 0; dict *ht; @@ -665,6 +666,10 @@ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { use_pattern = !(pat[0] == '*' && patlen == 1); i += 2; + } else if (!strcasecmp(c->argv[i]->ptr, "type") && o == NULL && j >= 2) { + /* SCAN for a particular type only applies to the db dict */ + typename = c->argv[i+1]->ptr; + i+= 2; } else { addReply(c,shared.syntaxerr); goto cleanup; @@ -759,6 +764,31 @@ void scanGenericCommand(client *c, robj *o, unsigned long cursor) { } } + /* Filter an element if it isn't the type we want. */ + if (!filter && o == NULL && typename){ + robj* typecheck; + char *type; + typecheck = lookupKeyReadWithFlags(c->db, kobj, LOOKUP_NOTOUCH); + if (typecheck == NULL) { + type = "none"; + } else { + switch(typecheck->type) { + case OBJ_STRING: type = "string"; break; + case OBJ_LIST: type = "list"; break; + case OBJ_SET: type = "set"; break; + case OBJ_ZSET: type = "zset"; break; + case OBJ_HASH: type = "hash"; break; + case OBJ_STREAM: type = "stream"; break; + case OBJ_MODULE: { + moduleValue *mv = typecheck->ptr; + type = mv->type->name; + }; break; + default: type = "unknown"; break; + } + } + if (strcasecmp((char*) typename, type)) filter = 1; + } + /* Filter element if it is an expired key. */ if (!filter && o == NULL && expireIfNeeded(c->db, kobj)) filter = 1; diff --git a/tests/unit/scan.tcl b/tests/unit/scan.tcl index c0f4349d2..9f9ff4df2 100644 --- a/tests/unit/scan.tcl +++ b/tests/unit/scan.tcl @@ -53,6 +53,51 @@ start_server {tags {"scan"}} { assert_equal 100 [llength $keys] } + test "SCAN TYPE" { + r flushdb + # populate only creates strings + r debug populate 1000 + + # Check non-strings are excluded + set cur 0 + set keys {} + while 1 { + set res [r scan $cur type "list"] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys {*}$k + if {$cur == 0} break + } + + assert_equal 0 [llength $keys] + + # Check strings are included + set cur 0 + set keys {} + while 1 { + set res [r scan $cur type "string"] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys {*}$k + if {$cur == 0} break + } + + assert_equal 1000 [llength $keys] + + # Check all three args work together + set cur 0 + set keys {} + while 1 { + set res [r scan $cur type "string" match "key:*" count 10] + set cur [lindex $res 0] + set k [lindex $res 1] + lappend keys {*}$k + if {$cur == 0} break + } + + assert_equal 1000 [llength $keys] + } + foreach enc {intset hashtable} { test "SSCAN with encoding $enc" { # Create the Set