Skip to content

Commit

Permalink
Redis backend now uses BITFIELD to store counts`
Browse files Browse the repository at this point in the history
  • Loading branch information
jacob-pro committed Jan 21, 2024
1 parent 8fe3403 commit 8f64b48
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 61 deletions.
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changes

## 0.3.0 2024-01-21

- Breaking: Removes async-trait dependency
- Breaking: Redis backend now uses BITFIELD to store counts

## 0.2.2 2022-04-19

- Improve documentation.
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ actix-web = { version = "4", default-features = false, features = ["macros"] }
dashmap = { version = "5.4.0", optional = true }
futures = "0.3.28"
log = "0.4.19"
redis = { version = "0.23.0", default-features = false, features = ["tokio-comp", "aio", "connection-manager"], optional = true }
redis = { version = "0.24.0", default-features = false, features = ["tokio-comp", "aio", "connection-manager"], optional = true }
thiserror = "1.0.40"

[features]
Expand Down
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,25 @@ async fn main() -> std::io::Result<()> {
.await
}
```

Try it out:

```
$ curl -v http://127.0.0.1:8080
* Trying 127.0.0.1:8080...
* Connected to 127.0.0.1 (127.0.0.1) port 8080 (#0)
> GET / HTTP/1.1
> Host: 127.0.0.1:8080
> User-Agent: curl/7.83.1
> Accept: */*
>
* Mark bundle as not supporting multiuse
< HTTP/1.1 404 Not Found
< content-length: 0
< x-ratelimit-limit: 5
< x-ratelimit-reset: 60
< x-ratelimit-remaining: 4
< date: Sun, 21 Jan 2024 16:52:27 GMT
<
* Connection #0 to host 127.0.0.1 left intact
```
2 changes: 1 addition & 1 deletion src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Decision {
}

pub fn is_allowed(self) -> bool {
self == Self::Allowed
matches!(self, Self::Allowed)
}

