Skip to content

Commit 4e01e23

Browse files
committed
enh: update tests accordingly
1 parent dd180ba commit 4e01e23

File tree

3 files changed

+138
-7
lines changed

3 files changed

+138
-7
lines changed

nitransforms/io/itk.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,9 @@ def from_image(cls, imgobj):
347347
warnings.warn("Incorrect intent identified.")
348348
hdr.set_intent("vector")
349349

350-
field = np.squeeze(np.asanyarray(imgobj.dataobj)).transpose(2, 1, 0, 3)
350+
field = np.squeeze(np.asanyarray(imgobj.dataobj))
351+
field[..., (0, 1)] *= 1.0
352+
field = field.transpose(2, 1, 0, 3)
351353
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
352354

353355
@classmethod
@@ -357,7 +359,9 @@ def to_image(cls, imgobj):
357359
hdr = imgobj.header.copy()
358360
hdr.set_intent("vector")
359361

360-
field = imgobj.get_fdata().transpose(2, 1, 0, 3)[..., None, :]
362+
field = imgobj.get_fdata()
363+
field = field.transpose(2, 1, 0, 3)[..., None, :]
364+
field[..., (0, 1)] *= 1.0
361365
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
362366

363367

nitransforms/tests/test_nonlinear.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
"""Tests of nonlinear transforms."""
44

55
import os
6+
from subprocess import check_call
7+
import shutil
8+
69
import pytest
710

811
import numpy as np
@@ -184,3 +187,128 @@ def manual_map(x):
184187
pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
185188
expected = np.vstack([manual_map(p) for p in pts])
186189
assert np.allclose(bspline.map(pts), expected, atol=1e-6)
190+
191+
192+
def test_densefield_map_against_ants(testdata_path, tmp_path):
193+
"""Map points with DenseFieldTransform and compare to ANTs."""
194+
warpfile = (
195+
testdata_path
196+
/ "regressions"
197+
/ ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz")
198+
)
199+
if not warpfile.exists():
200+
pytest.skip("Composite transform test data not available")
201+
202+
points = np.array(
203+
[
204+
[0.0, 0.0, 0.0],
205+
[1.0, 2.0, 3.0],
206+
[10.0, -10.0, 5.0],
207+
[-5.0, 7.0, -2.0],
208+
[-12.0, 12.0, 0.0],
209+
]
210+
)
211+
csvin = tmp_path / "points.csv"
212+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
213+
214+
csvout = tmp_path / "out.csv"
215+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
216+
exe = cmd.split()[0]
217+
if not shutil.which(exe):
218+
pytest.skip(f"Command {exe} not found on host")
219+
check_call(cmd, shell=True)
220+
221+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
222+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
223+
224+
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
225+
mapped = xfm.map(points)
226+
227+
assert np.allclose(mapped, ants_pts, atol=1e-6)
228+
229+
230+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
231+
@pytest.mark.parametrize("gridpoints", [True, False])
232+
def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, gridpoints):
233+
"""Create a constant displacement field and compare mappings."""
234+
235+
nii = get_testdata[image_orientation]
236+
237+
# Create a reference centered at the origin with various axis orders/flips
238+
shape = nii.shape
239+
ref_affine = nii.affine.copy()
240+
241+
field = np.hstack((
242+
np.zeros(np.prod(shape)),
243+
np.linspace(-80, 80, num=np.prod(shape)),
244+
np.linspace(-50, 50, num=np.prod(shape)),
245+
)).reshape(shape + (3, ))
246+
fieldnii = nb.Nifti1Image(field, ref_affine, None)
247+
248+
warpfile = tmp_path / "itk_transform.nii.gz"
249+
ITKDisplacementsField.to_filename(fieldnii, warpfile)
250+
251+
# Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
252+
xfm = DenseFieldTransform(fieldnii)
253+
itk_xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
254+
255+
assert xfm == itk_xfm
256+
np.testing.assert_allclose(xfm.reference.affine, itk_xfm.reference.affine)
257+
np.testing.assert_allclose(ref_affine, itk_xfm.reference.affine)
258+
np.testing.assert_allclose(xfm.reference.shape, itk_xfm.reference.shape)
259+
np.testing.assert_allclose(xfm._field, itk_xfm._field)
260+
261+
points = (
262+
xfm.reference.ndcoords.T if gridpoints
263+
else np.array(
264+
[
265+
[0.0, 0.0, 0.0],
266+
[1.0, 2.0, 3.0],
267+
[10.0, -10.0, 5.0],
268+
[-5.0, 7.0, -2.0],
269+
[12.0, 0.0, -11.0],
270+
]
271+
)
272+
)
273+
274+
mapped = xfm.map(points)
275+
nit_deltas = mapped - points
276+
277+
if gridpoints:
278+
np.testing.assert_array_equal(field, nit_deltas.reshape(*shape, -1))
279+
280+
csvin = tmp_path / "points.csv"
281+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
282+
283+
csvout = tmp_path / "out.csv"
284+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
285+
exe = cmd.split()[0]
286+
if not shutil.which(exe):
287+
pytest.skip(f"Command {exe} not found on host")
288+
check_call(cmd, shell=True)
289+
290+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
291+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
292+
293+
# if gridpoints:
294+
# ants_field = ants_pts.reshape(shape + (3, ))
295+
# diff = xfm._field[..., 0] - ants_field[..., 0]
296+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
297+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
298+
299+
# diff = xfm._field[..., 1] - ants_field[..., 1]
300+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
301+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
302+
303+
# diff = xfm._field[..., 2] - ants_field[..., 2]
304+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
305+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
306+
307+
ants_deltas = ants_pts - points
308+
np.testing.assert_array_equal(nit_deltas, ants_deltas)
309+
np.testing.assert_array_equal(mapped, ants_pts)
310+
311+
diff = mapped - ants_pts
312+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
313+
314+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"

nitransforms/tests/test_resampling.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def test_displacements_field1(
192192

193193
xfm = nitnl.load(xfm_fname, fmt=sw_tool)
194194

195+
import pdb; pdb.set_trace()
196+
195197
# Then apply the transform and cross-check with software
196198
cmd = APPLY_NONLINEAR_CMD[sw_tool](
197199
transform=os.path.abspath(xfm_fname),
@@ -247,11 +249,7 @@ def test_displacements_field1(
247249
assert np.sqrt((diff[5:-5, 5:-5, 5:-5] ** 2).mean()) < 1e-6
248250

249251

250-
@pytest.mark.xfail(
251-
reason="Disable while #266 is developed.",
252-
strict=False,
253-
)
254-
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
252+
@pytest.mark.parametrize("sw_tool", ["afni"])
255253
def test_displacements_field2(tmp_path, testdata_path, sw_tool):
256254
"""Check a translation-only field on one or more axes, different image orientations."""
257255
os.chdir(str(tmp_path))
@@ -283,6 +281,7 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool):
283281
nt_moved = apply(xfm, img_fname, order=0)
284282
nt_moved.to_filename("nt_resampled.nii.gz")
285283
sw_moved.set_data_dtype(nt_moved.get_data_dtype())
284+
286285
diff = np.asanyarray(
287286
sw_moved.dataobj, dtype=sw_moved.get_data_dtype()
288287
) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())

0 commit comments

Comments
 (0)