1010import java .io .IOException ;
1111import java .util .Collections ;
1212import java .util .HashMap ;
13+ import java .util .HashSet ;
1314import java .util .List ;
1415import java .util .Map ;
16+ import java .util .concurrent .atomic .AtomicBoolean ;
1517import java .util .stream .Collectors ;
1618import java .util .stream .Stream ;
1719
@@ -27,6 +29,7 @@ public class GraphQLWebsocketServlet extends Endpoint {
2729 private static final String HANDSHAKE_REQUEST_KEY = HandshakeRequest .class .getName ();
2830 private static final String PROTOCOL_HANDLER_REQUEST_KEY = SubscriptionProtocolHandler .class .getName ();
2931 private static final CloseReason ERROR_CLOSE_REASON = new CloseReason (CloseReason .CloseCodes .UNEXPECTED_CONDITION , "Internal Server Error" );
32+ private static final CloseReason SHUTDOWN_CLOSE_REASON = new CloseReason (CloseReason .CloseCodes .UNEXPECTED_CONDITION , "Server Shut Down" );
3033
3134 private static final List <SubscriptionProtocolFactory > subscriptionProtocolFactories = Collections .singletonList (new ApolloSubscriptionProtocolFactory ());
3235 private static final SubscriptionProtocolFactory fallbackSubscriptionProtocolFactory = new FallbackSubscriptionProtocolFactory ();
@@ -40,19 +43,29 @@ public class GraphQLWebsocketServlet extends Endpoint {
4043
4144 private final Map <Session , WsSessionSubscriptions > sessionSubscriptionCache = new HashMap <>();
4245 private final SubscriptionHandlerInput subscriptionHandlerInput ;
46+ private final AtomicBoolean isShuttingDown = new AtomicBoolean (false );
47+ private final AtomicBoolean isShutDown = new AtomicBoolean (false );
48+ private final Object cacheLock = new Object ();
4349
4450 public GraphQLWebsocketServlet (GraphQLQueryInvoker queryInvoker , GraphQLInvocationInputFactory invocationInputFactory , GraphQLObjectMapper graphQLObjectMapper ) {
4551 this .subscriptionHandlerInput = new SubscriptionHandlerInput (invocationInputFactory , queryInvoker , graphQLObjectMapper );
4652 }
4753
4854 @ Override
4955 public void onOpen (Session session , EndpointConfig endpointConfig ) {
50- log .debug ("Session opened: {}, {}" , session .getId (), endpointConfig );
5156 final WsSessionSubscriptions subscriptions = new WsSessionSubscriptions ();
5257 final HandshakeRequest request = (HandshakeRequest ) session .getUserProperties ().get (HANDSHAKE_REQUEST_KEY );
5358 final SubscriptionProtocolHandler subscriptionProtocolHandler = (SubscriptionProtocolHandler ) session .getUserProperties ().get (PROTOCOL_HANDLER_REQUEST_KEY );
5459
55- sessionSubscriptionCache .put (session , subscriptions );
60+ synchronized (cacheLock ) {
61+ if (isShuttingDown .get ()) {
62+ throw new IllegalStateException ("Server is shutting down!" );
63+ }
64+
65+ sessionSubscriptionCache .put (session , subscriptions );
66+ }
67+
68+ log .debug ("Session opened: {}, {}" , session .getId (), endpointConfig );
5669
5770 // This *cannot* be a lambda because of the way undertow checks the class...
5871 session .addMessageHandler (new MessageHandler .Whole <String >() {
@@ -71,7 +84,10 @@ public void onMessage(String text) {
7184 @ Override
7285 public void onClose (Session session , CloseReason closeReason ) {
7386 log .debug ("Session closed: {}, {}" , session .getId (), closeReason );
74- WsSessionSubscriptions subscriptions = sessionSubscriptionCache .remove (session );
87+ WsSessionSubscriptions subscriptions ;
88+ synchronized (cacheLock ) {
89+ subscriptions = sessionSubscriptionCache .remove (session );
90+ }
7591 if (subscriptions != null ) {
7692 subscriptions .close ();
7793 }
@@ -110,6 +126,42 @@ public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request,
110126 }
111127 }
112128
129+ /**
130+ * Stops accepting connections and closes all existing connections
131+ */
132+ public void beginShutDown () {
133+ synchronized (cacheLock ) {
134+ isShuttingDown .set (true );
135+ Map <Session , WsSessionSubscriptions > copy = new HashMap <>(sessionSubscriptionCache );
136+
137+ // Prevent comodification exception since #onClose() is called during session.close(), but we can't necessarily rely on that happening so we close subscriptions here anyway.
138+ copy .forEach ((session , wsSessionSubscriptions ) -> {
139+ wsSessionSubscriptions .close ();
140+ try {
141+ session .close (SHUTDOWN_CLOSE_REASON );
142+ } catch (IOException e ) {
143+ log .error ("Error closing websocket session!" , e );
144+ }
145+ });
146+
147+ copy .clear ();
148+
149+ if (!sessionSubscriptionCache .isEmpty ()) {
150+ log .error ("GraphQLWebsocketServlet did not shut down cleanly!" );
151+ sessionSubscriptionCache .clear ();
152+ }
153+ }
154+
155+ isShutDown .set (true );
156+ }
157+
158+ /**
159+ * @return true when shutdown is complete
160+ */
161+ public boolean isShutDown () {
162+ return isShutDown .get ();
163+ }
164+
113165 private static SubscriptionProtocolFactory getSubscriptionProtocolFactory (List <String > accept ) {
114166 for (String protocol : accept ) {
115167 for (SubscriptionProtocolFactory subscriptionProtocolFactory : subscriptionProtocolFactories ) {
0 commit comments