|
30 | 30 | import java.util.Map; |
31 | 31 | import java.util.Objects; |
32 | 32 | import java.util.Optional; |
| 33 | +import java.util.concurrent.CountDownLatch; |
33 | 34 | import java.util.concurrent.atomic.AtomicReference; |
34 | 35 | import java.util.function.BiConsumer; |
35 | 36 | import java.util.function.Consumer; |
@@ -291,7 +292,7 @@ public String executeQuery(String query) { |
291 | 292 |
|
292 | 293 | private void doRequestAsync(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) { |
293 | 294 | if (configuration.isAsyncServletModeEnabled()) { |
294 | | - AsyncContext asyncContext = request.startAsync(); |
| 295 | + AsyncContext asyncContext = request.startAsync(request, response); |
295 | 296 | HttpServletRequest asyncRequest = (HttpServletRequest) asyncContext.getRequest(); |
296 | 297 | HttpServletResponse asyncResponse = (HttpServletResponse) asyncContext.getResponse(); |
297 | 298 | new Thread(() -> doRequest(asyncRequest, asyncResponse, handler, asyncContext)).start(); |
@@ -344,12 +345,22 @@ private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQL |
344 | 345 | resp.setContentType(APPLICATION_EVENT_STREAM_UTF8); |
345 | 346 | resp.setStatus(STATUS_OK); |
346 | 347 |
|
347 | | - HttpServletRequest req = invocationInput.getContext().getHttpServletRequest().get(); |
348 | | - AsyncContext asyncContext = req.startAsync(req, resp); |
349 | | - asyncContext.setTimeout(60 * 1000); |
| 348 | + HttpServletRequest req = invocationInput.getContext().getHttpServletRequest().orElseThrow(IllegalStateException::new); |
| 349 | + boolean isInAsyncThread = req.isAsyncStarted(); |
| 350 | + AsyncContext asyncContext = isInAsyncThread ? req.getAsyncContext() : req.startAsync(req, resp); |
| 351 | + asyncContext.setTimeout(configuration.getSubscriptionTimeout()); |
350 | 352 | AtomicReference<Subscription> subscriptionRef = new AtomicReference<>(); |
351 | 353 | asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef)); |
352 | | - ((Publisher<ExecutionResult>) result.getData()).subscribe(new ExecutionResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper)); |
| 354 | + ExecutionResultSubscriber subscriber = new ExecutionResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper); |
| 355 | + ((Publisher<ExecutionResult>) result.getData()).subscribe(subscriber); |
| 356 | + if (isInAsyncThread) { |
| 357 | + // We need to delay the completion of async context until after the subscription has terminated, otherwise the AsyncContext is prematurely closed. |
| 358 | + try { |
| 359 | + subscriber.await(); |
| 360 | + } catch (InterruptedException e) { |
| 361 | + Thread.currentThread().interrupt(); |
| 362 | + } |
| 363 | + } |
353 | 364 | } |
354 | 365 | } |
355 | 366 |
|
@@ -480,6 +491,7 @@ private static class ExecutionResultSubscriber implements Subscriber<ExecutionRe |
480 | 491 | private final AtomicReference<Subscription> subscriptionRef; |
481 | 492 | private final AsyncContext asyncContext; |
482 | 493 | private final GraphQLObjectMapper graphQLObjectMapper; |
| 494 | + private final CountDownLatch completedLatch = new CountDownLatch(1); |
483 | 495 |
|
484 | 496 | public ExecutionResultSubscriber(AtomicReference<Subscription> subscriptionRef, AsyncContext asyncContext, GraphQLObjectMapper graphQLObjectMapper) { |
485 | 497 | this.subscriptionRef = subscriptionRef; |
@@ -507,11 +519,17 @@ public void onNext(ExecutionResult executionResult) { |
507 | 519 | @Override |
508 | 520 | public void onError(Throwable t) { |
509 | 521 | asyncContext.complete(); |
| 522 | + completedLatch.countDown(); |
510 | 523 | } |
511 | 524 |
|
512 | 525 | @Override |
513 | 526 | public void onComplete() { |
514 | 527 | asyncContext.complete(); |
| 528 | + completedLatch.countDown(); |
| 529 | + } |
| 530 | + |
| 531 | + public void await() throws InterruptedException { |
| 532 | + completedLatch.await(); |
515 | 533 | } |
516 | 534 | } |
517 | 535 | } |
0 commit comments