pub fn is_denied(self) -> bool {
Expand Down
124 changes: 65 additions & 59 deletions src/backend/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,13 @@ use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput
use actix_web::rt::time::Instant;
use actix_web::{HttpResponse, ResponseError};
use redis::aio::ConnectionManager;
use redis::AsyncCommands;
use redis::{AsyncCommands, Cmd};
use std::borrow::Cow;
use std::time::Duration;
use thiserror::Error;

// https://github.com/mitsuhiko/redis-rs/issues/353
macro_rules! async_transaction {
($conn:expr, $keys:expr, $body:expr) => {
loop {
redis::cmd("WATCH").arg($keys).query_async($conn).await?;

if let Some(response) = $body {
redis::cmd("UNWATCH").query_async($conn).await?;
break response;
}
}
};
}
const BITFIELD_ENCODING: &str = "u63";
const BITFIELD_OFFSET: u8 = 0;

#[derive(Debug, Error)]
pub enum Error {
Expand All @@ -29,7 +18,7 @@ pub enum Error {
#[from]
redis::RedisError,
),
#[error("Unexpected negative TTL response")]
#[error("Unexpected negative TTL response for the rate limit key")]
NegativeTtl,
}

Expand Down Expand Up @@ -58,7 +47,7 @@ impl RedisBackend {
/// ```no_run
/// # use actix_extensible_rate_limit::backend::redis::RedisBackend;
/// # use redis::aio::ConnectionManager;
/// # async {
/// # async fn example() {
/// let client = redis::Client::open("redis://127.0.0.1/").unwrap();
/// let manager = ConnectionManager::new(client).await.unwrap();
/// let backend = RedisBackend::builder(manager).build();
Expand Down Expand Up @@ -112,26 +101,37 @@ impl Backend<SimpleInput> for RedisBackend {
input: SimpleInput,
) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> {
let key = self.make_key(&input.key);
// https://github.com/actix/actix-extras/blob/master/actix-limitation/src/lib.rs#L123

let mut pipe = redis::pipe();
pipe.atomic()
.cmd("SET") // Set key and value
// Increment the rate limit count
.cmd("BITFIELD")
.arg(key.as_ref())
.arg(0i64)
.arg("EX") // Set the specified expire time, in seconds.
.arg(input.interval.as_secs())
.arg("NX") // Only set the key if it does not already exist.
.ignore() // --- ignore returned value of SET command ---
.cmd("INCR") // Increment key
.arg("OVERFLOW")
.arg("SAT")
.arg("INCRBY")
.arg(BITFIELD_ENCODING)
.arg(BITFIELD_OFFSET)
.arg(1)
.arg("GET")
.arg(BITFIELD_ENCODING)
.arg(BITFIELD_OFFSET)
// Set the key to expire (only if it doesn't already have an expiry)
.cmd("EXPIRE")
.arg(key.as_ref())
.cmd("TTL") // Return time-to-live of key
.arg(input.interval.as_secs())
.arg("NX")
.ignore()
// Return time-to-live of key
.cmd("TTL")
.arg(key.as_ref());

let mut con = self.connection.clone();
let (count, ttl): (u64, i64) = pipe.query_async(&mut con).await?;
let (counts, ttl): (Vec<u64>, i64) = pipe.query_async(&mut con).await?;
if ttl < 0 {
return Err(Self::Error::NegativeTtl);
return Err(Error::NegativeTtl);
}
let count = *counts.first().expect("BITFIELD should return one value");

let allow = count <= input.max_requests;
let output = SimpleOutput {
Expand All @@ -144,24 +144,20 @@ impl Backend<SimpleInput> for RedisBackend {

async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> {
let key = self.make_key(&token);

let mut con = self.connection.clone();
async_transaction!(&mut con, &[key.as_ref()], {
let old_val: Option<u64> = con.get(key.as_ref()).await?;
if let Some(old_val) = old_val {
if old_val >= 1 {
redis::pipe()
.atomic()
.decr::<_, u64>(key.as_ref(), 1)
.ignore()
.query_async::<_, Option<()>>(&mut con)
.await?
} else {
Some(())
}
} else {
Some(())
}
});
let mut cmd = Cmd::new();
cmd.arg("BITFIELD")
.arg(key.as_ref())
.arg("OVERFLOW")
.arg("SAT")
.arg("INCRBY")
.arg(BITFIELD_ENCODING)
.arg(BITFIELD_OFFSET)
.arg(-1);

cmd.query_async(&mut con).await?;

Ok(())
}
}
Expand Down Expand Up @@ -203,13 +199,17 @@ mod tests {
max_requests: 5,
key: "test_allow_deny".to_string(),
};
for _ in 0..5 {
for i in (0..5).rev() {
// First 5 should be allowed
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
let (decision, output, _) = backend.request(input.clone()).await.unwrap();
assert_eq!(output.remaining, i);
assert_eq!(output.limit, 5);
assert!(decision.is_allowed());
}
// Sixth should be denied
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
let (decision, output, _) = backend.request(input.clone()).await.unwrap();
assert_eq!(output.remaining, 0);
assert_eq!(output.limit, 5);
assert!(decision.is_denied());
}

Expand All @@ -224,9 +224,11 @@ mod tests {
// Make first request, should be allowed
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
// Request again, should be denied

// Request again immediately afterwards, should now be denied
let (decision, out, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_denied());

// Sleep until reset, should now be allowed
tokio::time::sleep(Duration::from_secs(out.seconds_until_reset())).await;
let (decision, _, _) = backend.request(input).await.unwrap();
Expand All @@ -247,12 +249,14 @@ mod tests {
assert_eq!(output.remaining, 1);
assert_eq!(output.limit, 2);
assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);

// Second of 2 should be allowed.
let (decision, output, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
assert_eq!(output.remaining, 0);
assert_eq!(output.limit, 2);
assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);

// Should be denied
let (decision, output, _) = backend.request(input).await.unwrap();
assert!(decision.is_denied());
Expand Down Expand Up @@ -281,18 +285,20 @@ mod tests {

#[actix_web::test]
async fn test_rollback_key_gone() {
let backend = make_backend("test_rollback_key_gone").await.build();
let key = "test_rollback_key_gone";
let backend = make_backend(key).await.build();
let mut con = backend.connection.clone();
// The rollback could happen after the key has already expired
backend
.rollback("test_rollback_key_gone".to_string())
.await
.unwrap();
// In which case nothing should happen
assert!(!con
.exists::<_, bool>("test_rollback_key_gone")
.await
.unwrap());
// The rollback could happen after the key has already expired / gone
backend.rollback(key.to_string()).await.unwrap();
// In which case the count should remain at 0
let mut cmd = Cmd::new();
cmd.arg("BITFIELD")
.arg(key)
.arg("GET")
.arg(BITFIELD_ENCODING)
.arg(BITFIELD_OFFSET);
let value: Vec<u64> = cmd.query_async(&mut con).await.unwrap();
assert_eq!(value[0], 0u64);
}

#[actix_web::test]
Expand Down

0 comments on commit 8f64b48

Please sign in to comment.