@@ -163,14 +163,17 @@ class EstimateResponseSHInputSpec(DipyBaseInterfaceInputSpec):
163163 exists = True , desc = ('input mask in which we find single fibers' ))
164164 fa_thresh = traits .Float (
165165 0.7 , usedefault = True , desc = ('default FA threshold' ))
166- save_glyph = traits .Bool (False , usedefault = True ,
167- desc = ('save a png file of the response' ))
168- response = File (desc = ('the output response file' ))
166+ recursive = traits .Bool (
167+ False , usedefault = True ,
168+ desc = 'use the recursive response estimator from dipy' )
169+ response = File (
170+ 'response.txt' , usedefault = True , desc = ('the output response file' ))
171+ out_mask = File ('wm_mask.nii.gz' , usedefault = True , desc = 'computed wm mask' )
169172
170173
171174class EstimateResponseSHOutputSpec (TraitedSpec ):
172- response = File (desc = ('the response file' ))
173- glyph_file = File (desc = 'graphical representation of the response' )
175+ response = File (exists = True , desc = ('the response file' ))
176+ out_mask = File (exists = True , desc = ( 'output wm mask' ) )
174177
175178
176179class EstimateResponseSH (DipyBaseInterface ):
@@ -199,7 +202,8 @@ class EstimateResponseSH(DipyBaseInterface):
199202
200203 def _run_interface (self , runtime ):
201204 from dipy .core .gradients import GradientTable
202- from dipy .reconst .csdeconv import fractional_anisotropy
205+ from dipy .reconst .dti import fractional_anisotropy , mean_diffusivity
206+ from dipy .reconst .csdeconv import recursive_response
203207
204208 img = nb .load (self .inputs .in_file )
205209 affine = img .get_affine ()
@@ -215,61 +219,46 @@ def _run_interface(self, runtime):
215219 gtab = self ._get_gradient_table ()
216220
217221 evals = nb .load (self .inputs .in_evals ).get_data ()
218- FA = fractional_anisotropy (evals )
219- FA [np .isnan (FA )] = 0
220- FA [msk != 1 ] = 0
221-
222- indices = np .where (FA > self .inputs .fa_thresh )
223-
224- lambdas = evals [indices ][:, :2 ]
225- S0s = data [indices ][:, np .nonzero (gtab .b0s_mask )[0 ]]
226- S0 = np .mean (S0s )
227- l01 = np .mean (lambdas , axis = 0 )
228- respev = np .array ([l01 [0 ], l01 [1 ], l01 [1 ]])
229- response = (respev , S0 )
230- ratio = respev [1 ] / respev [0 ]
231-
232- if abs (ratio - 0.2 ) > 0.1 :
233- iflogger .warn (('Estimated response is not prolate enough. '
234- 'Ratio=%0.3f.' ) % ratio )
235-
236- np .savetxt (self ._gen_outname (),
237- np .array (respev .tolist () + [S0 ]).reshape (- 1 ))
238-
239- if self .inputs .save_glyph :
240- from dipy .viz import fvtk
241- from dipy .data import get_sphere
242- from dipy .sims .voxel import single_tensor_odf
222+ FA = np .nan_to_num (fractional_anisotropy (evals )) * msk
223+
224+ if not self .inputs .recursive :
225+ indices = np .where (FA > self .inputs .fa_thresh )
226+ lambdas = evals [indices ][:, :2 ]
227+ S0s = data [indices ][:, np .nonzero (gtab .b0s_mask )[0 ]]
228+ S0 = np .mean (S0s )
229+ l01 = np .mean (lambdas , axis = 0 )
230+ respev = np .array ([l01 [0 ], l01 [1 ], l01 [1 ]])
231+ response = np .array (respev .tolist () + [S0 ]).reshape (- 1 )
232+
233+ ratio = abs (respev [1 ] / respev [0 ])
234+ if ratio > 0.25 :
235+ iflogger .warn (('Estimated response is not prolate enough. '
236+ 'Ratio=%0.3f.' ) % ratio )
237+ else :
238+ MD = np .nan_to_num (mean_diffusivity (evals )) * msk
239+ indices = np .logical_or (
240+ FA >= 0.4 , (np .logical_and (FA >= 0.15 , MD >= 0.0011 )))
241+ data = nb .load (self .inputs .in_file ).get_data ()
242+ response = recursive_response (gtab , data , mask = indices , sh_order = 8 ,
243+ peak_thr = 0.01 , init_fa = 0.08 ,
244+ init_trace = 0.0021 , iter = 8 ,
245+ convergence = 0.001 ,
246+ parallel = True )
247+
248+ np .savetxt (op .abspath (self .inputs .response ), response )
249+
250+ wm_mask = np .zeros_like (FA )
251+ wm_mask [indices ] = 1
252+ nb .Nifti1Image (
253+ wm_mask .astype (np .uint8 ), affine ,
254+ None ).to_filename (op .abspath (self .inputs .out_mask ))
243255
244- ren = fvtk .ren ()
245- evecs = np .array ([[0 , 1 , 0 ], [0 , 0 , 1 ], [1 , 0 , 0 ]]).T
246- sphere = get_sphere ('symmetric724' )
247- response_odf = single_tensor_odf (sphere .vertices , respev , evecs )
248- response_actor = fvtk .sphere_funcs (response_odf , sphere )
249- fvtk .add (ren , response_actor )
250- fvtk .record (ren , out_path = self ._gen_outname () + '.png' ,
251- size = (200 , 200 ))
252- fvtk .rm (ren , response_actor )
253256 return runtime
254257
255- def _gen_outname (self ):
256- if isdefined (self .inputs .response ):
257- return self .inputs .response
258- else :
259- fname , fext = op .splitext (op .basename (self .inputs .in_file ))
260- if fext == '.gz' :
261- fname , fext2 = op .splitext (fname )
262- fext = fext2 + fext
263- return op .abspath (fname ) + '_response.txt'
264- return out_file
265-
266258 def _list_outputs (self ):
267259 outputs = self ._outputs ().get ()
268- outputs ['response' ] = self ._gen_outname ()
269-
270- if isdefined (self .inputs .save_glyph ) and self .inputs .save_glyph :
271- outputs ['glyph_file' ] = self ._gen_outname () + '.png'
272-
260+ outputs ['response' ] = op .abspath (self .inputs .response )
261+ outputs ['out_mask' ] = op .abspath (self .inputs .out_mask )
273262 return outputs
274263
275264
0 commit comments