diff --git a/src/main/java/io/nats/client/impl/MessageQueue.java b/src/main/java/io/nats/client/impl/MessageQueue.java index feae8f3c3..d62a3a409 100644 --- a/src/main/java/io/nats/client/impl/MessageQueue.java +++ b/src/main/java/io/nats/client/impl/MessageQueue.java @@ -15,15 +15,16 @@ import io.nats.client.NatsSystemClock; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; -import java.util.function.Predicate; import static io.nats.client.support.NatsConstants.*; @@ -33,8 +34,9 @@ class MessageQueue { protected static final int DRAINING = 2; protected static final long MIN_OFFER_TIMEOUT_NANOS = 100 * NANOS_PER_MILLI; - protected final AtomicLong length; protected final AtomicLong sizeInBytes; + protected final AtomicLong length; + protected final AtomicBoolean filtered; protected final AtomicInteger running; protected final boolean singleReaderMode; protected final LinkedBlockingQueue queue; @@ -45,12 +47,18 @@ class MessageQueue { protected final long offerTimeoutNanos; protected final Duration requestCleanupInterval; + static class MarkerMessage extends ProtocolMessage { + MarkerMessage(String mark) { + super(mark.getBytes(StandardCharsets.ISO_8859_1), false); + } + } + // SPECIAL MARKER MESSAGES // A simple == is used to resolve if any message is exactly the static pill object in question // ---------- // 1. Poison pill is a graphic, but common term for an item that breaks loops or stop something. // In this class the poison pill is used to break out of timed waits on the blocking queue. - protected static final NatsMessage POISON_PILL = new NatsMessage("_poison", null, EMPTY_BODY); + protected static final MarkerMessage POISON_PILL = new MarkerMessage("_poison"); MessageQueue(boolean singleReaderMode, Duration requestCleanupInterval) { this(singleReaderMode, -1, false, requestCleanupInterval, null); @@ -81,6 +89,7 @@ class MessageQueue { this.running = new AtomicInteger(RUNNING); sizeInBytes = new AtomicLong(0); length = new AtomicLong(0); + filtered = new AtomicBoolean(true); this.offerLockNanos = requestCleanupInterval.toNanos(); this.offerTimeoutNanos = Math.max(MIN_OFFER_TIMEOUT_NANOS, requestCleanupInterval.toMillis() * NANOS_PER_MILLI * 95 / 100) ; @@ -97,9 +106,11 @@ class MessageQueue { void drainTo(MessageQueue target) { editLock.lock(); try { - queue.drainTo(target.queue); - target.length.set(length.getAndSet(0)); + this.queue.drainTo(target.queue); target.sizeInBytes.set(sizeInBytes.getAndSet(0)); + target.length.set(length.getAndSet(0)); + target.filtered.set(false); + this.filtered.set(true); } finally { editLock.unlock(); } @@ -178,6 +189,7 @@ boolean push(NatsMessage msg, boolean internal) { } sizeInBytes.getAndAdd(msg.getSizeInBytes()); length.incrementAndGet(); + filtered.set(false); return true; } @@ -206,11 +218,11 @@ void poisonTheQueue() { } /** - * Marking the queue, like POISON, is a message we don't want to count. + * Marking the queue, like poisonTheQueue, is a message we don't want to count. * Intended to only be used with an unbounded queue. Use at your own risk. * @param msg the mark */ - void markTheQueue(NatsMessage msg) { + void markTheQueue(MarkerMessage msg) { queue.offer(msg); } @@ -250,7 +262,7 @@ NatsMessage pop(Duration timeout) throws InterruptedException { } sizeInBytes.getAndAdd(-msg.getSizeInBytes()); - length.decrementAndGet(); + filtered.set(length.decrementAndGet() == 0); return msg; } @@ -286,7 +298,7 @@ NatsMessage accumulate(long maxBytesToAccumulate, long maxMessagesToAccumulate, if (maxMessagesToAccumulate <= 1 || size >= maxBytesToAccumulate) { sizeInBytes.addAndGet(-size); - length.decrementAndGet(); + filtered.set(length.decrementAndGet() == 0); return msg; } @@ -320,7 +332,7 @@ NatsMessage accumulate(long maxBytesToAccumulate, long maxMessagesToAccumulate, } sizeInBytes.addAndGet(-size); - length.addAndGet(-accumulated); + filtered.set(length.addAndGet(-accumulated) == 0); return msg; } @@ -338,24 +350,32 @@ long sizeInBytes() { return sizeInBytes.get(); } - void filter(Predicate p) { + void filterOnStop() { editLock.lock(); try { if (this.isRunning()) { throw new IllegalStateException("Filter is only supported when the queue is paused"); } - ArrayList newQueue = new ArrayList<>(); - NatsMessage cursor = this.queue.poll(); - while (cursor != null) { - if (!p.test(cursor)) { - newQueue.add(cursor); - } else { - sizeInBytes.addAndGet(-cursor.getSizeInBytes()); - length.decrementAndGet(); + if (!filtered.get()) { + long removed = 0; + long removedBytes = 0; + ArrayList newQueue = new ArrayList<>(); + NatsMessage cursor = this.queue.poll(); + while (cursor != null) { + if (cursor.isProtocolFilterOnStop()) { + removedBytes += cursor.getSizeInBytes(); + removed++; + } + else { + newQueue.add(cursor); + } + cursor = this.queue.poll(); } - cursor = this.queue.poll(); + this.queue.addAll(newQueue); + sizeInBytes.addAndGet(-removedBytes); + length.addAndGet(-removed); + filtered.set(true); } - this.queue.addAll(newQueue); } finally { editLock.unlock(); } @@ -367,6 +387,7 @@ void clear() { this.queue.clear(); length.set(0); sizeInBytes.set(0); + filtered.set(true); } finally { editLock.unlock(); } diff --git a/src/main/java/io/nats/client/impl/NatsConnectionWriter.java b/src/main/java/io/nats/client/impl/NatsConnectionWriter.java index 221df6686..b0af7279f 100644 --- a/src/main/java/io/nats/client/impl/NatsConnectionWriter.java +++ b/src/main/java/io/nats/client/impl/NatsConnectionWriter.java @@ -31,7 +31,8 @@ import java.util.concurrent.locks.ReentrantLock; import static io.nats.client.support.BuilderBase.bufferAllocSize; -import static io.nats.client.support.NatsConstants.*; +import static io.nats.client.support.NatsConstants.CR; +import static io.nats.client.support.NatsConstants.LF; class NatsConnectionWriter implements Runnable { enum Mode { @@ -114,7 +115,7 @@ Future stop() { try { this.normalOutgoing.pause(); this.reconnectOutgoing.pause(); - this.normalOutgoing.filter(NatsMessage::isProtocolFilterOnStop); + this.normalOutgoing.filterOnStop(); } finally { this.startStopLock.unlock(); @@ -127,7 +128,7 @@ boolean isRunning() { return running.get(); } - private static final NatsMessage END_RECONNECT = new NatsMessage("_end", null, EMPTY_BODY); + private static final MessageQueue.MarkerMessage END_RECONNECT = new MessageQueue.MarkerMessage("_end_reconnect"); void sendMessageBatch(NatsMessage msg, DataPort dataPort, StatisticsCollector stats) throws IOException { writerLock.lock(); diff --git a/src/test/java/io/nats/client/impl/MessageQueueTests.java b/src/test/java/io/nats/client/impl/MessageQueueTests.java index 6585a6cd2..628695dd6 100644 --- a/src/test/java/io/nats/client/impl/MessageQueueTests.java +++ b/src/test/java/io/nats/client/impl/MessageQueueTests.java @@ -15,9 +15,7 @@ import org.junit.jupiter.api.Test; -import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.util.Arrays; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -29,9 +27,9 @@ public class MessageQueueTests { static final Duration REQUEST_CLEANUP_INTERVAL = Duration.ofSeconds(5); static final byte[] PING = "PING".getBytes(); - static final byte[] ONE = "one".getBytes(); - static final byte[] TWO = "two".getBytes(); - static final byte[] THREE = "three".getBytes(); + static final byte[] AAA = "aaa".getBytes(); + static final byte[] BBB = "bbb".getBytes(); + static final byte[] CCC = "ccc".getBytes(); @Test public void testEmptyPop() throws InterruptedException { @@ -457,9 +455,9 @@ public void testLength() throws InterruptedException { @Test public void testSizeInBytes() throws InterruptedException { MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL); - NatsMessage msg1 = new ProtocolMessage(ONE); - NatsMessage msg2 = new ProtocolMessage(TWO); - NatsMessage msg3 = new ProtocolMessage(THREE); + NatsMessage msg1 = new ProtocolMessage(AAA); + NatsMessage msg2 = new ProtocolMessage(BBB); + NatsMessage msg3 = new ProtocolMessage(CCC); long expected = 0; q.push(msg1); @@ -548,75 +546,55 @@ public void testDrainTo() { } @Test - public void testFilterTail() throws InterruptedException { - MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL); - NatsMessage msg1 = new ProtocolMessage(ONE); - NatsMessage msg2 = new ProtocolMessage(TWO); - NatsMessage msg3 = new ProtocolMessage(THREE); - byte[] expected = "one".getBytes(StandardCharsets.UTF_8); - - q.push(msg1); - q.push(msg2); - q.push(msg3); - - long before = q.sizeInBytes(); - q.pause(); - q.filter((msg) -> Arrays.equals(expected, msg.getProtocolBytes())); - q.resume(); - long after = q.sizeInBytes(); + public void testFilterFirstIn() throws InterruptedException { + _testFiltered(1); + } - assertEquals(2,q.length()); - assertEquals(before, after + expected.length + 2); - assertEquals(q.popNow(), msg2); - assertEquals(q.popNow(), msg3); + @Test + public void testFilterLastIn() throws InterruptedException { + _testFiltered(3); } @Test - public void testFilterHead() throws InterruptedException { - MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL); - NatsMessage msg1 = new ProtocolMessage(ONE); - NatsMessage msg2 = new ProtocolMessage(TWO); - NatsMessage msg3 = new ProtocolMessage(THREE); - byte[] expected = "three".getBytes(StandardCharsets.UTF_8); + public void testFilterMiddle() throws InterruptedException { + _testFiltered(2); + } + private static void _testFiltered(int filtered) throws InterruptedException { + NatsMessage msg1 = new ProtocolMessage(AAA, filtered == 1); + NatsMessage msg2 = new ProtocolMessage(BBB, filtered == 2); + NatsMessage msg3 = new ProtocolMessage(CCC, filtered == 3); + + MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL); q.push(msg1); q.push(msg2); q.push(msg3); long before = q.sizeInBytes(); q.pause(); - q.filter((msg) -> Arrays.equals(expected, msg.getProtocolBytes())); + q.filterOnStop(); q.resume(); long after = q.sizeInBytes(); - assertEquals(2,q.length()); - assertEquals(before, after + expected.length + 2); - assertEquals(q.popNow(), msg1); - assertEquals(q.popNow(), msg2); - } - - @Test - public void testFilterMiddle() throws InterruptedException { - MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL); - NatsMessage msg1 = new ProtocolMessage(ONE); - NatsMessage msg2 = new ProtocolMessage(TWO); - NatsMessage msg3 = new ProtocolMessage(THREE); - byte[] expected = "two".getBytes(StandardCharsets.UTF_8); - - q.push(msg1); - q.push(msg2); - q.push(msg3); + assertEquals(2, q.length()); + assertEquals(before, after + 3 + 2); - long before = q.sizeInBytes(); q.pause(); - q.filter((msg) -> Arrays.equals(expected, msg.getProtocolBytes())); + q.filterOnStop(); q.resume(); - long after = q.sizeInBytes(); - assertEquals(2,q.length()); - assertEquals(before, after + expected.length + 2); - assertEquals(q.popNow(), msg1); - assertEquals(q.popNow(), msg3); + assertEquals(2, q.length()); + assertEquals(before, after + 3 + 2); + + if (filtered != 1) { + assertEquals(q.popNow(), msg1); + } + if (filtered != 2) { + assertEquals(q.popNow(), msg2); + } + if (filtered != 3) { + assertEquals(q.popNow(), msg3); + } } @Test @@ -631,7 +609,7 @@ public void testPausedAccumulate() throws InterruptedException { public void testThrowOnFilterIfRunning() { assertThrows(IllegalStateException.class, () -> { MessageQueue q = new MessageQueue(true, REQUEST_CLEANUP_INTERVAL); - q.filter((msg) -> true); + q.filterOnStop(); fail(); }); } @@ -639,9 +617,9 @@ public void testThrowOnFilterIfRunning() { @Test public void testExceptionWhenQueueIsFull() { MessageQueue q = new MessageQueue(true, 2, false, REQUEST_CLEANUP_INTERVAL); - NatsMessage msg1 = new ProtocolMessage(ONE); - NatsMessage msg2 = new ProtocolMessage(TWO); - NatsMessage msg3 = new ProtocolMessage(THREE); + NatsMessage msg1 = new ProtocolMessage(AAA); + NatsMessage msg2 = new ProtocolMessage(BBB); + NatsMessage msg3 = new ProtocolMessage(CCC); assertTrue(q.push(msg1)); assertTrue(q.push(msg2)); @@ -656,9 +634,9 @@ public void testExceptionWhenQueueIsFull() { @Test public void testDiscardMessageWhenQueueFull() { MessageQueue q = new MessageQueue(true, 2, true, REQUEST_CLEANUP_INTERVAL); - NatsMessage msg1 = new ProtocolMessage(ONE); - NatsMessage msg2 = new ProtocolMessage(TWO); - NatsMessage msg3 = new ProtocolMessage(THREE); + NatsMessage msg1 = new ProtocolMessage(AAA); + NatsMessage msg2 = new ProtocolMessage(BBB); + NatsMessage msg3 = new ProtocolMessage(CCC); assertTrue(q.push(msg1)); assertTrue(q.push(msg2));