2323import reactor .core .publisher .Mono ;
2424
2525import java .nio .ByteBuffer ;
26- import java .util .Arrays ;
27- import java .util .Collections ;
28- import java .util .List ;
2926import java .util .Set ;
3027import java .util .concurrent .ConcurrentSkipListSet ;
3128import java .util .concurrent .atomic .AtomicLong ;
3229import java .util .concurrent .atomic .AtomicReference ;
3330import java .util .function .Function ;
3431import java .util .function .Supplier ;
32+ import java .util .stream .Collectors ;
3533
3634import org .springframework .data .redis .connection .ReactiveSubscription ;
3735import org .springframework .data .redis .connection .SubscriptionListener ;
36+ import org .springframework .data .redis .connection .util .ByteArrayWrapper ;
3837import org .springframework .lang .Nullable ;
3938import org .springframework .util .Assert ;
4039import org .springframework .util .ObjectUtils ;
@@ -50,19 +49,22 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
5049
5150 private final LettuceByteBufferPubSubListenerWrapper listener ;
5251 private final StatefulRedisPubSubConnection <ByteBuffer , ByteBuffer > connection ;
53- private final RedisPubSubReactiveCommands <ByteBuffer , ByteBuffer > commands ;
52+
53+ private final RedisPubSubReactiveCommands <ByteBuffer , ByteBuffer > reactive ;
54+ private final LettuceReactivePubSubCommands commands ;
5455
5556 private final State patternState ;
5657 private final State channelState ;
5758
5859 LettuceReactiveSubscription (SubscriptionListener subscriptionListener ,
59- StatefulRedisPubSubConnection <ByteBuffer , ByteBuffer > connection ,
60+ StatefulRedisPubSubConnection <ByteBuffer , ByteBuffer > connection , LettuceReactivePubSubCommands commands ,
6061 Function <Throwable , Throwable > exceptionTranslator ) {
6162
6263 this .listener = new LettuceByteBufferPubSubListenerWrapper (
6364 new LettuceMessageListener ((messages , pattern ) -> {}, subscriptionListener ));
6465 this .connection = connection ;
65- this .commands = connection .reactive ();
66+ this .reactive = connection .reactive ();
67+ this .commands = commands ;
6668 connection .addListener (listener );
6769
6870 this .patternState = new State (exceptionTranslator );
@@ -84,7 +86,7 @@ public Mono<Void> pSubscribe(ByteBuffer... patterns) {
8486 Assert .notNull (patterns , "Patterns must not be null" );
8587 Assert .noNullElements (patterns , "Patterns must not contain null elements" );
8688
87- return patternState .subscribe (patterns , commands ::psubscribe );
89+ return patternState .subscribe (patterns , commands ::pSubscribe );
8890 }
8991
9092 @ Override
@@ -112,7 +114,7 @@ public Mono<Void> pUnsubscribe(ByteBuffer... patterns) {
112114 Assert .notNull (patterns , "Patterns must not be null" );
113115 Assert .noNullElements (patterns , "Patterns must not contain null elements" );
114116
115- return ObjectUtils .isEmpty (patterns ) ? Mono .empty () : patternState .unsubscribe (patterns , commands ::punsubscribe );
117+ return ObjectUtils .isEmpty (patterns ) ? Mono .empty () : patternState .unsubscribe (patterns , commands ::pUnsubscribe );
116118 }
117119
118120 @ Override
@@ -128,12 +130,12 @@ public Set<ByteBuffer> getPatterns() {
128130 @ Override
129131 public Flux <Message <ByteBuffer , ByteBuffer >> receive () {
130132
131- Flux <Message <ByteBuffer , ByteBuffer >> channelMessages = channelState .receive (() -> commands .observeChannels () //
132- .filter (message -> channelState .getTargets (). contains (message .getChannel ())) //
133+ Flux <Message <ByteBuffer , ByteBuffer >> channelMessages = channelState .receive (() -> reactive .observeChannels () //
134+ .filter (message -> channelState .contains (message .getChannel ())) //
133135 .map (message -> new ChannelMessage <>(message .getChannel (), message .getMessage ())));
134136
135- Flux <Message <ByteBuffer , ByteBuffer >> patternMessages = patternState .receive (() -> commands .observePatterns () //
136- .filter (message -> patternState .getTargets (). contains (message .getPattern ())) //
137+ Flux <Message <ByteBuffer , ByteBuffer >> patternMessages = patternState .receive (() -> reactive .observePatterns () //
138+ .filter (message -> patternState .contains (message .getPattern ())) //
137139 .map (message -> new PatternMessage <>(message .getPattern (), message .getChannel (), message .getMessage ())));
138140
139141 return channelMessages .mergeWith (patternMessages );
@@ -149,7 +151,7 @@ public Mono<Void> cancel() {
149151
150152 // this is to ensure completion of the futures and result processing. Since we're unsubscribing first, we expect
151153 // that we receive pub/sub confirmations before the PING response.
152- return commands .ping ().then (Mono .fromRunnable (() -> {
154+ return reactive .ping ().then (Mono .fromRunnable (() -> {
153155 connection .removeListener (listener );
154156 }));
155157 }));
@@ -162,7 +164,7 @@ public Mono<Void> cancel() {
162164 */
163165 static class State {
164166
165- private final Set <ByteBuffer > targets = new ConcurrentSkipListSet <>();
167+ private final Set <ByteArrayWrapper > targets = new ConcurrentSkipListSet <>();
166168 private final AtomicLong subscribers = new AtomicLong ();
167169 private final AtomicReference <Flux <?>> flux = new AtomicReference <>();
168170 private final Function <Throwable , Throwable > exceptionTranslator ;
@@ -182,8 +184,12 @@ static class State {
182184 */
183185 Mono <Void > subscribe (ByteBuffer [] targets , Function <ByteBuffer [], Mono <Void >> subscribeFunction ) {
184186
185- return subscribeFunction .apply (targets ).doOnSuccess ((discard ) -> this .targets .addAll (Arrays .asList (targets )))
186- .onErrorMap (exceptionTranslator );
187+ return subscribeFunction .apply (targets ).doOnSuccess ((discard ) -> {
188+
189+ for (ByteBuffer target : targets ) {
190+ this .targets .add (getWrapper (target ));
191+ }
192+ }).onErrorMap (exceptionTranslator );
187193 }
188194
189195 /**
@@ -198,16 +204,18 @@ Mono<Void> unsubscribe(ByteBuffer[] targets, Function<ByteBuffer[], Mono<Void>>
198204
199205 return Mono .defer (() -> {
200206
201- List <ByteBuffer > targetCollection = Arrays .asList (targets );
202-
203207 return unsubscribeFunction .apply (targets ).doOnSuccess ((discard ) -> {
204- this .targets .removeAll (targetCollection );
208+
209+ for (ByteBuffer byteBuffer : targets ) {
210+ this .targets .remove (getWrapper (byteBuffer ));
211+ }
205212 }).onErrorMap (exceptionTranslator );
206213 });
207214 }
208215
209216 Set <ByteBuffer > getTargets () {
210- return Collections .unmodifiableSet (targets );
217+ return targets .stream ().map (ByteArrayWrapper ::getArray ).map (ByteBuffer ::wrap )
218+ .collect (Collectors .toUnmodifiableSet ());
211219 }
212220
213221 /**
@@ -263,5 +271,13 @@ void terminate() {
263271 disposable .dispose ();
264272 }
265273 }
274+
275+ public boolean contains (ByteBuffer target ) {
276+ return this .targets .contains (getWrapper (target ));
277+ }
278+
279+ private static ByteArrayWrapper getWrapper (ByteBuffer byteBuffer ) {
280+ return new ByteArrayWrapper (byteBuffer );
281+ }
266282 }
267283}
0 commit comments