33import pickle
44import struct
55import unittest
6- from typing import Any
6+ from typing import Any , Optional
77from unittest import mock
88
99import fastapi
1010import google .protobuf .any_pb2
1111import google .protobuf .wrappers_pb2
1212import httpx
13- from cryptography .hazmat .primitives .asymmetric .ed25519 import Ed25519PublicKey
13+ from cryptography .hazmat .primitives .asymmetric .ed25519 import (
14+ Ed25519PrivateKey ,
15+ Ed25519PublicKey ,
16+ )
1417from fastapi .testclient import TestClient
1518
19+ import dispatch
1620from dispatch .experimental .durable .registry import clear_functions
1721from dispatch .fastapi import Dispatch
1822from dispatch .function import Arguments , Error , Function , Input , Output
1923from dispatch .proto import _any_unpickle as any_unpickle
2024from dispatch .sdk .v1 import call_pb2 as call_pb
2125from dispatch .sdk .v1 import function_pb2 as function_pb
22- from dispatch .signature import parse_verification_key , public_key_from_pem
26+ from dispatch .signature import (
27+ parse_verification_key ,
28+ private_key_from_pem ,
29+ public_key_from_pem ,
30+ )
2331from dispatch .status import Status
24- from dispatch .test import EndpointClient
32+ from dispatch .test import DispatchServer , DispatchService , EndpointClient
2533
2634
27- def create_dispatch_instance (app , endpoint ):
35+ def create_dispatch_instance (app : fastapi . FastAPI , endpoint : str ):
2836 return Dispatch (
2937 app ,
3038 endpoint = endpoint ,
@@ -33,6 +41,13 @@ def create_dispatch_instance(app, endpoint):
3341 )
3442
3543
44+ def create_endpoint_client (
45+ app : fastapi .FastAPI , signing_key : Optional [Ed25519PrivateKey ] = None
46+ ):
47+ http_client = TestClient (app )
48+ return EndpointClient (http_client , signing_key )
49+
50+
3651class TestFastAPI (unittest .TestCase ):
3752 def test_Dispatch (self ):
3853 app = fastapi .FastAPI ()
@@ -79,8 +94,7 @@ def my_function(input: Input) -> Output:
7994 f"You told me: '{ input .input } ' ({ len (input .input )} characters)"
8095 )
8196
82- client = EndpointClient .from_app (app )
83-
97+ client = create_endpoint_client (app )
8498 pickled = pickle .dumps ("Hello World!" )
8599 input_any = google .protobuf .any_pb2 .Any ()
86100 input_any .Pack (google .protobuf .wrappers_pb2 .BytesValue (value = pickled ))
@@ -102,6 +116,96 @@ def my_function(input: Input) -> Output:
102116 self .assertEqual (output , "You told me: 'Hello World!' (12 characters)" )
103117
104118
119+ signing_key = private_key_from_pem (
120+ """
121+ -----BEGIN PRIVATE KEY-----
122+ MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF
123+ -----END PRIVATE KEY-----
124+ """
125+ )
126+
127+ verification_key = public_key_from_pem (
128+ """
129+ -----BEGIN PUBLIC KEY-----
130+ MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
131+ -----END PUBLIC KEY-----
132+ """
133+ )
134+
135+
136+ class TestFullFastapi (unittest .TestCase ):
137+ def setUp (self ):
138+ self .endpoint_app = fastapi .FastAPI ()
139+ endpoint_client = create_endpoint_client (self .endpoint_app , signing_key )
140+
141+ api_key = "0000000000000000"
142+ self .dispatch_service = DispatchService (
143+ endpoint_client , api_key , collect_roundtrips = True
144+ )
145+ self .dispatch_server = DispatchServer (self .dispatch_service )
146+ self .dispatch_client = dispatch .Client (
147+ api_key , api_url = self .dispatch_server .url
148+ )
149+
150+ self .dispatch = Dispatch (
151+ self .endpoint_app ,
152+ endpoint = "http://function-service" , # unused
153+ verification_key = verification_key ,
154+ api_key = api_key ,
155+ api_url = self .dispatch_server .url ,
156+ )
157+
158+ self .dispatch_server .start ()
159+
160+ def tearDown (self ):
161+ self .dispatch_server .stop ()
162+
163+ def test_simple_end_to_end (self ):
164+ # The FastAPI server.
165+ @self .dispatch .function
166+ def my_function (name : str ) -> str :
167+ return f"Hello world: { name } "
168+
169+ call = my_function .build_call (52 )
170+ self .assertEqual (call .function .split ("." )[- 1 ], "my_function" )
171+
172+ # The client.
173+ [dispatch_id ] = self .dispatch_client .dispatch ([my_function .build_call (52 )])
174+
175+ # Simulate execution for testing purposes.
176+ self .dispatch_service .dispatch_calls ()
177+
178+ # Validate results.
179+ roundtrips = self .dispatch_service .roundtrips [dispatch_id ]
180+ self .assertEqual (len (roundtrips ), 1 )
181+ _ , response = roundtrips [0 ]
182+ self .assertEqual (any_unpickle (response .exit .result .output ), "Hello world: 52" )
183+
184+ def test_simple_missing_signature (self ):
185+ @self .dispatch .function
186+ async def my_function (name : str ) -> str :
187+ return f"Hello world: { name } "
188+
189+ call = my_function .build_call (52 )
190+ self .assertEqual (call .function .split ("." )[- 1 ], "my_function" )
191+
192+ [dispatch_id ] = self .dispatch_client .dispatch ([call ])
193+
194+ self .dispatch_service .endpoint_client = create_endpoint_client (
195+ self .endpoint_app
196+ ) # no signing key
197+ try :
198+ self .dispatch_service .dispatch_calls ()
199+ except httpx .HTTPStatusError as e :
200+ assert e .response .status_code == 403
201+ assert e .response .json () == {
202+ "code" : "permission_denied" ,
203+ "message" : 'Expected "Signature-Input" header field to be present' ,
204+ }
205+ else :
206+ assert False , "Expected HTTPStatusError"
207+
208+
105209def response_output (resp : function_pb .RunResponse ) -> Any :
106210 return any_unpickle (resp .exit .result .output )
107211
@@ -120,7 +224,7 @@ def root():
120224 self .app , endpoint = "https://127.0.0.1:9999"
121225 )
122226 self .http_client = TestClient (self .app )
123- self .client = EndpointClient . from_app (self .app )
227+ self .client = create_endpoint_client (self .app )
124228
125229 def execute (
126230 self , func : Function , input = None , state = None , calls = None
0 commit comments