1616from IPython import get_ipython
1717from traitlets import Any , Bool , CFloat , Dict , HasTraits , Instance , Integer , List , Set
1818
19+ import ipyparallel as ipp
1920from ipyparallel import util
2021from ipyparallel .controller .dependency import Dependency , dependent
2122
@@ -767,7 +768,7 @@ def scatter(
767768 mapObject = Map .dists [dist ]()
768769 nparts = len (targets )
769770 futures = []
770- trackers = []
771+ _lengths = []
771772 for index , engineid in enumerate (targets ):
772773 partition = mapObject .getPartition (seq , index , nparts )
773774 if flatten and len (partition ) == 1 :
@@ -777,10 +778,12 @@ def scatter(
777778 r = self .push (ns , block = False , track = track , targets = engineid )
778779 r .owner = False
779780 futures .extend (r ._children )
781+ _lengths .append (len (partition ))
780782
781783 r = AsyncResult (
782784 self .client , futures , fname = 'scatter' , targets = targets , owner = True
783785 )
786+ r ._scatter_lengths = _lengths
784787 if block :
785788 r .wait ()
786789 else :
@@ -930,7 +933,6 @@ def _really_apply(
930933 track = self .track if track is None else track
931934 targets = self .targets if targets is None else targets
932935 idents , _targets = self .client ._build_targets (targets )
933- futures = []
934936
935937 pf = PrePickled (f )
936938 pargs = [PrePickled (arg ) for arg in args ]
@@ -1014,8 +1016,113 @@ def make_asyncresult(message_future):
10141016 pass
10151017 return ar
10161018
1017- def map (self , f , * sequences , ** kwargs ):
1018- raise NotImplementedError ("BroadcastView.map not yet implemented" )
1019+ @staticmethod
1020+ def _broadcast_map (f , * sequence_names ):
1021+ """Function passed to apply
1022+
1023+ Equivalent, but account for the fact that scatter
1024+ occurs in a separate step.
1025+
1026+ Does these things:
1027+ - resolve sequence names to sequences in the user namespace
1028+ - collect list(map(f, *squences))
1029+ - cleanup temporary sequence variables from scatter
1030+ """
1031+ sequences = []
1032+ ip = get_ipython ()
1033+ for seq_name in sequence_names :
1034+ sequences .append (ip .user_ns .pop (seq_name ))
1035+ return list (map (f , * sequences ))
1036+
1037+ @_not_coalescing
1038+ def map (self , f , * sequences , block = None , track = False , return_exceptions = False ):
1039+ """Parallel version of builtin `map`, using this View's `targets`.
1040+
1041+ There will be one task per engine, so work will be chunked
1042+ if the sequences are longer than `targets`.
1043+
1044+ Results can be iterated as they are ready, but will become available in chunks.
1045+
1046+ .. note::
1047+
1048+ BroadcastView does not yet have a fully native map implementation.
1049+ In particular, the scatter step is still one message per engine,
1050+ identical to DirectView,
1051+ and typically slower due to the more complex scheduler.
1052+
1053+ It is more efficient to partition inputs via other means (e.g. SPMD based on rank & size)
1054+ and use `apply` to submit all tasks in one broadcast.
1055+
1056+ .. versionadded:: 8.8
1057+
1058+ Parameters
1059+ ----------
1060+ f : callable
1061+ function to be mapped
1062+ *sequences : one or more sequences of matching length
1063+ the sequences to be distributed and passed to `f`
1064+ block : bool [default self.block]
1065+ whether to wait for the result or not
1066+ track : bool [default False]
1067+ Track underlying zmq send to indicate when it is safe to modify memory.
1068+ Only for zero-copy sends such as numpy arrays that are going to be modified in-place.
1069+ return_exceptions : bool [default False]
1070+ Return remote Exceptions in the result sequence instead of raising them.
1071+
1072+ Returns
1073+ -------
1074+ If block=False
1075+ An :class:`~ipyparallel.client.asyncresult.AsyncMapResult` instance.
1076+ An object like AsyncResult, but which reassembles the sequence of results
1077+ into a single list. AsyncMapResults can be iterated through before all
1078+ results are complete.
1079+ else
1080+ A list, the result of ``map(f,*sequences)``
1081+ """
1082+ if block is None :
1083+ block = self .block
1084+ if track is None :
1085+ track = self .track
1086+
1087+ # unique identifier, since we're living in the interactive namespace
1088+ map_key = secrets .token_hex (5 )
1089+ dist = 'b'
1090+ map_object = Map .dists [dist ]()
1091+
1092+ seq_names = []
1093+ for i , seq in enumerate (sequences ):
1094+ seq_name = f"_seq_{ map_key } _{ i } "
1095+ seq_names .append (seq_name )
1096+ try :
1097+ len (seq )
1098+ except Exception :
1099+ # cast length-less sequences (e.g. Range) to list
1100+ seq = list (seq )
1101+
1102+ ar = self .scatter (seq_name , seq , dist = dist , block = False , track = track )
1103+ scatter_chunk_sizes = ar ._scatter_lengths
1104+
1105+ # submit the map tasks as an actual broadcast
1106+ ar = self .apply (self ._broadcast_map , f , * seq_names )
1107+ ar .owner = False
1108+ # re-wrap messages in an AsyncMapResult to get map API
1109+ # this is where the 'gather' reconstruction happens
1110+ amr = ipp .AsyncMapResult (
1111+ self .client ,
1112+ ar ._children ,
1113+ map_object ,
1114+ fname = getname (f ),
1115+ return_exceptions = return_exceptions ,
1116+ chunk_sizes = {
1117+ future .msg_id : chunk_size
1118+ for future , chunk_size in zip (ar ._children , scatter_chunk_sizes )
1119+ },
1120+ )
1121+
1122+ if block :
1123+ return amr .get ()
1124+ else :
1125+ return amr
10191126
10201127 # scatter/gather cannot be coalescing yet
10211128 scatter = _not_coalescing (DirectView .scatter )
0 commit comments