Skip to content

Commit 50bd4bb

Browse files
committed
fix(#15309): Optimize PriorityQueue.remove() function
1 parent 71e822e commit 50bd4bb

File tree

2 files changed

+127
-65
lines changed

2 files changed

+127
-65
lines changed

lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
import java.util.Arrays;
2020
import java.util.Collection;
2121
import java.util.Comparator;
22+
import java.util.HashMap;
23+
import java.util.HashSet;
2224
import java.util.Iterator;
25+
import java.util.Map;
2326
import java.util.NoSuchElementException;
27+
import java.util.Set;
2428
import java.util.function.IntFunction;
2529
import java.util.function.Supplier;
2630

@@ -72,6 +76,7 @@ public static <T> PriorityQueue<T> usingComparator(
7276
private final int maxSize;
7377
private final T[] heap;
7478
private final LessThan<? super T> lessThan;
79+
private final Map<T, Set<Integer>> indexMap = new HashMap<>();
7580

7681
/** Create an empty priority queue of the configured size using the specified {@link LessThan}. */
7782
public PriorityQueue(int maxSize, LessThan<? super T> lessThan) {
@@ -182,9 +187,9 @@ public void addAll(Collection<T> elements) {
182187
* @return the new 'top' element in the queue.
183188
*/
184189
public final T add(T element) {
185-
// don't modify size until we know heap access didn't throw AIOOB.
186190
int index = size + 1;
187191
heap[index] = element;
192+
addIndex(element, index);
188193
size = index;
189194
upHeap(index);
190195
return heap[1];
@@ -272,6 +277,7 @@ public final int size() {
272277
public final void clear() {
273278
Arrays.fill(heap, 0, size + 1, null);
274279
size = 0;
280+
indexMap.clear();
275281
}
276282

277283
/**
@@ -280,20 +286,25 @@ public final void clear() {
280286
* constant remove time but the trade-off would be extra cost to all additions/insertions)
281287
*/
282288
public final boolean remove(T element) {
283-
for (int i = 1; i <= size; i++) {
284-
if (heap[i] == element) {
285-
heap[i] = heap[size];
286-
heap[size] = null; // permit GC of objects
287-
size--;
288-
if (i <= size) {
289-
if (!upHeap(i)) {
290-
downHeap(i);
291-
}
292-
}
293-
return true;
294-
}
289+
Set<Integer> indices = indexMap.get(element);
290+
if (indices == null || indices.isEmpty()) return false;
291+
Integer idx = indices.iterator().next();
292+
removeIndex(element, idx);
293+
T last = heap[size];
294+
if (idx == size) {
295+
heap[size] = null;
296+
size--;
297+
return true;
295298
}
296-
return false;
299+
removeIndex(last, size);
300+
heap[idx] = last;
301+
addIndex(last, idx);
302+
heap[size] = null;
303+
size--;
304+
if (!upHeap(idx)) {
305+
downHeap(idx);
306+
}
307+
return true;
297308
}
298309

299310
/**
@@ -320,36 +331,54 @@ public T[] drainToArrayHighestFirst(IntFunction<T[]> newArray) {
320331
return array;
321332
}
322333

323-
private boolean upHeap(int origPos) {
324-
int i = origPos;
325-
T node = heap[i]; // save bottom node
326-
int j = i >>> 1;
327-
while (j > 0 && lessThan.lessThan(node, heap[j])) {
328-
heap[i] = heap[j]; // shift parents down
329-
i = j;
330-
j = j >>> 1;
334+
private void addIndex(T element, int idx) {
335+
indexMap.computeIfAbsent(element, k -> new HashSet<>()).add(idx);
336+
}
337+
338+
private void removeIndex(T element, int idx) {
339+
Set<Integer> indices = indexMap.get(element);
340+
if (indices != null) {
341+
indices.remove(idx);
342+
if (indices.isEmpty()) indexMap.remove(element);
331343
}
332-
heap[i] = node; // install saved node
333-
return i != origPos;
334344
}
335345

336-
private void downHeap(int i) {
337-
T node = heap[i]; // save top node
338-
int j = i << 1; // find smaller child
339-
int k = j + 1;
340-
if (k <= size && lessThan.lessThan(heap[k], heap[j])) {
341-
j = k;
346+
protected boolean upHeap(int i) {
347+
T node = heap[i];
348+
int j = i;
349+
while (j > 1 && lessThan.lessThan(node, heap[j >> 1])) {
350+
heap[j] = heap[j >> 1];
351+
removeIndex(heap[j], j >> 1);
352+
addIndex(heap[j], j);
353+
j >>= 1;
342354
}
343-
while (j <= size && lessThan.lessThan(heap[j], node)) {
344-
heap[i] = heap[j]; // shift up child
345-
i = j;
346-
j = i << 1;
347-
k = j + 1;
348-
if (k <= size && lessThan.lessThan(heap[k], heap[j])) {
355+
heap[j] = node;
356+
removeIndex(node, i);
357+
addIndex(node, j);
358+
return j < i;
359+
}
360+
361+
protected boolean downHeap(int i) {
362+
T node = heap[i];
363+
int j = i;
364+
int k;
365+
while ((k = j << 1) <= size) {
366+
if (k < size && lessThan.lessThan(heap[k + 1], heap[k])) {
367+
k++;
368+
}
369+
if (lessThan.lessThan(heap[k], node)) {
370+
heap[j] = heap[k];
371+
removeIndex(heap[j], k);
372+
addIndex(heap[j], j);
349373
j = k;
374+
} else {
375+
break;
350376
}
351377
}
352-
heap[i] = node; // install saved node
378+
heap[j] = node;
379+
removeIndex(node, i);
380+
addIndex(node, j);
381+
return j > i;
353382
}
354383

355384
/**

lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -208,47 +208,80 @@ public void testAddAllDoesNotFitIntoQueue() {
208208
() -> pq.addAll(list));
209209
}
210210

211-
/** Randomly add and remove elements, comparing against the reference java.util.PriorityQueue. */
212-
public void testRemovalsAndInsertions() {
211+
/** Randomly remove elements, comparing against the reference java.util.PriorityQueue by value. */
212+
public void testRemovals() {
213213
int maxElement = RandomNumbers.randomIntBetween(random(), 1, 10_000);
214214
int size = maxElement / 2 + 1;
215-
216215
var reference = new java.util.PriorityQueue<Integer>();
217216
var pq = new IntegerQueue(size);
218-
219217
Random localRandom = nonAssertingRandom(random());
220-
221-
// Lucene's PriorityQueue.remove uses reference equality, not .equals to determine which
222-
// elements
223-
// to remove (!).
224218
HashMap<Integer, Integer> ints = new HashMap<>();
219+
// Fill both queues with up to maxSize elements
220+
for (int i = 0; i < size; i++) {
221+
Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k);
222+
pq.add(element);
223+
reference.add(element);
224+
}
225+
// Perform random removals and compare by value
226+
for (int i = 0; i < size; i++) {
227+
Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k);
228+
int pqCount = 0, refCount = 0;
229+
for (Integer val : pq) if (val.equals(element)) pqCount++;
230+
for (Integer val : reference) if (val.equals(element)) refCount++;
231+
boolean pqRemoved = pq.remove(element);
232+
boolean refRemoved = reference.remove(element);
233+
assertEquals("remove() should return true if value was present", refCount > 0, pqRemoved);
234+
assertEquals("remove() should return true if value was present", refCount > 0, refRemoved);
235+
int pqCountAfter = 0, refCountAfter = 0;
236+
for (Integer val : pq) if (val.equals(element)) pqCountAfter++;
237+
for (Integer val : reference) if (val.equals(element)) refCountAfter++;
238+
assertEquals("Should remove only one instance (value)", Math.max(0, refCount - 1), refCountAfter);
239+
assertEquals("Should remove only one instance (value)", Math.max(0, pqCount - 1), pqCountAfter);
240+
assertEquals("pq and reference should match counts after removal", refCountAfter, pqCountAfter);
241+
assertEquals("size after removal should match", reference.size(), pq.size());
242+
Integer pqTop = pq.top();
243+
Integer refTop = reference.peek();
244+
if (pqTop != null && refTop != null) {
245+
assertEquals("top() value difference after removal?", refTop.intValue(), pqTop.intValue());
246+
} else {
247+
assertEquals("top() value difference after removal?", refTop, pqTop);
248+
}
249+
}
250+
pq.checkValidity();
251+
}
225252

253+
/** Randomly add elements, comparing against the reference java.util.PriorityQueue by value. */
254+
public void testInsertions() {
255+
int maxElement = RandomNumbers.randomIntBetween(random(), 1, 10_000);
256+
int size = maxElement / 2 + 1;
257+
var reference = new java.util.PriorityQueue<Integer>();
258+
var pq = new IntegerQueue(size);
259+
Random localRandom = nonAssertingRandom(random());
260+
HashMap<Integer, Integer> ints = new HashMap<>();
226261
for (int i = 0, iters = size * 2; i < iters; i++) {
227262
Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k);
228-
229-
var action = localRandom.nextInt(100);
230-
if (action < 25) {
231-
// removals, possibly misses.
232-
assertEquals("remove() difference: " + i, reference.remove(element), pq.remove(element));
263+
var dropped = pq.insertWithOverflow(element);
264+
reference.add(element);
265+
Integer droppedReference;
266+
if (reference.size() > size) {
267+
droppedReference = reference.remove();
233268
} else {
234-
// additions.
235-
var dropped = pq.insertWithOverflow(element);
236-
237-
reference.add(element);
238-
Integer droppedReference;
239-
if (reference.size() > size) {
240-
droppedReference = reference.remove();
241-
} else {
242-
droppedReference = null;
243-
}
244-
245-
assertEquals("insertWithOverflow() difference.", dropped, droppedReference);
269+
droppedReference = null;
270+
}
271+
if (dropped != null && droppedReference != null) {
272+
assertEquals("insertWithOverflow() dropped value difference.", dropped.intValue(), droppedReference.intValue());
273+
} else {
274+
assertEquals("insertWithOverflow() dropped value difference.", droppedReference, dropped);
246275
}
247-
248276
assertEquals("insertWithOverflow() size difference?", reference.size(), pq.size());
249-
assertEquals("top() difference?", reference.peek(), pq.top());
277+
Integer pqTop = pq.top();
278+
Integer refTop = reference.peek();
279+
if (pqTop != null && refTop != null) {
280+
assertEquals("top() value difference?", refTop.intValue(), pqTop.intValue());
281+
} else {
282+
assertEquals("top() value difference?", refTop, pqTop);
283+
}
250284
}
251-
252285
pq.checkValidity();
253286
}
254287

0 commit comments

Comments
 (0)