Skip to content

Commit 9d89225

Browse files
committed
Shut down websocket servlet gracefully
1 parent f656740 commit 9d89225

File tree

3 files changed

+59
-6
lines changed

3 files changed

+59
-6
lines changed

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
version =
1+
version = 6.1.1
22
group = com.graphql-java

src/main/java/graphql/servlet/GraphQLWebsocketServlet.java

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import java.io.IOException;
1111
import java.util.Collections;
1212
import java.util.HashMap;
13+
import java.util.HashSet;
1314
import java.util.List;
1415
import java.util.Map;
16+
import java.util.concurrent.atomic.AtomicBoolean;
1517
import java.util.stream.Collectors;
1618
import java.util.stream.Stream;
1719

@@ -27,6 +29,7 @@ public class GraphQLWebsocketServlet extends Endpoint {
2729
private static final String HANDSHAKE_REQUEST_KEY = HandshakeRequest.class.getName();
2830
private static final String PROTOCOL_HANDLER_REQUEST_KEY = SubscriptionProtocolHandler.class.getName();
2931
private static final CloseReason ERROR_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.UNEXPECTED_CONDITION, "Internal Server Error");
32+
private static final CloseReason SHUTDOWN_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.UNEXPECTED_CONDITION, "Server Shut Down");
3033

3134
private static final List<SubscriptionProtocolFactory> subscriptionProtocolFactories = Collections.singletonList(new ApolloSubscriptionProtocolFactory());
3235
private static final SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory = new FallbackSubscriptionProtocolFactory();
@@ -40,19 +43,29 @@ public class GraphQLWebsocketServlet extends Endpoint {
4043

4144
private final Map<Session, WsSessionSubscriptions> sessionSubscriptionCache = new HashMap<>();
4245
private final SubscriptionHandlerInput subscriptionHandlerInput;
46+
private final AtomicBoolean isShuttingDown = new AtomicBoolean(false);
47+
private final AtomicBoolean isShutDown = new AtomicBoolean(false);
48+
private final Object cacheLock = new Object();
4349

4450
public GraphQLWebsocketServlet(GraphQLQueryInvoker queryInvoker, GraphQLInvocationInputFactory invocationInputFactory, GraphQLObjectMapper graphQLObjectMapper) {
4551
this.subscriptionHandlerInput = new SubscriptionHandlerInput(invocationInputFactory, queryInvoker, graphQLObjectMapper);
4652
}
4753

4854
@Override
4955
public void onOpen(Session session, EndpointConfig endpointConfig) {
50-
log.debug("Session opened: {}, {}", session.getId(), endpointConfig);
5156
final WsSessionSubscriptions subscriptions = new WsSessionSubscriptions();
5257
final HandshakeRequest request = (HandshakeRequest) session.getUserProperties().get(HANDSHAKE_REQUEST_KEY);
5358
final SubscriptionProtocolHandler subscriptionProtocolHandler = (SubscriptionProtocolHandler) session.getUserProperties().get(PROTOCOL_HANDLER_REQUEST_KEY);
5459

55-
sessionSubscriptionCache.put(session, subscriptions);
60+
synchronized (cacheLock) {
61+
if (isShuttingDown.get()) {
62+
throw new IllegalStateException("Server is shutting down!");
63+
}
64+
65+
sessionSubscriptionCache.put(session, subscriptions);
66+
}
67+
68+
log.debug("Session opened: {}, {}", session.getId(), endpointConfig);
5669

5770
// This *cannot* be a lambda because of the way undertow checks the class...
5871
session.addMessageHandler(new MessageHandler.Whole<String>() {
@@ -71,7 +84,10 @@ public void onMessage(String text) {
7184
@Override
7285
public void onClose(Session session, CloseReason closeReason) {
7386
log.debug("Session closed: {}, {}", session.getId(), closeReason);
74-
WsSessionSubscriptions subscriptions = sessionSubscriptionCache.remove(session);
87+
WsSessionSubscriptions subscriptions;
88+
synchronized (cacheLock) {
89+
subscriptions = sessionSubscriptionCache.remove(session);
90+
}
7591
if (subscriptions != null) {
7692
subscriptions.close();
7793
}
@@ -110,6 +126,42 @@ public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request,
110126
}
111127
}
112128

129+
/**
130+
* Stops accepting connections and closes all existing connections
131+
*/
132+
public void beginShutDown() {
133+
synchronized (cacheLock) {
134+
isShuttingDown.set(true);
135+
Map<Session, WsSessionSubscriptions> copy = new HashMap<>(sessionSubscriptionCache);
136+
137+
// Prevent comodification exception since #onClose() is called during session.close(), but we can't necessarily rely on that happening so we close subscriptions here anyway.
138+
copy.forEach((session, wsSessionSubscriptions) -> {
139+
wsSessionSubscriptions.close();
140+
try {
141+
session.close(SHUTDOWN_CLOSE_REASON);
142+
} catch (IOException e) {
143+
log.error("Error closing websocket session!", e);
144+
}
145+
});
146+
147+
copy.clear();
148+
149+
if(!sessionSubscriptionCache.isEmpty()) {
150+
log.error("GraphQLWebsocketServlet did not shut down cleanly!");
151+
sessionSubscriptionCache.clear();
152+
}
153+
}
154+
155+
isShutDown.set(true);
156+
}
157+
158+
/**
159+
* @return true when shutdown is complete
160+
*/
161+
public boolean isShutDown() {
162+
return isShutDown.get();
163+
}
164+
113165
private static SubscriptionProtocolFactory getSubscriptionProtocolFactory(List<String> accept) {
114166
for (String protocol : accept) {
115167
for (SubscriptionProtocolFactory subscriptionProtocolFactory : subscriptionProtocolFactories) {

src/main/java/graphql/servlet/internal/ApolloSubscriptionProtocolHandler.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
public class ApolloSubscriptionProtocolHandler extends SubscriptionProtocolHandler {
2828

2929
private static final Logger log = LoggerFactory.getLogger(ApolloSubscriptionProtocolHandler.class);
30+
private static final CloseReason TERMINATE_CLOSE_REASON = new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "client requested " + GQL_CONNECTION_TERMINATE.getType());
3031

3132
private final SubscriptionHandlerInput input;
3233

@@ -68,9 +69,9 @@ public void onMessage(HandshakeRequest request, Session session, WsSessionSubscr
6869

6970
case GQL_CONNECTION_TERMINATE:
7071
try {
71-
session.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "client requested " + GQL_CONNECTION_TERMINATE.getType()));
72+
session.close(TERMINATE_CLOSE_REASON);
7273
} catch (IOException e) {
73-
log.error("Unable to close websocket session!", e);
74+
log.error("Error closing websocket session!", e);
7475
}
7576
break;
7677

0 commit comments

Comments
 (0)