Skip to content

Commit

Permalink
Merge pull request #46 from esl/delayed-check-servers
Browse files Browse the repository at this point in the history
Pause on all nodes during join
  • Loading branch information
chrzaszcz authored Jan 4, 2024
2 parents 0e3f83e + 6a49936 commit 7eca9e9
Show file tree
Hide file tree
Showing 3 changed files with 378 additions and 54 deletions.
90 changes: 59 additions & 31 deletions src/cets.erl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
delete_objects/2,
dump/1,
remote_dump/1,
send_dump/4,
send_dump/5,
table_name/1,
other_nodes/1,
get_nodes_request/1,
Expand Down Expand Up @@ -119,6 +119,7 @@
-type table_name() :: atom().
-type pause_monitor() :: reference().
-type servers() :: ordsets:ordset(server_pid()).
-type node_down_event() :: #{node => node(), pid => pid(), reason => term()}.
-type state() :: #{
tab := table_name(),
keypos := pos_integer(),
Expand All @@ -131,7 +132,7 @@
opts := start_opts(),
backlog := [backlog_entry()],
pause_monitors := [pause_monitor()],
node_down_history := [node()]
node_down_history := [node_down_event()]
}.

-type long_msg() ::
Expand All @@ -146,17 +147,19 @@
| {unpause, reference()}
| get_leader
| {set_leader, boolean()}
| {send_dump, servers(), join_ref(), [tuple()]}.
| {send_dump, servers(), join_ref(), pause_monitor(), [tuple()]}.

-type info() :: #{
table := table_name(),
nodes := [node()],
other_servers := [pid()],
size := non_neg_integer(),
memory := non_neg_integer(),
ack_pid := ack_pid(),
join_ref := join_ref(),
opts := start_opts(),
node_down_history := [node()]
node_down_history := [node_down_event()],
pause_monitors := [pause_monitor()]
}.

