From ebb30c4fb011e84b3c1cbc51d9eac9705d6780c5 Mon Sep 17 00:00:00 2001 From: Piston Date: Thu, 10 Sep 2015 22:26:49 +0300 Subject: [PATCH] fixes --- .../com/karasiq/proxychain/app/Server.scala | 98 +++---------------- .../proxychain/app/TLSHandlerTamper.scala | 81 +++++++++++++++ 2 files changed, 97 insertions(+), 82 deletions(-) create mode 100644 src/main/scala/com/karasiq/proxychain/app/TLSHandlerTamper.scala diff --git a/src/main/scala/com/karasiq/proxychain/app/Server.scala b/src/main/scala/com/karasiq/proxychain/app/Server.scala index 5d10402..39bae8e 100644 --- a/src/main/scala/com/karasiq/proxychain/app/Server.scala +++ b/src/main/scala/com/karasiq/proxychain/app/Server.scala @@ -1,17 +1,14 @@ package com.karasiq.proxychain.app -import java.io.IOException import java.net.InetSocketAddress import java.nio.channels.{ServerSocketChannel, SocketChannel} import java.util.concurrent.Executors import akka.actor._ import akka.event.Logging -import akka.io.Tcp -import com.karasiq.networkutils.SocketChannelWrapper import com.karasiq.tls.{TLS, TLSCertificateVerifier, TLSKeyStore, TLSServerWrapper} -import scala.concurrent.{ExecutionContext, Promise} +import scala.concurrent.{ExecutionContext, Future, Promise} import scala.util.control private[app] final class Server(cfg: AppConfig) extends Actor with ActorLogging { @@ -70,32 +67,32 @@ private[app] final class TLSServer(address: InetSocketAddress, cfg: AppConfig) e def receive = { case Accepted(socket) ⇒ // New connection accepted import context.dispatcher - val tlsTamper = Promise[ActorRef]() val handler = context.actorOf(Props(classOf[Handler], cfg)) val catcher = control.Exception.allCatch.withApply { exc ⇒ - if (!tlsTamper.tryFailure(exc)) { - tlsTamper.future.onSuccess { - case ar: ActorRef ⇒ - context.stop(ar) - } - } context.stop(handler) socket.close() } catcher { - val serverWrapper = new TLSServerWrapper(keySet, clientAuth, new TLSCertificateVerifier()) { - private val log = Logging(context.system, handler) + val log = Logging(context.system, handler) + val tlsSocket = Promise[SocketChannel]() + tlsSocket.future.onFailure { + case exc ⇒ + log.error(exc, "Error opening TLS socket") + handler ! ErrorClosed + } + + val serverWrapper = new TLSServerWrapper(keySet, clientAuth, new TLSCertificateVerifier()) { override protected def onInfo(message: String): Unit = { log.debug(message) } override protected def onHandshakeFinished(): Unit = { log.debug("TLS handhake finished") - tlsTamper.future.onSuccess { - case tamper ⇒ - tamper ! ResumeReading + tlsSocket.future.onSuccess { case socket: SocketChannel ⇒ + val actor = context.actorOf(Props(classOf[TLSHandlerTamper], socket)) + actor ! Register(handler) } } @@ -104,72 +101,9 @@ private[app] final class TLSServer(address: InetSocketAddress, cfg: AppConfig) e } } - val tlsSocket = serverWrapper(socket) - - val tamperActor = context.actorOf(Props(new Actor with Stash { - @throws[Exception](classOf[Exception]) - override def preStart(): Unit = { - val catcher = control.Exception.catching(classOf[IOException]).withApply { _ ⇒ - self ! ErrorClosed - } - - catcher { - super.preStart() - context.watch(handler) - SocketChannelWrapper.register(tlsSocket, self) - } - } - - override def postStop(): Unit = { - SocketChannelWrapper.unregister(tlsSocket) - tlsSocket.close() - super.postStop() - } - - def onClose: Receive = { - case c @ Tcp.Closed ⇒ - handler ! c - context.stop(self) - - case c @ Tcp.Close ⇒ - sender() ! ConfirmedClosed - context.stop(self) - - case Terminated(_) ⇒ - context.stop(self) - } - - def readSuspended: Receive = { - case Received(data) ⇒ - stash() - } - - def readResumed: Receive = { - case r @ Received(data) ⇒ - handler ! r - } - - def streaming: Receive = { - case SuspendReading ⇒ - context.become(onClose.orElse(readSuspended).orElse(streaming)) - - case ResumeReading ⇒ - unstashAll() - context.become(onClose.orElse(readResumed).orElse(streaming)) - - case w @ Write(data, ack) ⇒ - tlsSocket.write(data.toByteBuffer) - if (ack != Tcp.NoAck) sender() ! ack - - case event: Tcp.Event ⇒ - handler ! event - } - - override def receive: Receive = onClose.orElse(readSuspended).orElse(streaming) - })) - if (!tlsTamper.trySuccess(tamperActor)) { - context.stop(tamperActor) - } + tlsSocket.completeWith(Future { + serverWrapper(socket) + }) } } } \ No newline at end of file diff --git a/src/main/scala/com/karasiq/proxychain/app/TLSHandlerTamper.scala b/src/main/scala/com/karasiq/proxychain/app/TLSHandlerTamper.scala new file mode 100644 index 0000000..ed770aa --- /dev/null +++ b/src/main/scala/com/karasiq/proxychain/app/TLSHandlerTamper.scala @@ -0,0 +1,81 @@ +package com.karasiq.proxychain.app + +import java.nio.channels.SocketChannel + +import akka.actor._ +import akka.io.Tcp +import akka.io.Tcp._ +import com.karasiq.networkutils.SocketChannelWrapper + +import scala.util.control + +class TLSHandlerTamper(tlsSocket: SocketChannel) extends Actor with ActorLogging with Stash { + private var handler: Option[ActorRef] = None + + @throws[Exception](classOf[Exception]) + override def preStart(): Unit = { + val catcher = control.Exception.allCatch.withApply { exc ⇒ + log.error(exc, "TLS initialization error") + self ! ErrorClosed + } + + catcher { + super.preStart() + SocketChannelWrapper.register(tlsSocket, self) + } + } + + override def postStop(): Unit = { + log.debug("TLS tamper stopped: {}", tlsSocket) + SocketChannelWrapper.unregister(tlsSocket) + tlsSocket.close() + super.postStop() + } + + def onClose: Receive = { + case c @ Tcp.Closed ⇒ + handler.foreach(_ ! c) + context.stop(self) + + case c @ Tcp.Close ⇒ + sender() ! ConfirmedClosed + context.stop(self) + + case Terminated(_) ⇒ + context.stop(self) + + case Register(newHandler, _, _) ⇒ + handler.foreach(context.unwatch) + context.watch(newHandler) + handler = Some(newHandler) + self ! ResumeReading + } + + def readSuspended: Receive = { + case Received(data) ⇒ + stash() + } + + def readResumed: Receive = { + case r @ Received(data) ⇒ + handler.foreach(_ ! r) + } + + def streaming: Receive = { + case SuspendReading ⇒ + context.become(onClose.orElse(readSuspended).orElse(streaming)) + + case ResumeReading ⇒ + unstashAll() + context.become(onClose.orElse(readResumed).orElse(streaming)) + + case w @ Write(data, ack) ⇒ + tlsSocket.write(data.toByteBuffer) + if (ack != Tcp.NoAck) sender() ! ack + + case event: Tcp.Event ⇒ + handler.foreach(_ ! event) + } + + override def receive: Receive = onClose.orElse(readSuspended).orElse(streaming) +}