From fbc92284642a003ad8bdcd66e616070e1143abb0 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 20 Jul 2022 17:04:59 +0200 Subject: [PATCH 01/16] ENH: Collapse linear and nonlinear transforms chains Very undertested, but currently there is a test that uses a "collapsed" transform on an ITK's .h5 file with one affine and one nonlinear. BSpline transforms not currently supported. Resolves #89. --- nitransforms/linear.py | 16 +++++++--------- nitransforms/manip.py | 16 +++++++--------- nitransforms/tests/test_linear.py | 6 +++--- nitransforms/tests/test_manip.py | 8 +++++++- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 9c430d3b..239f0ebc 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -123,19 +123,17 @@ def __matmul__(self, b): True >>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) - >>> xfm1 @ np.eye(4) == xfm1 + >>> xfm1 @ Affine() == xfm1 True """ - if not isinstance(b, self.__class__): - _b = self.__class__(b) - else: - _b = b + if isinstance(b, self.__class__): + return self.__class__( + b.matrix @ self.matrix, + reference=b.reference, + ) - retval = self.__class__(self.matrix.dot(_b.matrix)) - if _b.reference: - retval.reference = _b.reference - return retval + return b @ self @property def matrix(self): diff --git a/nitransforms/manip.py b/nitransforms/manip.py index 233f5adf..58d15058 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -8,7 +8,6 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" from collections.abc import Iterable -import numpy as np from .base import ( TransformBase, @@ -140,9 +139,9 @@ def map(self, x, inverse=False): return x - def asaffine(self, indices=None): + def collapse(self): """ - Combine a succession of linear transforms into one. + Combine a succession of transforms into one. Example ------ @@ -150,7 +149,7 @@ def asaffine(self, indices=None): ... Affine.from_matvec(vec=(2, -10, 3)), ... Affine.from_matvec(vec=(-2, 10, -3)), ... ]) - >>> chain.asaffine() + >>> chain.collapse() array([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], @@ -160,7 +159,7 @@ def asaffine(self, indices=None): ... Affine.from_matvec(vec=(1, 2, 3)), ... Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]), ... ]) - >>> chain.asaffine() + >>> chain.collapse() array([[0., 1., 0., 2.], [0., 0., 1., 3.], [1., 0., 0., 1.], @@ -168,7 +167,7 @@ def asaffine(self, indices=None): >>> np.allclose( ... chain.map((4, -2, 1)), - ... chain.asaffine().map((4, -2, 1)), + ... chain.collapse().map((4, -2, 1)), ... ) True @@ -178,9 +177,8 @@ def asaffine(self, indices=None): The indices of the values to extract. """ - affines = self.transforms if indices is None else np.take(self.transforms, indices) - retval = affines[0] - for xfm in affines[1:]: + retval = self.transforms[-1] + for xfm in reversed(self.transforms[:-1]): retval = xfm @ retval return retval diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index eea77b7f..f3f83b38 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -372,10 +372,10 @@ def test_mulmat_operator(testdata_path): mat2 = from_matvec(np.eye(3), (4, 2, -1)) aff = nitl.Affine(mat1, reference=ref) - composed = aff @ mat2 + composed = aff @ nitl.Affine(mat2) assert composed.reference is None - assert composed == nitl.Affine(mat1.dot(mat2)) + assert composed == nitl.Affine(mat2 @ mat1) composed = nitl.Affine(mat2) @ aff assert composed.reference == aff.reference - assert composed == nitl.Affine(mat2.dot(mat1), reference=ref) + assert composed == nitl.Affine(mat1 @ mat2, reference=ref) diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index 6dee540e..59f7f3b7 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -60,6 +60,12 @@ def test_itk_h5(tmp_path, testdata_path): # A certain tolerance is necessary because of resampling at borders assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL + col_moved = xfm.collapse().apply(img_fname, order=0) + col_moved.to_filename("nt_collapse_resampled.nii.gz") + diff = sw_moved.get_fdata() - col_moved.get_fdata() + # A certain tolerance is necessary because of resampling at borders + assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL + @pytest.mark.parametrize("ext0", ["lta", "tfm"]) @pytest.mark.parametrize("ext1", ["lta", "tfm"]) @@ -81,7 +87,7 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2): ] ) assert np.allclose( - chain.asaffine().matrix, + chain.collapse().matrix, Affine.from_filename( data_path / "regressions" / f"from-fsnative_to-bold_mode-image.{ext2}", fmt=f"{FMT[ext2]}", From 0b134082298a7dedc54d8fa32eec8d849f733904 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 17 Jul 2025 18:29:30 +0200 Subject: [PATCH 02/16] enh: read X5 transform files --- nitransforms/io/x5.py | 49 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index 463a1336..378e03d3 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -25,6 +25,8 @@ import numpy as np +from .base import TransformFileError + @dataclass class X5Domain: @@ -136,3 +138,50 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): # "AdditionalParameters", data=node.additional_parameters # ) return fname + + +def from_filename(fname: str | Path) -> List[X5Transform]: + """Read a list of :class:`X5Transform` objects from an X5 HDF5 file.""" + try: + with h5py.File(str(fname), "r") as in_file: + if in_file.attrs.get("Format") != "X5": + raise TransformFileError("Input file is not in X5 format") + + tg = in_file["TransformGroup"] + return [ + _read_x5_group(node) + for _, node in sorted(tg.items(), key=lambda kv: int(kv[0])) + ] + except OSError as exc: # pragma: no cover - in case h5py not installed + raise TransformFileError(str(exc)) from exc + + +def _read_x5_group(node) -> X5Transform: + x5 = X5Transform( + type=node.attrs["Type"], + transform=np.asarray(node["Transform"]), + subtype=node.attrs.get("SubType"), + representation=node.attrs.get("Representation"), + metadata=json.loads(node.attrs["Metadata"]) + if "Metadata" in node.attrs + else None, + dimension_kinds=[ + k.decode() if isinstance(k, bytes) else k + for k in node["DimensionKinds"][()] + ], + domain=None, + inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None, + jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None, + array_length=int(node.attrs.get("ArrayLength", 1)), + ) + + if "Domain" in node: + dgrp = node["Domain"] + x5.domain = X5Domain( + grid=bool(int(np.asarray(dgrp["Grid"]))), + size=tuple(np.asarray(dgrp["Size"])), + mapping=np.asarray(dgrp["Mapping"]), + coordinates=dgrp.attrs.get("Coordinates"), + ) + + return x5 From a22a6d01cc0d8289bdfb1262678f0fa0d06ffd16 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 17 Jul 2025 18:38:03 +0200 Subject: [PATCH 03/16] refactor: simplify X5 loader --- nitransforms/io/x5.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index 378e03d3..ec36c708 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -25,8 +25,6 @@ import numpy as np -from .base import TransformFileError - @dataclass class X5Domain: @@ -142,18 +140,15 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): def from_filename(fname: str | Path) -> List[X5Transform]: """Read a list of :class:`X5Transform` objects from an X5 HDF5 file.""" - try: - with h5py.File(str(fname), "r") as in_file: - if in_file.attrs.get("Format") != "X5": - raise TransformFileError("Input file is not in X5 format") - - tg = in_file["TransformGroup"] - return [ - _read_x5_group(node) - for _, node in sorted(tg.items(), key=lambda kv: int(kv[0])) - ] - except OSError as exc: # pragma: no cover - in case h5py not installed - raise TransformFileError(str(exc)) from exc + with h5py.File(str(fname), "r") as in_file: + if in_file.attrs.get("Format") != "X5": + raise ValueError("Input file is not in X5 format") + + tg = in_file["TransformGroup"] + return [ + _read_x5_group(node) + for _, node in sorted(tg.items(), key=lambda kv: int(kv[0])) + ] def _read_x5_group(node) -> X5Transform: From 2d3ba2f748889cc68ae66e77bb11bdc7c4881ed1 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 17 Jul 2025 18:43:55 +0200 Subject: [PATCH 04/16] test: cover x5.from_filename --- nitransforms/tests/test_x5.py | 38 ++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/nitransforms/tests/test_x5.py b/nitransforms/tests/test_x5.py index 8502a387..51e39ceb 100644 --- a/nitransforms/tests/test_x5.py +++ b/nitransforms/tests/test_x5.py @@ -1,7 +1,8 @@ import numpy as np +import pytest from h5py import File as H5File -from ..io.x5 import X5Transform, X5Domain, to_filename +from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename def test_x5_transform_defaults(): @@ -39,3 +40,38 @@ def test_to_filename(tmp_path): assert "0" in grp assert grp["0"].attrs["Type"] == "linear" assert grp["0"].attrs["ArrayLength"] == 1 + + +def test_from_filename_roundtrip(tmp_path): + domain = X5Domain(grid=False, size=(5, 5, 5), mapping=np.eye(4)) + node = X5Transform( + type="linear", + transform=np.eye(4), + dimension_kinds=("space", "space", "space", "vector"), + domain=domain, + metadata={"foo": "bar"}, + inverse=np.eye(4), + ) + fname = tmp_path / "test.x5" + to_filename(fname, [node]) + + x5_list = from_filename(fname) + assert len(x5_list) == 1 + x5 = x5_list[0] + assert x5.type == node.type + assert np.allclose(x5.transform, node.transform) + assert x5.dimension_kinds == list(node.dimension_kinds) + assert x5.domain.grid == domain.grid + assert x5.domain.size == tuple(domain.size) + assert np.allclose(x5.domain.mapping, domain.mapping) + assert x5.metadata == node.metadata + assert np.allclose(x5.inverse, node.inverse) + + +def test_from_filename_invalid(tmp_path): + fname = tmp_path / "invalid.h5" + with H5File(fname, "w") as f: + f.attrs["Format"] = "NOTX5" + + with pytest.raises(ValueError): + from_filename(fname) From eeb4d5d4fcffcc354ceb911dcc92336b34b31d73 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 17 Jul 2025 19:29:45 +0200 Subject: [PATCH 05/16] enh: enable loading of X5 affines --- nitransforms/linear.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/nitransforms/linear.py b/nitransforms/linear.py index cf8f8465..13430285 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -27,7 +27,12 @@ EQUALITY_TOL, ) from nitransforms.io import get_linear_factory, TransformFileError -from nitransforms.io.x5 import X5Transform, X5Domain, to_filename as save_x5 +from nitransforms.io.x5 import ( + X5Transform, + X5Domain, + to_filename as save_x5, + from_filename as load_x5, +) class Affine(TransformBase): @@ -174,8 +179,20 @@ def ndim(self): return self._matrix.ndim + 1 @classmethod - def from_filename(cls, filename, fmt=None, reference=None, moving=None): + def from_filename( + cls, filename, fmt=None, reference=None, moving=None, x5_position=0 + ): """Create an affine from a transform file.""" + + if fmt and fmt.upper() == "X5": + x5_xfm = load_x5(filename)[x5_position] + Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping + if x5_xfm.domain: + # override reference + raise NotImplementedError + + return Transform(x5_xfm.transform, reference=reference) + fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") if fmt is not None and not Path(filename).exists(): @@ -265,7 +282,9 @@ def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False): if fmt.upper() == "X5": return save_x5(filename, [self.to_x5(store_inverse=x5_inverse)]) - writer = get_linear_factory(fmt, is_array=isinstance(self, LinearTransformsMapping)) + writer = get_linear_factory( + fmt, is_array=isinstance(self, LinearTransformsMapping) + ) if fmt.lower() in ("itk", "ants", "elastix"): writer.from_ras(self.matrix).to_filename(filename) From 8e5096969c92fba7140a51e1e3d8ed8cf45edd6c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 17 Jul 2025 19:51:40 +0200 Subject: [PATCH 06/16] fix: process exceptions when trying to open X5 --- nitransforms/io/x5.py | 24 +++++++++++++++--------- nitransforms/linear.py | 16 +++++++++++++--- nitransforms/tests/test_linear.py | 6 ++++++ nitransforms/tests/test_x5.py | 2 +- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index ec36c708..a86a8554 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -140,15 +140,21 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): def from_filename(fname: str | Path) -> List[X5Transform]: """Read a list of :class:`X5Transform` objects from an X5 HDF5 file.""" - with h5py.File(str(fname), "r") as in_file: - if in_file.attrs.get("Format") != "X5": - raise ValueError("Input file is not in X5 format") - - tg = in_file["TransformGroup"] - return [ - _read_x5_group(node) - for _, node in sorted(tg.items(), key=lambda kv: int(kv[0])) - ] + try: + with h5py.File(str(fname), "r") as in_file: + if in_file.attrs.get("Format") != "X5": + raise TypeError("Input file is not in X5 format") + + tg = in_file["TransformGroup"] + return [ + _read_x5_group(node) + for _, node in sorted(tg.items(), key=lambda kv: int(kv[0])) + ] + except OSError as err: + if "file signature not found" in err.args[0]: + raise TypeError("Input file is not HDF5.") + + raise # pragma: no cover def _read_x5_group(node) -> X5Transform: diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 13430285..8797a1c8 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -9,6 +9,7 @@ """Linear transforms.""" import warnings +from collections import namedtuple import numpy as np from pathlib import Path @@ -187,9 +188,18 @@ def from_filename( if fmt and fmt.upper() == "X5": x5_xfm = load_x5(filename)[x5_position] Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping - if x5_xfm.domain: - # override reference - raise NotImplementedError + if ( + x5_xfm.domain + and not x5_xfm.domain.grid + and len(x5_xfm.domain.size) == 3 + ): # pragma: no cover + raise NotImplementedError( + "Only 3D regularly gridded domains are supported" + ) + elif x5_xfm.domain: + # Override reference + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) return Transform(x5_xfm.transform, reference=reference) diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 32634c61..89583d57 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -265,6 +265,9 @@ def test_linear_to_x5(tmpdir, store_inverse): aff.to_filename("export1.x5", x5_inverse=store_inverse) + # Test round trip + assert aff == nitl.Affine.from_filename("export1.x5", fmt="X5") + # Test with Domain img = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="float32"), np.eye(4)) img_path = Path(tmpdir) / "ref.nii.gz" @@ -275,6 +278,9 @@ def test_linear_to_x5(tmpdir, store_inverse): assert node.domain.size == aff.reference.shape aff.to_filename("export2.x5", x5_inverse=store_inverse) + # Test round trip + assert aff == nitl.Affine.from_filename("export2.x5", fmt="X5") + # Test with Jacobian node.jacobian = np.zeros((2, 2, 2), dtype="float32") io.x5.to_filename("export3.x5", [node]) diff --git a/nitransforms/tests/test_x5.py b/nitransforms/tests/test_x5.py index 51e39ceb..89b49e06 100644 --- a/nitransforms/tests/test_x5.py +++ b/nitransforms/tests/test_x5.py @@ -73,5 +73,5 @@ def test_from_filename_invalid(tmp_path): with H5File(fname, "w") as f: f.attrs["Format"] = "NOTX5" - with pytest.raises(ValueError): + with pytest.raises(TypeError): from_filename(fname) From b02077eb6e24981005355a8fcd9cb6f562cf01c3 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 07:06:21 +0200 Subject: [PATCH 07/16] tst: add round-trip test to linear mappings --- nitransforms/tests/test_linear.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 89583d57..d1e5e47e 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -286,16 +286,24 @@ def test_linear_to_x5(tmpdir, store_inverse): io.x5.to_filename("export3.x5", [node]) -def test_mapping_to_x5(): +@pytest.mark.parametrize("store_inverse", [True, False]) +def test_mapping_to_x5(tmp_path, store_inverse): mats = [ np.eye(4), np.array([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]]), ] mapping = nitl.LinearTransformsMapping(mats) - node = mapping.to_x5() + node = mapping.to_x5( + metadata={"GeneratedBy": "FreeSurfer 8"}, store_inverse=store_inverse + ) assert node.array_length == 2 assert node.transform.shape == (2, 4, 4) + mapping.to_filename(tmp_path / "export1.x5", x5_inverse=store_inverse) + + # Test round trip + assert mapping == nitl.Affine.from_filename(tmp_path / "export1.x5", fmt="X5") + def test_mulmat_operator(testdata_path): """Check the @ operator.""" From 1b544cb575541dc8c257ab1db8e45bb84748d177 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 07:24:04 +0200 Subject: [PATCH 08/16] move flake8 config to pyproject --- pyproject.toml | 9 +++++++++ setup.cfg | 7 ------- 2 files changed, 9 insertions(+), 7 deletions(-) delete mode 100644 setup.cfg diff --git a/pyproject.toml b/pyproject.toml index f11e2e5e..a6ac0859 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,3 +98,12 @@ exclude_lines = [ "raise NotImplementedError", "warnings\\.warn", ] + +[tool.flake8] +max-line-length = 99 +doctests = false +ignore = [ + "E266", + "E231", + "W503", +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index f355be94..00000000 --- a/setup.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[flake8] -max-line-length = 99 -doctests = False -ignore = - E266 - E231 - W503 From 33d91ad5ad3dcf3a99485b1abd860bf4a43ad9b0 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 07:49:08 +0200 Subject: [PATCH 09/16] hotfix: make tox pickup flake8 config from ``pyproject.toml`` --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index fe549039..50d167bc 100644 --- a/tox.ini +++ b/tox.ini @@ -59,6 +59,7 @@ description = Check our style guide labels = check deps = flake8 + flake8-pyproject skip_install = true commands = flake8 nitransforms From 55e69374fdb1d325f9fc287c29b40463e8dea739 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 07:35:02 +0200 Subject: [PATCH 10/16] tst: refactor io/lta to reduce one partial line --- nitransforms/io/lta.py | 68 +++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/nitransforms/io/lta.py b/nitransforms/io/lta.py index 334266bb..1e7445bf 100644 --- a/nitransforms/io/lta.py +++ b/nitransforms/io/lta.py @@ -1,4 +1,5 @@ """Read/write linear transforms.""" + import numpy as np from nibabel.volumeutils import Recoder from nibabel.affines import voxel_sizes, from_matvec @@ -29,12 +30,12 @@ class VolumeGeometry(StringBasedStruct): template_dtype = np.dtype( [ ("valid", "i4"), # Valid values: 0, 1 - ("volume", "i4", (3, )), # width, height, depth - ("voxelsize", "f4", (3, )), # xsize, ysize, zsize + ("volume", "i4", (3,)), # width, height, depth + ("voxelsize", "f4", (3,)), # xsize, ysize, zsize ("xras", "f8", (3, 1)), # x_r, x_a, x_s ("yras", "f8", (3, 1)), # y_r, y_a, y_s ("zras", "f8", (3, 1)), # z_r, z_a, z_s - ("cras", "f8", (3, )), # c_r, c_a, c_s + ("cras", "f8", (3,)), # c_r, c_a, c_s ("filename", "U1024"), ] ) # Not conformant (may be >1024 bytes) @@ -109,14 +110,19 @@ def from_string(cls, string): label, valstring = lines.pop(0).split(" =") assert label.strip() == key - val = "" - if valstring.strip(): - parsed = np.genfromtxt( + parsed = ( + np.genfromtxt( [valstring.encode()], autostrip=True, dtype=cls.dtype[key] ) - if parsed.size: - val = parsed.reshape(sa[key].shape) - sa[key] = val + if valstring.strip() + else None + ) + + if parsed is not None and parsed.size: + sa[key] = parsed.reshape(sa[key].shape) + else: # pragma: no coverage + """Do not set sa[key]""" + return volgeom @@ -218,11 +224,15 @@ def to_ras(self, moving=None, reference=None): def to_string(self, partial=False): """Convert this transform to text.""" sa = self.structarr - lines = [ - "# LTA file created by NiTransforms", - "type = {}".format(sa["type"]), - "nxforms = 1", - ] if not partial else [] + lines = ( + [ + "# LTA file created by NiTransforms", + "type = {}".format(sa["type"]), + "nxforms = 1", + ] + if not partial + else [] + ) # Standard preamble lines += [ @@ -232,10 +242,7 @@ def to_string(self, partial=False): ] # Format parameters matrix - lines += [ - " ".join(f"{v:18.15e}" for v in sa["m_L"][i]) - for i in range(4) - ] + lines += [" ".join(f"{v:18.15e}" for v in sa["m_L"][i]) for i in range(4)] lines += [ "src volume info", @@ -324,10 +331,7 @@ def __getitem__(self, idx): def to_ras(self, moving=None, reference=None): """Set type to RAS2RAS and return the new matrix.""" self.structarr["type"] = 1 - return [ - xfm.to_ras(moving=moving, reference=reference) - for xfm in self.xforms - ] + return [xfm.to_ras(moving=moving, reference=reference) for xfm in self.xforms] def to_string(self): """Convert this LTA into text format.""" @@ -396,9 +400,11 @@ def from_ras(cls, ras, moving=None, reference=None): sa["type"] = 1 sa["nxforms"] = ras.shape[0] for i in range(sa["nxforms"]): - lt._xforms.append(cls._inner_type.from_ras( - ras[i, ...], moving=moving, reference=reference - )) + lt._xforms.append( + cls._inner_type.from_ras( + ras[i, ...], moving=moving, reference=reference + ) + ) sa["subject"] = "unset" sa["fscale"] = 0.0 @@ -407,8 +413,10 @@ def from_ras(cls, ras, moving=None, reference=None): def _drop_comments(string): """Drop comments.""" - return "\n".join([ - line.split("#")[0].strip() - for line in string.splitlines() - if line.split("#")[0].strip() - ]) + return "\n".join( + [ + line.split("#")[0].strip() + for line in string.splitlines() + if line.split("#")[0].strip() + ] + ) From 987eaa8dbcc75f1e17738dca4a007f12b6350f8f Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 10:33:21 +0200 Subject: [PATCH 11/16] Add failing test for serialized resampling --- nitransforms/tests/test_resampling.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index 2384ad97..50533121 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -363,3 +363,25 @@ def test_LinearTransformsMapping_apply( reference=testdata_path / "sbref.nii.gz", serialize_nvols=2 if serialize_4d else np.inf, ) + + +def test_apply_serialized_4d_multiple_targets(): + """Regression test for per-volume transforms with serialized resampling.""" + nvols = 9 + shape = (10, 5, 5) + base = np.zeros(shape, dtype=np.float32) + base[9, 2, 2] = 1 + img = nb.Nifti1Image(np.stack([base] * nvols, axis=-1), np.eye(4)) + + transforms = [] + for i in range(nvols): + mat = np.eye(4) + mat[0, 3] = i + transforms.append(nitl.Affine(mat)) + + xfm = nitl.LinearTransformsMapping(transforms, reference=img) + moved = apply(xfm, img, order=0) + + data = np.asanyarray(moved.dataobj) + idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)] + assert idxs == [(9 - i, 2, 2) for i in range(nvols)] From 5f08a36d0ceabb263ce6c44ec699976c2cefde17 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 11:02:31 +0200 Subject: [PATCH 12/16] enh: add docstring and doctests for `nitransforms.io.get_linear_factory` --- nitransforms/io/__init__.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/nitransforms/io/__init__.py b/nitransforms/io/__init__.py index f9030724..a2ec7e6b 100644 --- a/nitransforms/io/__init__.py +++ b/nitransforms/io/__init__.py @@ -1,6 +1,7 @@ # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Read and write transforms.""" + from nitransforms.io import afni, fsl, itk, lta, x5 from nitransforms.io.base import TransformIOError, TransformFileError @@ -27,7 +28,37 @@ def get_linear_factory(fmt, is_array=True): - """Return the type required by a given format.""" + """ + Return the type required by a given format. + + Parameters + ---------- + fmt : :obj:`str` + A format identifying string. + is_array : :obj:`bool` + Whether the array version of the class should be returned. + + Returns + ------- + type + The class object (not an instance) of the linear transfrom to be created + (for example, :obj:`~nitransforms.io.itk.ITKLinearTransform`). + + Examples + -------- + >>> get_linear_factory("itk") + + >>> get_linear_factory("itk", is_array=False) + + >>> get_linear_factory("fsl") + + >>> get_linear_factory("fsl", is_array=False) + + >>> get_linear_factory("fakepackage") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + TypeError: Unsupported transform format . + + """ if fmt.lower() not in _IO_TYPES: raise TypeError(f"Unsupported transform format <{fmt}>.") From 5da27b1b37735a3633a2e58afc5c6d2a9f257ce6 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 11:46:28 +0200 Subject: [PATCH 13/16] FIX: recompute targets for serialized per-volume resampling --- nitransforms/resampling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 53750206..a4a8bf42 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -110,8 +110,9 @@ async def _apply_serial( for t in range(n_resamplings): xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image + targets_t = targets + if targets_t is None: + targets_t = ImageGrid(spatialimage).index( # data should be an image _as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim) ) @@ -127,7 +128,7 @@ async def _apply_serial( partial( ndi.map_coordinates, data_t, - targets, + targets_t, output=output[..., t], order=order, mode=mode, From b562eb3823f712672ceac8fa71dc3c3f97ec3c1e Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 12:03:10 +0200 Subject: [PATCH 14/16] fix: remove ``__iter__()`` as iterator protocol is not met --- nitransforms/linear.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 8797a1c8..26bf3374 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -377,11 +377,6 @@ def __init__(self, transforms, reference=None): ) self._inverse = np.linalg.inv(self._matrix) - def __iter__(self): - """Enable iterating over the series of transforms.""" - for _m in self.matrix: - yield Affine(_m, reference=self._reference) - def __getitem__(self, i): """Enable indexed access to the series of matrices.""" return Affine(self.matrix[i, ...], reference=self._reference) From 72cd04f1dfd2ed81191616aefdf2f6d7f8b58fe5 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 12:07:48 +0200 Subject: [PATCH 15/16] fix: recompute coordinates per volume in serial resampling --- nitransforms/resampling.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index a4a8bf42..be658011 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -110,11 +110,13 @@ async def _apply_serial( for t in range(n_resamplings): xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] - targets_t = targets - if targets_t is None: - targets_t = ImageGrid(spatialimage).index( # data should be an image + targets_t = ( + ImageGrid(spatialimage).index( _as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim) ) + if targets is None + else targets + ) data_t = ( data From 4e159c2ec38cdfd435853eafea4c256034e246d5 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Fri, 18 Jul 2025 17:08:32 +0200 Subject: [PATCH 16/16] fix: generalize targets, test all branches --- nitransforms/resampling.py | 65 +++++++++++++++++++-------- nitransforms/tests/test_resampling.py | 7 ++- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index be658011..98ef4454 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -10,6 +10,7 @@ import asyncio from os import cpu_count +from contextlib import suppress from functools import partial from pathlib import Path from typing import Callable, TypeVar, Union @@ -108,14 +109,16 @@ async def _apply_serial( semaphore = asyncio.Semaphore(max_concurrent) for t in range(n_resamplings): - xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] + xfm_t = ( + transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] + ) targets_t = ( ImageGrid(spatialimage).index( _as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim) ) if targets is None - else targets + else targets[t, ...] ) data_t = ( @@ -258,11 +261,22 @@ def apply( dim=_ref.ndim, ) ) - elif xfm_nvols == 1: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + else: + # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints. + targets = ( + ImageGrid(spatialimage).index( + _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + ) + if targets is None + else targets ) + if targets.ndim == 3: + targets = np.rollaxis(targets, targets.ndim - 1, 0) + else: + assert targets.ndim == 2 + targets = targets[np.newaxis, ...] + if serialize_4d: data = ( np.asanyarray(spatialimage.dataobj, dtype=input_dtype) @@ -297,17 +311,24 @@ def apply( else: data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype) - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) - ) - + if data_nvols == 1 and xfm_nvols == 1: + targets = np.squeeze(targets) + assert targets.ndim == 2 # Cast 3D data into 4D if 4D nonsequential transform - if data_nvols == 1 and xfm_nvols > 1: + elif data_nvols == 1 and xfm_nvols > 1: data = data[..., np.newaxis] - if transform.ndim == 4: - targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T + if xfm_nvols > 1: + assert targets.ndim == 3 + n_time, n_dim, n_vox = targets.shape + # Reshape to (3, n_time x n_vox) + ijk_targets = np.rollaxis(targets, 0, 2).reshape((n_dim, -1)) + time_row = np.repeat(np.arange(n_time), n_vox)[None, :] + + # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k) + # t is the slowest-changing axis, so we put it first + targets = np.vstack((time_row, ijk_targets)) + data = np.rollaxis(data, data.ndim - 1, 0) resampled = ndi.map_coordinates( data, @@ -326,11 +347,19 @@ def apply( ) hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype()) - moved = spatialimage.__class__( - resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)), - _ref.affine, - hdr, - ) + if serialize_4d: + resampled = resampled.reshape( + _ref.shape + if n_resamplings == 1 + else _ref.shape + (resampled.shape[-1],) + ) + else: + resampled = resampled.reshape((-1, *_ref.shape)) + resampled = np.rollaxis(resampled, 0, resampled.ndim) + with suppress(ValueError): + resampled = np.squeeze(resampled, axis=3) + + moved = spatialimage.__class__(resampled, _ref.affine, hdr) return moved output_dtype = output_dtype or input_dtype diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index 50533121..0e11df5b 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -365,7 +365,8 @@ def test_LinearTransformsMapping_apply( ) -def test_apply_serialized_4d_multiple_targets(): +@pytest.mark.parametrize("serialize_4d", [True, False]) +def test_apply_4d(serialize_4d): """Regression test for per-volume transforms with serialized resampling.""" nvols = 9 shape = (10, 5, 5) @@ -379,9 +380,11 @@ def test_apply_serialized_4d_multiple_targets(): mat[0, 3] = i transforms.append(nitl.Affine(mat)) + extraparams = {} if serialize_4d else {"serialize_nvols": nvols + 1} + xfm = nitl.LinearTransformsMapping(transforms, reference=img) - moved = apply(xfm, img, order=0) + moved = apply(xfm, img, order=0, **extraparams) data = np.asanyarray(moved.dataobj) idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)] assert idxs == [(9 - i, 2, 2) for i in range(nvols)]