Skip to content

Commit 5397ac1

Browse files
mp911dechristophstrobl
authored andcommitted
Track subscriptions and unsubscriptions in LettuceReactiveRedisConnection.
We now track subscriptions and unsubscriptions in the reactive API to ensure that we do not prematurely unsubscribe from a channel or pattern if the topic was subscribed multiple times. Original Pull Request: #2467
1 parent f58d4e9 commit 5397ac1

File tree

7 files changed

+484
-43
lines changed

7 files changed

+484
-43
lines changed

src/main/java/org/springframework/data/redis/connection/lettuce/LettuceReactivePubSubCommands.java

Lines changed: 159 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,21 @@
2020
import reactor.core.publisher.Mono;
2121

2222
import java.nio.ByteBuffer;
23+
import java.util.ArrayList;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.concurrent.ConcurrentHashMap;
27+
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
28+
import java.util.function.BiFunction;
2329
import java.util.function.Function;
2430

2531
import org.reactivestreams.Publisher;
26-
2732
import org.springframework.data.redis.connection.ReactivePubSubCommands;
2833
import org.springframework.data.redis.connection.ReactiveSubscription;
2934
import org.springframework.data.redis.connection.ReactiveSubscription.ChannelMessage;
3035
import org.springframework.data.redis.connection.SubscriptionListener;
36+
import org.springframework.data.redis.connection.util.ByteArrayWrapper;
37+
import org.springframework.data.redis.util.ByteUtils;
3138
import org.springframework.util.Assert;
3239

