2525import java .util .Set ;
2626import java .util .concurrent .ConcurrentHashMap ;
2727import java .util .concurrent .atomic .AtomicInteger ;
28+ import java .util .function .Consumer ;
2829
2930import org .apache .commons .logging .Log ;
3031import org .apache .commons .logging .LogFactory ;
@@ -106,9 +107,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
106107
107108 private @ Nullable MessageHeaderInitializer headerInitializer ;
108109
109- private @ Nullable Map <String , MessageChannel > orderedHandlingMessageChannels ;
110+ private final Map <String , SessionInfo > sessions = new ConcurrentHashMap <>() ;
110111
111- private final Map < String , Principal > stompAuthentications = new ConcurrentHashMap <>() ;
112+ private boolean preserveReceiveOrder ;
112113
113114 private @ Nullable Boolean immutableMessageInterceptorPresent ;
114115
@@ -201,7 +202,7 @@ public void setHeaderInitializer(@Nullable MessageHeaderInitializer headerInitia
201202 * @since 6.1
202203 */
203204 public void setPreserveReceiveOrder (boolean preserveReceiveOrder ) {
204- this .orderedHandlingMessageChannels = ( preserveReceiveOrder ? new ConcurrentHashMap <>() : null ) ;
205+ this .preserveReceiveOrder = preserveReceiveOrder ;
205206 }
206207
207208 /**
@@ -210,7 +211,7 @@ public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
210211 * @since 6.1
211212 */
212213 public boolean isPreserveReceiveOrder () {
213- return ( this .orderedHandlingMessageChannels != null ) ;
214+ return this .preserveReceiveOrder ;
214215 }
215216
216217 @ Override
@@ -245,7 +246,7 @@ public Stats getStats() {
245246 */
246247 @ Override
247248 public void handleMessageFromClient (WebSocketSession session ,
248- WebSocketMessage <?> webSocketMessage , MessageChannel targetChannel ) {
249+ WebSocketMessage <?> webSocketMessage , MessageChannel channel ) {
249250
250251 List <Message <byte []>> messages ;
251252 try {
@@ -288,35 +289,36 @@ else if (webSocketMessage instanceof BinaryMessage binaryMessage) {
288289 return ;
289290 }
290291
291- MessageChannel channelToUse = targetChannel ;
292- if (this .orderedHandlingMessageChannels != null ) {
293- channelToUse = this .orderedHandlingMessageChannels .computeIfAbsent (
294- session .getId (), id -> new OrderedMessageChannelDecorator (targetChannel , logger ));
295- }
292+ SessionInfo info = this .sessions .get (session .getId ());
293+ MessageChannel channelToUse = (info != null ? info .getMessageChannelToUse () : null );
296294
297295 for (Message <byte []> message : messages ) {
298- StompHeaderAccessor headerAccessor =
299- MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
296+ StompHeaderAccessor headerAccessor = MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
300297 Assert .state (headerAccessor != null , "No StompHeaderAccessor" );
301298
302299 StompCommand command = headerAccessor .getCommand ();
303- boolean isConnect = StompCommand .CONNECT .equals (command ) || StompCommand .STOMP .equals (command );
304-
300+ boolean isConnect = ( StompCommand .CONNECT .equals (command ) || StompCommand .STOMP .equals (command ) );
301+ String sessionId = session . getId ();
305302 boolean sent = false ;
303+
306304 try {
305+ if (isConnect ) {
306+ channelToUse = (this .preserveReceiveOrder ? new OrderedMessageChannelDecorator (channel , logger ) : channel );
307+ info = new SessionInfo (channelToUse , session .getPrincipal ());
308+ SessionInfo prevInfo = this .sessions .putIfAbsent (sessionId , info );
309+ Assert .state (prevInfo == null , "Session already exists" );
310+ headerAccessor .setUserChangeCallback (info );
311+ }
312+ else {
313+ Assert .state (channelToUse != null , "Unknown session: " + sessionId );
314+ }
307315
308- headerAccessor .setSessionId (session . getId () );
316+ headerAccessor .setSessionId (sessionId );
309317 headerAccessor .setSessionAttributes (session .getAttributes ());
310318 headerAccessor .setUser (getUser (session ));
311- if (isConnect ) {
312- headerAccessor .setUserChangeCallback (user -> {
313- if (user != null && user != session .getPrincipal ()) {
314- this .stompAuthentications .put (session .getId (), user );
315- }
316- });
317- }
318319 headerAccessor .setHeader (SimpMessageHeaderAccessor .HEART_BEAT_HEADER , headerAccessor .getHeartbeat ());
319- if (!detectImmutableMessageInterceptor (targetChannel )) {
320+
321+ if (!detectImmutableMessageInterceptor (channel )) {
320322 headerAccessor .setImmutable ();
321323 }
322324
@@ -356,23 +358,28 @@ else if (StompCommand.UNSUBSCRIBE.equals(command)) {
356358 }
357359 catch (Throwable ex ) {
358360 if (logger .isDebugEnabled ()) {
359- logger .debug ("Failed to send message to MessageChannel in session " + session . getId () , ex );
361+ logger .debug ("Failed to send message to MessageChannel in session " + sessionId , ex );
360362 }
361363 else if (logger .isErrorEnabled ()) {
362364 // Skip for unsent CONNECT or SUBSCRIBE (likely authentication/authorization issues)
363365 if (sent || !(isConnect || StompCommand .SUBSCRIBE .equals (command ))) {
364366 logger .error ("Failed to send message to MessageChannel in session " +
365- session . getId () + ":" + ex .getMessage ());
367+ sessionId + ":" + ex .getMessage ());
366368 }
367369 }
368370 handleError (session , ex , message );
369371 }
372+
373+ if (!sent && isConnect ) {
374+ this .sessions .remove (sessionId );
375+ break ;
376+ }
370377 }
371378 }
372379
373380 private @ Nullable Principal getUser (WebSocketSession session ) {
374- Principal user = this .stompAuthentications .get (session .getId ());
375- return (user != null ? user : session .getPrincipal ());
381+ SessionInfo info = this .sessions .get (session .getId ());
382+ return (info != null ? info . getUser () : session .getPrincipal ());
376383 }
377384
378385 private void handleError (WebSocketSession session , Throwable ex , @ Nullable Message <byte []> clientMessage ) {
@@ -674,10 +681,7 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
674681 outputChannel .send (message );
675682 }
676683 finally {
677- if (this .orderedHandlingMessageChannels != null ) {
678- this .orderedHandlingMessageChannels .remove (session .getId ());
679- }
680- this .stompAuthentications .remove (session .getId ());
684+ this .sessions .remove (session .getId ());
681685 SimpAttributesContextHolder .resetAttributes ();
682686 simpAttributes .sessionCompleted ();
683687 }
@@ -707,6 +711,36 @@ public String toString() {
707711 }
708712
709713
714+ private static class SessionInfo implements Consumer <Principal > {
715+
716+ private final MessageChannel channel ;
717+
718+ private final @ Nullable Principal webSocketUser ;
719+
720+ private volatile @ Nullable Principal stompUser ;
721+
722+ SessionInfo (MessageChannel channel , @ Nullable Principal user ) {
723+ this .channel = channel ;
724+ this .webSocketUser = user ;
725+ }
726+
727+ public MessageChannel getMessageChannelToUse () {
728+ return this .channel ;
729+ }
730+
731+ public @ Nullable Principal getUser () {
732+ return (this .stompUser != null ? this .stompUser : this .webSocketUser );
733+ }
734+
735+ @ Override
736+ public void accept (@ Nullable Principal stompUser ) {
737+ if (stompUser != null && stompUser != this .webSocketUser ) {
738+ this .stompUser = stompUser ;
739+ }
740+ }
741+ }
742+
743+
710744 /**
711745 * Contract for access to session counters.
712746 * @since 5.2
0 commit comments