77#
88### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99"""Common interface for transforms."""
10+
11+ import os
1012from collections .abc import Iterable
1113import numpy as np
1214
13- from .base import (
15+ import h5py
16+ from nitransforms .base import (
1417 TransformBase ,
1518 TransformError ,
1619)
17- from .linear import Affine , LinearTransformsMapping
18- from .nonlinear import DenseFieldTransform
20+ from nitransforms .io import itk , x5 as x5io
21+ from nitransforms .io .x5 import from_filename as load_x5
22+ from nitransforms .linear import (
23+ Affine ,
24+ from_x5 as linear_from_x5 , # noqa: F401
25+ )
26+ from nitransforms .nonlinear import (
27+ DenseFieldTransform ,
28+ from_x5 as nonlinear_from_x5 , # noqa: F401
29+ )
1930
2031
2132class TransformChain (TransformBase ):
@@ -183,7 +194,9 @@ def asaffine(self, indices=None):
183194 The indices of the values to extract.
184195
185196 """
186- affines = self .transforms if indices is None else np .take (self .transforms , indices )
197+ affines = (
198+ self .transforms if indices is None else np .take (self .transforms , indices )
199+ )
187200 retval = affines [0 ]
188201 for xfm in affines [1 :]:
189202 retval = xfm @ retval
@@ -192,51 +205,28 @@ def asaffine(self, indices=None):
192205 @classmethod
193206 def from_filename (cls , filename , fmt = "X5" , reference = None , moving = None , x5_chain = 0 ):
194207 """Load a transform file."""
195- from .io import itk , x5 as x5io
196- import h5py
197- import nibabel as nb
198- from collections import namedtuple
199208
200209 retval = []
201210 if fmt and fmt .upper () == "X5" :
211+ xfm_list = load_x5 (filename )
212+ if not xfm_list :
213+ raise TransformError ("Empty transform group" )
214+
202215 with h5py .File (str (filename ), "r" ) as f :
203- if f .attrs .get ("Format" ) == "X5" :
204- tg = [
205- x5io ._read_x5_group (node )
206- for _ , node in sorted (f ["TransformGroup" ].items (), key = lambda kv : int (kv [0 ]))
207- ]
208- chain_grp = f .get ("TransformChain" )
209- if chain_grp is None :
210- raise TransformError ("X5 file contains no TransformChain" )
211-
212- chain_path = chain_grp [str (x5_chain )][()]
213- if isinstance (chain_path , bytes ):
214- chain_path = chain_path .decode ()
215- indices = [int (idx ) for idx in chain_path .split ("/" ) if idx ]
216-
217- Domain = namedtuple ("Domain" , "affine shape" )
218- for idx in indices :
219- node = tg [idx ]
220- if node .type == "linear" :
221- Transform = Affine if node .array_length == 1 else LinearTransformsMapping
222- reference = None
223- if node .domain is not None :
224- reference = Domain (node .domain .mapping , node .domain .size )
225- retval .append (Transform (node .transform , reference = reference ))
226- elif node .type == "nonlinear" :
227- reference = Domain (node .domain .mapping , node .domain .size )
228- field = nb .Nifti1Image (node .transform , reference .affine )
229- retval .append (
230- DenseFieldTransform (
231- field ,
232- is_deltas = node .representation == "displacements" ,
233- reference = reference ,
234- )
235- )
236- else : # pragma: no cover - unsupported type
237- raise NotImplementedError (f"Unsupported transform type { node .type } " )
238-
239- return TransformChain (retval )
216+ chain_grp = f .get ("TransformChain" )
217+ if chain_grp is None :
218+ raise TransformError ("X5 file contains no TransformChain" )
219+
220+ chain_path = chain_grp [str (x5_chain )][()]
221+ if isinstance (chain_path , bytes ):
222+ chain_path = chain_path .decode ()
223+
224+ for idx in chain_path .split ("/" ):
225+ node = x5io ._read_x5_group (xfm_list [int (idx )])
226+ from_x5 = globals ()[f"{ node .type } _from_x5" ]
227+ retval .append (from_x5 ([node ]))
228+
229+ return TransformChain (retval )
240230
241231 if str (filename ).endswith (".h5" ):
242232 reference = None
@@ -253,57 +243,24 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None, x5_chain
253243
254244 def to_filename (self , filename , fmt = "X5" ):
255245 """Store the transform chain in X5 format."""
256- from .io import x5 as x5io
257- import os
258- import h5py
259246
260247 if fmt .upper () != "X5" :
261248 raise NotImplementedError ("Only X5 format is supported for chains" )
262249
263- if os .path .exists (filename ):
264- with h5py .File (str (filename ), "r" ) as f :
265- existing = [
266- x5io ._read_x5_group (node )
267- for _ , node in sorted (f ["TransformGroup" ].items (), key = lambda kv : int (kv [0 ]))
268- ]
269- else :
270- existing = []
271-
272- # convert to objects for equality check
273- from collections import namedtuple
274- import nibabel as nb
275-
276- def _as_transform (x5node ):
277- Domain = namedtuple ("Domain" , "affine shape" )
278- if x5node .type == "linear" :
279- Transform = Affine if x5node .array_length == 1 else LinearTransformsMapping
280- ref = None
281- if x5node .domain is not None :
282- ref = Domain (x5node .domain .mapping , x5node .domain .size )
283- return Transform (x5node .transform , reference = ref )
284- reference = Domain (x5node .domain .mapping , x5node .domain .size )
285- field = nb .Nifti1Image (x5node .transform , reference .affine )
286- return DenseFieldTransform (
287- field ,
288- is_deltas = x5node .representation == "displacements" ,
289- reference = reference ,
290- )
291-
292- existing_objs = [_as_transform (n ) for n in existing ]
293- path_indices = []
250+ existing = load_x5 (filename ) if os .path .exists (filename ) else []
251+ xfm_chain = []
294252 new_nodes = []
253+ next_xfm_index = len (existing )
295254 for xfm in self .transforms :
296- # find existing
297- idx = None
298- for i , obj in enumerate (existing_objs ):
299- if type (xfm ) is type (obj ) and xfm == obj :
300- idx = i
255+ for eidx , existing_xfm in enumerate (existing ):
256+ if xfm == existing_xfm :
257+ xfm_chain .append (eidx )
301258 break
302- if idx is None :
303- idx = len ( existing_objs )
304- new_nodes .append ((idx , xfm . to_x5 () ))
305- existing_objs .append (xfm )
306- path_indices . append ( idx )
259+ else :
260+ xfm_chain . append ( next_xfm_index )
261+ new_nodes .append ((next_xfm_index , xfm ))
262+ existing .append (xfm )
263+ next_xfm_index += 1
307264
308265 mode = "r+" if os .path .exists (filename ) else "w"
309266 with h5py .File (str (filename ), mode ) as f :
@@ -317,7 +274,7 @@ def _as_transform(x5node):
317274 x5io ._write_x5_group (g , node )
318275
319276 cg = f .require_group ("TransformChain" )
320- cg .create_dataset (str (len (cg )), data = "/" .join (str (i ) for i in path_indices ))
277+ cg .create_dataset (str (len (cg )), data = "/" .join (str (i ) for i in xfm_chain ))
321278
322279 return filename
323280
0 commit comments