Skip to content

Commit e168b74

Browse files
committed
Ported @tdraier to latest version of graphql-java-servlet and added supporting tests
1 parent 6aa0828 commit e168b74

File tree

7 files changed

+270
-14
lines changed

7 files changed

+270
-14
lines changed

src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import graphql.schema.GraphQLFieldDefinition;
88
import graphql.servlet.internal.GraphQLRequest;
99
import graphql.servlet.internal.VariableMapper;
10+
import org.reactivestreams.Publisher;
11+
import org.reactivestreams.Subscriber;
12+
import org.reactivestreams.Subscription;
1013
import org.slf4j.Logger;
1114
import org.slf4j.LoggerFactory;
1215

13-
import javax.servlet.AsyncContext;
14-
import javax.servlet.Servlet;
15-
import javax.servlet.ServletConfig;
16-
import javax.servlet.ServletException;
16+
import javax.servlet.*;
1717
import javax.servlet.http.HttpServlet;
1818
import javax.servlet.http.HttpServletRequest;
1919
import javax.servlet.http.HttpServletResponse;
@@ -30,6 +30,7 @@
3030
import java.util.Map;
3131
import java.util.Objects;
3232
import java.util.Optional;
33+
import java.util.concurrent.atomic.AtomicReference;
3334
import java.util.function.BiConsumer;
3435
import java.util.function.Consumer;
3536
import java.util.function.Function;
@@ -43,6 +44,7 @@ public abstract class AbstractGraphQLHttpServlet extends HttpServlet implements
4344
public static final Logger log = LoggerFactory.getLogger(AbstractGraphQLHttpServlet.class);
4445

4546
public static final String APPLICATION_JSON_UTF8 = "application/json;charset=UTF-8";
47+
public static final String APPLICATION_EVENT_STREAM_UTF8 = "text/event-stream;charset=UTF-8";
4648
public static final String APPLICATION_GRAPHQL = "application/graphql";
4749
public static final int STATUS_OK = 200;
4850
public static final int STATUS_BAD_REQUEST = 400;
@@ -334,9 +336,21 @@ private Optional<Part> getFileItem(Map<String, List<Part>> fileItems, String nam
334336
private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLSingleInvocationInput invocationInput, HttpServletResponse resp) throws IOException {
335337
ExecutionResult result = queryInvoker.query(invocationInput);
336338

337-
resp.setContentType(APPLICATION_JSON_UTF8);
338-
resp.setStatus(STATUS_OK);
339-
resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result));
339+
if (!(result.getData() instanceof Publisher)) {
340+
resp.setContentType(APPLICATION_JSON_UTF8);
341+
resp.setStatus(STATUS_OK);
342+
resp.getWriter().write(graphQLObjectMapper.serializeResultAsJson(result));
343+
} else {
344+
resp.setContentType(APPLICATION_EVENT_STREAM_UTF8);
345+
resp.setStatus(STATUS_OK);
346+
347+
HttpServletRequest req = invocationInput.getContext().getHttpServletRequest().get();
348+
AsyncContext asyncContext = req.startAsync(req, resp);
349+
asyncContext.setTimeout(60 * 1000);
350+
AtomicReference<Subscription> subscriptionRef = new AtomicReference<>();
351+
asyncContext.addListener(new SubscriptionAsyncListener(subscriptionRef));
352+
((Publisher<ExecutionResult>) result.getData()).subscribe(new ExecutionResultSubscriber(subscriptionRef, asyncContext, graphQLObjectMapper));
353+
}
340354
}
341355

