2020 ImageGrid ,
2121 SpatialReference ,
2222 _as_homogeneous ,
23+ EQUALITY_TOL ,
2324)
2425
2526
26- class DisplacementsFieldTransform (TransformBase ):
27- """Represents a dense field of displacements (one vector per voxel)."""
27+ class DeformationFieldTransform (TransformBase ):
28+ """Represents a dense field of deformed locations (corresponding to each voxel)."""
2829
2930 __slots__ = ["_field" ]
3031
@@ -34,8 +35,8 @@ def __init__(self, field, reference=None):
3435
3536 Example
3637 -------
37- >>> DisplacementsFieldTransform (test_dir / "someones_displacement_field.nii.gz")
38- <DisplacementFieldTransform [3D] (57, 67, 56)>
38+ >>> DeformationFieldTransform (test_dir / "someones_displacement_field.nii.gz")
39+ <DeformationFieldTransform [3D] (57, 67, 56)>
3940
4041 """
4142 super ().__init__ ()
@@ -59,13 +60,13 @@ def __init__(self, field, reference=None):
5960 ndim = self ._field .ndim - 1
6061 if self ._field .shape [- 1 ] != ndim :
6162 raise TransformError (
62- "The number of components of the displacements (%d) does not "
63+ "The number of components of the displacements (%d) does not match "
6364 "the number of dimensions (%d)" % (self ._field .shape [- 1 ], ndim )
6465 )
6566
6667 def __repr__ (self ):
6768 """Beautify the python representation."""
68- return f"<DisplacementFieldTransform [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
69+ return f"<{ self . __class__ . __name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
6970
7071 def map (self , x , inverse = False ):
7172 r"""
@@ -92,12 +93,12 @@ def map(self, x, inverse=False):
9293
9394 Examples
9495 --------
95- >>> xfm = DisplacementsFieldTransform (test_dir / "someones_displacement_field.nii.gz")
96+ >>> xfm = DeformationFieldTransform (test_dir / "someones_displacement_field.nii.gz")
9697 >>> xfm.map([-6.5, -36., -19.5]).tolist()
97- [[-6.5 , -36.475167989730835, -19.5 ]]
98+ [[0.0 , -0.47516798973083496, 0.0 ]]
9899
99100 >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
100- [[-6.5 , -36.475167989730835, -19.5 ], [-1 .0, -42.038356602191925, -11.25 ]]
101+ [[0.0 , -0.47516798973083496, 0.0 ], [0 .0, -0.538356602191925, 0.0 ]]
101102
102103 """
103104
@@ -108,7 +109,76 @@ def map(self, x, inverse=False):
108109 if np .any (np .abs (ijk - indexes ) > 0.05 ):
109110 warnings .warn ("Some coordinates are off-grid of the displacements field." )
110111 indexes = tuple (tuple (i ) for i in indexes .T )
111- return x + self ._field [indexes ]
112+ return self ._field [indexes ]
113+
114+ def __matmul__ (self , b ):
115+ """
116+ Compose with a transform on the right.
117+
118+ Examples
119+ --------
120+ >>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
121+ >>> xfm2 = xfm @ TransformBase()
122+ >>> xfm == xfm2
123+ True
124+
125+ """
126+ retval = b .map (
127+ self ._field .reshape ((- 1 , self ._field .shape [- 1 ]))
128+ ).reshape (self ._field .shape )
129+ return DeformationFieldTransform (retval , reference = self .reference )
130+
131+ def __eq__ (self , other ):
132+ """
133+ Overload equals operator.
134+
135+ Examples
136+ --------
137+ >>> xfm1 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
138+ >>> xfm2 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
139+ >>> xfm1 == xfm2
140+ True
141+
142+ """
143+ _eq = np .allclose (self ._field , other ._field , rtol = EQUALITY_TOL )
144+ if _eq and self ._reference != other ._reference :
145+ warnings .warn ("Fields are equal, but references do not match." )
146+ return _eq
147+
148+
149+ class DisplacementsFieldTransform (DeformationFieldTransform ):
150+ """
151+ Represents a dense field of displacements (one vector per voxel).
152+
153+ Converting to a field of deformations is straightforward by just adding the corresponding
154+ displacement to the :math:`(x, y, z)` coordinates of each voxel.
155+ Numerically, deformation fields are less susceptible to rounding errors
156+ than displacements fields.
157+ SPM generally prefers deformations for that reason.
158+
159+ """
160+
161+ __slots__ = ["_displacements" ]
162+
163+ def __init__ (self , field , reference = None ):
164+ """
165+ Create a transform supported by a field of voxel-wise displacements.
166+
167+ Example
168+ -------
169+ >>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
170+ >>> xfm
171+ <DisplacementsFieldTransform[3D] (57, 67, 56)>
172+
173+ >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
174+ [[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
175+
176+ """
177+ super ().__init__ (field , reference = reference )
178+ self ._displacements = self ._field
179+ # Convert from displacements to deformations fields
180+ # (just add the origin to the displacements vector)
181+ self ._field += self .reference .ndcoords .T .reshape (self ._field .shape )
112182
113183 @classmethod
114184 def from_filename (cls , filename , fmt = "X5" ):
0 commit comments