Skip to content

Commit fa4dc1f

Browse files
committed
Added keep alive runner for periodically sending keep alives (fix #91)
1 parent 19fd060 commit fa4dc1f

File tree

6 files changed

+111
-26
lines changed

6 files changed

+111
-26
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package graphql.servlet.internal;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
6+
import javax.websocket.Session;
7+
import java.util.Map;
8+
import java.util.concurrent.ConcurrentHashMap;
9+
import java.util.concurrent.Executors;
10+
import java.util.concurrent.Future;
11+
import java.util.concurrent.ScheduledExecutorService;
12+
import java.util.concurrent.TimeUnit;
13+
14+
class ApolloSubscriptionKeepAliveRunner {
15+
16+
private static final Logger LOG = LoggerFactory.getLogger(ApolloSubscriptionKeepAliveRunner.class);
17+
18+
private static final long KEEP_ALIVE_INTERVAL_SEC = 15;
19+
private static final int EXECUTOR_POOL_SIZE = 10;
20+
21+
private ScheduledExecutorService executor;
22+
private SubscriptionSender sender;
23+
private ApolloSubscriptionProtocolHandler.OperationMessage keepAliveMessage;
24+
private Map<Session, Future<?>> futures;
25+
26+
ApolloSubscriptionKeepAliveRunner(SubscriptionSender sender) {
27+
this.sender = sender;
28+
this.keepAliveMessage = ApolloSubscriptionProtocolHandler.OperationMessage.newKeepAliveMessage();
29+
this.executor = Executors.newScheduledThreadPool(EXECUTOR_POOL_SIZE);
30+
this.futures = new ConcurrentHashMap<>();
31+
}
32+
33+
void keepAlive(Session session) {
34+
if (!futures.containsKey(session)) {
35+
Future<?> future = executor.scheduleAtFixedRate(() -> {
36+
try {
37+
if (session.isOpen()) {
38+
sender.send(session, keepAliveMessage);
39+
} else {
40+
LOG.warn("Session appears to be closed. Aborting keep alive");
41+
abort(session);
42+
}
43+
} catch (Throwable t) {
44+
LOG.error("Cannot send keep alive message. Aborting keep alive", t);
45+
abort(session);
46+
}
47+
}, 0, KEEP_ALIVE_INTERVAL_SEC, TimeUnit.SECONDS);
48+
futures.put(session, future);
49+
}
50+
}
51+
52+
void abort(Session session) {
53+
Future<?> future = futures.remove(session);
54+
if (future != null) {
55+
future.cancel(true);
56+
}
57+
}
58+
59+
}

src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ public class ApolloSubscriptionProtocolHandler extends SubscriptionProtocolHandl
3333
private static final CloseReason TERMINATE_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "client requested " + GQL_CONNECTION_TERMINATE.getType());
3434

3535
private final SubscriptionHandlerInput input;
36+
private final SubscriptionSender sender;
37+
private final ApolloSubscriptionKeepAliveRunner keepAliveRunner;
3638
private final ApolloSubscriptionConnectionListener connectionListener;
3739

3840
public ApolloSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) {
@@ -41,6 +43,8 @@ public ApolloSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHa
4143
.filter(ApolloSubscriptionConnectionListener.class::isInstance)
4244
.map(ApolloSubscriptionConnectionListener.class::cast)
4345
.orElse(new ApolloSubscriptionConnectionListener() {});
46+
this.sender = new SubscriptionSender(this.input.getGraphQLObjectMapper().getJacksonMapper());
47+
this.keepAliveRunner = new ApolloSubscriptionKeepAliveRunner(this.sender);
4448
}
4549

4650
@Override
@@ -54,20 +58,20 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
5458
return;
5559
}
5660

57-
switch(message.getType()) {
61+
switch (message.getType()) {
5862
case GQL_CONNECTION_INIT:
5963
try {
6064
Optional<Object> connectionResponse = connectionListener.onConnect(message.getPayload());
6165
connectionResponse.ifPresent(it -> session.getUserProperties().put(ApolloSubscriptionConnectionListener.CONNECT_RESULT_KEY, it));
6266
} catch (Throwable t) {
63-
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ERROR, t.getMessage());
67+
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ERROR, message.getId(), t);
6468
return;
6569
}
6670

6771
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_ACK, message.getId());
6872

6973
if (connectionListener.isKeepAliveEnabled()) {
70-
sendMessage(session, OperationMessage.Type.GQL_CONNECTION_KEEP_ALIVE, message.getId());
74+
keepAliveRunner.keepAlive(session);
7175
}
7276
break;
7377

@@ -82,10 +86,12 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
8286
break;
8387

8488
case GQL_STOP:
89+
keepAliveRunner.abort(session);
8590
unsubscribe(subscriptions, message.id);
8691
break;
8792