3340
/**
@@ -39,16 +46,27 @@ class LettuceReactivePubSubCommands implements ReactivePubSubCommands {
3946

4047
private final LettuceReactiveRedisConnection connection;
4148

49+
private final Map<ByteArrayWrapper, Target> channels = new ConcurrentHashMap<>();
50+
51+
private final Map<ByteArrayWrapper, Target> patterns = new ConcurrentHashMap<>();
52+
4253
LettuceReactivePubSubCommands(LettuceReactiveRedisConnection connection) {
4354
this.connection = connection;
4455
}
4556

57+
public Map<ByteArrayWrapper, Target> getChannels() {
58+
return channels;
59+
}
60+
61+
public Map<ByteArrayWrapper, Target> getPatterns() {
62+
return patterns;
63+
}
64+
4665
@Override
4766
public Mono<ReactiveSubscription> createSubscription(SubscriptionListener listener) {
4867

49-
return connection.getPubSubConnection()
50-
.map(pubSubConnection -> new LettuceReactiveSubscription(listener, pubSubConnection,
51-
connection.translateException()));
68+
return connection.getPubSubConnection().map(pubSubConnection -> new LettuceReactiveSubscription(listener,
69+
pubSubConnection, this, connection.translateException()));
5270
}
5371

5472
@Override
@@ -65,20 +83,157 @@ public Mono<Void> subscribe(ByteBuffer... channels) {
6583

6684
Assert.notNull(channels, "Channels must not be null");
6785

86+
Target.trackSubscriptions(channels, this.channels); // track usage but do not limit what to subscribe to
87+
6888
return doWithPubSub(commands -> commands.subscribe(channels));
6989
}
7090

91+
public Mono<Void> unsubscribe(ByteBuffer... channels) {
92+
93+
Assert.notNull(patterns, "Patterns must not be null");
94+
95+
ByteBuffer[] actualUnsubscribe = Target.trackUnsubscriptions(channels, this.channels);
96+
97+
if (actualUnsubscribe.length == 0 && channels.length != 0) {
98+
return Mono.empty();
99+
}
100+
101+
return doWithPubSub(commands -> commands.unsubscribe(actualUnsubscribe));
102+
}
103+
71104
@Override
72105
public Mono<Void> pSubscribe(ByteBuffer... patterns) {
73106

74107
Assert.notNull(patterns, "Patterns must not be null");
75108

109+
Target.trackSubscriptions(patterns, this.patterns); // track usage but do not limit what to subscribe to
110+
76111
return doWithPubSub(commands -> commands.psubscribe(patterns));
77112
}
78113

114+
public Mono<Void> pUnsubscribe(ByteBuffer... patterns) {
115+
116+
Assert.notNull(patterns, "Patterns must not be null");
117+
118+
ByteBuffer[] actualUnsubscribe = Target.trackUnsubscriptions(patterns, this.patterns);
119+
120+
if (actualUnsubscribe.length == 0 && patterns.length != 0) {
121+
return Mono.empty();
122+
}
123+
124+
return doWithPubSub(commands -> commands.punsubscribe(actualUnsubscribe));
125+
}
126+
79127
private <T> Mono<T> doWithPubSub(Function<RedisPubSubReactiveCommands<ByteBuffer, ByteBuffer>, Mono<T>> function) {
80128

81129
return connection.getPubSubConnection().flatMap(pubSubConnection -> function.apply(pubSubConnection.reactive()))
82130
.onErrorMap(connection.translateException());
83131
}
132+
133+
static class Target {
134+
135+
private static final AtomicLongFieldUpdater<Target> SUBSCRIBERS = AtomicLongFieldUpdater.newUpdater(Target.class,
136+
"subscribers");
137+
138+
private final byte[] raw;
139+
140+
private volatile long subscribers;
141+
142+
Target(byte[] raw) {
143+
this.raw = raw;
144+
}
145+
146+
/**
147+
* Record the subscriptions to {@code targets} and store these in {@code targetMap}.
148+
*
149+
* @param targets
150+
* @param targetMap
151+
*/
152+
public static void trackSubscriptions(ByteBuffer[] targets, Map<ByteArrayWrapper, Target> targetMap) {
153+
doWithTargets(targets, targetMap, Target::allocate);
154+
}
155+
156+
/**
157+
* Record the un-subscriptions to {@code targets} and store these in {@code targetMap}. Returns the targets to
158+
* actually unsubscribe from if there are no subscribers to a particular target.
159+
*
160+
* @param targets
161+
* @param targetMap
162+
*/
163+
public static ByteBuffer[] trackUnsubscriptions(ByteBuffer[] targets, Map<ByteArrayWrapper, Target> targetMap) {
164+
return doWithTargets(targets, targetMap, Target::deallocate);
165+
}
166+
167+
static ByteBuffer[] doWithTargets(ByteBuffer[] targets, Map<ByteArrayWrapper, Target> targetMap,
168+
BiFunction<ByteBuffer, Map<ByteArrayWrapper, Target>, Boolean> f) {
169+
170+
List<ByteBuffer> toSubscribe = new ArrayList<>(targets.length);
171+
172+
synchronized (targetMap) {
173+
for (ByteBuffer target : targets) {
174+
if (f.apply(target, targetMap)) {
175+
toSubscribe.add(target);
176+
}
177+
}
178+
}
179+
180+
return toSubscribe.toArray(new ByteBuffer[0]);
181+
}
182+
183+
boolean increment() {
184+
return SUBSCRIBERS.incrementAndGet(this) == 1;
185+
}
186+
187+
boolean decrement() {
188+
189+
long l = SUBSCRIBERS.get(this);
190+
191+
if (l > 0) {
192+
if (SUBSCRIBERS.compareAndSet(this, l, l - 1)) {
193+
return l == 1; // return true if this was the last subscriber
194+
}
195+
}
196+
197+
return false;
198+
}
199+
200+
static boolean allocate(ByteBuffer buffer, Map<ByteArrayWrapper, Target> targets) {
201+
202+
byte[] raw = ByteUtils.getBytes(buffer);
203+
204+
ByteArrayWrapper wrapper = new ByteArrayWrapper(raw);
205+
Target targetToUse = targets.get(wrapper);
206+
207+
if (targetToUse == null) {
208+
targetToUse = new Target(raw);
209+
targets.put(wrapper, targetToUse);
210+
}
211+
212+
return targetToUse.increment();
213+
}
214+
215+
static boolean deallocate(ByteBuffer buffer, Map<ByteArrayWrapper, Target> targets) {
216+
217+
byte[] raw = ByteUtils.getBytes(buffer);
218+
219+
ByteArrayWrapper wrapper = new ByteArrayWrapper(raw);
220+
Target targetToUse = targets.get(wrapper);
221+
222+
if (targetToUse == null) {
223+
return false;
224+
}
225+
226+
if (targetToUse.decrement()) {
227+
targets.remove(wrapper);
228+
return true;
229+
}
230+
231+
return false;
232+
}
233+
234+
@Override
235+
public String toString() {
236+
return String.format("%s: Subscribers: %s", new String(raw), SUBSCRIBERS.get(this));
237+
}
238+
}
84239
}

