Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit df516cb

Browse files
Parallel sort via tbb (#844)
1 parent 37da684 commit df516cb

File tree

10 files changed

+1433
-3
lines changed

10 files changed

+1433
-3
lines changed

conda-recipe/meta.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ requirements:
2929
- pandas {{ PANDAS_VERSION }}
3030
- pyarrow {{ PYARROW_VERSION }}
3131
- wheel
32+
- tbb-devel
3233

3334
run:
3435
- python
@@ -37,6 +38,7 @@ requirements:
3738
- pyarrow {{ PYARROW_VERSION }}
3839
- numba {{ NUMBA_VERSION }}
3940
- setuptools
41+
- tbb4py
4042

4143
test:
4244
imports:
@@ -58,6 +60,7 @@ outputs:
5860
- numpy
5961
- pandas {{ PANDAS_VERSION }}
6062
- pyarrow {{ PYARROW_VERSION }}
63+
- tbb-devel
6164

6265
about:
6366
home: https://github.com/IntelPython/sdc

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ numpy>=1.16
22
pandas==0.25.3
33
pyarrow==0.17.0
44
numba==0.49.1
5+
tbb
6+
tbb-devel

sdc/functions/sort.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2020, Intel Corporation All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
#
10+
# Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
16+
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
17+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
18+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
22+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
23+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24+
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
from numba import njit, cfunc, literally
28+
from numba.extending import intrinsic, overload
29+
from numba import types
30+
from numba.core import cgutils
31+
from numba import typed
32+
from numba import config
33+
import ctypes as ct
34+
35+
from sdc import concurrent_sort
36+
37+
38+
def bind(sym, sig):
39+
# Returns ctypes binding to symbol sym with signature sig
40+
addr = getattr(concurrent_sort, sym)
41+
return ct.cast(addr, sig)
42+
43+
44+
parallel_sort_arithm_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_uint64)
45+
46+
parallel_sort_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_uint64,
47+
ct.c_uint64, ct.c_void_p,)
48+
49+
parallel_sort_sym = bind('parallel_sort',
50+
parallel_sort_sig)
51+
52+
parallel_stable_sort_sym = bind('parallel_stable_sort',
53+
parallel_sort_sig)
54+
55+
parallel_sort_t_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_uint64)
56+
57+
set_threads_count_sig = ct.CFUNCTYPE(None, ct.c_uint64)
58+
set_threads_count_sym = bind('set_number_of_threads', set_threads_count_sig)
59+
60+
set_threads_count_sym(config.NUMBA_NUM_THREADS)
61+
62+
63+
def less(left, right):
64+
pass
65+
66+
67+
@overload(less, jit_options={'locals': {'result': types.int8}})
68+
def less_overload(left, right):
69+
def less_impl(left, right):
70+
result = left < right
71+
return result
72+
73+
return less_impl
74+
75+
76+
@intrinsic
77+
def adaptor(tyctx, thing, another):
78+
# This function creates a call specialisation on "custom_hash" based on the
79+
# type of "thing" and its literal value
80+
81+
# resolve to function type
82+
sig = types.intp(thing, another)
83+
fnty = tyctx.resolve_value_type(less)
84+
85+
def codegen(cgctx, builder, sig, args):
86+
ty = sig.args[0]
87+
# trigger resolution to get a "custom_hash" impl based on the call type
88+
# "ty" and its literal value
89+
# import pdb; pdb.set_trace()
90+
lsig = fnty.get_call_type(tyctx, (ty, ty), {})
91+
resolved = cgctx.get_function(fnty, lsig)
92+
93+
# close over resolved function, this is to deal with python scoping
94+
def resolved_codegen(cgctx, builder, sig, args):
95+
return resolved(builder, args)
96+
97+
# A python function "wrapper" is made for the `@cfunc` arg, this calls
98+
# the jitted function "wrappee", which will be compiled as part of the
99+
# compilation chain for the cfunc. In turn the wrappee jitted function
100+
# has an intrinsic call which is holding reference to the resolved type
101+
# specialised custom_hash call above.
102+
@intrinsic
103+
def dispatcher(_ityctx, _a, _b):
104+
return types.int8(thing, another), resolved_codegen
105+
106+
@intrinsic
107+
def deref(_ityctx, _x):
108+
# to deref the void * passed. TODO: nrt awareness
109+
catchthing = thing
110+
sig = catchthing(_x)
111+
112+
def codegen(cgctx, builder, sig, args):
113+
toty = cgctx.get_value_type(sig.return_type).as_pointer()
114+
addressable = builder.bitcast(args[0], toty)
115+
zero_intpt = cgctx.get_constant(types.intp, 0)
116+
vref = builder.gep(addressable, [zero_intpt], inbounds=True)
117+
118+
return builder.load(vref)
119+
120+
return sig, codegen
121+
122+
@njit
123+
def wrappee(ap, bp):
124+
a = deref(ap)
125+
b = deref(bp)
126+
return dispatcher(a, b)
127+
128+
def wrapper(a, b):
129+
return wrappee(a, b)
130+
131+
callback = cfunc(types.int8(types.voidptr, types.voidptr))(wrapper)
132+
133+
# bake in address as a int const
134+
address = callback.address
135+
return cgctx.get_constant(types.intp, address)
136+
137+
return sig, codegen
138+
139+
140+
@intrinsic
141+
def asvoidp(tyctx, thing):
142+
sig = types.voidptr(thing)
143+
144+
def codegen(cgctx, builder, sig, args):
145+
dm_thing = cgctx.data_model_manager[sig.args[0]]
146+
data_thing = dm_thing.as_data(builder, args[0])
147+
ptr_thing = cgutils.alloca_once_value(builder, data_thing)
148+
149+
return builder.bitcast(ptr_thing, cgutils.voidptr_t)
150+
151+
return sig, codegen
152+
153+
154+
@intrinsic
155+
def sizeof(context, t):
156+
sig = types.uint64(t)
157+
158+
def codegen(cgctx, builder, sig, args):
159+
size = cgctx.get_abi_sizeof(t)
160+
return cgctx.get_constant(types.uint64, size)
161+
162+
return sig, codegen
163+
164+
165+
types_to_postfix = {types.int8: 'i8',
166+
types.uint8: 'u8',
167+
types.int16: 'i16',
168+
types.uint16: 'u16',
169+
types.int32: 'i32',
170+
types.uint32: 'u32',
171+
types.int64: 'i64',
172+
types.uint64: 'u64',
173+
types.float32: 'f32',
174+
types.float64: 'f64'}
175+
176+
177+
def load_symbols(name, sig, types):
178+
result = {}
179+
180+
func_text = '\n'.join([f"result[{typ}] = bind('{name}_{pstfx}', sig)" for typ, pstfx in types.items()])
181+
glbls = {f'{typ}': typ for typ in types.keys()}
182+
glbls.update({'result': result, 'sig': sig, 'bind': bind})
183+
exec(func_text, glbls)
184+
185+
return result
186+
187+
188+
sort_map = load_symbols('parallel_sort', parallel_sort_arithm_sig, types_to_postfix)
189+
stable_sort_map = load_symbols('parallel_stable_sort', parallel_sort_arithm_sig, types_to_postfix)
190+
191+
192+
@intrinsic
193+
def list_itemsize(tyctx, list_ty):
194+
sig = types.uint64(list_ty)
195+
196+
def codegen(cgctx, builder, sig, args):
197+
nb_lty = sig.args[0]
198+
nb_item_ty = nb_lty.item_type
199+
ll_item_ty = cgctx.get_value_type(nb_item_ty)
200+
item_size = cgctx.get_abi_sizeof(ll_item_ty)
201+
return cgctx.get_constant(sig.return_type, item_size)
202+
203+
return sig, codegen
204+
205+
206+
def itemsize(arr):
207+
pass
208+
209+
210+
@overload(itemsize)
211+
def itemsize_overload(arr):
212+
if isinstance(arr, types.Array):
213+
def itemsize_impl(arr):
214+
return arr.itemsize
215+
216+
return itemsize_impl
217+
218+
if isinstance(arr, types.List):
219+
def itemsize_impl(arr):
220+
return list_itemsize(arr)
221+
222+
return itemsize_impl
223+
224+
raise NotImplementedError
225+
226+
227+
def parallel_sort(arr):
228+
pass
229+
230+
231+
@overload(parallel_sort)
232+
def parallel_sort_overload(arr):
233+
234+
if not isinstance(arr, types.Array):
235+
raise NotImplementedError
236+
237+
dt = arr.dtype
238+
239+
if dt in types_to_postfix.keys():
240+
sort_f = sort_map[dt]
241+
242+
def parallel_sort_arithm_impl(arr):
243+
return sort_f(arr.ctypes, len(arr))
244+
245+
return parallel_sort_arithm_impl
246+
247+
def parallel_sort_impl(arr):
248+
item_size = itemsize(arr)
249+
return parallel_sort_sym(arr.ctypes, len(arr), item_size, adaptor(arr[0], arr[0]))
250+
251+
return parallel_sort_impl
252+
253+
254+
def parallel_stable_sort(arr):
255+
pass
256+
257+
258+
@overload(parallel_stable_sort)
259+
def parallel_stable_sort_overload(arr):
260+
261+
if not isinstance(arr, types.Array):
262+
raise NotImplementedError
263+
264+
dt = arr.dtype
265+
266+
if dt in types_to_postfix.keys():
267+
sort_f = stable_sort_map[dt]
268+
269+
def parallel_stable_sort_arithm_impl(arr):
270+
return sort_f(arr.ctypes, len(arr))
271+
272+
return parallel_stable_sort_arithm_impl
273+
274+
def parallel_stable_sort_impl(arr):
275+
item_size = itemsize(arr)
276+
return parallel_stable_sort_sym(arr.ctypes, len(arr), item_size, adaptor(arr[0], arr[0]))
277+
278+
return parallel_stable_sort_impl

0 commit comments

Comments
 (0)