Skip to content

Commit

Permalink
moving NodeConnection HashMap to MVar and adding refreshCluster function
Browse files Browse the repository at this point in the history
  • Loading branch information
shashi-kant-juspay committed Apr 12, 2023
1 parent bc79b73 commit cab4f4a
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 54 deletions.
93 changes: 60 additions & 33 deletions src/Database/Redis/Cluster.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
module Database.Redis.Cluster
( Connection(..)
, NodeRole(..)
Expand All @@ -13,6 +15,9 @@ module Database.Redis.Cluster
, HashSlot
, Shard(..)
, TimeoutException(..)
, TcpInfo(..)
, Host
, NodeID
, connect
, destroyNodeResources
, requestPipelined
Expand Down Expand Up @@ -46,6 +51,7 @@ import Text.Read (readMaybe)

import Database.Redis.Protocol(Reply(Error), renderRequest, reply)
import qualified Database.Redis.Cluster.Command as CMD
import Network.TLS (ClientParams)

-- This module implements a clustered connection whilst maintaining
-- compatibility with the original Hedis codebase. In particular it still
Expand All @@ -60,7 +66,7 @@ import qualified Database.Redis.Cluster.Command as CMD
-- | 'NodeConnection's, a 'Pipeline', and a 'ShardMap'
type IsReadOnly = Bool

data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap IsReadOnly
data Connection = Connection (MVar NodeConnectionMap) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap IsReadOnly TcpInfo

-- | A connection to a single node in the cluster, similar to 'ProtocolPipelining.Connection'
data NodeConnection = NodeConnection (Pool CC.ConnectionContext) (IOR.IORef (Maybe B.ByteString)) NodeID
Expand Down Expand Up @@ -114,6 +120,17 @@ data Shard = Shard MasterNode [SlaveNode] deriving (Show, Eq, Ord)
-- A map from hashslot to shards
newtype ShardMap = ShardMap (IntMap.IntMap Shard) deriving (Show)

type NodeConnectionMap = HM.HashMap NodeID NodeConnection

-- Object for storing Tcp Connection Info which will be used when cluster is refreshed
data TcpInfo = TcpInfo
{ connectAuth :: Maybe B.ByteString
, connectTLSParams :: Maybe ClientParams
, idleTime :: Time.NominalDiffTime
, maxResources :: Int
, timeoutOpt :: Maybe Int
} deriving Show

newtype MissingNodeException = MissingNodeException [B.ByteString] deriving (Show, Typeable)
instance Exception MissingNodeException

Expand All @@ -129,8 +146,8 @@ instance Exception NoNodeException
data TimeoutException = TimeoutException String deriving (Show, Typeable)
instance Exception TimeoutException

connect :: (Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext) -> [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> Bool -> ([NodeConnection] -> IO ShardMap) -> Time.NominalDiffTime -> Int -> IO Connection
connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap idleTime maxResources = do
connect :: (Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext) -> [CMD.CommandInfo] -> MVar ShardMap -> Bool -> ([NodeConnection] -> IO ShardMap) -> TcpInfo -> IO Connection
connect withAuth commandInfos shardMapVar isReadOnly refreshShardMap (tcpInfo@TcpInfo{ timeoutOpt, maxResources, idleTime }) = do
shardMap <- readMVar shardMapVar
stateVar <- newMVar $ Pending []
pipelineVar <- newMVar $ Pipeline stateVar
Expand All @@ -149,7 +166,8 @@ connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap
throwIO NoNodeException
else
return eNodeConns
return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) isReadOnly where
nodeConnsVar <- newMVar nodeConns
return $ Connection nodeConnsVar pipelineVar shardMapVar (CMD.newInfoMap commandInfos) isReadOnly tcpInfo where
simpleNodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection)
simpleNodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ nodes shardMap)
nodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection, Bool)
Expand All @@ -169,27 +187,27 @@ connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap
refreshShardMapVar shardMap = hasLocked $ modifyMVar_ shardMapVar (const (pure shardMap))

