11import base64
22import os
33import pickle
4+ import struct
45import unittest
56from typing import Any
67from unittest import mock
@@ -282,7 +283,7 @@ def test_error_on_access_input_in_second_call(self):
282283 @self .dispatch .primitive_function
283284 def my_function (input : Input ) -> Output :
284285 if input .is_first_call :
285- return Output .poll (state = 42 )
286+ return Output .poll (coroutine_state = b"42" )
286287 try :
287288 print (input .input )
288289 except ValueError :
@@ -294,7 +295,7 @@ def my_function(input: Input) -> Output:
294295 return Output .value ("not reached" )
295296
296297 resp = self .execute (my_function , input = "cool stuff" )
297- self .assertEqual (42 , pickle . loads ( resp .poll .coroutine_state ) )
298+ self .assertEqual (b"42" , resp .poll .coroutine_state )
298299
299300 resp = self .execute (my_function , state = resp .poll .coroutine_state )
300301 self .assertEqual ("ValueError" , resp .exit .result .error .type )
@@ -337,11 +338,12 @@ def coroutine3(input: Input) -> Output:
337338 if input .is_first_call :
338339 counter = input .input
339340 else :
340- counter = input .coroutine_state
341+ ( counter ,) = struct . unpack ( "@i" , input .coroutine_state )
341342 counter -= 1
342343 if counter <= 0 :
343344 return Output .value ("done" )
344- return Output .poll (state = counter )
345+ coroutine_state = struct .pack ("@i" , counter )
346+ return Output .poll (coroutine_state = coroutine_state )
345347
346348 # first call
347349 resp = self .execute (coroutine3 , input = 4 )
@@ -375,9 +377,10 @@ def coroutine_main(input: Input) -> Output:
375377 if input .is_first_call :
376378 text : str = input .input
377379 return Output .poll (
378- state = text , calls = [coro_compute_len ._build_primitive_call (text )]
380+ coroutine_state = text .encode (),
381+ calls = [coro_compute_len ._build_primitive_call (text )],
379382 )
380- text = input .coroutine_state
383+ text = input .coroutine_state . decode ()
381384 length = input .call_results [0 ].output
382385 return Output .value (f"length={ length } text='{ text } '" )
383386
@@ -415,7 +418,8 @@ def coroutine_main(input: Input) -> Output:
415418 if input .is_first_call :
416419 text : str = input .input
417420 return Output .poll (
418- state = text , calls = [coro_compute_len ._build_primitive_call (text )]
421+ coroutine_state = text .encode (),
422+ calls = [coro_compute_len ._build_primitive_call (text )],
419423 )
420424 error = input .call_results [0 ].error
421425 if error is not None :
0 commit comments