342356
private void queryBatched(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLBatchedInvocationInput invocationInput, HttpServletResponse resp) throws Exception {
@@ -437,4 +451,67 @@ default void accept(HttpServletRequest request, HttpServletResponse response) {
437451

438452
void handle(HttpServletRequest request, HttpServletResponse response) throws Exception;
439453
}
454+
455+
private static class SubscriptionAsyncListener implements AsyncListener {
456+
private final AtomicReference<Subscription> subscriptionRef;
457+
public SubscriptionAsyncListener(AtomicReference<Subscription> subscriptionRef) {
458+
this.subscriptionRef = subscriptionRef;
459+
}
460+
461+
@Override public void onComplete(AsyncEvent event) {
462+
subscriptionRef.get().cancel();
463+
}
464+
465+
@Override public void onTimeout(AsyncEvent event) {
466+
subscriptionRef.get().cancel();
467+
}
468+
469+
@Override public void onError(AsyncEvent event) {
470+
subscriptionRef.get().cancel();
471+
}
472+
473+
@Override public void onStartAsync(AsyncEvent event) {
474+
}
475+
}
476+
477+
478+
private static class ExecutionResultSubscriber implements Subscriber<ExecutionResult> {
479+
480+
private final AtomicReference<Subscription> subscriptionRef;
481+
private final AsyncContext asyncContext;
482+
private final GraphQLObjectMapper graphQLObjectMapper;
483+
484+
public ExecutionResultSubscriber(AtomicReference<Subscription> subscriptionRef, AsyncContext asyncContext, GraphQLObjectMapper graphQLObjectMapper) {
485+
this.subscriptionRef = subscriptionRef;
486+
this.asyncContext = asyncContext;
487+
this.graphQLObjectMapper = graphQLObjectMapper;
488+
}
489+
490+
@Override
491+
public void onSubscribe(Subscription subscription) {
492+
subscriptionRef.set(subscription);
493+
subscriptionRef.get().request(1);
494+
}
495+
496+
@Override
497+
public void onNext(ExecutionResult executionResult) {
498+
try {
499+
Writer writer = asyncContext.getResponse().getWriter();
500+
writer.write("data: " + graphQLObjectMapper.serializeResultAsJson(executionResult) + "\n\n");
501+
writer.flush();
502+
subscriptionRef.get().request(1);
503+
} catch (IOException ignored) {
504+
}
505+
}
506+
507+
@Override
508+
public void onError(Throwable t) {
509+
asyncContext.complete();
510+
}
511+
512+
@Override
513+
public void onComplete() {
514+
asyncContext.complete();
515+
}
516+
}
440517
}