destroyNodeResources :: Connection -> IO ()
destroyNodeResources (Connection nodeConnMap _ _ _ _ ) = mapM_ disconnectNode (HM.elems nodeConnMap) where
destroyNodeResources (Connection nodeConnMapVar _ _ _ _ _) = readMVar nodeConnMapVar >>= (mapM_ disconnectNode . HM.elems) where
disconnectNode (NodeConnection nodePool _ _) = destroyAllResources nodePool

-- Add a request to the current pipeline for this connection. The pipeline will
-- be executed implicitly as soon as any result returned from this function is
-- evaluated.
requestPipelined :: IO ShardMap -> Connection -> [B.ByteString] -> IO Reply
requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
requestPipelined refreshShardmapAction conn@(Connection _ pipelineVar shardMapVar _ _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
(newStateVar, repliesIndex) <- hasLocked $ modifyMVar stateVar $ \case
Pending requests | isMulti nextRequest -> do
replies <- evaluatePipeline shardMapVar refreshAction conn requests
replies <- evaluatePipeline shardMapVar refreshShardmapAction conn requests
s' <- newMVar $ TransactionPending [nextRequest]
return (Executed replies, (s', 0))
Pending requests | length requests > 1000 -> do
replies <- evaluatePipeline shardMapVar refreshAction conn (nextRequest:requests)
replies <- evaluatePipeline shardMapVar refreshShardmapAction conn (nextRequest:requests)
return (Executed replies, (stateVar, length requests))
Pending requests ->
return (Pending (nextRequest:requests), (stateVar, length requests))
TransactionPending requests ->
if isExec nextRequest then do
replies <- evaluateTransactionPipeline shardMapVar refreshAction conn (nextRequest:requests)
replies <- evaluateTransactionPipeline shardMapVar refreshShardmapAction conn (nextRequest:requests)
return (Executed replies, (stateVar, length requests))
else
return (TransactionPending (nextRequest:requests), (stateVar, length requests))
Expand All @@ -205,10 +223,10 @@ requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) n
Executed replies ->
return (Executed replies, replies)
Pending requests-> do
replies <- evaluatePipeline shardMapVar refreshAction conn requests
replies <- evaluatePipeline shardMapVar refreshShardmapAction conn requests
return (Executed replies, replies)
TransactionPending requests-> do
replies <- evaluateTransactionPipeline shardMapVar refreshAction conn requests
replies <- evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests
return (Executed replies, replies)
return $ replies !! repliesIndex
return (Pipeline newStateVar, evaluateAction)
Expand Down Expand Up @@ -271,7 +289,7 @@ evaluatePipeline shardMapVar refreshShardmapAction conn requests = do
Left (err :: SomeException) ->
case fromException err of
Just (er :: TimeoutException) -> throwIO er
_ -> executeRequests (getRandomConnection cc conn) r
_ -> getRandomConnection cc conn >>= (`executeRequests` r)
) (zip eresps requestsByNode)
-- check for any moved in both responses and continue the flow.
when (any (moved . rawResponse) resps) refreshShardMapVar
Expand Down Expand Up @@ -306,14 +324,14 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies =
-- there is one.
case last replies of
(Error errString) | B.isPrefixOf "MOVED" errString -> do
let (Connection _ _ _ infoMap _) = conn
let (Connection _ _ _ infoMap _ _) = conn
keys <- mconcat <$> mapM (requestKeys infoMap) requests
hashSlot <- hashSlotForKeys (CrossSlotException requests) keys
nodeConn <- nodeConnForHashSlot shardMapVar conn (MissingNodeException (head requests)) hashSlot
requestNode nodeConn requests
(askingRedirection -> Just (host, port)) -> do
shardMap <- hasLocked $ readMVar shardMapVar
let maybeAskNode = nodeConnWithHostAndPort shardMap conn host port
maybeAskNode <- nodeConnWithHostAndPort shardMap conn host port
case maybeAskNode of
Just askNode -> tail <$> requestNode askNode (["ASKING"] : requests)
Nothing -> case retryCount of
Expand All @@ -328,7 +346,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies =
evaluateTransactionPipeline :: MVar ShardMap -> IO ShardMap -> Connection -> [[B.ByteString]] -> IO [Reply]
evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = do
let requests = reverse requests'
let (Connection _ _ _ infoMap _) = conn
let (Connection _ _ _ infoMap _ _) = conn
keys <- mconcat <$> mapM (requestKeys infoMap) requests
-- In cluster mode Redis expects commands in transactions to all work on the
-- same hashslot. We find that hashslot here.
Expand All @@ -346,7 +364,7 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d
resps <-
case eresps of
Right v -> return v
Left (_ :: SomeException) -> requestNode (getRandomConnection nodeConn conn) requests
Left (_ :: SomeException) -> getRandomConnection nodeConn conn >>= (`requestNode` requests)
-- The Redis documentation has the following to say on the effect of
-- resharding on multi-key operations:
--
Expand Down Expand Up @@ -380,8 +398,9 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d

nodeConnForHashSlot :: Exception e => MVar ShardMap -> Connection -> e -> HashSlot -> IO NodeConnection
nodeConnForHashSlot shardMapVar conn exception hashSlot = do
let (Connection nodeConns _ _ _ _) = conn
let (Connection nodeConnsVar _ _ _ _ _) = conn
(ShardMap shardMap) <- hasLocked $ readMVar shardMapVar
nodeConns <- readMVar nodeConnsVar
node <-
case IntMap.lookup (fromEnum hashSlot) shardMap of
Nothing -> throwIO exception
Expand Down Expand Up @@ -422,13 +441,16 @@ moved (Error errString) = case Char8.words errString of
moved _ = False


nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> Maybe NodeConnection
nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _ _) host port = do
node <- nodeWithHostAndPort shardMap host port
HM.lookup (nodeId node) nodeConns
nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> IO (Maybe NodeConnection)
nodeConnWithHostAndPort shardMap (Connection nodeConnsVar _ _ _ _ _) host port =
case nodeWithHostAndPort shardMap host port of
Nothing -> return Nothing
Just node -> do
nodeConns <- readMVar nodeConnsVar
return (HM.lookup (nodeId node) nodeConns)

