|
7 | 7 | import warnings |
8 | 8 |
|
9 | 9 | import numba |
| 10 | +from numba import prange |
| 11 | +import numpy as np |
10 | 12 |
|
11 | 13 | from PySDM.backends.impl_numba import methods |
12 | 14 | from PySDM.backends.impl_numba.random import Random as ImportedRandom |
@@ -45,21 +47,43 @@ def __init__( |
45 | 47 | self.formulae_flattened = self.formulae.flatten |
46 | 48 |
|
47 | 49 | parallel_default = True |
48 | | - if platform.machine() == "arm64": |
49 | | - if "CI" not in os.environ: |
50 | | - warnings.warn( |
51 | | - "Disabling Numba threading due to ARM64 CPU (atomics do not work yet)" |
52 | | - ) |
53 | | - parallel_default = False # TODO #1183 - atomics don't work on ARM64! |
54 | | - |
55 | | - try: |
56 | | - numba.parfors.parfor.ensure_parallel_support() |
57 | | - except numba.core.errors.UnsupportedParforsError: |
58 | | - if "CI" not in os.environ: |
59 | | - warnings.warn( |
60 | | - "Numba version used does not support parallel for (32 bits?)" |
61 | | - ) |
62 | | - parallel_default = False |
| 50 | + |
| 51 | + if override_jit_flags is not None and "parallel" in override_jit_flags: |
| 52 | + parallel_default = override_jit_flags["parallel"] |
| 53 | + |
| 54 | + if parallel_default: |
| 55 | + if platform.machine() == "arm64": |
| 56 | + if "CI" not in os.environ: |
| 57 | + warnings.warn( |
| 58 | + "Disabling Numba threading due to ARM64 CPU (atomics do not work yet)" |
| 59 | + ) |
| 60 | + parallel_default = False # TODO #1183 - atomics don't work on ARM64! |
| 61 | + |
| 62 | + try: |
| 63 | + numba.parfors.parfor.ensure_parallel_support() |
| 64 | + except numba.core.errors.UnsupportedParforsError: |
| 65 | + if "CI" not in os.environ: |
| 66 | + warnings.warn( |
| 67 | + "Numba version used does not support parallel for (32 bits?)" |
| 68 | + ) |
| 69 | + parallel_default = False |
| 70 | + |
| 71 | + if not numba.config.DISABLE_JIT: # pylint: disable=no-member |
| 72 | + |
| 73 | + @numba.jit(parallel=True, nopython=True) |
| 74 | + def fill_array_with_thread_id(arr): |
| 75 | + """writes thread id to corresponding array element""" |
| 76 | + for i in prange( # pylint: disable=not-an-iterable |
| 77 | + numba.get_num_threads() |
| 78 | + ): |
| 79 | + arr[i] = numba.get_thread_id() |
| 80 | + |
| 81 | + fill_array_with_thread_id(arr := np.full(numba.get_num_threads(), -1)) |
| 82 | + if not max(arr) == arr[-1] == numba.get_num_threads() - 1: |
| 83 | + raise ValueError( |
| 84 | + "Numba threading enabled but does not work" |
| 85 | + " (try other setting of the NUMBA_THREADING_LAYER env var?)" |
| 86 | + ) |
63 | 87 |
|
64 | 88 | assert "fastmath" not in (override_jit_flags or {}) |
65 | 89 | self.default_jit_flags = { |
|
0 commit comments