Skip to content

Commit 5b8698a

Browse files
committed
fix: write tests to avoid regressions
Resolves: #267.
1 parent aca0ac9 commit 5b8698a

File tree

2 files changed

+45
-32
lines changed

2 files changed

+45
-32
lines changed

nitransforms/tests/test_nonlinear.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,24 @@
1010

1111
import numpy as np
1212
import nibabel as nb
13-
from nitransforms.resampling import apply
1413
from nitransforms.base import TransformError
1514
from nitransforms.nonlinear import (
1615
BSplineFieldTransform,
1716
DenseFieldTransform,
1817
)
1918

2019

20+
SOME_TEST_POINTS = np.array(
21+
[
22+
[0.0, 0.0, 0.0],
23+
[1.0, 2.0, 3.0],
24+
[10.0, -10.0, 5.0],
25+
[-5.0, 7.0, -2.0],
26+
[12.0, 0.0, -11.0],
27+
]
28+
)
29+
30+
2131
def test_displacements_init():
2232
identity1 = DenseFieldTransform(
2333
np.zeros((10, 10, 10, 3)),
@@ -62,41 +72,12 @@ def test_bsplines_references(testdata_path):
6272
testdata_path / "someones_bspline_coefficients.nii.gz"
6373
).to_field()
6474

65-
with pytest.raises(TransformError):
66-
apply(
67-
BSplineFieldTransform(
68-
testdata_path / "someones_bspline_coefficients.nii.gz"
69-
),
70-
testdata_path / "someones_anatomy.nii.gz",
71-
)
72-
73-
apply(
74-
BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"),
75-
testdata_path / "someones_anatomy.nii.gz",
75+
BSplineFieldTransform(
76+
testdata_path / "someones_bspline_coefficients.nii.gz",
7677
reference=testdata_path / "someones_anatomy.nii.gz",
7778
)
7879

7980

80-
@pytest.mark.xfail(
81-
reason="Disable while #266 is developed.",
82-
strict=False,
83-
)
84-
def test_bspline(tmp_path, testdata_path):
85-
"""
86-
Cross-check B-Splines and deformation field.
87-
88-
This test is disabled and will be split into two separate tests.
89-
The current implementation will be moved into test_resampling.py,
90-
since that's what it actually tests.
91-
92-
In GH-266, this test will be re-implemented by testing the equivalence
93-
of the B-Spline and deformation field transforms by calling the
94-
transform's `map()` method on points.
95-
96-
"""
97-
assert True
98-
99-
10081
def test_map_bspline_vs_displacement(tmp_path, testdata_path):
10182
"""Cross-check B-Splines and deformation field."""
10283
os.chdir(str(tmp_path))
@@ -107,6 +88,8 @@ def test_map_bspline_vs_displacement(tmp_path, testdata_path):
10788

10889
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name).to_field()
10990
dispxfm = DenseFieldTransform(disp_name)
91+
# Interpolating field should be reasonably similar
92+
np.testing.assert_allclose(dispxfm._field, bsplxfm._field, atol=1e-1, rtol=1e-4)
11093

11194
# Interpolating the field should be reasonably similar
11295
np.testing.assert_allclose(dispxfm._field, bsplxfm._field, atol=1e-1, rtol=1e-4)

nitransforms/tests/test_resampling.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,33 @@ def test_apply_4d(serialize_4d):
402402
data = np.asanyarray(moved.dataobj)
403403
idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)]
404404
assert idxs == [(9 - i, 2, 2) for i in range(nvols)]
405+
406+
407+
@pytest.mark.xfail(
408+
reason="GH-267: disabled while debugging",
409+
strict=False,
410+
)
411+
def test_apply_bspline(tmp_path, testdata_path):
412+
"""Cross-check B-Splines and deformation field."""
413+
os.chdir(str(tmp_path))
414+
415+
img_name = testdata_path / "someones_anatomy.nii.gz"
416+
disp_name = testdata_path / "someones_displacement_field.nii.gz"
417+
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"
418+
419+
bsplxfm = nitnl.BSplineFieldTransform(bs_name, reference=img_name)
420+
dispxfm = nitnl.DenseFieldTransform(disp_name)
421+
422+
out_disp = apply(dispxfm, img_name)
423+
out_bspl = apply(bsplxfm, img_name)
424+
425+
out_disp.to_filename("resampled_field.nii.gz")
426+
out_bspl.to_filename("resampled_bsplines.nii.gz")
427+
428+
assert (
429+
np.sqrt(
430+
(out_disp.get_fdata(dtype="float32") - out_bspl.get_fdata(dtype="float32"))
431+
** 2
432+
).mean()
433+
< 0.2
434+
)

0 commit comments

Comments
 (0)