nodeConnectionForCommand :: Connection -> ShardMap -> [B.ByteString] -> IO [NodeConnection]
nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap shardMap) request =
nodeConnectionForCommand conn@(Connection nodeConnsVar _ _ infoMap _ _) (ShardMap shardMap) request =
case request of
("FLUSHALL" : _) -> allNodes
("FLUSHDB" : _) -> allNodes
Expand All @@ -440,16 +462,19 @@ nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap sha
node <- case IntMap.lookup (fromEnum hashSlot) shardMap of
Nothing -> throwIO $ MissingNodeException request
Just (Shard master _) -> return master
nodeConns <- readMVar nodeConnsVar
maybe (throwIO $ MissingNodeException request) (return . return) (HM.lookup (nodeId node) nodeConns)
where
allNodes =
case allMasterNodes conn (ShardMap shardMap) of
allNodes = do
maybeNodes <- allMasterNodes conn (ShardMap shardMap)
case maybeNodes of
Nothing -> throwIO $ MissingNodeException request
Just allNodes' -> return allNodes'

allMasterNodes :: Connection -> ShardMap -> Maybe [NodeConnection]
allMasterNodes (Connection nodeConns _ _ _ _) (ShardMap shardMap) =
mapM (flip HM.lookup nodeConns . nodeId) onlyMasterNodes
allMasterNodes :: Connection -> ShardMap -> IO (Maybe [NodeConnection])
allMasterNodes (Connection nodeConnsVar _ _ _ _ _) (ShardMap shardMap) = do
nodeConns <- readMVar nodeConnsVar
return $ mapM (flip HM.lookup nodeConns . nodeId) onlyMasterNodes
where
onlyMasterNodes = (\(Shard master _) -> master) <$> nub (IntMap.elems shardMap)

