1414 TransformBase ,
1515 TransformError ,
1616)
17- from .linear import Affine
17+ from .linear import Affine , LinearTransformsMapping
1818from .nonlinear import DenseFieldTransform
1919
2020
@@ -190,12 +190,15 @@ def asaffine(self, indices=None):
190190 return retval
191191
192192 @classmethod
193- def from_filename (cls , filename , fmt = "X5" , reference = None , moving = None ):
193+ def from_filename (cls , filename , fmt = "X5" , reference = None , moving = None , x5_chain = 0 ):
194194 """Load a transform file."""
195- from .io import itk
195+ from .io import itk , x5 as x5io
196+ import h5py
197+ import nibabel as nb
198+ from collections import namedtuple
196199
197200 retval = []
198- if str (filename ).endswith (".h5" ):
201+ if str (filename ).endswith (".h5" ) and ( fmt is None or fmt . upper () != "X5" ) :
199202 reference = None
200203 xforms = itk .ITKCompositeH5 .from_filename (filename )
201204 for xfmobj in xforms :
@@ -206,8 +209,120 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
206209
207210 return TransformChain (retval )
208211
212+ if fmt and fmt .upper () == "X5" :
213+ with h5py .File (str (filename ), "r" ) as f :
214+ if f .attrs .get ("Format" ) != "X5" :
215+ raise TypeError ("Input file is not in X5 format" )
216+
217+ tg = [
218+ x5io ._read_x5_group (node )
219+ for _ , node in sorted (f ["TransformGroup" ].items (), key = lambda kv : int (kv [0 ]))
220+ ]
221+ chain_grp = f .get ("TransformChain" )
222+ if chain_grp is None :
223+ raise TransformError ("X5 file contains no TransformChain" )
224+
225+ chain_path = chain_grp [str (x5_chain )][()]
226+ if isinstance (chain_path , bytes ):
227+ chain_path = chain_path .decode ()
228+ indices = [int (idx ) for idx in chain_path .split ("/" ) if idx ]
229+
230+ Domain = namedtuple ("Domain" , "affine shape" )
231+ for idx in indices :
232+ node = tg [idx ]
233+ if node .type == "linear" :
234+ Transform = Affine if node .array_length == 1 else LinearTransformsMapping
235+ reference = None
236+ if node .domain is not None :
237+ reference = Domain (node .domain .mapping , node .domain .size )
238+ retval .append (Transform (node .transform , reference = reference ))
239+ elif node .type == "nonlinear" :
240+ reference = Domain (node .domain .mapping , node .domain .size )
241+ field = nb .Nifti1Image (node .transform , reference .affine )
242+ retval .append (
243+ DenseFieldTransform (
244+ field ,
245+ is_deltas = node .representation == "displacements" ,
246+ reference = reference ,
247+ )
248+ )
249+ else : # pragma: no cover - unsupported type
250+ raise NotImplementedError (f"Unsupported transform type { node .type } " )
251+
252+ return TransformChain (retval )
253+
209254 raise NotImplementedError
210255
256+ def to_filename (self , filename , fmt = "X5" ):
257+ """Store the transform chain in X5 format."""
258+ from .io import x5 as x5io
259+ import os
260+ import h5py
261+
262+ if fmt .upper () != "X5" :
263+ raise NotImplementedError ("Only X5 format is supported for chains" )
264+
265+ if os .path .exists (filename ):
266+ with h5py .File (str (filename ), "r" ) as f :
267+ existing = [
268+ x5io ._read_x5_group (node )
269+ for _ , node in sorted (f ["TransformGroup" ].items (), key = lambda kv : int (kv [0 ]))
270+ ]
271+ else :
272+ existing = []
273+
274+ # convert to objects for equality check
275+ from collections import namedtuple
276+ import nibabel as nb
277+
278+ def _as_transform (x5node ):
279+ Domain = namedtuple ("Domain" , "affine shape" )
280+ if x5node .type == "linear" :
281+ Transform = Affine if x5node .array_length == 1 else LinearTransformsMapping
282+ ref = None
283+ if x5node .domain is not None :
284+ ref = Domain (x5node .domain .mapping , x5node .domain .size )
285+ return Transform (x5node .transform , reference = ref )
286+ reference = Domain (x5node .domain .mapping , x5node .domain .size )
287+ field = nb .Nifti1Image (x5node .transform , reference .affine )
288+ return DenseFieldTransform (
289+ field ,
290+ is_deltas = x5node .representation == "displacements" ,
291+ reference = reference ,
292+ )
293+
294+ existing_objs = [_as_transform (n ) for n in existing ]
295+ path_indices = []
296+ new_nodes = []
297+ for xfm in self .transforms :
298+ # find existing
299+ idx = None
300+ for i , obj in enumerate (existing_objs ):
301+ if type (xfm ) is type (obj ) and xfm == obj :
302+ idx = i
303+ break
304+ if idx is None :
305+ idx = len (existing_objs )
306+ new_nodes .append ((idx , xfm .to_x5 ()))
307+ existing_objs .append (xfm )
308+ path_indices .append (idx )
309+
310+ mode = "r+" if os .path .exists (filename ) else "w"
311+ with h5py .File (str (filename ), mode ) as f :
312+ if "Format" not in f .attrs :
313+ f .attrs ["Format" ] = "X5"
314+ f .attrs ["Version" ] = np .uint16 (1 )
315+
316+ tg = f .require_group ("TransformGroup" )
317+ for idx , node in new_nodes :
318+ g = tg .create_group (str (idx ))
319+ x5io ._write_x5_group (g , node )
320+
321+ cg = f .require_group ("TransformChain" )
322+ cg .create_dataset (str (len (cg )), data = "/" .join (str (i ) for i in path_indices ))
323+
324+ return filename
325+
211326
212327def _as_chain (x ):
213328 """Convert a value into a transform chain."""
0 commit comments