Skip to content

Commit 673e9eb

Browse files
authored
Merge pull request #145 from wleroux/master
Ported @tdraier to latest version of graphql-java-servlet and added supporting tests
2 parents 6aa0828 + 27b3116 commit 673e9eb

9 files changed

+326
-17
lines changed

src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import graphql.schema.GraphQLFieldDefinition;
88
import graphql.servlet.internal.GraphQLRequest;
99
import graphql.servlet.internal.VariableMapper;
10+
import org.reactivestreams.Publisher;
11+
import org.reactivestreams.Subscriber;
12+
import org.reactivestreams.Subscription;
1013
import org.slf4j.Logger;
1114
import org.slf4j.LoggerFactory;
1215

13-
import javax.servlet.AsyncContext;
14-
import javax.servlet.Servlet;
15-
import javax.servlet.ServletConfig;
16-
import javax.servlet.ServletException;
16+
import javax.servlet.*;
1717
import javax.servlet.http.HttpServlet;
1818
import javax.servlet.http.HttpServletRequest;
1919
import javax.servlet.http.HttpServletResponse;
@@ -30,6 +30,8 @@
3030
import java.util.Map;
3131
import java.util.Objects;
3232
import java.util.Optional;
33+
import java.util.concurrent.CountDownLatch;
34+
import java.util.concurrent.atomic.AtomicReference;
3335
import java.util.function.BiConsumer;
3436
import java.util.function.Consumer;
3537
import java.util.function.Function;
@@ -43,6 +45,7 @@ public abstract class AbstractGraphQLHttpServlet extends HttpServlet implements
4345
public static final Logger log = LoggerFactory.getLogger(AbstractGraphQLHttpServlet.class);
4446

4547
public static final String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8";
48+
public static final String APPLICATION_EVENT_STREAM_UTF8 = "text/event-stream;charset=UTF-8";
4649
public static final String APPLICATION_GRAPHQL = "application/graphql";
4750
public static final int STATUS_OK = 200;
4851
public static final int STATUS_BAD_REQUEST = 400;
@@ -289,7 +292,7 @@ public String executeQuery(String query) {
289292

290293
private void doRequestAsync(HttpServletRequest request, HttpServletResponse response, HttpRequestHandler handler) {
291294
if (configuration.isAsyncServletModeEnabled()) {
292-
AsyncContext asyncContext = request.startAsync();
295+
AsyncContext asyncContext = request.startAsync(request, response);
293296
HttpServletRequest asyncRequest = (HttpServletRequest) asyncContext.getRequest();
294297
HttpServletResponse asyncResponse = (HttpServletResponse) asyncContext.getResponse();
295298
new Thread(() -> doRequest(asyncRequest, asyncResponse, handler, asyncContext)).start();
@@ -334,9 +337,31 @@ private Optional<Part> getFileItem(Map<String, List<Part>> fileItems, String nam
334337
private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLSingleInvocationInput invocationInput, HttpServletResponse resp) throws IOException {
335338
ExecutionResult result = queryInvoker.query(invocationInput);
336339

337-
resp.setContentType(APPLICATION_JSON_UTF8);
338-
resp.setStatus(STATUS_OK);
339-
resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result));
340+
if (!(result.getData() instanceof Publisher)) {
341+
resp.setContentType(APPLICATION_JSON_UTF8);
342+
resp.setStatus(STATUS_OK);
343+
resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result));
344+
} else {
345+
resp.setContentType(APPLICATION_EVENT_STREAM_UTF8);
346+
resp.setStatus(STATUS_OK);
347+
348+
HttpServletRequest req = invocationInput.getContext().getHttpServletRequest().orElseThrow(IllegalStateException::new);
349+
boolean isInAsyncThread = req.isAsyncStarted();
350+
AsyncContext asyncContext = isInAsyncThread ? req.getAsyncContext() : req.startAsync(req, resp);
351+
asyncContext.setTimeout(configuration.getSubscriptionTimeout());
352+
AtomicReference<Subscription> subscriptionRef = new AtomicReference<>();
353+
asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef));
354+
ExecutionResultSubscriber subscriber = new ExecutionResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper);
355+
((Publisher<ExecutionResult>) result.getData()).subscribe(subscriber);
356+
if (isInAsyncThread) {
357+
// We need to delay the completion of async context until after the subscription has terminated, otherwise the AsyncContext is prematurely closed.
358+
try {
359+
subscriber.await();
360+
} catch (InterruptedException e) {
361+
Thread.currentThread().interrupt();
362+
}
363+
}
364+
}
340365
}
341366