Expand Down Expand Up @@ -508,14 +533,16 @@ requestMasterNodes conn req = do
concat <$> mapM (`requestNode` [req]) masterNodeConns

masterNodes :: Connection -> IO [NodeConnection]
masterNodes (Connection nodeConns _ shardMapVar _ _) = do
masterNodes (Connection nodeConnsVar _ shardMapVar _ _ _) = do
(ShardMap shardMap) <- readMVar shardMapVar
let masters = map ((\(Shard m _) -> m) . snd) $ IntMap.toList shardMap
let masterNodeIds = map nodeId masters
nodeConns <- readMVar nodeConnsVar
return $ mapMaybe (`HM.lookup` nodeConns) masterNodeIds

getRandomConnection :: NodeConnection -> Connection -> NodeConnection
getRandomConnection nc conn =
let (Connection hmn _ _ _ _) = conn
conns = HM.elems hmn
in fromMaybe (head conns) $ find (nc /= ) conns
getRandomConnection :: NodeConnection -> Connection -> IO NodeConnection
getRandomConnection nc conn = do
let (Connection hmnVar _ _ _ _ _) = conn
hmn <- readMVar hmnVar
let conns = HM.elems hmn
return $ fromMaybe (head conns) $ find (nc /= ) conns
76 changes: 55 additions & 21 deletions src/Database/Redis/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NamedFieldPuns #-}
module Database.Redis.Connection where

import Control.Exception
import qualified Control.Monad.Catch as Catch
import Control.Monad.IO.Class(liftIO, MonadIO)
import Control.Monad(when)
import Control.Concurrent.MVar(MVar, newMVar)
import Control.Concurrent.MVar(MVar, newMVar, putMVar, readMVar, modifyMVar_)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as Char8
import Data.Functor(void)
Expand All @@ -34,6 +35,8 @@ import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (race)
import qualified Database.Redis.Types as T
--import qualified Database.Redis.Cluster.Pipeline as ClusterPipeline
import qualified Data.IORef as IOR
import Data.List (nub)

