Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: client unpause and client help #4630

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/server/debugcmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ void DebugCmd::Run(CmdArgList args, facade::SinkReplyBuilder* builder) {
"TX",
" Performs transaction analysis per shard.",
"TRAFFIC <path> | [STOP]",
" Starts traffic logging to the specified path. If path is not specified,"
" Starts traffic logging to the specified path. If path is not specified,",
" traffic logging is stopped.",
"RECVSIZE [<tid> | ENABLE | DISABLE]",
" Prints the histogram of the received request sizes on the given thread",
Expand Down
16 changes: 9 additions & 7 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1153,13 +1153,15 @@ void Service::DispatchCommand(ArgSlice args, SinkReplyBuilder* builder,

// Don't interrupt running multi commands or admin connections.
if (etl.IsPaused() && !dispatching_in_multi && cntx->conn() && !cntx->conn()->IsPrivileged()) {
bool is_write = cid->IsWriteOnly();
is_write |= cid->name() == "PUBLISH" || cid->name() == "EVAL" || cid->name() == "EVALSHA";
is_write |= cid->name() == "EXEC" && dfly_cntx->conn_state.exec_info.is_write;

cntx->paused = true;
etl.AwaitPauseState(is_write);
cntx->paused = false;
bool has_sub = args.size() == 2;
if (cid->name() != "CLIENT" || !has_sub || absl::AsciiStrToUpper(args[1]) != "UNPAUSE") {
bool is_write = cid->IsWriteOnly();
is_write |= cid->name() == "PUBLISH" || cid->name() == "EVAL" || cid->name() == "EVALSHA";
is_write |= cid->name() == "EXEC" && dfly_cntx->conn_state.exec_info.is_write;
cntx->paused = true;
etl.AwaitPauseState(is_write);
cntx->paused = false;
}
}

if (auto err = VerifyCommandState(cid, args_no_cmd, *dfly_cntx); err) {
Expand Down
120 changes: 90 additions & 30 deletions src/server/server_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,35 +361,6 @@ void ClientList(CmdArgList args, absl::Span<facade::Listener*> listeners, SinkRe
return rb->SendVerbatimString(result);
}

void ClientPauseCmd(CmdArgList args, vector<facade::Listener*> listeners, SinkReplyBuilder* builder,
ConnectionContext* cntx) {
CmdArgParser parser(args);

auto timeout = parser.Next<uint64_t>();
ClientPause pause_state = ClientPause::ALL;
if (parser.HasNext()) {
pause_state = parser.MapNext("WRITE", ClientPause::WRITE, "ALL", ClientPause::ALL);
}
if (auto err = parser.Error(); err) {
return builder->SendError(err->MakeReply());
}

const auto timeout_ms = timeout * 1ms;
auto is_pause_in_progress = [end_time = chrono::steady_clock::now() + timeout_ms] {
return ServerState::tlocal()->gstate() != GlobalState::SHUTTING_DOWN &&
chrono::steady_clock::now() < end_time;
};

if (auto pause_fb_opt =
Pause(listeners, cntx->ns, cntx->conn(), pause_state, std::move(is_pause_in_progress));
pause_fb_opt) {
pause_fb_opt->Detach();
builder->SendOk();
} else {
builder->SendError("Failed to pause all running clients");
}
}

void ClientTracking(CmdArgList args, SinkReplyBuilder* builder, ConnectionContext* cntx) {
auto* rb = static_cast<RedisReplyBuilder*>(builder);
if (!rb->IsResp3())
Expand Down Expand Up @@ -964,6 +935,8 @@ void ServerFamily::JoinSnapshotSchedule() {
void ServerFamily::Shutdown() {
VLOG(1) << "ServerFamily::Shutdown";

client_pause_fb_.JoinIfNeeded();

load_fiber_.JoinIfNeeded();

JoinSnapshotSchedule();
Expand Down Expand Up @@ -1882,6 +1855,56 @@ void ServerFamily::Auth(CmdArgList args, const CommandContext& cmd_cntx) {
}
}

void ServerFamily::ClientUnPauseCmd(CmdArgList args, SinkReplyBuilder* builder) {
if (!args.empty()) {
builder->SendError(facade::kSyntaxErr);
return;
}
std::unique_lock lk(client_pause_mu_);
client_pause_.store(false, std::memory_order_relaxed);
client_pause_fb_.JoinIfNeeded();
builder->SendOk();
}

void ClientHelp(SinkReplyBuilder* builder) {
string_view help_arr[] = {
"CLIENT <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
"CACHING (YES|NO)",
" Enable/disable tracking of the keys for next command in OPTIN/OPTOUT modes.",
"GETNAME",
" Return the name of the current connection.",
"ID",
" Return the ID of the current connection.",
"KILL <ip:port>",
" Kill connection made from <ip:port>.",
"KILL <option> <value> [<option> <value> [...]]",
" Kill connections. Options are:",
" * ADDR (<ip:port>|<unixsocket>:0)",
" Kill connections made from the specified address",
" * LADDR (<ip:port>|<unixsocket>:0)",
" Kill connections made to specified local address",
" * ID <client-id>",
" Kill connections by client id.",
"LIST",
" Return information about client connections.",
"UNPAUSE",
" Stop the current client pause, resuming traffic.",
"PAUSE <timeout> [WRITE|ALL]",
" Suspend all, or just write, clients for <timeout> milliseconds.",
"SETNAME <name>",
" Assign the name <name> to the current connection.",
"SETINFO <option> <value>",
"Set client meta attr. Options are:",
" * LIB-NAME: the client lib name.",
" * LIB-VER: the client lib version.",
"TRACKING (ON|OFF) [OPTIN] [OPTOUT] [NOLOOP]",
" Control server assisted client side caching.",
"HELP",
" Print this help."};
auto* rb = static_cast<RedisReplyBuilder*>(builder);
return rb->SendSimpleStrArr(help_arr);
}

void ServerFamily::Client(CmdArgList args, const CommandContext& cmd_cntx) {
string sub_cmd = absl::AsciiStrToUpper(ArgS(args, 0));
CmdArgList sub_args = args.subspan(1);
Expand All @@ -1895,7 +1918,9 @@ void ServerFamily::Client(CmdArgList args, const CommandContext& cmd_cntx) {
} else if (sub_cmd == "LIST") {
return ClientList(sub_args, absl::MakeSpan(listeners_), builder, cntx);
} else if (sub_cmd == "PAUSE") {
return ClientPauseCmd(sub_args, GetNonPriviligedListeners(), builder, cntx);
return ClientPauseCmd(sub_args, builder, cntx);
} else if (sub_cmd == "UNPAUSE") {
return ClientUnPauseCmd(sub_args, builder);
} else if (sub_cmd == "TRACKING") {
return ClientTracking(sub_args, builder, cntx);
} else if (sub_cmd == "KILL") {
Expand All @@ -1906,6 +1931,8 @@ void ServerFamily::Client(CmdArgList args, const CommandContext& cmd_cntx) {
return ClientSetInfo(sub_args, builder, cntx);
} else if (sub_cmd == "ID") {
return ClientId(sub_args, builder, cntx);
} else if (sub_cmd == "HELP") {
return ClientHelp(builder);
}

LOG_FIRST_N(ERROR, 10) << "Subcommand " << sub_cmd << " not supported";
Expand Down Expand Up @@ -3179,6 +3206,39 @@ void ServerFamily::Module(CmdArgList args, const CommandContext& cmd_cntx) {
rb->SendLong(20'000); // we target v2
}

void ServerFamily::ClientPauseCmd(CmdArgList args, SinkReplyBuilder* builder,
ConnectionContext* cntx) {
CmdArgParser parser(args);
auto listeners = GetNonPriviligedListeners();

auto timeout = parser.Next<uint64_t>();
ClientPause pause_state = ClientPause::ALL;
if (parser.HasNext()) {
pause_state = parser.MapNext("WRITE", ClientPause::WRITE, "ALL", ClientPause::ALL);
}
if (auto err = parser.Error(); err) {
return builder->SendError(err->MakeReply());
}

const auto timeout_ms = timeout * 1ms;
auto is_pause_in_progress = [this, end_time = chrono::steady_clock::now() + timeout_ms] {
return ServerState::tlocal()->gstate() != GlobalState::SHUTTING_DOWN &&
chrono::steady_clock::now() < end_time && client_pause_.load(std::memory_order_relaxed);
};

std::unique_lock lk(client_pause_mu_);
if (auto pause_fb_opt =
Pause(listeners, cntx->ns, cntx->conn(), pause_state, std::move(is_pause_in_progress));
pause_fb_opt) {
client_pause_fb_.JoinIfNeeded();
client_pause_.store(true, std::memory_order_relaxed);
client_pause_fb_ = std::move(*pause_fb_opt);
builder->SendOk();
} else {
builder->SendError("Failed to pause all running clients");
}
}

#define HFUNC(x) SetHandler(HandlerFunc(this, &ServerFamily::x))

namespace acl {
Expand Down
8 changes: 7 additions & 1 deletion src/server/server_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ class ServerFamily {

static bool DoAuth(ConnectionContext* cntx, std::string_view username, std::string_view password);

void ClientPauseCmd(CmdArgList args, SinkReplyBuilder* builder, ConnectionContext* cntx);
void ClientUnPauseCmd(CmdArgList args, SinkReplyBuilder* builder);

util::fb2::Fiber snapshot_schedule_fb_;
util::fb2::Fiber load_fiber_;

Expand All @@ -334,7 +337,7 @@ class ServerFamily {
bool accepting_connections_ = true;
util::ProactorBase* pb_task_ = nullptr;

mutable util::fb2::Mutex replicaof_mu_, save_mu_;
mutable util::fb2::Mutex replicaof_mu_, save_mu_, client_pause_mu_;
std::shared_ptr<Replica> replica_ ABSL_GUARDED_BY(replicaof_mu_);
std::vector<std::unique_ptr<Replica>> cluster_replicas_
ABSL_GUARDED_BY(replicaof_mu_); // used to replicating multiple nodes to single dragonfly
Expand All @@ -358,6 +361,9 @@ class ServerFamily {
std::unique_ptr<util::fb2::FiberQueueThreadPool> fq_threadpool_;
std::shared_ptr<detail::SnapshotStorage> snapshot_storage_;

std::atomic<bool> client_pause_ = false;
util::fb2::Fiber client_pause_fb_;

// protected by save_mu_
util::fb2::Fiber bg_save_fb_;

Expand Down
26 changes: 26 additions & 0 deletions tests/dragonfly/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,3 +1160,29 @@ async def push_pipeline(bad_actor_client, size=1):
info = await good_client.info()

assert info["dispatch_queue_bytes"] == 0


async def test_client_unpause(df_factory):
server = df_factory.create()
server.start()

async_client = server.client()
await async_client.client_pause(15000, all=False)

async def set_foo():
client = server.client()
await client.execute_command("SET", "foo", "bar")

p1 = asyncio.create_task(set_foo())

await asyncio.sleep(2)
assert not p1.done()

async with async_timeout.timeout(2):
await async_client.client_unpause()

await p1
assert p1.done()

await async_client.client_pause(5, all=False)
server.stop()
Loading