11from __future__ import annotations
22
3- import functools
43import inspect
54import logging
5+ from functools import wraps
66from types import FunctionType
77from typing import Any , Callable , Dict , TypeAlias
88
2323"""
2424
2525
26+ # https://stackoverflow.com/questions/653368/how-to-create-a-decorator-that-can-be-used-either-with-or-without-parameters
27+ def decorator (f ):
28+ """This decorator is intended to declare decorators that can be used with
29+ or without parameters. If the decorated function is called with a single
30+ callable argument, it is assumed to be a function and the decorator is
31+ applied to it. Otherwise, the decorator is called with the arguments
32+ provided and the result is returned.
33+ """
34+
35+ @wraps (f )
36+ def method (self , * args , ** kwargs ):
37+ if len (args ) == 1 and len (kwargs ) == 0 and callable (args [0 ]):
38+ return f (self , args [0 ])
39+
40+ def wrapper (func ):
41+ return f (self , func , * args , ** kwargs )
42+
43+ return wrapper
44+
45+ return method
46+
47+
2648class Function :
2749 """Callable wrapper around a function meant to be used throughout the
2850 Dispatch Python SDK.
2951 """
3052
31- __slots__ = ("_endpoint" , "_client" , "_name" , "_primitive_func" , "_func" , "call" )
53+ __slots__ = ("_endpoint" , "_client" , "_name" , "_primitive_func" , "_func" )
3254
3355 def __init__ (
3456 self ,
@@ -42,11 +64,12 @@ def __init__(
4264 self ._client = client
4365 self ._name = name
4466 self ._primitive_func = primitive_func
45- self ._func = func
46-
4767 # FIXME: is there a way to decorate the function at the definition
4868 # without making it a class method?
49- self .call = durable (self ._call_async )
69+ if inspect .iscoroutinefunction (func ):
70+ self ._func = durable (self ._call_async )
71+ else :
72+ self ._func = func
5073
5174 def __call__ (self , * args , ** kwargs ):
5275 return self ._func (* args , ** kwargs )
@@ -90,7 +113,7 @@ def _primitive_dispatch(self, input: Any = None) -> DispatchID:
90113 return dispatch_id
91114
92115 async def _call_async (self , * args , ** kwargs ) -> Any :
93- """Asynchronously call the function from a @dispatch.coroutine ."""
116+ """Asynchronously call the function from a @dispatch.function ."""
94117 return await dispatch .coroutine .call (
95118 self .build_call (* args , ** kwargs , correlation_id = None )
96119 )
@@ -142,39 +165,27 @@ def __init__(self, endpoint: str, client: Client | None):
142165 self ._endpoint = endpoint
143166 self ._client = client
144167
145- def function (self ) -> Callable [[FunctionType ], Function ]:
168+ @decorator
169+ def function (self , func : Callable ) -> Function :
146170 """Returns a decorator that registers functions."""
171+ return self ._register_function (func )
147172
148- # Note: the indirection here means that we can add parameters
149- # to the decorator later without breaking existing apps.
150- return self ._register_function
151-
152- def coroutine (self ) -> Callable [[FunctionType ], Function | FunctionType ]:
153- """Returns a decorator that registers coroutines."""
154-
155- # Note: the indirection here means that we can add parameters
156- # to the decorator later without breaking existing apps.
157- return self ._register_coroutine
158-
159- def primitive_function (self ) -> Callable [[PrimitiveFunctionType ], Function ]:
173+ @decorator
174+ def primitive_function (self , func : Callable ) -> Function :
160175 """Returns a decorator that registers primitive functions."""
161-
162- # Note: the indirection here means that we can add parameters
163- # to the decorator later without breaking existing apps.
164- return self ._register_primitive_function
176+ return self ._register_primitive_function (func )
165177
166178 def _register_function (self , func : Callable ) -> Function :
167179 if inspect .iscoroutinefunction (func ):
168- raise TypeError (
169- "async functions must be registered via @dispatch.coroutine"
170- )
180+ return self ._register_coroutine (func )
171181
172182 logger .info ("registering function: %s" , func .__qualname__ )
173183
174184 # Register the function with the experimental.durable package, in case
175185 # it's referenced from a @dispatch.coroutine.
176186 func = durable (func )
177187
188+ @wraps (func )
178189 def primitive_func (input : Input ) -> Output :
179190 try :
180191 try :
@@ -196,14 +207,11 @@ def primitive_func(input: Input) -> Output:
196207 return self ._register (func , primitive_func )
197208
198209 def _register_coroutine (self , func : Callable ) -> Function :
199- if not inspect .iscoroutinefunction (func ):
200- raise TypeError (f"{ func .__qualname__ } must be an async function" )
201-
202210 logger .info ("registering coroutine: %s" , func .__qualname__ )
203211
204212 func = durable (func )
205213
206- @functools . wraps (func )
214+ @wraps (func )
207215 def primitive_func (input : Input ) -> Output :
208216 return OneShotScheduler (func ).run (input )
209217
0 commit comments