diff --git a/build.sbt b/build.sbt index ea93f02..f26195e 100644 --- a/build.sbt +++ b/build.sbt @@ -2,6 +2,14 @@ name := "ratelimiter2" organization := "org.tunnelbear" -version := "0.1" +version := "1.0" scalaVersion := "2.12.6" + +libraryDependencies += "org.specs2" %% "specs2-core" % "4.8.3" % "test" +libraryDependencies += "org.specs2" %% "specs2-mock" % "4.8.3" % "test" + +// Cross-build to support Scala 2.11 projects (remembackend...), and Scala 2.12 projects (tbearDashboard2, polarbackend...) +// See https://www.scala-sbt.org/1.x/docs/Cross-Build.html +lazy val root = (project in file(".")).settings( + crossScalaVersions := List("2.12.6", "2.11.7")) diff --git a/src/main/scala/RateLimiter/RateLimiterService.scala b/src/main/scala/RateLimiter/RateLimiterService.scala index ba22a67..6ce4e0c 100644 --- a/src/main/scala/RateLimiter/RateLimiterService.scala +++ b/src/main/scala/RateLimiter/RateLimiterService.scala @@ -1,32 +1,44 @@ package RateLimiter -import RateLimiter.RateLimiters.{AuthLimiter, IPLimiter, TagLimiter} +import RateLimiter.RateLimiters._ import scala.concurrent.ExecutionContext import scala.concurrent.duration.Duration +// TODO: change blacklistOnBlock field to enableBlacklisting? trait RateLimiterService { implicit def storage: RateLimiterStorage def dictLimit: Long def dictExpiry: Duration + def dictBlacklist: Boolean def bruteLimit: Long def bruteExpiry: Duration + def bruteBlacklist: Boolean def ipLimit: Long def ipExpiry: Duration + def ipBlacklist: Boolean + def tagLimit(tag: String): Long + def tagExpiry(tag: String): Duration + def tagBlacklist(tag: String): Boolean def authLimiter(ip: String, userIdentifier: String)(implicit executionContext: ExecutionContext): AuthLimiter = { - AuthLimiter(ip, userIdentifier, dictLimit, dictExpiry.toMillis, bruteLimit, bruteExpiry.toMillis) + AuthLimiter(ip, userIdentifier, dictLimit, dictExpiry.toMillis, dictBlacklist, bruteLimit, bruteExpiry.toMillis, bruteBlacklist) } - def ipLimiter(ip: String)(implicit executionContext: ExecutionContext) : IPLimiter = { - IPLimiter(ip, ipLimit, ipExpiry.toMillis) + def ipLimiter(ip: String)(implicit executionContext: ExecutionContext): IPLimiter = { + IPLimiter(ip, ipLimit, ipExpiry.toMillis, ipBlacklist) } - def tagLimiter(tag: String, ip: String, limit: Long, expiry: Duration)(implicit executionContext: ExecutionContext): TagLimiter = { - TagLimiter(tag, ip, limit, expiry.toMillis) + def tagLimiter(tag: String, ip: String)(implicit executionContext: ExecutionContext): TagLimiter = { + TagLimiter(tag, ip, tagLimit(tag), tagExpiry(tag).toMillis, tagBlacklist(tag)) + } + + // TODO: define separate methods for this? + def globalTagLimiter(tag: String)(implicit executionContext: ExecutionContext): GlobalTagLimiter = { + GlobalTagLimiter(tag, tagLimit(tag), tagExpiry(tag).toMillis) } } diff --git a/src/main/scala/RateLimiter/RateLimiterStatus.scala b/src/main/scala/RateLimiter/RateLimiterStatus.scala new file mode 100644 index 0000000..f12e356 --- /dev/null +++ b/src/main/scala/RateLimiter/RateLimiterStatus.scala @@ -0,0 +1,9 @@ +package RateLimiter + +object RateLimiterStatus extends Enumeration { + type RateLimiterStatus = Value + + val Allow: RateLimiterStatus = Value("Allow") + val Block: RateLimiterStatus = Value("Block") + val Blacklist: RateLimiterStatus = Value("Blacklist") +} diff --git a/src/main/scala/RateLimiter/RateLimiters/AuthLimiter.scala b/src/main/scala/RateLimiter/RateLimiters/AuthLimiter.scala index 09849c2..b1ece87 100644 --- a/src/main/scala/RateLimiter/RateLimiters/AuthLimiter.scala +++ b/src/main/scala/RateLimiter/RateLimiters/AuthLimiter.scala @@ -3,26 +3,24 @@ package RateLimiter.RateLimiters import RateLimiter.RateLimiterStorage import RateLimiter.Strategies.{BruteForceStrategy, DictionaryStrategy} -import scala.concurrent.{ExecutionContext, Future} - -case class AuthLimiter(ip: String, userIdentifier: String, dictLimit: Long, dictExpiry: Long, bruteLimit: Long, bruteExpiry: Long)(implicit rateLimiterStorage: RateLimiterStorage, executionContext: ExecutionContext) extends BaseRateLimiter { +import scala.concurrent.ExecutionContext + +case class AuthLimiter( + ip: String, + userIdentifier: String, + dictLimit: Long, + dictExpiry: Long, + dictBlacklist: Boolean, + bruteLimit: Long, + bruteExpiry: Long, + bruteBlacklist: Boolean +)(implicit rateLimiterStorage: RateLimiterStorage, override val executionContext: ExecutionContext) extends StrategyRateLimiter { private final val DictIdentifier = "DictAuthLimiter" private final val BruteIdentifier = "BruteAuthLimiter" - private final val Strategies = List( - DictionaryStrategy(DictIdentifier, ip, userIdentifier, dictLimit, dictExpiry), - BruteForceStrategy(BruteIdentifier, ip, userIdentifier, bruteLimit, bruteExpiry) + protected final override def strategies = Seq( + DictionaryStrategy(DictIdentifier, ip, userIdentifier, dictLimit, dictExpiry, dictBlacklist), + BruteForceStrategy(BruteIdentifier, ip, userIdentifier, bruteLimit, bruteExpiry, bruteBlacklist) ) - - override def allow: Future[Boolean] = { - Future.traverse(Strategies)(strategy => strategy.allow) - .map(_.forall(identity)) - } - - override def increment: Future[Unit] = { - Future.traverse(Strategies)(strategy => strategy.increment()) - .map(_.tail) - } } - diff --git a/src/main/scala/RateLimiter/RateLimiters/BaseRateLimiter.scala b/src/main/scala/RateLimiter/RateLimiters/BaseRateLimiter.scala index 9f4d6b9..79ff126 100644 --- a/src/main/scala/RateLimiter/RateLimiters/BaseRateLimiter.scala +++ b/src/main/scala/RateLimiter/RateLimiters/BaseRateLimiter.scala @@ -1,8 +1,11 @@ package RateLimiter.RateLimiters +import RateLimiter.RateLimiterStatus.RateLimiterStatus + import scala.concurrent.Future trait BaseRateLimiter { - def allow: Future[Boolean] - def increment: Future[Unit] + def status: Future[RateLimiterStatus] + def increment(): Future[Unit] + def statusWithIncrement(): Future[RateLimiterStatus] } diff --git a/src/main/scala/RateLimiter/RateLimiters/GlobalTagLimiter.scala b/src/main/scala/RateLimiter/RateLimiters/GlobalTagLimiter.scala new file mode 100644 index 0000000..ac8af34 --- /dev/null +++ b/src/main/scala/RateLimiter/RateLimiters/GlobalTagLimiter.scala @@ -0,0 +1,19 @@ +package RateLimiter.RateLimiters + +import RateLimiter.RateLimiterStorage +import RateLimiter.Strategies.GlobalTagStrategy + +import scala.concurrent.ExecutionContext + +case class GlobalTagLimiter( + tag: String, + limit: Long, + expiry: Long +)(implicit rateLimiterStorage: RateLimiterStorage, override val executionContext: ExecutionContext) extends StrategyRateLimiter { + + private final val Identifier = "GlobalTagLimiter" + + protected final override def strategies = Seq( + GlobalTagStrategy(Identifier, tag, limit, expiry) + ) +} diff --git a/src/main/scala/RateLimiter/RateLimiters/IPLimiter.scala b/src/main/scala/RateLimiter/RateLimiters/IPLimiter.scala index bd2fbcd..60e7303 100644 --- a/src/main/scala/RateLimiter/RateLimiters/IPLimiter.scala +++ b/src/main/scala/RateLimiter/RateLimiters/IPLimiter.scala @@ -3,18 +3,12 @@ package RateLimiter.RateLimiters import RateLimiter.RateLimiterStorage import RateLimiter.Strategies.IPStrategy -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext -case class IPLimiter(ip: String, limit: Long, expiry: Long)(implicit rateLimiterStorage: RateLimiterStorage, executionContext: ExecutionContext) extends BaseRateLimiter { - private final val Identifier = "IPLimiter" - - override def allow: Future[Boolean] = { - IPStrategy(Identifier, ip, limit, expiry).allow - } - - override def increment: Future[Unit] = { - IPStrategy(Identifier, ip, limit, expiry).increment() - } +case class IPLimiter(ip: String, limit: Long, expiry: Long, blacklistOnBlock: Boolean)(implicit rateLimiterStorage: RateLimiterStorage, override val executionContext: ExecutionContext) extends StrategyRateLimiter { + private final val Identifier = s"IPLimiter" + protected final override def strategies = Seq( + IPStrategy(Identifier, ip, limit, expiry, blacklistOnBlock) + ) } - diff --git a/src/main/scala/RateLimiter/RateLimiters/StrategyRateLimiter.scala b/src/main/scala/RateLimiter/RateLimiters/StrategyRateLimiter.scala new file mode 100644 index 0000000..f4fa0dc --- /dev/null +++ b/src/main/scala/RateLimiter/RateLimiters/StrategyRateLimiter.scala @@ -0,0 +1,36 @@ +package RateLimiter.RateLimiters + +import RateLimiter.RateLimiterStatus._ +import RateLimiter.Strategies.BaseStrategy + +import scala.concurrent.{ExecutionContext, Future} + +trait StrategyRateLimiter extends BaseRateLimiter { + protected def strategies: Seq[BaseStrategy] + implicit val executionContext: ExecutionContext + + override def status: Future[RateLimiterStatus] = { + Future + .traverse(strategies)(strategy => strategy.status) + .map(_.fold(Allow) { + case (Allow, status) => status + case (Block, status) => if (status != Allow) status else Block + case (Blacklist, _) => Blacklist + }) + } + + override def increment(): Future[Unit] = { + Future.traverse(strategies)(strategy => strategy.increment()) + .map(_.tail) + } + + // TODO: does this logic make sense, and is it intuitive? Should this logic live here? + override def statusWithIncrement(): Future[RateLimiterStatus] = { + status.map { + case Allow => + increment() + Allow + case status => status + } + } +} diff --git a/src/main/scala/RateLimiter/RateLimiters/TagLimiter.scala b/src/main/scala/RateLimiter/RateLimiters/TagLimiter.scala index fdc7b5e..8ef3c06 100644 --- a/src/main/scala/RateLimiter/RateLimiters/TagLimiter.scala +++ b/src/main/scala/RateLimiter/RateLimiters/TagLimiter.scala @@ -3,19 +3,18 @@ package RateLimiter.RateLimiters import RateLimiter.RateLimiterStorage import RateLimiter.Strategies.TagStrategy -import scala.concurrent.{ExecutionContext, Future} - -case class TagLimiter(tag: String, ip: String, limit: Long, expiry: Long)(implicit rateLimiterStorage: RateLimiterStorage, executionContext: ExecutionContext) extends BaseRateLimiter { - +import scala.concurrent.ExecutionContext + +case class TagLimiter( + tag: String, + ip: String, + limit: Long, + expiry: Long, + blacklistOnBlock: Boolean +)(implicit rateLimiterStorage: RateLimiterStorage, override val executionContext: ExecutionContext) extends StrategyRateLimiter { private final val Identifier = "TagLimiter" - override def allow: Future[Boolean] = { - TagStrategy(Identifier, tag, ip, limit, expiry).allow - } - - override def increment: Future[Unit] = { - TagStrategy(Identifier, tag, ip, limit, expiry).increment() - } - + protected final override def strategies = Seq( + TagStrategy(Identifier, tag, ip, limit, expiry, blacklistOnBlock) + ) } - diff --git a/src/main/scala/RateLimiter/Strategies/BaseStrategy.scala b/src/main/scala/RateLimiter/Strategies/BaseStrategy.scala index a3b6b53..06c00e8 100644 --- a/src/main/scala/RateLimiter/Strategies/BaseStrategy.scala +++ b/src/main/scala/RateLimiter/Strategies/BaseStrategy.scala @@ -1,6 +1,7 @@ package RateLimiter.Strategies import RateLimiter.RateLimiterStorage +import RateLimiter.RateLimiterStatus._ import scala.concurrent.{ExecutionContext, Future} @@ -9,14 +10,18 @@ trait BaseStrategy { implicit def storage: RateLimiterStorage def identifier: String - def ip: String def limit: Long def expiry: Long - - def key: String = s"$identifier:$ip" - - def allow(implicit executionContext: ExecutionContext): Future[Boolean] = { - storage.getCount(key, expiry).map(_ < limit) + def key: String + def blacklistOnBlock: Boolean + + def status(implicit executionContext: ExecutionContext): Future[RateLimiterStatus] = { + storage.getCount(key, expiry).map { count => + println(s"CHECKING: $identifier, $count") + if (count < limit) Allow + else if (!blacklistOnBlock) Block + else Blacklist + } } def increment(): Future[Unit] = { diff --git a/src/main/scala/RateLimiter/Strategies/BruteForceStrategy.scala b/src/main/scala/RateLimiter/Strategies/BruteForceStrategy.scala index d81faf4..dcb24c4 100644 --- a/src/main/scala/RateLimiter/Strategies/BruteForceStrategy.scala +++ b/src/main/scala/RateLimiter/Strategies/BruteForceStrategy.scala @@ -5,8 +5,8 @@ import RateLimiter.RateLimiterStorage /* Ratelimits based on number of attempts on a single user */ -case class BruteForceStrategy(identifier: String, ip: String, userIdentifier: String, limit: Long, expiry: Long)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { - override def storage = rateLimiterStorage +case class BruteForceStrategy(identifier: String, ip: String, userIdentifier: String, limit: Long, expiry: Long, blacklistOnBlock: Boolean)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { + override implicit def storage: RateLimiterStorage = rateLimiterStorage - override def key: String = s"$identifier:$userIdentifier" + def key: String = s"$identifier:$userIdentifier" } diff --git a/src/main/scala/RateLimiter/Strategies/DictionaryStrategy.scala b/src/main/scala/RateLimiter/Strategies/DictionaryStrategy.scala index 13321de..cbc9f21 100644 --- a/src/main/scala/RateLimiter/Strategies/DictionaryStrategy.scala +++ b/src/main/scala/RateLimiter/Strategies/DictionaryStrategy.scala @@ -7,10 +7,12 @@ import scala.concurrent.Future /* Ratelimits based on a single IP attempting many different users */ -case class DictionaryStrategy(identifier: String, ip: String, userIdentifier: String, limit: Long, expiry: Long)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { +case class DictionaryStrategy(identifier: String, ip: String, userIdentifier: String, limit: Long, expiry: Long, blacklistOnBlock: Boolean)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { override implicit def storage: RateLimiterStorage = rateLimiterStorage - override def increment: Future[Unit] = { + def key: String = s"$identifier:$ip" + + override def increment(): Future[Unit] = { storage.incrementCount(key, userIdentifier, expiry) } diff --git a/src/main/scala/RateLimiter/Strategies/GlobalTagStrategy.scala b/src/main/scala/RateLimiter/Strategies/GlobalTagStrategy.scala new file mode 100644 index 0000000..07cda6d --- /dev/null +++ b/src/main/scala/RateLimiter/Strategies/GlobalTagStrategy.scala @@ -0,0 +1,16 @@ +package RateLimiter.Strategies + +import RateLimiter.RateLimiterStorage + +/* + Ratelimits based on number of requests with this tag for all users + Can be used to ratelimit specific actions for example + */ +case class GlobalTagStrategy(identifier: String, tag: String, limit: Long, expiry: Long)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { + override implicit def storage: RateLimiterStorage = rateLimiterStorage + + def key: String = s"$identifier:$tag" + + // Should never blacklist since that would effectively block all users + override def blacklistOnBlock = false +} diff --git a/src/main/scala/RateLimiter/Strategies/IPStrategy.scala b/src/main/scala/RateLimiter/Strategies/IPStrategy.scala index fde0186..aed9212 100644 --- a/src/main/scala/RateLimiter/Strategies/IPStrategy.scala +++ b/src/main/scala/RateLimiter/Strategies/IPStrategy.scala @@ -5,6 +5,8 @@ import RateLimiter.RateLimiterStorage /* Ratelimits based on number of requests for a single ip */ -case class IPStrategy(identifier: String, ip: String, limit: Long, expiry: Long)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { - override def storage = rateLimiterStorage +case class IPStrategy(identifier: String, ip: String, limit: Long, expiry: Long, blacklistOnBlock: Boolean)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { + override implicit def storage: RateLimiterStorage = rateLimiterStorage + + def key: String = s"$identifier:$ip" } diff --git a/src/main/scala/RateLimiter/Strategies/TagStrategy.scala b/src/main/scala/RateLimiter/Strategies/TagStrategy.scala index 1e2354f..63ab11c 100644 --- a/src/main/scala/RateLimiter/Strategies/TagStrategy.scala +++ b/src/main/scala/RateLimiter/Strategies/TagStrategy.scala @@ -6,8 +6,8 @@ import RateLimiter.RateLimiterStorage Ratelimits based on number of requests with this tag for a single ip. Can be used to ratelimit specific actions for example */ -case class TagStrategy(identifier: String, tag: String, ip: String, limit: Long, expiry: Long)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { +case class TagStrategy(identifier: String, tag: String, ip: String, limit: Long, expiry: Long, blacklistOnBlock: Boolean)(implicit rateLimiterStorage: RateLimiterStorage) extends BaseStrategy { override implicit def storage: RateLimiterStorage = rateLimiterStorage - override def key: String = s"$identifier:$tag:$ip" + def key: String = s"$identifier:$tag:$ip" } diff --git a/src/test/scala/Example/ExampleController.scala b/src/test/scala/Example/ExampleController.scala index 24221bf..5fc720c 100644 --- a/src/test/scala/Example/ExampleController.scala +++ b/src/test/scala/Example/ExampleController.scala @@ -1,26 +1,48 @@ package Example -import RateLimiter.RateLimiters.TagLimiter +import RateLimiter.RateLimiters.{AuthLimiter, TagLimiter} +import RateLimiter.RateLimiterStatus._ -import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future class ExampleController extends RateLimitedController { - def someAuthAction(ip: String, email: String): Int = { - if (!authLimiter(ip, email).allow) return 429 - else authLimiter(ip, email).increment + private val Action1 = "Action1" + private val Action2 = "Action2" - // do actions - 200 + // Example of defining a rate limit depending on the action + override def tagLimit(tag: String): Long = tag match { + case `Action1` => 10 + case `Action2` => 20 + case _ => super.tagLimit(tag) } - def someSpecificAction(ip: String): Int = { - val limiter: TagLimiter = tagLimiter("specific", ip, 10, 1 minute) - - if (!limiter.allow) return 429 - else limiter.increment + def someAuthAction(ip: String, email: String): Future[Int] = { + val limiter: AuthLimiter = authLimiter(ip, email) + + // Note that you wouldn't need to explicitly wrap your action with beforeAction if you were using one of Play's + // action composition design patterns (e.g., Stackable Controller) + beforeAction(ip) { + limiter.statusWithIncrement().map { + case Allow => + // do stuff + 200 + case _ => 429 + } + } + } - // do stuff - 200 + def someSpecificAction(ip: String): Future[Int] = { + val limiter: TagLimiter = tagLimiter(Action1, ip) + + beforeAction(ip) { + limiter.statusWithIncrement().map { + case Allow => + // do stuff + 200 + case _ => 429 + } + } } } diff --git a/src/test/scala/Example/RateLimitedController.scala b/src/test/scala/Example/RateLimitedController.scala index 9a914c8..7b0a0dc 100644 --- a/src/test/scala/Example/RateLimitedController.scala +++ b/src/test/scala/Example/RateLimitedController.scala @@ -1,16 +1,23 @@ package Example import RateLimiter.RateLimiterStorage +import RateLimiter.RateLimiters.IPLimiter +import RateLimiter.RateLimiterStatus._ + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future class RateLimitedController extends RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = RateLimiterStorageImpl // this theoretical function would wrap all requests inherited from RateLimitedController - def beforeAction(ip: String): Int = { - if (!ipLimiter(ip).allow) return 429 - else ipLimiter(ip).increment + def beforeAction(ip: String)(f: => Future[Int]): Future[Int] = { + val limiter: IPLimiter = ipLimiter(ip) - 200 + // TODO: note that if blacklisting is enabled, the case isn't handled. + limiter.statusWithIncrement().flatMap { + case Allow => f + case Block => Future.successful(429) + } } - } diff --git a/src/test/scala/Example/RateLimiterServiceImpl.scala b/src/test/scala/Example/RateLimiterServiceImpl.scala index 508e9e7..58d9a65 100644 --- a/src/test/scala/Example/RateLimiterServiceImpl.scala +++ b/src/test/scala/Example/RateLimiterServiceImpl.scala @@ -6,13 +6,20 @@ import scala.concurrent.duration._ trait RateLimiterServiceImpl extends RateLimiterService { - def dictLimit: Long = 5 - def dictExpiry: Duration = 1 day + override def dictLimit: Long = 5 + override def dictExpiry: Duration = 1 day + override def dictBlacklist: Boolean = false - def bruteLimit: Long = 10 - def bruteExpiry: Duration = 10 minutes + override def bruteLimit: Long = 10 + override def bruteExpiry: Duration = 10 minutes + override def bruteBlacklist: Boolean = false - def ipLimit: Long = 50 - def ipExpiry: Duration = 2 minutes + override def ipLimit: Long = 50 + override def ipExpiry: Duration = 2 minutes + override def ipBlacklist: Boolean = false + + override def tagLimit(tag: String): Long = ipLimit + override def tagExpiry(tag: String): Duration = ipExpiry + override def tagBlacklist(tag: String): Boolean = ipBlacklist } diff --git a/src/test/scala/Example/RateLimiterStorageImpl.scala b/src/test/scala/Example/RateLimiterStorageImpl.scala index 2811557..e7ce20d 100644 --- a/src/test/scala/Example/RateLimiterStorageImpl.scala +++ b/src/test/scala/Example/RateLimiterStorageImpl.scala @@ -2,23 +2,26 @@ package Example import RateLimiter.RateLimiterStorage +import scala.concurrent.Future + // Implementation using in memory cache object RateLimiterStorageImpl extends RateLimiterStorage { - var Storage = Map[String, Map[String, Long]]() + private var Storage = Map[String, Map[String, Long]]() - def incrementCount(key: String, value: String, expiry: Long) = { + def incrementCount(key: String, value: String, expiry: Long): Future[Unit] = { // Add new entry val entries = Storage.getOrElse(key, Map[String, Long]()) + (value -> System.currentTimeMillis) // Update storage Storage += (key -> entries) + Future.successful(()) } - def getCount(key: String, expiry: Long): Long = { + def getCount(key: String, expiry: Long): Future[Long] = { val expires: Long = System.currentTimeMillis - expiry // non-expired entries val entries = Storage.getOrElse(key, Map[String, Long]()).filter(_._2 > expires) // Update storage to remove expired Storage += (key -> entries) - entries.size + Future.successful(entries.size) } } diff --git a/src/test/scala/RateLimiterServiceSpec.scala b/src/test/scala/RateLimiterServiceSpec.scala new file mode 100644 index 0000000..bc6cb15 --- /dev/null +++ b/src/test/scala/RateLimiterServiceSpec.scala @@ -0,0 +1,201 @@ +import RateLimiter.{RateLimiterService, RateLimiterStorage} +import RateLimiter.RateLimiterStatus._ +import org.specs2.concurrent.ExecutionEnv +import org.specs2.mock.Mockito +import org.specs2.mutable._ + +import scala.concurrent.Future +import scala.concurrent.duration._ + +class RateLimiterServiceSpec(implicit ee: ExecutionEnv) extends Specification with Mockito { + + val ip = "123.123.123.123" + val userIdentifier = "michael@tunnelbear.com" + + val ipLimiterKey = s"IPLimiter:$ip" + val dictLimiterKey = s"DictAuthLimiter:$ip" + val bruteLimiterKey = s"BruteAuthLimiter:$userIdentifier" + + val tag = "tag1" + val tagWithBlacklist = "tag2" + val tagLimiterKey = s"TagLimiter:$tag:$ip" + val tagLimiterKeyBlacklist = s"TagLimiter:$tagWithBlacklist:$ip" + + val globalTag1 = "tag" + val globalTag2 = "globalTag2" + val globalTagLimiterKey1 = s"GlobalTagLimiter:$globalTag1" + val globalTagLimiterKey2 = s"GlobalTagLimiter:$globalTag2" + + trait RateLimiterServiceImpl extends RateLimiterService { + + override def dictLimit: Long = 10 + override def dictExpiry: Duration = 1 minute + override def dictBlacklist: Boolean = false + + override def bruteLimit: Long = 20 + override def bruteExpiry: Duration = 2 minutes + override def bruteBlacklist: Boolean = false + + override def ipLimit: Long = 30 + override def ipExpiry: Duration = 3 minutes + override def ipBlacklist: Boolean = false + + override def tagLimit(t: String): Long = t match { + case `tag` => 40 + case `tagWithBlacklist` => 50 + case _ => 60 + } + override def tagExpiry(t: String): Duration = t match { + case `tag` => 4 minutes + case `tagWithBlacklist` => 5 minutes + case _ => 6 minutes + } + override def tagBlacklist(t: String): Boolean = t match { + case `tag` => false + case `tagWithBlacklist` => true + case _ => false + } + + } + + // TODO: test status and statusWithIncrement + "RateLimiterServiceTestImpl provides an IPLimiter" should { + + "allows request that does not exceed given rate" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(ipLimiterKey, rls.ipExpiry.toMillis) returns Future.successful(1) + + rls.ipLimiter(ip).status must be_==(Allow).await + } + + "blocks request that exceeds given rate" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(ipLimiterKey, rls.ipExpiry.toMillis) returns Future.successful(rls.ipLimit) + + rls.ipLimiter(ip).status must be_==(Block).await + } + + "blacklist user if request is blocked and blacklisting is enabled" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { + override implicit def storage: RateLimiterStorage = mockStorage + override def ipBlacklist = true + } + mockStorage.getCount(ipLimiterKey, rls.ipExpiry.toMillis) returns Future.successful(rls.ipLimit) + + rls.ipLimiter(ip).status must be_==(Blacklist).await + } + + } + + + + "RateLimiterServiceTestImpl provides an AuthLimiter" should { + + "allows request that does not exceed given rates" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(dictLimiterKey, rls.dictExpiry.toMillis) returns Future.successful(1) + mockStorage.getCount(bruteLimiterKey, rls.bruteExpiry.toMillis) returns Future.successful(1) + + rls.authLimiter(ip, userIdentifier).status must be_==(Allow).await + } + + "blocks request that exceeds dict rate" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(dictLimiterKey, rls.dictExpiry.toMillis) returns Future.successful(rls.dictLimit) + mockStorage.getCount(bruteLimiterKey, rls.bruteExpiry.toMillis) returns Future.successful(1) + + rls.authLimiter(ip, userIdentifier).status must be_==(Block).await + } + + "blocks request that exceeds brute rate" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(dictLimiterKey, rls.dictExpiry.toMillis) returns Future.successful(1) + mockStorage.getCount(bruteLimiterKey, rls.bruteExpiry.toMillis) returns Future.successful(rls.bruteLimit) + + rls.authLimiter(ip, userIdentifier).status must be_==(Block).await + } + + "blacklist a user if request is blocked and blacklisting is enabled" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { + override implicit def storage: RateLimiterStorage = mockStorage + override def dictBlacklist = true + } + mockStorage.getCount(dictLimiterKey, rls.dictExpiry.toMillis) returns Future.successful(rls.dictLimit) + mockStorage.getCount(bruteLimiterKey, rls.bruteExpiry.toMillis) returns Future.successful(1) + + rls.authLimiter(ip, userIdentifier).status must be_==(Blacklist).await + } + + } + + + + "RateLimiterServiceTestImpl provides a TagLimiter" should { + + "allows request that does not exceed given rates" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + + mockStorage.getCount(tagLimiterKey, rls.tagExpiry(tag).toMillis) returns Future.successful(1) + mockStorage.getCount(tagLimiterKeyBlacklist, rls.tagExpiry(tagWithBlacklist).toMillis) returns Future.successful(1) + + rls.tagLimiter(tag, ip).status must be_==(Allow).await + rls.tagLimiter(tagWithBlacklist, ip).status must be_==(Allow).await + } + + "blocks request that exceeds given rate" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(tagLimiterKey, rls.tagExpiry(tag).toMillis) returns Future.successful(rls.tagLimit(tag)) + mockStorage.getCount(tagLimiterKeyBlacklist, rls.tagExpiry(tagWithBlacklist).toMillis) returns Future.successful(1) + + rls.tagLimiter(tag, ip).status must be_==(Block).await + rls.tagLimiter(tagWithBlacklist, ip).status must be_==(Allow).await + } + + "blacklist a user if request is blocked and blacklisting is enabled" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(tagLimiterKey, rls.tagExpiry(tag).toMillis) returns Future.successful(1) + mockStorage.getCount(tagLimiterKeyBlacklist, rls.tagExpiry(tagWithBlacklist).toMillis) returns Future.successful(rls.tagLimit(tagWithBlacklist)) + + rls.tagLimiter(tag, ip).status must be_==(Allow).await + rls.tagLimiter(tagWithBlacklist, ip).status must be_==(Blacklist).await + } + + } + + + + "RateLimiterServiceTestImpl provides a GlobalTagLimiter" should { + + "allows request that does not exceed given rates" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(globalTagLimiterKey1, rls.tagExpiry(globalTag1).toMillis) returns Future.successful(1) + mockStorage.getCount(globalTagLimiterKey2, rls.tagExpiry(globalTag2).toMillis) returns Future.successful(1) + + rls.globalTagLimiter(globalTag1).status must be_==(Allow).await + rls.globalTagLimiter(globalTag2).status must be_==(Allow).await + } + + "blocks request that exceeds given rate" in { + val mockStorage = mock[RateLimiterStorage] + val rls = new RateLimiterServiceImpl { override implicit def storage: RateLimiterStorage = mockStorage } + mockStorage.getCount(globalTagLimiterKey1, rls.tagExpiry(globalTag1).toMillis) returns Future.successful(rls.tagLimit(globalTag1)) + mockStorage.getCount(globalTagLimiterKey2, rls.tagExpiry(globalTag2).toMillis) returns Future.successful(1) + + rls.globalTagLimiter(globalTag1).status must be_==(Block).await + rls.globalTagLimiter(globalTag2).status must be_==(Allow).await + } + + } + +}