Skip to content

Commit

Permalink
Make CPUConcurrencyController a shared_ptr
Browse files Browse the repository at this point in the history
Summary: Technically this is a behavioral change, but Kirill think it should be harmless

Reviewed By: sazonovkirill

Differential Revision: D67947852

fbshipit-source-id: c419d10b70e7c6c4ea0556587a5e9855a3b98a48
  • Loading branch information
Evan Zou authored and facebook-github-bot committed Feb 4, 2025
1 parent b1c77bd commit 45447d0
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 27 deletions.
4 changes: 2 additions & 2 deletions third-party/thrift/src/thrift/lib/cpp2/server/ServerConfigs.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class ServerConfigs {
virtual const AdaptiveConcurrencyController&
getAdaptiveConcurrencyController() const = 0;

virtual CPUConcurrencyController& getCPUConcurrencyController() = 0;
virtual const CPUConcurrencyController& getCPUConcurrencyController()
virtual CPUConcurrencyController* getCPUConcurrencyController() = 0;
virtual const CPUConcurrencyController* getCPUConcurrencyController()
const = 0;

// @see ThriftServer::getNumIOWorkerThreads function.
Expand Down
15 changes: 11 additions & 4 deletions third-party/thrift/src/thrift/lib/cpp2/server/ThriftServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ ThriftServer::ThriftServer()
apache::thrift::detail::makeAdaptiveConcurrencyConfig(),
thriftConfig_.getMaxRequests().getObserver(),
detail::getThriftServerConfig(*this)},
cpuConcurrencyController_{
cpuConcurrencyController_{std::make_shared<CPUConcurrencyController>(
makeCPUConcurrencyControllerConfigInternal(),
*this,
detail::getThriftServerConfig(*this)},
detail::getThriftServerConfig(*this))},
addresses_(1),
wShutdownSocketSet_(folly::tryGetShutdownSocketSet()),
lastRequestTime_(
Expand Down Expand Up @@ -1957,7 +1957,7 @@ folly::Optional<OverloadResult> ThriftServer::checkOverload(
!getMethodsBypassMaxRequestsLimit().contains(method) &&
static_cast<uint32_t>(getActiveRequests()) >= maxRequests) {
LoadShedder loadShedder = LoadShedder::MAX_REQUESTS;
if (getCPUConcurrencyController().requestShed(
if (notifyCPUConcurrencyControllerOnRequestLoadShed(
CPUConcurrencyController::Method::MAX_REQUESTS)) {
loadShedder = LoadShedder::CPU_CONCURRENCY_CONTROLLER;
} else if (getAdaptiveConcurrencyController().enabled()) {
Expand All @@ -1975,7 +1975,7 @@ folly::Optional<OverloadResult> ThriftServer::checkOverload(
!getMethodsBypassMaxRequestsLimit().contains(method) &&
!qpsTokenBucket_.consume(1.0, maxQps, maxQps)) {
LoadShedder loadShedder = LoadShedder::MAX_QPS;
if (getCPUConcurrencyController().requestShed(
if (notifyCPUConcurrencyControllerOnRequestLoadShed(
CPUConcurrencyController::Method::MAX_QPS)) {
loadShedder = LoadShedder::CPU_CONCURRENCY_CONTROLLER;
}
Expand Down Expand Up @@ -2544,4 +2544,11 @@ bool ThriftServer::getTaskExpireTimeForRequest(
}
return queueTimeout != taskTimeout;
}

bool ThriftServer::notifyCPUConcurrencyControllerOnRequestLoadShed(
std::optional<CPUConcurrencyController::Method> method) {
auto* cpuConcurrencyControllerPtr = getCPUConcurrencyController();
return cpuConcurrencyControllerPtr != nullptr &&
cpuConcurrencyControllerPtr->requestShed(method);
}
} // namespace apache::thrift
19 changes: 13 additions & 6 deletions third-party/thrift/src/thrift/lib/cpp2/server/ThriftServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <atomic>
#include <chrono>
#include <cstdlib>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
Expand Down Expand Up @@ -902,6 +901,11 @@ class ThriftServer : public apache::thrift::concurrency::Runnable,

void disableInfoLogging() { infoLoggingEnabled_ = false; }

void setCPUConcurrencyController(
std::shared_ptr<CPUConcurrencyController> controller) {
cpuConcurrencyController_ = std::move(controller);
}

private:
friend ThriftServerConfig& detail::getThriftServerConfig(ThriftServer&);

Expand Down Expand Up @@ -1015,7 +1019,7 @@ class ThriftServer : public apache::thrift::concurrency::Runnable,
mockCPUConcurrencyControllerConfig_{std::nullopt};
folly::observer::Observer<CPUConcurrencyController::Config>
makeCPUConcurrencyControllerConfigInternal();
CPUConcurrencyController cpuConcurrencyController_;
std::shared_ptr<CPUConcurrencyController> cpuConcurrencyController_;

//! The server's listening addresses
std::vector<folly::SocketAddress> addresses_;
Expand Down Expand Up @@ -1077,14 +1081,17 @@ class ThriftServer : public apache::thrift::concurrency::Runnable,
return adaptiveConcurrencyController_;
}

CPUConcurrencyController& getCPUConcurrencyController() final {
return cpuConcurrencyController_;
CPUConcurrencyController* getCPUConcurrencyController() final {
return cpuConcurrencyController_.get();
}

const CPUConcurrencyController& getCPUConcurrencyController() const final {
return cpuConcurrencyController_;
const CPUConcurrencyController* getCPUConcurrencyController() const final {
return cpuConcurrencyController_.get();
}

bool notifyCPUConcurrencyControllerOnRequestLoadShed(
std::optional<CPUConcurrencyController::Method> method);

void setMockCPUConcurrencyControllerConfig(
CPUConcurrencyController::Config config) {
mockCPUConcurrencyControllerConfig_.setValue(config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ folly::Optional<OverloadResult> QpsOverloadChecker::checkOverload(
if (!server_.getMethodsBypassMaxRequestsLimit().contains(*params.method) &&
!qpsTokenBucket_.consume(1.0, maxQps, maxQps)) {
LoadShedder loadShedder = LoadShedder::MAX_QPS;
if (server_.getCPUConcurrencyController().requestShed(
if (auto* cpuConcurrencyController = server_.getCPUConcurrencyController();
cpuConcurrencyController != nullptr &&
cpuConcurrencyController->requestShed(
CPUConcurrencyController::Method::MAX_QPS)) {
loadShedder = LoadShedder::CPU_CONCURRENCY_CONTROLLER;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ folly::Optional<OverloadResult> QueueConcurrencyOverloadChecker::checkOverload(
!server_.getMethodsBypassMaxRequestsLimit().contains(*params.method)) &&
static_cast<uint32_t>(server_.getActiveRequests()) >= maxRequests) {
LoadShedder loadShedder = LoadShedder::MAX_REQUESTS;
if (server_.getCPUConcurrencyController().requestShed(
if (auto* cpuConcurrencyController = server_.getCPUConcurrencyController();
cpuConcurrencyController->requestShed(
CPUConcurrencyController::Method::MAX_REQUESTS)) {
loadShedder = LoadShedder::CPU_CONCURRENCY_CONTROLLER;
} else if (server_.getAdaptiveConcurrencyController().enabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ TEST(MockCpuConcurrencyControllerConfigTest, testOverride) {
CPUConcurrencyController::Config{.concurrencyLowerBound = 1111});
folly::observer_detail::ObserverManager::waitForAllUpdates();

auto config = server.getCPUConcurrencyController().config();
auto* cpuConcurrencyController = server.getCPUConcurrencyController();
ASSERT_NE(cpuConcurrencyController, nullptr);
auto config = cpuConcurrencyController->config();
ASSERT_EQ(config->concurrencyLowerBound, 1111);
ASSERT_TRUE(kMakeCPUConcurrencyControllerConfigCalled);
}
Expand All @@ -57,7 +59,9 @@ TEST(MockCpuConcurrencyControllerConfigTest, testBase) {
kMakeCPUConcurrencyControllerConfigCalled = false;

ThriftServer server;
auto config = server.getCPUConcurrencyController().config();
auto* cpuConcurrencyController = server.getCPUConcurrencyController();
ASSERT_NE(cpuConcurrencyController, nullptr);
auto config = cpuConcurrencyController->config();
ASSERT_EQ(config->concurrencyLowerBound, 2222);
ASSERT_TRUE(kMakeCPUConcurrencyControllerConfigCalled);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

namespace cpp2 apache.thrift.test

service TestService {
string echo(1: string str);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>

#include <folly/coro/GtestHelpers.h>
#include <thrift/lib/cpp2/server/ThriftServer.h>
#include <thrift/lib/cpp2/server/test/gen-cpp2/TestService_clients.h>
#include <thrift/lib/cpp2/server/test/gen-cpp2/TestService_handlers.h>
#include <thrift/lib/cpp2/util/ScopedServerInterfaceThread.h>

using namespace ::testing;
using namespace apache::thrift;

namespace {

class TestServicehandler
: public apache::thrift::ServiceHandler<test::TestService> {
folly::coro::Task<std::unique_ptr<std::string>> co_echo(
std::unique_ptr<std::string> str) override {
co_return std::move(str);
}
};

} // namespace

CO_TEST(ThriftServerCpuCCTest, CPUConcurrencyControllerCanBeNull) {
auto handler = std::make_shared<TestServicehandler>();
auto server = std::make_shared<ScopedServerInterfaceThread>(
std::move(handler), [&](ThriftServer& thriftServer) {
thriftServer.setCPUConcurrencyController(nullptr);
});
EXPECT_EQ(server->getThriftServer().getCPUConcurrencyController(), nullptr);

std::unique_ptr<Client<test::TestService>> client =
server->newClient<Client<test::TestService>>();

std::string result = co_await client->co_echo("hello");
EXPECT_EQ(result, "hello");
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ namespace apache::thrift {
RequestStateMachine::RequestStateMachine(
bool includeInRecentRequests,
AdaptiveConcurrencyController& controller,
CPUConcurrencyController& cpuController)
CPUConcurrencyController* cpuController)
: includeInRecentRequests_(includeInRecentRequests),
adaptiveConcurrencyController_(controller),
cpuController_(cpuController) {
if (includeInRecentRequests_) {
adaptiveConcurrencyController_.requestStarted(started());
cpuController_.requestStarted();
if (cpuController_ != nullptr) {
cpuController_->requestStarted();
}
}
}

Expand Down Expand Up @@ -65,7 +67,9 @@ RequestStateMachine::~RequestStateMachine() {
[[nodiscard]] bool RequestStateMachine::tryStopProcessing() {
if (!startProcessingOrQueueTimeout_.exchange(
true, std::memory_order_relaxed)) {
cpuController_.requestShed();
if (cpuController_ != nullptr) {
cpuController_->requestShed();
}
dequeued_.store(
std::chrono::steady_clock::now(), std::memory_order_relaxed);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class RequestStateMachine {
RequestStateMachine(
bool includeInRecentRequests,
AdaptiveConcurrencyController& controller,
CPUConcurrencyController& cpuController);
CPUConcurrencyController* cpuController);

~RequestStateMachine();

Expand Down Expand Up @@ -97,7 +97,7 @@ class RequestStateMachine {
std::atomic<std::chrono::steady_clock::time_point> dequeued_{
std::chrono::steady_clock::time_point::min()};
AdaptiveConcurrencyController& adaptiveConcurrencyController_;
CPUConcurrencyController& cpuController_;
CPUConcurrencyController* cpuController_;
};

} // namespace apache::thrift
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,13 @@ ThriftRequestCore::LogRequestSampleCallback::buildRequestLoggingContext(
requestLoggingContext.clientTimeoutMs = thriftRequest.clientTimeout_;

// CPUConcurrencyController mode
requestLoggingContext.cpuConcurrencyControllerMode = static_cast<uint8_t>(
serverConfigs_.getCPUConcurrencyController().config()->mode);
if (serverConfigs_.getCPUConcurrencyController() != nullptr) {
requestLoggingContext.cpuConcurrencyControllerMode = static_cast<uint8_t>(
serverConfigs_.getCPUConcurrencyController()->config()->mode);
} else {
requestLoggingContext.cpuConcurrencyControllerMode =
static_cast<uint8_t>(CPUConcurrencyController::Mode::DISABLED);
}

return requestLoggingContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ class ServerConfigsMock : public ServerConfigs {
return adaptiveConcurrencyController_;
}

apache::thrift::CPUConcurrencyController& getCPUConcurrencyController()
apache::thrift::CPUConcurrencyController* getCPUConcurrencyController()
override {
return cpuConcurrencyController_;
return &cpuConcurrencyController_;
}

const apache::thrift::CPUConcurrencyController& getCPUConcurrencyController()
const apache::thrift::CPUConcurrencyController* getCPUConcurrencyController()
const override {
return cpuConcurrencyController_;
return &cpuConcurrencyController_;
}

uint32_t getMaxRequests() const override {
Expand Down

0 comments on commit 45447d0

Please sign in to comment.