Skip to content

Commit d186d8f

Browse files
committed
Initial work up of fix for dispatching instrumentation
1 parent 8ed10e0 commit d186d8f

File tree

4 files changed

+185
-61
lines changed

4 files changed

+185
-61
lines changed

src/main/java/graphql/servlet/instrumentation/AbstractTrackingApproach.java

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@
1111
import graphql.execution.instrumentation.parameters.InstrumentationDeferredFieldParameters;
1212
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters;
1313
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters;
14+
import graphql.language.Field;
15+
import graphql.language.Selection;
16+
import graphql.language.SelectionSet;
17+
import graphql.schema.GraphQLOutputType;
1418
import org.dataloader.DataLoaderRegistry;
1519
import org.slf4j.Logger;
1620
import org.slf4j.LoggerFactory;
1721

1822
import java.util.Collections;
1923
import java.util.List;
24+
import java.util.Optional;
2025
import java.util.concurrent.CompletableFuture;
2126

2227
/**
@@ -45,6 +50,8 @@ protected RequestStack getStack() {
4550
public ExecutionStrategyInstrumentationContext beginExecutionStrategy(InstrumentationExecutionStrategyParameters parameters) {
4651
ExecutionId executionId = parameters.getExecutionContext().getExecutionId();
4752
ExecutionPath path = parameters.getExecutionStrategyParameters().getPath();
53+
List<Selection> selectionSet = Optional.ofNullable(parameters.getExecutionStrategyParameters().getField())
54+
.map(MergedField::getSingleField).map(Field::getSelectionSet).map(SelectionSet::getSelections).orElse(Collections.emptyList());
4855
int parentLevel = path.getLevel();
4956
int curLevel = parentLevel + 1;
5057
int fieldCount = parameters.getExecutionStrategyParameters().getFields().size();
@@ -67,7 +74,7 @@ public void onCompleted(ExecutionResult result, Throwable t) {
6774
@Override
6875
public void onFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList) {
6976
synchronized (stack) {
70-
stack.setStatus(executionId, handleOnFieldValuesInfo(fieldValueInfoList, stack, executionId, curLevel));
77+
stack.setStatus(executionId, handleOnFieldValuesInfo(fieldValueInfoList, stack, executionId, curLevel, selectionSet));
7178
if (stack.allReady()) {
7279
dispatchWithoutLocking();
7380
}
@@ -79,7 +86,7 @@ public void onDeferredField(MergedField field) {
7986
// fake fetch count for this field
8087
synchronized (stack) {
8188
stack.increaseFetchCount(executionId, curLevel);
82-
stack.setStatus(executionId, dispatchIfNeeded(stack, executionId, curLevel));
89+
stack.setStatus(executionId, dispatchIfNeeded(stack, executionId, curLevel, selectionSet));
8390
if (stack.allReady()) {
8491
dispatchWithoutLocking();
8592
}
@@ -91,7 +98,7 @@ public void onDeferredField(MergedField field) {
9198
//
9299
// thread safety : called with synchronised(stack)
93100
//
94-
private boolean handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, RequestStack stack, ExecutionId executionId, int curLevel) {
101+
private boolean handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, RequestStack stack, ExecutionId executionId, int curLevel, List<Selection> selectionSet) {
95102
stack.increaseHappenedOnFieldValueCalls(executionId, curLevel);
96103
int expectedStrategyCalls = 0;
97104
for (FieldValueInfo fieldValueInfo : fieldValueInfoList) {
@@ -102,7 +109,7 @@ private boolean handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList,
102109
}
103110
}
104111
stack.increaseExpectedStrategyCalls(executionId, curLevel + 1, expectedStrategyCalls);
105-
return dispatchIfNeeded(stack, executionId, curLevel + 1);
112+
return dispatchIfNeeded(stack, executionId, curLevel + 1, selectionSet);
106113
}
107114

108115
private int getCountForList(FieldValueInfo fieldValueInfo) {
@@ -121,6 +128,8 @@ private int getCountForList(FieldValueInfo fieldValueInfo) {
121128
public DeferredFieldInstrumentationContext beginDeferredField(InstrumentationDeferredFieldParameters parameters) {
122129
ExecutionId executionId = parameters.getExecutionContext().getExecutionId();
123130
int level = parameters.getExecutionStrategyParameters().getPath().getLevel();
131+
List<Selection> selectionSet = Optional.ofNullable(parameters.getExecutionStrategyParameters().getField())
132+
.map(MergedField::getSingleField).map(Field::getSelectionSet).map(SelectionSet::getSelections).orElse(Collections.emptyList());
124133
synchronized (stack) {
125134
stack.clearAndMarkCurrentLevelAsReady(executionId, level);
126135
}
@@ -138,7 +147,7 @@ public void onCompleted(ExecutionResult result, Throwable t) {
138147
@Override
139148
public void onFieldValueInfo(FieldValueInfo fieldValueInfo) {
140149
synchronized (stack) {
141-
stack.setStatus(executionId, handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), stack, executionId, level));
150+
stack.setStatus(executionId, handleOnFieldValuesInfo(Collections.singletonList(fieldValueInfo), stack, executionId, level, selectionSet));
142151
if (stack.allReady()) {
143152
dispatchWithoutLocking();
144153
}
@@ -151,14 +160,16 @@ public void onFieldValueInfo(FieldValueInfo fieldValueInfo) {
151160
public InstrumentationContext<Object> beginFieldFetch(InstrumentationFieldFetchParameters parameters) {
152161
ExecutionId executionId = parameters.getExecutionContext().getExecutionId();
153162
ExecutionPath path = parameters.getEnvironment().getExecutionStepInfo().getPath();
163+
List<Selection> selectionSet = Optional.ofNullable(parameters.getEnvironment().getField())
164+
.map(Field::getSelectionSet).map(SelectionSet::getSelections).orElse(Collections.emptyList());
154165
int level = path.getLevel();
155166
return new InstrumentationContext<Object>() {
156167

157168
@Override
158169
public void onDispatched(CompletableFuture result) {
159170
synchronized (stack) {
160171
stack.increaseFetchCount(executionId, level);
161-
stack.setStatus(executionId, dispatchIfNeeded(stack, executionId, level));
172+
stack.setStatus(executionId, dispatchIfNeeded(stack, executionId, level, selectionSet));
162173

163174
if (stack.allReady()) {
164175
dispatchWithoutLocking();
@@ -176,16 +187,19 @@ public void onCompleted(Object result, Throwable t) {
176187
public void removeTracking(ExecutionId executionId) {
177188
synchronized (stack) {
178189
stack.removeExecution(executionId);
190+
if (stack.allReady()) {
191+
dispatchWithoutLocking();
192+
}
179193
}
180194
}
181195

182196

183197
//
184198
// thread safety : called with synchronised(stack)
185199
//
186-
private boolean dispatchIfNeeded(RequestStack stack, ExecutionId executionId, int level) {
200+
private boolean dispatchIfNeeded(RequestStack stack, ExecutionId executionId, int level, List<Selection> selectionSet) {
187201
if (levelReady(stack, executionId, level)) {
188-
return stack.dispatchIfNotDispatchedBefore(executionId, level);
202+
return stack.dispatchIfNotDispatchedBefore(executionId, level, selectionSet);
189203
}
190204
return false;
191205
}

src/main/java/graphql/servlet/instrumentation/RequestStack.java

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
package graphql.servlet.instrumentation;
22

3+
import com.sun.org.apache.bcel.internal.generic.Select;
34
import graphql.Assert;
45
import graphql.execution.ExecutionId;
6+
import graphql.language.Selection;
7+
import graphql.language.SelectionSet;
8+
import graphql.schema.GraphQLOutputType;
59

10+
import java.util.Collections;
11+
import java.util.Iterator;
612
import java.util.LinkedHashMap;
713
import java.util.LinkedHashSet;
14+
import java.util.List;
815
import java.util.Map;
916
import java.util.Objects;
1017
import java.util.Set;
18+
import java.util.stream.Collectors;
1119

1220
/**
1321
* Manages sets of call stack state for ongoing executions.
@@ -27,7 +35,7 @@ private static class CallStack {
2735
private final Map<Integer, Integer> happenedOnFieldValueCallsPerLevel = new LinkedHashMap<>();
2836

2937

30-
private final Set<Integer> dispatchedLevels = new LinkedHashSet<>();
38+
private final Map<Integer, List<Selection>> dispatchedLevels = new LinkedHashMap<>();
3139

3240
private CallStack() {
3341
expectedStrategyCallsPerLevel.put(1, 1);
@@ -66,6 +74,10 @@ private boolean allFetchesHappened(int level) {
6674
return Objects.equals(fetchCountPerLevel.get(level), expectedFetchCountPerLevel.get(level));
6775
}
6876

77+
private Map<Integer, List<Selection>> getDispatchedLevels() {
78+
return dispatchedLevels;
79+
}
80+
6981
@Override
7082
public String toString() {
7183
return "CallStack{" +
@@ -78,12 +90,12 @@ public String toString() {
7890
'}';
7991
}
8092

81-
private boolean dispatchIfNotDispatchedBefore(int level) {
82-
if (dispatchedLevels.contains(level)) {
93+
private boolean dispatchIfNotDispatchedBefore(int level, List<Selection> selectionSet) {
94+
if (dispatchedLevels.containsKey(level)) {
8395
Assert.assertShouldNeverHappen("level " + level + " already dispatched");
8496
return false;
8597
}
86-
dispatchedLevels.add(level);
98+
dispatchedLevels.put(level, selectionSet);
8799
return true;
88100
}
89101

@@ -119,14 +131,52 @@ public void setStatus(ExecutionId executionId, boolean toState) {
119131
* @return if all managed executions are ready to be dispatched.
120132
*/
121133
public boolean allReady() {
122-
return status.values().stream().noneMatch(Boolean.FALSE::equals);
134+
return status.values().stream().noneMatch(Boolean.FALSE::equals) &&
135+
activeRequests.values().stream().map(CallStack::getDispatchedLevels).allMatch(dispatchMap ->
136+
verifyAgainstOthers(dispatchMap, activeRequests.values().stream().map(CallStack::getDispatchedLevels).collect(Collectors.toList())));
137+
}
138+
139+
boolean verifyAgainstOthers(Map<Integer, List<Selection>> current, List<Map<Integer, List<Selection>>> others) {
140+
for (Map<Integer, List<Selection>> other : others) {
141+
Iterator<Map.Entry<Integer, List<Selection>>> currentIter = current.entrySet().iterator();
142+
Iterator<Map.Entry<Integer, List<Selection>>> otherIter = other.entrySet().iterator();
143+
if (currentIter.hasNext() && otherIter.hasNext()) {
144+
Map.Entry<Integer, List<Selection>> currentFirstEntry = currentIter.next();
145+
Map.Entry<Integer, List<Selection>> otherFirstEntry = otherIter.next();
146+
boolean matching = selectionsEqual(currentFirstEntry.getValue(), otherFirstEntry.getValue());
147+
while (matching && currentIter.hasNext() && otherIter.hasNext()) {
148+
currentFirstEntry = currentIter.next();
149+
otherFirstEntry = otherIter.next();
150+
matching = selectionsEqual(currentFirstEntry.getValue(), otherFirstEntry.getValue());
151+
}
152+
if (matching && (currentIter.hasNext() || otherIter.hasNext())) {
153+
return false;
154+
}
155+
} else if (otherIter.hasNext() || currentIter.hasNext()) {
156+
return false;
157+
}
158+
}
159+
return true;
160+
}
161+
162+
private boolean selectionsEqual(List<Selection> first, List<Selection> second) {
163+
if (first.size() != second.size()) {
164+
return false;
165+
} else {
166+
for (int i = 0; i < first.size(); i++) {
167+
if (!first.get(i).isEqualTo(second.get(i))) {
168+
return false;
169+
}
170+
}
171+
return true;
172+
}
123173
}
124174

