|
4 | 4 |
|
5 | 5 | from asyncio import ensure_future, gather, shield, wait_for |
6 | 6 | from contextlib import suppress |
| 7 | +from copy import copy |
7 | 8 | from typing import ( |
8 | 9 | Any, |
9 | 10 | AsyncGenerator, |
@@ -219,6 +220,7 @@ def build( |
219 | 220 | subscribe_field_resolver: GraphQLFieldResolver | None = None, |
220 | 221 | middleware: Middleware | None = None, |
221 | 222 | is_awaitable: Callable[[Any], bool] | None = None, |
| 223 | + **custom_args: Any, |
222 | 224 | ) -> list[GraphQLError] | ExecutionContext: |
223 | 225 | """Build an execution context |
224 | 226 |
|
@@ -292,24 +294,14 @@ def build( |
292 | 294 | IncrementalPublisher(), |
293 | 295 | middleware_manager, |
294 | 296 | is_awaitable, |
| 297 | + **custom_args, |
295 | 298 | ) |
296 | 299 |
|
297 | 300 | def build_per_event_execution_context(self, payload: Any) -> ExecutionContext: |
298 | 301 | """Create a copy of the execution context for usage with subscribe events.""" |
299 | | - return self.__class__( |
300 | | - self.schema, |
301 | | - self.fragments, |
302 | | - payload, |
303 | | - self.context_value, |
304 | | - self.operation, |
305 | | - self.variable_values, |
306 | | - self.field_resolver, |
307 | | - self.type_resolver, |
308 | | - self.subscribe_field_resolver, |
309 | | - self.incremental_publisher, |
310 | | - self.middleware_manager, |
311 | | - self.is_awaitable, |
312 | | - ) |
| 302 | + context = copy(self) |
| 303 | + context.root_value = payload |
| 304 | + return context |
313 | 305 |
|
314 | 306 | def execute_operation( |
315 | 307 | self, initial_result_record: InitialResultRecord |
@@ -1709,6 +1701,7 @@ def execute( |
1709 | 1701 | middleware: Middleware | None = None, |
1710 | 1702 | execution_context_class: type[ExecutionContext] | None = None, |
1711 | 1703 | is_awaitable: Callable[[Any], bool] | None = None, |
| 1704 | + **custom_context_args: Any, |
1712 | 1705 | ) -> AwaitableOrValue[ExecutionResult]: |
1713 | 1706 | """Execute a GraphQL operation. |
1714 | 1707 |
|
@@ -1741,6 +1734,7 @@ def execute( |
1741 | 1734 | middleware, |
1742 | 1735 | execution_context_class, |
1743 | 1736 | is_awaitable, |
| 1737 | + **custom_context_args, |
1744 | 1738 | ) |
1745 | 1739 | if isinstance(result, ExecutionResult): |
1746 | 1740 | return result |
@@ -1769,6 +1763,7 @@ def experimental_execute_incrementally( |
1769 | 1763 | middleware: Middleware | None = None, |
1770 | 1764 | execution_context_class: type[ExecutionContext] | None = None, |
1771 | 1765 | is_awaitable: Callable[[Any], bool] | None = None, |
| 1766 | + **custom_context_args: Any, |
1772 | 1767 | ) -> AwaitableOrValue[ExecutionResult | ExperimentalIncrementalExecutionResults]: |
1773 | 1768 | """Execute GraphQL operation incrementally (internal implementation). |
1774 | 1769 |
|
@@ -1797,6 +1792,7 @@ def experimental_execute_incrementally( |
1797 | 1792 | subscribe_field_resolver, |
1798 | 1793 | middleware, |
1799 | 1794 | is_awaitable, |
| 1795 | + **custom_context_args, |
1800 | 1796 | ) |
1801 | 1797 |
|
1802 | 1798 | # Return early errors if execution context failed. |
@@ -2127,6 +2123,7 @@ def subscribe( |
2127 | 2123 | subscribe_field_resolver: GraphQLFieldResolver | None = None, |
2128 | 2124 | execution_context_class: type[ExecutionContext] | None = None, |
2129 | 2125 | middleware: MiddlewareManager | None = None, |
| 2126 | + **custom_context_args: Any, |
2130 | 2127 | ) -> AwaitableOrValue[AsyncIterator[ExecutionResult] | ExecutionResult]: |
2131 | 2128 | """Create a GraphQL subscription. |
2132 | 2129 |
|
@@ -2167,6 +2164,7 @@ def subscribe( |
2167 | 2164 | type_resolver, |
2168 | 2165 | subscribe_field_resolver, |
2169 | 2166 | middleware=middleware, |
| 2167 | + **custom_context_args, |
2170 | 2168 | ) |
2171 | 2169 |
|
2172 | 2170 | # Return early errors if execution context failed. |
@@ -2202,6 +2200,7 @@ def create_source_event_stream( |
2202 | 2200 | type_resolver: GraphQLTypeResolver | None = None, |
2203 | 2201 | subscribe_field_resolver: GraphQLFieldResolver | None = None, |
2204 | 2202 | execution_context_class: type[ExecutionContext] | None = None, |
| 2203 | + **custom_context_args: Any, |
2205 | 2204 | ) -> AwaitableOrValue[AsyncIterable[Any] | ExecutionResult]: |
2206 | 2205 | """Create source event stream |
2207 | 2206 |
|
@@ -2238,6 +2237,7 @@ def create_source_event_stream( |
2238 | 2237 | field_resolver, |
2239 | 2238 | type_resolver, |
2240 | 2239 | subscribe_field_resolver, |
| 2240 | + **custom_context_args, |
2241 | 2241 | ) |
2242 | 2242 |
|
2243 | 2243 | # Return early errors if execution context failed. |
|
0 commit comments