342367
private void queryBatched(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLBatchedInvocationInput invocationInput, HttpServletResponse resp) throws Exception {
@@ -437,4 +462,74 @@ default void accept(HttpServletRequest request, HttpServletResponse response) {
437462

438463
void handle(HttpServletRequest request, HttpServletResponse response) throws Exception;
439464
}
465+
466+
private static class SubscriptionAsyncListener implements AsyncListener {
467+
private final AtomicReference<Subscription> subscriptionRef;
468+
public SubscriptionAsyncListener(AtomicReference<Subscription> subscriptionRef) {
469+
this.subscriptionRef = subscriptionRef;
470+
}
471+
472+
@Override public void onComplete(AsyncEvent event) {
473+
subscriptionRef.get().cancel();
474+
}
475+
476+
@Override public void onTimeout(AsyncEvent event) {
477+
subscriptionRef.get().cancel();
478+
}
479+
480+
@Override public void onError(AsyncEvent event) {
481+
subscriptionRef.get().cancel();
482+
}
483+
484+
@Override public void onStartAsync(AsyncEvent event) {
485+
}
486+
}
487+
488+
489+
private static class ExecutionResultSubscriber implements Subscriber<ExecutionResult> {
490+
491+
private final AtomicReference<Subscription> subscriptionRef;
492+
private final AsyncContext asyncContext;
493+
private final GraphQLObjectMapper graphQLObjectMapper;
494+
private final CountDownLatch completedLatch = new CountDownLatch(1);
495+
496+
public ExecutionResultSubscriber(AtomicReference<Subscription> subscriptionRef, AsyncContext asyncContext, GraphQLObjectMapper graphQLObjectMapper) {
497+
this.subscriptionRef = subscriptionRef;
498+
this.asyncContext = asyncContext;
499+
this.graphQLObjectMapper = graphQLObjectMapper;
500+
}
501+
502+
@Override
503+
public void onSubscribe(Subscription subscription) {
504+
subscriptionRef.set(subscription);
505+
subscriptionRef.get().request(1);
506+
}
507+
508+
@Override
509+
public void onNext(ExecutionResult executionResult) {
510+
try {
511+
Writer writer = asyncContext.getResponse().getWriter();
512+
writer.write("data: " + graphQLObjectMapper.serializeResultAsJson(executionResult) + "\n\n");
513+
writer.flush();
514+
subscriptionRef.get().request(1);
515+
} catch (IOException ignored) {
516+
}
517+
}
518+
519+
@Override
520+
public void onError(Throwable t) {
521+
asyncContext.complete();
522+
completedLatch.countDown();
523+
}
524+
525+
@Override
526+
public void onComplete() {
527+
asyncContext.complete();
528+
completedLatch.countDown();
529+
}
530+
531+
public void await() throws InterruptedException {
532+
completedLatch.await();
533+
}
534+
}
440535
}

