Skip to content

Commit

Permalink
Pass-thru custom Redis commands using Lettuce.
Browse files Browse the repository at this point in the history
We now accept unknown custom Redis commands when using the Lettuce driver. Previously, custom commands were required to exist in Lettuce's CommandType enumeration and unknown commands (such as modules) failed to run.

Closes #1979
  • Loading branch information
mp911de committed Feb 23, 2021
1 parent 9cc6fa3 commit 773acd0
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
import io.lettuce.core.protocol.Command;
import io.lettuce.core.protocol.CommandArgs;
import io.lettuce.core.protocol.CommandType;
import io.lettuce.core.protocol.ProtocolKeyword;
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection;
import io.lettuce.core.sentinel.api.StatefulRedisSentinelConnection;

import java.lang.reflect.Constructor;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -389,7 +391,7 @@ public Object execute(String command, @Nullable CommandOutput commandOutputTypeH
Assert.hasText(command, "a valid command needs to be specified");

String name = command.trim().toUpperCase();
CommandType commandType = CommandType.valueOf(name);
ProtocolKeyword commandType = getCommandType(name);

validateCommandIfRunningInTransactionMode(commandType, args);

Expand Down Expand Up @@ -1109,14 +1111,14 @@ io.lettuce.core.ScanCursor getScanCursor(long cursorId) {
return io.lettuce.core.ScanCursor.of(Long.toString(cursorId));
}

private void validateCommandIfRunningInTransactionMode(CommandType cmd, byte[]... args) {
private void validateCommandIfRunningInTransactionMode(ProtocolKeyword cmd, byte[]... args) {

if (this.isQueueing()) {
validateCommand(cmd, args);
}
}

private void validateCommand(CommandType cmd, @Nullable byte[]... args) {
private void validateCommand(ProtocolKeyword cmd, @Nullable byte[]... args) {

RedisCommand redisCommand = RedisCommand.failsafeCommandLookup(cmd.name());
if (!RedisCommand.UNKNOWN.equals(redisCommand) && redisCommand.requiresArguments()) {
Expand All @@ -1128,6 +1130,15 @@ private void validateCommand(CommandType cmd, @Nullable byte[]... args) {
}
}

private static ProtocolKeyword getCommandType(String name) {

try {
return CommandType.valueOf(name);
} catch (IllegalArgumentException e) {
return new CustomCommandType(name);
}
}

/**
* {@link TypeHints} provide {@link CommandOutput} information for a given {@link CommandType}.
*
Expand All @@ -1136,7 +1147,7 @@ private void validateCommand(CommandType cmd, @Nullable byte[]... args) {
static class TypeHints {

@SuppressWarnings("rawtypes") //
private static final Map<CommandType, Class<? extends CommandOutput>> COMMAND_OUTPUT_TYPE_MAPPING = new HashMap<>();
private static final Map<ProtocolKeyword, Class<? extends CommandOutput>> COMMAND_OUTPUT_TYPE_MAPPING = new HashMap<>();

@SuppressWarnings("rawtypes") //
private static final Map<Class<?>, Constructor<CommandOutput>> CONSTRUCTORS = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -1298,7 +1309,7 @@ static class TypeHints {
* @return {@link ByteArrayOutput} as default when no matching {@link CommandOutput} available.
*/
@SuppressWarnings("rawtypes")
public CommandOutput getTypeHint(CommandType type) {
public CommandOutput getTypeHint(ProtocolKeyword type) {
return getTypeHint(type, new ByteArrayOutput<>(CODEC));
}

Expand All @@ -1309,7 +1320,7 @@ public CommandOutput getTypeHint(CommandType type) {
* @return
*/
@SuppressWarnings("rawtypes")
public CommandOutput getTypeHint(CommandType type, CommandOutput defaultType) {
public CommandOutput getTypeHint(ProtocolKeyword type, CommandOutput defaultType) {

if (type == null || !COMMAND_OUTPUT_TYPE_MAPPING.containsKey(type)) {
return defaultType;
Expand Down Expand Up @@ -1552,4 +1563,49 @@ public void onClose(StatefulConnection<?, ?> connection) {
connection.setAutoFlushCommands(true);
}
}

/**
* @since 2.3.8
*/
static class CustomCommandType implements ProtocolKeyword {

private final String name;

CustomCommandType(String name) {
this.name = name;
}

@Override
public byte[] getBytes() {
return name.getBytes(StandardCharsets.US_ASCII);
}

@Override
public String name() {
return name;
}

@Override
public boolean equals(Object o) {

if (this == o) {
return true;
}
if (!(o instanceof CustomCommandType)) {
return false;
}
CustomCommandType that = (CustomCommandType) o;
return ObjectUtils.nullSafeEquals(name, that.name);
}

@Override
public int hashCode() {
return ObjectUtils.nullSafeHashCode(name);
}

@Override
public String toString() {
return name;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@
import static org.mockito.Mockito.*;

import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.XAddArgs;
import io.lettuce.core.XClaimArgs;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.api.sync.RedisCommands;
import io.lettuce.core.codec.ByteArrayCodec;
import io.lettuce.core.codec.RedisCodec;
import io.lettuce.core.output.StatusOutput;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.Command;
import io.lettuce.core.protocol.CommandArgs;

import java.lang.reflect.InvocationTargetException;
import java.time.Duration;
Expand All @@ -33,6 +39,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

import org.springframework.dao.InvalidDataAccessResourceUsageException;
import org.springframework.data.redis.connection.AbstractConnectionUnitTestBase;
import org.springframework.data.redis.connection.RedisServerCommands.ShutdownOption;
Expand Down Expand Up @@ -198,7 +205,6 @@ void xClaimShouldNotAddJustIdFlagToArgs() {
}

assertThat(ReflectionTestUtils.getField(args.getValue(), "justid")).isEqualTo(false);

}

@Test // DATAREDIS-1226
Expand All @@ -216,6 +222,21 @@ void xClaimJustIdShouldAddJustIdFlagToArgs() {

assertThat(ReflectionTestUtils.getField(args.getValue(), "justid")).isEqualTo(true);
}

@Test // GH-1979
void executeShouldPassThruCustomCommands() {

Command<byte[], byte[], String> command = new Command<>(new LettuceConnection.CustomCommandType("FOO.BAR"),
new StatusOutput<>(ByteArrayCodec.INSTANCE));
AsyncCommand<byte[], byte[], String> future = new AsyncCommand<>(command);
future.complete();

when(asyncCommandsMock.dispatch(any(), any(), any())).thenReturn((RedisFuture) future);

connection.execute("foo.bar", command.getOutput());

verify(asyncCommandsMock).dispatch(eq(command.getType()), eq(command.getOutput()), any(CommandArgs.class));
}
}

public static class LettucePipelineConnectionUnitTests extends BasicUnitTests {
Expand Down

0 comments on commit 773acd0

Please sign in to comment.