1010
1111import asyncio
1212from os import cpu_count
13+ from contextlib import suppress
1314from functools import partial
1415from pathlib import Path
1516from typing import Callable , TypeVar , Union
@@ -108,12 +109,17 @@ async def _apply_serial(
108109 semaphore = asyncio .Semaphore (max_concurrent )
109110
110111 for t in range (n_resamplings ):
111- xfm_t = transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
112+ xfm_t = (
113+ transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
114+ )
112115
113- if targets is None :
114- targets = ImageGrid (spatialimage ).index ( # data should be an image
116+ targets_t = (
117+ ImageGrid (spatialimage ).index (
115118 _as_homogeneous (xfm_t .map (ref_ndcoords ), dim = ref_ndim )
116119 )
120+ if targets is None
121+ else targets [t , ...]
122+ )
117123
118124 data_t = (
119125 data
@@ -127,7 +133,7 @@ async def _apply_serial(
127133 partial (
128134 ndi .map_coordinates ,
129135 data_t ,
130- targets ,
136+ targets_t ,
131137 output = output [..., t ],
132138 order = order ,
133139 mode = mode ,
@@ -255,11 +261,22 @@ def apply(
255261 dim = _ref .ndim ,
256262 )
257263 )
258- elif xfm_nvols == 1 :
259- targets = ImageGrid (spatialimage ).index ( # data should be an image
260- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
264+ else :
265+ # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266+ targets = (
267+ ImageGrid (spatialimage ).index (
268+ _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
269+ )
270+ if targets is None
271+ else targets
261272 )
262273
274+ if targets .ndim == 3 :
275+ targets = np .rollaxis (targets , targets .ndim - 1 , 0 )
276+ else :
277+ assert targets .ndim == 2
278+ targets = targets [np .newaxis , ...]
279+
263280 if serialize_4d :
264281 data = (
265282 np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
@@ -294,17 +311,24 @@ def apply(
294311 else :
295312 data = np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
296313
297- if targets is None :
298- targets = ImageGrid (spatialimage ).index ( # data should be an image
299- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
300- )
301-
314+ if data_nvols == 1 and xfm_nvols == 1 :
315+ targets = np .squeeze (targets )
316+ assert targets .ndim == 2
302317 # Cast 3D data into 4D if 4D nonsequential transform
303- if data_nvols == 1 and xfm_nvols > 1 :
318+ elif data_nvols == 1 and xfm_nvols > 1 :
304319 data = data [..., np .newaxis ]
305320
306- if transform .ndim == 4 :
307- targets = _as_homogeneous (targets .reshape (- 2 , targets .shape [0 ])).T
321+ if xfm_nvols > 1 :
322+ assert targets .ndim == 3
323+ n_time , n_dim , n_vox = targets .shape
324+ # Reshape to (3, n_time x n_vox)
325+ ijk_targets = np .rollaxis (targets , 0 , 2 ).reshape ((n_dim , - 1 ))
326+ time_row = np .repeat (np .arange (n_time ), n_vox )[None , :]
327+
328+ # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
329+ # t is the slowest-changing axis, so we put it first
330+ targets = np .vstack ((time_row , ijk_targets ))
331+ data = np .rollaxis (data , data .ndim - 1 , 0 )
308332
309333 resampled = ndi .map_coordinates (
310334 data ,
@@ -323,11 +347,19 @@ def apply(
323347 )
324348 hdr .set_data_dtype (output_dtype or spatialimage .header .get_data_dtype ())
325349
326- moved = spatialimage .__class__ (
327- resampled .reshape (_ref .shape if n_resamplings == 1 else _ref .shape + (- 1 ,)),
328- _ref .affine ,
329- hdr ,
330- )
350+ if serialize_4d :
351+ resampled = resampled .reshape (
352+ _ref .shape
353+ if n_resamplings == 1
354+ else _ref .shape + (resampled .shape [- 1 ],)
355+ )
356+ else :
357+ resampled = resampled .reshape ((- 1 , * _ref .shape ))
358+ resampled = np .rollaxis (resampled , 0 , resampled .ndim )
359+ with suppress (ValueError ):
360+ resampled = np .squeeze (resampled , axis = 3 )
361+
362+ moved = spatialimage .__class__ (resampled , _ref .affine , hdr )
331363 return moved
332364
333365 output_dtype = output_dtype or input_dtype
0 commit comments