import Database.Redis.Commands
( ping
Expand Down Expand Up @@ -224,32 +227,26 @@ connectCluster bootstrapConnInfo = do
Right infos -> do
let
isConnectionReadOnly = connectReadOnly bootstrapConnInfo
clusterConnection = Cluster.connect withAuth infos shardMapVar timeoutOptUs isConnectionReadOnly refreshShardMapWithNodeConn (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
connectTLSParams' = connectTLSParams bootstrapConnInfo
connectAuth' = connectAuth bootstrapConnInfo
tcpInfo = Cluster.TcpInfo {
connectTLSParams = connectTLSParams',
connectAuth = connectAuth',
idleTime = connectMaxIdleTime bootstrapConnInfo,
maxResources = connectMaxConnections bootstrapConnInfo,
timeoutOpt = timeoutOptUs
}
withAuth = tcpConnWithAuth connectAuth' connectTLSParams'
clusterConnection = Cluster.connect withAuth infos shardMapVar isConnectionReadOnly refreshShardMapWithNodeConn tcpInfo
-- pool <- createPool (clusterConnect isConnectionReadOnly clusterConnection) Cluster.disconnect 3 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
connection <- clusterConnect isConnectionReadOnly clusterConnection
return $ ClusteredConnection shardMapVar connection
where
withAuth host port timeout = do
conn <- PP.connect host port timeout
conn' <- case connectTLSParams bootstrapConnInfo of
Nothing -> return conn
Just tlsParams -> PP.enableTLS tlsParams conn
PP.beginReceiving conn'

runRedisInternal conn' $ do
-- AUTH
case connectAuth bootstrapConnInfo of
Nothing -> return ()
Just pass -> do
resp <- auth pass
case resp of
Left r -> liftIO $ throwIO $ ConnectAuthError r
_ -> return ()
return $ PP.toCtx conn'

clusterConnect :: Bool -> IO Cluster.Connection -> IO Cluster.Connection
clusterConnect readOnlyConnection connection = do
clusterConn@(Cluster.Connection nodeMap _ _ _ _) <- connection
clusterConn@(Cluster.Connection nodeMapVar _ _ _ _ _) <- connection
nodeMap <- readMVar nodeMapVar
nodesConns <- sequence $ (ctxToConn . snd) <$> (HM.toList nodeMap)
when readOnlyConnection $
mapM_ (\maybeConn -> case maybeConn of
Expand All @@ -269,6 +266,26 @@ connectCluster bootstrapConnInfo = do
putStrLn ("SomeException Occured in NodeID " ++ show nid)
return Nothing


tcpConnWithAuth :: Maybe B.ByteString -> Maybe ClientParams -> Cluster.Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext
tcpConnWithAuth connectAuth connectTLSParams host port timeout = do
conn <- PP.connect host port timeout
conn' <- case connectTLSParams of
Nothing -> return conn
Just tlsParams -> PP.enableTLS tlsParams conn
PP.beginReceiving conn'

runRedisInternal conn' $ do
-- AUTH
case connectAuth of
Nothing -> return ()
Just pass -> do
resp <- auth pass
case resp of
Left r -> liftIO $ throwIO $ ConnectAuthError r
_ -> return ()
return $ PP.toCtx conn'

shardMapFromClusterSlotsResponse :: ClusterSlotsResponse -> IO ShardMap
shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr mkShardMap (pure IntMap.empty) clusterSlotsResponseEntries where
mkShardMap :: ClusterSlotsResponseEntry -> IO (IntMap.IntMap Shard) -> IO (IntMap.IntMap Shard)
Expand All @@ -286,8 +303,25 @@ shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr m
in
Cluster.Node clusterSlotsNodeID role hostname (toEnum clusterSlotsNodePort)

refreshCluster :: Cluster.Connection -> IO ()
refreshCluster conn@(Cluster.Connection nodeConnsVar _ shardMapVar _ _ Cluster.TcpInfo { idleTime, maxResources, timeoutOpt, connectAuth, connectTLSParams }) = do
newShardMap <- refreshShardMap conn
modifyMVar_ nodeConnsVar $ \_ -> do
putMVar shardMapVar newShardMap
simpleNodeConnections newShardMap
where
withAuth = tcpConnWithAuth connectAuth connectTLSParams
simpleNodeConnections :: ShardMap -> IO (HM.HashMap Cluster.NodeID Cluster.NodeConnection)
simpleNodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ Cluster.nodes shardMap)
connectNode :: Cluster.Node -> IO (Cluster.NodeID, Cluster.NodeConnection)
connectNode (Cluster.Node n _ host port) = do
ctx <- createPool (withAuth host (CC.PortNumber $ toEnum port) timeoutOpt) CC.disconnect 1 idleTime maxResources
ref <- IOR.newIORef Nothing
return (n, Cluster.NodeConnection ctx ref n)

refreshShardMap :: Cluster.Connection -> IO ShardMap
refreshShardMap (Cluster.Connection nodeConns _ _ _ _) =
refreshShardMap (Cluster.Connection nodeConnsVar _ _ _ _ _) = do
nodeConns <- readMVar nodeConnsVar
refreshShardMapWithNodeConn (HM.elems nodeConns)

refreshShardMapWithNodeConn :: [Cluster.NodeConnection] -> IO ShardMap
Expand Down

0 comments on commit cab4f4a

Please sign in to comment.