Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 7ff3b19

Browse files
authored
Add tests for get_chunks and change formula for get_chunks (#618)
* Add tests for get_chunks and change formula for get_chunks * fix pep8 * Pass pool_size explicitly * Fix year * Use _check_get_chunks() * Rename date with data * Add parallel_chunks()
1 parent 7150b62 commit 7ff3b19

File tree

3 files changed

+95
-36
lines changed

3 files changed

+95
-36
lines changed

sdc/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from sdc.tests.test_hpat_jit import *
4646

4747
from sdc.tests.test_sdc_numpy import *
48+
from sdc.tests.test_prange_utils import *
4849

4950
# performance tests
5051
import sdc.tests.tests_perf

sdc/tests/test_prange_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2020, Intel Corporation All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
#
10+
# Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
16+
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
17+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
18+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
22+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
23+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24+
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
from sdc.tests.test_base import TestCase
28+
29+
from sdc.utilities.prange_utils import get_chunks, Chunk
30+
31+
32+
class ChunkTest(TestCase):
33+
34+
def _get_chunks_data(self):
35+
yield (5, 5), [
36+
Chunk(start=0, stop=1),
37+
Chunk(start=1, stop=2),
38+
Chunk(start=2, stop=3),
39+
Chunk(start=3, stop=4),
40+
Chunk(start=4, stop=5),
41+
]
42+
yield (5, 6), [
43+
Chunk(start=0, stop=1),
44+
Chunk(start=1, stop=2),
45+
Chunk(start=2, stop=3),
46+
Chunk(start=3, stop=4),
47+
Chunk(start=4, stop=5),
48+
]
49+
yield (9, 5), [
50+
Chunk(start=0, stop=2),
51+
Chunk(start=2, stop=4),
52+
Chunk(start=4, stop=6),
53+
Chunk(start=6, stop=8),
54+
Chunk(start=8, stop=9),
55+
]
56+
yield (9, 4), [
57+
Chunk(start=0, stop=3),
58+
Chunk(start=3, stop=5),
59+
Chunk(start=5, stop=7),
60+
Chunk(start=7, stop=9),
61+
]
62+
yield (9, 2), [
63+
Chunk(start=0, stop=5),
64+
Chunk(start=5, stop=9),
65+
]
66+
yield (9, 3), [
67+
Chunk(start=0, stop=3),
68+
Chunk(start=3, stop=6),
69+
Chunk(start=6, stop=9),
70+
]
71+
72+
def _check_get_chunks(self, args, expected_chunks):
73+
pyfunc = get_chunks
74+
cfunc = self.jit(pyfunc)
75+
76+
self.assertEqual(pyfunc(*args), expected_chunks)
77+
self.assertEqual(cfunc(*args), expected_chunks)
78+
79+
def test_get_chunks(self):
80+
for args, expected_chunks in self._get_chunks_data():
81+
with self.subTest(args=args):
82+
self._check_get_chunks(args, expected_chunks)

sdc/utilities/prange_utils.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -29,61 +29,37 @@
2929
import sdc
3030

3131
from typing import NamedTuple
32-
from sdc.utilities.utils import sdc_overload
32+
from sdc.utilities.utils import sdc_overload, sdc_register_jitable
3333

3434

3535
class Chunk(NamedTuple):
3636
start: int
3737
stop: int
3838

3939

40+
@sdc_register_jitable
4041
def get_pool_size():
4142
if sdc.config.config_use_parallel_overloads:
4243
return numba.config.NUMBA_NUM_THREADS
4344

4445
return 1
4546

4647

47-
@sdc_overload(get_pool_size)
48-
def get_pool_size_overload():
49-
pool_size = get_pool_size()
50-
51-
def get_pool_size_impl():
52-
return pool_size
53-
54-
return get_pool_size_impl
55-
56-
57-
def get_chunks(size, pool_size=0):
58-
if pool_size == 0:
59-
pool_size = get_pool_size()
60-
61-
chunk_size = (size - 1) // pool_size + 1
48+
@sdc_register_jitable
49+
def get_chunks(size, pool_size):
50+
pool_size = min(pool_size, size)
51+
chunk_size = size // pool_size
52+
overload_size = size % pool_size
6253

6354
chunks = []
6455
for i in range(pool_size):
65-
start = min(i * chunk_size, size)
66-
stop = min((i + 1) * chunk_size, size)
56+
start = i * chunk_size + min(i, overload_size)
57+
stop = (i + 1) * chunk_size + min(i + 1, overload_size)
6758
chunks.append(Chunk(start, stop))
6859

6960
return chunks
7061

7162

72-
@sdc_overload(get_chunks)
73-
def get_chunks_overload(size, pool_size=0):
74-
def get_chunks_impl(size, pool_size=0):
75-
if pool_size == 0:
76-
pool_size = get_pool_size()
77-
78-
chunk_size = (size - 1) // pool_size + 1
79-
80-
chunks = []
81-
for i in range(pool_size):
82-
start = min(i * chunk_size, size)
83-
stop = min((i + 1) * chunk_size, size)
84-
chunk = Chunk(start, stop)
85-
chunks.append(chunk)
86-
87-
return chunks
88-
89-
return get_chunks_impl
63+
@sdc_register_jitable
64+
def parallel_chunks(size):
65+
return get_chunks(size, get_pool_size())

0 commit comments

Comments
 (0)