33"""Tests of nonlinear transforms."""
44
55import os
6+ from subprocess import check_call
7+ import shutil
8+
9+ import SimpleITK as sitk
610import pytest
711
812import numpy as np
913import nibabel as nb
14+ from nibabel .affines import from_matvec
1015from nitransforms .resampling import apply
1116from nitransforms .base import TransformError
1217from nitransforms .io .base import TransformFileError
1520 DenseFieldTransform ,
1621)
1722from nitransforms import io
18- from . .io .itk import ITKDisplacementsField
23+ from nitransforms .io .itk import ITKDisplacementsField
1924
2025
2126@pytest .mark .parametrize ("size" , [(20 , 20 , 20 ), (20 , 20 , 20 , 3 )])
@@ -34,16 +39,6 @@ def test_displacements_bad_sizes(size):
3439 DenseFieldTransform (nb .Nifti1Image (np .zeros (size ), np .eye (4 ), None ))
3540
3641
37- def test_itk_disp_load_intent ():
38- """Checks whether the NIfTI intent is fixed."""
39- with pytest .warns (UserWarning ):
40- field = ITKDisplacementsField .from_image (
41- nb .Nifti1Image (np .zeros ((20 , 20 , 20 , 1 , 3 )), np .eye (4 ), None )
42- )
43-
44- assert field .header .get_intent ()[0 ] == "vector"
45-
46-
4742def test_displacements_init ():
4843 identity1 = DenseFieldTransform (
4944 np .zeros ((10 , 10 , 10 , 3 )),
@@ -67,6 +62,30 @@ def test_displacements_init():
6762 )
6863
6964
65+ @pytest .mark .parametrize ("is_deltas" , [True , False ])
66+ def test_densefield_oob_resampling (is_deltas ):
67+ """Ensure mapping outside the field returns input coordinates."""
68+ ref = nb .Nifti1Image (np .zeros ((2 , 2 , 2 ), dtype = "uint8" ), np .eye (4 ))
69+
70+ if is_deltas :
71+ field = nb .Nifti1Image (np .ones ((2 , 2 , 2 , 3 ), dtype = "float32" ), np .eye (4 ))
72+ else :
73+ grid = np .stack (
74+ np .meshgrid (* [np .arange (2 ) for _ in range (3 )], indexing = "ij" ),
75+ axis = - 1 ,
76+ ).astype ("float32" )
77+ field = nb .Nifti1Image (grid + 1.0 , np .eye (4 ))
78+
79+ xfm = DenseFieldTransform (field , is_deltas = is_deltas , reference = ref )
80+
81+ points = np .array ([[- 1.0 , - 1.0 , - 1.0 ], [0.5 , 0.5 , 0.5 ], [3.0 , 3.0 , 3.0 ]])
82+ mapped = xfm .map (points )
83+
84+ assert np .allclose (mapped [0 ], points [0 ])
85+ assert np .allclose (mapped [2 ], points [2 ])
86+ assert np .allclose (mapped [1 ], points [1 ] + 1 )
87+
88+
7089def test_bsplines_init ():
7190 with pytest .raises (TransformError ):
7291 BSplineFieldTransform (
@@ -122,76 +141,6 @@ def test_bspline(tmp_path, testdata_path):
122141 )
123142
124143
125- @pytest .mark .parametrize ("is_deltas" , [True , False ])
126- def test_densefield_x5_roundtrip (tmp_path , is_deltas ):
127- """Ensure dense field transforms roundtrip via X5."""
128- ref = nb .Nifti1Image (np .zeros ((2 , 2 , 2 ), dtype = "uint8" ), np .eye (4 ))
129- disp = nb .Nifti1Image (np .random .rand (2 , 2 , 2 , 3 ).astype ("float32" ), np .eye (4 ))
130-
131- xfm = DenseFieldTransform (disp , is_deltas = is_deltas , reference = ref )
132-
133- node = xfm .to_x5 (metadata = {"GeneratedBy" : "pytest" })
134- assert node .type == "nonlinear"
135- assert node .subtype == "densefield"
136- assert node .representation == "displacements" if is_deltas else "deformations"
137- assert node .domain .size == ref .shape
138- assert node .metadata ["GeneratedBy" ] == "pytest"
139-
140- fname = tmp_path / "test.x5"
141- io .x5 .to_filename (fname , [node ])
142-
143- xfm2 = DenseFieldTransform .from_filename (fname , fmt = "X5" )
144-
145- assert xfm2 .reference .shape == ref .shape
146- assert np .allclose (xfm2 .reference .affine , ref .affine )
147- assert xfm == xfm2
148-
149-
150- def test_bspline_to_x5 (tmp_path ):
151- """Check BSpline transforms export to X5."""
152- coeff = nb .Nifti1Image (np .zeros ((2 , 2 , 2 , 3 ), dtype = "float32" ), np .eye (4 ))
153- ref = nb .Nifti1Image (np .zeros ((2 , 2 , 2 ), dtype = "uint8" ), np .eye (4 ))
154-
155- xfm = BSplineFieldTransform (coeff , reference = ref )
156- node = xfm .to_x5 (metadata = {"tool" : "pytest" })
157- assert node .type == "nonlinear"
158- assert node .subtype == "bspline"
159- assert node .representation == "coefficients"
160- assert node .metadata ["tool" ] == "pytest"
161-
162- fname = tmp_path / "bspline.x5"
163- io .x5 .to_filename (fname , [node ])
164-
165- xfm2 = BSplineFieldTransform .from_filename (fname , fmt = "X5" )
166- assert np .allclose (xfm ._coeffs , xfm2 ._coeffs )
167- assert xfm2 .reference .shape == ref .shape
168- assert np .allclose (xfm2 .reference .affine , ref .affine )
169-
170-
171- @pytest .mark .parametrize ("is_deltas" , [True , False ])
172- def test_densefield_oob_resampling (is_deltas ):
173- """Ensure mapping outside the field returns input coordinates."""
174- ref = nb .Nifti1Image (np .zeros ((2 , 2 , 2 ), dtype = "uint8" ), np .eye (4 ))
175-
176- if is_deltas :
177- field = nb .Nifti1Image (np .ones ((2 , 2 , 2 , 3 ), dtype = "float32" ), np .eye (4 ))
178- else :
179- grid = np .stack (
180- np .meshgrid (* [np .arange (2 ) for _ in range (3 )], indexing = "ij" ),
181- axis = - 1 ,
182- ).astype ("float32" )
183- field = nb .Nifti1Image (grid + 1.0 , np .eye (4 ))
184-
185- xfm = DenseFieldTransform (field , is_deltas = is_deltas , reference = ref )
186-
187- points = np .array ([[- 1.0 , - 1.0 , - 1.0 ], [0.5 , 0.5 , 0.5 ], [3.0 , 3.0 , 3.0 ]])
188- mapped = xfm .map (points )
189-
190- assert np .allclose (mapped [0 ], points [0 ])
191- assert np .allclose (mapped [2 ], points [2 ])
192- assert np .allclose (mapped [1 ], points [1 ] + 1 )
193-
194-
195144def test_bspline_map_gridpoints ():
196145 """BSpline mapping matches dense field on grid points."""
197146 ref = nb .Nifti1Image (np .zeros ((5 , 5 , 5 ), dtype = "uint8" ), np .eye (4 ))
@@ -243,3 +192,128 @@ def manual_map(x):
243192 pts = np .array ([[1.2 , 1.5 , 2.0 ], [3.3 , 1.7 , 2.4 ]])
244193 expected = np .vstack ([manual_map (p ) for p in pts ])
245194 assert np .allclose (bspline .map (pts ), expected , atol = 1e-6 )
195+
196+
197+ def test_densefield_map_against_ants (testdata_path , tmp_path ):
198+ """Map points with DenseFieldTransform and compare to ANTs."""
199+ warpfile = (
200+ testdata_path
201+ / "regressions"
202+ / ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz" )
203+ )
204+ if not warpfile .exists ():
205+ pytest .skip ("Composite transform test data not available" )
206+
207+ points = np .array (
208+ [
209+ [0.0 , 0.0 , 0.0 ],
210+ [1.0 , 2.0 , 3.0 ],
211+ [10.0 , - 10.0 , 5.0 ],
212+ [- 5.0 , 7.0 , - 2.0 ],
213+ [- 12.0 , 12.0 , 0.0 ],
214+ ]
215+ )
216+ csvin = tmp_path / "points.csv"
217+ np .savetxt (csvin , points , delimiter = "," , header = "x,y,z" , comments = "" )
218+
219+ csvout = tmp_path / "out.csv"
220+ cmd = f"antsApplyTransformsToPoints -d 3 -i { csvin } -o { csvout } -t { warpfile } "
221+ exe = cmd .split ()[0 ]
222+ if not shutil .which (exe ):
223+ pytest .skip (f"Command { exe } not found on host" )
224+ check_call (cmd , shell = True )
225+
226+ ants_res = np .genfromtxt (csvout , delimiter = "," , names = True )
227+ ants_pts = np .vstack ([ants_res [n ] for n in ("x" , "y" , "z" )]).T
228+
229+ xfm = DenseFieldTransform (ITKDisplacementsField .from_filename (warpfile ))
230+ mapped = xfm .map (points )
231+
232+ assert np .allclose (mapped , ants_pts , atol = 1e-6 )
233+
234+
235+ @pytest .mark .parametrize ("image_orientation" , ["RAS" , "LAS" , "LPS" , "oblique" ])
236+ @pytest .mark .parametrize ("gridpoints" , [True , False ])
237+ def test_constant_field_vs_ants (tmp_path , get_testdata , image_orientation , gridpoints ):
238+ """Create a constant displacement field and compare mappings."""
239+
240+ nii = get_testdata [image_orientation ]
241+
242+ # Create a reference centered at the origin with various axis orders/flips
243+ shape = nii .shape
244+ ref_affine = nii .affine .copy ()
245+
246+ field = np .hstack ((
247+ np .zeros (np .prod (shape )),
248+ np .linspace (- 80 , 80 , num = np .prod (shape )),
249+ np .linspace (- 50 , 50 , num = np .prod (shape )),
250+ )).reshape (shape + (3 , ))
251+ fieldnii = nb .Nifti1Image (field , ref_affine , None )
252+
253+ warpfile = tmp_path / "itk_transform.nii.gz"
254+ ITKDisplacementsField .to_filename (fieldnii , warpfile )
255+
256+ # Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
257+ xfm = DenseFieldTransform (fieldnii )
258+ itk_xfm = DenseFieldTransform (ITKDisplacementsField .from_filename (warpfile ))
259+
260+ assert xfm == itk_xfm
261+ np .testing .assert_allclose (xfm .reference .affine , itk_xfm .reference .affine )
262+ np .testing .assert_allclose (ref_affine , itk_xfm .reference .affine )
263+ np .testing .assert_allclose (xfm .reference .shape , itk_xfm .reference .shape )
264+ np .testing .assert_allclose (xfm ._field , itk_xfm ._field )
265+
266+ points = (
267+ xfm .reference .ndcoords .T if gridpoints
268+ else np .array (
269+ [
270+ [0.0 , 0.0 , 0.0 ],
271+ [1.0 , 2.0 , 3.0 ],
272+ [10.0 , - 10.0 , 5.0 ],
273+ [- 5.0 , 7.0 , - 2.0 ],
274+ [12.0 , 0.0 , - 11.0 ],
275+ ]
276+ )
277+ )
278+
279+ mapped = xfm .map (points )
280+ nit_deltas = mapped - points
281+
282+ if gridpoints :
283+ np .testing .assert_array_equal (field , nit_deltas .reshape (* shape , - 1 ))
284+
285+ csvin = tmp_path / "points.csv"
286+ np .savetxt (csvin , points , delimiter = "," , header = "x,y,z" , comments = "" )
287+
288+ csvout = tmp_path / "out.csv"
289+ cmd = f"antsApplyTransformsToPoints -d 3 -i { csvin } -o { csvout } -t { warpfile } "
290+ exe = cmd .split ()[0 ]
291+ if not shutil .which (exe ):
292+ pytest .skip (f"Command { exe } not found on host" )
293+ check_call (cmd , shell = True )
294+
295+ ants_res = np .genfromtxt (csvout , delimiter = "," , names = True )
296+ ants_pts = np .vstack ([ants_res [n ] for n in ("x" , "y" , "z" )]).T
297+
298+ # if gridpoints:
299+ # ants_field = ants_pts.reshape(shape + (3, ))
300+ # diff = xfm._field[..., 0] - ants_field[..., 0]
301+ # mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
302+ # assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
303+
304+ # diff = xfm._field[..., 1] - ants_field[..., 1]
305+ # mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
306+ # assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
307+
308+ # diff = xfm._field[..., 2] - ants_field[..., 2]
309+ # mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
310+ # assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
311+
312+ ants_deltas = ants_pts - points
313+ np .testing .assert_array_equal (nit_deltas , ants_deltas )
314+ np .testing .assert_array_equal (mapped , ants_pts )
315+
316+ diff = mapped - ants_pts
317+ mask = np .argwhere (np .abs (diff ) > 1e-2 )[:, 0 ]
318+
319+ assert len (mask ) == 0 , f"A total of { len (mask )} /{ ants_pts .shape [0 ]} contained errors:\n { diff [mask ]} "
0 commit comments