@@ -203,18 +203,33 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
203203 """Create an affine from a transform file."""
204204 fmtlist = [fmt ] if fmt is not None else ("itk" , "lta" , "afni" , "fsl" )
205205
206+ is_array = cls != Affine
207+
208+ errors = []
206209 for potential_fmt in fmtlist :
210+ if (potential_fmt == "itk" and Path (filename ).suffix == ".mat" ):
211+ is_array = False
212+ cls = Affine
213+
207214 try :
208- struct = get_linear_factory (potential_fmt ).from_filename (filename )
209- matrix = struct .to_ras (reference = reference , moving = moving )
210- if cls == Affine :
211- if np .shape (matrix )[0 ] != 1 :
212- raise TypeError ("Cannot load transform array '%s'" % filename )
213- matrix = matrix [0 ]
214- return cls (matrix , reference = reference )
215- except (TransformFileError , FileNotFoundError ):
215+ struct = get_linear_factory (
216+ potential_fmt ,
217+ is_array = is_array
218+ ).from_filename (filename )
219+ except (TransformFileError , FileNotFoundError ) as err :
220+ errors .append ((potential_fmt , err ))
216221 continue
217222
223+ matrix = struct .to_ras (reference = reference , moving = moving )
224+
225+ # Process matrix
226+ if not is_array and np .ndim (matrix ) == 3 :
227+ if np .shape (matrix )[0 ] != 1 :
228+ raise TypeError ("Cannot load transform array '%s'" % filename )
229+ matrix = matrix [0 ]
230+
231+ return cls (matrix , reference = reference )
232+
218233 raise TransformFileError (
219234 f"Could not open <{ filename } > (formats tried: { ', ' .join (fmtlist )} )."
220235 )
@@ -499,6 +514,8 @@ def load(filename, fmt=None, reference=None, moving=None):
499514 xfm = LinearTransformsMapping .from_filename (
500515 filename , fmt = fmt , reference = reference , moving = moving
501516 )
502- if len (xfm ) == 1 :
503- return xfm [0 ]
517+
518+ if isinstance (xfm , LinearTransformsMapping ) and len (xfm ) == 1 :
519+ xfm = xfm [0 ]
520+
504521 return xfm
0 commit comments