1010from mantidimaging import helper as h
1111from mantidimaging .core .data import Images
1212from mantidimaging .core .operations .base_filter import BaseFilter , FilterGroup
13- from mantidimaging .core .parallel import two_shared_mem as ptsm
14- from mantidimaging .core .parallel import utility as pu
13+ from mantidimaging .core .parallel import utility as pu , shared as ps
1514from mantidimaging .core .utility .progress_reporting import Progress
1615from mantidimaging .gui .utility .qt_helpers import Type
1716from mantidimaging .gui .widgets .stack_selector import StackSelectorWidgetView
@@ -58,7 +57,7 @@ class FlatFieldFilter(BaseFilter):
5857 filter_name = 'Flat-fielding'
5958
6059 @staticmethod
61- def filter_func (data : Images ,
60+ def filter_func (images : Images ,
6261 flat_before : Images = None ,
6362 flat_after : Images = None ,
6463 dark_before : Images = None ,
@@ -80,7 +79,7 @@ def filter_func(data: Images,
8079 :param chunksize: The number of chunks that each worker will receive.
8180 :return: Filtered data (stack of images)
8281 """
83- h .check_data_stack (data )
82+ h .check_data_stack (images )
8483
8584 if selected_flat_fielding is not None :
8685 if selected_flat_fielding == "Both, concatenated" and flat_after is not None and flat_before is not None \
@@ -101,19 +100,19 @@ def filter_func(data: Images,
101100 if 2 != flat_avg .ndim or 2 != dark_avg .ndim :
102101 raise ValueError (
103102 f"Incorrect shape of the flat image ({ flat_avg .shape } ) or dark image ({ dark_avg .shape } ) \
104- which should match the shape of the sample images ({ data .data .shape } )" )
103+ which should match the shape of the sample images ({ images .data .shape } )" )
105104
106- if not data .data .shape [1 :] == flat_avg .shape == dark_avg .shape :
107- raise ValueError (f"Not all images are the expected shape: { data .data .shape [1 :]} , instead "
105+ if not images .data .shape [1 :] == flat_avg .shape == dark_avg .shape :
106+ raise ValueError (f"Not all images are the expected shape: { images .data .shape [1 :]} , instead "
108107 f"flat had shape: { flat_avg .shape } , and dark had shape: { dark_avg .shape } " )
109108
110109 progress = Progress .ensure_instance (progress ,
111- num_steps = data .data .shape [0 ],
110+ num_steps = images .data .shape [0 ],
112111 task_name = 'Background Correction' )
113- _execute (data .data , flat_avg , dark_avg , cores , chunksize , progress )
112+ _execute (images .data , flat_avg , dark_avg , cores , chunksize , progress )
114113
115- h .check_data_stack (data )
116- return data
114+ h .check_data_stack (images )
115+ return images
117116
118117 @staticmethod
119118 def register_gui (form , on_change , view : FiltersWindowView ) -> Dict [str , Any ]:
@@ -260,7 +259,7 @@ def _subtract(data, dark=None):
260259 np .subtract (data , dark , out = data )
261260
262261
263- def _execute (data , flat = None , dark = None , cores = None , chunksize = None , progress = None ):
262+ def _execute (data : np . ndarray , flat = None , dark = None , cores = None , chunksize = None , progress = None ):
264263 """A benchmark justifying the current implementation, performed on
265264 500x2048x2048 images.
266265
@@ -289,11 +288,13 @@ def _execute(data, flat=None, dark=None, cores=None, chunksize=None, progress=No
289288 norm_divide [norm_divide == 0 ] = MINIMUM_PIXEL_VALUE
290289
291290 # subtract the dark from all images
292- f = ptsm .create_partial (_subtract , fwd_function = ptsm .inplace_second_2d )
293- data , dark = ptsm .execute (data , dark , f , cores , chunksize , progress = progress )
291+ do_subtract = ps .create_partial (_subtract , fwd_function = ps .inplace_second_2d )
292+ ps .shared_list = [data , dark ]
293+ ps .execute (do_subtract , data .shape [0 ], progress , cores = cores )
294294
295295 # divide the data by (flat - dark)
296- f = ptsm .create_partial (_divide , fwd_function = ptsm .inplace_second_2d )
297- data , norm_divide = ptsm .execute (data , norm_divide , f , cores , chunksize , progress = progress )
296+ do_divide = ps .create_partial (_divide , fwd_function = ps .inplace_second_2d )
297+ ps .shared_list = [data , norm_divide ]
298+ ps .execute (do_divide , data .shape [0 ], progress , cores = cores )
298299
299300 return data
0 commit comments