Skip to content

Commit 7a89465

Browse files
committed
Try to refine pivot in choose_pivot
1 parent 5dd6143 commit 7a89465

File tree

3 files changed

+59
-26
lines changed

3 files changed

+59
-26
lines changed

src/quicksort.jl

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,30 +28,18 @@ function Base.sort!(
2828
return v
2929
end
3030

31-
function choose_pivot(xs, order)
32-
return _median(
33-
order,
34-
(
35-
xs[1],
36-
xs[end÷8],
37-
xs[end÷4],
38-
xs[3*(end÷8)],
39-
xs[end÷2],
40-
xs[5*(end÷8)],
41-
xs[3*(end÷4)],
42-
xs[7*(end÷8)],
43-
xs[end],
44-
),
45-
)
46-
end
47-
4831
function _quicksort!(ys, xs, alg, order, givenpivot = nothing)
4932
@check length(ys) == length(xs)
5033
if length(ys) <= max(8, alg.basesize)
5134
return _quicksort_serial!(ys, xs, alg, order)
5235
end
36+
isrefined = false
5337
pivot = if givenpivot === nothing
54-
choose_pivot(ys, order)
38+
let pivot, ishomogenous
39+
pivot, ishomogenous, isrefined = choose_pivot(ys, alg.basesize, order)
40+
ishomogenous && return ys
41+
pivot
42+
end
5543
else
5644
something(givenpivot)
5745
end
@@ -95,6 +83,7 @@ function _quicksort!(ys, xs, alg, order, givenpivot = nothing)
9583
total_nbelows = above_offsets[1]
9684
if total_nbelows == 0
9785
@assert givenpivot === nothing
86+
@assert !isrefined
9887
betterpivot, ishomogenous = refine_pivot(ys, pivot, alg.basesize, order)
9988
ishomogenous && return ys
10089
return _quicksort!(ys, xs, alg, order, Some(betterpivot))
@@ -124,7 +113,7 @@ function _quicksort_serial!(ys, xs, alg, order)
124113
if length(ys) <= max(8, alg.smallsize)
125114
return sort!(ys, alg.smallsort, order)
126115
end
127-
pivot = choose_pivot(ys, order)
116+
_, pivot = samples_and_pivot(ys, order)
128117

129118
nbelows, naboves = quicksort_partition!(xs, ys, pivot, order)
130119
@DBG @check nbelows + naboves == length(xs)
@@ -161,6 +150,50 @@ function quicksort_copyback!(ys, xs_chunk, nbelows, below_offset, above_offset)
161150
end
162151
end
163152

153+
@inline function samples_and_pivot(xs, order)
154+
samples = (
155+
xs[1],
156+
xs[end÷8],
157+
xs[end÷4],
158+
xs[3*(end÷8)],
159+
xs[end÷2],
160+
xs[5*(end÷8)],
161+
xs[3*(end÷4)],
162+
xs[7*(end÷8)],
163+
xs[end],
164+
)
165+
pivot = _median(order, samples)
166+
return samples, pivot
167+
end
168+
169+
"""
170+
choose_pivot(xs, basesize, order) -> (pivot, ishomogenous::Bool, isrefined::Bool)
171+
"""
172+
function choose_pivot(xs, basesize, order)
173+
samples, pivot = samples_and_pivot(xs, order)
174+
if (
175+
eq(order, samples[1], pivot) &&
176+
eq(order, samples[1], samples[2]) &&
177+
eq(order, samples[2], samples[3]) &&
178+
eq(order, samples[3], samples[4]) &&
179+
eq(order, samples[4], samples[5]) &&
180+
eq(order, samples[5], samples[6]) &&
181+
eq(order, samples[6], samples[7]) &&
182+
eq(order, samples[7], samples[8]) &&
183+
eq(order, samples[8], samples[9])
184+
)
185+
pivot, ishomogenous =
186+
refine_pivot_serial(@view(xs[begin:min(end, begin + 127)]), pivot, order)
187+
if ishomogenous
188+
length(xs) <= 128 && return (pivot, true, true)
189+
pivot, ishomogenous =
190+
refine_pivot(@view(xs[begin+128:end]), pivot, basesize, order)
191+
return (pivot, ishomogenous, true)
192+
end
193+
end
194+
return (pivot, false, false)
195+
end
196+
164197
"""
165198
refine_pivot(ys, badpivot::T, basesize, order) -> (pivot::T, ishomogenous::Bool)
166199
@@ -224,5 +257,3 @@ end
224257
# TODO: Check if the homogeneity check can be done in `quicksort_partition!`
225258
# without overall performance degradation? Use it to determine the pivot
226259
# for the next recursion.
227-
# TODO: Do this right after `choose_pivot` if it finds out that all samples are
228-
# equivalent?

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ function elsizeof(::Type{T}) where {T}
3737
end
3838
end
3939

40-
eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a))
40+
@inline eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a))
4141

42-
function _median(order, (a, b, c)::NTuple{3,Any})
42+
@inline function _median(order, (a, b, c)::NTuple{3,Any})
4343
# Sort `(a, b, c)`:
4444
if Base.lt(order, b, a)
4545
a, b = b, a
@@ -53,7 +53,7 @@ function _median(order, (a, b, c)::NTuple{3,Any})
5353
return b
5454
end
5555

56-
_median(order, (a, b, c, d, e, f, g, h, i)::NTuple{9,Any}) = _median(
56+
@inline _median(order, (a, b, c, d, e, f, g, h, i)::NTuple{9,Any}) = _median(
5757
order,
5858
(_median(order, (a, b, c)), _median(order, (d, e, f)), _median(order, (g, h, i))),
5959
)

test/test_sort.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ using ThreadsX.Implementations: refine_pivot
1919
end
2020
end
2121

22+
divby(x) = Base.Fix2(÷, x)
23+
2224
@testset "stable sort" begin
2325
@testset for alg in [ThreadsX.MergeSort, ThreadsX.StableQuickSort]
24-
@test ThreadsX.sort(1:45; alg = alg, basesize = 25, by = _ -> 1) == 1:45
25-
@test ThreadsX.sort(1:1000; alg = alg, basesize = 200, by = _ -> 1) == 1:1000
26+
@test ThreadsX.sort(1:45; alg = alg, basesize = 25, by = divby(2)) == 1:45
27+
@test ThreadsX.sort(1:1000; alg = alg, basesize = 200, by = divby(2)) == 1:1000
2628
end
2729
end
2830

0 commit comments

Comments
 (0)