@@ -54,6 +54,10 @@ def get_opt_parser():
5454 " If --data-max-abs-diff is also specified, only the data points "
5555 " with absolute difference greater than that value would be "
5656 " considered for relative difference check." ),
57+ Option ("--dt" , "--datatype" ,
58+ dest = "dtype" ,
59+ default = np .float64 ,
60+ help = "Enter a numpy datatype such as 'float32'." )
5761 ])
5862
5963 return p
@@ -116,7 +120,7 @@ def get_headers_diff(file_headers, names=None):
116120 return difference
117121
118122
119- def get_data_hash_diff (files ):
123+ def get_data_hash_diff (files , dtype = np . float64 ):
120124 """Get difference between md5 values of data
121125
122126 Parameters
@@ -130,7 +134,7 @@ def get_data_hash_diff(files):
130134 """
131135
132136 md5sums = [
133- hashlib .md5 (np .ascontiguousarray (nib .load (f ).get_fdata ())).hexdigest ()
137+ hashlib .md5 (np .ascontiguousarray (nib .load (f ).get_fdata (dtype = dtype ))).hexdigest ()
134138 for f in files
135139 ]
136140
@@ -140,7 +144,7 @@ def get_data_hash_diff(files):
140144 return md5sums
141145
142146
143- def get_data_diff (files , max_abs = 0 , max_rel = 0 ):
147+ def get_data_diff (files , max_abs = 0 , max_rel = 0 , dtype = np . float64 ):
144148 """Get difference between data
145149
146150 Parameters
@@ -153,6 +157,8 @@ def get_data_diff(files, max_abs=0, max_rel=0):
153157 Maximal relative (`abs(diff)/mean(diff)`) difference to tolerate.
154158 If `max_abs` is specified, then those data points with lesser than that
155159 absolute difference, are not considered for relative difference testing
160+ dtype: np, optional
161+ Datatype to be used when extracting data from files
156162
157163 Returns
158164 -------
@@ -167,7 +173,7 @@ def get_data_diff(files, max_abs=0, max_rel=0):
167173 """
168174
169175 # we are doomed to keep them in RAM now
170- data = [f if isinstance (f , np .ndarray ) else nib .load (f ).get_fdata ()
176+ data = [f if isinstance (f , np .ndarray ) else nib .load (f ).get_fdata (dtype = dtype )
171177 for f in files ]
172178 diffs = OrderedDict ()
173179 for i , d1 in enumerate (data [:- 1 ]):
@@ -268,7 +274,7 @@ def display_diff(files, diff):
268274 return output
269275
270276
271- def diff (files , header_fields = 'all' , data_max_abs_diff = None , data_max_rel_diff = None ):
277+ def diff (files , header_fields = 'all' , data_max_abs_diff = None , data_max_rel_diff = None , dtype = np . float64 ):
272278 assert len (files ) >= 2 , "Please enter at least two files"
273279
274280 file_headers = [nib .load (f ).header for f in files ]
@@ -282,13 +288,14 @@ def diff(files, header_fields='all', data_max_abs_diff=None, data_max_rel_diff=N
282288
283289 diff = get_headers_diff (file_headers , header_fields )
284290
285- data_md5_diffs = get_data_hash_diff (files )
291+ data_md5_diffs = get_data_hash_diff (files , dtype )
286292 if data_md5_diffs :
287293 # provide details, possibly triggering the ignore of the difference
288294 # in data
289295 data_diffs = get_data_diff (files ,
290296 max_abs = data_max_abs_diff ,
291- max_rel = data_max_rel_diff )
297+ max_rel = data_max_rel_diff ,
298+ dtype = dtype )
292299 if data_diffs :
293300 diff ['DATA(md5)' ] = data_md5_diffs
294301 diff .update (data_diffs )
@@ -313,7 +320,8 @@ def main(args=None, out=None):
313320 files ,
314321 header_fields = opts .header_fields ,
315322 data_max_abs_diff = opts .data_max_abs_diff ,
316- data_max_rel_diff = opts .data_max_rel_diff
323+ data_max_rel_diff = opts .data_max_rel_diff ,
324+ dtype = opts .dtype
317325 )
318326
319327 if files_diff :
0 commit comments