33import pickle
44import struct
55import unittest
6- from typing import Any
6+ from typing import Any , Optional
77from unittest import mock
88
99import fastapi
1010import google .protobuf .any_pb2
1111import google .protobuf .wrappers_pb2
1212import httpx
13- from cryptography .hazmat .primitives .asymmetric .ed25519 import Ed25519PublicKey
13+ from cryptography .hazmat .primitives .asymmetric .ed25519 import (
14+ Ed25519PrivateKey ,
15+ Ed25519PublicKey ,
16+ )
1417from fastapi .testclient import TestClient
1518
19+ import dispatch
1620from dispatch .experimental .durable .registry import clear_functions
1721from dispatch .fastapi import Dispatch
1822from dispatch .function import Arguments , Error , Function , Input , Output
1923from dispatch .proto import _any_unpickle as any_unpickle
2024from dispatch .sdk .v1 import call_pb2 as call_pb
2125from dispatch .sdk .v1 import function_pb2 as function_pb
22- from dispatch .signature import parse_verification_key , public_key_from_pem
26+ from dispatch .signature import (
27+ parse_verification_key ,
28+ private_key_from_pem ,
29+ public_key_from_pem ,
30+ )
2331from dispatch .status import Status
24- from dispatch .test import EndpointClient
32+ from dispatch .test import DispatchServer , DispatchService , EndpointClient
2533
2634
27- def create_dispatch_instance (app , endpoint ):
35+ def create_dispatch_instance (app : fastapi . FastAPI , endpoint : str ):
2836 return Dispatch (
2937 app ,
3038 endpoint = endpoint ,
@@ -33,6 +41,13 @@ def create_dispatch_instance(app, endpoint):
3341 )
3442
3543
44+ def create_endpoint_client (
45+ app : fastapi .FastAPI , signing_key : Optional [Ed25519PrivateKey ] = None
46+ ):
47+ http_client = TestClient (app )
48+ return EndpointClient (http_client , signing_key )
49+
50+
3651class TestFastAPI (unittest .TestCase ):
3752 def test_Dispatch (self ):
3853 app = fastapi .FastAPI ()
@@ -54,10 +69,6 @@ def read_root():
5469 resp = client .post ("/dispatch.sdk.v1.FunctionService/Run" )
5570 self .assertEqual (resp .status_code , 400 )
5671
57- def test_Dispatch_no_app (self ):
58- with self .assertRaises (ValueError ):
59- create_dispatch_instance (None , endpoint = "http://127.0.0.1:9999" )
60-
6172 @mock .patch .dict (os .environ , {"DISPATCH_ENDPOINT_URL" : "" })
6273 def test_Dispatch_no_endpoint (self ):
6374 app = fastapi .FastAPI ()
@@ -79,8 +90,7 @@ def my_function(input: Input) -> Output:
7990 f"You told me: '{ input .input } ' ({ len (input .input )} characters)"
8091 )
8192
82- client = EndpointClient .from_app (app )
83-
93+ client = create_endpoint_client (app )
8494 pickled = pickle .dumps ("Hello World!" )
8595 input_any = google .protobuf .any_pb2 .Any ()
8696 input_any .Pack (google .protobuf .wrappers_pb2 .BytesValue (value = pickled ))
@@ -102,6 +112,96 @@ def my_function(input: Input) -> Output:
102112 self .assertEqual (output , "You told me: 'Hello World!' (12 characters)" )
103113
104114
115+ signing_key = private_key_from_pem (
116+ """
117+ -----BEGIN PRIVATE KEY-----
118+ MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF
119+ -----END PRIVATE KEY-----
120+ """
121+ )
122+
123+ verification_key = public_key_from_pem (
124+ """
125+ -----BEGIN PUBLIC KEY-----
126+ MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
127+ -----END PUBLIC KEY-----
128+ """
129+ )
130+
131+
132+ class TestFullFastapi (unittest .TestCase ):
133+ def setUp (self ):
134+ self .endpoint_app = fastapi .FastAPI ()
135+ endpoint_client = create_endpoint_client (self .endpoint_app , signing_key )
136+
137+ api_key = "0000000000000000"
138+ self .dispatch_service = DispatchService (
139+ endpoint_client , api_key , collect_roundtrips = True
140+ )
141+ self .dispatch_server = DispatchServer (self .dispatch_service )
142+ self .dispatch_client = dispatch .Client (
143+ api_key , api_url = self .dispatch_server .url
144+ )
145+
146+ self .dispatch = Dispatch (
147+ self .endpoint_app ,
148+ endpoint = "http://function-service" , # unused
149+ verification_key = verification_key ,
150+ api_key = api_key ,
151+ api_url = self .dispatch_server .url ,
152+ )
153+
154+ self .dispatch_server .start ()
155+
156+ def tearDown (self ):
157+ self .dispatch_server .stop ()
158+
159+ def test_simple_end_to_end (self ):
160+ # The FastAPI server.
161+ @self .dispatch .function
162+ def my_function (name : str ) -> str :
163+ return f"Hello world: { name } "
164+
165+ call = my_function .build_call (52 )
166+ self .assertEqual (call .function .split ("." )[- 1 ], "my_function" )
167+
168+ # The client.
169+ [dispatch_id ] = self .dispatch_client .dispatch ([my_function .build_call (52 )])
170+
171+ # Simulate execution for testing purposes.
172+ self .dispatch_service .dispatch_calls ()
173+
174+ # Validate results.
175+ roundtrips = self .dispatch_service .roundtrips [dispatch_id ]
176+ self .assertEqual (len (roundtrips ), 1 )
177+ _ , response = roundtrips [0 ]
178+ self .assertEqual (any_unpickle (response .exit .result .output ), "Hello world: 52" )
179+
180+ def test_simple_missing_signature (self ):
181+ @self .dispatch .function
182+ async def my_function (name : str ) -> str :
183+ return f"Hello world: { name } "
184+
185+ call = my_function .build_call (52 )
186+ self .assertEqual (call .function .split ("." )[- 1 ], "my_function" )
187+
188+ [dispatch_id ] = self .dispatch_client .dispatch ([call ])
189+
190+ self .dispatch_service .endpoint_client = create_endpoint_client (
191+ self .endpoint_app
192+ ) # no signing key
193+ try :
194+ self .dispatch_service .dispatch_calls ()
195+ except httpx .HTTPStatusError as e :
196+ assert e .response .status_code == 403
197+ assert e .response .json () == {
198+ "code" : "permission_denied" ,
199+ "message" : 'Expected "Signature-Input" header field to be present' ,
200+ }
201+ else :
202+ assert False , "Expected HTTPStatusError"
203+
204+
105205def response_output (resp : function_pb .RunResponse ) -> Any :
106206 return any_unpickle (resp .exit .result .output )
107207
@@ -120,7 +220,7 @@ def root():
120220 self .app , endpoint = "https://127.0.0.1:9999"
121221 )
122222 self .http_client = TestClient (self .app )
123- self .client = EndpointClient . from_app (self .app )
223+ self .client = create_endpoint_client (self .app )
124224
125225 def execute (
126226 self , func : Function , input = None , state = None , calls = None
0 commit comments