Skip to content

Commit 9ef9231

Browse files
author
donglaiw
committed
resolve valid_pos bug
2 parents f04d2ea + 1d8f16f commit 9ef9231

File tree

5 files changed

+71
-21
lines changed

5 files changed

+71
-21
lines changed

connectomics/config/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@
278278
_C.DATASET.REJECT_SAMPLING.SIZE_THRES = -1
279279
_C.DATASET.REJECT_SAMPLING.DIVERSITY = -1
280280
_C.DATASET.REJECT_SAMPLING.P = 0.95
281+
_C.DATASET.REJECT_SAMPLING.NUM_TRIAL = 50
281282

282283
# Normalize model inputs (the images are assumed to be gray-scale).
283284
_C.DATASET.MEAN = 0.5

connectomics/data/dataset/dataset_volume.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, List
22
import numpy as np
33
import random
4+
import warnings
45

56
import torch
67
import torch.utils.data
@@ -68,6 +69,7 @@ def __init__(self,
6869
do_relabel: bool = True,
6970
# rejection sampling
7071
reject_size_thres: int = 0,
72+
reject_num_trial: int = 50,
7173
reject_diversity: int = 0,
7274
reject_p: float = 0.95,
7375
# normalization
@@ -98,6 +100,7 @@ def __init__(self,
98100
# rejection samping
99101
self.reject_size_thres = reject_size_thres
100102
self.reject_diversity = reject_diversity
103+
self.reject_num_trial = reject_num_trial
101104
self.reject_p = reject_p
102105

103106
# normalization
@@ -113,6 +116,17 @@ def __init__(self,
113116
assert len(set(x[0] for x in volume_size)) == 1, "All volumes should have the same number of channels"
114117
self.volume_size = [x[-3:] for x in volume_size]
115118

119+
volume_selection = [(sample_label_size <= x).all() for x in self.volume_size]
120+
if not all(volume_selection):
121+
print('remove volumes whose sizes are smaller than the model input', volume_selection)
122+
self.volume = [x for i,x in enumerate(self.volume) if volume_selection[i]]
123+
volume_size = [np.array(x.shape) for x in self.volume]
124+
self.volume_size = [x[-3:] for x in volume_size]
125+
if self.label is not None:
126+
self.label = [x for i,x in enumerate(self.label) if volume_selection[i]]
127+
if valid_mask is not None:
128+
valid_mask = [x for i,x in enumerate(valid_mask) if volume_selection[i]]
129+
116130
self.sample_volume_size = np.array(
117131
sample_volume_size).astype(int) # model input size
118132
if self.label is not None:
@@ -122,7 +136,7 @@ def __init__(self,
122136
if self.augmentor is not None:
123137
assert np.array_equal(
124138
self.augmentor.sample_size, self.sample_label_size)
125-
self._assert_valid_shape()
139+
#self._assert_valid_shape()
126140

127141
# compute number of samples for each dataset (multi-volume input)
128142
self.sample_stride = np.array(sample_stride).astype(int)
@@ -138,15 +152,19 @@ def __init__(self,
138152
self.valid_mask = valid_mask
139153
self.valid_ratio = valid_ratio
140154
# precompute valid region
155+
# can be memory intensive
156+
self.valid_pos = [None] * len(self.valid_mask) if self.valid_mask is not None else [None] * len(self.volume)
157+
"""
141158
if self.valid_mask is not None:
142-
self.valid_pos = [None] * len(self.valid_mask)
143159
for i, x in enumerate(self.valid_mask):
144160
if x is not None:
145161
self.valid_pos[i] = get_valid_pos(x, sample_volume_size, valid_ratio)
146162
self.sample_num[i] = self.valid_pos[i].shape[0]
163+
print(i, self.sample_num[i])
147164
self.sample_num_a = np.sum(self.sample_num)
148165
self.sample_num_c = np.cumsum([0] + list(self.sample_num))
149-
166+
"""
167+
150168
if self.mode in ['val', 'test']: # for validation and test
151169
self.sample_size_test = [
152170
np.array([np.prod(x[1:3]), x[2]]) for x in self.sample_size]
@@ -240,17 +258,17 @@ def _get_pos_train(self, vol_size):
240258
# np.random: same seed
241259
pos = [0, 0, 0, 0]
242260
# pick a dataset
243-
did = self._index_to_dataset(random.randint(0, self.sample_num_a))
261+
did = self._index_to_dataset(random.randint(0, self.sample_num_a - 1))
244262
pos[0] = did
245263
# pick a position
246264
# all regions are valid
247265
if self.valid_pos[did] is None:
248266
tmp_size = count_volume(
249267
self.volume_size[did], vol_size, self.sample_stride)
250-
tmp_pos = [random.randint(0, tmp_size[x]) * self.sample_stride[x]
268+
tmp_pos = [random.randint(0, tmp_size[x] - 1) * self.sample_stride[x]
251269
for x in range(len(tmp_size))]
252270
else:
253-
tmp_pos = self.valid_pos[did][random.randint(0, self.valid_pos[did].shape[0])]
271+
tmp_pos = self.valid_pos[did][random.randint(0, self.valid_pos[did].shape[0]) - 1]
254272

255273
pos[1:] = tmp_pos
256274
return pos
@@ -282,16 +300,21 @@ def _rejection_sampling(self, vol_size):
282300
out_valid = augmented['valid_mask']
283301

284302
if self._is_valid(out_valid) and self._is_fg(out_label):
303+
#print('yes', sample_count)
285304
return pos, out_volume, out_label, out_valid
286305

287306
sample_count += 1
288-
if sample_count > 100:
307+
if sample_count > self.reject_num_trial:
289308
err_msg = (
290309
"Can not find any valid subvolume after sampling the "
291-
"dataset for more than 100 times. Please adjust the "
310+
f"dataset for more than {self.reject_num_trial} times. Please adjust the "
292311
"valid mask or rejection sampling configurations."
293312
)
294-
raise RuntimeError(err_msg)
313+
#raise RuntimeError(err_msg)
314+
# return anyway with a useless sample
315+
warnings.warn(err_msg)
316+
#print('no..')
317+
return pos, out_volume, out_label, out_valid
295318

296319
def _random_sampling(self, vol_size):
297320
"""Randomly sample a subvolume from all the volumes.

connectomics/data/utils/data_crop.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22
from scipy.ndimage import convolve
3+
import torch
4+
from torch.nn.functional import conv2d, conv3d
35

46
####################################################################
57
## Process image stacks.
@@ -18,27 +20,52 @@ def crop_volume(data, sz, st=(0, 0, 0)):
1820
else: # crop spatial dimensions
1921
return data[:, st[0]:st[0]+sz[0], st[1]:st[1]+sz[1], st[2]:st[2]+sz[2]]
2022

21-
def get_valid_pos(mask, vol_sz, valid_ratio):
22-
mask_sum = convolve(mask, np.ones(vol_sz), mode='constant', cval=0)
23+
def get_valid_pos_torch(mask, vol_sz, valid_ratio):
24+
# torch version
25+
# bug: out of memory
2326
valid_thres = valid_ratio * np.prod(vol_sz)
2427
data_sz = mask.shape
25-
pad_sz_pre = (np.array(vol_sz) - 1) // 2
26-
pad_sz_post = data_sz - (vol_sz - pad_sz_pre - 1)
2728
if len(vol_sz) == 3:
28-
mask_sum = mask_sum[pad_sz_pre[0]:pad_sz_post[0], \
29-
pad_sz_pre[1]:pad_sz_post[1], \
30-
pad_sz_pre[2]:pad_sz_post[2]] >= valid_thres
29+
mask_sum = conv3d(torch.from_numpy(mask[None,None].astype(int)), torch.ones(tuple(vol_sz))[None,None], padding='valid')[0,0].numpy()>= valid_thres
3130
zz, yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), \
3231
np.arange(mask_sum.shape[1]), \
3332
np.arange(mask_sum.shape[2]))
3433
valid_pos = np.stack([zz.T[mask_sum], \
3534
yy.T[mask_sum], \
3635
xx.T[mask_sum]], axis=1)
3736
else:
38-
mask_sum = mask_sum[pad_sz_pre[0]:pad_sz_post[0], \
39-
pad_sz_pre[1]:pad_sz_post[1]] >= valid_thres
37+
mask_sum = conv2d(torch.from_numpy(mask[None,None].astype(int)), torch.ones(tuple(vol_sz))[None,None], padding='valid')[0,0].numpy()>= valid_thres
4038
yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), \
4139
np.arange(mask_sum.shape[1]))
4240
valid_pos = np.stack([yy.T[mask_sum], \
4341
xx.T[mask_sum]], axis=1)
4442
return valid_pos
43+
44+
def get_valid_pos(mask, vol_sz, valid_ratio):
45+
# scipy version
46+
valid_thres = valid_ratio * np.prod(vol_sz)
47+
data_sz = mask.shape
48+
mask_sum = convolve(mask.astype(int), np.ones(vol_sz), mode='constant', cval=0)
49+
pad_sz_pre = (np.array(vol_sz) - 1) // 2
50+
pad_sz_post = data_sz - (vol_sz - pad_sz_pre - 1)
51+
valid_pos = np.zeros([0,3])
52+
if len(vol_sz) == 3:
53+
mask_sum = mask_sum[pad_sz_pre[0]:pad_sz_post[0], \
54+
pad_sz_pre[1]:pad_sz_post[1], \
55+
pad_sz_pre[2]:pad_sz_post[2]] >= valid_thres
56+
if mask_sum.max() > 0:
57+
zz, yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), \
58+
np.arange(mask_sum.shape[1]), \
59+
np.arange(mask_sum.shape[2]))
60+
valid_pos = np.stack([zz.transpose([1,0,2])[mask_sum], \
61+
yy.transpose([1,0,2])[mask_sum], \
62+
xx.transpose([1,0,2])[mask_sum]], axis=1)
63+
else:
64+
mask_sum = mask_sum[pad_sz_pre[0]:pad_sz_post[0], \
65+
pad_sz_pre[1]:pad_sz_post[1]] >= valid_thres
66+
if mask_sum.max() > 0:
67+
yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), \
68+
np.arange(mask_sum.shape[1]))
69+
valid_pos = np.stack([yy.T[mask_sum], \
70+
xx.T[mask_sum]], axis=1)
71+
return valid_pos

docs/environment_docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ dependencies:
100100
- connectomics==0.0.1.dev4
101101
- contourpy==1.1.1
102102
- cycler==0.12.1
103-
- cython==0.29.21
103+
- cython==0.29.22
104104
- debugpy==1.8.1
105105
- decorator==5.1.1
106106
- defusedxml==0.7.1

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
'scikit-image>=0.17.2',
1111
'opencv-python>=4.3.0',
1212
'matplotlib>=3.3.0',
13-
'Cython==0.29.21',
13+
'Cython>=0.29.22',
1414
'yacs>=0.1.8',
1515
'h5py>=2.10.0',
1616
'gputil>=1.4.0',
1717
'imageio>=2.9.0',
18-
'tensorflow>=2.2.0',
1918
'tensorboard>=2.2.2',
2019
'einops>=0.3.0',
2120
'tqdm>=4.58.0',

0 commit comments

Comments
 (0)