Replace slots_to_channels radix tree with slot specific dictionaries for shard channels. (#12804)

We have achieved replacing `slots_to_keys` radix tree with key->slot
linked list (#9356), and then replacing the list with slot specific
dictionaries for keys (#11695).

Shard channels behave just like keys in many ways, and we also need a
slots->channels mapping. Currently this is still done by using a radix
tree. So we should split `server.pubsubshard_channels` into 16384 dicts
and drop the radix tree, just like what we did to DBs.

Some benefits (basically the benefits of what we've done to DBs):
1. Optimize counting channels in a slot. This is currently used only in
removing channels in a slot. But this is potentially more useful:
sometimes we need to know how many channels there are in a specific slot
when doing slot migration. Counting is now implemented by traversing the
radix tree, and with this PR it will be as simple as calling `dictSize`,
from O(n) to O(1).
2. The radix tree in the cluster has been removed. The shard channel
names no longer require additional storage, which can save memory.
3. Potentially useful in slot migration, as shard channels are logically
split by slots, thus making it easier to migrate, remove or add as a
whole.
4. Avoid rehashing a big dict when there is a large number of channels.

Drawbacks:
1. Takes more memory than using radix tree when there are relatively few
shard channels.

What this PR does:
1. in cluster mode, split `server.pubsubshard_channels` into 16384
dicts, in standalone mode, still use only one dict.
2. drop the `slots_to_channels` radix tree.
3. to save memory (to solve the drawback above), all 16384 dicts are
created lazily, which means only when a channel is about to be inserted
to the dict will the dict be initialized, and when all channels are
deleted, the dict would delete itself.
5. use `server.shard_channel_count` to keep track of the number of all
shard channels.

---------

Co-authored-by: Viktor Söderqvist <viktor.soderqvist@est.tech>
This commit is contained in:
Chen Tianjie 2023-12-27 17:40:45 +08:00 committed by GitHub
parent fa751f9bef
commit 8527959598
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 159 additions and 149 deletions

View File

@ -1906,7 +1906,7 @@ int ACLCheckAllPerm(client *c, int *idxptr) {
int totalSubscriptions(void) { int totalSubscriptions(void) {
return dictSize(server.pubsub_patterns) + return dictSize(server.pubsub_patterns) +
dictSize(server.pubsub_channels) + dictSize(server.pubsub_channels) +
dictSize(server.pubsubshard_channels); server.shard_channel_count;
} }
/* If 'new' can access all channels 'original' could then return NULL; /* If 'new' can access all channels 'original' could then return NULL;

View File

@ -48,8 +48,6 @@ void clusterUpdateMyselfHostname(void);
void clusterUpdateMyselfAnnouncedPorts(void); void clusterUpdateMyselfAnnouncedPorts(void);
void clusterUpdateMyselfHumanNodename(void); void clusterUpdateMyselfHumanNodename(void);
void slotToChannelAdd(sds channel);
void slotToChannelDel(sds channel);
void clusterPropagatePublish(robj *channel, robj *message, int sharded); void clusterPropagatePublish(robj *channel, robj *message, int sharded);
unsigned long getClusterConnectionsCount(void); unsigned long getClusterConnectionsCount(void);

View File

@ -1021,9 +1021,6 @@ void clusterInit(void) {
exit(1); exit(1);
} }
/* The slots -> channels map is a radix tree. Initialize it here. */
server.cluster->slots_to_channels = raxNew();
/* Set myself->port/cport/pport to my listening ports, we'll just need to /* Set myself->port/cport/pport to my listening ports, we'll just need to
* discover the IP address via MEET messages. */ * discover the IP address via MEET messages. */
deriveAnnouncedPorts(&myself->tcp_port, &myself->tls_port, &myself->cport); deriveAnnouncedPorts(&myself->tcp_port, &myself->tls_port, &myself->cport);
@ -5075,7 +5072,7 @@ int verifyClusterConfigWithData(void) {
/* Remove all the shard channel related information not owned by the current shard. */ /* Remove all the shard channel related information not owned by the current shard. */
static inline void removeAllNotOwnedShardChannelSubscriptions(void) { static inline void removeAllNotOwnedShardChannelSubscriptions(void) {
if (!dictSize(server.pubsubshard_channels)) return; if (!server.shard_channel_count) return;
clusterNode *currmaster = clusterNodeIsMaster(myself) ? myself : myself->slaveof; clusterNode *currmaster = clusterNodeIsMaster(myself) ? myself : myself->slaveof;
for (int j = 0; j < CLUSTER_SLOTS; j++) { for (int j = 0; j < CLUSTER_SLOTS; j++) {
if (server.cluster->slots[j] != currmaster) { if (server.cluster->slots[j] != currmaster) {
@ -5664,27 +5661,9 @@ sds genClusterInfoString(void) {
void removeChannelsInSlot(unsigned int slot) { void removeChannelsInSlot(unsigned int slot) {
unsigned int channelcount = countChannelsInSlot(slot); if (countChannelsInSlot(slot) == 0) return;
if (channelcount == 0) return;
/* Retrieve all the channels for the slot. */ pubsubShardUnsubscribeAllChannelsInSlot(slot);
robj **channels = zmalloc(sizeof(robj*)*channelcount);
raxIterator iter;
int j = 0;
unsigned char indexed[2];
indexed[0] = (slot >> 8) & 0xff;
indexed[1] = slot & 0xff;
raxStart(&iter,server.cluster->slots_to_channels);
raxSeek(&iter,">=",indexed,2);
while(raxNext(&iter)) {
if (iter.key[0] != indexed[0] || iter.key[1] != indexed[1]) break;
channels[j++] = createStringObject((char*)iter.key + 2, iter.key_len - 2);
}
raxStop(&iter);
pubsubUnsubscribeShardChannels(channels, channelcount);
zfree(channels);
} }
@ -5719,52 +5698,10 @@ unsigned int delKeysInSlot(unsigned int hashslot) {
return j; return j;
} }
/* -----------------------------------------------------------------------------
* Operation(s) on channel rax tree.
* -------------------------------------------------------------------------- */
void slotToChannelUpdate(sds channel, int add) {
size_t keylen = sdslen(channel);
unsigned int hashslot = keyHashSlot(channel,keylen);
unsigned char buf[64];
unsigned char *indexed = buf;
if (keylen+2 > 64) indexed = zmalloc(keylen+2);
indexed[0] = (hashslot >> 8) & 0xff;
indexed[1] = hashslot & 0xff;
memcpy(indexed+2,channel,keylen);
if (add) {
raxInsert(server.cluster->slots_to_channels,indexed,keylen+2,NULL,NULL);
} else {
raxRemove(server.cluster->slots_to_channels,indexed,keylen+2,NULL);
}
if (indexed != buf) zfree(indexed);
}
void slotToChannelAdd(sds channel) {
slotToChannelUpdate(channel,1);
}
void slotToChannelDel(sds channel) {
slotToChannelUpdate(channel,0);
}
/* Get the count of the channels for a given slot. */ /* Get the count of the channels for a given slot. */
unsigned int countChannelsInSlot(unsigned int hashslot) { unsigned int countChannelsInSlot(unsigned int hashslot) {
raxIterator iter; dict *d = server.pubsubshard_channels[hashslot];
int j = 0; return d ? dictSize(d) : 0;
unsigned char indexed[2];
indexed[0] = (hashslot >> 8) & 0xff;
indexed[1] = hashslot & 0xff;
raxStart(&iter,server.cluster->slots_to_channels);
raxSeek(&iter,">=",indexed,2);
while(raxNext(&iter)) {
if (iter.key[0] != indexed[0] || iter.key[1] != indexed[1]) break;
j++;
}
raxStop(&iter);
return j;
} }
int clusterNodeIsMyself(clusterNode *n) { int clusterNodeIsMyself(clusterNode *n) {

View File

@ -318,7 +318,6 @@ struct clusterState {
clusterNode *migrating_slots_to[CLUSTER_SLOTS]; clusterNode *migrating_slots_to[CLUSTER_SLOTS];
clusterNode *importing_slots_from[CLUSTER_SLOTS]; clusterNode *importing_slots_from[CLUSTER_SLOTS];
clusterNode *slots[CLUSTER_SLOTS]; clusterNode *slots[CLUSTER_SLOTS];
rax *slots_to_channels;
/* The following fields are used to take the slave state on elections. */ /* The following fields are used to take the slave state on elections. */
mstime_t failover_auth_time; /* Time of previous or next election. */ mstime_t failover_auth_time; /* Time of previous or next election. */
int failover_auth_count; /* Number of votes received so far. */ int failover_auth_count; /* Number of votes received so far. */

View File

@ -36,7 +36,7 @@ typedef struct pubsubtype {
int shard; int shard;
dict *(*clientPubSubChannels)(client*); dict *(*clientPubSubChannels)(client*);
int (*subscriptionCount)(client*); int (*subscriptionCount)(client*);
dict **serverPubSubChannels; dict **(*serverPubSubChannels)(unsigned int);
robj **subscribeMsg; robj **subscribeMsg;
robj **unsubscribeMsg; robj **unsubscribeMsg;
robj **messageBulk; robj **messageBulk;
@ -62,12 +62,22 @@ dict* getClientPubSubChannels(client *c);
*/ */
dict* getClientPubSubShardChannels(client *c); dict* getClientPubSubShardChannels(client *c);
/*
* Get server's global Pub/Sub channels dict.
*/
dict **getServerPubSubChannels(unsigned int slot);
/*
* Get server's shard level Pub/Sub channels dict.
*/
dict **getServerPubSubShardChannels(unsigned int slot);
/* /*
* Get list of channels client is subscribed to. * Get list of channels client is subscribed to.
* If a pattern is provided, the subset of channels is returned * If a pattern is provided, the subset of channels is returned
* matching the pattern. * matching the pattern.
*/ */
void channelList(client *c, sds pat, dict* pubsub_channels); void channelList(client *c, sds pat, dict** pubsub_channels, int is_sharded);
/* /*
* Pub/Sub type for global channels. * Pub/Sub type for global channels.
@ -76,7 +86,7 @@ pubsubtype pubSubType = {
.shard = 0, .shard = 0,
.clientPubSubChannels = getClientPubSubChannels, .clientPubSubChannels = getClientPubSubChannels,
.subscriptionCount = clientSubscriptionsCount, .subscriptionCount = clientSubscriptionsCount,
.serverPubSubChannels = &server.pubsub_channels, .serverPubSubChannels = getServerPubSubChannels,
.subscribeMsg = &shared.subscribebulk, .subscribeMsg = &shared.subscribebulk,
.unsubscribeMsg = &shared.unsubscribebulk, .unsubscribeMsg = &shared.unsubscribebulk,
.messageBulk = &shared.messagebulk, .messageBulk = &shared.messagebulk,
@ -89,7 +99,7 @@ pubsubtype pubSubShardType = {
.shard = 1, .shard = 1,
.clientPubSubChannels = getClientPubSubShardChannels, .clientPubSubChannels = getClientPubSubShardChannels,
.subscriptionCount = clientShardSubscriptionsCount, .subscriptionCount = clientShardSubscriptionsCount,
.serverPubSubChannels = &server.pubsubshard_channels, .serverPubSubChannels = getServerPubSubShardChannels,
.subscribeMsg = &shared.ssubscribebulk, .subscribeMsg = &shared.ssubscribebulk,
.unsubscribeMsg = &shared.sunsubscribebulk, .unsubscribeMsg = &shared.sunsubscribebulk,
.messageBulk = &shared.smessagebulk, .messageBulk = &shared.smessagebulk,
@ -213,7 +223,7 @@ int serverPubsubSubscriptionCount(void) {
/* Return the number of pubsub shard level channels is handled. */ /* Return the number of pubsub shard level channels is handled. */
int serverPubsubShardSubscriptionCount(void) { int serverPubsubShardSubscriptionCount(void) {
return dictSize(server.pubsubshard_channels); return server.shard_channel_count;
} }
@ -235,6 +245,16 @@ dict* getClientPubSubShardChannels(client *c) {
return c->pubsubshard_channels; return c->pubsubshard_channels;
} }
dict **getServerPubSubChannels(unsigned int slot) {
UNUSED(slot);
return &server.pubsub_channels;
}
dict **getServerPubSubShardChannels(unsigned int slot) {
serverAssert(server.cluster_enabled || slot == 0);
return &server.pubsubshard_channels[slot];
}
/* Return the number of pubsub + pubsub shard level channels /* Return the number of pubsub + pubsub shard level channels
* a client is subscribed to. */ * a client is subscribed to. */
int clientTotalPubSubSubscriptionCount(client *c) { int clientTotalPubSubSubscriptionCount(client *c) {
@ -258,20 +278,32 @@ void unmarkClientAsPubSub(client *c) {
/* Subscribe a client to a channel. Returns 1 if the operation succeeded, or /* Subscribe a client to a channel. Returns 1 if the operation succeeded, or
* 0 if the client was already subscribed to that channel. */ * 0 if the client was already subscribed to that channel. */
int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) { int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) {
dict **d_ptr;
dictEntry *de; dictEntry *de;
list *clients = NULL; list *clients = NULL;
int retval = 0; int retval = 0;
unsigned int slot = 0;
/* Add the channel to the client -> channels hash table */ /* Add the channel to the client -> channels hash table */
if (dictAdd(type.clientPubSubChannels(c),channel,NULL) == DICT_OK) { if (dictAdd(type.clientPubSubChannels(c),channel,NULL) == DICT_OK) {
retval = 1; retval = 1;
incrRefCount(channel); incrRefCount(channel);
/* Add the client to the channel -> list of clients hash table */ /* Add the client to the channel -> list of clients hash table */
de = dictFind(*type.serverPubSubChannels, channel); if (server.cluster_enabled && type.shard) {
slot = c->slot;
}
d_ptr = type.serverPubSubChannels(slot);
if (*d_ptr == NULL) {
*d_ptr = dictCreate(&keylistDictType);
}
de = dictFind(*d_ptr, channel);
if (de == NULL) { if (de == NULL) {
clients = listCreate(); clients = listCreate();
dictAdd(*type.serverPubSubChannels, channel, clients); dictAdd(*d_ptr, channel, clients);
incrRefCount(channel); incrRefCount(channel);
if (type.shard) {
server.shard_channel_count++;
}
} else { } else {
clients = dictGetVal(de); clients = dictGetVal(de);
} }
@ -285,10 +317,12 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) {
/* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or /* Unsubscribe a client from a channel. Returns 1 if the operation succeeded, or
* 0 if the client was not subscribed to the specified channel. */ * 0 if the client was not subscribed to the specified channel. */
int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) { int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) {
dict *d;
dictEntry *de; dictEntry *de;
list *clients; list *clients;
listNode *ln; listNode *ln;
int retval = 0; int retval = 0;
int slot = 0;
/* Remove the channel from the client -> channels hash table */ /* Remove the channel from the client -> channels hash table */
incrRefCount(channel); /* channel may be just a pointer to the same object incrRefCount(channel); /* channel may be just a pointer to the same object
@ -296,7 +330,12 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype ty
if (dictDelete(type.clientPubSubChannels(c),channel) == DICT_OK) { if (dictDelete(type.clientPubSubChannels(c),channel) == DICT_OK) {
retval = 1; retval = 1;
/* Remove the client from the channel -> clients list hash table */ /* Remove the client from the channel -> clients list hash table */
de = dictFind(*type.serverPubSubChannels, channel); if (server.cluster_enabled && type.shard) {
slot = c->slot != -1 ? c->slot : (int)keyHashSlot(channel->ptr, sdslen(channel->ptr));
}
d = *type.serverPubSubChannels(slot);
serverAssertWithInfo(c,NULL,d != NULL);
de = dictFind(d, channel);
serverAssertWithInfo(c,NULL,de != NULL); serverAssertWithInfo(c,NULL,de != NULL);
clients = dictGetVal(de); clients = dictGetVal(de);
ln = listSearchKey(clients,c); ln = listSearchKey(clients,c);
@ -306,11 +345,14 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype ty
/* Free the list and associated hash entry at all if this was /* Free the list and associated hash entry at all if this was
* the latest client, so that it will be possible to abuse * the latest client, so that it will be possible to abuse
* Redis PUBSUB creating millions of channels. */ * Redis PUBSUB creating millions of channels. */
dictDelete(*type.serverPubSubChannels, channel); dictDelete(d, channel);
/* As this channel isn't subscribed by anyone, it's safe if (type.shard) {
* to remove the channel from the slot. */ if (dictSize(d) == 0) {
if (server.cluster_enabled & type.shard) { dictRelease(d);
slotToChannelDel(channel->ptr); dict **d_ptr = type.serverPubSubChannels(slot);
*d_ptr = NULL;
}
server.shard_channel_count--;
} }
} }
} }
@ -322,19 +364,22 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype ty
return retval; return retval;
} }
void pubsubShardUnsubscribeAllClients(robj *channel) { /* Unsubscribe all shard channels in a slot. */
int retval; void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) {
dictEntry *de = dictFind(server.pubsubshard_channels, channel); dict *d = server.pubsubshard_channels[slot];
serverAssertWithInfo(NULL,channel,de != NULL); if (!d) {
return;
}
dictIterator *di = dictGetSafeIterator(d);
dictEntry *de;
while ((de = dictNext(di)) != NULL) {
robj *channel = dictGetKey(de);
list *clients = dictGetVal(de); list *clients = dictGetVal(de);
if (listLength(clients) > 0) {
/* For each client subscribed to the channel, unsubscribe it. */ /* For each client subscribed to the channel, unsubscribe it. */
listIter li;
listNode *ln; listNode *ln;
listRewind(clients, &li); while ((ln = listFirst(clients)) != NULL) {
while ((ln = listNext(&li)) != NULL) {
client *c = listNodeValue(ln); client *c = listNodeValue(ln);
retval = dictDelete(c->pubsubshard_channels, channel); int retval = dictDelete(c->pubsubshard_channels, channel);
serverAssertWithInfo(c,channel,retval == DICT_OK); serverAssertWithInfo(c,channel,retval == DICT_OK);
addReplyPubsubUnsubscribed(c, channel, pubSubShardType); addReplyPubsubUnsubscribed(c, channel, pubSubShardType);
/* If the client has no other pubsub subscription, /* If the client has no other pubsub subscription,
@ -343,16 +388,14 @@ void pubsubShardUnsubscribeAllClients(robj *channel) {
unmarkClientAsPubSub(c); unmarkClientAsPubSub(c);
} }
} }
server.shard_channel_count--;
dictDelete(d, channel);
} }
/* Delete the channel from server pubsubshard channels hash table. */ dictReleaseIterator(di);
retval = dictDelete(server.pubsubshard_channels, channel); dictRelease(d);
/* Delete the channel from slots_to_channel mapping. */ server.pubsubshard_channels[slot] = NULL;
slotToChannelDel(channel->ptr);
serverAssertWithInfo(NULL,channel,retval == DICT_OK);
decrRefCount(channel); /* it is finally safe to release it */
} }
/* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */ /* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */
int pubsubSubscribePattern(client *c, robj *pattern) { int pubsubSubscribePattern(client *c, robj *pattern) {
dictEntry *de; dictEntry *de;
@ -446,17 +489,6 @@ int pubsubUnsubscribeShardAllChannels(client *c, int notify) {
return count; return count;
} }
/*
* Unsubscribe a client from provided shard subscribed channel(s).
*/
void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count) {
for (unsigned int j = 0; j < count; j++) {
/* Remove the channel from server and from the clients
* subscribed to it as well as notify them. */
pubsubShardUnsubscribeAllClients(channels[j]);
}
}
/* Unsubscribe from all the patterns. Return the number of patterns the /* Unsubscribe from all the patterns. Return the number of patterns the
* client was subscribed from. */ * client was subscribed from. */
int pubsubUnsubscribeAllPatterns(client *c, int notify) { int pubsubUnsubscribeAllPatterns(client *c, int notify) {
@ -483,13 +515,19 @@ int pubsubUnsubscribeAllPatterns(client *c, int notify) {
*/ */
int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) { int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) {
int receivers = 0; int receivers = 0;
dict *d;
dictEntry *de; dictEntry *de;
dictIterator *di; dictIterator *di;
listNode *ln; listNode *ln;
listIter li; listIter li;
unsigned int slot = 0;
/* Send to clients listening for that channel */ /* Send to clients listening for that channel */
de = dictFind(*type.serverPubSubChannels, channel); if (server.cluster_enabled && type.shard) {
slot = keyHashSlot(channel->ptr, sdslen(channel->ptr));
}
d = *type.serverPubSubChannels(slot);
de = d ? dictFind(d, channel) : NULL;
if (de) { if (de) {
list *list = dictGetVal(de); list *list = dictGetVal(de);
listNode *ln; listNode *ln;
@ -658,7 +696,7 @@ NULL
{ {
/* PUBSUB CHANNELS [<pattern>] */ /* PUBSUB CHANNELS [<pattern>] */
sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr;
channelList(c, pat, server.pubsub_channels); channelList(c, pat, &server.pubsub_channels, 0);
} else if (!strcasecmp(c->argv[1]->ptr,"numsub") && c->argc >= 2) { } else if (!strcasecmp(c->argv[1]->ptr,"numsub") && c->argc >= 2) {
/* PUBSUB NUMSUB [Channel_1 ... Channel_N] */ /* PUBSUB NUMSUB [Channel_1 ... Channel_N] */
int j; int j;
@ -678,14 +716,15 @@ NULL
{ {
/* PUBSUB SHARDCHANNELS */ /* PUBSUB SHARDCHANNELS */
sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr; sds pat = (c->argc == 2) ? NULL : c->argv[2]->ptr;
channelList(c,pat,server.pubsubshard_channels); channelList(c,pat,server.pubsubshard_channels,server.cluster_enabled);
} else if (!strcasecmp(c->argv[1]->ptr,"shardnumsub") && c->argc >= 2) { } else if (!strcasecmp(c->argv[1]->ptr,"shardnumsub") && c->argc >= 2) {
/* PUBSUB SHARDNUMSUB [ShardChannel_1 ... ShardChannel_N] */ /* PUBSUB SHARDNUMSUB [ShardChannel_1 ... ShardChannel_N] */
int j; int j;
addReplyArrayLen(c, (c->argc-2)*2); addReplyArrayLen(c, (c->argc-2)*2);
for (j = 2; j < c->argc; j++) { for (j = 2; j < c->argc; j++) {
list *l = dictFetchValue(server.pubsubshard_channels, c->argv[j]); unsigned int slot = calculateKeySlot(c->argv[j]->ptr);
dict *d = server.pubsubshard_channels[slot];
list *l = d ? dictFetchValue(d, c->argv[j]) : NULL;
addReplyBulk(c,c->argv[j]); addReplyBulk(c,c->argv[j]);
addReplyLongLong(c,l ? listLength(l) : 0); addReplyLongLong(c,l ? listLength(l) : 0);
@ -695,13 +734,18 @@ NULL
} }
} }
void channelList(client *c, sds pat, dict *pubsub_channels) { void channelList(client *c, sds pat, dict **pubsub_channels, int is_sharded) {
dictIterator *di = dictGetIterator(pubsub_channels);
dictEntry *de;
long mblen = 0; long mblen = 0;
void *replylen; void *replylen;
unsigned int slot_cnt = is_sharded ? CLUSTER_SLOTS : 1;
replylen = addReplyDeferredLen(c); replylen = addReplyDeferredLen(c);
for (unsigned int i = 0; i < slot_cnt; i++) {
if (pubsub_channels[i] == NULL) {
continue;
}
dictIterator *di = dictGetIterator(pubsub_channels[i]);
dictEntry *de;
while((de = dictNext(di)) != NULL) { while((de = dictNext(di)) != NULL) {
robj *cobj = dictGetKey(de); robj *cobj = dictGetKey(de);
sds channel = cobj->ptr; sds channel = cobj->ptr;
@ -714,6 +758,7 @@ void channelList(client *c, sds pat, dict *pubsub_channels) {
} }
} }
dictReleaseIterator(di); dictReleaseIterator(di);
}
setDeferredArrayLen(c,replylen,mblen); setDeferredArrayLen(c,replylen,mblen);
} }
@ -735,14 +780,6 @@ void ssubscribeCommand(client *c) {
} }
for (int j = 1; j < c->argc; j++) { for (int j = 1; j < c->argc; j++) {
/* A channel is only considered to be added, if a
* subscriber exists for it. And if a subscriber
* already exists the slotToChannel doesn't needs
* to be incremented. */
if (server.cluster_enabled &
(dictFind(*pubSubShardType.serverPubSubChannels, c->argv[j]) == NULL)) {
slotToChannelAdd(c->argv[j]->ptr);
}
pubsubSubscribeChannel(c, c->argv[j], pubSubShardType); pubsubSubscribeChannel(c, c->argv[j], pubSubShardType);
} }
markClientAsPubSub(c); markClientAsPubSub(c);

View File

@ -2714,10 +2714,10 @@ void initServer(void) {
server.db = zmalloc(sizeof(redisDb)*server.dbnum); server.db = zmalloc(sizeof(redisDb)*server.dbnum);
/* Create the Redis databases, and initialize other internal state. */ /* Create the Redis databases, and initialize other internal state. */
int slot_count = (server.cluster_enabled) ? CLUSTER_SLOTS : 1;
for (j = 0; j < server.dbnum; j++) { for (j = 0; j < server.dbnum; j++) {
int slotCount = (server.cluster_enabled) ? CLUSTER_SLOTS : 1; server.db[j].dict = dictCreateMultiple(&dbDictType, slot_count);
server.db[j].dict = dictCreateMultiple(&dbDictType, slotCount); server.db[j].expires = dictCreateMultiple(&dbExpiresDictType,slot_count);
server.db[j].expires = dictCreateMultiple(&dbExpiresDictType,slotCount);
server.db[j].expires_cursor = 0; server.db[j].expires_cursor = 0;
server.db[j].blocking_keys = dictCreate(&keylistDictType); server.db[j].blocking_keys = dictCreate(&keylistDictType);
server.db[j].blocking_keys_unblock_on_nokey = dictCreate(&objectKeyPointerValueDictType); server.db[j].blocking_keys_unblock_on_nokey = dictCreate(&objectKeyPointerValueDictType);
@ -2726,7 +2726,7 @@ void initServer(void) {
server.db[j].id = j; server.db[j].id = j;
server.db[j].avg_ttl = 0; server.db[j].avg_ttl = 0;
server.db[j].defrag_later = listCreate(); server.db[j].defrag_later = listCreate();
server.db[j].dict_count = slotCount; server.db[j].dict_count = slot_count;
initDbState(&server.db[j]); initDbState(&server.db[j]);
listSetFreeMethod(server.db[j].defrag_later,(void (*)(void*))sdsfree); listSetFreeMethod(server.db[j].defrag_later,(void (*)(void*))sdsfree);
} }
@ -2734,7 +2734,8 @@ void initServer(void) {
evictionPoolAlloc(); /* Initialize the LRU keys pool. */ evictionPoolAlloc(); /* Initialize the LRU keys pool. */
server.pubsub_channels = dictCreate(&keylistDictType); server.pubsub_channels = dictCreate(&keylistDictType);
server.pubsub_patterns = dictCreate(&keylistDictType); server.pubsub_patterns = dictCreate(&keylistDictType);
server.pubsubshard_channels = dictCreate(&keylistDictType); server.pubsubshard_channels = zcalloc(sizeof(dict *) * slot_count);
server.shard_channel_count = 0;
server.pubsub_clients = 0; server.pubsub_clients = 0;
server.cronloops = 0; server.cronloops = 0;
server.in_exec = 0; server.in_exec = 0;
@ -5869,7 +5870,7 @@ sds genRedisInfoString(dict *section_dict, int all_sections, int everything) {
"keyspace_misses:%lld\r\n", server.stat_keyspace_misses, "keyspace_misses:%lld\r\n", server.stat_keyspace_misses,
"pubsub_channels:%ld\r\n", dictSize(server.pubsub_channels), "pubsub_channels:%ld\r\n", dictSize(server.pubsub_channels),
"pubsub_patterns:%lu\r\n", dictSize(server.pubsub_patterns), "pubsub_patterns:%lu\r\n", dictSize(server.pubsub_patterns),
"pubsubshard_channels:%lu\r\n", dictSize(server.pubsubshard_channels), "pubsubshard_channels:%llu\r\n", server.shard_channel_count,
"latest_fork_usec:%lld\r\n", server.stat_fork_time, "latest_fork_usec:%lld\r\n", server.stat_fork_time,
"total_forks:%lld\r\n", server.stat_total_forks, "total_forks:%lld\r\n", server.stat_total_forks,
"migrate_cached_sockets:%ld\r\n", dictSize(server.migrate_cached_sockets), "migrate_cached_sockets:%ld\r\n", dictSize(server.migrate_cached_sockets),

View File

@ -1994,7 +1994,8 @@ struct redisServer {
dict *pubsub_patterns; /* A dict of pubsub_patterns */ dict *pubsub_patterns; /* A dict of pubsub_patterns */
int notify_keyspace_events; /* Events to propagate via Pub/Sub. This is an int notify_keyspace_events; /* Events to propagate via Pub/Sub. This is an
xor of NOTIFY_... flags. */ xor of NOTIFY_... flags. */
dict *pubsubshard_channels; /* Map shard channels to list of subscribed clients */ dict **pubsubshard_channels; /* Map shard channels in every slot to list of subscribed clients */
unsigned long long shard_channel_count;
unsigned int pubsub_clients; /* # of clients in Pub/Sub mode */ unsigned int pubsub_clients; /* # of clients in Pub/Sub mode */
/* Cluster */ /* Cluster */
int cluster_enabled; /* Is cluster enabled? */ int cluster_enabled; /* Is cluster enabled? */
@ -2498,6 +2499,7 @@ extern dictType sdsHashDictType;
extern dictType dbExpiresDictType; extern dictType dbExpiresDictType;
extern dictType modulesDictType; extern dictType modulesDictType;
extern dictType sdsReplyDictType; extern dictType sdsReplyDictType;
extern dictType keylistDictType;
extern dict *modules; extern dict *modules;
/*----------------------------------------------------------------------------- /*-----------------------------------------------------------------------------
@ -3197,7 +3199,7 @@ robj *hashTypeDup(robj *o);
/* Pub / Sub */ /* Pub / Sub */
int pubsubUnsubscribeAllChannels(client *c, int notify); int pubsubUnsubscribeAllChannels(client *c, int notify);
int pubsubUnsubscribeShardAllChannels(client *c, int notify); int pubsubUnsubscribeShardAllChannels(client *c, int notify);
void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count); void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot);
int pubsubUnsubscribeAllPatterns(client *c, int notify); int pubsubUnsubscribeAllPatterns(client *c, int notify);
int pubsubPublishMessage(robj *channel, robj *message, int sharded); int pubsubPublishMessage(robj *channel, robj *message, int sharded);
int pubsubPublishMessageAndPropagateToCluster(robj *channel, robj *message, int sharded); int pubsubPublishMessageAndPropagateToCluster(robj *channel, robj *message, int sharded);

View File

@ -56,6 +56,21 @@ test "client can subscribe to multiple shard channels across different slots in
$cluster sunsubscribe ch7 $cluster sunsubscribe ch7
} }
test "sunsubscribe without specifying any channel would unsubscribe all shard channels subscribed" {
set publishclient [redis_client_by_addr $publishnode(host) $publishnode(port)]
set subscribeclient [redis_deferring_client_by_addr $publishnode(host) $publishnode(port)]
set sub_res [ssubscribe $subscribeclient [list "\{channel.0\}1" "\{channel.0\}2" "\{channel.0\}3"]]
assert_equal [list 1 2 3] $sub_res
sunsubscribe $subscribeclient
assert_equal 0 [$publishclient spublish "\{channel.0\}1" hello]
assert_equal 0 [$publishclient spublish "\{channel.0\}2" hello]
assert_equal 0 [$publishclient spublish "\{channel.0\}3" hello]
$publishclient close
$subscribeclient close
}
test "Verify Pub/Sub and Pub/Sub shard no overlap" { test "Verify Pub/Sub and Pub/Sub shard no overlap" {
set slot [$cluster cluster keyslot "channel.0"] set slot [$cluster cluster keyslot "channel.0"]
@ -92,3 +107,24 @@ test "Verify Pub/Sub and Pub/Sub shard no overlap" {
$subscribeclient close $subscribeclient close
$subscribeshardclient close $subscribeshardclient close
} }
test "PUBSUB channels/shardchannels" {
set subscribeclient [redis_deferring_client_by_addr $publishnode(host) $publishnode(port)]
set subscribeclient2 [redis_deferring_client_by_addr $publishnode(host) $publishnode(port)]
set subscribeclient3 [redis_deferring_client_by_addr $publishnode(host) $publishnode(port)]
set publishclient [redis_client_by_addr $publishnode(host) $publishnode(port)]
ssubscribe $subscribeclient [list "\{channel.0\}1"]
ssubscribe $subscribeclient2 [list "\{channel.0\}2"]
ssubscribe $subscribeclient3 [list "\{channel.0\}3"]
assert_equal {3} [llength [$publishclient pubsub shardchannels]]
subscribe $subscribeclient [list "\{channel.0\}4"]
assert_equal {3} [llength [$publishclient pubsub shardchannels]]
sunsubscribe $subscribeclient
set channel_list [$publishclient pubsub shardchannels]
assert_equal {2} [llength $channel_list]
assert {[lsearch -exact $channel_list "\{channel.0\}2"] >= 0}
assert {[lsearch -exact $channel_list "\{channel.0\}3"] >= 0}
}