11from __future__ import annotations
22
3+ import asyncio
34import inspect
45import logging
56import os
67from functools import wraps
78from types import CoroutineType
89from typing import (
910 Any ,
11+ Awaitable ,
1012 Callable ,
1113 Coroutine ,
1214 Dict ,
3335logger = logging .getLogger (__name__ )
3436
3537
36- PrimitiveFunctionType : TypeAlias = Callable [[Input ], Output ]
38+ PrimitiveFunctionType : TypeAlias = Callable [[Input ], Awaitable [ Output ] ]
3739"""A primitive function is a function that accepts a dispatch.proto.Input
3840and unconditionally returns a dispatch.proto.Output. It must not raise
3941exceptions.
@@ -70,8 +72,8 @@ def endpoint(self, value: str):
7072 def name (self ) -> str :
7173 return self ._name
7274
73- def _primitive_call (self , input : Input ) -> Output :
74- return self ._primitive_func (input )
75+ async def _primitive_call (self , input : Input ) -> Output :
76+ return await self ._primitive_func (input )
7577
7678 def _primitive_dispatch (self , input : Any = None ) -> DispatchID :
7779 [dispatch_id ] = self ._client .dispatch ([self ._build_primitive_call (input )])
@@ -226,6 +228,7 @@ def function(self, func: Callable[P, T]) -> Function[P, T]: ...
226228 def function (self , func ):
227229 """Decorator that registers functions."""
228230 name = func .__qualname__
231+
229232 if not inspect .iscoroutinefunction (func ):
230233 logger .info ("registering function: %s" , name )
231234 return self ._register_function (name , func )
@@ -237,23 +240,22 @@ def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]:
237240 func = durable (func )
238241
239242 @wraps (func )
240- async def async_wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T :
241- return func (* args , ** kwargs )
242-
243- async_wrapper .__qualname__ = f"{ name } _async"
243+ async def asyncio_wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T :
244+ loop = asyncio .get_running_loop ()
245+ return await loop .run_in_executor (None , func , * args , ** kwargs )
244246
245- return self ._register_coroutine (name , async_wrapper )
247+ asyncio_wrapper .__qualname__ = f"{ name } _asyncio"
248+ return self ._register_coroutine (name , asyncio_wrapper )
246249
247250 def _register_coroutine (
248251 self , name : str , func : Callable [P , Coroutine [Any , Any , T ]]
249252 ) -> Function [P , T ]:
250253 logger .info ("registering coroutine: %s" , name )
251-
252254 func = durable (func )
253255
254256 @wraps (func )
255- def primitive_func (input : Input ) -> Output :
256- return OneShotScheduler (func ).run (input )
257+ async def primitive_func (input : Input ) -> Output :
258+ return await OneShotScheduler (func ).run (input )
257259
258260 primitive_func .__qualname__ = f"{ name } _primitive"
259261 durable_primitive_func = durable (primitive_func )
0 commit comments