8893
case GQL_CONNECTION_TERMINATE:
94+
keepAliveRunner.abort(session);
8995
try {
9096
session.close(TERMINATE_CLOSE_REASON);
9197
} catch (IOException e) {
@@ -112,7 +118,7 @@ private GraphQLSingleInvocationInput createInvocationInput(Session session, Oper
112118
private void handleSubscriptionStart(Session session, WsSessionSubscriptions subscriptions, String id, ExecutionResult executionResult) {
113119
executionResult = input.getGraphQLObjectMapper().sanitizeErrors(executionResult);
114120

115-
if(input.getGraphQLObjectMapper().areErrorsPresent(executionResult)) {
121+
if (input.getGraphQLObjectMapper().areErrorsPresent(executionResult)) {
116122
sendMessage(session, OperationMessage.Type.GQL_ERROR, id, input.getGraphQLObjectMapper().convertSanitizedExecutionResult(executionResult, false));
117123
return;
118124
}
@@ -127,11 +133,13 @@ protected void sendDataMessage(Session session, String id, Object payload) {
127133

128134
@Override
129135
protected void sendErrorMessage(Session session, String id) {
136+
keepAliveRunner.abort(session);
130137
sendMessage(session, GQL_ERROR, id);
131138
}
132139

133140
@Override
134141
protected void sendCompleteMessage(Session session, String id) {
142+
keepAliveRunner.abort(session);
135143
sendMessage(session, GQL_COMPLETE, id);
136144
}
137145

@@ -140,13 +148,7 @@ private void sendMessage(Session session, OperationMessage.Type type, String id)
140148
}
141149

142150
private void sendMessage(Session session, OperationMessage.Type type, String id, Object payload) {
143-
try {
144-
session.getBasicRemote().sendText(input.getGraphQLObjectMapper().getJacksonMapper().writeValueAsString(
145-
new OperationMessage(type, id, payload)
146-
));
147-
} catch (IOException e) {
148-
throw new RuntimeException("Error sending subscription response", e);
149-
}
151+
sender.send(session, new OperationMessage(type, id, payload));
150152
}
151153

152154
@JsonInclude(JsonInclude.Include.NON_NULL)
@@ -164,6 +166,10 @@ public OperationMessage(Type type, String id, Object payload) {
164166
this.payload = payload;
165167
}
166168

169+
static OperationMessage newKeepAliveMessage() {
170+
return new OperationMessage(Type.GQL_CONNECTION_KEEP_ALIVE, null, null);
171+
}
172+
167173
public Type getType() {
168174
return type;
169175
}

src/main/java/graphql/servlet/internal/FallbackSubscriptionProtocolHandler.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
public class FallbackSubscriptionProtocolHandler extends SubscriptionProtocolHandler {
1212

1313
private final SubscriptionHandlerInput input;
14+
private final SubscriptionSender sender;
1415

1516
public FallbackSubscriptionProtocolHandler(SubscriptionHandlerInput subscriptionHandlerInput) {
1617
this.input = subscriptionHandlerInput;
18+
sender = new SubscriptionSender(subscriptionHandlerInput.getGraphQLObjectMapper().getJacksonMapper());
1719
}
1820

1921
@Override
@@ -32,11 +34,7 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
3234

3335
@Override
3436
protected void sendDataMessage(Session session, String id, Object payload) {
35-
try {
36-
session.getBasicRemote().sendText(input.getGraphQLObjectMapper().getJacksonMapper().writeValueAsString(payload));
37-
} catch (IOException e) {
38-
throw new RuntimeException("Error sending subscription response", e);
39-
}
37+
sender.send(session, payload);
4038
}
4139

4240
@Override

src/main/java/graphql/servlet/internal/SubscriptionProtocolHandler.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ public void onNext(ExecutionResult executionResult) {
5555
@Override
5656
public void onError(Throwable throwable) {
5757
log.error("Subscription error", throwable);
58-
subscriptions.cancel(id);
58+
unsubscribe(subscriptions, id);
5959
sendErrorMessage(session, id);
6060
}
6161

6262
@Override
6363
public void onComplete() {
64-
subscriptions.cancel(id);
64+
unsubscribe(subscriptions, id);
6565
sendCompleteMessage(session, id);
6666
}
6767
});
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package graphql.servlet.internal;
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper;
4+
5+
import javax.websocket.Session;
6+
import java.io.IOException;
7+
8+
class SubscriptionSender {
9+
10+
private final ObjectMapper objectMapper;
11+
12+
SubscriptionSender(ObjectMapper objectMapper) {
13+
this.objectMapper = objectMapper;
14+
}
15+
16+
void send(Session session, Object payload) {
17+
try {
18+
session.getBasicRemote().sendText(objectMapper.writeValueAsString(payload));
19+
} catch (IOException e) {
20+
throw new RuntimeException("Error sending subscription response", e);
21+
}
22+
}
23+
}

src/main/java/graphql/servlet/internal/WsSessionSubscriptions.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import java.util.HashMap;
66
import java.util.Map;
7+
import java.util.concurrent.ConcurrentHashMap;
78

89
/**
910
* @author Andrew Potter
@@ -12,15 +13,15 @@ public class WsSessionSubscriptions {
1213
private final Object lock = new Object();
1314

1415
private boolean closed = false;
15-
private Map<String, Subscription> subscriptions = new HashMap<>();
16+
private Map<String, Subscription> subscriptions = new ConcurrentHashMap<>();
1617

1718
public void add(Subscription subscription) {
1819
add(getImplicitId(subscription), subscription);
1920
}
2021

2122
public void add(String id, Subscription subscription) {
2223
synchronized (lock) {
23-
if(closed) {
24+
if (closed) {
2425
throw new IllegalStateException("Websocket was already closed!");
2526
}
2627
subscriptions.put(id, subscription);
@@ -32,19 +33,17 @@ public void cancel(Subscription subscription) {
3233
}
3334

3435
public void cancel(String id) {
35-
synchronized (lock) {
36-
Subscription subscription = subscriptions.remove(id);
37-
if(subscription != null) {
38-
subscription.cancel();
39-
}
36+
Subscription subscription = subscriptions.remove(id);
37+
if(subscription != null) {
38+
subscription.cancel();
4039
}
4140
}
4241

4342
public void close() {
4443
synchronized (lock) {
4544
closed = true;
4645
subscriptions.forEach((k, v) -> v.cancel());
47-
subscriptions = new HashMap<>();
46+
subscriptions.clear();
4847
}
4948
}
5049

0 commit comments

Comments
 (0)