@@ -262,6 +262,11 @@ class TSNRInputSpec(BaseInterfaceInputSpec):
262262 in_file = InputMultiPath (File (exists = True ), mandatory = True ,
263263 desc = 'realigned 4D file or a list of 3D files' )
264264 regress_poly = traits .Range (low = 1 , desc = 'Remove polynomials' )
265+ tsnr_file = File ('tsnr.nii.gz' , usedefault = True , desc = 'output tSNR file' )
266+ mean_file = File ('mean.nii.gz' , usedefault = True , desc = 'output mean file' )
267+ stddev_file = File ('stdev.nii.gz' , usedefault = True , desc = 'output tSNR file' )
268+ detrended_file = File ('detrend.nii.gz' , usedefault = True ,
269+ desc = 'input file after detrending' )
265270
266271
267272class TSNROutputSpec (TraitedSpec ):
@@ -287,24 +292,17 @@ class TSNR(BaseInterface):
287292 input_spec = TSNRInputSpec
288293 output_spec = TSNROutputSpec
289294
290- def _gen_output_file_name (self , suffix = None ):
291- _ , base , ext = split_filename (self .inputs .in_file [0 ])
292- if suffix in ['mean' , 'stddev' ]:
293- return os .path .abspath (base + "_tsnr_" + suffix + ext )
294- elif suffix in ['detrended' ]:
295- return os .path .abspath (base + "_" + suffix + ext )
296- else :
297- return os .path .abspath (base + "_tsnr" + ext )
298-
299295 def _run_interface (self , runtime ):
300296 img = nb .load (self .inputs .in_file [0 ])
301297 header = img .get_header ().copy ()
302298 vollist = [nb .load (filename ) for filename in self .inputs .in_file ]
303299 data = np .concatenate ([vol .get_data ().reshape (
304300 vol .get_shape ()[:3 ] + (- 1 ,)) for vol in vollist ], axis = 3 )
301+ data = data .nan_to_num ()
305302 if data .dtype .kind == 'i' :
306303 header .set_data_dtype (np .float32 )
307304 data = data .astype (np .float32 )
305+
308306 if isdefined (self .inputs .regress_poly ):
309307 timepoints = img .get_shape ()[- 1 ]
310308 X = np .ones ((timepoints , 1 ))
@@ -318,16 +316,18 @@ def _run_interface(self, runtime):
318316 0 , 4 )
319317 data = data - datahat
320318 img = nb .Nifti1Image (data , img .get_affine (), header )
321- nb .save (img , self ._gen_output_file_name ('detrended' ))
319+ nb .save (img , op .abspath (self .inputs .detrended_file ))
320+
322321 meanimg = np .mean (data , axis = 3 )
323322 stddevimg = np .std (data , axis = 3 )
324- tsnr = meanimg / stddevimg
323+ tsnr = np .zeros_like (meanimg )
324+ tsnr [stddevimg > 1.e-3 ] = meanimg [stddevimg > 1.e-3 ] / stddevimg [stddevimg > 1.e-3 ]
325325 img = nb .Nifti1Image (tsnr , img .get_affine (), header )
326- nb .save (img , self . _gen_output_file_name ( ))
326+ nb .save (img , op . abspath ( self . inputs . tsnr_file ))
327327 img = nb .Nifti1Image (meanimg , img .get_affine (), header )
328- nb .save (img , self . _gen_output_file_name ( 'mean' ))
328+ nb .save (img , op . abspath ( self . inputs . mean_file ))
329329 img = nb .Nifti1Image (stddevimg , img .get_affine (), header )
330- nb .save (img , self . _gen_output_file_name ( 'stddev' ))
330+ nb .save (img , op . abspath ( self . inputs . stddev_file ))
331331 return runtime
332332
333333 def _list_outputs (self ):
0 commit comments