Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolve BasicTokenBucket thread-safe issue #2327

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 49 additions & 14 deletions folly/TokenBucket.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ class BasicDynamicTokenBucket {
template <typename Policy = TokenBucketPolicyDefault>
class BasicTokenBucket {
private:
template <typename T>
using Atom = typename Policy::template atom<T>;
using Impl = BasicDynamicTokenBucket<Policy>;

public:
Expand All @@ -501,23 +503,37 @@ class BasicTokenBucket {
BasicTokenBucket(
double genRate, double burstSize, double zeroTime = 0) noexcept
: tokenBucket_(zeroTime), rate_(genRate), burstSize_(burstSize) {
assert(rate_ > 0);
assert(burstSize_ > 0);
assert(rate_.load(std::memory_order_acquire) > 0);
assert(burstSize_.load(std::memory_order_acquire) > 0);
}

/**
* Copy constructor.
*
* Warning: not thread safe!
*/
BasicTokenBucket(const BasicTokenBucket& other) noexcept = default;
BasicTokenBucket(const BasicTokenBucket& other) noexcept
: tokenBucket_(other.tokenBucket_),
rate_(other.rate_.load(std::memory_order_acquire)),
burstSize_(other.burstSize_.load(std::memory_order_acquire)) {}

/**
* Copy-assignment operator.
*
* Warning: not thread safe!
*/
BasicTokenBucket& operator=(const BasicTokenBucket& other) noexcept = default;
BasicTokenBucket& operator=(const BasicTokenBucket& other) noexcept {
if (this != &other) {
tokenBucket_ = other.tokenBucket_;
rate_.store(
other.rate_.load(std::memory_order_acquire),
std::memory_order_release);
burstSize_.store(
other.burstSize_.load(std::memory_order_acquire),
std::memory_order_release);
}
return *this;
}

/**
* Returns the current time in seconds since Epoch.
Expand Down Expand Up @@ -578,7 +594,11 @@ class BasicTokenBucket {
* @return True if the rate limit check passed, false otherwise.
*/
bool consume(double toConsume, double nowInSeconds = defaultClockNow()) {
return tokenBucket_.consume(toConsume, rate_, burstSize_, nowInSeconds);
return tokenBucket_.consume(
toConsume,
rate_.load(std::memory_order_acquire),
burstSize_.load(std::memory_order_acquire),
nowInSeconds);
}

/**
Expand All @@ -597,15 +617,19 @@ class BasicTokenBucket {
double consumeOrDrain(
double toConsume, double nowInSeconds = defaultClockNow()) {
return tokenBucket_.consumeOrDrain(
toConsume, rate_, burstSize_, nowInSeconds);
toConsume,
rate_.load(std::memory_order_acquire),
burstSize_.load(std::memory_order_acquire),
nowInSeconds);
}

/**
* Returns extra token back to the bucket. Cannot be negative.
* For negative tokens, setCapacity() can be used
*/
void returnTokens(double tokensToReturn) {
return tokenBucket_.returnTokens(tokensToReturn, rate_);
return tokenBucket_.returnTokens(
tokensToReturn, rate_.load(std::memory_order_acquire));
}

/**
Expand All @@ -615,7 +639,10 @@ class BasicTokenBucket {
Optional<double> consumeWithBorrowNonBlocking(
double toConsume, double nowInSeconds = defaultClockNow()) {
return tokenBucket_.consumeWithBorrowNonBlocking(
toConsume, rate_, burstSize_, nowInSeconds);
toConsume,
rate_.load(std::memory_order_acquire),
burstSize_.load(std::memory_order_acquire),
nowInSeconds);
}

/**
Expand All @@ -624,7 +651,10 @@ class BasicTokenBucket {
bool consumeWithBorrowAndWait(
double toConsume, double nowInSeconds = defaultClockNow()) {
return tokenBucket_.consumeWithBorrowAndWait(
toConsume, rate_, burstSize_, nowInSeconds);
toConsume,
rate_.load(std::memory_order_acquire),
burstSize_.load(std::memory_order_acquire),
nowInSeconds);
}

/**
Expand All @@ -644,27 +674,32 @@ class BasicTokenBucket {
* Thread-safe (but returned value may immediately be outdated).
*/
double balance(double nowInSeconds = defaultClockNow()) const noexcept {
return tokenBucket_.balance(rate_, burstSize_, nowInSeconds);
return tokenBucket_.balance(
rate_.load(std::memory_order_acquire),
burstSize_.load(std::memory_order_acquire),
nowInSeconds);
}

/**
* Returns the number of tokens generated per second.
*
* Thread-safe (but returned value may immediately be outdated).
*/
double rate() const noexcept { return rate_; }
double rate() const noexcept { return rate_.load(std::memory_order_acquire); }

/**
* Returns the maximum burst size.
*
* Thread-safe (but returned value may immediately be outdated).
*/
double burst() const noexcept { return burstSize_; }
double burst() const noexcept {
return burstSize_.load(std::memory_order_acquire);
}

private:
Impl tokenBucket_;
double rate_;
double burstSize_;
Atom<double> rate_;
Atom<double> burstSize_;
};

using TokenBucket = BasicTokenBucket<>;
Expand Down