@@ -472,20 +472,28 @@ end
472472```
473473"""
474474function addprocs (manager:: ClusterManager ; kwargs... )
475+ params = merge (default_addprocs_params (manager), Dict {Symbol, Any} (kwargs))
476+
475477 init_multi ()
476478
477479 cluster_mgmt_from_master_check ()
478480
479- lock (worker_lock)
480- try
481- addprocs_locked (manager:: ClusterManager ; kwargs... )
482- finally
483- unlock (worker_lock)
484- end
481+ # Call worker-starting callbacks
482+ warning_interval = params[:callback_warning_interval ]
483+ _run_callbacks_concurrently (" worker-starting" , worker_starting_callbacks,
484+ warning_interval, [(manager, params)])
485+
486+ # Add new workers
487+ new_workers = @lock worker_lock addprocs_locked (manager:: ClusterManager , params)
488+
489+ # Call worker-started callbacks
490+ _run_callbacks_concurrently (" worker-started" , worker_started_callbacks,
491+ warning_interval, new_workers)
492+
493+ return new_workers
485494end
486495
487- function addprocs_locked (manager:: ClusterManager ; kwargs... )
488- params = merge (default_addprocs_params (manager), Dict {Symbol,Any} (kwargs))
496+ function addprocs_locked (manager:: ClusterManager , params)
489497 topology (Symbol (params[:topology ]))
490498
491499 if PGRP. topology != = :all_to_all
@@ -572,7 +580,8 @@ default_addprocs_params() = Dict{Symbol,Any}(
572580 :exeflags => ` ` ,
573581 :env => [],
574582 :enable_threaded_blas => false ,
575- :lazy => true )
583+ :lazy => true ,
584+ :callback_warning_interval => 10 )
576585
577586
578587function setup_launched_worker (manager, wconfig, launched_q)
@@ -870,13 +879,151 @@ const HDR_COOKIE_LEN=16
870879const map_pid_wrkr = Dict {Int, Union{Worker, LocalProcess}} ()
871880const map_sock_wrkr = IdDict ()
872881const map_del_wrkr = Set {Int} ()
882+ const worker_starting_callbacks = Dict {Any, Base.Callable} ()
883+ const worker_started_callbacks = Dict {Any, Base.Callable} ()
884+ const worker_exiting_callbacks = Dict {Any, Base.Callable} ()
885+ const worker_exited_callbacks = Dict {Any, Base.Callable} ()
873886
874887# whether process is a master or worker in a distributed setup
875888myrole () = LPROCROLE[]
876889function myrole! (proctype:: Symbol )
877890 LPROCROLE[] = proctype
878891end
879892
893+ # Callbacks
894+
895+ function _run_callbacks_concurrently (callbacks_name, callbacks_dict, warning_interval, arglist)
896+ callback_tasks = Dict {Any, Task} ()
897+ for args in arglist
898+ for (name, callback) in callbacks_dict
899+ callback_tasks[name] = Threads. @spawn callback (args... )
900+ end
901+ end
902+
903+ running_callbacks = () -> [" '$(key) '" for (key, task) in callback_tasks if ! istaskdone (task)]
904+ while timedwait (() -> isempty (running_callbacks ()), warning_interval) === :timed_out
905+ callbacks_str = join (running_callbacks (), " , " )
906+ @warn " Waiting for these $(callbacks_name) callbacks to finish: $(callbacks_str) "
907+ end
908+
909+ # Wait on the tasks so that exceptions bubble up
910+ wait .(values (callback_tasks))
911+ end
912+
913+ function _add_callback (f, key, dict; arg_types= Tuple{Int})
914+ desired_signature = " f(" * join ([" ::$(t) " for t in arg_types. types], " , " ) * " )"
915+
916+ if ! hasmethod (f, arg_types)
917+ throw (ArgumentError (" Callback function is invalid, it must be able to be called with these argument types: $(desired_signature) " ))
918+ elseif haskey (dict, key)
919+ throw (ArgumentError (" A callback function with key '$(key) ' already exists" ))
920+ end
921+
922+ if isnothing (key)
923+ key = Symbol (gensym (), nameof (f))
924+ end
925+
926+ dict[key] = f
927+ return key
928+ end
929+
930+ _remove_callback (key, dict) = delete! (dict, key)
931+
932+ """
933+ add_worker_starting_callback(f::Base.Callable; key=nothing)
934+
935+ Register a callback to be called on the master process immediately before new
936+ workers are started. The callback `f` will be called with the `ClusterManager`
937+ instance that is being used and a dictionary of parameters related to adding
938+ workers, i.e. `f(manager, params)`. The `params` dictionary is specific to the
939+ `manager` type. Note that the `LocalManager` and `SSHManager` cluster managers
940+ in DistributedNext are not fully documented yet, see the
941+ [managers.jl](https://github.com/JuliaParallel/DistributedNext.jl/blob/master/src/managers.jl)
942+ file for their definitions.
943+
944+ !!! warning
945+ Adding workers can fail so it is not guaranteed that the workers requested
946+ will exist.
947+
948+ The worker-starting callbacks will be executed concurrently. If one throws an
949+ exception it will not be caught and will bubble up through [`addprocs`](@ref).
950+
951+ Keep in mind that the callbacks will add to the time taken to launch workers; so
952+ try to either keep the callbacks fast to execute, or do the actual work
953+ asynchronously by spawning a task in the callback (beware of race conditions if
954+ you do this).
955+ """
956+ add_worker_starting_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_starting_callbacks;
957+ arg_types= Tuple{ClusterManager, Dict})
958+
959+ remove_worker_starting_callback (key) = _remove_callback (key, worker_starting_callbacks)
960+
961+ """
962+ add_worker_started_callback(f::Base.Callable; key=nothing)
963+
964+ Register a callback to be called on the master process whenever a worker is
965+ added. The callback will be called with the added worker ID,
966+ e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
967+ not specified.
968+
969+ The worker-started callbacks will be executed concurrently. If one throws an
970+ exception it will not be caught and will bubble up through [`addprocs()`](@ref).
971+
972+ Keep in mind that the callbacks will add to the time taken to launch workers; so
973+ try to either keep the callbacks fast to execute, or do the actual
974+ initialization asynchronously by spawning a task in the callback (beware of race
975+ conditions if you do this).
976+ """
977+ add_worker_started_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_started_callbacks)
978+
979+ """
980+ remove_worker_started_callback(key)
981+
982+ Remove the callback for `key` that was added with [`add_worker_started_callback()`](@ref).
983+ """
984+ remove_worker_started_callback (key) = _remove_callback (key, worker_started_callbacks)
985+
986+ """
987+ add_worker_exiting_callback(f::Base.Callable; key=nothing)
988+
989+ Register a callback to be called on the master process immediately before a
990+ worker is removed with [`rmprocs()`](@ref). The callback will be called with the
991+ worker ID, e.g. `f(w::Int)`. Chooses and returns a unique key for the callback
992+ if `key` is not specified.
993+
994+ All worker-exiting callbacks will be executed concurrently and if they don't
995+ all finish before the `callback_timeout` passed to `rmprocs()` then the process
996+ will be removed anyway.
997+ """
998+ add_worker_exiting_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_exiting_callbacks)
999+
1000+ """
1001+ remove_worker_exiting_callback(key)
1002+
1003+ Remove the callback for `key` that was added with [`add_worker_exiting_callback()`](@ref).
1004+ """
1005+ remove_worker_exiting_callback (key) = _remove_callback (key, worker_exiting_callbacks)
1006+
1007+ """
1008+ add_worker_exited_callback(f::Base.Callable; key=nothing)
1009+
1010+ Register a callback to be called on the master process when a worker has exited
1011+ for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
1012+ segfaulting etc). The callback will be called with the worker ID,
1013+ e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
1014+ not specified.
1015+
1016+ If the callback throws an exception it will be caught and printed.
1017+ """
1018+ add_worker_exited_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_exited_callbacks)
1019+
1020+ """
1021+ remove_worker_exited_callback(key)
1022+
1023+ Remove the callback for `key` that was added with [`add_worker_exited_callback()`](@ref).
1024+ """
1025+ remove_worker_exited_callback (key) = _remove_callback (key, worker_exited_callbacks)
1026+
8801027# cluster management related API
8811028"""
8821029 myid()
@@ -1063,7 +1210,7 @@ function cluster_mgmt_from_master_check()
10631210end
10641211
10651212"""
1066- rmprocs(pids...; waitfor=typemax(Int))
1213+ rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10 )
10671214
10681215Remove the specified workers. Note that only process 1 can add or remove
10691216workers.
@@ -1077,6 +1224,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10771224 returned. The user should call [`wait`](@ref) on the task before invoking any other
10781225 parallel calls.
10791226
1227+ The `callback_timeout` specifies how long to wait for any callbacks to execute
1228+ before continuing to remove the workers (see
1229+ [`add_worker_exiting_callback()`](@ref)).
1230+
10801231# Examples
10811232```julia-repl
10821233\$ julia -p 5
@@ -1093,24 +1244,38 @@ julia> workers()
10931244 6
10941245```
10951246"""
1096- function rmprocs (pids... ; waitfor= typemax (Int))
1247+ function rmprocs (pids... ; waitfor= typemax (Int), callback_timeout = 10 )
10971248 cluster_mgmt_from_master_check ()
10981249
10991250 pids = vcat (pids... )
11001251 if waitfor == 0
1101- t = @async _rmprocs (pids, typemax (Int))
1252+ t = @async _rmprocs (pids, typemax (Int), callback_timeout )
11021253 yield ()
11031254 return t
11041255 else
1105- _rmprocs (pids, waitfor)
1256+ _rmprocs (pids, waitfor, callback_timeout )
11061257 # return a dummy task object that user code can wait on.
11071258 return @async nothing
11081259 end
11091260end
11101261
1111- function _rmprocs (pids, waitfor)
1262+ function _rmprocs (pids, waitfor, callback_timeout )
11121263 lock (worker_lock)
11131264 try
1265+ # Run the callbacks
1266+ callback_tasks = Dict {Any, Task} ()
1267+ for pid in pids
1268+ for (name, callback) in worker_exiting_callbacks
1269+ callback_tasks[name] = Threads. @spawn callback (pid)
1270+ end
1271+ end
1272+
1273+ if timedwait (() -> all (istaskdone .(values (callback_tasks))), callback_timeout) === :timed_out
1274+ timedout_callbacks = [" '$(key) '" for (key, task) in callback_tasks if ! istaskdone (task)]
1275+ callbacks_str = join (timedout_callbacks, " , " )
1276+ @warn " Some worker-exiting callbacks have not yet finished, continuing to remove workers anyway. These are the callbacks still running: $(callbacks_str) "
1277+ end
1278+
11141279 rmprocset = Union{LocalProcess, Worker}[]
11151280 for p in pids
11161281 if p == 1
@@ -1256,6 +1421,18 @@ function deregister_worker(pg, pid)
12561421 delete! (pg. refs, id)
12571422 end
12581423 end
1424+
1425+ # Call callbacks on the master
1426+ if myid () == 1
1427+ for (name, callback) in worker_exited_callbacks
1428+ try
1429+ callback (pid)
1430+ catch ex
1431+ @error " Error when running worker-exited callback '$(name) '" exception= (ex, catch_backtrace ())
1432+ end
1433+ end
1434+ end
1435+
12591436 return
12601437end
12611438
0 commit comments