Skip to content

Commit ae6c330

Browse files
committed
enh: update tests accordingly
1 parent b2a794b commit ae6c330

File tree

2 files changed

+134
-2
lines changed

2 files changed

+134
-2
lines changed

nitransforms/io/itk.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,9 @@ def from_image(cls, imgobj):
347347
warnings.warn("Incorrect intent identified.")
348348
hdr.set_intent("vector")
349349

350-
field = np.squeeze(np.asanyarray(imgobj.dataobj)).transpose(2, 1, 0, 3)
350+
field = np.squeeze(np.asanyarray(imgobj.dataobj))
351+
field[..., (0, 1)] *= 1.0
352+
field = field.transpose(2, 1, 0, 3)
351353
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
352354

353355
@classmethod
@@ -357,7 +359,9 @@ def to_image(cls, imgobj):
357359
hdr = imgobj.header.copy()
358360
hdr.set_intent("vector")
359361

360-
field = imgobj.get_fdata().transpose(2, 1, 0, 3)[..., None, :]
362+
field = imgobj.get_fdata()
363+
field = field.transpose(2, 1, 0, 3)[..., None, :]
364+
field[..., (0, 1)] *= 1.0
361365
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
362366

363367

nitransforms/tests/test_nonlinear.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
"""Tests of nonlinear transforms."""
44

55
import os
6+
from subprocess import check_call
7+
import shutil
8+
69
import pytest
710

811
import numpy as np
@@ -234,3 +237,128 @@ def manual_map(x):
234237
pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
235238
expected = np.vstack([manual_map(p) for p in pts])
236239
assert np.allclose(bspline.map(pts), expected, atol=1e-6)
240+
241+
242+
def test_densefield_map_against_ants(testdata_path, tmp_path):
243+
"""Map points with DenseFieldTransform and compare to ANTs."""
244+
warpfile = (
245+
testdata_path
246+
/ "regressions"
247+
/ ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz")
248+
)
249+
if not warpfile.exists():
250+
pytest.skip("Composite transform test data not available")
251+
252+
points = np.array(
253+
[
254+
[0.0, 0.0, 0.0],
255+
[1.0, 2.0, 3.0],
256+
[10.0, -10.0, 5.0],
257+
[-5.0, 7.0, -2.0],
258+
[-12.0, 12.0, 0.0],
259+
]
260+
)
261+
csvin = tmp_path / "points.csv"
262+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
263+
264+
csvout = tmp_path / "out.csv"
265+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
266+
exe = cmd.split()[0]
267+
if not shutil.which(exe):
268+
pytest.skip(f"Command {exe} not found on host")
269+
check_call(cmd, shell=True)
270+
271+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
272+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
273+
274+
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
275+
mapped = xfm.map(points)
276+
277+
assert np.allclose(mapped, ants_pts, atol=1e-6)
278+
279+
280+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
281+
@pytest.mark.parametrize("gridpoints", [True, False])
282+
def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, gridpoints):
283+
"""Create a constant displacement field and compare mappings."""
284+
285+
nii = get_testdata[image_orientation]
286+
287+
# Create a reference centered at the origin with various axis orders/flips
288+
shape = nii.shape
289+
ref_affine = nii.affine.copy()
290+
291+
field = np.hstack((
292+
np.zeros(np.prod(shape)),
293+
np.linspace(-80, 80, num=np.prod(shape)),
294+
np.linspace(-50, 50, num=np.prod(shape)),
295+
)).reshape(shape + (3, ))
296+
fieldnii = nb.Nifti1Image(field, ref_affine, None)
297+
298+
warpfile = tmp_path / "itk_transform.nii.gz"
299+
ITKDisplacementsField.to_filename(fieldnii, warpfile)
300+
301+
# Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
302+
xfm = DenseFieldTransform(fieldnii)
303+
itk_xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
304+
305+
assert xfm == itk_xfm
306+
np.testing.assert_allclose(xfm.reference.affine, itk_xfm.reference.affine)
307+
np.testing.assert_allclose(ref_affine, itk_xfm.reference.affine)
308+
np.testing.assert_allclose(xfm.reference.shape, itk_xfm.reference.shape)
309+
np.testing.assert_allclose(xfm._field, itk_xfm._field)
310+
311+
points = (
312+
xfm.reference.ndcoords.T if gridpoints
313+
else np.array(
314+
[
315+
[0.0, 0.0, 0.0],
316+
[1.0, 2.0, 3.0],
317+
[10.0, -10.0, 5.0],
318+
[-5.0, 7.0, -2.0],
319+
[12.0, 0.0, -11.0],
320+
]
321+
)
322+
)
323+
324+
mapped = xfm.map(points)
325+
nit_deltas = mapped - points
326+
327+
if gridpoints:
328+
np.testing.assert_array_equal(field, nit_deltas.reshape(*shape, -1))
329+
330+
csvin = tmp_path / "points.csv"
331+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
332+
333+
csvout = tmp_path / "out.csv"
334+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
335+
exe = cmd.split()[0]
336+
if not shutil.which(exe):
337+
pytest.skip(f"Command {exe} not found on host")
338+
check_call(cmd, shell=True)
339+
340+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
341+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
342+
343+
# if gridpoints:
344+
# ants_field = ants_pts.reshape(shape + (3, ))
345+
# diff = xfm._field[..., 0] - ants_field[..., 0]
346+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
347+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
348+
349+
# diff = xfm._field[..., 1] - ants_field[..., 1]
350+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
351+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
352+
353+
# diff = xfm._field[..., 2] - ants_field[..., 2]
354+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
355+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
356+
357+
ants_deltas = ants_pts - points
358+
np.testing.assert_array_equal(nit_deltas, ants_deltas)
359+
np.testing.assert_array_equal(mapped, ants_pts)
360+
361+
diff = mapped - ants_pts
362+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
363+
364+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"

0 commit comments

Comments
 (0)