Skip to content

Commit 710ebe9

Browse files
add a sanity check ensuring that Numba threading actually works (#1690)
Co-authored-by: AgnieszkaZaba <56157996+AgnieszkaZaba@users.noreply.github.com>
1 parent f5479b6 commit 710ebe9

File tree

3 files changed

+57
-17
lines changed

3 files changed

+57
-17
lines changed

.github/workflows/pypi.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
unset CI
3939
cd ${{ matrix.packages-dir }}
4040
python -m build 2>&1 | tee build.log
41-
exit `fgrep -i warning build.log | grep -v impl_numba/warnings.py \
41+
exit `fgrep -i warning build.log | grep -v warnings.py \
4242
| grep -v "no previously-included files matching" \
4343
| grep -v "version of {dist_name} already set" \
4444
| grep -v -E "UserWarning: version of PySDM(-examples)? already set" \

PySDM/backends/numba.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import warnings
88

99
import numba
10+
from numba import prange
11+
import numpy as np
1012

1113
from PySDM.backends.impl_numba import methods
1214
from PySDM.backends.impl_numba.random import Random as ImportedRandom
@@ -45,21 +47,43 @@ def __init__(
4547
self.formulae_flattened = self.formulae.flatten
4648

4749
parallel_default = True
48-
if platform.machine() == "arm64":
49-
if "CI" not in os.environ:
50-
warnings.warn(
51-
"Disabling Numba threading due to ARM64 CPU (atomics do not work yet)"
52-
)
53-
parallel_default = False # TODO #1183 - atomics don't work on ARM64!
54-
55-
try:
56-
numba.parfors.parfor.ensure_parallel_support()
57-
except numba.core.errors.UnsupportedParforsError:
58-
if "CI" not in os.environ:
59-
warnings.warn(
60-
"Numba version used does not support parallel for (32 bits?)"
61-
)
62-
parallel_default = False
50+
51+
if override_jit_flags is not None and "parallel" in override_jit_flags:
52+
parallel_default = override_jit_flags["parallel"]
53+
54+
if parallel_default:
55+
if platform.machine() == "arm64":
56+
if "CI" not in os.environ:
57+
warnings.warn(
58+
"Disabling Numba threading due to ARM64 CPU (atomics do not work yet)"
59+
)
60+
parallel_default = False # TODO #1183 - atomics don't work on ARM64!
61+
62+
try:
63+
numba.parfors.parfor.ensure_parallel_support()
64+
except numba.core.errors.UnsupportedParforsError:
65+
if "CI" not in os.environ:
66+
warnings.warn(
67+
"Numba version used does not support parallel for (32 bits?)"
68+
)
69+
parallel_default = False
70+
71+
if not numba.config.DISABLE_JIT: # pylint: disable=no-member
72+
73+
@numba.jit(parallel=True, nopython=True)
74+
def fill_array_with_thread_id(arr):
75+
"""writes thread id to corresponding array element"""
76+
for i in prange( # pylint: disable=not-an-iterable
77+
numba.get_num_threads()
78+
):
79+
arr[i] = numba.get_thread_id()
80+
81+
fill_array_with_thread_id(arr := np.full(numba.get_num_threads(), -1))
82+
if not max(arr) == arr[-1] == numba.get_num_threads() - 1:
83+
raise ValueError(
84+
"Numba threading enabled but does not work"
85+
" (try other setting of the NUMBA_THREADING_LAYER env var?)"
86+
)
6387

6488
assert "fastmath" not in (override_jit_flags or {})
6589
self.default_jit_flags = {

tests/unit_tests/backends/test_ctor_defaults.py renamed to tests/unit_tests/backends/test_ctor_defaults_and_warnings.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
2+
from unittest import mock
3+
import warnings
24
import inspect
5+
import pytest
6+
import numba
37

48
from PySDM.backends import Numba, ThrustRTC
59

610

7-
class TestCtorDefaults:
11+
class TestCtorDefaultsAndWarnings:
812
@staticmethod
913
def test_gpu_ctor_defaults():
1014
signature = inspect.signature(ThrustRTC.__init__)
@@ -17,3 +21,15 @@ def test_gpu_ctor_defaults():
1721
def test_cpu_ctor_defaults():
1822
signature = inspect.signature(Numba.__init__)
1923
assert signature.parameters["formulae"].default is None
24+
25+
@staticmethod
26+
@mock.patch("PySDM.backends.numba.prange", new=range)
27+
def test_check_numba_threading_warning():
28+
if numba.config.DISABLE_JIT: # pylint: disable=no-member
29+
pytest.skip()
30+
31+
with warnings.catch_warnings():
32+
warnings.simplefilter("ignore")
33+
with pytest.raises(ValueError) as exc_info:
34+
Numba()
35+
assert exc_info.match(r"^Numba threading enabled but does not work")

0 commit comments

Comments
 (0)