src/main/java/graphql/servlet/GraphQLSchemaProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static GraphQLSchema copyReadOnly(GraphQLSchema schema) {
3434

3535
/**
3636
* @param request the http request
37-
* @return a read-only schema based on the request (auth, etc). Should return the same schema (query-only version) as {@link #getSchema(HttpServletRequest)} for a given request.
37+
* @return a read-only schema based on the request (auth, etc). Should return the same schema (query/subscription-only version) as {@link #getSchema(HttpServletRequest)} for a given request.
3838
*/
3939
GraphQLSchema getReadOnlySchema(HttpServletRequest request);
4040

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package graphql.servlet;
2+
3+
import graphql.schema.GraphQLFieldDefinition;
4+
5+
import java.util.Collection;
6+
7+
public interface GraphQLSubscriptionProvider extends GraphQLProvider {
8+
Collection<GraphQLFieldDefinition> getSubscriptions();
9+
}

src/main/java/graphql/servlet/OsgiGraphQLHttpServlet.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public class OsgiGraphQLHttpServlet extends AbstractGraphQLHttpServlet {
3333

3434
private final List<GraphQLQueryProvider> queryProviders = new ArrayList<>();
3535
private final List<GraphQLMutationProvider> mutationProviders = new ArrayList<>();
36+
private final List<GraphQLSubscriptionProvider> subscriptionProviders = new ArrayList<>();
3637
private final List<GraphQLTypesProvider> typesProviders = new ArrayList<>();
3738

3839
private final GraphQLQueryInvoker queryInvoker;
@@ -151,8 +152,23 @@ private void doUpdateSchema() {
151152
}
152153
}
153154

155+
GraphQLObjectType subscriptionType = null;
156+
157+
if (!subscriptionProviders.isEmpty()) {
158+
final GraphQLObjectType.Builder subscriptionTypeBuilder = newObject().name("Subscription").description("Root subscription type");
159+
160+
for (GraphQLSubscriptionProvider provider : subscriptionProviders) {
161+
provider.getSubscriptions().forEach(subscriptionTypeBuilder::field);
162+
}
163+
164+
if (!subscriptionTypeBuilder.build().getFieldDefinitions().isEmpty()) {
165+
subscriptionType = subscriptionTypeBuilder.build();
166+
}
167+
}
168+
154169
this.schemaProvider = new DefaultGraphQLSchemaProvider(newSchema().query(queryTypeBuilder.build())
155170
.mutation(mutationType)
171+
.subscription(subscriptionType)
156172
.additionalTypes(types)
157173
.build());
158174
}
@@ -165,6 +181,9 @@ public void bindProvider(GraphQLProvider provider) {
165181
if (provider instanceof GraphQLMutationProvider) {
166182
mutationProviders.add((GraphQLMutationProvider) provider);
167183
}
184+
if (provider instanceof GraphQLSubscriptionProvider) {
185+
subscriptionProviders.add((GraphQLSubscriptionProvider) provider);
186+
}
168187
if (provider instanceof GraphQLTypesProvider) {
169188
typesProviders.add((GraphQLTypesProvider) provider);
170189
}
@@ -177,6 +196,9 @@ public void unbindProvider(GraphQLProvider provider) {
177196
if (provider instanceof GraphQLMutationProvider) {
178197
mutationProviders.remove(provider);
179198
}
199+
if (provider instanceof GraphQLSubscriptionProvider) {
200+
subscriptionProviders.remove(provider);
201+
}
180202
if (provider instanceof GraphQLTypesProvider) {
181203
typesProviders.remove(provider);
182204
}
@@ -203,6 +225,16 @@ public void unbindMutationProvider(GraphQLMutationProvider mutationProvider) {
203225
updateSchema();
204226
}
205227

228+
@Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC)
229+
public void bindSubscriptionProvider(GraphQLSubscriptionProvider subscriptionProvider) {
230+
subscriptionProviders.add(subscriptionProvider);
231+
updateSchema();
232+
}
233+
public void unbindSubscriptionProvider(GraphQLSubscriptionProvider subscriptionProvider) {
234+
subscriptionProviders.remove(subscriptionProvider);
235+
updateSchema();
236+
}
237+
206238
@Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC)
207239
public void bindTypesProvider(GraphQLTypesProvider typesProvider) {
208240
typesProviders.add(typesProvider);

src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ package graphql.servlet
22

33
import com.fasterxml.jackson.databind.ObjectMapper
44
import graphql.Scalars
5-
import graphql.annotations.annotationTypes.GraphQLType
65
import graphql.execution.ExecutionStepInfo
76
import graphql.execution.instrumentation.ChainedInstrumentation
8-
97
import graphql.execution.instrumentation.Instrumentation
8+
import graphql.execution.reactive.SingleSubscriberPublisher
109
import graphql.schema.GraphQLNonNull
1110
import org.dataloader.DataLoaderRegistry
1211
import org.springframework.mock.web.MockHttpServletRequest
@@ -17,6 +16,9 @@ import spock.lang.Specification
1716

1817
import javax.servlet.ServletInputStream
1918
import javax.servlet.http.HttpServletRequest
19+
import java.util.concurrent.CountDownLatch
20+
import java.util.concurrent.TimeUnit
21+
import java.util.concurrent.atomic.AtomicReference
2022

2123
/**
2224
* @author Andrew Potter
@@ -27,17 +29,32 @@ class AbstractGraphQLHttpServletSpec extends Specification {
2729
public static final int STATUS_BAD_REQUEST = 400
2830
public static final int STATUS_ERROR = 500
2931
public static final String CONTENT_TYPE_JSON_UTF8 = 'application/json;charset=UTF-8'
32+
public static final String CONTENT_TYPE_SERVER_SENT_EVENTS = 'text/event-stream;charset=UTF-8'
3033

3134
@Shared
3235
ObjectMapper mapper = new ObjectMapper()
3336

3437
AbstractGraphQLHttpServlet servlet
3538
MockHttpServletRequest request
3639
MockHttpServletResponse response
40+
CountDownLatch subscriptionLatch
3741

3842
def setup() {
39-
servlet = TestUtils.createServlet()
43+
subscriptionLatch = new CountDownLatch(1)
44+
servlet = TestUtils.createServlet({ env -> env.arguments.arg }, { env -> env.arguments.arg }, { env ->
45+
AtomicReference<SingleSubscriberPublisher<String>> publisherRef = new AtomicReference<>()
46+
publisherRef.set(new SingleSubscriberPublisher<String>({
47+
SingleSubscriberPublisher<String> publisher = publisherRef.get()
48+
publisher.offer("First\n\n" + env.arguments.arg)
49+
publisher.offer("Second\n\n" + env.arguments.arg)
50+
publisher.noMoreData()
51+
subscriptionLatch.countDown()
52+
}))
53+
return publisherRef.get()
54+
})
55+
4056
request = new MockHttpServletRequest()
57+
request.asyncSupported = true
4158
response = new MockHttpServletResponse()
4259
}
4360

@@ -46,6 +63,17 @@ class AbstractGraphQLHttpServletSpec extends Specification {
4663
mapper.readValue(response.getContentAsByteArray(), Map)
4764
}
4865

66+
List<Map<String, Object>> getSubscriptionResponseContent() {
67+
String[] data = response.getContentAsString().split("\n\n")
68+
return data.collect { dataLine ->
69+
if (dataLine.startsWith("data: ")) {
70+
return mapper.readValue(dataLine.substring(5), Map)
71+
} else {
72+
throw new IllegalStateException("Could not read event stream")
73+
}
74+
}
75+
}
76+
4977
List<Map<String, Object>> getBatchedResponseContent() {
5078
mapper.readValue(response.getContentAsByteArray(), List)
5179
}
@@ -263,6 +291,26 @@ class AbstractGraphQLHttpServletSpec extends Specification {
263291
getBatchedResponseContent()[1].errors.size() == 1
264292
}
265293

294+
def "subscription query over HTTP GET with variables as string returns data"() {
295+
setup:
296+
request.addParameter('query', 'subscription Subscription($arg: String!) { echo(arg: $arg) }')
297+
request.addParameter('operationName', 'Subscription')
298+
request.addParameter( 'variables', '{"arg": "test"}')
299+
request.setAsyncSupported(true)
300+
301+
when:
302+
servlet.doGet(request, response)
303+
then:
304+
response.getStatus() == STATUS_OK
305+
response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS
306+
307+
when:
308+
subscriptionLatch.await(1, TimeUnit.SECONDS)
309+
then:
310+
getSubscriptionResponseContent()[0].data.echo == "First\n\ntest"
311+
getSubscriptionResponseContent()[1].data.echo == "Second\n\ntest"
312+
}
313+
266314
def "query over HTTP POST without part or body returns bad request"() {
267315
when:
268316
servlet.doPost(request, response)
@@ -903,6 +951,24 @@ class AbstractGraphQLHttpServletSpec extends Specification {
903951
getBatchedResponseContent()[1].data.echo == "test"
904952
}
905953

954+
def "subscription query over HTTP POST with variables as string returns data"() {
955+
setup:
956+
request.setContent('{"query": "subscription Subscription($arg: String!) { echo(arg: $arg) }", "operationName": "Subscription", "variables": {"arg": "test"}}'.bytes)
957+
request.setAsyncSupported(true)
958+
959+
when:
960+
servlet.doPost(request, response)
961+
then:
962+
response.getStatus() == STATUS_OK
963+
response.getContentType() == CONTENT_TYPE_SERVER_SENT_EVENTS
964+
965+
when:
966+
subscriptionLatch.await(1, TimeUnit.SECONDS)
967+
then:
968+
getSubscriptionResponseContent()[0].data.echo == "First\n\ntest"
969+
getSubscriptionResponseContent()[1].data.echo == "Second\n\ntest"
970+
}
971+
906972
def "errors before graphql schema execution return internal server error"() {
907973
setup:
908974
servlet = SimpleGraphQLHttpServlet.newBuilder(GraphQLInvocationInputFactory.newBuilder {

src/test/groovy/graphql/servlet/OsgiGraphQLHttpServletSpec.groovy

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,41 @@ class OsgiGraphQLHttpServletSpec extends Specification {
7979
then:
8080
servlet.getSchemaProvider().getSchema().getMutationType() == null
8181
}
82+
83+
static class TestSubscriptionProvider implements GraphQLSubscriptionProvider {
84+
@Override
85+
Collection<GraphQLFieldDefinition> getSubscriptions() {
86+
return Collections.singletonList(newFieldDefinition().name("subscription").type(GraphQLAnnotations.object(Subscription.class)).build())
87+
}
88+
89+
90+
@GraphQLName("subscription")
91+
static class Subscription {
92+
@GraphQLField
93+
public String field;
94+
}
95+
}
96+
97+
def "subscription provider adds subscription objects"() {
98+
setup:
99+
OsgiGraphQLHttpServlet servlet = new OsgiGraphQLHttpServlet()
100+
TestSubscriptionProvider subscriptionProvider = new TestSubscriptionProvider()
101+
servlet.bindSubscriptionProvider(subscriptionProvider)
102+
GraphQLFieldDefinition subscription
103+
104+
when:
105+
subscription = servlet.getSchemaProvider().getSchema().getSubscriptionType().getFieldDefinition("subscription")
106+
then:
107+
subscription.getType().getName() == "subscription"
108+
109+
when:
110+
subscription = servlet.getSchemaProvider().getReadOnlySchema(null).getSubscriptionType().getFieldDefinition("subscription")
111+
then:
112+
subscription.getType().getName() == "subscription"
113+
114+
when:
115+
servlet.unbindSubscriptionProvider(subscriptionProvider)
116+
then:
117+
servlet.getSchemaProvider().getSchema().getSubscriptionType() == null
118+
}
82119
}

0 commit comments

Comments
 (0)