|
8 | 8 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## |
9 | 9 | """Nonlinear transforms.""" |
10 | 10 | import warnings |
| 11 | +from pathlib import Path |
11 | 12 | import numpy as np |
12 | | -from .base import TransformBase |
13 | | -from . import io |
| 13 | +from scipy.sparse import sparse_vstack |
| 14 | +from scipy import ndimage as ndi |
| 15 | +from nibabel.funcs import four_to_three |
| 16 | +from nibabel.loadsave import load as _nbload |
14 | 17 |
|
15 | | -# from .base import ImageGrid |
16 | | -# from nibabel.funcs import four_to_three |
| 18 | +from . import io |
| 19 | +from .interp.bspline import grid_bspline_weights |
| 20 | +from .base import ( |
| 21 | + TransformBase, |
| 22 | + ImageGrid, |
| 23 | + SpatialReference, |
| 24 | + _as_homogeneous, |
| 25 | +) |
17 | 26 |
|
18 | 27 |
|
19 | 28 | class DisplacementsFieldTransform(TransformBase): |
@@ -90,70 +99,118 @@ def from_filename(cls, filename, fmt="X5"): |
90 | 99 |
|
91 | 100 | load = DisplacementsFieldTransform.from_filename |
92 | 101 |
|
93 | | -# class BSplineFieldTransform(TransformBase): |
94 | | -# """Represent a nonlinear transform parameterized by BSpline basis.""" |
95 | | - |
96 | | -# __slots__ = ['_coeffs', '_knots', '_refknots', '_order', '_moving'] |
97 | | -# __s = (slice(None), ) |
98 | | - |
99 | | -# def __init__(self, reference, coefficients, order=3): |
100 | | -# """Create a smooth deformation field using B-Spline basis.""" |
101 | | -# super(BSplineFieldTransform, self).__init__() |
102 | | -# self._order = order |
103 | | -# self.reference = reference |
104 | | - |
105 | | -# if coefficients.shape[-1] != self.ndim: |
106 | | -# raise ValueError( |
107 | | -# 'Number of components of the coefficients does ' |
108 | | -# 'not match the number of dimensions') |
109 | | - |
110 | | -# self._coeffs = np.asanyarray(coefficients.dataobj) |
111 | | -# self._knots = ImageGrid(four_to_three(coefficients)[0]) |
112 | | -# self._cache_moving() |
113 | | - |
114 | | -# def _cache_moving(self): |
115 | | -# self._moving = np.zeros((self.reference.shape) + (3, ), |
116 | | -# dtype='float32') |
117 | | -# ijk = np.moveaxis(self.reference.ndindex, 0, -1).reshape(-1, self.ndim) |
118 | | -# xyz = np.moveaxis(self.reference.ndcoords, 0, -1).reshape(-1, self.ndim) |
119 | | -# print(np.shape(xyz)) |
120 | | - |
121 | | -# for i in range(np.shape(xyz)[0]): |
122 | | -# print(i, xyz[i, :]) |
123 | | -# self._moving[tuple(ijk[i]) + self.__s] = self._interp_transform(xyz[i, :]) |
124 | | - |
125 | | -# def _interp_transform(self, coords): |
126 | | -# # Calculate position in the grid of control points |
127 | | -# knots_ijk = self._knots.inverse.dot(np.hstack((coords, 1)))[:3] |
128 | | -# neighbors = [] |
129 | | -# offset = 0.0 if self._order & 1 else 0.5 |
130 | | -# # Calculate neighbors along each dimension |
131 | | -# for dim in range(self.ndim): |
132 | | -# first = int(np.floor(knots_ijk[dim] + offset) - self._order // 2) |
133 | | -# neighbors.append(list(range(first, first + self._order + 1))) |
134 | | - |
135 | | -# # Get indexes of the neighborings clique |
136 | | -# ndindex = np.moveaxis( |
137 | | -# np.array(np.meshgrid(*neighbors, indexing='ij')), 0, -1).reshape( |
138 | | -# -1, self.ndim) |
139 | | - |
140 | | -# # Calculate the tensor B-spline weights of each neighbor |
141 | | -# # weights = np.prod(vbspl(ndindex - knots_ijk), axis=-1) |
142 | | -# ndindex = [tuple(v) for v in ndindex] |
143 | | - |
144 | | -# # Retrieve coefficients and deal with boundary conditions |
145 | | -# zero = np.zeros(self.ndim) |
146 | | -# shape = np.array(self._knots.shape) |
147 | | -# coeffs = [] |
148 | | -# for ijk in ndindex: |
149 | | -# offbounds = (zero > ijk) | (shape <= ijk) |
150 | | -# coeffs.append( |
151 | | -# self._coeffs[ijk] if not np.any(offbounds) |
152 | | -# else [0.0] * self.ndim) |
153 | | - |
154 | | -# # coords[:3] += weights.dot(np.array(coeffs, dtype=float)) |
155 | | -# return self.reference.inverse.dot(np.hstack((coords, 1)))[:3] |
156 | | - |
157 | | -# def _map_voxel(self, index, moving=None): |
158 | | -# """Apply ijk' = f_ijk((i, j, k)), equivalent to the above with indexes.""" |
159 | | -# return tuple(self._moving[index + self.__s]) |
| 102 | + |
| 103 | +class BSplineFieldTransform(TransformBase): |
| 104 | + """Represent a nonlinear transform parameterized by BSpline basis.""" |
| 105 | + |
| 106 | + __slots__ = ['_coeffs', '_knots', '_weights', '_order', '_moving'] |
| 107 | + __s = (slice(None), ) |
| 108 | + |
| 109 | + def __init__(self, reference, coefficients, order=3): |
| 110 | + """Create a smooth deformation field using B-Spline basis.""" |
| 111 | + super(BSplineFieldTransform, self).__init__() |
| 112 | + self._order = order |
| 113 | + self.reference = reference |
| 114 | + |
| 115 | + if coefficients.shape[-1] != self.ndim: |
| 116 | + raise ValueError( |
| 117 | + 'Number of components of the coefficients does ' |
| 118 | + 'not match the number of dimensions') |
| 119 | + |
| 120 | + self._coeffs = np.asanyarray(coefficients.dataobj) |
| 121 | + self._knots = ImageGrid(four_to_three(coefficients)[0]) |
| 122 | + self._weights = None |
| 123 | + |
| 124 | + def apply( |
| 125 | + self, |
| 126 | + spatialimage, |
| 127 | + reference=None, |
| 128 | + order=3, |
| 129 | + mode="constant", |
| 130 | + cval=0.0, |
| 131 | + prefilter=True, |
| 132 | + output_dtype=None, |
| 133 | + ): |
| 134 | + """Apply a B-Spline transform on input data.""" |
| 135 | + |
| 136 | + if reference is not None and isinstance(reference, (str, Path)): |
| 137 | + reference = _nbload(str(reference)) |
| 138 | + |
| 139 | + _ref = ( |
| 140 | + self.reference if reference is None else SpatialReference.factory(reference) |
| 141 | + ) |
| 142 | + |
| 143 | + if isinstance(spatialimage, (str, Path)): |
| 144 | + spatialimage = _nbload(str(spatialimage)) |
| 145 | + |
| 146 | + if not isinstance(_ref, ImageGrid): |
| 147 | + return super().apply( |
| 148 | + spatialimage, |
| 149 | + reference=reference, |
| 150 | + order=order, |
| 151 | + mode=mode, |
| 152 | + cval=cval, |
| 153 | + prefilter=prefilter, |
| 154 | + output_dtype=output_dtype, |
| 155 | + ) |
| 156 | + |
| 157 | + # If locations to be interpolated are on a grid, use faster tensor-bspline calculation |
| 158 | + if self._weights is None: |
| 159 | + self._weights = grid_bspline_weights(_ref, self._knots) |
| 160 | + |
| 161 | + ycoords = _ref.ndcoords.T + ( |
| 162 | + np.squeeze(np.hstack(self._coeffs).T) @ sparse_vstack(self._weights) |
| 163 | + ) |
| 164 | + |
| 165 | + data = np.squeeze(np.asanyarray(spatialimage.dataobj)) |
| 166 | + output_dtype = output_dtype or data.dtype |
| 167 | + targets = ImageGrid(spatialimage).index( # data should be an image |
| 168 | + _as_homogeneous(np.vstack(ycoords), dim=_ref.ndim) |
| 169 | + ) |
| 170 | + |
| 171 | + if data.ndim == 4: |
| 172 | + if len(self) != data.shape[-1]: |
| 173 | + raise ValueError( |
| 174 | + "Attempting to apply %d transforms on a file with " |
| 175 | + "%d timepoints" % (len(self), data.shape[-1]) |
| 176 | + ) |
| 177 | + targets = targets.reshape((len(self), -1, targets.shape[-1])) |
| 178 | + resampled = np.stack( |
| 179 | + [ |
| 180 | + ndi.map_coordinates( |
| 181 | + data[..., t], |
| 182 | + targets[t, ..., : _ref.ndim].T, |
| 183 | + output=output_dtype, |
| 184 | + order=order, |
| 185 | + mode=mode, |
| 186 | + cval=cval, |
| 187 | + prefilter=prefilter, |
| 188 | + ) |
| 189 | + for t in range(data.shape[-1]) |
| 190 | + ], |
| 191 | + axis=0, |
| 192 | + ) |
| 193 | + elif data.ndim in (2, 3): |
| 194 | + resampled = ndi.map_coordinates( |
| 195 | + data, |
| 196 | + targets[..., : _ref.ndim].T, |
| 197 | + output=output_dtype, |
| 198 | + order=order, |
| 199 | + mode=mode, |
| 200 | + cval=cval, |
| 201 | + prefilter=prefilter, |
| 202 | + ) |
| 203 | + |
| 204 | + newdata = resampled.reshape((len(self), *_ref.shape)) |
| 205 | + moved = spatialimage.__class__( |
| 206 | + np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header |
| 207 | + ) |
| 208 | + moved.header.set_data_dtype(output_dtype) |
| 209 | + return moved |
| 210 | + |
| 211 | + def map(self, x, inverse=False): |
| 212 | + raise NotImplementedError |
| 213 | + |
| 214 | + def _map_voxel(self, index, moving=None): |
| 215 | + """Apply ijk' = f_ijk((i, j, k)), equivalent to the above with indexes.""" |
| 216 | + raise NotImplementedError |
0 commit comments