@@ -472,20 +472,28 @@ end
472472```
473473"""
474474function addprocs (manager:: ClusterManager ; kwargs... )
475+ params = merge (default_addprocs_params (manager), Dict {Symbol, Any} (kwargs))
476+
475477 init_multi ()
476478
477479 cluster_mgmt_from_master_check ()
478480
479- lock (worker_lock)
480- try
481- addprocs_locked (manager:: ClusterManager ; kwargs... )
482- finally
483- unlock (worker_lock)
484- end
481+ # Call worker-starting callbacks
482+ warning_interval = params[:callback_warning_interval ]
483+ _run_callbacks_concurrently (" worker-starting" , worker_starting_callbacks,
484+ warning_interval, [(manager, params)])
485+
486+ # Add new workers
487+ new_workers = @lock worker_lock addprocs_locked (manager:: ClusterManager , params)
488+
489+ # Call worker-started callbacks
490+ _run_callbacks_concurrently (" worker-started" , worker_started_callbacks,
491+ warning_interval, new_workers)
492+
493+ return new_workers
485494end
486495
487- function addprocs_locked (manager:: ClusterManager ; kwargs... )
488- params = merge (default_addprocs_params (manager), Dict {Symbol,Any} (kwargs))
496+ function addprocs_locked (manager:: ClusterManager , params)
489497 topology (Symbol (params[:topology ]))
490498
491499 if PGRP. topology != = :all_to_all
@@ -572,7 +580,8 @@ default_addprocs_params() = Dict{Symbol,Any}(
572580 :exeflags => ` ` ,
573581 :env => [],
574582 :enable_threaded_blas => false ,
575- :lazy => true )
583+ :lazy => true ,
584+ :callback_warning_interval => 10 )
576585
577586
578587function setup_launched_worker (manager, wconfig, launched_q)
@@ -870,13 +879,160 @@ const HDR_COOKIE_LEN=16
870879const map_pid_wrkr = Dict {Int, Union{Worker, LocalProcess}} ()
871880const map_sock_wrkr = IdDict ()
872881const map_del_wrkr = Set {Int} ()
882+ const worker_starting_callbacks = Dict {Any, Base.Callable} ()
883+ const worker_started_callbacks = Dict {Any, Base.Callable} ()
884+ const worker_exiting_callbacks = Dict {Any, Base.Callable} ()
885+ const worker_exited_callbacks = Dict {Any, Base.Callable} ()
873886
874887# whether process is a master or worker in a distributed setup
875888myrole () = LPROCROLE[]
876889function myrole! (proctype:: Symbol )
877890 LPROCROLE[] = proctype
878891end
879892
893+ # Callbacks
894+
895+ function _run_callbacks_concurrently (callbacks_name, callbacks_dict, warning_interval, arglist)
896+ callback_tasks = Dict {Any, Task} ()
897+ for args in arglist
898+ for (name, callback) in callbacks_dict
899+ callback_tasks[name] = Threads. @spawn callback (args... )
900+ end
901+ end
902+
903+ running_callbacks = () -> [" '$(key) '" for (key, task) in callback_tasks if ! istaskdone (task)]
904+ while timedwait (() -> isempty (running_callbacks ()), warning_interval) === :timed_out
905+ callbacks_str = join (running_callbacks (), " , " )
906+ @warn " Waiting for these $(callbacks_name) callbacks to finish: $(callbacks_str) "
907+ end
908+
909+ # Wait on the tasks so that exceptions bubble up
910+ wait .(values (callback_tasks))
911+ end
912+
913+ function _add_callback (f, key, dict; arg_types= Tuple{Int})
914+ desired_signature = " f(" * join ([" ::$(t) " for t in arg_types. types], " , " ) * " )"
915+
916+ if ! hasmethod (f, arg_types)
917+ throw (ArgumentError (" Callback function is invalid, it must be able to be called with these argument types: $(desired_signature) " ))
918+ elseif haskey (dict, key)
919+ throw (ArgumentError (" A callback function with key '$(key) ' already exists" ))
920+ end
921+
922+ if isnothing (key)
923+ key = Symbol (gensym (), nameof (f))
924+ end
925+
926+ dict[key] = f
927+ return key
928+ end
929+
930+ _remove_callback (key, dict) = delete! (dict, key)
931+
932+ """
933+ add_worker_starting_callback(f::Base.Callable; key=nothing)
934+
935+ Register a callback to be called on the master process immediately before new
936+ workers are started. The callback `f` will be called with the `ClusterManager`
937+ instance that is being used and a dictionary of parameters related to adding
938+ workers, i.e. `f(manager, params)`. The `params` dictionary is specific to the
939+ `manager` type. Note that the `LocalManager` and `SSHManager` cluster managers
940+ in DistributedNext are not fully documented yet, see the
941+ [managers.jl](https://github.com/JuliaParallel/DistributedNext.jl/blob/master/src/managers.jl)
942+ file for their definitions.
943+
944+ !!! warning
945+ Adding workers can fail so it is not guaranteed that the workers requested
946+ will exist.
947+
948+ The worker-starting callbacks will be executed concurrently. If one throws an
949+ exception it will not be caught and will bubble up through [`addprocs`](@ref).
950+
951+ Keep in mind that the callbacks will add to the time taken to launch workers; so
952+ try to either keep the callbacks fast to execute, or do the actual work
953+ asynchronously by spawning a task in the callback (beware of race conditions if
954+ you do this).
955+ """
956+ add_worker_starting_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_starting_callbacks;
957+ arg_types= Tuple{ClusterManager, Dict})
958+ """
959+ remove_worker_starting_callback(key)
960+
961+ Remove the callback for `key` that was added with [`add_worker_starting_callback()`](@ref).
962+ """
963+ remove_worker_starting_callback (key) = _remove_callback (key, worker_starting_callbacks)
964+
965+ """
966+ add_worker_started_callback(f::Base.Callable; key=nothing)
967+
968+ Register a callback to be called on the master process whenever a worker is
969+ added. The callback will be called with the added worker ID,
970+ e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
971+ not specified.
972+
973+ The worker-started callbacks will be executed concurrently. If one throws an
974+ exception it will not be caught and will bubble up through [`addprocs()`](@ref).
975+
976+ Keep in mind that the callbacks will add to the time taken to launch workers; so
977+ try to either keep the callbacks fast to execute, or do the actual
978+ initialization asynchronously by spawning a task in the callback (beware of race
979+ conditions if you do this).
980+ """
981+ add_worker_started_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_started_callbacks)
982+
983+ """
984+ remove_worker_started_callback(key)
985+
986+ Remove the callback for `key` that was added with [`add_worker_started_callback()`](@ref).
987+ """
988+ remove_worker_started_callback (key) = _remove_callback (key, worker_started_callbacks)
989+
990+ """
991+ add_worker_exiting_callback(f::Base.Callable; key=nothing)
992+
993+ Register a callback to be called on the master process immediately before a
994+ worker is removed with [`rmprocs()`](@ref). The callback will be called with the
995+ worker ID, e.g. `f(w::Int)`. Chooses and returns a unique key for the callback
996+ if `key` is not specified.
997+
998+ All worker-exiting callbacks will be executed concurrently and if they don't
999+ all finish before the `callback_timeout` passed to `rmprocs()` then the process
1000+ will be removed anyway.
1001+ """
1002+ add_worker_exiting_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_exiting_callbacks)
1003+
1004+ """
1005+ remove_worker_exiting_callback(key)
1006+
1007+ Remove the callback for `key` that was added with [`add_worker_exiting_callback()`](@ref).
1008+ """
1009+ remove_worker_exiting_callback (key) = _remove_callback (key, worker_exiting_callbacks)
1010+
1011+ """
1012+ add_worker_exited_callback(f::Base.Callable; key=nothing)
1013+
1014+ Register a callback to be called on the master process when a worker has exited
1015+ for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
1016+ segfaulting etc). Chooses and returns a unique key for the callback if `key` is
1017+ not specified.
1018+
1019+ The callback will be called with the worker ID and the final
1020+ `Distributed.WorkerState` of the worker, e.g. `f(w::Int, state)`. `state` is an
1021+ enum, a value of `WorkerState_terminated` means a graceful exit and a value of
1022+ `WorkerState_exterminated` means the worker died unexpectedly.
1023+
1024+ If the callback throws an exception it will be caught and printed.
1025+ """
1026+ add_worker_exited_callback (f:: Base.Callable ; key= nothing ) = _add_callback (f, key, worker_exited_callbacks;
1027+ arg_types= Tuple{Int, WorkerState})
1028+
1029+ """
1030+ remove_worker_exited_callback(key)
1031+
1032+ Remove the callback for `key` that was added with [`add_worker_exited_callback()`](@ref).
1033+ """
1034+ remove_worker_exited_callback (key) = _remove_callback (key, worker_exited_callbacks)
1035+
8801036# cluster management related API
8811037"""
8821038 myid()
@@ -1063,7 +1219,7 @@ function cluster_mgmt_from_master_check()
10631219end
10641220
10651221"""
1066- rmprocs(pids...; waitfor=typemax(Int))
1222+ rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10 )
10671223
10681224Remove the specified workers. Note that only process 1 can add or remove
10691225workers.
@@ -1077,6 +1233,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10771233 returned. The user should call [`wait`](@ref) on the task before invoking any other
10781234 parallel calls.
10791235
1236+ The `callback_timeout` specifies how long to wait for any callbacks to execute
1237+ before continuing to remove the workers (see
1238+ [`add_worker_exiting_callback()`](@ref)).
1239+
10801240# Examples
10811241```julia-repl
10821242\$ julia -p 5
@@ -1093,24 +1253,38 @@ julia> workers()
10931253 6
10941254```
10951255"""
1096- function rmprocs (pids... ; waitfor= typemax (Int))
1256+ function rmprocs (pids... ; waitfor= typemax (Int), callback_timeout = 10 )
10971257 cluster_mgmt_from_master_check ()
10981258
10991259 pids = vcat (pids... )
11001260 if waitfor == 0
1101- t = @async _rmprocs (pids, typemax (Int))
1261+ t = @async _rmprocs (pids, typemax (Int), callback_timeout )
11021262 yield ()
11031263 return t
11041264 else
1105- _rmprocs (pids, waitfor)
1265+ _rmprocs (pids, waitfor, callback_timeout )
11061266 # return a dummy task object that user code can wait on.
11071267 return @async nothing
11081268 end
11091269end
11101270
1111- function _rmprocs (pids, waitfor)
1271+ function _rmprocs (pids, waitfor, callback_timeout )
11121272 lock (worker_lock)
11131273 try
1274+ # Run the callbacks
1275+ callback_tasks = Dict {Any, Task} ()
1276+ for pid in pids
1277+ for (name, callback) in worker_exiting_callbacks
1278+ callback_tasks[name] = Threads. @spawn callback (pid)
1279+ end
1280+ end
1281+
1282+ if timedwait (() -> all (istaskdone .(values (callback_tasks))), callback_timeout) === :timed_out
1283+ timedout_callbacks = [" '$(key) '" for (key, task) in callback_tasks if ! istaskdone (task)]
1284+ callbacks_str = join (timedout_callbacks, " , " )
1285+ @warn " Some worker-exiting callbacks have not yet finished, continuing to remove workers anyway. These are the callbacks still running: $(callbacks_str) "
1286+ end
1287+
11141288 rmprocset = Union{LocalProcess, Worker}[]
11151289 for p in pids
11161290 if p == 1
@@ -1256,6 +1430,18 @@ function deregister_worker(pg, pid)
12561430 delete! (pg. refs, id)
12571431 end
12581432 end
1433+
1434+ # Call callbacks on the master
1435+ if myid () == 1
1436+ for (name, callback) in worker_exited_callbacks
1437+ try
1438+ callback (pid, w. state)
1439+ catch ex
1440+ @error " Error when running worker-exited callback '$(name) '" exception= (ex, catch_backtrace ())
1441+ end
1442+ end
1443+ end
1444+
12591445 return
12601446end
12611447
0 commit comments