@@ -162,9 +162,14 @@ class EstimateResponseSHInputSpec(DipyBaseInterfaceInputSpec):
162162 in_mask = File (
163163 exists = True , desc = ('input mask in which we find single fibers' ))
164164 fa_thresh = traits .Float (
165- 0.7 , usedefault = True , desc = ('default FA threshold' ))
165+ 0.7 , usedefault = True , desc = ('FA threshold' ))
166+ roi_radius = traits .Int (
167+ 10 , usedefault = True , desc = ('ROI radius to be used in auto_response' ))
168+ auto = traits .Bool (
169+ True , usedefault = True , xor = ['recursive' ],
170+ desc = 'use the auto_response estimator from dipy' )
166171 recursive = traits .Bool (
167- False , usedefault = True ,
172+ False , usedefault = True , xor = [ 'auto' ],
168173 desc = 'use the recursive response estimator from dipy' )
169174 response = File (
170175 'response.txt' , usedefault = True , desc = ('the output response file' ))
@@ -203,7 +208,7 @@ class EstimateResponseSH(DipyBaseInterface):
203208 def _run_interface (self , runtime ):
204209 from dipy .core .gradients import GradientTable
205210 from dipy .reconst .dti import fractional_anisotropy , mean_diffusivity
206- from dipy .reconst .csdeconv import recursive_response
211+ from dipy .reconst .csdeconv import recursive_response , auto_response
207212
208213 img = nb .load (self .inputs .in_file )
209214 affine = img .get_affine ()
@@ -218,23 +223,18 @@ def _run_interface(self, runtime):
218223 data = img .get_data ().astype (np .float32 )
219224 gtab = self ._get_gradient_table ()
220225
221- evals = nb .load (self .inputs .in_evals ).get_data ()
226+ evals = np . nan_to_num ( nb .load (self .inputs .in_evals ).get_data () )
222227 FA = np .nan_to_num (fractional_anisotropy (evals )) * msk
223-
224- if not self .inputs .recursive :
225- indices = np .where (FA > self .inputs .fa_thresh )
226- lambdas = evals [indices ][:, :2 ]
227- S0s = data [indices ][:, np .nonzero (gtab .b0s_mask )[0 ]]
228- S0 = np .mean (S0s )
229- l01 = np .mean (lambdas , axis = 0 )
230- respev = np .array ([l01 [0 ], l01 [1 ], l01 [1 ]])
231- response = np .array (respev .tolist () + [S0 ]).reshape (- 1 )
232-
233- ratio = abs (respev [1 ] / respev [0 ])
234- if ratio > 0.25 :
235- iflogger .warn (('Estimated response is not prolate enough. '
236- 'Ratio=%0.3f.' ) % ratio )
237- else :
228+ indices = np .where (FA > self .inputs .fa_thresh )
229+ S0s = data [indices ][:, np .nonzero (gtab .b0s_mask )[0 ]]
230+ S0 = np .mean (S0s )
231+
232+ if self .inputs .auto :
233+ response , ratio = auto_response (gtab , data ,
234+ roi_radius = self .inputs .roi_radius ,
235+ fa_thr = self .inputs .fa_thresh )
236+ response = response [0 ].tolist () + [S0 ]
237+ elif self .inputs .recursive :
238238 MD = np .nan_to_num (mean_diffusivity (evals )) * msk
239239 indices = np .logical_or (
240240 FA >= 0.4 , (np .logical_and (FA >= 0.15 , MD >= 0.0011 )))
@@ -244,6 +244,23 @@ def _run_interface(self, runtime):
244244 init_trace = 0.0021 , iter = 8 ,
245245 convergence = 0.001 ,
246246 parallel = True )
247+ ratio = abs (response [1 ] / response [0 ])
248+ else :
249+ lambdas = evals [indices ]
250+ l01 = np .sort (np .mean (lambdas , axis = 0 ))
251+
252+ response = np .array ([l01 [- 1 ], l01 [- 2 ], l01 [- 2 ], S0 ])
253+ ratio = abs (response [1 ] / response [0 ])
254+
255+ if ratio > 0.25 :
256+ iflogger .warn (('Estimated response is not prolate enough. '
257+ 'Ratio=%0.3f.' ) % ratio )
258+ elif ratio < 1.e-5 or np .any (np .isnan (response )):
259+ response = np .array ([1.8e-3 , 3.6e-4 , 3.6e-4 , S0 ])
260+ iflogger .warn (
261+ ('Estimated response is not valid, using a default one' ))
262+ else :
263+ iflogger .info (('Estimated response: %s' ) % str (response [:3 ]))
247264
248265 np .savetxt (op .abspath (self .inputs .response ), response )
249266
@@ -252,7 +269,6 @@ def _run_interface(self, runtime):
252269 nb .Nifti1Image (
253270 wm_mask .astype (np .uint8 ), affine ,
254271 None ).to_filename (op .abspath (self .inputs .out_mask ))
255-
256272 return runtime
257273
258274 def _list_outputs (self ):
0 commit comments