88
99from fastapi .testclient import TestClient
1010
11- import dispatch .sdk .v1 .status_pb2 as status_pb
12-
13- from ... import function_service
14- from ...test_client import ServerTest
11+ from dispatch import Client
12+ from dispatch .sdk .v1 import status_pb2 as status_pb
13+ from dispatch .test import DispatchServer , DispatchService , EndpointClient
1514
1615
1716class TestAutoRetry (unittest .TestCase ):
@@ -22,29 +21,33 @@ class TestAutoRetry(unittest.TestCase):
2221 "DISPATCH_API_KEY" : "0000000000000000" ,
2322 },
2423 )
25- def test_foo (self ):
26- from . import app
24+ def test_app (self ):
25+ from .app import app , dispatch
26+
27+ # Setup a fake Dispatch server.
28+ endpoint_client = EndpointClient .from_app (app )
29+ dispatch_service = DispatchService (endpoint_client , collect_roundtrips = True )
30+ with DispatchServer (dispatch_service ) as dispatch_server :
2731
28- server = ServerTest ()
29- servicer = server .servicer
30- app .dispatch ._client = server .client
31- app .some_logic ._client = server .client
32+ # Use it when dispatching function calls.
33+ dispatch .set_client (Client (api_url = dispatch_server .url ))
3234
33- http_client = TestClient (app .app , base_url = "http://dispatch-service" )
34- app_client = function_service .client (http_client )
35+ http_client = TestClient (app )
36+ response = http_client .get ("/" )
37+ self .assertEqual (response .status_code , 200 )
3538
36- response = http_client .get ("/" )
37- self .assertEqual (response .status_code , 200 )
39+ dispatch_service .dispatch_calls ()
3840
39- server .execute (app_client )
41+ # Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6
42+ # calls, including 5 retries.
43+ for i in range (6 ):
44+ dispatch_service .dispatch_calls ()
4045
41- # Seed(2) used in the app outputs 0, 0, 0, 2, 1, 5. So we expect 6
42- # calls, including 5 retries.
43- for i in range (6 ):
44- server .execute (app_client )
45- self .assertEqual (len (servicer .responses ), 6 )
46+ self .assertEqual (len (dispatch_service .roundtrips ), 1 )
47+ roundtrips = list (dispatch_service .roundtrips .values ())[0 ]
48+ self .assertEqual (len (roundtrips ), 6 )
4649
47- statuses = [r [ " response" ] .status for r in servicer . responses ]
48- self .assertEqual (
49- statuses , [status_pb .STATUS_TEMPORARY_ERROR ] * 5 + [status_pb .STATUS_OK ]
50- )
50+ statuses = [response .status for request , response in roundtrips ]
51+ self .assertEqual (
52+ statuses , [status_pb .STATUS_TEMPORARY_ERROR ] * 5 + [status_pb .STATUS_OK ]
53+ )
0 commit comments