77import graphql .schema .GraphQLFieldDefinition ;
88import graphql .servlet .internal .GraphQLRequest ;
99import graphql .servlet .internal .VariableMapper ;
10+ import org .reactivestreams .Publisher ;
11+ import org .reactivestreams .Subscriber ;
12+ import org .reactivestreams .Subscription ;
1013import org .slf4j .Logger ;
1114import org .slf4j .LoggerFactory ;
1215
13- import javax .servlet .AsyncContext ;
14- import javax .servlet .Servlet ;
15- import javax .servlet .ServletConfig ;
16- import javax .servlet .ServletException ;
16+ import javax .servlet .*;
1717import javax .servlet .http .HttpServlet ;
1818import javax .servlet .http .HttpServletRequest ;
1919import javax .servlet .http .HttpServletResponse ;
3030import java .util .Map ;
3131import java .util .Objects ;
3232import java .util .Optional ;
33+ import java .util .concurrent .CountDownLatch ;
34+ import java .util .concurrent .atomic .AtomicReference ;
3335import java .util .function .BiConsumer ;
3436import java .util .function .Consumer ;
3537import java .util .function .Function ;
@@ -43,6 +45,7 @@ public abstract class AbstractGraphQLHttpServlet extends HttpServlet implements
4345 public static final Logger log = LoggerFactory .getLogger (AbstractGraphQLHttpServlet .class );
4446
4547 public static final String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8" ;
48+ public static final String APPLICATION_EVENT_STREAM_UTF8 = "text/event-stream;charset=UTF-8" ;
4649 public static final String APPLICATION_GRAPHQL = "application/graphql" ;
4750 public static final int STATUS_OK = 200 ;
4851 public static final int STATUS_BAD_REQUEST = 400 ;
@@ -289,7 +292,7 @@ public String executeQuery(String query) {
289292
290293 private void doRequestAsync (HttpServletRequest request , HttpServletResponse response , HttpRequestHandler handler ) {
291294 if (configuration .isAsyncServletModeEnabled ()) {
292- AsyncContext asyncContext = request .startAsync ();
295+ AsyncContext asyncContext = request .startAsync (request , response );
293296 HttpServletRequest asyncRequest = (HttpServletRequest ) asyncContext .getRequest ();
294297 HttpServletResponse asyncResponse = (HttpServletResponse ) asyncContext .getResponse ();
295298 new Thread (() -> doRequest (asyncRequest , asyncResponse , handler , asyncContext )).start ();
@@ -334,9 +337,31 @@ private Optional<Part> getFileItem(Map<String, List<Part>> fileItems, String nam
334337 private void query (GraphQLQueryInvoker queryInvoker , GraphQLObjectMapper graphQLObjectMapper , GraphQLSingleInvocationInput invocationInput , HttpServletResponse resp ) throws IOException {
335338 ExecutionResult result = queryInvoker .query (invocationInput );
336339
337- resp .setContentType (APPLICATION_JSON_UTF8 );
338- resp .setStatus (STATUS_OK );
339- resp .getWriter ().write (graphQLObjectMapper .serializeResultAsJson (result ));
340+ if (!(result .getData () instanceof Publisher )) {
341+ resp .setContentType (APPLICATION_JSON_UTF8 );
342+ resp .setStatus (STATUS_OK );
343+ resp .getWriter ().write (graphQLObjectMapper .serializeResultAsJson (result ));
344+ } else {
345+ resp .setContentType (APPLICATION_EVENT_STREAM_UTF8 );
346+ resp .setStatus (STATUS_OK );
347+
348+ HttpServletRequest req = invocationInput .getContext ().getHttpServletRequest ().orElseThrow (IllegalStateException ::new );
349+ boolean isInAsyncThread = req .isAsyncStarted ();
350+ AsyncContext asyncContext = isInAsyncThread ? req .getAsyncContext () : req .startAsync (req , resp );
351+ asyncContext .setTimeout (configuration .getSubscriptionTimeout ());
352+ AtomicReference <Subscription > subscriptionRef = new AtomicReference <>();
353+ asyncContext .addListener (new SubscriptionAsyncListener (subscriptionRef ));
354+ ExecutionResultSubscriber subscriber = new ExecutionResultSubscriber (subscriptionRef , asyncContext , graphQLObjectMapper );
355+ ((Publisher <ExecutionResult >) result .getData ()).subscribe (subscriber );
356+ if (isInAsyncThread ) {
357+ // We need to delay the completion of async context until after the subscription has terminated, otherwise the AsyncContext is prematurely closed.
358+ try {
359+ subscriber .await ();
360+ } catch (InterruptedException e ) {
361+ Thread .currentThread ().interrupt ();
362+ }
363+ }
364+ }
340365 }
341366
342367 private void queryBatched (GraphQLQueryInvoker queryInvoker , GraphQLObjectMapper graphQLObjectMapper , GraphQLBatchedInvocationInput invocationInput , HttpServletResponse resp ) throws Exception {
@@ -437,4 +462,74 @@ default void accept(HttpServletRequest request, HttpServletResponse response) {
437462
438463 void handle (HttpServletRequest request , HttpServletResponse response ) throws Exception ;
439464 }
465+
466+ private static class SubscriptionAsyncListener implements AsyncListener {
467+ private final AtomicReference <Subscription > subscriptionRef ;
468+ public SubscriptionAsyncListener (AtomicReference <Subscription > subscriptionRef ) {
469+ this .subscriptionRef = subscriptionRef ;
470+ }
471+
472+ @ Override public void onComplete (AsyncEvent event ) {
473+ subscriptionRef .get ().cancel ();
474+ }
475+
476+ @ Override public void onTimeout (AsyncEvent event ) {
477+ subscriptionRef .get ().cancel ();
478+ }
479+
480+ @ Override public void onError (AsyncEvent event ) {
481+ subscriptionRef .get ().cancel ();
482+ }
483+
484+ @ Override public void onStartAsync (AsyncEvent event ) {
485+ }
486+ }
487+
488+
489+ private static class ExecutionResultSubscriber implements Subscriber <ExecutionResult > {
490+
491+ private final AtomicReference <Subscription > subscriptionRef ;
492+ private final AsyncContext asyncContext ;
493+ private final GraphQLObjectMapper graphQLObjectMapper ;
494+ private final CountDownLatch completedLatch = new CountDownLatch (1 );
495+
496+ public ExecutionResultSubscriber (AtomicReference <Subscription > subscriptionRef , AsyncContext asyncContext , GraphQLObjectMapper graphQLObjectMapper ) {
497+ this .subscriptionRef = subscriptionRef ;
498+ this .asyncContext = asyncContext ;
499+ this .graphQLObjectMapper = graphQLObjectMapper ;
500+ }
501+
502+ @ Override
503+ public void onSubscribe (Subscription subscription ) {
504+ subscriptionRef .set (subscription );
505+ subscriptionRef .get ().request (1 );
506+ }
507+
508+ @ Override
509+ public void onNext (ExecutionResult executionResult ) {
510+ try {
511+ Writer writer = asyncContext .getResponse ().getWriter ();
512+ writer .write ("data: " + graphQLObjectMapper .serializeResultAsJson (executionResult ) + "\n \n " );
513+ writer .flush ();
514+ subscriptionRef .get ().request (1 );
515+ } catch (IOException ignored ) {
516+ }
517+ }
518+
519+ @ Override
520+ public void onError (Throwable t ) {
521+ asyncContext .complete ();
522+ completedLatch .countDown ();
523+ }
524+
525+ @ Override
526+ public void onComplete () {
527+ asyncContext .complete ();
528+ completedLatch .countDown ();
529+ }
530+
531+ public void await () throws InterruptedException {
532+ completedLatch .await ();
533+ }
534+ }
440535}
0 commit comments