Skip to content

Commit

Permalink
Merge pull request #101 from JamesWrigley/jps/threadsafe_workerstate
Browse files Browse the repository at this point in the history
  • Loading branch information
IanButterworth authored Jan 22, 2025
2 parents 3b9e4fd + 90041ca commit 0e27cd1
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 61 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Manifest.toml
*.swp
112 changes: 66 additions & 46 deletions src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ mutable struct Worker
del_msgs::Array{Any,1} # XXX: Could del_msgs and add_msgs be Channels?
add_msgs::Array{Any,1}
@atomic gcflag::Bool
state::WorkerState
c_state::Condition # wait for state changes
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily
@atomic state::WorkerState
c_state::Threads.Condition # wait for state changes, lock for state
ct_time::Float64 # creation time
conn_func::Any # used to setup connections lazily

r_stream::IO
w_stream::IO
Expand Down Expand Up @@ -134,7 +134,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand All @@ -144,8 +144,10 @@ mutable struct Worker
end

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
lock(w.c_state) do
@atomic w.state = state
notify(w.c_state; all=true)
end
end

function check_worker_state(w::Worker)
Expand All @@ -161,15 +163,16 @@ function check_worker_state(w::Worker)
else
w.ct_time = time()
if myid() > w.id
t = @async exec_conn_func(w)
t = Threads.@spawn Threads.threadpool() exec_conn_func(w)
else
# route request via node 1
t = @async remotecall_fetch((p,to_id) -> remotecall_fetch(exec_conn_func, p, to_id), 1, w.id, myid())
t = Threads.@spawn Threads.threadpool() remotecall_fetch((p,to_id) -> remotecall_fetch(exec_conn_func, p, to_id), 1, w.id, myid())
end
errormonitor(t)
wait_for_conn(w)
end
end
return nothing
end

exec_conn_func(id::Int) = exec_conn_func(worker_from_id(id)::Worker)
Expand All @@ -191,9 +194,17 @@ function wait_for_conn(w)
timeout = worker_timeout() - (time() - w.ct_time)
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")

