From cab4f4af059026bbd3cf6f6d5dd259997afffdbf Mon Sep 17 00:00:00 2001 From: "shashi.kant" Date: Wed, 12 Apr 2023 18:00:55 +0530 Subject: [PATCH] moving NodeConnection HashMap to MVar and adding refreshCluster function --- src/Database/Redis/Cluster.hs | 93 ++++++++++++++++++++------------ src/Database/Redis/Connection.hs | 76 ++++++++++++++++++-------- 2 files changed, 115 insertions(+), 54 deletions(-) diff --git a/src/Database/Redis/Cluster.hs b/src/Database/Redis/Cluster.hs index 247843c2..58a5ebb6 100644 --- a/src/Database/Redis/Cluster.hs +++ b/src/Database/Redis/Cluster.hs @@ -4,6 +4,8 @@ {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE RecordWildCards #-} module Database.Redis.Cluster ( Connection(..) , NodeRole(..) @@ -13,6 +15,9 @@ module Database.Redis.Cluster , HashSlot , Shard(..) , TimeoutException(..) + , TcpInfo(..) + , Host + , NodeID , connect , destroyNodeResources , requestPipelined @@ -46,6 +51,7 @@ import Text.Read (readMaybe) import Database.Redis.Protocol(Reply(Error), renderRequest, reply) import qualified Database.Redis.Cluster.Command as CMD +import Network.TLS (ClientParams) -- This module implements a clustered connection whilst maintaining -- compatibility with the original Hedis codebase. In particular it still @@ -60,7 +66,7 @@ import qualified Database.Redis.Cluster.Command as CMD -- | 'NodeConnection's, a 'Pipeline', and a 'ShardMap' type IsReadOnly = Bool -data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap IsReadOnly +data Connection = Connection (MVar NodeConnectionMap) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap IsReadOnly TcpInfo -- | A connection to a single node in the cluster, similar to 'ProtocolPipelining.Connection' data NodeConnection = NodeConnection (Pool CC.ConnectionContext) (IOR.IORef (Maybe B.ByteString)) NodeID @@ -114,6 +120,17 @@ data Shard = Shard MasterNode [SlaveNode] deriving (Show, Eq, Ord) -- A map from hashslot to shards newtype ShardMap = ShardMap (IntMap.IntMap Shard) deriving (Show) +type NodeConnectionMap = HM.HashMap NodeID NodeConnection + +-- Object for storing Tcp Connection Info which will be used when cluster is refreshed +data TcpInfo = TcpInfo + { connectAuth :: Maybe B.ByteString + , connectTLSParams :: Maybe ClientParams + , idleTime :: Time.NominalDiffTime + , maxResources :: Int + , timeoutOpt :: Maybe Int + } deriving Show + newtype MissingNodeException = MissingNodeException [B.ByteString] deriving (Show, Typeable) instance Exception MissingNodeException @@ -129,8 +146,8 @@ instance Exception NoNodeException data TimeoutException = TimeoutException String deriving (Show, Typeable) instance Exception TimeoutException -connect :: (Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext) -> [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> Bool -> ([NodeConnection] -> IO ShardMap) -> Time.NominalDiffTime -> Int -> IO Connection -connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap idleTime maxResources = do +connect :: (Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext) -> [CMD.CommandInfo] -> MVar ShardMap -> Bool -> ([NodeConnection] -> IO ShardMap) -> TcpInfo -> IO Connection +connect withAuth commandInfos shardMapVar isReadOnly refreshShardMap (tcpInfo@TcpInfo{ timeoutOpt, maxResources, idleTime }) = do shardMap <- readMVar shardMapVar stateVar <- newMVar $ Pending [] pipelineVar <- newMVar $ Pipeline stateVar @@ -149,7 +166,8 @@ connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap throwIO NoNodeException else return eNodeConns - return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) isReadOnly where + nodeConnsVar <- newMVar nodeConns + return $ Connection nodeConnsVar pipelineVar shardMapVar (CMD.newInfoMap commandInfos) isReadOnly tcpInfo where simpleNodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection) simpleNodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ nodes shardMap) nodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection, Bool) @@ -169,27 +187,27 @@ connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap refreshShardMapVar shardMap = hasLocked $ modifyMVar_ shardMapVar (const (pure shardMap)) destroyNodeResources :: Connection -> IO () -destroyNodeResources (Connection nodeConnMap _ _ _ _ ) = mapM_ disconnectNode (HM.elems nodeConnMap) where +destroyNodeResources (Connection nodeConnMapVar _ _ _ _ _) = readMVar nodeConnMapVar >>= (mapM_ disconnectNode . HM.elems) where disconnectNode (NodeConnection nodePool _ _) = destroyAllResources nodePool -- Add a request to the current pipeline for this connection. The pipeline will -- be executed implicitly as soon as any result returned from this function is -- evaluated. requestPipelined :: IO ShardMap -> Connection -> [B.ByteString] -> IO Reply -requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do +requestPipelined refreshShardmapAction conn@(Connection _ pipelineVar shardMapVar _ _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do (newStateVar, repliesIndex) <- hasLocked $ modifyMVar stateVar $ \case Pending requests | isMulti nextRequest -> do - replies <- evaluatePipeline shardMapVar refreshAction conn requests + replies <- evaluatePipeline shardMapVar refreshShardmapAction conn requests s' <- newMVar $ TransactionPending [nextRequest] return (Executed replies, (s', 0)) Pending requests | length requests > 1000 -> do - replies <- evaluatePipeline shardMapVar refreshAction conn (nextRequest:requests) + replies <- evaluatePipeline shardMapVar refreshShardmapAction conn (nextRequest:requests) return (Executed replies, (stateVar, length requests)) Pending requests -> return (Pending (nextRequest:requests), (stateVar, length requests)) TransactionPending requests -> if isExec nextRequest then do - replies <- evaluateTransactionPipeline shardMapVar refreshAction conn (nextRequest:requests) + replies <- evaluateTransactionPipeline shardMapVar refreshShardmapAction conn (nextRequest:requests) return (Executed replies, (stateVar, length requests)) else return (TransactionPending (nextRequest:requests), (stateVar, length requests)) @@ -205,10 +223,10 @@ requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) n Executed replies -> return (Executed replies, replies) Pending requests-> do - replies <- evaluatePipeline shardMapVar refreshAction conn requests + replies <- evaluatePipeline shardMapVar refreshShardmapAction conn requests return (Executed replies, replies) TransactionPending requests-> do - replies <- evaluateTransactionPipeline shardMapVar refreshAction conn requests + replies <- evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests return (Executed replies, replies) return $ replies !! repliesIndex return (Pipeline newStateVar, evaluateAction) @@ -271,7 +289,7 @@ evaluatePipeline shardMapVar refreshShardmapAction conn requests = do Left (err :: SomeException) -> case fromException err of Just (er :: TimeoutException) -> throwIO er - _ -> executeRequests (getRandomConnection cc conn) r + _ -> getRandomConnection cc conn >>= (`executeRequests` r) ) (zip eresps requestsByNode) -- check for any moved in both responses and continue the flow. when (any (moved . rawResponse) resps) refreshShardMapVar @@ -306,14 +324,14 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies = -- there is one. case last replies of (Error errString) | B.isPrefixOf "MOVED" errString -> do - let (Connection _ _ _ infoMap _) = conn + let (Connection _ _ _ infoMap _ _) = conn keys <- mconcat <$> mapM (requestKeys infoMap) requests hashSlot <- hashSlotForKeys (CrossSlotException requests) keys nodeConn <- nodeConnForHashSlot shardMapVar conn (MissingNodeException (head requests)) hashSlot requestNode nodeConn requests (askingRedirection -> Just (host, port)) -> do shardMap <- hasLocked $ readMVar shardMapVar - let maybeAskNode = nodeConnWithHostAndPort shardMap conn host port + maybeAskNode <- nodeConnWithHostAndPort shardMap conn host port case maybeAskNode of Just askNode -> tail <$> requestNode askNode (["ASKING"] : requests) Nothing -> case retryCount of @@ -328,7 +346,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies = evaluateTransactionPipeline :: MVar ShardMap -> IO ShardMap -> Connection -> [[B.ByteString]] -> IO [Reply] evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = do let requests = reverse requests' - let (Connection _ _ _ infoMap _) = conn + let (Connection _ _ _ infoMap _ _) = conn keys <- mconcat <$> mapM (requestKeys infoMap) requests -- In cluster mode Redis expects commands in transactions to all work on the -- same hashslot. We find that hashslot here. @@ -346,7 +364,7 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d resps <- case eresps of Right v -> return v - Left (_ :: SomeException) -> requestNode (getRandomConnection nodeConn conn) requests + Left (_ :: SomeException) -> getRandomConnection nodeConn conn >>= (`requestNode` requests) -- The Redis documentation has the following to say on the effect of -- resharding on multi-key operations: -- @@ -380,8 +398,9 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d nodeConnForHashSlot :: Exception e => MVar ShardMap -> Connection -> e -> HashSlot -> IO NodeConnection nodeConnForHashSlot shardMapVar conn exception hashSlot = do - let (Connection nodeConns _ _ _ _) = conn + let (Connection nodeConnsVar _ _ _ _ _) = conn (ShardMap shardMap) <- hasLocked $ readMVar shardMapVar + nodeConns <- readMVar nodeConnsVar node <- case IntMap.lookup (fromEnum hashSlot) shardMap of Nothing -> throwIO exception @@ -422,13 +441,16 @@ moved (Error errString) = case Char8.words errString of moved _ = False -nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> Maybe NodeConnection -nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _ _) host port = do - node <- nodeWithHostAndPort shardMap host port - HM.lookup (nodeId node) nodeConns +nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> IO (Maybe NodeConnection) +nodeConnWithHostAndPort shardMap (Connection nodeConnsVar _ _ _ _ _) host port = + case nodeWithHostAndPort shardMap host port of + Nothing -> return Nothing + Just node -> do + nodeConns <- readMVar nodeConnsVar + return (HM.lookup (nodeId node) nodeConns) nodeConnectionForCommand :: Connection -> ShardMap -> [B.ByteString] -> IO [NodeConnection] -nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap shardMap) request = +nodeConnectionForCommand conn@(Connection nodeConnsVar _ _ infoMap _ _) (ShardMap shardMap) request = case request of ("FLUSHALL" : _) -> allNodes ("FLUSHDB" : _) -> allNodes @@ -440,16 +462,19 @@ nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap sha node <- case IntMap.lookup (fromEnum hashSlot) shardMap of Nothing -> throwIO $ MissingNodeException request Just (Shard master _) -> return master + nodeConns <- readMVar nodeConnsVar maybe (throwIO $ MissingNodeException request) (return . return) (HM.lookup (nodeId node) nodeConns) where - allNodes = - case allMasterNodes conn (ShardMap shardMap) of + allNodes = do + maybeNodes <- allMasterNodes conn (ShardMap shardMap) + case maybeNodes of Nothing -> throwIO $ MissingNodeException request Just allNodes' -> return allNodes' -allMasterNodes :: Connection -> ShardMap -> Maybe [NodeConnection] -allMasterNodes (Connection nodeConns _ _ _ _) (ShardMap shardMap) = - mapM (flip HM.lookup nodeConns . nodeId) onlyMasterNodes +allMasterNodes :: Connection -> ShardMap -> IO (Maybe [NodeConnection]) +allMasterNodes (Connection nodeConnsVar _ _ _ _ _) (ShardMap shardMap) = do + nodeConns <- readMVar nodeConnsVar + return $ mapM (flip HM.lookup nodeConns . nodeId) onlyMasterNodes where onlyMasterNodes = (\(Shard master _) -> master) <$> nub (IntMap.elems shardMap) @@ -508,14 +533,16 @@ requestMasterNodes conn req = do concat <$> mapM (`requestNode` [req]) masterNodeConns masterNodes :: Connection -> IO [NodeConnection] -masterNodes (Connection nodeConns _ shardMapVar _ _) = do +masterNodes (Connection nodeConnsVar _ shardMapVar _ _ _) = do (ShardMap shardMap) <- readMVar shardMapVar let masters = map ((\(Shard m _) -> m) . snd) $ IntMap.toList shardMap let masterNodeIds = map nodeId masters + nodeConns <- readMVar nodeConnsVar return $ mapMaybe (`HM.lookup` nodeConns) masterNodeIds -getRandomConnection :: NodeConnection -> Connection -> NodeConnection -getRandomConnection nc conn = - let (Connection hmn _ _ _ _) = conn - conns = HM.elems hmn - in fromMaybe (head conns) $ find (nc /= ) conns +getRandomConnection :: NodeConnection -> Connection -> IO NodeConnection +getRandomConnection nc conn = do + let (Connection hmnVar _ _ _ _ _) = conn + hmn <- readMVar hmnVar + let conns = HM.elems hmn + return $ fromMaybe (head conns) $ find (nc /= ) conns diff --git a/src/Database/Redis/Connection.hs b/src/Database/Redis/Connection.hs index b5373761..04321a08 100644 --- a/src/Database/Redis/Connection.hs +++ b/src/Database/Redis/Connection.hs @@ -2,13 +2,14 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE NamedFieldPuns #-} module Database.Redis.Connection where import Control.Exception import qualified Control.Monad.Catch as Catch import Control.Monad.IO.Class(liftIO, MonadIO) import Control.Monad(when) -import Control.Concurrent.MVar(MVar, newMVar) +import Control.Concurrent.MVar(MVar, newMVar, putMVar, readMVar, modifyMVar_) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as Char8 import Data.Functor(void) @@ -34,6 +35,8 @@ import Control.Concurrent (threadDelay) import Control.Concurrent.Async (race) import qualified Database.Redis.Types as T --import qualified Database.Redis.Cluster.Pipeline as ClusterPipeline +import qualified Data.IORef as IOR +import Data.List (nub) import Database.Redis.Commands ( ping @@ -224,32 +227,26 @@ connectCluster bootstrapConnInfo = do Right infos -> do let isConnectionReadOnly = connectReadOnly bootstrapConnInfo - clusterConnection = Cluster.connect withAuth infos shardMapVar timeoutOptUs isConnectionReadOnly refreshShardMapWithNodeConn (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo) + connectTLSParams' = connectTLSParams bootstrapConnInfo + connectAuth' = connectAuth bootstrapConnInfo + tcpInfo = Cluster.TcpInfo { + connectTLSParams = connectTLSParams', + connectAuth = connectAuth', + idleTime = connectMaxIdleTime bootstrapConnInfo, + maxResources = connectMaxConnections bootstrapConnInfo, + timeoutOpt = timeoutOptUs + } + withAuth = tcpConnWithAuth connectAuth' connectTLSParams' + clusterConnection = Cluster.connect withAuth infos shardMapVar isConnectionReadOnly refreshShardMapWithNodeConn tcpInfo -- pool <- createPool (clusterConnect isConnectionReadOnly clusterConnection) Cluster.disconnect 3 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo) connection <- clusterConnect isConnectionReadOnly clusterConnection return $ ClusteredConnection shardMapVar connection where - withAuth host port timeout = do - conn <- PP.connect host port timeout - conn' <- case connectTLSParams bootstrapConnInfo of - Nothing -> return conn - Just tlsParams -> PP.enableTLS tlsParams conn - PP.beginReceiving conn' - - runRedisInternal conn' $ do - -- AUTH - case connectAuth bootstrapConnInfo of - Nothing -> return () - Just pass -> do - resp <- auth pass - case resp of - Left r -> liftIO $ throwIO $ ConnectAuthError r - _ -> return () - return $ PP.toCtx conn' clusterConnect :: Bool -> IO Cluster.Connection -> IO Cluster.Connection clusterConnect readOnlyConnection connection = do - clusterConn@(Cluster.Connection nodeMap _ _ _ _) <- connection + clusterConn@(Cluster.Connection nodeMapVar _ _ _ _ _) <- connection + nodeMap <- readMVar nodeMapVar nodesConns <- sequence $ (ctxToConn . snd) <$> (HM.toList nodeMap) when readOnlyConnection $ mapM_ (\maybeConn -> case maybeConn of @@ -269,6 +266,26 @@ connectCluster bootstrapConnInfo = do putStrLn ("SomeException Occured in NodeID " ++ show nid) return Nothing + +tcpConnWithAuth :: Maybe B.ByteString -> Maybe ClientParams -> Cluster.Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext +tcpConnWithAuth connectAuth connectTLSParams host port timeout = do + conn <- PP.connect host port timeout + conn' <- case connectTLSParams of + Nothing -> return conn + Just tlsParams -> PP.enableTLS tlsParams conn + PP.beginReceiving conn' + + runRedisInternal conn' $ do + -- AUTH + case connectAuth of + Nothing -> return () + Just pass -> do + resp <- auth pass + case resp of + Left r -> liftIO $ throwIO $ ConnectAuthError r + _ -> return () + return $ PP.toCtx conn' + shardMapFromClusterSlotsResponse :: ClusterSlotsResponse -> IO ShardMap shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr mkShardMap (pure IntMap.empty) clusterSlotsResponseEntries where mkShardMap :: ClusterSlotsResponseEntry -> IO (IntMap.IntMap Shard) -> IO (IntMap.IntMap Shard) @@ -286,8 +303,25 @@ shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr m in Cluster.Node clusterSlotsNodeID role hostname (toEnum clusterSlotsNodePort) +refreshCluster :: Cluster.Connection -> IO () +refreshCluster conn@(Cluster.Connection nodeConnsVar _ shardMapVar _ _ Cluster.TcpInfo { idleTime, maxResources, timeoutOpt, connectAuth, connectTLSParams }) = do + newShardMap <- refreshShardMap conn + modifyMVar_ nodeConnsVar $ \_ -> do + putMVar shardMapVar newShardMap + simpleNodeConnections newShardMap + where + withAuth = tcpConnWithAuth connectAuth connectTLSParams + simpleNodeConnections :: ShardMap -> IO (HM.HashMap Cluster.NodeID Cluster.NodeConnection) + simpleNodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ Cluster.nodes shardMap) + connectNode :: Cluster.Node -> IO (Cluster.NodeID, Cluster.NodeConnection) + connectNode (Cluster.Node n _ host port) = do + ctx <- createPool (withAuth host (CC.PortNumber $ toEnum port) timeoutOpt) CC.disconnect 1 idleTime maxResources + ref <- IOR.newIORef Nothing + return (n, Cluster.NodeConnection ctx ref n) + refreshShardMap :: Cluster.Connection -> IO ShardMap -refreshShardMap (Cluster.Connection nodeConns _ _ _ _) = +refreshShardMap (Cluster.Connection nodeConnsVar _ _ _ _ _) = do + nodeConns <- readMVar nodeConnsVar refreshShardMapWithNodeConn (HM.elems nodeConns) refreshShardMapWithNodeConn :: [Cluster.NodeConnection] -> IO ShardMap