diff --git a/src/main/scala/com/redis/IO.scala b/src/main/scala/com/redis/IO.scala index c63e2c09..758d7293 100644 --- a/src/main/scala/com/redis/IO.scala +++ b/src/main/scala/com/redis/IO.scala @@ -19,6 +19,8 @@ trait IO extends Log { socket != null && socket.isBound && !socket.isClosed && socket.isConnected && !socket.isInputShutdown && !socket.isOutputShutdown } + def onConnect(): Unit + // Connects the socket, and sets the input and output streams. def connect: Boolean = { try { @@ -31,6 +33,7 @@ trait IO extends Log { out = socket.getOutputStream in = new BufferedInputStream(socket.getInputStream) + onConnect() true } catch { case x: Throwable => diff --git a/src/main/scala/com/redis/RedisClient.scala b/src/main/scala/com/redis/RedisClient.scala index b9c1349a..738bb23e 100644 --- a/src/main/scala/com/redis/RedisClient.scala +++ b/src/main/scala/com/redis/RedisClient.scala @@ -30,10 +30,10 @@ trait Redis extends IO with Protocol { result } catch { case e: RedisConnectionException => - if (reconnect) send(command, args)(result) + if (disconnect) send(command, args)(result) else throw e case e: SocketException => - if (reconnect) send(command, args)(result) + if (disconnect) send(command, args)(result) else throw e } @@ -42,10 +42,10 @@ trait Redis extends IO with Protocol { result } catch { case e: RedisConnectionException => - if (reconnect) send(command)(result) + if (disconnect) send(command)(result) else throw e case e: SocketException => - if (reconnect) send(command)(result) + if (disconnect) send(command)(result) else throw e } @@ -54,11 +54,6 @@ trait Redis extends IO with Protocol { protected def flattenPairs(in: Iterable[Product2[Any, Any]]): List[Any] = in.iterator.flatMap(x => Iterator(x._1, x._2)).toList - def reconnect: Boolean = { - disconnect && initialize - } - - protected def initialize : Boolean } trait RedisCommand extends Redis @@ -78,16 +73,11 @@ trait RedisCommand extends Redis val database: Int = 0 val secret: Option[Any] = None - override def initialize : Boolean = { - if(connect) { - secret.foreach {s => - auth(s) - } - selectDatabase() - true - } else { - false + override def onConnect: Unit = { + secret.foreach {s => + auth(s) } + selectDatabase() } private def selectDatabase(): Unit = { @@ -106,8 +96,6 @@ class RedisClient(override val host: String, override val port: Int, override val database: Int = 0, override val secret: Option[Any] = None, override val timeout : Int = 0) extends RedisCommand with PubSub { - initialize - def this() = this("localhost", 6379) def this(connectionUri: java.net.URI) = this( host = connectionUri.getHost, @@ -217,13 +205,12 @@ class RedisClient(override val host: String, override val port: Int, // TODO: Find a better abstraction override def connected = parent.connected override def connect = parent.connect - override def reconnect = parent.reconnect override def disconnect = parent.disconnect override def clearFd = parent.clearFd override def write(data: Array[Byte]) = parent.write(data) override def readLine = parent.readLine override def readCounted(count: Int) = parent.readCounted(count) - override def initialize = parent.initialize + override def onConnect() = parent.onConnect() override def close(): Unit = parent.close() } diff --git a/src/main/scala/com/redis/ds/Deque.scala b/src/main/scala/com/redis/ds/Deque.scala index 5a637ddc..d530f786 100644 --- a/src/main/scala/com/redis/ds/Deque.scala +++ b/src/main/scala/com/redis/ds/Deque.scala @@ -84,7 +84,6 @@ class RedisDequeClient(val h: String, val p: Int, val d: Int = 0, val s: Option[ val key = k override val database = d override val secret = s - initialize override def close(): Unit = disconnect } diff --git a/src/test/scala/com/redis/RedisClientSpec.scala b/src/test/scala/com/redis/RedisClientSpec.scala index 5eb3f429..3a0aa9e7 100644 --- a/src/test/scala/com/redis/RedisClientSpec.scala +++ b/src/test/scala/com/redis/RedisClientSpec.scala @@ -1,10 +1,15 @@ package com.redis -import java.net.URI +import java.net.{ServerSocket, URI} +import com.github.dockerjava.core.DefaultDockerClientConfig import com.redis.api.ApiSpec -import org.scalatest.FunSpec -import org.scalatest.Matchers +import com.whisk.docker.DockerContainerManager +import com.whisk.docker.impl.dockerjava.Docker +import org.scalatest.{FunSpec, Matchers} + +import scala.concurrent.Await +import scala.concurrent.duration._ class RedisClientSpec extends FunSpec with Matchers with ApiSpec { @@ -67,4 +72,39 @@ class RedisClientSpec extends FunSpec r.get("vvl:qm") r.close() }} + + describe("test reconnect") { + it("should re-init after server restart") { + val docker = new Docker(DefaultDockerClientConfig.createDefaultConfigBuilder().build()).client + + val port = { + val s = new ServerSocket(0) + val p = s.getLocalPort + s.close() + p + } + + val manager = new DockerContainerManager( + createContainer(ports = Map(redisPort -> port)) :: Nil, dockerFactory.createExecutor() + ) + + val key = "test-1" + val value = "test-value-1" + + val (cs, _) :: _ = Await.result(manager.initReadyAll(20.seconds), 21.second) + val id = Await.result(cs.id, 10.seconds) + + val c = new RedisClient(redisContainerHost, port, 8, timeout = 10.seconds.toMillis.toInt) + c.set(key, value) + docker.stopContainerCmd(id).exec() + try {c.get(key)} catch { case e: Throwable => } + docker.startContainerCmd(id).exec() + val got = c.get(key) + c.close() + docker.removeContainerCmd(id).withForce(true).withRemoveVolumes(true).exec() + docker.close() + + got shouldBe Some(value) + } + } }