@async (sleep(timeout); notify(w.c_state; all=true))
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
T = Threads.@spawn Threads.threadpool() begin
sleep($timeout)
lock(w.c_state) do
notify(w.c_state; all=true)
end
end
errormonitor(T)
lock(w.c_state) do
wait(w.c_state)
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
end
end
nothing
end
Expand Down Expand Up @@ -247,7 +258,7 @@ function start_worker(out::IO, cookie::AbstractString=readline(stdin); close_std
else
sock = listen(interface, LPROC.bind_port)
end
errormonitor(@async while isopen(sock)
errormonitor(Threads.@spawn while isopen(sock)
client = accept(sock)
process_messages(client, client, true)
end)
Expand Down Expand Up @@ -279,7 +290,7 @@ end


function redirect_worker_output(ident, stream)
t = @async while !eof(stream)
t = Threads.@spawn while !eof(stream)
line = readline(stream)
if startswith(line, " From worker ")
# stdout's of "additional" workers started from an initial worker on a host are not available
Expand Down Expand Up @@ -318,7 +329,7 @@ function read_worker_host_port(io::IO)
leader = String[]
try
while ntries > 0
readtask = @async readline(io)
readtask = Threads.@spawn Threads.threadpool() readline(io)
yield()
while !istaskdone(readtask) && ((time_ns() - t0) < timeout)
sleep(0.05)
Expand Down Expand Up @@ -419,7 +430,7 @@ if launching workers programmatically, execute `addprocs` in its own task.
```julia
# On busy clusters, call `addprocs` asynchronously
t = @async addprocs(...)
t = Threads.@spawn addprocs(...)
```
```julia
Expand Down Expand Up @@ -485,20 +496,23 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
# call manager's `launch` is a separate task. This allows the master
# process initiate the connection setup process as and when workers come
# online
t_launch = @async launch(manager, params, launched, launch_ntfy)
t_launch = Threads.@spawn Threads.threadpool() launch(manager, params, launched, launch_ntfy)

@sync begin
while true
if isempty(launched)
istaskdone(t_launch) && break
@async (sleep(1); notify(launch_ntfy))
Threads.@spawn Threads.threadpool() begin
sleep(1)
notify(launch_ntfy)
end
wait(launch_ntfy)
end

if !isempty(launched)
wconfig = popfirst!(launched)
let wconfig=wconfig
@async setup_launched_worker(manager, wconfig, launched_q)
Threads.@spawn Threads.threadpool() setup_launched_worker(manager, wconfig, launched_q)
end
end
end
Expand Down Expand Up @@ -578,7 +592,7 @@ function launch_n_additional_processes(manager, frompid, fromconfig, cnt, launch
wconfig.port = port

let wconfig=wconfig
@async begin
Threads.@spawn Threads.threadpool() begin
pid = create_worker(manager, wconfig)
remote_do(redirect_output_from_additional_worker, frompid, pid, port)
push!(launched_q, pid)
Expand Down Expand Up @@ -645,7 +659,12 @@ function create_worker(manager, wconfig)
# require the value of config.connect_at which is set only upon connection completion
for jw in PGRP.workers
if (jw.id != 1) && (jw.id < w.id)
(jw.state === W_CREATED) && wait(jw.c_state)
# wait for wl to join
if jw.state === W_CREATED
lock(jw.c_state) do
wait(jw.c_state)
end
end
push!(join_list, jw)
end
end
Expand All @@ -668,7 +687,12 @@ function create_worker(manager, wconfig)
end

for wl in wlist
(wl.state === W_CREATED) && wait(wl.c_state)
lock(wl.c_state) do
if wl.state === W_CREATED
# wait for wl to join
wait(wl.c_state)
end
end
push!(join_list, wl)
end
end
Expand Down Expand Up @@ -727,23 +751,21 @@ function redirect_output_from_additional_worker(pid, port)
end

function check_master_connect()
timeout = worker_timeout() * 1e9
# If we do not have at least process 1 connect to us within timeout
# we log an error and exit, unless we're running on valgrind
if ccall(:jl_running_on_valgrind,Cint,()) != 0
return
end
@async begin
start = time_ns()
while !haskey(map_pid_wrkr, 1) && (time_ns() - start) < timeout
sleep(1.0)
end

if !haskey(map_pid_wrkr, 1)
print(stderr, "Master process (id 1) could not connect within $(timeout/1e9) seconds.\nexiting.\n")
exit(1)
errormonitor(
Threads.@spawn begin
timeout = worker_timeout()
if timedwait(() -> !haskey(map_pid_wrkr, 1), timeout) === :timed_out
print(stderr, "Master process (id 1) could not connect within $(timeout) seconds.\nexiting.\n")
exit(1)
end
end
end
)
end


Expand Down Expand Up @@ -1028,13 +1050,13 @@ function rmprocs(pids...; waitfor=typemax(Int))

pids = vcat(pids...)
if waitfor == 0
t = @async _rmprocs(pids, typemax(Int))
t = Threads.@spawn Threads.threadpool() _rmprocs(pids, typemax(Int))
yield()
return t
else
_rmprocs(pids, waitfor)
# return a dummy task object that user code can wait on.
return @async nothing
return Threads.@spawn Threads.threadpool() nothing
end
end

Expand Down Expand Up @@ -1217,7 +1239,7 @@ function interrupt(pids::AbstractVector=workers())
@assert myid() == 1
@sync begin
for pid in pids
@async interrupt(pid)
Threads.@spawn Threads.threadpool() interrupt(pid)
end
end
end
Expand Down Expand Up @@ -1288,18 +1310,16 @@ end

using Random: randstring

let inited = false
# do initialization that's only needed when there is more than 1 processor
global function init_multi()
if !inited
inited = true
push!(Base.package_callbacks, _require_callback)
atexit(terminate_all_workers)
init_bind_addr()
cluster_cookie(randstring(HDR_COOKIE_LEN))
end
return nothing
# do initialization that's only needed when there is more than 1 processor
const inited = Threads.Atomic{Bool}(false)
function init_multi()
if !Threads.atomic_cas!(inited, false, true)
push!(Base.package_callbacks, _require_callback)
atexit(terminate_all_workers)
init_bind_addr()
cluster_cookie(randstring(HDR_COOKIE_LEN))
end
return nothing
end

function init_parallel()
Expand Down
4 changes: 2 additions & 2 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ function remotecall_eval(m::Module, procs, ex)
# execute locally last as we do not want local execution to block serialization
# of the request to remote nodes.
for _ in 1:run_locally
@async Core.eval(m, ex)
Threads.@spawn Threads.threadpool() Core.eval(m, ex)
end
end
nothing
Expand Down Expand Up @@ -275,7 +275,7 @@ function preduce(reducer, f, R)
end

function pfor(f, R)
t = @async @sync for c in splitrange(Int(firstindex(R)), Int(lastindex(R)), nworkers())
t = Threads.@spawn Threads.threadpool() @sync for c in splitrange(Int(firstindex(R)), Int(lastindex(R)), nworkers())
@spawnat :any f(R, first(c), last(c))
end
errormonitor(t)
Expand Down
4 changes: 2 additions & 2 deletions src/managers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
# Wait for all launches to complete.
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
let machine=machine, cnt=cnt
@async try
Threads.@spawn Threads.threadpool() try
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
catch e
print(stderr, "exception launching on machine $(machine) : $(e)\n")
Expand Down Expand Up @@ -744,7 +744,7 @@ function kill(manager::LocalManager, pid::Int, config::WorkerConfig; exit_timeou
# First, try sending `exit()` to the remote over the usual control channels
remote_do(exit, pid)

timer_task = @async begin
timer_task = Threads.@spawn Threads.threadpool() begin
sleep(exit_timeout)

# Check to see if our child exited, and if not, send an actual kill signal
Expand Down
2 changes: 1 addition & 1 deletion src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ function flush_gc_msgs()
end
catch e
bt = catch_backtrace()
@async showerror(stderr, e, bt)
Threads.@spawn showerror(stderr, e, bt)
end
end

Expand Down
14 changes: 7 additions & 7 deletions src/process_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function schedule_call(rid, thunk)
rv = RemoteValue(def_rv_channel())
(PGRP::ProcessGroup).refs[rid] = rv
push!(rv.clientset, rid.whence)
errormonitor(@async run_work_thunk(rv, thunk))
errormonitor(Threads.@spawn run_work_thunk(rv, thunk))
return rv
end
end
Expand Down Expand Up @@ -118,7 +118,7 @@ end

## message event handlers ##
function process_messages(r_stream::TCPSocket, w_stream::TCPSocket, incoming::Bool=true)
errormonitor(@async process_tcp_streams(r_stream, w_stream, incoming))
errormonitor(Threads.@spawn process_tcp_streams(r_stream, w_stream, incoming))
end

function process_tcp_streams(r_stream::TCPSocket, w_stream::TCPSocket, incoming::Bool)
Expand Down Expand Up @@ -148,7 +148,7 @@ Julia version number to perform the authentication handshake.
See also [`cluster_cookie`](@ref).
"""
function process_messages(r_stream::IO, w_stream::IO, incoming::Bool=true)
errormonitor(@async message_handler_loop(r_stream, w_stream, incoming))
errormonitor(Threads.@spawn message_handler_loop(r_stream, w_stream, incoming))
end

function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
Expand Down Expand Up @@ -283,7 +283,7 @@ function handle_msg(msg::CallMsg{:call}, header, r_stream, w_stream, version)
schedule_call(header.response_oid, ()->invokelatest(msg.f, msg.args...; msg.kwargs...))
end
function handle_msg(msg::CallMsg{:call_fetch}, header, r_stream, w_stream, version)
errormonitor(@async begin
errormonitor(Threads.@spawn begin
v = run_work_thunk(()->invokelatest(msg.f, msg.args...; msg.kwargs...), false)
if isa(v, SyncTake)
try
Expand All @@ -299,15 +299,15 @@ function handle_msg(msg::CallMsg{:call_fetch}, header, r_stream, w_stream, versi
end

function handle_msg(msg::CallWaitMsg, header, r_stream, w_stream, version)
errormonitor(@async begin
errormonitor(Threads.@spawn begin
rv = schedule_call(header.response_oid, ()->invokelatest(msg.f, msg.args...; msg.kwargs...))
deliver_result(w_stream, :call_wait, header.notify_oid, fetch(rv.c))
nothing
end)
end

function handle_msg(msg::RemoteDoMsg, header, r_stream, w_stream, version)
errormonitor(@async run_work_thunk(()->invokelatest(msg.f, msg.args...; msg.kwargs...), true))
errormonitor(Threads.@spawn run_work_thunk(()->invokelatest(msg.f, msg.args...; msg.kwargs...), true))
end

function handle_msg(msg::ResultMsg, header, r_stream, w_stream, version)
Expand Down Expand Up @@ -350,7 +350,7 @@ function handle_msg(msg::JoinPGRPMsg, header, r_stream, w_stream, version)
# The constructor registers the object with a global registry.
Worker(rpid, ()->connect_to_peer(cluster_manager, rpid, wconfig))
else
@async connect_to_peer(cluster_manager, rpid, wconfig)
Threads.@spawn connect_to_peer(cluster_manager, rpid, wconfig)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/remotecall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ or to use a local [`Channel`](@ref) as a proxy:
```julia
p = 1
f = Future(p)
errormonitor(@async put!(f, remotecall_fetch(long_computation, p)))
errormonitor(Threads.@spawn put!(f, remotecall_fetch(long_computation, p)))
isready(f) # will not block
```
"""
Expand Down Expand Up @@ -322,7 +322,7 @@ function process_worker(rr)
msg = (remoteref_id(rr), myid())

# Needs to acquire a lock on the del_msg queue
T = Threads.@spawn begin
T = Threads.@spawn Threads.threadpool() begin
publish_del_msg!($w, $msg)
end
Base.errormonitor(T)
Expand Down
Loading

0 comments on commit 0e27cd1

Please sign in to comment.