Skip to content

Commit

Permalink
add "address_remap" feature to RedisCluster (#2726)
Browse files Browse the repository at this point in the history
* add cluster "host_port_remap" feature for asyncio.RedisCluster

* Add a unittest for asyncio.RedisCluster

* Add host_port_remap to _sync_ RedisCluster

* add synchronous tests

* rename arg to `address_remap` and take and return an address tuple.

* Add class documentation

* Add CHANGES
  • Loading branch information
kristjanvalur authored May 2, 2023
1 parent ac15d52 commit a7857e1
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Add `address_remap` parameter to `RedisCluster`
* Fix incorrect usage of once flag in async Sentinel
* asyncio: Fix memory leak caused by hiredis (#2693)
* Allow data to drain from async PythonParser when reading during a disconnect()
Expand Down
31 changes: 30 additions & 1 deletion redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import warnings
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
maximum number of connections are already created, a
:class:`~.MaxConnectionsError` is raised. This error may be retried as defined
by :attr:`connection_error_retry_attempts`
:param address_remap:
| An optional callable which, when provided with an internal network
address of a node, e.g. a `(host, port)` tuple, will return the address
where the node is reachable. This can be used to map the addresses at
which the nodes _think_ they are, to addresses at which a client may
reach them, such as when they sit behind a proxy.
| Rest of the arguments will be passed to the
:class:`~redis.asyncio.connection.Connection` instances when created
Expand Down Expand Up @@ -250,6 +258,7 @@ def __init__(
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -337,7 +346,12 @@ def __init__(
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))

self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
self.nodes_manager = NodesManager(
startup_nodes,
require_full_coverage,
kwargs,
address_remap=address_remap,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
Expand Down Expand Up @@ -1059,17 +1073,20 @@ class NodesManager:
"require_full_coverage",
"slots_cache",
"startup_nodes",
"address_remap",
)

def __init__(
self,
startup_nodes: List["ClusterNode"],
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
self.connection_kwargs = connection_kwargs
self.address_remap = address_remap

self.default_node: "ClusterNode" = None
self.nodes_cache: Dict[str, "ClusterNode"] = {}
Expand Down Expand Up @@ -1228,6 +1245,7 @@ async def initialize(self) -> None:
if host == "":
host = startup_node.host
port = int(primary_node[1])
host, port = self.remap_host_port(host, port)

target_node = tmp_nodes_cache.get(get_node_name(host, port))
if not target_node:
Expand All @@ -1246,6 +1264,7 @@ async def initialize(self) -> None:
for replica_node in replica_nodes:
host = replica_node[0]
port = replica_node[1]
host, port = self.remap_host_port(host, port)

target_replica_node = tmp_nodes_cache.get(
get_node_name(host, port)
Expand Down Expand Up @@ -1319,6 +1338,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
)
)

def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
"""
Remap the host and port returned from the cluster to a different
internal value. Useful if the client is not connecting directly
to the cluster.
"""
if self.address_remap:
return self.address_remap((host, port))
return host, port


class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
"""
Expand Down
22 changes: 22 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def __init__(
read_from_replicas: bool = False,
dynamic_startup_nodes: bool = True,
url: Optional[str] = None,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -514,6 +515,12 @@ def __init__(
reinitialize_steps to 1.
To avoid reinitializing the cluster on moved errors, set
reinitialize_steps to 0.
:param address_remap:
An optional callable which, when provided with an internal network
address of a node, e.g. a `(host, port)` tuple, will return the address
where the node is reachable. This can be used to map the addresses at
which the nodes _think_ they are, to addresses at which a client may
reach them, such as when they sit behind a proxy.
:**kwargs:
Extra arguments that will be sent into Redis instance when created
Expand Down Expand Up @@ -594,6 +601,7 @@ def __init__(
from_url=from_url,
require_full_coverage=require_full_coverage,
dynamic_startup_nodes=dynamic_startup_nodes,
address_remap=address_remap,
**kwargs,
)

Expand Down Expand Up @@ -1269,6 +1277,7 @@ def __init__(
lock=None,
dynamic_startup_nodes=True,
connection_pool_class=ConnectionPool,
address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
self.nodes_cache = {}
Expand All @@ -1280,6 +1289,7 @@ def __init__(
self._require_full_coverage = require_full_coverage
self._dynamic_startup_nodes = dynamic_startup_nodes
self.connection_pool_class = connection_pool_class
self.address_remap = address_remap
self._moved_exception = None
self.connection_kwargs = kwargs
self.read_load_balancer = LoadBalancer()
Expand Down Expand Up @@ -1502,6 +1512,7 @@ def initialize(self):
if host == "":
host = startup_node.host
port = int(primary_node[1])
host, port = self.remap_host_port(host, port)

target_node = self._get_or_create_cluster_node(
host, port, PRIMARY, tmp_nodes_cache
Expand All @@ -1518,6 +1529,7 @@ def initialize(self):
for replica_node in replica_nodes:
host = str_if_bytes(replica_node[0])
port = replica_node[1]
host, port = self.remap_host_port(host, port)

target_replica_node = self._get_or_create_cluster_node(
host, port, REPLICA, tmp_nodes_cache
Expand Down Expand Up @@ -1591,6 +1603,16 @@ def reset(self):
# The read_load_balancer is None, do nothing
pass

def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
"""
Remap the host and port returned from the cluster to a different
internal value. Useful if the client is not connecting directly
to the cluster.
"""
if self.address_remap:
return self.address_remap((host, port))
return host, port


class ClusterPubSub(PubSub):
"""
Expand Down
110 changes: 109 additions & 1 deletion tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from _pytest.fixtures import FixtureRequest

from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
from redis.asyncio.connection import Connection, SSLConnection
from redis.asyncio.connection import Connection, SSLConnection, async_timeout
from redis.asyncio.parser import CommandsParser
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
Expand Down Expand Up @@ -49,6 +49,71 @@
]


class NodeProxy:
"""A class to proxy a node connection to a different port"""

def __init__(self, addr, redis_addr):
self.addr = addr
self.redis_addr = redis_addr
self.send_event = asyncio.Event()
self.server = None
self.task = None
self.n_connections = 0

async def start(self):
# test that we can connect to redis
async with async_timeout(2):
_, redis_writer = await asyncio.open_connection(*self.redis_addr)
redis_writer.close()
self.server = await asyncio.start_server(
self.handle, *self.addr, reuse_address=True
)
self.task = asyncio.create_task(self.server.serve_forever())

async def handle(self, reader, writer):
# establish connection to redis
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
try:
self.n_connections += 1
pipe1 = asyncio.create_task(self.pipe(reader, redis_writer))
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer))
await asyncio.gather(pipe1, pipe2)
finally:
redis_writer.close()

