11import numpy as np
22from scipy .ndimage import convolve
3+ import torch
4+ from torch .nn .functional import conv2d , conv3d
35
46####################################################################
57## Process image stacks.
@@ -18,27 +20,52 @@ def crop_volume(data, sz, st=(0, 0, 0)):
1820 else : # crop spatial dimensions
1921 return data [:, st [0 ]:st [0 ]+ sz [0 ], st [1 ]:st [1 ]+ sz [1 ], st [2 ]:st [2 ]+ sz [2 ]]
2022
21- def get_valid_pos (mask , vol_sz , valid_ratio ):
22- mask_sum = convolve (mask , np .ones (vol_sz ), mode = 'constant' , cval = 0 )
23+ def get_valid_pos_torch (mask , vol_sz , valid_ratio ):
24+ # torch version
25+ # bug: out of memory
2326 valid_thres = valid_ratio * np .prod (vol_sz )
2427 data_sz = mask .shape
25- pad_sz_pre = (np .array (vol_sz ) - 1 ) // 2
26- pad_sz_post = data_sz - (vol_sz - pad_sz_pre - 1 )
2728 if len (vol_sz ) == 3 :
28- mask_sum = mask_sum [pad_sz_pre [0 ]:pad_sz_post [0 ], \
29- pad_sz_pre [1 ]:pad_sz_post [1 ], \
30- pad_sz_pre [2 ]:pad_sz_post [2 ]] >= valid_thres
29+ mask_sum = conv3d (torch .from_numpy (mask [None ,None ].astype (int )), torch .ones (tuple (vol_sz ))[None ,None ], padding = 'valid' )[0 ,0 ].numpy ()>= valid_thres
3130 zz , yy , xx = np .meshgrid (np .arange (mask_sum .shape [0 ]), \
3231 np .arange (mask_sum .shape [1 ]), \
3332 np .arange (mask_sum .shape [2 ]))
3433 valid_pos = np .stack ([zz .T [mask_sum ], \
3534 yy .T [mask_sum ], \
3635 xx .T [mask_sum ]], axis = 1 )
3736 else :
38- mask_sum = mask_sum [pad_sz_pre [0 ]:pad_sz_post [0 ], \
39- pad_sz_pre [1 ]:pad_sz_post [1 ]] >= valid_thres
37+ mask_sum = conv2d (torch .from_numpy (mask [None ,None ].astype (int )), torch .ones (tuple (vol_sz ))[None ,None ], padding = 'valid' )[0 ,0 ].numpy ()>= valid_thres
4038 yy , xx = np .meshgrid (np .arange (mask_sum .shape [0 ]), \
4139 np .arange (mask_sum .shape [1 ]))
4240 valid_pos = np .stack ([yy .T [mask_sum ], \
4341 xx .T [mask_sum ]], axis = 1 )
4442 return valid_pos
43+
44+ def get_valid_pos (mask , vol_sz , valid_ratio ):
45+ # scipy version
46+ valid_thres = valid_ratio * np .prod (vol_sz )
47+ data_sz = mask .shape
48+ mask_sum = convolve (mask .astype (int ), np .ones (vol_sz ), mode = 'constant' , cval = 0 )
49+ pad_sz_pre = (np .array (vol_sz ) - 1 ) // 2
50+ pad_sz_post = data_sz - (vol_sz - pad_sz_pre - 1 )
51+ valid_pos = np .zeros ([0 ,3 ])
52+ if len (vol_sz ) == 3 :
53+ mask_sum = mask_sum [pad_sz_pre [0 ]:pad_sz_post [0 ], \
54+ pad_sz_pre [1 ]:pad_sz_post [1 ], \
55+ pad_sz_pre [2 ]:pad_sz_post [2 ]] >= valid_thres
56+ if mask_sum .max () > 0 :
57+ zz , yy , xx = np .meshgrid (np .arange (mask_sum .shape [0 ]), \
58+ np .arange (mask_sum .shape [1 ]), \
59+ np .arange (mask_sum .shape [2 ]))
60+ valid_pos = np .stack ([zz .transpose ([1 ,0 ,2 ])[mask_sum ], \
61+ yy .transpose ([1 ,0 ,2 ])[mask_sum ], \
62+ xx .transpose ([1 ,0 ,2 ])[mask_sum ]], axis = 1 )
63+ else :
64+ mask_sum = mask_sum [pad_sz_pre [0 ]:pad_sz_post [0 ], \
65+ pad_sz_pre [1 ]:pad_sz_post [1 ]] >= valid_thres
66+ if mask_sum .max () > 0 :
67+ yy , xx = np .meshgrid (np .arange (mask_sum .shape [0 ]), \
68+ np .arange (mask_sum .shape [1 ]))
69+ valid_pos = np .stack ([yy .T [mask_sum ], \
70+ xx .T [mask_sum ]], axis = 1 )
71+ return valid_pos
0 commit comments