@@ -305,12 +305,13 @@ def valueFromFunc(k, v): # type: (Any, Any) -> Any
305305 # https://github.com/python/mypy/issues/797
306306 ** kwargs )
307307 elif method == "flat_crossproduct" :
308- jobs = flat_crossproduct_scatter (step , inputobj ,
309- scatter ,
310- cast (Callable [[Any ], Any ],
308+ jobs = cast (Generator ,
309+ flat_crossproduct_scatter (step , inputobj ,
310+ scatter ,
311+ cast (Callable [[Any ], Any ],
311312 # known bug in mypy
312313 # https://github.com/python/mypy/issues/797
313- callback ), 0 , ** kwargs )
314+ callback ), 0 , ** kwargs ) )
314315 else :
315316 _logger .debug (u"[job %s] job input %s" , step .name , json .dumps (inputobj , indent = 4 ))
316317 inputobj = postScatterEval (inputobj )
@@ -332,7 +333,7 @@ def run(self, **kwargs):
332333 _logger .debug (u"[%s] workflow starting" , self .name )
333334
334335 def job (self , joborder , output_callback , ** kwargs ):
335- # type: (Dict[Text, Any], Callable[[Any, Any], Any], **Any) -> Generator[WorkflowJob, None, None]
336+ # type: (Dict[Text, Any], Callable[[Any, Any], Any], **Any) -> Generator
336337 self .state = {}
337338 self .processStatus = "success"
338339
@@ -405,7 +406,7 @@ def __init__(self, toolpath_object, **kwargs):
405406 # TODO: statically validate data links instead of doing it at runtime.
406407
407408 def job (self , joborder , output_callback , ** kwargs ):
408- # type: (Dict[Text, Text], Callable[[Any, Any], Any], **Any) -> Generator[WorkflowJob, None, None]
409+ # type: (Dict[Text, Text], Callable[[Any, Any], Any], **Any) -> Generator
409410 builder = self ._init_job (joborder , ** kwargs )
410411 wj = WorkflowJob (self , ** kwargs )
411412 yield wj
@@ -577,9 +578,25 @@ def setTotal(self, total): # type: (int) -> None
577578 if self .completed == self .total :
578579 self .output_callback (self .dest , self .processStatus )
579580
581+ def parallel_steps (steps , rc , kwargs ): # type: (List[Generator], ReceiveScatterOutput, Dict[str, Any]) -> Generator
582+ while rc .completed < rc .total :
583+ made_progress = False
584+ for step in steps :
585+ if kwargs .get ("on_error" , "stop" ) == "stop" and rc .processStatus != "success" :
586+ break
587+ for j in step :
588+ if kwargs .get ("on_error" , "stop" ) == "stop" and rc .processStatus != "success" :
589+ break
590+ if j :
591+ made_progress = True
592+ yield j
593+ else :
594+ break
595+ if not made_progress and rc .completed < rc .total :
596+ yield None
580597
581598def dotproduct_scatter (process , joborder , scatter_keys , output_callback , ** kwargs ):
582- # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator[WorkflowJob, None, None]
599+ # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator
583600 l = None
584601 for s in scatter_keys :
585602 if l is None :
@@ -593,21 +610,23 @@ def dotproduct_scatter(process, joborder, scatter_keys, output_callback, **kwarg
593610
594611 rc = ReceiveScatterOutput (output_callback , output )
595612
613+ steps = []
596614 for n in range (0 , l ):
597615 jo = copy .copy (joborder )
598616 for s in scatter_keys :
599617 jo [s ] = joborder [s ][n ]
600618
601619 jo = kwargs ["postScatterEval" ](jo )
602620
603- for j in process .job (jo , functools .partial (rc .receive_scatter_output , n ), ** kwargs ):
604- yield j
621+ steps .append (process .job (jo , functools .partial (rc .receive_scatter_output , n ), ** kwargs ))
605622
606623 rc .setTotal (l )
607624
625+ return parallel_steps (steps , rc , kwargs )
626+
608627
609628def nested_crossproduct_scatter (process , joborder , scatter_keys , output_callback , ** kwargs ):
610- # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator[WorkflowJob, None, None]
629+ # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator
611630 scatter_key = scatter_keys [0 ]
612631 l = len (joborder [scatter_key ])
613632 output = {} # type: Dict[Text,List[Text]]
@@ -616,25 +635,24 @@ def nested_crossproduct_scatter(process, joborder, scatter_keys, output_callback
616635
617636 rc = ReceiveScatterOutput (output_callback , output )
618637
638+ steps = []
619639 for n in range (0 , l ):
620640 jo = copy .copy (joborder )
621641 jo [scatter_key ] = joborder [scatter_key ][n ]
622642
623643 if len (scatter_keys ) == 1 :
624644 jo = kwargs ["postScatterEval" ](jo )
625- for j in process .job (jo , functools .partial (rc .receive_scatter_output , n ), ** kwargs ):
626- yield j
645+ steps .append (process .job (jo , functools .partial (rc .receive_scatter_output , n ), ** kwargs ))
627646 else :
628- for j in nested_crossproduct_scatter (process , jo ,
647+ steps . append ( nested_crossproduct_scatter (process , jo ,
629648 scatter_keys [1 :], cast ( # known bug with mypy
630- # https://github.com/python/mypy/issues/797
649+ # https://github.com/python/mypy/issues/797g
631650 Callable [[Any ], Any ],
632- functools .partial (rc .receive_scatter_output , n )),
633- ** kwargs ):
634- yield j
651+ functools .partial (rc .receive_scatter_output , n )), ** kwargs ))
635652
636653 rc .setTotal (l )
637654
655+ return parallel_steps (steps , rc , kwargs )
638656
639657def crossproduct_size (joborder , scatter_keys ):
640658 # type: (Dict[Text, Any], List[Text]) -> int
@@ -650,7 +668,7 @@ def crossproduct_size(joborder, scatter_keys):
650668 return sum
651669
652670def flat_crossproduct_scatter (process , joborder , scatter_keys , output_callback , startindex , ** kwargs ):
653- # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Union[ReceiveScatterOutput,Callable[..., Any]], int, **Any) -> Generator[WorkflowJob, None, None ]
671+ # type: (WorkflowJobStep, Dict[Text, Any], List[Text], Union[ReceiveScatterOutput,Callable[..., Any]], int, **Any) -> Union[List[Generator], Generator ]
654672 scatter_key = scatter_keys [0 ]
655673 l = len (joborder [scatter_key ])
656674 rc = None # type: ReceiveScatterOutput
@@ -665,20 +683,23 @@ def flat_crossproduct_scatter(process, joborder, scatter_keys, output_callback,
665683 else :
666684 raise Exception ("Unhandled code path. Please report this." )
667685
686+ steps = []
668687 put = startindex
669688 for n in range (0 , l ):
670689 jo = copy .copy (joborder )
671690 jo [scatter_key ] = joborder [scatter_key ][n ]
672691
673692 if len (scatter_keys ) == 1 :
674693 jo = kwargs ["postScatterEval" ](jo )
675- for j in process .job (jo , functools .partial (rc .receive_scatter_output , put ), ** kwargs ):
676- yield j
694+ steps .append (process .job (jo , functools .partial (rc .receive_scatter_output , put ), ** kwargs ))
677695 put += 1
678696 else :
679- for j in flat_crossproduct_scatter (process , jo , scatter_keys [1 :], rc , put , ** kwargs ):
680- if j :
681- put += 1
682- yield j
697+ add = flat_crossproduct_scatter (process , jo , scatter_keys [1 :], rc , put , ** kwargs )
698+ put += len (cast (List [Generator ], add ))
699+ steps .extend (add )
683700
684- rc .setTotal (put )
701+ if startindex == 0 and not isinstance (output_callback , ReceiveScatterOutput ):
702+ rc .setTotal (put )
703+ return parallel_steps (steps , rc , kwargs )
704+ else :
705+ return steps
0 commit comments