async def aclose(self):
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
await self.server.wait_closed()

async def pipe(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
):
while True:
data = await reader.read(1000)
if not data:
break
writer.write(data)
await writer.drain()


@pytest.fixture
def redis_addr(request):
redis_url = request.config.getoption("--redis-url")
scheme, netloc = urlparse(redis_url)[:2]
assert scheme == "redis"
if ":" in netloc:
host, port = netloc.split(":")
return host, int(port)
else:
return netloc, 6379


@pytest_asyncio.fixture()
async def slowlog(r: RedisCluster) -> None:
"""
Expand Down Expand Up @@ -809,6 +874,49 @@ async def test_default_node_is_replaced_after_exception(self, r):
# Rollback to the old default node
r.replace_default_node(curr_default_node)

async def test_address_remap(self, create_redis, redis_addr):
"""Test that we can create a rediscluster object with
a host-port remapper and map connections through proxy objects
"""

# we remap the first n nodes
offset = 1000
n = 6
ports = [redis_addr[1] + i for i in range(n)]

def address_remap(address):
# remap first three nodes to our local proxy
# old = host, port
host, port = address
if int(port) in ports:
host, port = "127.0.0.1", int(port) + offset
# print(f"{old} {host, port}")
return host, port

# create the proxies
proxies = [
NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
for port in ports
]
await asyncio.gather(*[p.start() for p in proxies])
try:
# create cluster:
r = await create_redis(
cls=RedisCluster, flushdb=False, address_remap=address_remap
)
try:
assert await r.ping() is True
assert await r.set("byte_string", b"giraffe")
assert await r.get("byte_string") == b"giraffe"
finally:
await r.close()
finally:
await asyncio.gather(*[p.aclose() for p in proxies])

# verify that the proxies were indeed used
n_used = sum((1 if p.n_connections else 0) for p in proxies)
assert n_used > 1


class TestClusterRedisCommands:
"""
Expand Down
Loading

0 comments on commit a7857e1

Please sign in to comment.