1- from functools import partial
1+ from functools import partial , reduce
22from inspect import isfunction
3- from itertools import chain
43
5- from typing import Callable , Iterator , Dict , Tuple , Any , Iterable , Optional , cast
4+ from typing import Callable , Iterator , Dict , Tuple , Any , Optional
65
76__all__ = ["MiddlewareManager" ]
87
@@ -41,8 +40,10 @@ def get_field_resolver(
4140 if self ._middleware_resolvers is None :
4241 return field_resolver
4342 if field_resolver not in self ._cached_resolvers :
44- self ._cached_resolvers [field_resolver ] = middleware_chain (
45- field_resolver , self ._middleware_resolvers
43+ self ._cached_resolvers [field_resolver ] = reduce (
44+ lambda chained_fns , next_fn : partial (next_fn , chained_fns ),
45+ self ._middleware_resolvers ,
46+ field_resolver ,
4647 )
4748 return self ._cached_resolvers [field_resolver ]
4849
@@ -56,19 +57,3 @@ def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]
5657 resolver_func = getattr (middleware , "resolve" , None )
5758 if resolver_func is not None :
5859 yield resolver_func
59-
60-
61- def middleware_chain (
62- func : GraphQLFieldResolver , middlewares : Iterable [Callable ]
63- ) -> GraphQLFieldResolver :
64- """Chain the given function with the provided middlewares.
65-
66- Returns a new resolver function that is the chain of both.
67- """
68- if not middlewares :
69- return func
70- middlewares = chain ((func ,), middlewares )
71- last_func : Optional [GraphQLFieldResolver ] = None
72- for middleware in middlewares :
73- last_func = partial (middleware , last_func ) if last_func else middleware
74- return cast (GraphQLFieldResolver , last_func )
0 commit comments