1212from functools import partial
1313from collections import namedtuple
1414import numpy as np
15+ import nibabel as nb
1516
1617from nitransforms import io
1718from nitransforms .io .base import _ensure_image
19+ from nitransforms .io .x5 import from_filename as load_x5
1820from nitransforms .interp .bspline import grid_bspline_weights , _cubic_bspline
1921from nitransforms .base import (
2022 TransformBase ,
3436class DenseFieldTransform (TransformBase ):
3537 """Represents dense field (voxel-wise) transforms."""
3638
37- __slots__ = ("_field" , "_deltas" )
39+ __slots__ = ("_field" , "_deltas" , "_is_deltas" )
3840
3941 def __init__ (self , field = None , is_deltas = True , reference = None ):
4042 """
@@ -68,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):
6870
6971 super ().__init__ ()
7072
71- if field is not None :
72- field = _ensure_image (field )
73- self ._field = np .squeeze (
74- np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
75- )
76- else :
77- self ._field = np .zeros ((* reference .shape , reference .ndim ), dtype = "float32" )
78- is_deltas = True
73+ self ._is_deltas = is_deltas
7974
8075 try :
8176 self .reference = ImageGrid (reference if reference is not None else field )
@@ -86,22 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
8681 else "Reference is not a spatial image"
8782 )
8883
84+ fieldshape = (* self .reference .shape , self .reference .ndim )
85+ if field is not None :
86+ field = _ensure_image (field )
87+ self ._field = np .squeeze (
88+ np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
89+ )
90+ if fieldshape != self ._field .shape :
91+ raise TransformError (
92+ f"Shape of the field ({ 'x' .join (str (i ) for i in self ._field .shape )} ) "
93+ f"doesn't match that of the reference({ 'x' .join (str (i ) for i in fieldshape )} )"
94+ )
95+ else :
96+ self ._field = np .zeros (fieldshape , dtype = "float32" )
97+ self ._is_deltas = True
98+
8999 if self ._field .shape [- 1 ] != self .ndim :
90100 raise TransformError (
91101 "The number of components of the field (%d) does not match "
92102 "the number of dimensions (%d)" % (self ._field .shape [- 1 ], self .ndim )
93103 )
94104
95- if is_deltas :
96- self ._deltas = self ._field
105+ if self ._is_deltas :
106+ self ._deltas = (
107+ self ._field .copy ()
108+ ) # IMPORTANT: you don't want to update deltas
97109 # Convert from displacements (deltas) to deformations fields
98110 # (just add its origin to each delta vector)
99- self ._field += self .reference .ndcoords .T .reshape (self . _field . shape )
111+ self ._field += self .reference .ndcoords .T .reshape (fieldshape )
100112
101113 def __repr__ (self ):
102114 """Beautify the python representation."""
103115 return f"<{ self .__class__ .__name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
104116
117+ @property
118+ def is_deltas (self ):
119+ """Check whether this is a displacements (``True``) or a deformation (``False``) field."""
120+ return self ._is_deltas
121+
105122 @property
106123 def ndim (self ):
107124 """Get the dimensions of the transform."""
@@ -230,7 +247,7 @@ def __eq__(self, other):
230247 True
231248
232249 """
233- _eq = np .array_equal (self ._field , other ._field )
250+ _eq = np .allclose (self ._field , other ._field )
234251 if _eq and self ._reference != other ._reference :
235252 warnings .warn ("Fields are equal, but references do not match." )
236253 return _eq
@@ -253,9 +270,9 @@ def to_x5(self, metadata=None):
253270 return io .x5 .X5Transform (
254271 type = "nonlinear" ,
255272 subtype = "densefield" ,
256- representation = "displacements" ,
273+ representation = "displacements" if self . is_deltas else "deformations" ,
257274 metadata = metadata ,
258- transform = self ._deltas ,
275+ transform = self ._deltas if self . is_deltas else self . _field ,
259276 dimension_kinds = kinds ,
260277 domain = domain ,
261278 )
@@ -273,12 +290,15 @@ def from_filename(cls, filename, fmt="X5"):
273290 raise NotImplementedError (f"Unsupported format <{ fmt } >" )
274291
275292 if fmt == "X5" :
276- from .io .x5 import from_filename as load_x5
277-
278293 x5_xfm = load_x5 (filename )[0 ]
279294 Domain = namedtuple ("Domain" , "affine shape" )
280295 reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
281- return cls (x5_xfm .transform , is_deltas = True , reference = reference )
296+ field = nb .Nifti1Image (x5_xfm .transform , reference .affine )
297+ return cls (
298+ field ,
299+ is_deltas = x5_xfm .representation == "displacements" ,
300+ reference = reference ,
301+ )
282302
283303 return cls (_factory [fmt .lower ()].from_filename (filename ))
284304
@@ -315,6 +335,24 @@ def ndim(self):
315335 """Get the dimensions of the transform."""
316336 return self ._coeffs .ndim - 1
317337
338+ @classmethod
339+ def from_filename (cls , filename , fmt = "X5" ):
340+ _factory = {
341+ "X5" : None ,
342+ }
343+ fmt = fmt .upper ()
344+ if fmt not in {k .upper () for k in _factory }:
345+ raise NotImplementedError (f"Unsupported format <{ fmt } >" )
346+
347+ x5_xfm = load_x5 (filename )[0 ]
348+ Domain = namedtuple ("Domain" , "affine shape" )
349+ reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
350+
351+ coefficients = nb .Nifti1Image (x5_xfm .transform , x5_xfm .additional_parameters )
352+ return cls (coefficients , reference = reference )
353+
354+ # return cls(_factory[fmt.lower()].from_filename(filename))
355+
318356 def to_field (self , reference = None , dtype = "float32" ):
319357 """Generate a displacements deformation field from this B-Spline field."""
320358 _ref = (
@@ -349,21 +387,17 @@ def to_x5(self, metadata=None):
349387 coordinates = "cartesian" ,
350388 )
351389
352- meta = metadata | {
353- "KnotsAffine" : self ._knots .affine .tolist (),
354- "KnotsShape" : self ._knots .shape ,
355- }
356-
357390 kinds = tuple ("space" for _ in range (self .ndim )) + ("vector" ,)
358391
359392 return io .x5 .X5Transform (
360393 type = "nonlinear" ,
361394 subtype = "bspline" ,
362395 representation = "coefficients" ,
363- metadata = meta ,
396+ metadata = metadata ,
364397 transform = self ._coeffs ,
365398 dimension_kinds = kinds ,
366399 domain = domain ,
400+ additional_parameters = self ._knots .affine ,
367401 )
368402
369403 def map (self , x , inverse = False ):
0 commit comments