125175
/**
126176
* Removes all dispatch status. Should be used after a call to dispatch.
127177
*/
128178
public void allReset() {
129-
status.clear();
179+
status.keySet().forEach(key -> status.put(key, false));
130180
}
131181

132182
/**
@@ -262,12 +312,12 @@ public boolean allStrategyCallsHappened(ExecutionId executionId, int level) {
262312
* @param level the level to get the value of
263313
* @return dispatchIfNotDispattchedBefore
264314
*/
265-
public boolean dispatchIfNotDispatchedBefore(ExecutionId executionId, int level) {
315+
public boolean dispatchIfNotDispatchedBefore(ExecutionId executionId, int level, List<Selection> selectionSet) {
266316
if (!activeRequests.containsKey(executionId)) {
267317
throw new IllegalStateException(
268318
String.format("Execution %s not managed by this RequestStack, can not get dispatch if not dispatched before value", executionId));
269319
}
270-
return activeRequests.get(executionId).dispatchIfNotDispatchedBefore(level);
320+
return activeRequests.get(executionId).dispatchIfNotDispatchedBefore(level, selectionSet);
271321
}
272322

273323
/**

src/test/groovy/graphql/servlet/DataLoaderDispatchingSpec.groovy

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,28 @@ class DataLoaderDispatchingSpec extends Specification {
3737
AbstractGraphQLHttpServlet servlet
3838
MockHttpServletRequest request
3939
MockHttpServletResponse response
40-
AtomicInteger fetchCounter = new AtomicInteger()
41-
AtomicInteger loadCounter = new AtomicInteger()
42-
43-
BatchLoader<String, String> batchLoaderA = new BatchLoader<String, String>() {
44-
@Override
45-
CompletionStage<List<String>> load(List<String> keys) {
46-
fetchCounter.incrementAndGet()
47-
CompletableFuture.completedFuture(keys)
40+
AtomicInteger fetchCounterA = new AtomicInteger()
41+
AtomicInteger loadCounterA = new AtomicInteger()
42+
AtomicInteger fetchCounterB = new AtomicInteger()
43+
AtomicInteger loadCounterB = new AtomicInteger()
44+
AtomicInteger fetchCounterC = new AtomicInteger()
45+
AtomicInteger loadCounterC = new AtomicInteger()
46+
47+
BatchLoader<String, String> batchLoaderWithCounter(AtomicInteger fetchCounter) {
48+
return new BatchLoader<String, String>() {
49+
@Override
50+
CompletionStage<List<String>> load(List<String> keys) {
51+
fetchCounter.incrementAndGet()
52+
CompletableFuture.completedFuture(keys)
53+
}
4854
}
4955
}
5056

5157
def registry() {
5258
DataLoaderRegistry registry = new DataLoaderRegistry()
53-
registry.register("A", DataLoader.newDataLoader(batchLoaderA))
59+
registry.register("A", DataLoader.newDataLoader(batchLoaderWithCounter(fetchCounterA)))
60+
registry.register("B", DataLoader.newDataLoader(batchLoaderWithCounter(fetchCounterB)))
61+
registry.register("C", DataLoader.newDataLoader(batchLoaderWithCounter(fetchCounterC)))
5462
registry
5563
}
5664

@@ -61,13 +69,13 @@ class DataLoaderDispatchingSpec extends Specification {
6169
response = new MockHttpServletResponse()
6270
}
6371

64-
def queryDataFetcher() {
72+
def queryDataFetcher(String dataLoaderName, AtomicInteger loadCounter) {
6573
return new DataFetcher() {
6674
@Override
6775
Object get(DataFetchingEnvironment environment) {
6876
String id = environment.arguments.arg
6977
loadCounter.incrementAndGet()
70-
environment.getDataLoader("A").load(id)
78+
environment.getDataLoader(dataLoaderName).load(id)
7179
}
7280
}
7381
}
@@ -92,55 +100,67 @@ class DataLoaderDispatchingSpec extends Specification {
92100
}
93101

94102
def configureServlet(ContextSetting contextSetting) {
95-
servlet = TestUtils.createDataLoadingServlet( queryDataFetcher(),
96-
{ env -> env.arguments.arg },
97-
{ env ->
98-
AtomicReference<SingleSubscriberPublisher<String>> publisherRef = new AtomicReference<>();
99-
publisherRef.set(new SingleSubscriberPublisher<>({ subscription ->
100-
publisherRef.get().offer(env.arguments.arg)
101-
publisherRef.get().noMoreData()
102-
}))
103-
return publisherRef.get()
104-
}, false, contextSetting,
103+
servlet = TestUtils.createDataLoadingServlet(queryDataFetcher("A", loadCounterA),
104+
queryDataFetcher("B", loadCounterB), queryDataFetcher("C", loadCounterC),
105+
false, contextSetting,
105106
contextBuilder())
106107
}
107108

109+
def resetCounters() {
110+
fetchCounterA.set(0)
111+
fetchCounterB.set(0)
112+
loadCounterA.set(0)
113+
loadCounterB.set(0)
114+
}
115+
108116
def "batched query with per query context does not batch loads together"() {
109117
setup:
110118
configureServlet(ContextSetting.PER_QUERY)
111-
request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]')
112-
fetchCounter.set(0)
113-
loadCounter.set(0)
119+
request.addParameter('query', '[{ "query": "query { query(arg:\\"test\\") { echo(arg:\\"test\\") { echo(arg:\\"test\\") } }}" }, { "query": "query{query(arg:\\"test\\") { echo (arg:\\"test\\") { echo(arg:\\"test\\")} }}" },' +
120+
' { "query": "query{queryTwo(arg:\\"test\\") { echo (arg:\\"test\\")}}" }, { "query": "query{queryTwo(arg:\\"test\\") { echo (arg:\\"test\\")}}" }]')
121+
resetCounters()
114122

115123
when:
116124
servlet.doGet(request, response)
117125

118126
then:
119127
response.getStatus() == STATUS_OK
120128
response.getContentType() == CONTENT_TYPE_JSON_UTF8
121-
getBatchedResponseContent()[0].data.echo == "test"
122-
getBatchedResponseContent()[1].data.echo == "test"
123-
fetchCounter.get() == 2
124-
loadCounter.get() == 2
129+
getBatchedResponseContent()[0].data.query.echo.echo == "test"
130+
getBatchedResponseContent()[1].data.query.echo.echo == "test"
131+
getBatchedResponseContent()[2].data.queryTwo.echo == "test"
132+
getBatchedResponseContent()[3].data.queryTwo.echo == "test"
133+
fetchCounterA.get() == 2
134+
loadCounterA.get() == 2
135+
fetchCounterB.get() == 2
136+
loadCounterB.get() == 2
137+
fetchCounterC.get() == 2
138+
loadCounterC.get() == 2
125139
}
126140

127141
def "batched query with per request context batches all queries within the request"() {
128142
setup:
129143
servlet = configureServlet(ContextSetting.PER_REQUEST)
130-
request.addParameter('query', '[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]')
131-
fetchCounter.set(0)
132-
loadCounter.set(0)
144+
request.addParameter('query', '[{ "query": "query { query(arg:\\"test\\") { echo(arg:\\"test\\") { echo(arg:\\"test\\") } }}" }, { "query": "query{query(arg:\\"test\\") { echo (arg:\\"test\\") { echo(arg:\\"test\\")} }}" },' +
145+
' { "query": "query{queryTwo(arg:\\"test\\") { echo (arg:\\"test\\")}}" }, { "query": "query{queryTwo(arg:\\"test\\") { echo (arg:\\"test\\")}}" }]')
146+
resetCounters()
133147

134148
when:
135149
servlet.doGet(request, response)
136150

137151
then:
138152
response.getStatus() == STATUS_OK
139153
response.getContentType() == CONTENT_TYPE_JSON_UTF8
140-
getBatchedResponseContent()[0].data.echo == "test"
141-
getBatchedResponseContent()[1].data.echo == "test"
142-
fetchCounter.get() == 1
143-
loadCounter.get() == 2
154+
getBatchedResponseContent()[0].data.query.echo.echo == "test"
155+
getBatchedResponseContent()[1].data.query.echo.echo == "test"
156+
getBatchedResponseContent()[2].data.queryTwo.echo == "test"
157+
getBatchedResponseContent()[3].data.queryTwo.echo == "test"
158+
fetchCounterA.get() == 1
159+
loadCounterA.get() == 2
160+
fetchCounterB.get() == 1
161+
loadCounterB.get() == 2
162+
fetchCounterC.get() == 1
163+
loadCounterC.get() == 2
144164
}
145165

146166
List<Map<String, Object>> getBatchedResponseContent() {

0 commit comments

Comments
 (0)