@@ -31,10 +31,21 @@ class DisplacementsFieldTransform(TransformBase):
3131 __slots__ = ["_field" ]
3232
3333 def __init__ (self , field , reference = None ):
34- """Create a dense deformation field transform."""
34+ """
35+ Create a dense deformation field transform.
36+
37+ Example
38+ -------
39+ >>> DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
40+ <(57, 67, 56) field of 3D displacements>
41+
42+ """
3543 super ().__init__ ()
3644
37- self ._field = np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
45+ field = _ensure_image (field )
46+ self ._field = np .squeeze (
47+ np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
48+ )
3849 self .reference = reference or field .__class__ (
3950 np .zeros (self ._field .shape [:- 1 ]), field .affine , field .header
4051 )
@@ -46,6 +57,10 @@ def __init__(self, field, reference=None):
4657 "the number of dimensions (%d)" % (self ._field .shape [- 1 ], ndim )
4758 )
4859
60+ def __repr__ (self ):
61+ """Beautify the python representation."""
62+ return f"<{ self ._field .shape [:3 ]} field of { self ._field .shape [- 1 ]} D displacements>"
63+
4964 def map (self , x , inverse = False ):
5065 r"""
5166 Apply the transformation to a list of physical coordinate points.
@@ -71,15 +86,12 @@ def map(self, x, inverse=False):
7186
7287 Examples
7388 --------
74- >>> field = np.zeros((10, 10, 10, 3))
75- >>> field[..., 0] = 4.0
76- >>> fieldimg = nb.Nifti1Image(field, np.diag([2., 2., 2., 1.]))
77- >>> xfm = DisplacementsFieldTransform(fieldimg)
78- >>> xfm([4.0, 4.0, 4.0]).tolist()
79- [[8.0, 4.0, 4.0]]
89+ >>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
90+ >>> xfm.map([-6.5, -36., -19.5]).tolist()
91+ [[-6.5, -36.475167989730835, -19.5]]
8092
81- >>> xfm([[4.0, 4.0, 4.0 ], [8, 2, 10 ]]).tolist()
82- [[8.0, 4.0, 4.0 ], [12 .0, 2.0, 10.0 ]]
93+ >>> xfm.map ([[-6.5, -36., -19.5 ], [-1., -41.5, -11.25 ]]).tolist()
94+ [[-6.5, -36.475167989730835, -19.5 ], [-1 .0, -42.038356602191925, -11.25 ]]
8395
8496 """
8597 if inverse is True :
@@ -112,26 +124,31 @@ class BSplineFieldTransform(TransformBase):
112124
113125 __slots__ = ['_coeffs' , '_knots' , '_weights' , '_order' , '_moving' ]
114126
115- def __init__ (self , reference , coefficients , order = 3 ):
127+ def __init__ (self , coefficients , reference = None , order = 3 ):
116128 """Create a smooth deformation field using B-Spline basis."""
117129 super (BSplineFieldTransform , self ).__init__ ()
118130 self ._order = order
119- self .reference = reference
120131
121132 coefficients = _ensure_image (coefficients )
122- if coefficients .shape [- 1 ] != self .ndim :
123- raise ValueError (
124- 'Number of components of the coefficients does '
125- 'not match the number of dimensions' )
126133
127134 self ._coeffs = np .asanyarray (coefficients .dataobj )
128135 self ._knots = ImageGrid (four_to_three (coefficients )[0 ])
129136 self ._weights = None
137+ if reference is not None :
138+ self .reference = reference
139+
140+ if coefficients .shape [- 1 ] != self .ndim :
141+ raise ValueError (
142+ 'Number of components of the coefficients does '
143+ 'not match the number of dimensions' )
130144
131- def to_field (self , reference = None ):
145+ def to_field (self , reference = None , dtype = "float32" ):
132146 """Generate a displacements deformation field from this B-Spline field."""
133147 reference = _ensure_image (reference )
134148 _ref = self .reference if reference is None else SpatialReference .factory (reference )
149+ if _ref is None :
150+ raise ValueError ("A reference must be defined" )
151+
135152 ndim = self ._coeffs .shape [- 1 ]
136153
137154 # If locations to be interpolated are on a grid, use faster tensor-bspline calculation
@@ -143,7 +160,7 @@ def to_field(self, reference=None):
143160 for d in range (ndim ):
144161 field [:, d ] = self ._coeffs [..., d ].reshape (- 1 ) @ self ._weights
145162
146- return field .astype ("float32" )
163+ return field .astype (dtype )
147164
148165 def apply (
149166 self ,
@@ -215,23 +232,22 @@ def map(self, x, inverse=False):
215232
216233 Examples
217234 --------
218- >>> field = np.zeros((10, 10, 10, 3))
219- >>> field[..., 0] = 4.0
220- >>> fieldimg = nb.Nifti1Image(field, np.diag([2., 2., 2., 1.]))
221- >>> xfm = DisplacementsFieldTransform(fieldimg)
222- >>> xfm([4.0, 4.0, 4.0]).tolist()
223- [[8.0, 4.0, 4.0]]
224-
225- >>> xfm([[4.0, 4.0, 4.0], [8, 2, 10]]).tolist()
226- [[8.0, 4.0, 4.0], [12.0, 2.0, 10.0]]
235+ >>> xfm = BSplineFieldTransform(test_dir / "someones_bspline_coefficients.nii.gz")
236+ >>> xfm.reference = test_dir / "someones_anatomy.nii.gz"
237+ >>> xfm.map([-6.5, -36., -19.5]).tolist()
238+ [[-6.5, -31.476097418406784, -19.5]]
239+
240+ >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
241+ [[-6.5, -31.476097418406784, -19.5], [-1.0, -3.8072675377121996, -11.25]]
242+
227243 """
228244 vfunc = partial (
229245 _map_xyz ,
230246 reference = self .reference ,
231247 knots = self ._knots ,
232248 coeffs = self ._coeffs ,
233249 )
234- return [vfunc (_x ) for _x in np .atleast_2d (x )]
250+ return np . array ( [vfunc (_x ). tolist () for _x in np .atleast_2d (x )])
235251
236252
237253def _map_xyz (x , reference , knots , coeffs ):
0 commit comments