diff --git a/src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnection.java b/src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnection.java index d808cee848..708344cc28 100644 --- a/src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnection.java +++ b/src/main/java/org/springframework/data/redis/connection/lettuce/LettuceConnection.java @@ -36,10 +36,12 @@ import io.lettuce.core.protocol.Command; import io.lettuce.core.protocol.CommandArgs; import io.lettuce.core.protocol.CommandType; +import io.lettuce.core.protocol.ProtocolKeyword; import io.lettuce.core.pubsub.StatefulRedisPubSubConnection; import io.lettuce.core.sentinel.api.StatefulRedisSentinelConnection; import java.lang.reflect.Constructor; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -389,7 +391,7 @@ public Object execute(String command, @Nullable CommandOutput commandOutputTypeH Assert.hasText(command, "a valid command needs to be specified"); String name = command.trim().toUpperCase(); - CommandType commandType = CommandType.valueOf(name); + ProtocolKeyword commandType = getCommandType(name); validateCommandIfRunningInTransactionMode(commandType, args); @@ -1109,14 +1111,14 @@ io.lettuce.core.ScanCursor getScanCursor(long cursorId) { return io.lettuce.core.ScanCursor.of(Long.toString(cursorId)); } - private void validateCommandIfRunningInTransactionMode(CommandType cmd, byte[]... args) { + private void validateCommandIfRunningInTransactionMode(ProtocolKeyword cmd, byte[]... args) { if (this.isQueueing()) { validateCommand(cmd, args); } } - private void validateCommand(CommandType cmd, @Nullable byte[]... args) { + private void validateCommand(ProtocolKeyword cmd, @Nullable byte[]... args) { RedisCommand redisCommand = RedisCommand.failsafeCommandLookup(cmd.name()); if (!RedisCommand.UNKNOWN.equals(redisCommand) && redisCommand.requiresArguments()) { @@ -1128,6 +1130,15 @@ private void validateCommand(CommandType cmd, @Nullable byte[]... args) { } } + private static ProtocolKeyword getCommandType(String name) { + + try { + return CommandType.valueOf(name); + } catch (IllegalArgumentException e) { + return new CustomCommandType(name); + } + } + /** * {@link TypeHints} provide {@link CommandOutput} information for a given {@link CommandType}. * @@ -1136,7 +1147,7 @@ private void validateCommand(CommandType cmd, @Nullable byte[]... args) { static class TypeHints { @SuppressWarnings("rawtypes") // - private static final Map> COMMAND_OUTPUT_TYPE_MAPPING = new HashMap<>(); + private static final Map> COMMAND_OUTPUT_TYPE_MAPPING = new HashMap<>(); @SuppressWarnings("rawtypes") // private static final Map, Constructor> CONSTRUCTORS = new ConcurrentHashMap<>(); @@ -1298,7 +1309,7 @@ static class TypeHints { * @return {@link ByteArrayOutput} as default when no matching {@link CommandOutput} available. */ @SuppressWarnings("rawtypes") - public CommandOutput getTypeHint(CommandType type) { + public CommandOutput getTypeHint(ProtocolKeyword type) { return getTypeHint(type, new ByteArrayOutput<>(CODEC)); } @@ -1309,7 +1320,7 @@ public CommandOutput getTypeHint(CommandType type) { * @return */ @SuppressWarnings("rawtypes") - public CommandOutput getTypeHint(CommandType type, CommandOutput defaultType) { + public CommandOutput getTypeHint(ProtocolKeyword type, CommandOutput defaultType) { if (type == null || !COMMAND_OUTPUT_TYPE_MAPPING.containsKey(type)) { return defaultType; @@ -1552,4 +1563,49 @@ public void onClose(StatefulConnection connection) { connection.setAutoFlushCommands(true); } } + + /** + * @since 2.3.8 + */ + static class CustomCommandType implements ProtocolKeyword { + + private final String name; + + CustomCommandType(String name) { + this.name = name; + } + + @Override + public byte[] getBytes() { + return name.getBytes(StandardCharsets.US_ASCII); + } + + @Override + public String name() { + return name; + } + + @Override + public boolean equals(Object o) { + + if (this == o) { + return true; + } + if (!(o instanceof CustomCommandType)) { + return false; + } + CustomCommandType that = (CustomCommandType) o; + return ObjectUtils.nullSafeEquals(name, that.name); + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHashCode(name); + } + + @Override + public String toString() { + return name; + } + } } diff --git a/src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionUnitTests.java b/src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionUnitTests.java index e161afd871..e7882df5b9 100644 --- a/src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/lettuce/LettuceConnectionUnitTests.java @@ -19,12 +19,18 @@ import static org.mockito.Mockito.*; import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisFuture; import io.lettuce.core.XAddArgs; import io.lettuce.core.XClaimArgs; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.api.sync.RedisCommands; +import io.lettuce.core.codec.ByteArrayCodec; import io.lettuce.core.codec.RedisCodec; +import io.lettuce.core.output.StatusOutput; +import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.Command; +import io.lettuce.core.protocol.CommandArgs; import java.lang.reflect.InvocationTargetException; import java.time.Duration; @@ -33,6 +39,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; + import org.springframework.dao.InvalidDataAccessResourceUsageException; import org.springframework.data.redis.connection.AbstractConnectionUnitTestBase; import org.springframework.data.redis.connection.RedisServerCommands.ShutdownOption; @@ -198,7 +205,6 @@ void xClaimShouldNotAddJustIdFlagToArgs() { } assertThat(ReflectionTestUtils.getField(args.getValue(), "justid")).isEqualTo(false); - } @Test // DATAREDIS-1226 @@ -216,6 +222,21 @@ void xClaimJustIdShouldAddJustIdFlagToArgs() { assertThat(ReflectionTestUtils.getField(args.getValue(), "justid")).isEqualTo(true); } + + @Test // GH-1979 + void executeShouldPassThruCustomCommands() { + + Command command = new Command<>(new LettuceConnection.CustomCommandType("FOO.BAR"), + new StatusOutput<>(ByteArrayCodec.INSTANCE)); + AsyncCommand future = new AsyncCommand<>(command); + future.complete(); + + when(asyncCommandsMock.dispatch(any(), any(), any())).thenReturn((RedisFuture) future); + + connection.execute("foo.bar", command.getOutput()); + + verify(asyncCommandsMock).dispatch(eq(command.getType()), eq(command.getOutput()), any(CommandArgs.class)); + } } public static class LettucePipelineConnectionUnitTests extends BasicUnitTests {