diff --git a/src/pubsub.c b/src/pubsub.c index 6c69431b8..9e3958b36 100644 --- a/src/pubsub.c +++ b/src/pubsub.c @@ -280,7 +280,7 @@ void unmarkClientAsPubSub(client *c) { int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { dict **d_ptr; dictEntry *de; - list *clients = NULL; + dict *clients = NULL; int retval = 0; unsigned int slot = 0; @@ -294,13 +294,13 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { } d_ptr = type.serverPubSubChannels(slot); if (*d_ptr == NULL) { - *d_ptr = dictCreate(&keylistDictType); + *d_ptr = dictCreate(&objToDictDictType); de = NULL; } else { de = dictFind(*d_ptr, channel); } if (de == NULL) { - clients = listCreate(); + clients = dictCreate(&clientDictType); dictAdd(*d_ptr, channel, clients); incrRefCount(channel); if (type.shard) { @@ -309,7 +309,7 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { } else { clients = dictGetVal(de); } - listAddNodeTail(clients,c); + serverAssert(dictAdd(clients, c, NULL) != DICT_ERR); } /* Notify the client */ addReplyPubsubSubscribed(c,channel,type); @@ -321,8 +321,7 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) { dict *d; dictEntry *de; - list *clients; - listNode *ln; + dict *clients; int retval = 0; int slot = 0; @@ -340,11 +339,9 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype ty de = dictFind(d, channel); serverAssertWithInfo(c,NULL,de != NULL); clients = dictGetVal(de); - ln = listSearchKey(clients,c); - serverAssertWithInfo(c,NULL,ln != NULL); - listDelNode(clients,ln); - if (listLength(clients) == 0) { - /* Free the list and associated hash entry at all if this was + serverAssertWithInfo(c, NULL, dictDelete(clients, c) == DICT_OK); + if (dictSize(clients) == 0) { + /* Free the dict and associated hash entry at all if this was * the latest client, so that it will be possible to abuse * Redis PUBSUB creating millions of channels. */ dictDelete(d, channel); @@ -376,11 +373,13 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) { dictEntry *de; while ((de = dictNext(di)) != NULL) { robj *channel = dictGetKey(de); - list *clients = dictGetVal(de); + dict *clients = dictGetVal(de); + if (dictSize(clients) == 0) goto cleanup; /* For each client subscribed to the channel, unsubscribe it. */ - listNode *ln; - while ((ln = listFirst(clients)) != NULL) { - client *c = listNodeValue(ln); + dictIterator *iter = dictGetSafeIterator(clients); + dictEntry *entry; + while ((entry = dictNext(iter)) != NULL) { + client *c = dictGetKey(entry); int retval = dictDelete(c->pubsubshard_channels, channel); serverAssertWithInfo(c,channel,retval == DICT_OK); addReplyPubsubUnsubscribed(c, channel, pubSubShardType); @@ -389,8 +388,9 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) { if (clientTotalPubSubSubscriptionCount(c) == 0) { unmarkClientAsPubSub(c); } - listDelNode(clients, ln); } + dictReleaseIterator(iter); +cleanup: server.shard_channel_count--; dictDelete(d, channel); } @@ -402,7 +402,7 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) { /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ int pubsubSubscribePattern(client *c, robj *pattern) { dictEntry *de; - list *clients; + dict *clients; int retval = 0; if (dictAdd(c->pubsub_patterns, pattern, NULL) == DICT_OK) { @@ -411,13 +411,13 @@ int pubsubSubscribePattern(client *c, robj *pattern) { /* Add the client to the pattern -> list of clients hash table */ de = dictFind(server.pubsub_patterns,pattern); if (de == NULL) { - clients = listCreate(); + clients = dictCreate(&clientDictType); dictAdd(server.pubsub_patterns,pattern,clients); incrRefCount(pattern); } else { clients = dictGetVal(de); } - listAddNodeTail(clients,c); + serverAssert(dictAdd(clients, c, NULL) != DICT_ERR); } /* Notify the client */ addReplyPubsubPatSubscribed(c,pattern); @@ -428,8 +428,7 @@ int pubsubSubscribePattern(client *c, robj *pattern) { * 0 if the client was not subscribed to the specified channel. */ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { dictEntry *de; - list *clients; - listNode *ln; + dict *clients; int retval = 0; incrRefCount(pattern); /* Protect the object. May be the same we remove */ @@ -439,11 +438,9 @@ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) { de = dictFind(server.pubsub_patterns,pattern); serverAssertWithInfo(c,NULL,de != NULL); clients = dictGetVal(de); - ln = listSearchKey(clients,c); - serverAssertWithInfo(c,NULL,ln != NULL); - listDelNode(clients,ln); - if (listLength(clients) == 0) { - /* Free the list and associated hash entry at all if this was + serverAssertWithInfo(c, NULL, dictDelete(clients, c) == DICT_OK); + if (dictSize(clients) == 0) { + /* Free the dict and associated hash entry at all if this was * the latest client. */ dictDelete(server.pubsub_patterns,pattern); } @@ -521,8 +518,6 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) dict *d; dictEntry *de; dictIterator *di; - listNode *ln; - listIter li; unsigned int slot = 0; /* Send to clients listening for that channel */ @@ -532,17 +527,16 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) d = *type.serverPubSubChannels(slot); de = d ? dictFind(d, channel) : NULL; if (de) { - list *list = dictGetVal(de); - listNode *ln; - listIter li; - - listRewind(list,&li); - while ((ln = listNext(&li)) != NULL) { - client *c = ln->value; + dict *clients = dictGetVal(de); + dictEntry *entry; + dictIterator *iter = dictGetSafeIterator(clients); + while ((entry = dictNext(iter)) != NULL) { + client *c = dictGetKey(entry); addReplyPubsubMessage(c,channel,message,*type.messageBulk); updateClientMemUsageAndBucket(c); receivers++; } + dictReleaseIterator(iter); } if (type.shard) { @@ -556,19 +550,21 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) channel = getDecodedObject(channel); while((de = dictNext(di)) != NULL) { robj *pattern = dictGetKey(de); - list *clients = dictGetVal(de); + dict *clients = dictGetVal(de); if (!stringmatchlen((char*)pattern->ptr, sdslen(pattern->ptr), (char*)channel->ptr, sdslen(channel->ptr),0)) continue; - listRewind(clients,&li); - while ((ln = listNext(&li)) != NULL) { - client *c = listNodeValue(ln); + dictEntry *entry; + dictIterator *iter = dictGetSafeIterator(clients); + while ((entry = dictNext(iter)) != NULL) { + client *c = dictGetKey(entry); addReplyPubsubPatMessage(c,pattern,channel,message); updateClientMemUsageAndBucket(c); receivers++; } + dictReleaseIterator(iter); } decrRefCount(channel); dictReleaseIterator(di); @@ -706,10 +702,10 @@ NULL addReplyArrayLen(c,(c->argc-2)*2); for (j = 2; j < c->argc; j++) { - list *l = dictFetchValue(server.pubsub_channels,c->argv[j]); + dict *d = dictFetchValue(server.pubsub_channels, c->argv[j]); addReplyBulk(c,c->argv[j]); - addReplyLongLong(c,l ? listLength(l) : 0); + addReplyLongLong(c, d ? dictSize(d) : 0); } } else if (!strcasecmp(c->argv[1]->ptr,"numpat") && c->argc == 2) { /* PUBSUB NUMPAT */ @@ -727,10 +723,10 @@ NULL for (j = 2; j < c->argc; j++) { unsigned int slot = calculateKeySlot(c->argv[j]->ptr); dict *d = server.pubsubshard_channels[slot]; - list *l = d ? dictFetchValue(d, c->argv[j]) : NULL; + dict *clients = d ? dictFetchValue(d, c->argv[j]) : NULL; addReplyBulk(c,c->argv[j]); - addReplyLongLong(c,l ? listLength(l) : 0); + addReplyLongLong(c, d ? dictSize(clients) : 0); } } else { addReplySubcommandSyntaxError(c); diff --git a/src/server.c b/src/server.c index 4fd4a993c..280644e7f 100644 --- a/src/server.c +++ b/src/server.c @@ -282,6 +282,12 @@ void dictListDestructor(dict *d, void *val) listRelease((list*)val); } +void dictDictDestructor(dict *d, void *val) +{ + UNUSED(d); + dictRelease((dict*)val); +} + int dictSdsKeyCompare(dict *d, const void *key1, const void *key2) { @@ -351,6 +357,17 @@ uint64_t dictCStrCaseHash(const void *key) { return dictGenCaseHashFunction((unsigned char*)key, strlen((char*)key)); } +/* Dict hash function for client */ +uint64_t dictClientHash(const void *key) { + return ((client *)key)->id; +} + +/* Dict compare function for client */ +int dictClientKeyCompare(dict *d, const void *key1, const void *key2) { + UNUSED(d); + return ((client *)key1)->id == ((client *)key2)->id; +} + /* Dict compare function for null terminated string */ int dictCStrKeyCompare(dict *d, const void *key1, const void *key2) { int l1,l2; @@ -596,6 +613,18 @@ dictType keylistDictType = { NULL /* allow to expand */ }; +/* KeyDict hash table type has unencoded redis objects as keys and + * dicts as values. It's used for PUBSUB command to track clients subscribing the channels. */ +dictType objToDictDictType = { + dictObjHash, /* hash function */ + NULL, /* key dup */ + NULL, /* val dup */ + dictObjKeyCompare, /* key compare */ + dictObjectDestructor, /* key destructor */ + dictDictDestructor, /* val destructor */ + NULL /* allow to expand */ +}; + /* Modules system dictionary type. Keys are module name, * values are pointer to RedisModule struct. */ dictType modulesDictType = { @@ -655,6 +684,15 @@ dictType sdsHashDictType = { NULL /* allow to expand */ }; +/* Client Set dictionary type. Keys are client, values are not used. */ +dictType clientDictType = { + dictClientHash, /* hash function */ + NULL, /* key dup */ + NULL, /* val dup */ + dictClientKeyCompare, /* key compare */ + .no_value = 1 /* no values in this dict */ +}; + int htNeedsResize(dict *dict) { long long size, used; @@ -2745,8 +2783,8 @@ void initServer(void) { } server.rehashing = listCreate(); evictionPoolAlloc(); /* Initialize the LRU keys pool. */ - server.pubsub_channels = dictCreate(&keylistDictType); - server.pubsub_patterns = dictCreate(&keylistDictType); + server.pubsub_channels = dictCreate(&objToDictDictType); + server.pubsub_patterns = dictCreate(&objToDictDictType); server.pubsubshard_channels = zcalloc(sizeof(dict *) * slot_count); server.shard_channel_count = 0; server.pubsub_clients = 0; diff --git a/src/server.h b/src/server.h index b398d8ae9..be33bf803 100644 --- a/src/server.h +++ b/src/server.h @@ -2499,6 +2499,8 @@ extern dictType hashDictType; extern dictType stringSetDictType; extern dictType externalStringType; extern dictType sdsHashDictType; +extern dictType clientDictType; +extern dictType objToDictDictType; extern dictType dbExpiresDictType; extern dictType modulesDictType; extern dictType sdsReplyDictType;