11import logging
22import pickle
33import sys
4- from dataclasses import dataclass
5- from typing import Any , Callable , Protocol , TypeAlias
4+ from dataclasses import dataclass , field
5+ from typing import Any , Awaitable , Callable , Protocol , TypeAlias
66
7- from dispatch .coroutine import Gather
7+ from dispatch .coroutine import AllDirective , AnyDirective , AnyException , RaceDirective
88from dispatch .error import IncompatibleStateError
99from dispatch .experimental .durable .function import DurableCoroutine , DurableGenerator
1010from dispatch .proto import Call , Error , Input , Output
@@ -73,17 +73,18 @@ def error(self) -> Exception | None:
7373 return self .first_error
7474
7575 def value (self ) -> Any :
76+ assert self .first_error is None
7677 assert self .result is not None
7778 return self .result .value
7879
7980
8081@dataclass (slots = True )
81- class GatherFuture :
82- """A future result of a dispatch.coroutine.gather () operation."""
82+ class AllFuture :
83+ """A future result of a dispatch.coroutine.all () operation."""
8384
84- order : list [CoroutineID ]
85- waiting : set [CoroutineID ]
86- results : dict [CoroutineID , CoroutineResult ]
85+ order : list [CoroutineID ] = field ( default_factory = list )
86+ waiting : set [CoroutineID ] = field ( default_factory = set )
87+ results : dict [CoroutineID , CoroutineResult ] = field ( default_factory = dict )
8788 first_error : Exception | None = None
8889
8990 def add_result (self , result : CallResult | CoroutineResult ):
@@ -94,13 +95,15 @@ def add_result(self, result: CallResult | CoroutineResult):
9495 except KeyError :
9596 return
9697
97- if result .error is not None and self .first_error is None :
98- self .first_error = result .error
98+ if result .error is not None :
99+ if self .first_error is None :
100+ self .first_error = result .error
101+ return
99102
100103 self .results [result .coroutine_id ] = result
101104
102105 def add_error (self , error : Exception ):
103- if self .first_error is not None :
106+ if self .first_error is None :
104107 self .first_error = error
105108
106109 def ready (self ) -> bool :
@@ -113,9 +116,108 @@ def error(self) -> Exception | None:
113116 def value (self ) -> list [Any ]:
114117 assert self .ready ()
115118 assert len (self .waiting ) == 0
119+ assert self .first_error is None
116120 return [self .results [id ].value for id in self .order ]
117121
118122
123+ @dataclass (slots = True )
124+ class AnyFuture :
125+ """A future result of a dispatch.coroutine.any() operation."""
126+
127+ order : list [CoroutineID ] = field (default_factory = list )
128+ waiting : set [CoroutineID ] = field (default_factory = set )
129+ first_result : CoroutineResult | None = None
130+ errors : dict [CoroutineID , Exception ] = field (default_factory = dict )
131+ generic_error : Exception | None = None
132+
133+ def add_result (self , result : CallResult | CoroutineResult ):
134+ assert isinstance (result , CoroutineResult )
135+
136+ try :
137+ self .waiting .remove (result .coroutine_id )
138+ except KeyError :
139+ return
140+
141+ if result .error is None :
142+ if self .first_result is None :
143+ self .first_result = result
144+ return
145+
146+ self .errors [result .coroutine_id ] = result .error
147+
148+ def add_error (self , error : Exception ):
149+ if self .generic_error is None :
150+ self .generic_error = error
151+
152+ def ready (self ) -> bool :
153+ return (
154+ self .generic_error is not None
155+ or self .first_result is not None
156+ or len (self .waiting ) == 0
157+ )
158+
159+ def error (self ) -> Exception | None :
160+ assert self .ready ()
161+ if self .generic_error is not None :
162+ return self .generic_error
163+ if self .first_result is not None or len (self .errors ) == 0 :
164+ return None
165+ match len (self .errors ):
166+ case 0 :
167+ return None
168+ case 1 :
169+ return self .errors [self .order [0 ]]
170+ case _:
171+ return AnyException ([self .errors [id ] for id in self .order ])
172+
173+ def value (self ) -> Any :
174+ assert self .ready ()
175+ if len (self .order ) == 0 :
176+ return None
177+ assert self .first_result is not None
178+ return self .first_result .value
179+
180+
181+ @dataclass (slots = True )
182+ class RaceFuture :
183+ """A future result of a dispatch.coroutine.race() operation."""
184+
185+ waiting : set [CoroutineID ] = field (default_factory = set )
186+ first_result : CoroutineResult | None = None
187+ first_error : Exception | None = None
188+
189+ def add_result (self , result : CallResult | CoroutineResult ):
190+ assert isinstance (result , CoroutineResult )
191+
192+ if result .error is not None :
193+ if self .first_error is None :
194+ self .first_error = result .error
195+ else :
196+ if self .first_result is None :
197+ self .first_result = result
198+
199+ self .waiting .remove (result .coroutine_id )
200+
201+ def add_error (self , error : Exception ):
202+ if self .first_error is None :
203+ self .first_error = error
204+
205+ def ready (self ) -> bool :
206+ return (
207+ self .first_error is not None
208+ or self .first_result is not None
209+ or len (self .waiting ) == 0
210+ )
211+
212+ def error (self ) -> Exception | None :
213+ assert self .ready ()
214+ return self .first_error
215+
216+ def value (self ) -> Any :
217+ assert self .first_error is None
218+ return self .first_result .value if self .first_result else None
219+
220+
119221@dataclass (slots = True )
120222class Coroutine :
121223 """An in-flight coroutine."""
@@ -386,30 +488,35 @@ def _run(self, input: Input) -> Output:
386488 state .prev_callers .append (coroutine )
387489 state .outstanding_calls += 1
388490
389- case Gather ():
390- gather = coroutine_yield
391-
392- children = []
393- for awaitable in gather .awaitables :
394- g = awaitable .__await__ ()
395- if not isinstance (g , DurableGenerator ):
396- raise ValueError (
397- "gather awaitable is not a @dispatch.function"
398- )
399- child_id = state .next_coroutine_id
400- state .next_coroutine_id += 1
401- child = Coroutine (
402- id = child_id , parent_id = coroutine .id , coroutine = g
403- )
404- logger .debug ("enqueuing %s for %s" , child , coroutine )
405- children .append (child )
491+ case AllDirective ():
492+ children = spawn_children (
493+ state , coroutine , coroutine_yield .awaitables
494+ )
406495
407- # Prepend children to get a depth-first traversal of coroutines.
408- state .ready = children + state .ready
496+ child_ids = [child .id for child in children ]
497+ coroutine .result = AllFuture (
498+ order = child_ids , waiting = set (child_ids )
499+ )
500+ state .suspended [coroutine .id ] = coroutine
501+
502+ case AnyDirective ():
503+ children = spawn_children (
504+ state , coroutine , coroutine_yield .awaitables
505+ )
409506
410507 child_ids = [child .id for child in children ]
411- coroutine .result = GatherFuture (
412- order = child_ids , waiting = set (child_ids ), results = {}
508+ coroutine .result = AnyFuture (
509+ order = child_ids , waiting = set (child_ids )
510+ )
511+ state .suspended [coroutine .id ] = coroutine
512+
513+ case RaceDirective ():
514+ children = spawn_children (
515+ state , coroutine , coroutine_yield .awaitables
516+ )
517+
518+ coroutine .result = RaceFuture (
519+ waiting = {child .id for child in children }
413520 )
414521 state .suspended [coroutine .id ] = coroutine
415522
@@ -446,6 +553,26 @@ def _run(self, input: Input) -> Output:
446553 )
447554
448555
556+ def spawn_children (
557+ state : State , coroutine : Coroutine , awaitables : tuple [Awaitable [Any ], ...]
558+ ) -> list [Coroutine ]:
559+ children = []
560+ for awaitable in awaitables :
561+ g = awaitable .__await__ ()
562+ if not isinstance (g , DurableGenerator ):
563+ raise TypeError ("awaitable is not a @dispatch.function" )
564+ child_id = state .next_coroutine_id
565+ state .next_coroutine_id += 1
566+ child = Coroutine (id = child_id , parent_id = coroutine .id , coroutine = g )
567+ logger .debug ("enqueuing %s for %s" , child , coroutine )
568+ children .append (child )
569+
570+ # Prepend children to get a depth-first traversal of coroutines.
571+ state .ready = children + state .ready
572+
573+ return children
574+
575+
449576def correlation_id (coroutine_id : CoroutineID , call_id : CallID ) -> CorrelationID :
450577 return coroutine_id << 32 | call_id
451578
0 commit comments