|
1 | 1 | import logging |
| 2 | +from abc import ABC |
2 | 3 | from typing import Any, Callable, Optional, Type, TypeVar |
3 | 4 |
|
4 | 5 | from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent |
|
9 | 10 | AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent) |
10 | 11 |
|
11 | 12 |
|
12 | | -class AppSyncResolver: |
| 13 | +class BaseRouter(ABC): |
| 14 | + current_event: AppSyncResolverEventT # type: ignore[valid-type] |
| 15 | + lambda_context: LambdaContext |
| 16 | + |
| 17 | + def __init__(self): |
| 18 | + self._resolvers: dict = {} |
| 19 | + |
| 20 | + def resolver(self, type_name: str = "*", field_name: Optional[str] = None): |
| 21 | + """Registers the resolver for field_name |
| 22 | +
|
| 23 | + Parameters |
| 24 | + ---------- |
| 25 | + type_name : str |
| 26 | + Type name |
| 27 | + field_name : str |
| 28 | + Field name |
| 29 | + """ |
| 30 | + |
| 31 | + def register_resolver(func): |
| 32 | + logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") |
| 33 | + self._resolvers[f"{type_name}.{field_name}"] = {"func": func} |
| 34 | + return func |
| 35 | + |
| 36 | + return register_resolver |
| 37 | + |
| 38 | + |
| 39 | +class AppSyncResolver(BaseRouter): |
13 | 40 | """ |
14 | 41 | AppSync resolver decorator |
15 | 42 |
|
@@ -40,29 +67,8 @@ def common_field() -> str: |
40 | 67 | return str(uuid.uuid4()) |
41 | 68 | """ |
42 | 69 |
|
43 | | - current_event: AppSyncResolverEventT # type: ignore[valid-type] |
44 | | - lambda_context: LambdaContext |
45 | | - |
46 | 70 | def __init__(self): |
47 | | - self._resolvers: dict = {} |
48 | | - |
49 | | - def resolver(self, type_name: str = "*", field_name: Optional[str] = None): |
50 | | - """Registers the resolver for field_name |
51 | | -
|
52 | | - Parameters |
53 | | - ---------- |
54 | | - type_name : str |
55 | | - Type name |
56 | | - field_name : str |
57 | | - Field name |
58 | | - """ |
59 | | - |
60 | | - def register_resolver(func): |
61 | | - logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") |
62 | | - self._resolvers[f"{type_name}.{field_name}"] = {"func": func} |
63 | | - return func |
64 | | - |
65 | | - return register_resolver |
| 71 | + super().__init__() |
66 | 72 |
|
67 | 73 | def resolve( |
68 | 74 | self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent |
@@ -136,10 +142,10 @@ def lambda_handler(event, context): |
136 | 142 | ValueError |
137 | 143 | If we could not find a field resolver |
138 | 144 | """ |
139 | | - self.current_event = data_model(event) |
140 | | - self.lambda_context = context |
141 | | - resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name) |
142 | | - return resolver(**self.current_event.arguments) |
| 145 | + BaseRouter.current_event = data_model(event) |
| 146 | + BaseRouter.lambda_context = context |
| 147 | + resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name) |
| 148 | + return resolver(**BaseRouter.current_event.arguments) |
143 | 149 |
|
144 | 150 | def _get_resolver(self, type_name: str, field_name: str) -> Callable: |
145 | 151 | """Get resolver for field_name |
@@ -167,3 +173,18 @@ def __call__( |
167 | 173 | ) -> Any: |
168 | 174 | """Implicit lambda handler which internally calls `resolve`""" |
169 | 175 | return self.resolve(event, context, data_model) |
| 176 | + |
| 177 | + def include_router(self, router: "Router") -> None: |
| 178 | + """Adds all resolvers defined in a router |
| 179 | +
|
| 180 | + Parameters |
| 181 | + ---------- |
| 182 | + router : Router |
| 183 | + A router containing a dict of field resolvers |
| 184 | + """ |
| 185 | + self._resolvers.update(router._resolvers) |
| 186 | + |
| 187 | + |
| 188 | +class Router(BaseRouter): |
| 189 | + def __init__(self): |
| 190 | + super().__init__() |
0 commit comments