@@ -227,7 +227,8 @@ def execute(
227227 any = any_pickle (input )
228228 req .input .CopyFrom (any )
229229 if state is not None :
230- req .poll_result .coroutine_state = state
230+ any = any_pickle (state )
231+ req .poll_result .typed_coroutine_state .CopyFrom (any )
231232 if calls is not None :
232233 for c in calls :
233234 req .poll_result .results .append (c )
@@ -247,10 +248,6 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult:
247248 resp = self .client .run (req )
248249 self .assertIsInstance (resp , function_pb .RunResponse )
249250
250- # Assert the response is terminal. Good enough until the test client can
251- # orchestrate coroutines.
252- self .assertTrue (len (resp .poll .coroutine_state ) == 0 )
253-
254251 resp .exit .result .correlation_id = call .correlation_id
255252 return resp .exit .result
256253
@@ -317,9 +314,10 @@ async def my_function(input: Input) -> Output:
317314 return Output .value ("not reached" )
318315
319316 resp = self .execute (my_function , input = "cool stuff" )
320- self .assertEqual (b"42" , resp .poll .coroutine_state )
317+ state = any_unpickle (resp .poll .typed_coroutine_state )
318+ self .assertEqual (b"42" , state )
321319
322- resp = self .execute (my_function , state = resp . poll . coroutine_state )
320+ resp = self .execute (my_function , state = state )
323321 self .assertEqual ("ValueError" , resp .exit .result .error .type )
324322 self .assertEqual (
325323 "This input is for a resumed coroutine" , resp .exit .result .error .message
@@ -360,32 +358,29 @@ async def coroutine3(input: Input) -> Output:
360358 if input .is_first_call :
361359 counter = input .input
362360 else :
363- ( counter ,) = struct . unpack ( "@i" , input .coroutine_state )
361+ counter = input .coroutine_state
364362 counter -= 1
365363 if counter <= 0 :
366364 return Output .value ("done" )
367- coroutine_state = struct .pack ("@i" , counter )
368- return Output .poll (coroutine_state = coroutine_state )
365+ return Output .poll (coroutine_state = counter )
369366
370367 # first call
371368 resp = self .execute (coroutine3 , input = 4 )
372- state = resp .poll .coroutine_state
373- self .assertTrue ( len ( state ) > 0 )
369+ state = any_unpickle ( resp .poll .typed_coroutine_state )
370+ self .assertEqual ( state , 3 )
374371
375372 # resume, state = 3
376373 resp = self .execute (coroutine3 , state = state )
377- state = resp .poll .coroutine_state
378- self .assertTrue ( len ( state ) > 0 )
374+ state = any_unpickle ( resp .poll .typed_coroutine_state )
375+ self .assertEqual ( state , 2 )
379376
380377 # resume, state = 2
381378 resp = self .execute (coroutine3 , state = state )
382- state = resp .poll .coroutine_state
383- self .assertTrue ( len ( state ) > 0 )
379+ state = any_unpickle ( resp .poll .typed_coroutine_state )
380+ self .assertEqual ( state , 1 )
384381
385382 # resume, state = 1
386383 resp = self .execute (coroutine3 , state = state )
387- state = resp .poll .coroutine_state
388- self .assertTrue (len (state ) == 0 )
389384 out = response_output (resp )
390385 self .assertEqual (out , "done" )
391386
@@ -399,18 +394,18 @@ async def coroutine_main(input: Input) -> Output:
399394 if input .is_first_call :
400395 text : str = input .input
401396 return Output .poll (
402- coroutine_state = text . encode () ,
397+ coroutine_state = text ,
403398 calls = [coro_compute_len ._build_primitive_call (text )],
404399 )
405- text = input .coroutine_state . decode ()
400+ text = input .coroutine_state
406401 length = input .call_results [0 ].output
407402 return Output .value (f"length={ length } text='{ text } '" )
408403
409404 resp = self .execute (coroutine_main , input = "cool stuff" )
410405
411406 # main saved some state
412- state = resp .poll .coroutine_state
413- self .assertTrue ( len ( state ) > 0 )
407+ state = any_unpickle ( resp .poll .typed_coroutine_state )
408+ self .assertEqual ( state , "cool stuff" )
414409 # main asks for 1 call to compute_len
415410 self .assertEqual (len (resp .poll .calls ), 1 )
416411 call = resp .poll .calls [0 ]
@@ -426,7 +421,6 @@ async def coroutine_main(input: Input) -> Output:
426421 # resume main with the result
427422 resp = self .execute (coroutine_main , state = state , calls = [resp2 ])
428423 # validate the final result
429- self .assertTrue (len (resp .poll .coroutine_state ) == 0 )
430424 out = response_output (resp )
431425 self .assertEqual ("length=10 text='cool stuff'" , out )
432426
@@ -440,7 +434,7 @@ async def coroutine_main(input: Input) -> Output:
440434 if input .is_first_call :
441435 text : str = input .input
442436 return Output .poll (
443- coroutine_state = text . encode () ,
437+ coroutine_state = text ,
444438 calls = [coro_compute_len ._build_primitive_call (text )],
445439 )
446440 error = input .call_results [0 ].error
@@ -452,8 +446,8 @@ async def coroutine_main(input: Input) -> Output:
452446 resp = self .execute (coroutine_main , input = "cool stuff" )
453447
454448 # main saved some state
455- state = resp .poll .coroutine_state
456- self .assertTrue ( len ( state ) > 0 )
449+ state = any_unpickle ( resp .poll .typed_coroutine_state )
450+ self .assertEqual ( state , "cool stuff" )
457451 # main asks for 1 call to compute_len
458452 self .assertEqual (len (resp .poll .calls ), 1 )
459453 call = resp .poll .calls [0 ]
@@ -466,7 +460,6 @@ async def coroutine_main(input: Input) -> Output:
466460 # resume main with the result
467461 resp = self .execute (coroutine_main , state = state , calls = [resp2 ])
468462 # validate the final result
469- self .assertTrue (len (resp .poll .coroutine_state ) == 0 )
470463 out = response_output (resp )
471464 self .assertEqual (out , "msg=Dead type='type'" )
472465
0 commit comments