Skip to content

Commit 78cb070

Browse files
committed
fix: write tests to avoid regressions
Resolves: #267.
1 parent aeb629f commit 78cb070

File tree

2 files changed

+57
-41
lines changed

2 files changed

+57
-41
lines changed

nitransforms/tests/test_nonlinear.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,24 @@
1010

1111
import numpy as np
1212
import nibabel as nb
13-
from nitransforms.resampling import apply
1413
from nitransforms.base import TransformError
1514
from nitransforms.nonlinear import (
1615
BSplineFieldTransform,
1716
DenseFieldTransform,
1817
)
1918

2019

20+
SOME_TEST_POINTS = np.array(
21+
[
22+
[0.0, 0.0, 0.0],
23+
[1.0, 2.0, 3.0],
24+
[10.0, -10.0, 5.0],
25+
[-5.0, 7.0, -2.0],
26+
[12.0, 0.0, -11.0],
27+
]
28+
)
29+
30+
2131
def test_displacements_init():
2232
identity1 = DenseFieldTransform(
2333
np.zeros((10, 10, 10, 3)),
@@ -62,41 +72,12 @@ def test_bsplines_references(testdata_path):
6272
testdata_path / "someones_bspline_coefficients.nii.gz"
6373
).to_field()
6474

65-
with pytest.raises(TransformError):
66-
apply(
67-
BSplineFieldTransform(
68-
testdata_path / "someones_bspline_coefficients.nii.gz"
69-
),
70-
testdata_path / "someones_anatomy.nii.gz",
71-
)
72-
73-
apply(
74-
BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"),
75-
testdata_path / "someones_anatomy.nii.gz",
75+
BSplineFieldTransform(
76+
testdata_path / "someones_bspline_coefficients.nii.gz",
7677
reference=testdata_path / "someones_anatomy.nii.gz",
7778
)
7879

7980

80-
@pytest.mark.xfail(
81-
reason="Disable while #266 is developed.",
82-
strict=False,
83-
)
84-
def test_bspline(tmp_path, testdata_path):
85-
"""
86-
Cross-check B-Splines and deformation field.
87-
88-
This test is disabled and will be split into two separate tests.
89-
The current implementation will be moved into test_resampling.py,
90-
since that's what it actually tests.
91-
92-
In GH-266, this test will be re-implemented by testing the equivalence
93-
of the B-Spline and deformation field transforms by calling the
94-
transform's `map()` method on points.
95-
96-
"""
97-
assert True
98-
99-
10081
def test_map_bspline_vs_displacement(tmp_path, testdata_path):
10182
"""Cross-check B-Splines and deformation field."""
10283
os.chdir(str(tmp_path))
@@ -107,6 +88,8 @@ def test_map_bspline_vs_displacement(tmp_path, testdata_path):
10788

10889
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name).to_field()
10990
dispxfm = DenseFieldTransform(disp_name)
91+
# Interpolating field should be reasonably similar
92+
np.testing.assert_allclose(dispxfm._field, bsplxfm._field, atol=1e-1, rtol=1e-4)
11093

11194
# Interpolating the field should be reasonably similar
11295
np.testing.assert_allclose(dispxfm._field, bsplxfm._field, atol=1e-1, rtol=1e-4)

nitransforms/tests/test_resampling.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ def test_apply_linear_transform(
150150

151151

152152
@pytest.mark.xfail(
153-
reason="Disable while #266 is developed.",
153+
reason="GH-267: disabled while debugging",
154154
strict=False,
155155
)
156156
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
157157
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
158158
@pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
159-
def test_displacements_field1(
159+
def test_apply_displacements_field1(
160160
tmp_path,
161161
get_testdata,
162162
get_testmask,
@@ -190,16 +190,17 @@ def test_displacements_field1(
190190
else:
191191
field.to_filename(xfm_fname)
192192

193-
xfm = nitnl.load(xfm_fname, fmt=sw_tool)
193+
# xfm = nitnl.load(xfm_fname, fmt=sw_tool)
194+
xfm = nitnl.DenseFieldTransform(fieldmap, reference=nii)
194195

196+
ants_output = tmp_path / "ants_brainmask.nii.gz"
195197
# Then apply the transform and cross-check with software
196198
cmd = APPLY_NONLINEAR_CMD[sw_tool](
197199
transform=os.path.abspath(xfm_fname),
198200
reference=tmp_path / "mask.nii.gz",
199201
moving=tmp_path / "mask.nii.gz",
200-
output=tmp_path / "resampled_brainmask.nii.gz",
201-
extra="",
202-
# extra="--output-data-type uchar" if sw_tool == "itk" else "",
202+
output=ants_output,
203+
extra="--output-data-type uchar" if sw_tool == "itk" else "",
203204
)
204205

205206
# skip test if command is not available on host
@@ -209,11 +210,13 @@ def test_displacements_field1(
209210

210211
# resample mask
211212
exit_code = check_call([cmd], shell=True)
212-
sw_moved_mask = nb.load("resampled_brainmask.nii.gz")
213+
assert exit_code == 0
214+
sw_moved_mask = nb.load(ants_output)
213215
nt_moved_mask = apply(xfm, msk, order=0)
214216

215-
# Calculate xor between both:
216-
sw_mask = np.asanyarray(sw_moved_mask.dataobj, dtype=bool)
217+
nt_moved_mask.to_filename(tmp_path / "nit_brainmask.nii.gz")
218+
219+
assert np.sqrt((diff**2).mean()) < RMSE_TOL_LINEAR
217220
brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool)
218221
percent_diff = (sw_mask != brainmask)[5:-5, 5:-5, 5:-5].sum() / brainmask.size
219222

@@ -403,3 +406,33 @@ def test_apply_4d(serialize_4d):
403406
data = np.asanyarray(moved.dataobj)
404407
idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)]
405408
assert idxs == [(9 - i, 2, 2) for i in range(nvols)]
409+
410+
411+
@pytest.mark.xfail(
412+
reason="GH-267: disabled while debugging",
413+
strict=False,
414+
)
415+
def test_apply_bspline(tmp_path, testdata_path):
416+
"""Cross-check B-Splines and deformation field."""
417+
os.chdir(str(tmp_path))
418+
419+
img_name = testdata_path / "someones_anatomy.nii.gz"
420+
disp_name = testdata_path / "someones_displacement_field.nii.gz"
421+
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"
422+
423+
bsplxfm = nitnl.BSplineFieldTransform(bs_name, reference=img_name)
424+
dispxfm = nitnl.DenseFieldTransform(disp_name)
425+
426+
out_disp = apply(dispxfm, img_name)
427+
out_bspl = apply(bsplxfm, img_name)
428+
429+
out_disp.to_filename("resampled_field.nii.gz")
430+
out_bspl.to_filename("resampled_bsplines.nii.gz")
431+
432+
assert (
433+
np.sqrt(
434+
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32"))
435+
** 2
436+
).mean()
437+
< 0.2
438+
)

0 commit comments

Comments
 (0)