11from typing import Optional , List
22import numpy as np
33import random
4+ import warnings
45
56import torch
67import torch .utils .data
@@ -68,6 +69,7 @@ def __init__(self,
6869 do_relabel : bool = True ,
6970 # rejection sampling
7071 reject_size_thres : int = 0 ,
72+ reject_num_trial : int = 50 ,
7173 reject_diversity : int = 0 ,
7274 reject_p : float = 0.95 ,
7375 # normalization
@@ -98,6 +100,7 @@ def __init__(self,
98100 # rejection samping
99101 self .reject_size_thres = reject_size_thres
100102 self .reject_diversity = reject_diversity
103+ self .reject_num_trial = reject_num_trial
101104 self .reject_p = reject_p
102105
103106 # normalization
@@ -113,6 +116,17 @@ def __init__(self,
113116 assert len (set (x [0 ] for x in volume_size )) == 1 , "All volumes should have the same number of channels"
114117 self .volume_size = [x [- 3 :] for x in volume_size ]
115118
119+ volume_selection = [(sample_label_size <= x ).all () for x in self .volume_size ]
120+ if not all (volume_selection ):
121+ print ('remove volumes whose sizes are smaller than the model input' , volume_selection )
122+ self .volume = [x for i ,x in enumerate (self .volume ) if volume_selection [i ]]
123+ volume_size = [np .array (x .shape ) for x in self .volume ]
124+ self .volume_size = [x [- 3 :] for x in volume_size ]
125+ if self .label is not None :
126+ self .label = [x for i ,x in enumerate (self .label ) if volume_selection [i ]]
127+ if valid_mask is not None :
128+ valid_mask = [x for i ,x in enumerate (valid_mask ) if volume_selection [i ]]
129+
116130 self .sample_volume_size = np .array (
117131 sample_volume_size ).astype (int ) # model input size
118132 if self .label is not None :
@@ -122,7 +136,7 @@ def __init__(self,
122136 if self .augmentor is not None :
123137 assert np .array_equal (
124138 self .augmentor .sample_size , self .sample_label_size )
125- self ._assert_valid_shape ()
139+ # self._assert_valid_shape()
126140
127141 # compute number of samples for each dataset (multi-volume input)
128142 self .sample_stride = np .array (sample_stride ).astype (int )
@@ -138,15 +152,19 @@ def __init__(self,
138152 self .valid_mask = valid_mask
139153 self .valid_ratio = valid_ratio
140154 # precompute valid region
155+ # can be memory intensive
156+ self .valid_pos = [None ] * len (self .valid_mask ) if self .valid_mask is not None else [None ] * len (self .volume )
157+ """
141158 if self.valid_mask is not None:
142- self .valid_pos = [None ] * len (self .valid_mask )
143159 for i, x in enumerate(self.valid_mask):
144160 if x is not None:
145161 self.valid_pos[i] = get_valid_pos(x, sample_volume_size, valid_ratio)
146162 self.sample_num[i] = self.valid_pos[i].shape[0]
163+ print(i, self.sample_num[i])
147164 self.sample_num_a = np.sum(self.sample_num)
148165 self.sample_num_c = np.cumsum([0] + list(self.sample_num))
149-
166+ """
167+
150168 if self .mode in ['val' , 'test' ]: # for validation and test
151169 self .sample_size_test = [
152170 np .array ([np .prod (x [1 :3 ]), x [2 ]]) for x in self .sample_size ]
@@ -240,17 +258,17 @@ def _get_pos_train(self, vol_size):
240258 # np.random: same seed
241259 pos = [0 , 0 , 0 , 0 ]
242260 # pick a dataset
243- did = self ._index_to_dataset (random .randint (0 , self .sample_num_a ))
261+ did = self ._index_to_dataset (random .randint (0 , self .sample_num_a - 1 ))
244262 pos [0 ] = did
245263 # pick a position
246264 # all regions are valid
247265 if self .valid_pos [did ] is None :
248266 tmp_size = count_volume (
249267 self .volume_size [did ], vol_size , self .sample_stride )
250- tmp_pos = [random .randint (0 , tmp_size [x ]) * self .sample_stride [x ]
268+ tmp_pos = [random .randint (0 , tmp_size [x ] - 1 ) * self .sample_stride [x ]
251269 for x in range (len (tmp_size ))]
252270 else :
253- tmp_pos = self .valid_pos [did ][random .randint (0 , self .valid_pos [did ].shape [0 ])]
271+ tmp_pos = self .valid_pos [did ][random .randint (0 , self .valid_pos [did ].shape [0 ]) - 1 ]
254272
255273 pos [1 :] = tmp_pos
256274 return pos
@@ -282,16 +300,21 @@ def _rejection_sampling(self, vol_size):
282300 out_valid = augmented ['valid_mask' ]
283301
284302 if self ._is_valid (out_valid ) and self ._is_fg (out_label ):
303+ #print('yes', sample_count)
285304 return pos , out_volume , out_label , out_valid
286305
287306 sample_count += 1
288- if sample_count > 100 :
307+ if sample_count > self . reject_num_trial :
289308 err_msg = (
290309 "Can not find any valid subvolume after sampling the "
291- "dataset for more than 100 times. Please adjust the "
310+ f "dataset for more than { self . reject_num_trial } times. Please adjust the "
292311 "valid mask or rejection sampling configurations."
293312 )
294- raise RuntimeError (err_msg )
313+ #raise RuntimeError(err_msg)
314+ # return anyway with a useless sample
315+ warnings .warn (err_msg )
316+ #print('no..')
317+ return pos , out_volume , out_label , out_valid
295318
296319 def _random_sampling (self , vol_size ):
297320 """Randomly sample a subvolume from all the volumes.
0 commit comments