-type handle_down_fun() :: fun((#{remote_pid := server_pid(), table := table_name()}) -> ok).
Expand All @@ -181,6 +184,7 @@
long_msg/0,
info/0,
table_name/0,
pause_monitor/0,
servers/0,
response_return/0,
response_timeout/0
Expand Down Expand Up @@ -230,10 +234,11 @@ table_name(Tab) when is_atom(Tab) ->
table_name(Server) ->
cets_call:long_call(Server, table_name).

-spec send_dump(server_ref(), servers(), join_ref(), [tuple()]) -> ok.
send_dump(Server, NewPids, JoinRef, OurDump) ->
-spec send_dump(server_ref(), servers(), join_ref(), pause_monitor(), [tuple()]) ->
ok | {error, ignored}.
send_dump(Server, NewPids, JoinRef, PauseRef, OurDump) ->
Info = #{msg => send_dump, join_ref => JoinRef, count => length(OurDump)},
cets_call:long_call(Server, {send_dump, NewPids, JoinRef, OurDump}, Info).
cets_call:long_call(Server, {send_dump, NewPids, JoinRef, PauseRef, OurDump}, Info).

%% Only the node that owns the data could update/remove the data.
%% Ideally, Key should contain inserter node info so cleaning and merging is simplified.
Expand Down Expand Up @@ -461,8 +466,8 @@ handle_call(remote_dump, From, State = #{tab := Tab}) ->
%% Do not block the main process (also reduces GC of the main process)
proc_lib:spawn_link(fun() -> gen_server:reply(From, {ok, dump(Tab)}) end),
{noreply, State};
handle_call({send_dump, NewPids, JoinRef, Dump}, _From, State) ->
handle_send_dump(NewPids, JoinRef, Dump, State);
handle_call({send_dump, NewPids, JoinRef, PauseRef, Dump}, _From, State) ->
handle_send_dump(NewPids, JoinRef, PauseRef, Dump, State);
handle_call(pause, _From = {FromPid, _}, State = #{pause_monitors := Mons}) ->
%% We monitor who pauses our server
Mon = erlang:monitor(process, FromPid),
Expand All @@ -483,11 +488,10 @@ handle_cast(Msg, State) ->
handle_info({remote_op, Op, From, AckPid, JoinRef}, State) ->
handle_remote_op(Op, From, AckPid, JoinRef, State),
{noreply, State};
handle_info({'DOWN', Mon, process, Pid, _Reason}, State) ->
{noreply, handle_down(Mon, Pid, State)};
handle_info({'DOWN', Mon, process, Pid, Reason}, State) ->
{noreply, handle_down(Mon, Pid, Reason, State)};
handle_info({check_server, FromPid, JoinRef}, State) ->
handle_check_server(FromPid, JoinRef, State),
{noreply, State};
{noreply, handle_check_server(FromPid, JoinRef, State)};
handle_info(Msg, State) ->
?LOG_ERROR(#{what => unexpected_info, msg => Msg}),
{noreply, State}.
Expand All @@ -500,34 +504,49 @@ code_change(_OldVsn, State, _Extra) ->

%% Internal logic

-spec handle_send_dump(servers(), join_ref(), [tuple()], state()) -> {reply, ok, state()}.
handle_send_dump(NewPids, JoinRef, Dump, State = #{tab := Tab, other_servers := Servers}) ->
ets:insert(Tab, Dump),
Servers2 = add_servers(NewPids, Servers),
{reply, ok, set_other_servers(Servers2, State#{join_ref := JoinRef})}.
-spec handle_send_dump(servers(), join_ref(), pause_monitor(), [tuple()], state()) ->
{reply, ok, state()}.
handle_send_dump(NewPids, JoinRef, PauseRef, Dump, State) ->
#{tab := Tab, other_servers := Servers, pause_monitors := PauseMons} = State,
case lists:member(PauseRef, PauseMons) of
true ->
ets:insert(Tab, Dump),
Servers2 = add_servers(NewPids, Servers),
{reply, ok, set_other_servers(Servers2, State#{join_ref := JoinRef})};
false ->
?LOG_ERROR(#{
what => send_dump_received_when_unpaused,
text => <<"Received send_dump message while in the unpaused state. Ignore it">>,
join_ref => JoinRef,
pause_ref => PauseRef,
state => State
}),
{reply, {error, ignored}, State}
end.

-spec handle_down(reference(), pid(), state()) -> state().
handle_down(Mon, Pid, State = #{pause_monitors := Mons}) ->
-spec handle_down(reference(), pid(), term(), state()) -> state().
handle_down(Mon, Pid, Reason, State = #{pause_monitors := Mons}) ->
case lists:member(Mon, Mons) of
true ->
?LOG_ERROR(#{
what => pause_owner_crashed,
state => State,
paused_by_pid => Pid
paused_by_pid => Pid,
reason => Reason
}),
handle_unpause2(Mon, Mons, State);
false ->
handle_down2(Pid, State)
handle_down2(Pid, Reason, State)
end.

-spec handle_down2(pid(), state()) -> state().
handle_down2(RemotePid, State = #{other_servers := Servers, ack_pid := AckPid}) ->
-spec handle_down2(pid(), term(), state()) -> state().
handle_down2(RemotePid, Reason, State = #{other_servers := Servers, ack_pid := AckPid}) ->
case lists:member(RemotePid, Servers) of
true ->
cets_ack:send_remote_down(AckPid, RemotePid),
call_user_handle_down(RemotePid, State),
Servers2 = lists:delete(RemotePid, Servers),
update_node_down_history(RemotePid, set_other_servers(Servers2, State));
update_node_down_history(RemotePid, Reason, set_other_servers(Servers2, State));
false ->
%% This should not happen
?LOG_ERROR(#{
Expand All @@ -538,8 +557,9 @@ handle_down2(RemotePid, State = #{other_servers := Servers, ack_pid := AckPid})
State
end.

update_node_down_history(RemotePid, State = #{node_down_history := History}) ->
State#{node_down_history := [node(RemotePid) | History]}.
update_node_down_history(RemotePid, Reason, State = #{node_down_history := History}) ->
Item = #{node => node(RemotePid), pid => RemotePid, reason => Reason},
State#{node_down_history := [Item | History]}.

%% Merge two lists of pids, create the missing monitors.
-spec add_servers(Servers, Servers) -> Servers when Servers :: servers().
Expand Down Expand Up @@ -726,9 +746,14 @@ send_check_server(Pid, JoinRef) ->
Pid ! {check_server, self(), JoinRef},
ok.

handle_check_server(_FromPid, JoinRef, #{join_ref := JoinRef}) ->
ok;
handle_check_server(FromPid, RemoteJoinRef, #{join_ref := JoinRef}) ->
%% That could actually arrive before we get fully unpaused
%% (though cets_join:pause_on_remote_node/2 would ensure that CETS server
%% would send check_server only after cets_join is down
%% and does not send new send_dump messages)
handle_check_server(_FromPid, JoinRef, State = #{join_ref := JoinRef}) ->
%% check_server passed - do nothing
State;
handle_check_server(FromPid, RemoteJoinRef, State = #{join_ref := JoinRef}) ->
?LOG_WARNING(#{
what => cets_check_server_failed,
text => <<"Disconnect the remote server">>,
Expand All @@ -739,7 +764,7 @@ handle_check_server(FromPid, RemoteJoinRef, #{join_ref := JoinRef}) ->
%% Ask the remote server to disconnect from us
Reason = {check_server_failed, {RemoteJoinRef, JoinRef}},
FromPid ! {'DOWN', make_ref(), process, self(), Reason},
ok.
State.

-spec handle_get_info(state()) -> info().
handle_get_info(
Expand All @@ -749,17 +774,20 @@ handle_get_info(
ack_pid := AckPid,
join_ref := JoinRef,
node_down_history := DownHistory,
pause_monitors := PauseMons,
opts := Opts
}
) ->
#{
table => Tab,
nodes => lists:usort(pids_to_nodes([self() | Servers])),
other_servers => Servers,
size => ets:info(Tab, size),
memory => ets:info(Tab, memory),
ack_pid => AckPid,
join_ref => JoinRef,
node_down_history => DownHistory,
pause_monitors => PauseMons,
opts => Opts
}.

Expand Down
74 changes: 65 additions & 9 deletions src/cets_join.erl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
-export([join/5]).
-include_lib("kernel/include/logger.hrl").

%% Export for RPC
-export([pause_on_remote_node/2]).

-ifdef(TEST).
-export([check_could_reach_each_other/3]).
-endif.

-type lock_key() :: term().
-type join_ref() :: reference().
-type server_pid() :: cets:server_pid().
-type rpc_result() :: {Class :: throw | exit | error, Reason :: term()} | {ok, ok}.

%% Critical events during the joining procedure
-type checkpoint() ::
Expand All @@ -20,14 +24,15 @@
| before_get_pids
| before_check_fully_connected
| before_unpause
| {before_send_dump, server_pid()}.
| {before_send_dump, server_pid()}
| {after_send_dump, server_pid(), Result :: term()}.

-type checkpoint_handler() :: fun((checkpoint()) -> ok).
-type join_opts() :: #{checkpoint_handler => checkpoint_handler()}.
-type join_opts() :: #{checkpoint_handler => checkpoint_handler(), join_ref => reference()}.

-export_type([join_ref/0]).

-ignore_xref([join/5]).
-ignore_xref([join/5, pause_on_remote_node/2]).

%% Adds a node to a cluster.
%% Writes from other nodes would wait for join completion.
Expand Down Expand Up @@ -100,7 +105,7 @@ join_loop(LockKey, Info, LocalPid, RemotePid, Start, JoinOpts) ->
-spec join2(cets_long:log_info(), server_pid(), server_pid(), join_opts()) -> ok.
join2(Info, LocalPid, RemotePid, JoinOpts) ->
checkpoint(join_start, JoinOpts),
JoinRef = make_ref(),
JoinRef = maps:get(join_ref, JoinOpts, make_ref()),
%% Joining is a symmetrical operation here - both servers exchange information between each other.
%% We still use LocalPid/RemotePid in names
%% (they are local and remote pids as passed from the cets_join and from the cets_discovery).
Expand All @@ -110,7 +115,7 @@ join2(Info, LocalPid, RemotePid, JoinOpts) ->
RemPids = get_pids(RemotePid),
check_pids(Info, LocPids, RemPids, JoinOpts),
AllPids = LocPids ++ RemPids,
Paused = [{Pid, cets:pause(Pid)} || Pid <- AllPids],
Paused = pause_servers(AllPids),
%% Merges data from two partitions together.
%% Each entry in the table is allowed to be updated by the node that owns
%% the key only, so merging is easy.
Expand All @@ -124,8 +129,8 @@ join2(Info, LocalPid, RemotePid, JoinOpts) ->
check_fully_connected(Info, LocPids),
check_fully_connected(Info, RemPids),
{LocalDump2, RemoteDump2} = maybe_apply_resolver(LocalDump, RemoteDump, ServerOpts),
RemF = fun(Pid) -> send_dump(Pid, LocPids, JoinRef, LocalDump2, JoinOpts) end,
LocF = fun(Pid) -> send_dump(Pid, RemPids, JoinRef, RemoteDump2, JoinOpts) end,
RemF = fun(Pid) -> send_dump(Pid, Paused, LocPids, JoinRef, LocalDump2, JoinOpts) end,
LocF = fun(Pid) -> send_dump(Pid, Paused, RemPids, JoinRef, RemoteDump2, JoinOpts) end,
lists:foreach(LocF, LocPids),
lists:foreach(RemF, RemPids),
ok
Expand All @@ -135,10 +140,51 @@ join2(Info, LocalPid, RemotePid, JoinOpts) ->
lists:foreach(fun({Pid, Ref}) -> catch cets:unpause(Pid, Ref) end, Paused)
end.

send_dump(Pid, Pids, JoinRef, Dump, JoinOpts) ->
-spec pause_servers(AllPids :: [pid(), ...]) -> Paused :: [{pid(), cets:pause_monitor()}].
pause_servers(AllPids) ->
%% We should create a pause helper process on each node in the cluster.
%% It is to ensure that node that losing a connection with cets_join coordinator
%% would not unpause one of the processes too soon
%% (because it could start sending remote ops to nodes which are still in the current joining procedure).
Paused = [{Pid, cets:pause(Pid)} || Pid <- AllPids],
OtherNodes = lists:delete(node(), lists:usort([node(Pid) || Pid <- AllPids])),
Results = erpc:multicall(
OtherNodes, ?MODULE, pause_on_remote_node, [self(), AllPids], timer:seconds(30)
),
assert_all_ok(OtherNodes, Results),
Paused.

-spec pause_on_remote_node(pid(), [pid()]) -> ok.
pause_on_remote_node(JoinerPid, AllPids) ->
Self = self(),
{Pid, Mon} = spawn_monitor(fun() ->
JoinerMon = erlang:monitor(process, JoinerPid),
MyNode = node(),
%% Ignore pids on the current node
%% (because we only interested in internode connections here).
%% Catching because we can ignore losing some connections here.
_Pauses = [catch cets:pause(Pid) || Pid <- AllPids, node(Pid) =/= MyNode],
Self ! {ready, self()},
receive
{'DOWN', JoinerMon, process, JoinerPid, _Reason} ->
%% Exit and release pauses
ok
end
end),
receive
{'DOWN', Mon, process, Pid, _Reason} ->
ok;
{ready, Pid} ->
ok
end.

send_dump(Pid, Paused, Pids, JoinRef, Dump, JoinOpts) ->
PauseRef = proplists:get_value(Pid, Paused),
checkpoint({before_send_dump, Pid}, JoinOpts),
%% Error reporting would be done by cets_long:call_tracked
catch cets:send_dump(Pid, Pids, JoinRef, Dump).
Result = catch cets:send_dump(Pid, Pids, JoinRef, PauseRef, Dump),
checkpoint({after_send_dump, Pid, Result}, JoinOpts),
ok.

remote_or_local_dump(Pid) when node(Pid) =:= node() ->
{ok, Tab} = cets:table_name(Pid),
Expand Down Expand Up @@ -305,6 +351,16 @@ pid_to_join_ref(Pid) ->
#{join_ref := JoinRef} = cets:info(Pid),
JoinRef.

-spec assert_all_ok(Nodes :: [node()], Results :: [rpc_result()]) -> ok.
assert_all_ok(Nodes, Results) ->
Zip = lists:zip(Nodes, Results),
case lists:filter(fun({_Node, Res}) -> Res =/= {ok, ok} end, Zip) of
[] ->
ok;
BadZip ->
error({assert_all_ok, BadZip})
end.

%% Checkpoints are used for testing
%% Checkpoints do nothing in production
-spec checkpoint(checkpoint(), join_opts()) -> ok.
Expand Down
Loading

0 comments on commit 7eca9e9

Please sign in to comment.