1010
1111import asyncio
1212from os import cpu_count
13+ from contextlib import suppress
1314from functools import partial
1415from pathlib import Path
1516from typing import Callable , TypeVar , Union
@@ -108,14 +109,16 @@ async def _apply_serial(
108109 semaphore = asyncio .Semaphore (max_concurrent )
109110
110111 for t in range (n_resamplings ):
111- xfm_t = transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
112+ xfm_t = (
113+ transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
114+ )
112115
113116 targets_t = (
114117 ImageGrid (spatialimage ).index (
115118 _as_homogeneous (xfm_t .map (ref_ndcoords ), dim = ref_ndim )
116119 )
117120 if targets is None
118- else targets
121+ else targets [ t , ...]
119122 )
120123
121124 data_t = (
@@ -258,11 +261,22 @@ def apply(
258261 dim = _ref .ndim ,
259262 )
260263 )
261- elif xfm_nvols == 1 :
262- targets = ImageGrid (spatialimage ).index ( # data should be an image
263- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
264+ else :
265+ # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266+ targets = (
267+ ImageGrid (spatialimage ).index (
268+ _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
269+ )
270+ if targets is None
271+ else targets
264272 )
265273
274+ if targets .ndim == 3 :
275+ targets = np .rollaxis (targets , targets .ndim - 1 , 0 )
276+ else :
277+ assert targets .ndim == 2
278+ targets = targets [np .newaxis , ...]
279+
266280 if serialize_4d :
267281 data = (
268282 np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
@@ -297,17 +311,24 @@ def apply(
297311 else :
298312 data = np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
299313
300- if targets is None :
301- targets = ImageGrid (spatialimage ).index ( # data should be an image
302- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
303- )
304-
314+ if data_nvols == 1 and xfm_nvols == 1 :
315+ targets = np .squeeze (targets )
316+ assert targets .ndim == 2
305317 # Cast 3D data into 4D if 4D nonsequential transform
306- if data_nvols == 1 and xfm_nvols > 1 :
318+ elif data_nvols == 1 and xfm_nvols > 1 :
307319 data = data [..., np .newaxis ]
308320
309- if transform .ndim == 4 :
310- targets = _as_homogeneous (targets .reshape (- 2 , targets .shape [0 ])).T
321+ if xfm_nvols > 1 :
322+ assert targets .ndim == 3
323+ n_time , n_dim , n_vox = targets .shape
324+ # Reshape to (3, n_time x n_vox)
325+ ijk_targets = np .rollaxis (targets , 0 , 2 ).reshape ((n_dim , - 1 ))
326+ time_row = np .repeat (np .arange (n_time ), n_vox )[None , :]
327+
328+ # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
329+ # t is the slowest-changing axis, so we put it first
330+ targets = np .vstack ((time_row , ijk_targets ))
331+ data = np .rollaxis (data , data .ndim - 1 , 0 )
311332
312333 resampled = ndi .map_coordinates (
313334 data ,
@@ -326,11 +347,19 @@ def apply(
326347 )
327348 hdr .set_data_dtype (output_dtype or spatialimage .header .get_data_dtype ())
328349
329- moved = spatialimage .__class__ (
330- resampled .reshape (_ref .shape if n_resamplings == 1 else _ref .shape + (- 1 ,)),
331- _ref .affine ,
332- hdr ,
333- )
350+ if serialize_4d :
351+ resampled = resampled .reshape (
352+ _ref .shape
353+ if n_resamplings == 1
354+ else _ref .shape + (resampled .shape [- 1 ],)
355+ )
356+ else :
357+ resampled = resampled .reshape ((- 1 , * _ref .shape ))
358+ resampled = np .rollaxis (resampled , 0 , resampled .ndim )
359+ with suppress (ValueError ):
360+ resampled = np .squeeze (resampled , axis = 3 )
361+
362+ moved = spatialimage .__class__ (resampled , _ref .affine , hdr )
334363 return moved
335364
336365 output_dtype = output_dtype or input_dtype
0 commit comments