Skip to content

Commit 4b00a0f

Browse files
author
donglaiw
committed
add erosion to distance transform
1 parent 4a755e8 commit 4b00a0f

File tree

6 files changed

+17
-47
lines changed

6 files changed

+17
-47
lines changed

configs/NucMM/NucMM-Mouse-UNet-BCD-v0.yaml

Lines changed: 0 additions & 17 deletions
This file was deleted.

configs/NucMM/NucMM-Mouse-UNet-BCD.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MODEL:
22
OUT_PLANES: 3
3-
TARGET_OPT: ['0','4-0-1', '5-3d-1-0-1.0']
3+
TARGET_OPT: ['0','4-0-1', '5-3d-1-0-1.0-0']
44
LOSS_OPTION:
55
- - WeightedBCEWithLogitsLoss
66
- DiceLoss

configs/NucMM/NucMM-Zebrafish-UNet-BCD-v0.yaml

Lines changed: 0 additions & 17 deletions
This file was deleted.

configs/NucMM/NucMM-Zebrafish-UNet-BCD.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MODEL:
22
OUT_PLANES: 3
3-
TARGET_OPT: ['0','4-0-1', '5-3d-1-0-1.0']
3+
TARGET_OPT: ['0','4-0-1', '5-3d-1-0-1.0-0']
44
LOSS_OPTION:
55
- - WeightedBCEWithLogitsLoss
66
- DiceLoss

connectomics/data/utils/data_segmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,12 @@ def seg2inst_edt(label, topt):
286286
# Format of the target option: 5-a-b-c-d
287287
# a: mode, b: padding, c: quantize, d: z_resolution
288288
if len(topt) == 1:
289-
topt = topt + '-2d-0-0-5.0' # 2d w/o padding or quantize (default)
289+
topt = topt + '-2d-0-0-5.0-0' # 2d w/o padding or quantize (default)
290290

291-
_, mode, padding, quant, z_res = topt.split('-')
291+
_, mode, padding, quant, z_res, erosion = topt.split('-')
292292
resolution = (float(z_res), 1.0, 1.0)
293293
return edt_instance(label.copy(), mode, resolution=resolution,
294-
quantize=bool(int(quant)), padding=bool(int(padding)))
294+
quantize=bool(int(quant)), padding=bool(int(padding)), erosion=erosion)
295295

296296

297297
def seg_to_targets(

connectomics/data/utils/data_transform.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import scipy
66
import numpy as np
77
from scipy.ndimage import distance_transform_edt
8-
from skimage.morphology import remove_small_holes, skeletonize
8+
from skimage.morphology import remove_small_holes, skeletonize, binary_erosion, disk
99
from skimage.measure import label as label_cc # avoid namespace conflict
1010
from skimage.filters import gaussian
1111

@@ -62,11 +62,12 @@ def edt_instance(label: np.ndarray,
6262
quantize: bool = True,
6363
resolution: Tuple[float] = (1.0, 1.0, 1.0),
6464
padding: bool = False):
65+
erosion: int = 0):
6566
assert mode in ['2d', '3d']
6667
if mode == '3d':
6768
# calculate 3d distance transform for instances
6869
vol_distance, vol_semantic = distance_transform(
69-
label, resolution=resolution, padding=padding)
70+
label, resolution=resolution, padding=padding, erosion=erosion)
7071
if quantize:
7172
vol_distance = energy_quantize(vol_distance)
7273
return vol_distance
@@ -75,7 +76,7 @@ def edt_instance(label: np.ndarray,
7576
vol_semantic = []
7677
for i in range(label.shape[0]):
7778
label_img = label[i].copy()
78-
distance, semantic = distance_transform(label_img, padding=padding)
79+
distance, semantic = distance_transform(label_img, padding=padding, erosion=erosion)
7980
vol_distance.append(distance)
8081
vol_semantic.append(semantic)
8182

@@ -120,7 +121,8 @@ def distance_transform(label: np.ndarray,
120121
bg_value: float = -1.0,
121122
relabel: bool = True,
122123
padding: bool = False,
123-
resolution: Tuple[float] = (1.0, 1.0)):
124+
resolution: Tuple[float] = (1.0, 1.0),
125+
erosion: int = 0):
124126
"""Euclidean distance transform (DT or EDT) for instance masks.
125127
"""
126128
eps = 1e-6
@@ -148,9 +150,12 @@ def distance_transform(label: np.ndarray,
148150
all_bg_sample = True
149151

150152
if not all_bg_sample:
153+
if erosion > 0:
154+
erosion_disk = disk(erosion)
151155
for idx in indices:
152-
temp1 = label.copy() == idx
153-
temp2 = remove_small_holes(temp1, 16, connectivity=1)
156+
temp2 = remove_small_holes(label == idx, 16, connectivity=1)
157+
if erosion > 0:
158+
temp2 = binary_erosion(temp2, erosion_disk)
154159

155160
semantic += temp2.astype(np.uint8)
156161
boundary_edt = distance_transform_edt(temp2, resolution)
@@ -222,8 +227,7 @@ def skeleton_aware_distance_transform(
222227

223228
if not all_bg_sample:
224229
for idx in indices:
225-
temp1 = label.copy() == idx
226-
temp2 = remove_small_holes(temp1, 16, connectivity=1)
230+
temp2 = remove_small_holes(label == idx, 16, connectivity=1)
227231
binary = temp2.copy()
228232

229233
if smooth:

0 commit comments

Comments
 (0)