88### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99"""Resampling utilities."""
1010
11+ import asyncio
1112from os import cpu_count
12- from concurrent . futures import ProcessPoolExecutor , as_completed
13+ from functools import partial
1314from pathlib import Path
14- from typing import Tuple
15+ from typing import Callable , TypeVar
1516
1617import numpy as np
1718from nibabel .loadsave import load as _nbload
2728 _as_homogeneous ,
2829)
2930
31+ R = TypeVar ("R" )
32+
3033SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8
3134"""Minimum number of volumes to automatically serialize 4D transforms."""
3235
3336
34- def _apply_volume (
35- index : int ,
37+ async def worker (job : Callable [[], R ], semaphore ) -> R :
38+ async with semaphore :
39+ loop = asyncio .get_running_loop ()
40+ return await loop .run_in_executor (None , job )
41+
42+
43+ async def _apply_serial (
3644 data : np .ndarray ,
45+ spatialimage : SpatialImage ,
3746 targets : np .ndarray ,
47+ transform : TransformBase ,
48+ ref_ndim : int ,
49+ ref_ndcoords : np .ndarray ,
50+ n_resamplings : int ,
51+ output : np .ndarray ,
52+ input_dtype : np .dtype ,
3853 order : int = 3 ,
3954 mode : str = "constant" ,
4055 cval : float = 0.0 ,
4156 prefilter : bool = True ,
42- ) -> Tuple [int , np .ndarray ]:
57+ max_concurrent : int = min (cpu_count (), 12 ),
58+ ):
4359 """
44- Decorate :obj:`~scipy.ndimage.map_coordinates` to return an order index for parallelization .
60+ Resample through a given transform serially, in a 3D+t setting .
4561
4662 Parameters
4763 ----------
48- index : :obj:`int`
49- The index of the volume to apply the interpolation to.
5064 data : :obj:`~numpy.ndarray`
5165 The input data array.
66+ spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike`
67+ The image object containing the data to be resampled in reference
68+ space
5269 targets : :obj:`~numpy.ndarray`
5370 The target coordinates for mapping.
71+ transform : :obj:`~nitransforms.base.TransformBase`
72+ The 3D, 3D+t, or 4D transform through which data will be resampled.
73+ ref_ndim : :obj:`int`
74+ Dimensionality of the resampling target (reference image).
75+ ref_ndcoords : :obj:`~numpy.ndarray`
76+ Physical coordinates (RAS+) where data will be interpolated, if the resampling
77+ target is a grid, the scanner coordinates of all voxels.
78+ n_resamplings : :obj:`int`
79+ Total number of 3D resamplings (can be defined by the input image, the transform,
80+ or be matched, that is, same number of volumes in the input and number of transforms).
81+ output : :obj:`~numpy.ndarray`
82+ The output data array where resampled values will be stored volume-by-volume.
5483 order : :obj:`int`, optional
5584 The order of the spline interpolation, default is 3.
5685 The order has to be in the range 0-5.
@@ -71,18 +100,46 @@ def _apply_volume(
71100
72101 Returns
73102 -------
74- (:obj:`int`, :obj:`~numpy .ndarray`)
75- The index and the array resulting from the interpolation .
103+ np .ndarray
104+ Data resampled on the 3D+t array of input coordinates .
76105
77106 """
78- return index , ndi .map_coordinates (
79- data ,
80- targets ,
81- order = order ,
82- mode = mode ,
83- cval = cval ,
84- prefilter = prefilter ,
85- )
107+ tasks = []
108+ semaphore = asyncio .Semaphore (max_concurrent )
109+
110+ for t in range (n_resamplings ):
111+ xfm_t = transform if n_resamplings == 1 else transform [t ]
112+
113+ if targets is None :
114+ targets = ImageGrid (spatialimage ).index ( # data should be an image
115+ _as_homogeneous (xfm_t .map (ref_ndcoords ), dim = ref_ndim )
116+ )
117+
118+ data_t = (
119+ data
120+ if data is not None
121+ else spatialimage .dataobj [..., t ].astype (input_dtype , copy = False )
122+ )
123+
124+ tasks .append (
125+ asyncio .create_task (
126+ worker (
127+ partial (
128+ ndi .map_coordinates ,
129+ data_t ,
130+ targets ,
131+ output = output [..., t ],
132+ order = order ,
133+ mode = mode ,
134+ cval = cval ,
135+ prefilter = prefilter ,
136+ ),
137+ semaphore ,
138+ )
139+ )
140+ )
141+ await asyncio .gather (* tasks )
142+ return output
86143
87144
88145def apply (
@@ -94,15 +151,17 @@ def apply(
94151 cval : float = 0.0 ,
95152 prefilter : bool = True ,
96153 output_dtype : np .dtype = None ,
97- serialize_nvols : int = SERIALIZE_VOLUME_WINDOW_WIDTH ,
98- njobs : int = None ,
99154 dtype_width : int = 8 ,
155+ serialize_nvols : int = SERIALIZE_VOLUME_WINDOW_WIDTH ,
156+ max_concurrent : int = min (cpu_count (), 12 ),
100157) -> SpatialImage | np .ndarray :
101158 """
102159 Apply a transformation to an image, resampling on the reference spatial object.
103160
104161 Parameters
105162 ----------
163+ transform: :obj:`~nitransforms.base.TransformBase`
164+ The 3D, 3D+t, or 4D transform through which data will be resampled.
106165 spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike`
107166 The image object containing the data to be resampled in reference
108167 space
@@ -118,15 +177,15 @@ def apply(
118177 or ``'wrap'``. Default is ``'constant'``.
119178 cval : :obj:`float`, optional
120179 Constant value for ``mode='constant'``. Default is 0.0.
121- prefilter: :obj:`bool`, optional
180+ prefilter : :obj:`bool`, optional
122181 Determines if the image's data array is prefiltered with
123182 a spline filter before interpolation. The default is ``True``,
124183 which will create a temporary *float64* array of filtered values
125184 if *order > 1*. If setting this to ``False``, the output will be
126185 slightly blurred if *order > 1*, unless the input is prefiltered,
127186 i.e. it is the result of calling the spline filter on the original
128187 input.
129- output_dtype: :obj:`~numpy.dtype`, optional
188+ output_dtype : :obj:`~numpy.dtype`, optional
130189 The dtype of the returned array or image, if specified.
131190 If ``None``, the default behavior is to use the effective dtype of
132191 the input image. If slope and/or intercept are defined, the effective
@@ -135,10 +194,17 @@ def apply(
135194 If ``reference`` is defined, then the return value is an image, with
136195 a data array of the effective dtype but with the on-disk dtype set to
137196 the input image's on-disk dtype.
138- dtype_width: :obj:`int`
197+ dtype_width : :obj:`int`
139198 Cap the width of the input data type to the given number of bytes.
140199 This argument is intended to work as a way to implement lower memory
141200 requirements in resampling.
201+ serialize_nvols : :obj:`int`
202+ Minimum number of volumes in a 3D+t (that is, a series of 3D transformations
203+ independent in time) to resample on a one-by-one basis.
204+ Serialized resampling can be executed concurrently (parallelized) with
205+ the argument ``max_concurrent``.
206+ max_concurrent : :obj:`int`
207+ Maximum number of 3D resamplings to be executed concurrently.
142208
143209 Returns
144210 -------
@@ -201,46 +267,30 @@ def apply(
201267 else None
202268 )
203269
204- njobs = cpu_count () if njobs is None or njobs < 1 else njobs
205-
206- with ProcessPoolExecutor (max_workers = min (njobs , n_resamplings )) as executor :
207- results = []
208- for t in range (n_resamplings ):
209- xfm_t = transform if n_resamplings == 1 else transform [t ]
210-
211- if targets is None :
212- targets = ImageGrid (spatialimage ).index ( # data should be an image
213- _as_homogeneous (xfm_t .map (ref_ndcoords ), dim = _ref .ndim )
214- )
215-
216- data_t = (
217- data
218- if data is not None
219- else spatialimage .dataobj [..., t ].astype (input_dtype , copy = False )
220- )
221-
222- results .append (
223- executor .submit (
224- _apply_volume ,
225- t ,
226- data_t ,
227- targets ,
228- order = order ,
229- mode = mode ,
230- cval = cval ,
231- prefilter = prefilter ,
232- )
233- )
270+ # Order F ensures individual volumes are contiguous in memory
271+ # Also matches NIfTI, making final save more efficient
272+ resampled = np .zeros (
273+ (len (ref_ndcoords ), len (transform )), dtype = input_dtype , order = "F"
274+ )
234275
235- # Order F ensures individual volumes are contiguous in memory
236- # Also matches NIfTI, making final save more efficient
237- resampled = np .zeros (
238- (len (ref_ndcoords ), len (transform )), dtype = input_dtype , order = "F"
276+ resampled = asyncio .run (
277+ _apply_serial (
278+ data ,
279+ spatialimage ,
280+ targets ,
281+ transform ,
282+ _ref .ndim ,
283+ ref_ndcoords ,
284+ n_resamplings ,
285+ resampled ,
286+ input_dtype ,
287+ order = order ,
288+ mode = mode ,
289+ cval = cval ,
290+ prefilter = prefilter ,
291+ max_concurrent = max_concurrent ,
239292 )
240-
241- for future in as_completed (results ):
242- t , resampled_t = future .result ()
243- resampled [..., t ] = resampled_t
293+ )
244294 else :
245295 data = np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
246296
0 commit comments