88### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99"""Nonlinear transforms."""
1010import warnings
11+ from functools import partial
1112from pathlib import Path
1213import numpy as np
13- from scipy .sparse import vstack as sparse_vstack
1414from scipy import ndimage as ndi
1515from nibabel .funcs import four_to_three
1616from nibabel .loadsave import load as _nbload
1717
18- from . import io
19- from .interp .bspline import grid_bspline_weights , _cubic_bspline
20- from .base import (
18+ from nitransforms import io
19+ from nitransforms .io .base import _ensure_image
20+ from nitransforms .interp .bspline import grid_bspline_weights , _cubic_bspline
21+ from nitransforms .base import (
2122 TransformBase ,
2223 ImageGrid ,
2324 SpatialReference ,
@@ -33,7 +34,11 @@ class DisplacementsFieldTransform(TransformBase):
3334 def __init__ (self , field , reference = None ):
3435 """Create a dense deformation field transform."""
3536 super ().__init__ ()
36- self ._field = np .asanyarray (field .dataobj )
37+
38+ self ._field = np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
39+ self .reference = reference or field .__class__ (
40+ np .zeros (self ._field .shape [:- 1 ]), field .affine , field .header
41+ )
3742
3843 ndim = self ._field .ndim - 1
3944 if self ._field .shape [- 1 ] != ndim :
@@ -42,13 +47,16 @@ def __init__(self, field, reference=None):
4247 "the number of dimensions (%d)" % (self ._field .shape [- 1 ], ndim )
4348 )
4449
45- self .reference = field .__class__ (
46- np .zeros (self ._field .shape [:- 1 ]), field .affine , field .header
47- )
48-
4950 def map (self , x , inverse = False ):
5051 r"""
51- Apply :math:`y = f(x)`.
52+ Apply the transformation to a list of physical coordinate points.
53+
54+ .. math::
55+ \mathbf{y} = \mathbf{x} + D(\mathbf{x}),
56+ \label{eq:2}\tag{2}
57+
58+ where :math:`D(\mathbf{x})` is the value of the discrete field of displacements
59+ :math:`D` interpolated at the location :math:`\mathbf{x}`.
5260
5361 Parameters
5462 ----------
@@ -104,14 +112,14 @@ class BSplineFieldTransform(TransformBase):
104112 """Represent a nonlinear transform parameterized by BSpline basis."""
105113
106114 __slots__ = ['_coeffs' , '_knots' , '_weights' , '_order' , '_moving' ]
107- __s = (slice (None ), )
108115
109116 def __init__ (self , reference , coefficients , order = 3 ):
110117 """Create a smooth deformation field using B-Spline basis."""
111118 super (BSplineFieldTransform , self ).__init__ ()
112119 self ._order = order
113120 self .reference = reference
114121
122+ coefficients = _ensure_image (coefficients )
115123 if coefficients .shape [- 1 ] != self .ndim :
116124 raise ValueError (
117125 'Number of components of the coefficients does '
@@ -121,6 +129,23 @@ def __init__(self, reference, coefficients, order=3):
121129 self ._knots = ImageGrid (four_to_three (coefficients )[0 ])
122130 self ._weights = None
123131
132+ def to_field (self , reference = None ):
133+ """Generate a displacements deformation field from this B-Spline field."""
134+ reference = _ensure_image (reference )
135+ _ref = self .reference if reference is None else SpatialReference .factory (reference )
136+ ndim = self ._coeffs .shape [- 1 ]
137+
138+ # If locations to be interpolated are on a grid, use faster tensor-bspline calculation
139+ if self ._weights is None :
140+ self ._weights = grid_bspline_weights (_ref , self ._knots )
141+
142+ field = np .zeros ((_ref .npoints , ndim ))
143+
144+ for d in range (ndim ):
145+ field [:, d ] = self ._coeffs [..., d ].reshape (- 1 ) @ self ._weights
146+
147+ return field .astype ("float32" )
148+
124149 def apply (
125150 self ,
126151 spatialimage ,
@@ -133,8 +158,8 @@ def apply(
133158 ):
134159 """Apply a B-Spline transform on input data."""
135160
136- if reference is not None and isinstance ( reference , ( str , Path )) :
137- reference = _nbload ( str ( reference ) )
161+ if reference is not None :
162+ reference = _ensure_image ( reference )
138163
139164 _ref = (
140165 self .reference if reference is None else SpatialReference .factory (reference )
@@ -143,6 +168,7 @@ def apply(
143168 if isinstance (spatialimage , (str , Path )):
144169 spatialimage = _nbload (str (spatialimage ))
145170
171+ # If locations to be interpolated are not on a grid, run map()
146172 if not isinstance (_ref , ImageGrid ):
147173 return super ().apply (
148174 spatialimage ,
@@ -154,72 +180,85 @@ def apply(
154180 output_dtype = output_dtype ,
155181 )
156182
157- # If locations to be interpolated are on a grid, use faster tensor-bspline calculation
158- if self ._weights is None :
159- self ._weights = grid_bspline_weights (_ref , self ._knots )
160-
161- ycoords = _ref .ndcoords .T + (
162- np .squeeze (np .hstack (self ._coeffs ).T ) @ sparse_vstack (self ._weights )
183+ # If locations to be interpolated are on a grid, generate a displacements field
184+ return DisplacementsFieldTransform (
185+ self .to_field ().reshape ((* (_ref .shape ), - 1 )),
186+ reference = _ref ,
187+ ).apply (
188+ spatialimage ,
189+ reference = reference ,
190+ order = order ,
191+ mode = mode ,
192+ cval = cval ,
193+ prefilter = prefilter ,
194+ output_dtype = output_dtype ,
163195 )
164196
165- data = np .squeeze (np .asanyarray (spatialimage .dataobj ))
166- output_dtype = output_dtype or data .dtype
167- targets = ImageGrid (spatialimage ).index ( # data should be an image
168- _as_homogeneous (np .vstack (ycoords ), dim = _ref .ndim )
169- )
197+ def map (self , x , inverse = False ):
198+ r"""
199+ Apply the transformation to a list of physical coordinate points.
170200
171- if data .ndim == 4 :
172- if len (self ) != data .shape [- 1 ]:
173- raise ValueError (
174- "Attempting to apply %d transforms on a file with "
175- "%d timepoints" % (len (self ), data .shape [- 1 ])
176- )
177- targets = targets .reshape ((len (self ), - 1 , targets .shape [- 1 ]))
178- resampled = np .stack (
179- [
180- ndi .map_coordinates (
181- data [..., t ],
182- targets [t , ..., : _ref .ndim ].T ,
183- output = output_dtype ,
184- order = order ,
185- mode = mode ,
186- cval = cval ,
187- prefilter = prefilter ,
188- )
189- for t in range (data .shape [- 1 ])
190- ],
191- axis = 0 ,
192- )
193- elif data .ndim in (2 , 3 ):
194- resampled = ndi .map_coordinates (
195- data ,
196- targets [..., : _ref .ndim ].T ,
197- output = output_dtype ,
198- order = order ,
199- mode = mode ,
200- cval = cval ,
201- prefilter = prefilter ,
202- )
201+ .. math::
202+ \mathbf{y} = \mathbf{x} + \Psi^3(\mathbf{k}, \mathbf{x}),
203+ \label{eq:2}\tag{2}
203204
204- newdata = resampled . reshape (( len ( self ), * _ref . shape ))
205- moved = spatialimage . __class__ (
206- np . moveaxis ( newdata , 0 , - 1 ), _ref . affine , spatialimage . header
207- )
208- moved . header . set_data_dtype ( output_dtype )
209- return moved
205+ Parameters
206+ ----------
207+ x : N x D numpy.ndarray
208+ Input RAS+ coordinates (i.e., physical coordinates).
209+ inverse : bool
210+ If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
210211
211- def map (self , x , inverse = False ):
212- """Apply :math:`y = f(x)`."""
213-
214- ijk = (self ._knots .inverse @ _as_homogeneous (x ).squeeze ())[:3 ]
215- w_start , w_end = np .ceil (ijk - 2 ).astype (int ), np .floor (ijk + 2 ).astype (int )
216- nonzero_knots = tuple ([
217- np .arange (start , end + 1 ) for start , end in zip (w_start , w_end )
218- ])
219- nonzero_knots = np .meshgrid (* nonzero_knots , indexing = "ij" )
220- window = np .array (nonzero_knots ).reshape ((self .reference .ndim , - 1 ))
221- distance = window .T - ijk
222- unique_d , indices = np .unique (distance .reshape (- 1 ), return_inverse = True )
223- tensor_bspline = _cubic_bspline (unique_d )[indices ].reshape (distance .shape ).prod (1 )
224- coeffs = self ._coeffs [nonzero_knots ].reshape ((- 1 , self ._coeffs .shape [- 1 ]))
225- return x + coeffs .T @ tensor_bspline
212+ Returns
213+ -------
214+ y : N x D numpy.ndarray
215+ Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
216+
217+ Examples
218+ --------
219+ >>> field = np.zeros((10, 10, 10, 3))
220+ >>> field[..., 0] = 4.0
221+ >>> fieldimg = nb.Nifti1Image(field, np.diag([2., 2., 2., 1.]))
222+ >>> xfm = DisplacementsFieldTransform(fieldimg)
223+ >>> xfm([4.0, 4.0, 4.0]).tolist()
224+ [[8.0, 4.0, 4.0]]
225+
226+ >>> xfm([[4.0, 4.0, 4.0], [8, 2, 10]]).tolist()
227+ [[8.0, 4.0, 4.0], [12.0, 2.0, 10.0]]
228+ """
229+ vfunc = partial (
230+ _map_xyz ,
231+ reference = self .reference ,
232+ knots = self ._knots ,
233+ coeffs = self ._coeffs ,
234+ )
235+ return [vfunc (_x ) for _x in np .atleast_2d (x )]
236+
237+
238+ def _map_xyz (x , reference , knots , coeffs ):
239+ """Apply the transformation to just one coordinate."""
240+ ndim = len (x )
241+ # Calculate the index coordinates of the point in the B-Spline grid
242+ ijk = (knots .inverse @ _as_homogeneous (x ).squeeze ())[:ndim ]
243+
244+ # Determine the window within distance 2.0 (where the B-Spline is nonzero)
245+ # Probably this will change if the order of the B-Spline is different
246+ w_start , w_end = np .ceil (ijk - 2 ).astype (int ), np .floor (ijk + 2 ).astype (int )
247+ # Generate a grid of indexes corresponding to the window
248+ nonzero_knots = tuple ([
249+ np .arange (start , end + 1 ) for start , end in zip (w_start , w_end )
250+ ])
251+ nonzero_knots = tuple (np .meshgrid (* nonzero_knots , indexing = "ij" ))
252+ window = np .array (nonzero_knots ).reshape ((ndim , - 1 ))
253+
254+ # Calculate the distance of the location w.r.t. to all voxels in window
255+ distance = window .T - ijk
256+ # Since this is a grid, distance only takes a few float values
257+ unique_d , indices = np .unique (distance .reshape (- 1 ), return_inverse = True )
258+ # Calculate the B-Spline weight corresponding to the distance.
259+ # Then multiply the three weights of each knot (tensor-product B-Spline)
260+ tensor_bspline = _cubic_bspline (unique_d )[indices ].reshape (distance .shape ).prod (1 )
261+ # Extract the values of the coefficients in the window
262+ coeffs = coeffs [nonzero_knots ].reshape ((- 1 , ndim ))
263+ # Inference: the displacement is the product of coefficients x tensor-product B-Splines
264+ return x + coeffs .T @ tensor_bspline
0 commit comments