src/main/java/org/springframework/data/redis/connection/lettuce/LettuceReactiveRedisConnection.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class LettuceReactiveRedisConnection implements ReactiveRedisConnection {
5151
private final AsyncConnect<StatefulConnection<ByteBuffer, ByteBuffer>> dedicatedConnection;
5252
private final AsyncConnect<StatefulRedisPubSubConnection<ByteBuffer, ByteBuffer>> pubSubConnection;
5353

54+
private final LettuceReactivePubSubCommands pubSub = new LettuceReactivePubSubCommands(this);
55+
5456
private @Nullable Mono<StatefulConnection<ByteBuffer, ByteBuffer>> sharedConnection;
5557

5658
/**
@@ -137,7 +139,7 @@ public ReactiveHyperLogLogCommands hyperLogLogCommands() {
137139

138140
@Override
139141
public ReactivePubSubCommands pubSubCommands() {
140-
return new LettuceReactivePubSubCommands(this);
142+
return pubSub;
141143
}
142144

143145
@Override

src/main/java/org/springframework/data/redis/connection/lettuce/LettuceReactiveSubscription.java

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,17 @@
2323
import reactor.core.publisher.Mono;
2424

2525
import java.nio.ByteBuffer;
26-
import java.util.Arrays;
27-
import java.util.Collections;
28-
import java.util.List;
2926
import java.util.Set;
3027
import java.util.concurrent.ConcurrentSkipListSet;
3128
import java.util.concurrent.atomic.AtomicLong;
3229
import java.util.concurrent.atomic.AtomicReference;
3330
import java.util.function.Function;
3431
import java.util.function.Supplier;
32+
import java.util.stream.Collectors;
3533

3634
import org.springframework.data.redis.connection.ReactiveSubscription;
3735
import org.springframework.data.redis.connection.SubscriptionListener;
36+
import org.springframework.data.redis.connection.util.ByteArrayWrapper;
3837
import org.springframework.lang.Nullable;
3938
import org.springframework.util.Assert;
4039
import org.springframework.util.ObjectUtils;
@@ -50,19 +49,22 @@ class LettuceReactiveSubscription implements ReactiveSubscription {
5049

5150
private final LettuceByteBufferPubSubListenerWrapper listener;
5251
private final StatefulRedisPubSubConnection<ByteBuffer, ByteBuffer> connection;
53-
private final RedisPubSubReactiveCommands<ByteBuffer, ByteBuffer> commands;
52+
53+
private final RedisPubSubReactiveCommands<ByteBuffer, ByteBuffer> reactive;
54+
private final LettuceReactivePubSubCommands commands;
5455

5556
private final State patternState;
5657
private final State channelState;
5758

5859
LettuceReactiveSubscription(SubscriptionListener subscriptionListener,
59-
StatefulRedisPubSubConnection<ByteBuffer, ByteBuffer> connection,
60+
StatefulRedisPubSubConnection<ByteBuffer, ByteBuffer> connection, LettuceReactivePubSubCommands commands,
6061
Function<Throwable, Throwable> exceptionTranslator) {
6162

6263
this.listener = new LettuceByteBufferPubSubListenerWrapper(
6364
new LettuceMessageListener((messages, pattern) -> {}, subscriptionListener));
6465
this.connection = connection;
65-
this.commands = connection.reactive();
66+
this.reactive = connection.reactive();
67+
this.commands = commands;
6668
connection.addListener(listener);
6769

6870
this.patternState = new State(exceptionTranslator);
@@ -84,7 +86,7 @@ public Mono<Void> pSubscribe(ByteBuffer... patterns) {
8486
Assert.notNull(patterns, "Patterns must not be null");
8587
Assert.noNullElements(patterns, "Patterns must not contain null elements");
8688

87-
return patternState.subscribe(patterns, commands::psubscribe);
89+
return patternState.subscribe(patterns, commands::pSubscribe);
8890
}
8991

9092
@Override
@@ -112,7 +114,7 @@ public Mono<Void> pUnsubscribe(ByteBuffer... patterns) {
112114
Assert.notNull(patterns, "Patterns must not be null");
113115
Assert.noNullElements(patterns, "Patterns must not contain null elements");
114116

115-
return ObjectUtils.isEmpty(patterns) ? Mono.empty() : patternState.unsubscribe(patterns, commands::punsubscribe);
117+
return ObjectUtils.isEmpty(patterns) ? Mono.empty() : patternState.unsubscribe(patterns, commands::pUnsubscribe);
116118
}
117119

118120
@Override
@@ -128,12 +130,12 @@ public Set<ByteBuffer> getPatterns() {
128130
@Override
129131
public Flux<Message<ByteBuffer, ByteBuffer>> receive() {
130132

131-
Flux<Message<ByteBuffer, ByteBuffer>> channelMessages = channelState.receive(() -> commands.observeChannels() //
132-
.filter(message -> channelState.getTargets().contains(message.getChannel())) //
133+
Flux<Message<ByteBuffer, ByteBuffer>> channelMessages = channelState.receive(() -> reactive.observeChannels() //
134+
.filter(message -> channelState.contains(message.getChannel())) //
133135
.map(message -> new ChannelMessage<>(message.getChannel(), message.getMessage())));
134136

135-
Flux<Message<ByteBuffer, ByteBuffer>> patternMessages = patternState.receive(() -> commands.observePatterns() //
136-
.filter(message -> patternState.getTargets().contains(message.getPattern())) //
137+
Flux<Message<ByteBuffer, ByteBuffer>> patternMessages = patternState.receive(() -> reactive.observePatterns() //
138+
.filter(message -> patternState.contains(message.getPattern())) //
137139
.map(message -> new PatternMessage<>(message.getPattern(), message.getChannel(), message.getMessage())));
138140

139141
return channelMessages.mergeWith(patternMessages);
@@ -149,7 +151,7 @@ public Mono<Void> cancel() {
149151

150152
// this is to ensure completion of the futures and result processing. Since we're unsubscribing first, we expect
151153
// that we receive pub/sub confirmations before the PING response.
152-
return commands.ping().then(Mono.fromRunnable(() -> {
154+
return reactive.ping().then(Mono.fromRunnable(() -> {
153155
connection.removeListener(listener);
154156
}));
155157
}));
@@ -162,7 +164,7 @@ public Mono<Void> cancel() {
162164
*/
163165
static class State {
164166

165-
private final Set<ByteBuffer> targets = new ConcurrentSkipListSet<>();
167+
private final Set<ByteArrayWrapper> targets = new ConcurrentSkipListSet<>();
166168
private final AtomicLong subscribers = new AtomicLong();
167169
private final AtomicReference<Flux<?>> flux = new AtomicReference<>();
168170
private final Function<Throwable, Throwable> exceptionTranslator;
@@ -182,8 +184,12 @@ static class State {
182184
*/
183185
Mono<Void> subscribe(ByteBuffer[] targets, Function<ByteBuffer[], Mono<Void>> subscribeFunction) {
184186

185-
return subscribeFunction.apply(targets).doOnSuccess((discard) -> this.targets.addAll(Arrays.asList(targets)))
186-
.onErrorMap(exceptionTranslator);
187+
return subscribeFunction.apply(targets).doOnSuccess((discard) -> {
188+
189+
for (ByteBuffer target : targets) {
190+
this.targets.add(getWrapper(target));
191+
}
192+
}).onErrorMap(exceptionTranslator);
187193
}
188194

189195
/**
@@ -198,16 +204,18 @@ Mono<Void> unsubscribe(ByteBuffer[] targets, Function<ByteBuffer[], Mono<Void>>
198204

199205
return Mono.defer(() -> {
200206

201-
List<ByteBuffer> targetCollection = Arrays.asList(targets);
202-
203207
return unsubscribeFunction.apply(targets).doOnSuccess((discard) -> {
204-
this.targets.removeAll(targetCollection);
208+
209+
for (ByteBuffer byteBuffer : targets) {
210+
this.targets.remove(getWrapper(byteBuffer));
211+
}
205212
}).onErrorMap(exceptionTranslator);
206213
});
207214
}
208215

209216
Set<ByteBuffer> getTargets() {
210-
return Collections.unmodifiableSet(targets);
217+
return targets.stream().map(ByteArrayWrapper::getArray).map(ByteBuffer::wrap)
218+
.collect(Collectors.toUnmodifiableSet());
211219
}
212220

213221
/**
@@ -263,5 +271,13 @@ void terminate() {
263271
disposable.dispose();
264272
}
265273
}
274+
275+
public boolean contains(ByteBuffer target) {
276+
return this.targets.contains(getWrapper(target));
277+
}
278+
279+
private static ByteArrayWrapper getWrapper(ByteBuffer byteBuffer) {
280+
return new ByteArrayWrapper(byteBuffer);
281+
}
266282
}
267283
}

0 commit comments

Comments
 (0)