diff --git a/src/db.c b/src/db.c index 12e1c2bbe..41ee53601 100644 --- a/src/db.c +++ b/src/db.c @@ -1242,12 +1242,12 @@ void copyCommand(client *c) { case OBJ_HASH: newobj = hashTypeDup(o); break; case OBJ_STREAM: newobj = streamDup(o); break; case OBJ_MODULE: - addReplyError(c, "Copying module type object is not supported"); - return; - default: { + newobj = moduleTypeDupOrReply(c, key, newkey, o); + if (!newobj) return; + break; + default: addReplyError(c, "unknown type object"); return; - }; } if (delete) { diff --git a/src/module.c b/src/module.c index b5921ed55..be0f939f2 100644 --- a/src/module.c +++ b/src/module.c @@ -3639,6 +3639,24 @@ void moduleTypeNameByID(char *name, uint64_t moduleid) { } } +/* Create a copy of a module type value using the copy callback. If failed + * or not supported, produce an error reply and return NULL. + */ +robj *moduleTypeDupOrReply(client *c, robj *fromkey, robj *tokey, robj *value) { + moduleValue *mv = value->ptr; + moduleType *mt = mv->type; + if (!mt->copy) { + addReplyError(c, "not supported for this module key"); + return NULL; + } + void *newval = mt->copy(fromkey, tokey, mv->value); + if (!newval) { + addReplyError(c, "module key failed to copy"); + return NULL; + } + return createModuleObject(mt, newval); +} + /* Register a new data type exported by the module. The parameters are the * following. Please for in depth documentation check the modules API * documentation, especially https://redis.io/topics/modules-native-types. @@ -3678,6 +3696,7 @@ void moduleTypeNameByID(char *name, uint64_t moduleid) { * .aux_save = myType_AuxRDBSaveCallBack, * .free_effort = myType_FreeEffortCallBack * .unlink = myType_UnlinkCallBack + * .copy = myType_CopyCallback * } * * * **rdb_load**: A callback function pointer that loads data from RDB files. @@ -3697,10 +3716,13 @@ void moduleTypeNameByID(char *name, uint64_t moduleid) { * been removed from the DB by redis, and may soon be freed by a background thread. Note that * it won't be called on FLUSHALL/FLUSHDB (both sync and async), and the module can use the * RedisModuleEvent_FlushDB to hook into that. + * * **copy**: A callback function pointer that is used to make a copy of the specified key. + * The module is expected to perform a deep copy of the specified value and return it. + * In addition, hints about the names of the source and destination keys is provided. + * A NULL return value is considered an error and the copy operation fails. + * Note: if the target key exists and is being overwritten, the copy callback will be + * called first, followed by a free callback to the value that is being replaced. * - * The **digest** and **mem_usage** methods should currently be omitted since - * they are not yet implemented inside the Redis modules core. - * * Note: the module name "AAAAAAAAA" is reserved and produces an error, it * happens to be pretty lame as well. * @@ -3743,6 +3765,7 @@ moduleType *RM_CreateDataType(RedisModuleCtx *ctx, const char *name, int encver, struct { moduleTypeFreeEffortFunc free_effort; moduleTypeUnlinkFunc unlink; + moduleTypeCopyFunc copy; } v3; } *tms = (struct typemethods*) typemethods_ptr; @@ -3763,6 +3786,7 @@ moduleType *RM_CreateDataType(RedisModuleCtx *ctx, const char *name, int encver, if (tms->version >= 3) { mt->free_effort = tms->v3.free_effort; mt->unlink = tms->v3.unlink; + mt->copy = tms->v3.copy; } memcpy(mt->name,name,sizeof(mt->name)); listAddNodeTail(ctx->module->types,mt); diff --git a/src/redismodule.h b/src/redismodule.h index 59f568c6f..e4e8b2b3a 100644 --- a/src/redismodule.h +++ b/src/redismodule.h @@ -494,6 +494,7 @@ typedef void (*RedisModuleTypeDigestFunc)(RedisModuleDigest *digest, void *value typedef void (*RedisModuleTypeFreeFunc)(void *value); typedef size_t (*RedisModuleTypeFreeEffortFunc)(RedisModuleString *key, const void *value); typedef void (*RedisModuleTypeUnlinkFunc)(RedisModuleString *key, const void *value); +typedef void *(*RedisModuleTypeCopyFunc)(RedisModuleString *fromkey, RedisModuleString *tokey, const void *value); typedef void (*RedisModuleClusterMessageReceiver)(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, const unsigned char *payload, uint32_t len); typedef void (*RedisModuleTimerProc)(RedisModuleCtx *ctx, void *data); typedef void (*RedisModuleCommandFilterFunc) (RedisModuleCommandFilterCtx *filter); @@ -516,6 +517,7 @@ typedef struct RedisModuleTypeMethods { int aux_save_triggers; RedisModuleTypeFreeEffortFunc free_effort; RedisModuleTypeUnlinkFunc unlink; + RedisModuleTypeCopyFunc copy; } RedisModuleTypeMethods; #define REDISMODULE_GET_API(name) \ diff --git a/src/server.h b/src/server.h index 42ce6d32a..97604a11b 100644 --- a/src/server.h +++ b/src/server.h @@ -529,6 +529,7 @@ typedef size_t (*moduleTypeMemUsageFunc)(const void *value); typedef void (*moduleTypeFreeFunc)(void *value); typedef size_t (*moduleTypeFreeEffortFunc)(struct redisObject *key, const void *value); typedef void (*moduleTypeUnlinkFunc)(struct redisObject *key, void *value); +typedef void *(*moduleTypeCopyFunc)(struct redisObject *fromkey, struct redisObject *tokey, const void *value); /* This callback type is called by moduleNotifyUserChanged() every time * a user authenticated via the module API is associated with a different @@ -550,6 +551,7 @@ typedef struct RedisModuleType { moduleTypeFreeFunc free; moduleTypeFreeEffortFunc free_effort; moduleTypeUnlinkFunc unlink; + moduleTypeCopyFunc copy; moduleTypeAuxLoadFunc aux_load; moduleTypeAuxSaveFunc aux_save; int aux_save_triggers; @@ -1694,6 +1696,7 @@ void moduleUnblockClient(client *c); int moduleClientIsBlockedOnKeys(client *c); void moduleNotifyUserChanged(client *c); void moduleNotifyKeyUnlink(robj *key, robj *val); +robj *moduleTypeDupOrReply(client *c, robj *fromkey, robj *tokey, robj *value); /* Utils */ long long ustime(void); diff --git a/tests/modules/datatype.c b/tests/modules/datatype.c index 6596f9368..0c6f95551 100644 --- a/tests/modules/datatype.c +++ b/tests/modules/datatype.c @@ -41,6 +41,32 @@ static void datatype_free(void *value) { } } +static void *datatype_copy(RedisModuleString *fromkey, RedisModuleString *tokey, const void *value) { + const DataType *old = value; + + /* Answers to ultimate questions cannot be copied! */ + if (old->intval == 42) + return NULL; + + DataType *new = (DataType *) RedisModule_Alloc(sizeof(DataType)); + + new->intval = old->intval; + new->strval = RedisModule_CreateStringFromString(NULL, old->strval); + + /* Breaking the rules here! We return a copy that also includes traces + * of fromkey/tokey to confirm we get what we expect. + */ + size_t len; + const char *str = RedisModule_StringPtrLen(fromkey, &len); + RedisModule_StringAppendBuffer(NULL, new->strval, "/", 1); + RedisModule_StringAppendBuffer(NULL, new->strval, str, len); + RedisModule_StringAppendBuffer(NULL, new->strval, "/", 1); + str = RedisModule_StringPtrLen(tokey, &len); + RedisModule_StringAppendBuffer(NULL, new->strval, str, len); + + return new; +} + static int datatype_set(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { if (argc != 4) { RedisModule_WrongArity(ctx); @@ -97,9 +123,13 @@ static int datatype_get(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) DataType *dt = RedisModule_ModuleTypeGetValue(key); RedisModule_CloseKey(key); - RedisModule_ReplyWithArray(ctx, 2); - RedisModule_ReplyWithLongLong(ctx, dt->intval); - RedisModule_ReplyWithString(ctx, dt->strval); + if (!dt) { + RedisModule_ReplyWithNullArray(ctx); + } else { + RedisModule_ReplyWithArray(ctx, 2); + RedisModule_ReplyWithLongLong(ctx, dt->intval); + RedisModule_ReplyWithString(ctx, dt->strval); + } return REDISMODULE_OK; } @@ -161,6 +191,7 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) .rdb_load = datatype_load, .rdb_save = datatype_save, .free = datatype_free, + .copy = datatype_copy }; datatype = RedisModule_CreateDataType(ctx, "test___dt", 1, &datatype_methods); diff --git a/tests/unit/moduleapi/datatype.tcl b/tests/unit/moduleapi/datatype.tcl index e235462ea..cd6ebb32a 100644 --- a/tests/unit/moduleapi/datatype.tcl +++ b/tests/unit/moduleapi/datatype.tcl @@ -41,4 +41,18 @@ start_server {tags {"modules"}} { catch {r datatype.swap key-a key-b} e set e } {*ERR*} + + test {DataType: Copy command works for modules} { + # Test failed copies + r datatype.set answer-to-universe 42 AAA + catch {r copy answer-to-universe answer2} e + assert_match {*module key failed to copy*} $e + + # Our module's data type copy function copies the int value as-is + # but appends // to the string value so we can + # track passed arguments. + r datatype.set sourcekey 1234 AAA + r copy sourcekey targetkey + r datatype.get targetkey + } {1234 AAA/sourcekey/targetkey} }