@@ -4,13 +4,19 @@ import graphql.AssertException
44import graphql.annotations.annotationTypes.GraphQLField
55import graphql.annotations.annotationTypes.GraphQLName
66import graphql.annotations.processor.GraphQLAnnotations
7+ import graphql.execution.instrumentation.Instrumentation
8+ import graphql.execution.instrumentation.InstrumentationState
9+ import graphql.execution.instrumentation.SimpleInstrumentation
10+ import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters
711import graphql.kickstart.execution.GraphQLRequest
812import graphql.kickstart.execution.config.ExecutionStrategyProvider
13+ import graphql.kickstart.execution.config.InstrumentationProvider
914import graphql.kickstart.execution.context.DefaultGraphQLContext
1015import graphql.kickstart.execution.context.GraphQLContext
1116import graphql.kickstart.servlet.context.GraphQLServletContextBuilder
1217import graphql.kickstart.servlet.core.GraphQLServletListener
1318import graphql.kickstart.servlet.core.GraphQLServletRootObjectBuilder
19+ import graphql.kickstart.servlet.input.NoOpBatchInputPreProcessor
1420import graphql.kickstart.servlet.osgi.*
1521import graphql.schema.*
1622import org.dataloader.DataLoaderRegistry
@@ -25,19 +31,19 @@ class OsgiGraphQLHttpServletSpec extends Specification {
2531
2632 @Override
2733 Collection<GraphQLFieldDefinition > getQueries () {
28- List<GraphQLFieldDefinition > fieldDefinitions = new ArrayList<> ();
34+ List<GraphQLFieldDefinition > fieldDefinitions = new ArrayList<> ()
2935 fieldDefinitions. add(newFieldDefinition()
3036 .name(" query" )
3137 .type(new GraphQLAnnotations (). object(Query . class))
3238 .staticValue(new Query ())
33- .build());
34- return fieldDefinitions;
39+ .build())
40+ return fieldDefinitions
3541 }
3642
3743 @GraphQLName (" query" )
3844 static class Query {
3945 @GraphQLField
40- public String field;
46+ public String field
4147 }
4248
4349 }
@@ -55,15 +61,15 @@ class OsgiGraphQLHttpServletSpec extends Specification {
5561 query. getType(). name == " query"
5662
5763 when :
58- query = servlet. getConfiguration(). getInvocationInputFactory(). getSchemaProvider(). getReadOnlySchema(null ). getQueryType(). getFieldDefinition(" query" )
64+ query = servlet. getConfiguration(). getInvocationInputFactory(). getSchemaProvider(). getReadOnlySchema(). getQueryType(). getFieldDefinition(" query" )
5965 then :
6066 query. getType(). name == " query"
6167
6268 when :
6369 servlet. unbindQueryProvider(queryProvider)
6470 then :
6571 servlet. getConfiguration(). getInvocationInputFactory(). getSchemaProvider(). getSchema(). getQueryType(). getFieldDefinitions(). get(0 ). name == " _empty"
66- servlet. getConfiguration(). getInvocationInputFactory(). getSchemaProvider(). getReadOnlySchema(null ). getQueryType(). getFieldDefinitions(). get(0 ). name == " _empty"
72+ servlet. getConfiguration(). getInvocationInputFactory(). getSchemaProvider(). getReadOnlySchema(). getQueryType(). getFieldDefinitions(). get(0 ). name == " _empty"
6773 }
6874
6975 static class TestMutationProvider implements GraphQLMutationProvider {
@@ -110,7 +116,7 @@ class OsgiGraphQLHttpServletSpec extends Specification {
110116 @GraphQLName (" subscription" )
111117 static class Subscription {
112118 @GraphQLField
113- public String field;
119+ public String field
114120 }
115121 }
116122
@@ -127,7 +133,7 @@ class OsgiGraphQLHttpServletSpec extends Specification {
127133 subscription. getType(). getName() == " subscription"
128134
129135 when :
130- subscription = servlet. getConfiguration(). getInvocationInputFactory(). getSchemaProvider(). getReadOnlySchema(null ). getSubscriptionType(). getFieldDefinition(" subscription" )
136+ subscription = servlet. getConfiguration(). getInvocationInputFactory(). getSchemaProvider(). getReadOnlySchema(). getSubscriptionType(). getFieldDefinition(" subscription" )
131137 then :
132138 subscription. getType(). getName() == " subscription"
133139
@@ -151,7 +157,7 @@ class OsgiGraphQLHttpServletSpec extends Specification {
151157 static class TestCodeRegistryProvider implements GraphQLCodeRegistryProvider {
152158 @Override
153159 GraphQLCodeRegistry getCodeRegistry () {
154- return GraphQLCodeRegistry . newCodeRegistry(). typeResolver(" Type" , { env -> null }). build();
160+ return GraphQLCodeRegistry . newCodeRegistry(). typeResolver(" Type" , { env -> null }). build()
155161 }
156162 }
157163
@@ -341,4 +347,27 @@ class OsgiGraphQLHttpServletSpec extends Specification {
341347 then :
342348 0 * executionStrategy. getQueryExecutionStrategy()
343349 }
350+
351+ def " instrumentation provider is bound and unbound" () {
352+ setup :
353+ def servlet = new OsgiGraphQLHttpServlet ()
354+ def instrumentation = new SimpleInstrumentation ()
355+ def instrumentationProvider = Mock (InstrumentationProvider )
356+ instrumentationProvider. getInstrumentation() >> instrumentation
357+ def request = GraphQLRequest . createIntrospectionRequest()
358+ instrumentation. createState(_ as InstrumentationCreateStateParameters ) >> Mock (InstrumentationState )
359+
360+ when :
361+ servlet. setInstrumentationProvider(instrumentationProvider)
362+ def invocationInput = servlet. configuration. invocationInputFactory. create(request)
363+ servlet. configuration. graphQLInvoker. query(invocationInput)
364+
365+ then :
366+ noExceptionThrown()
367+
368+ when :
369+ servlet. unsetInstrumentationProvider(instrumentationProvider)
370+ then :
371+ noExceptionThrown()
372+ }
344373}
0 commit comments