From 967fb3c6e812ed3ff4259b499a4a8401fe37b0fd Mon Sep 17 00:00:00 2001 From: Guillaume Koenig <106696198+knggk@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:50:18 -0500 Subject: [PATCH] Extend rax usage by allowing any long long value (#12837) The raxFind implementation uses a special pointer value (the address of a static string) as the "not found" value. It works as long as actual pointers were used. However we've seen usages where long long, non-pointer values have been used. It creates a risk that one of the long long value precisely is the address of the special "not found" value. This commit changes raxFind to return 1 or 0 to indicate elementhood, and take in a new void **value to optionally return the associated value. By extension, this also allow the RedisModule_DictSet/Replace operations to also safely insert integers instead of just pointers. --- src/acl.c | 10 +++++----- src/module.c | 50 ++++++++++++++++++++++++++++++------------------ src/networking.c | 5 +++-- src/rax.c | 18 +++++++---------- src/rax.h | 5 +---- src/rdb.c | 5 +++-- src/server.c | 12 +++++++----- src/t_stream.c | 40 ++++++++++++++++++++++---------------- src/tracking.c | 25 ++++++++++++++++-------- 9 files changed, 97 insertions(+), 73 deletions(-) diff --git a/src/acl.c b/src/acl.c index 58a9a3972..841f101cb 100644 --- a/src/acl.c +++ b/src/acl.c @@ -437,7 +437,7 @@ aclSelector *ACLUserGetRootSelector(user *u) { * * If the user with such name already exists NULL is returned. */ user *ACLCreateUser(const char *name, size_t namelen) { - if (raxFind(Users,(unsigned char*)name,namelen) != raxNotFound) return NULL; + if (raxFind(Users,(unsigned char*)name,namelen,NULL)) return NULL; user *u = zmalloc(sizeof(*u)); u->name = sdsnewlen(name,namelen); u->flags = USER_FLAG_DISABLED; @@ -1553,8 +1553,8 @@ unsigned long ACLGetCommandID(sds cmdname) { sds lowername = sdsdup(cmdname); sdstolower(lowername); if (commandId == NULL) commandId = raxNew(); - void *id = raxFind(commandId,(unsigned char*)lowername,sdslen(lowername)); - if (id != raxNotFound) { + void *id; + if (raxFind(commandId,(unsigned char*)lowername,sdslen(lowername),&id)) { sdsfree(lowername); return (unsigned long)id; } @@ -1585,8 +1585,8 @@ void ACLClearCommandID(void) { /* Return an username by its name, or NULL if the user does not exist. */ user *ACLGetUserByName(const char *name, size_t namelen) { - void *myuser = raxFind(Users,(unsigned char*)name,namelen); - if (myuser == raxNotFound) return NULL; + void *myuser = NULL; + raxFind(Users,(unsigned char*)name,namelen,&myuser); return myuser; } diff --git a/src/module.c b/src/module.c index 96bc61e0f..b966998c6 100644 --- a/src/module.c +++ b/src/module.c @@ -9130,7 +9130,7 @@ RedisModuleTimerID RM_CreateTimer(RedisModuleCtx *ctx, mstime_t period, RedisMod while(1) { key = htonu64(expiretime); - if (raxFind(Timers, (unsigned char*)&key,sizeof(key)) == raxNotFound) { + if (!raxFind(Timers, (unsigned char*)&key,sizeof(key),NULL)) { raxInsert(Timers,(unsigned char*)&key,sizeof(key),timer,NULL); break; } else { @@ -9169,8 +9169,11 @@ RedisModuleTimerID RM_CreateTimer(RedisModuleCtx *ctx, mstime_t period, RedisMod * If not NULL, the data pointer is set to the value of the data argument when * the timer was created. */ int RM_StopTimer(RedisModuleCtx *ctx, RedisModuleTimerID id, void **data) { - RedisModuleTimer *timer = raxFind(Timers,(unsigned char*)&id,sizeof(id)); - if (timer == raxNotFound || timer->module != ctx->module) + void *result; + if (!raxFind(Timers,(unsigned char*)&id,sizeof(id),&result)) + return REDISMODULE_ERR; + RedisModuleTimer *timer = result; + if (timer->module != ctx->module) return REDISMODULE_ERR; if (data) *data = timer->data; raxRemove(Timers,(unsigned char*)&id,sizeof(id),NULL); @@ -9185,8 +9188,11 @@ int RM_StopTimer(RedisModuleCtx *ctx, RedisModuleTimerID id, void **data) { * REDISMODULE_OK is returned. The arguments remaining or data can be NULL if * the caller does not need certain information. */ int RM_GetTimerInfo(RedisModuleCtx *ctx, RedisModuleTimerID id, uint64_t *remaining, void **data) { - RedisModuleTimer *timer = raxFind(Timers,(unsigned char*)&id,sizeof(id)); - if (timer == raxNotFound || timer->module != ctx->module) + void *result; + if (!raxFind(Timers,(unsigned char*)&id,sizeof(id),&result)) + return REDISMODULE_ERR; + RedisModuleTimer *timer = result; + if (timer->module != ctx->module) return REDISMODULE_ERR; if (remaining) { int64_t rem = ntohu64(id)-ustime(); @@ -9954,9 +9960,10 @@ int RM_DictReplace(RedisModuleDict *d, RedisModuleString *key, void *ptr) { * be set by reference to 1 if the key does not exist, or to 0 if the key * exists. */ void *RM_DictGetC(RedisModuleDict *d, void *key, size_t keylen, int *nokey) { - void *res = raxFind(d->rax,key,keylen); - if (nokey) *nokey = (res == raxNotFound); - return (res == raxNotFound) ? NULL : res; + void *res = NULL; + int found = raxFind(d->rax,key,keylen,&res); + if (nokey) *nokey = !found; + return res; } /* Like RedisModule_DictGetC() but takes the key as a RedisModuleString. */ @@ -10378,8 +10385,10 @@ void RM_FreeServerInfo(RedisModuleCtx *ctx, RedisModuleServerInfoData *data) { * mechanism to release the returned string. Return value will be NULL if the * field was not found. */ RedisModuleString *RM_ServerInfoGetField(RedisModuleCtx *ctx, RedisModuleServerInfoData *data, const char* field) { - sds val = raxFind(data->rax, (unsigned char *)field, strlen(field)); - if (val == raxNotFound) return NULL; + void *result; + if (!raxFind(data->rax, (unsigned char *)field, strlen(field), &result)) + return NULL; + sds val = result; RedisModuleString *o = createStringObject(val,sdslen(val)); if (ctx != NULL) autoMemoryAdd(ctx,REDISMODULE_AM_STRING,o); return o; @@ -10387,9 +10396,9 @@ RedisModuleString *RM_ServerInfoGetField(RedisModuleCtx *ctx, RedisModuleServerI /* Similar to RM_ServerInfoGetField, but returns a char* which should not be freed but the caller. */ const char *RM_ServerInfoGetFieldC(RedisModuleServerInfoData *data, const char* field) { - sds val = raxFind(data->rax, (unsigned char *)field, strlen(field)); - if (val == raxNotFound) return NULL; - return val; + void *result = NULL; + raxFind(data->rax, (unsigned char *)field, strlen(field), &result); + return result; } /* Get the value of a field from data collected with RM_GetServerInfo(). If the @@ -10397,11 +10406,12 @@ const char *RM_ServerInfoGetFieldC(RedisModuleServerInfoData *data, const char* * 0, and the optional out_err argument will be set to REDISMODULE_ERR. */ long long RM_ServerInfoGetFieldSigned(RedisModuleServerInfoData *data, const char* field, int *out_err) { long long ll; - sds val = raxFind(data->rax, (unsigned char *)field, strlen(field)); - if (val == raxNotFound) { + void *result; + if (!raxFind(data->rax, (unsigned char *)field, strlen(field), &result)) { if (out_err) *out_err = REDISMODULE_ERR; return 0; } + sds val = result; if (!string2ll(val,sdslen(val),&ll)) { if (out_err) *out_err = REDISMODULE_ERR; return 0; @@ -10415,11 +10425,12 @@ long long RM_ServerInfoGetFieldSigned(RedisModuleServerInfoData *data, const cha * 0, and the optional out_err argument will be set to REDISMODULE_ERR. */ unsigned long long RM_ServerInfoGetFieldUnsigned(RedisModuleServerInfoData *data, const char* field, int *out_err) { unsigned long long ll; - sds val = raxFind(data->rax, (unsigned char *)field, strlen(field)); - if (val == raxNotFound) { + void *result; + if (!raxFind(data->rax, (unsigned char *)field, strlen(field), &result)) { if (out_err) *out_err = REDISMODULE_ERR; return 0; } + sds val = result; if (!string2ull(val,&ll)) { if (out_err) *out_err = REDISMODULE_ERR; return 0; @@ -10433,11 +10444,12 @@ unsigned long long RM_ServerInfoGetFieldUnsigned(RedisModuleServerInfoData *data * optional out_err argument will be set to REDISMODULE_ERR. */ double RM_ServerInfoGetFieldDouble(RedisModuleServerInfoData *data, const char* field, int *out_err) { double dbl; - sds val = raxFind(data->rax, (unsigned char *)field, strlen(field)); - if (val == raxNotFound) { + void *result; + if (!raxFind(data->rax, (unsigned char *)field, strlen(field), &result)) { if (out_err) *out_err = REDISMODULE_ERR; return 0; } + sds val = result; if (!string2d(val,sdslen(val),&dbl)) { if (out_err) *out_err = REDISMODULE_ERR; return 0; diff --git a/src/networking.c b/src/networking.c index 4d8daecb3..c020faf89 100644 --- a/src/networking.c +++ b/src/networking.c @@ -1812,8 +1812,9 @@ int freeClientsInAsyncFreeQueue(void) { * are not registered clients. */ client *lookupClientByID(uint64_t id) { id = htonu64(id); - client *c = raxFind(server.clients_index,(unsigned char*)&id,sizeof(id)); - return (c == raxNotFound) ? NULL : c; + void *c = NULL; + raxFind(server.clients_index,(unsigned char*)&id,sizeof(id),&c); + return c; } /* This function should be called from _writeToClient when the reply list is not empty, diff --git a/src/rax.c b/src/rax.c index 304b26fe8..100744d79 100644 --- a/src/rax.c +++ b/src/rax.c @@ -44,11 +44,6 @@ #include RAX_MALLOC_INCLUDE -/* This is a special pointer that is guaranteed to never have the same value - * of a radix tree node. It's used in order to report "not found" error without - * requiring the function to have multiple return values. */ -void *raxNotFound = (void*)"rax-not-found-pointer"; - /* -------------------------------- Debugging ------------------------------ */ void raxDebugShowNode(const char *msg, raxNode *n); @@ -912,18 +907,19 @@ int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old) return raxGenericInsert(rax,s,len,data,old,0); } -/* Find a key in the rax, returns raxNotFound special void pointer value - * if the item was not found, otherwise the value associated with the - * item is returned. */ -void *raxFind(rax *rax, unsigned char *s, size_t len) { +/* Find a key in the rax: return 1 if the item is found, 0 otherwise. + * If there is an item and 'value' is passed in a non-NULL pointer, + * the value associated with the item is set at that address. */ +int raxFind(rax *rax, unsigned char *s, size_t len, void **value) { raxNode *h; debugf("### Lookup: %.*s\n", (int)len, s); int splitpos = 0; size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,NULL); if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) - return raxNotFound; - return raxGetData(h); + return 0; + if (value != NULL) *value = raxGetData(h); + return 1; } /* Return the memory address where the 'parent' node stores the specified diff --git a/src/rax.h b/src/rax.h index 6b1fd4188..c58c28b2c 100644 --- a/src/rax.h +++ b/src/rax.h @@ -185,15 +185,12 @@ typedef struct raxIterator { raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ } raxIterator; -/* A special pointer returned for not found items. */ -extern void *raxNotFound; - /* Exported API. */ rax *raxNew(void); int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old); int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old); int raxRemove(rax *rax, unsigned char *s, size_t len, void **old); -void *raxFind(rax *rax, unsigned char *s, size_t len); +int raxFind(rax *rax, unsigned char *s, size_t len, void **value); void raxFree(rax *rax); void raxFreeWithCallback(rax *rax, void (*free_callback)(void*)); void raxStart(raxIterator *it, rax *rt); diff --git a/src/rdb.c b/src/rdb.c index b50ea7867..f6b0054cc 100644 --- a/src/rdb.c +++ b/src/rdb.c @@ -2751,13 +2751,14 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) { decrRefCount(o); return NULL; } - streamNACK *nack = raxFind(cgroup->pel,rawid,sizeof(rawid)); - if (nack == raxNotFound) { + void *result; + if (!raxFind(cgroup->pel,rawid,sizeof(rawid),&result)) { rdbReportCorruptRDB("Consumer entry not found in " "group global PEL"); decrRefCount(o); return NULL; } + streamNACK *nack = result; /* Set the NACK consumer, that was left to NULL when * loading the global PEL. Then set the same shared diff --git a/src/server.c b/src/server.c index 29282958d..25a20a49b 100644 --- a/src/server.c +++ b/src/server.c @@ -4279,13 +4279,15 @@ int processCommand(client *c) { /* ====================== Error lookup and execution ===================== */ void incrementErrorCount(const char *fullerr, size_t namelen) { - struct redisError *error = raxFind(server.errors,(unsigned char*)fullerr,namelen); - if (error == raxNotFound) { - error = zmalloc(sizeof(*error)); - error->count = 0; + void *result; + if (!raxFind(server.errors,(unsigned char*)fullerr,namelen,&result)) { + struct redisError *error = zmalloc(sizeof(*error)); + error->count = 1; raxInsert(server.errors,(unsigned char*)fullerr,namelen,error,NULL); + } else { + struct redisError *error = result; + error->count++; } - error->count++; } /*================================== Shutdown =============================== */ diff --git a/src/t_stream.c b/src/t_stream.c index ccb566bae..733ccfc8c 100644 --- a/src/t_stream.c +++ b/src/t_stream.c @@ -242,10 +242,12 @@ robj *streamDup(robj *o) { raxStart(&ri_cpel, consumer->pel); raxSeek(&ri_cpel, "^", NULL, 0); while (raxNext(&ri_cpel)) { - streamNACK *new_nack = raxFind(new_cg->pel,ri_cpel.key,sizeof(streamID)); + void *result; + int found = raxFind(new_cg->pel,ri_cpel.key,sizeof(streamID),&result); - serverAssert(new_nack != raxNotFound); + serverAssert(found); + streamNACK *new_nack = result; new_nack->consumer = new_consumer; raxInsert(new_consumer->pel,ri_cpel.key,sizeof(streamID),new_nack,NULL); } @@ -1760,8 +1762,10 @@ size_t streamReplyWithRange(client *c, stream *s, streamID *start, streamID *end * or update it if the consumer is the same as before. */ if (group_inserted == 0) { streamFreeNACK(nack); - nack = raxFind(group->pel,buf,sizeof(buf)); - serverAssert(nack != raxNotFound); + void *result; + int found = raxFind(group->pel,buf,sizeof(buf),&result); + serverAssert(found); + nack = result; raxRemove(nack->consumer->pel,buf,sizeof(buf),NULL); /* Update the consumer and NACK metadata. */ nack->consumer = consumer; @@ -2473,7 +2477,7 @@ void streamFreeConsumer(streamConsumer *sc) { * consumer group is returned. */ streamCG *streamCreateCG(stream *s, char *name, size_t namelen, streamID *id, long long entries_read) { if (s->cgroups == NULL) s->cgroups = raxNew(); - if (raxFind(s->cgroups,(unsigned char*)name,namelen) != raxNotFound) + if (raxFind(s->cgroups,(unsigned char*)name,namelen,NULL)) return NULL; streamCG *cg = zmalloc(sizeof(*cg)); @@ -2496,9 +2500,9 @@ void streamFreeCG(streamCG *cg) { * pointer, otherwise if there is no such group, NULL is returned. */ streamCG *streamLookupCG(stream *s, sds groupname) { if (s->cgroups == NULL) return NULL; - streamCG *cg = raxFind(s->cgroups,(unsigned char*)groupname, - sdslen(groupname)); - return (cg == raxNotFound) ? NULL : cg; + void *cg = NULL; + raxFind(s->cgroups,(unsigned char*)groupname,sdslen(groupname),&cg); + return cg; } /* Create a consumer with the specified name in the group 'cg' and return. @@ -2528,9 +2532,8 @@ streamConsumer *streamCreateConsumer(streamCG *cg, sds name, robj *key, int dbid /* Lookup the consumer with the specified name in the group 'cg'. */ streamConsumer *streamLookupConsumer(streamCG *cg, sds name) { if (cg == NULL) return NULL; - streamConsumer *consumer = raxFind(cg->consumers,(unsigned char*)name, - sdslen(name)); - if (consumer == raxNotFound) return NULL; + void *consumer = NULL; + raxFind(cg->consumers,(unsigned char*)name,sdslen(name),&consumer); return consumer; } @@ -2844,8 +2847,9 @@ void xackCommand(client *c) { /* Lookup the ID in the group PEL: it will have a reference to the * NACK structure that will have a reference to the consumer, so that * we are able to remove the entry from both PELs. */ - streamNACK *nack = raxFind(group->pel,buf,sizeof(buf)); - if (nack != raxNotFound) { + void *result; + if (raxFind(group->pel,buf,sizeof(buf),&result)) { + streamNACK *nack = result; raxRemove(group->pel,buf,sizeof(buf),NULL); raxRemove(nack->consumer->pel,buf,sizeof(buf),NULL); streamFreeNACK(nack); @@ -3224,12 +3228,14 @@ void xclaimCommand(client *c) { streamEncodeID(buf,&id); /* Lookup the ID in the group PEL. */ - streamNACK *nack = raxFind(group->pel,buf,sizeof(buf)); + void *result = NULL; + raxFind(group->pel,buf,sizeof(buf),&result); + streamNACK *nack = result; /* Item must exist for us to transfer it to another consumer. */ if (!streamEntryExists(o->ptr,&id)) { /* Clear this entry from the PEL, it no longer exists */ - if (nack != raxNotFound) { + if (nack != NULL) { /* Propagate this change (we are going to delete the NACK). */ streamPropagateXCLAIM(c,c->argv[1],group,c->argv[2],c->argv[j],nack); propagate_last_id = 0; /* Will be propagated by XCLAIM itself. */ @@ -3247,13 +3253,13 @@ void xclaimCommand(client *c) { * entry in the PEL from scratch, so that XCLAIM can also * be used to create entries in the PEL. Useful for AOF * and replication of consumer groups. */ - if (force && nack == raxNotFound) { + if (force && nack == NULL) { /* Create the NACK. */ nack = streamCreateNACK(NULL); raxInsert(group->pel,buf,sizeof(buf),nack,NULL); } - if (nack != raxNotFound) { + if (nack != NULL) { /* We need to check if the minimum idle time requested * by the caller is satisfied by this entry. * diff --git a/src/tracking.c b/src/tracking.c index 5a9b114aa..429770065 100644 --- a/src/tracking.c +++ b/src/tracking.c @@ -72,8 +72,10 @@ void disableTracking(client *c) { raxStart(&ri,c->client_tracking_prefixes); raxSeek(&ri,"^",NULL,0); while(raxNext(&ri)) { - bcastState *bs = raxFind(PrefixTable,ri.key,ri.key_len); - serverAssert(bs != raxNotFound); + void *result; + int found = raxFind(PrefixTable,ri.key,ri.key_len,&result); + serverAssert(found); + bcastState *bs = result; raxRemove(bs->clients,(unsigned char*)&c,sizeof(c),NULL); /* Was it the last client? Remove the prefix from the * table. */ @@ -153,14 +155,17 @@ int checkPrefixCollisionsOrReply(client *c, robj **prefixes, size_t numprefix) { /* Set the client 'c' to track the prefix 'prefix'. If the client 'c' is * already registered for the specified prefix, no operation is performed. */ void enableBcastTrackingForPrefix(client *c, char *prefix, size_t plen) { - bcastState *bs = raxFind(PrefixTable,(unsigned char*)prefix,plen); + void *result; + bcastState *bs; /* If this is the first client subscribing to such prefix, create * the prefix in the table. */ - if (bs == raxNotFound) { + if (!raxFind(PrefixTable,(unsigned char*)prefix,plen,&result)) { bs = zmalloc(sizeof(*bs)); bs->keys = raxNew(); bs->clients = raxNew(); raxInsert(PrefixTable,(unsigned char*)prefix,plen,bs,NULL); + } else { + bs = result; } if (raxTryInsert(bs->clients,(unsigned char*)&c,sizeof(c),NULL,NULL)) { if (c->client_tracking_prefixes == NULL) @@ -240,12 +245,15 @@ void trackingRememberKeys(client *tracking, client *executing) { for(int j = 0; j < numkeys; j++) { int idx = keys[j].pos; sds sdskey = executing->argv[idx]->ptr; - rax *ids = raxFind(TrackingTable,(unsigned char*)sdskey,sdslen(sdskey)); - if (ids == raxNotFound) { + void *result; + rax *ids; + if (!raxFind(TrackingTable,(unsigned char*)sdskey,sdslen(sdskey),&result)) { ids = raxNew(); int inserted = raxTryInsert(TrackingTable,(unsigned char*)sdskey, sdslen(sdskey),ids, NULL); serverAssert(inserted == 1); + } else { + ids = result; } if (raxTryInsert(ids,(unsigned char*)&tracking->id,sizeof(tracking->id),NULL,NULL)) TrackingTableTotalItems++; @@ -372,8 +380,9 @@ void trackingInvalidateKey(client *c, robj *keyobj, int bcast) { if (bcast && raxSize(PrefixTable) > 0) trackingRememberKeyToBroadcast(c,(char *)key,keylen); - rax *ids = raxFind(TrackingTable,key,keylen); - if (ids == raxNotFound) return; + void *result; + if (!raxFind(TrackingTable,key,keylen,&result)) return; + rax *ids = result; raxIterator ri; raxStart(&ri,ids);