src/main/java/graphql/servlet/GraphQLConfiguration.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public class GraphQLConfiguration {
1313
private GraphQLObjectMapper objectMapper;
1414
private List<GraphQLServletListener> listeners;
1515
private boolean asyncServletModeEnabled;
16+
private long subscriptionTimeout;
1617

1718
public static GraphQLConfiguration.Builder with(GraphQLSchema schema) {
1819
return with(new DefaultGraphQLSchemaProvider(schema));
@@ -26,12 +27,13 @@ public static GraphQLConfiguration.Builder with(GraphQLInvocationInputFactory in
2627
return new Builder(invocationInputFactory);
2728
}
2829

29-
private GraphQLConfiguration(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper objectMapper, List<GraphQLServletListener> listeners, boolean asyncServletModeEnabled) {
30+
private GraphQLConfiguration(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper objectMapper, List<GraphQLServletListener> listeners, boolean asyncServletModeEnabled, long subscriptionTimeout) {
3031
this.invocationInputFactory = invocationInputFactory;
3132
this.queryInvoker = queryInvoker;
3233
this.objectMapper = objectMapper;
3334
this.listeners = listeners;
3435
this.asyncServletModeEnabled = asyncServletModeEnabled;
36+
this.subscriptionTimeout = subscriptionTimeout;
3537
}
3638

3739
public GraphQLInvocationInputFactory getInvocationInputFactory() {
@@ -62,6 +64,10 @@ public boolean remove(GraphQLServletListener listener) {
6264
return listeners.remove(listener);
6365
}
6466

67+
public long getSubscriptionTimeout() {
68+
return subscriptionTimeout;
69+
}
70+
6571
public static class Builder {
6672

6773
private GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder;
@@ -70,6 +76,7 @@ public static class Builder {
7076
private GraphQLObjectMapper objectMapper = GraphQLObjectMapper.newBuilder().build();
7177
private List<GraphQLServletListener> listeners = new ArrayList<>();
7278
private boolean asyncServletModeEnabled = false;
79+
private long subscriptionTimeout = 0;
7380

7481
private Builder(GraphQLInvocationInputFactory.Builder invocationInputFactoryBuilder) {
7582
this.invocationInputFactoryBuilder = invocationInputFactoryBuilder;
@@ -115,13 +122,19 @@ public Builder with(GraphQLRootObjectBuilder rootObjectBuilder) {
115122
return this;
116123
}
117124

125+
public Builder with(long subscriptionTimeout) {
126+
this.subscriptionTimeout = subscriptionTimeout;
127+
return this;
128+
}
129+
118130
public GraphQLConfiguration build() {
119131
return new GraphQLConfiguration(
120132
this.invocationInputFactory != null ? this.invocationInputFactory : invocationInputFactoryBuilder.build(),
121133
queryInvoker,
122134
objectMapper,
123135
listeners,
124-
asyncServletModeEnabled
136+
asyncServletModeEnabled,
137+
subscriptionTimeout
125138
);
126139
}
127140

src/main/java/graphql/servlet/GraphQLSchemaProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static GraphQLSchema copyReadOnly(GraphQLSchema schema) {
3434

3535
/**
3636
* @param request the http request
37-
* @return a read-only schema based on the request (auth, etc). Should return the same schema (query-only version) as {@link #getSchema(HttpServletRequest)} for a given request.
37+
* @return a read-only schema based on the request (auth, etc). Should return the same schema (query/subscription-only version) as {@link #getSchema(HttpServletRequest)} for a given request.
3838
*/
3939
GraphQLSchema getReadOnlySchema(HttpServletRequest request);
4040

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package graphql.servlet;
2+
3+
import graphql.schema.GraphQLFieldDefinition;
4+
5+
import java.util.Collection;
6+
7+
public interface GraphQLSubscriptionProvider extends GraphQLProvider {
8+
Collection<GraphQLFieldDefinition> getSubscriptions();
9+
}

src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public class OsgiGraphQLHttpServlet extends AbstractGraphQLHttpServlet {
3333

3434
private final List<GraphQLQueryProvider> queryProviders = new ArrayList<>();
3535
private final List<GraphQLMutationProvider> mutationProviders = new ArrayList<>();
36+
private final List<GraphQLSubscriptionProvider> subscriptionProviders = new ArrayList<>();
3637
private final List<GraphQLTypesProvider> typesProviders = new ArrayList<>();
3738

3839
private final GraphQLQueryInvoker queryInvoker;
@@ -151,8 +152,23 @@ private void doUpdateSchema() {
151152
}
152153
}
153154

155+
GraphQLObjectType subscriptionType = null;
156+
157+
if (!subscriptionProviders.isEmpty()) {
158+
final GraphQLObjectType.Builder subscriptionTypeBuilder = newObject().name("Subscription").description("Root subscription type");
159+
160+
for (GraphQLSubscriptionProvider provider : subscriptionProviders) {
161+
provider.getSubscriptions().forEach(subscriptionTypeBuilder::field);
162+
}
163+
164+
if (!subscriptionTypeBuilder.build().getFieldDefinitions().isEmpty()) {
165+
subscriptionType = subscriptionTypeBuilder.build();
166+
}
167+
}
168+
154169
this.schemaProvider = new DefaultGraphQLSchemaProvider(newSchema().query(queryTypeBuilder.build())
155170
.mutation(mutationType)
171+
.subscription(subscriptionType)
156172
.additionalTypes(types)
157173
.build());
158174
}
@@ -165,6 +181,9 @@ public void bindProvider(GraphQLProvider provider) {
165181
if (provider instanceof GraphQLMutationProvider) {
166182
mutationProviders.add((GraphQLMutationProvider) provider);
167183
}
184+
if (provider instanceof GraphQLSubscriptionProvider) {
185+
subscriptionProviders.add((GraphQLSubscriptionProvider) provider);
186+
}
168187
if (provider instanceof GraphQLTypesProvider) {
169188
typesProviders.add((GraphQLTypesProvider) provider);
170189
}
@@ -177,6 +196,9 @@ public void unbindProvider(GraphQLProvider provider) {
177196
if (provider instanceof GraphQLMutationProvider) {
178197
mutationProviders.remove(provider);
179198
}
199+
if (provider instanceof GraphQLSubscriptionProvider) {
200+
subscriptionProviders.remove(provider);
201+
}
180202
if (provider instanceof GraphQLTypesProvider) {
181203
typesProviders.remove(provider);
182204
}
@@ -203,6 +225,16 @@ public void unbindMutationProvider(GraphQLMutationProvider mutationProvider) {
203225
updateSchema();
204226
}
205227

228+
@Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC)
229+
public void bindSubscriptionProvider(GraphQLSubscriptionProvider subscriptionProvider) {
230+
subscriptionProviders.add(subscriptionProvider);
231+
updateSchema();
232+
}
233+
public void unbindSubscriptionProvider(GraphQLSubscriptionProvider subscriptionProvider) {
234+
subscriptionProviders.remove(subscriptionProvider);
235+
updateSchema();
236+
}
237+
206238
@Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC)
207239
public void bindTypesProvider(GraphQLTypesProvider typesProvider) {
208240
typesProviders.add(typesProvider);

src/main/java/graphql/servlet/SimpleGraphQLHttpServlet.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@ public SimpleGraphQLHttpServlet(GraphQLInvocationInputFactory invocationInputFac
3030
.build();
3131
}
3232

33+
/**
34+
* @deprecated use {@link GraphQLHttpServlet} instead
35+
*/
36+
@Deprecated
37+
public SimpleGraphQLHttpServlet(GraphQLInvocationInputFactory invocationInputFactory, GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, List<GraphQLServletListener> listeners, boolean asyncServletMode, long subscriptionTimeout) {
38+
super(listeners);
39+
this.configuration = GraphQLConfiguration.with(invocationInputFactory)
40+
.with(queryInvoker)
41+
.with(graphQLObjectMapper)
42+
.with(listeners != null ? listeners : new ArrayList<>())
43+
.with(asyncServletMode)
44+
.with(subscriptionTimeout)
45+
.build();
46+
}
47+
3348
private SimpleGraphQLHttpServlet(GraphQLConfiguration configuration) {
3449
this.configuration = Objects.requireNonNull(configuration, "configuration is required");
3550
}
@@ -77,6 +92,7 @@ public static class Builder {
7792
private GraphQLObjectMapper graphQLObjectMapper = GraphQLObjectMapper.newBuilder().build();
7893
private List<GraphQLServletListener> listeners;
7994
private boolean asyncServletMode;
95+
private long subscriptionTimeout;
8096

8197
Builder(GraphQLInvocationInputFactory invocationInputFactory) {
8298
this.invocationInputFactory = invocationInputFactory;
@@ -102,13 +118,19 @@ public Builder withListeners(List<GraphQLServletListener> listeners) {
102118
return this;
103119
}
104120

121+
public Builder withSubscriptionTimeout(long subscriptionTimeout) {
122+
this.subscriptionTimeout = subscriptionTimeout;
123+
return this;
124+
}
125+
105126
@Deprecated
106127
public SimpleGraphQLHttpServlet build() {
107128
GraphQLConfiguration configuration = GraphQLConfiguration.with(invocationInputFactory)
108129
.with(queryInvoker)
109130
.with(graphQLObjectMapper)
110131
.with(listeners != null ? listeners : new ArrayList<>())
111132
.with(asyncServletMode)
133+
.with(subscriptionTimeout)
112134
.build();
113135
return new SimpleGraphQLHttpServlet(configuration);
114136
}

0 commit comments

Comments
 (0)