|
1 | | -from typing import Callable, Iterator, Dict, Tuple, Any, Iterable, Optional, cast |
2 | | - |
3 | | -from inspect import isfunction |
4 | 1 | from functools import partial |
| 2 | +from inspect import isfunction |
5 | 3 | from itertools import chain |
6 | 4 |
|
| 5 | +from typing import ( |
| 6 | + Callable, Iterator, Dict, Tuple, Any, Iterable, Optional, cast) |
7 | 7 |
|
8 | | -from ..type import GraphQLFieldResolver |
9 | | - |
| 8 | +__all__ = ['MiddlewareManager'] |
10 | 9 |
|
11 | | -__all__ = ["MiddlewareManager", "middlewares"] |
12 | | - |
13 | | -# If the provided middleware is a class, this is the attribute we will look at |
14 | | -MIDDLEWARE_RESOLVER_FUNCTION = "resolve" |
| 10 | +GraphQLFieldResolver = Callable[..., Any] |
15 | 11 |
|
16 | 12 |
|
17 | 13 | class MiddlewareManager: |
18 | | - """MiddlewareManager helps to chain resolver functions with the provided |
19 | | - middleware functions and classes |
| 14 | + """Manager for the middleware chain. |
| 15 | +
|
| 16 | + This class helps to wrap resolver functions with the provided middleware |
| 17 | + functions and/or objects. The functions take the next middleware function |
| 18 | + as first argument. If middleware is provided as an object, it must provide |
| 19 | + a method 'resolve' that is used as the middleware function. |
20 | 20 | """ |
21 | 21 |
|
22 | | - __slots__ = ("middlewares", "_middleware_resolvers", "_cached_resolvers") |
| 22 | + __slots__ = 'middlewares', '_middleware_resolvers', '_cached_resolvers' |
23 | 23 |
|
24 | 24 | _cached_resolvers: Dict[GraphQLFieldResolver, GraphQLFieldResolver] |
25 | | - _middleware_resolvers: Optional[Tuple[Callable, ...]] |
| 25 | + _middleware_resolvers: Optional[Iterator[Callable]] |
26 | 26 |
|
27 | 27 | def __init__(self, *middlewares: Any) -> None: |
28 | 28 | self.middlewares = middlewares |
29 | | - if middlewares: |
30 | | - self._middleware_resolvers = tuple(get_middleware_resolvers(middlewares)) |
31 | | - else: |
32 | | - self.__middleware_resolvers = None |
| 29 | + self._middleware_resolvers = get_middleware_resolvers( |
| 30 | + middlewares) if middlewares else None |
33 | 31 | self._cached_resolvers = {} |
34 | 32 |
|
35 | 33 | def get_field_resolver( |
36 | | - self, field_resolver: GraphQLFieldResolver |
37 | | - ) -> GraphQLFieldResolver: |
38 | | - """Wraps the provided resolver returning a function that |
39 | | - executes chains the middleware functions with the resolver function""" |
| 34 | + self, field_resolver: GraphQLFieldResolver |
| 35 | + ) -> GraphQLFieldResolver: |
| 36 | + """Wrap the provided resolver with the middleware. |
| 37 | +
|
| 38 | + Returns a function that chains the middleware functions with the |
| 39 | + provided resolver function |
| 40 | + """ |
40 | 41 | if self._middleware_resolvers is None: |
41 | 42 | return field_resolver |
42 | 43 | if field_resolver not in self._cached_resolvers: |
43 | 44 | self._cached_resolvers[field_resolver] = middleware_chain( |
44 | | - field_resolver, self._middleware_resolvers |
45 | | - ) |
46 | | - |
| 45 | + field_resolver, self._middleware_resolvers) |
47 | 46 | return self._cached_resolvers[field_resolver] |
48 | 47 |
|
49 | 48 |
|
50 | | -middlewares = MiddlewareManager |
51 | | - |
52 | | - |
53 | | -def get_middleware_resolvers(middlewares: Tuple[Any, ...]) -> Iterator[Callable]: |
54 | | - """Returns the functions related to the middleware classes or functions""" |
| 49 | +def get_middleware_resolvers( |
| 50 | + middlewares: Tuple[Any, ...]) -> Iterator[Callable]: |
| 51 | + """Get a list of resolver functions from a list of classes or functions.""" |
55 | 52 | for middleware in middlewares: |
56 | | - # If the middleware is a function instead of a class |
57 | 53 | if isfunction(middleware): |
58 | 54 | yield middleware |
59 | | - resolver_func = getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION, None) |
60 | | - if resolver_func is not None: |
61 | | - yield resolver_func |
| 55 | + else: # middleware provided as object with 'resolve' method |
| 56 | + resolver_func = getattr(middleware, 'resolve', None) |
| 57 | + if resolver_func is not None: |
| 58 | + yield resolver_func |
62 | 59 |
|
63 | 60 |
|
64 | 61 | def middleware_chain( |
65 | | - func: GraphQLFieldResolver, middlewares: Iterable[Callable] |
66 | | -) -> GraphQLFieldResolver: |
67 | | - """Reduces the current function with the provided middlewares, |
68 | | - returning a new resolver function""" |
| 62 | + func: GraphQLFieldResolver, middlewares: Iterable[Callable] |
| 63 | + ) -> GraphQLFieldResolver: |
| 64 | + """Chain the given function with the provided middlewares. |
| 65 | +
|
| 66 | + Returns a new resolver function that is the chain of both. |
| 67 | + """ |
69 | 68 | if not middlewares: |
70 | 69 | return func |
71 | 70 | middlewares = chain((func,), middlewares) |
72 | 71 | last_func: Optional[GraphQLFieldResolver] = None |
73 | 72 | for middleware in middlewares: |
74 | 73 | last_func = partial(middleware, last_func) if last_func else middleware |
75 | | - |
76 | 74 | return cast(GraphQLFieldResolver, last_func) |
0 commit comments