55import scipy
66import numpy as np
77from scipy .ndimage import distance_transform_edt
8- from skimage .morphology import remove_small_holes , skeletonize
8+ from skimage .morphology import remove_small_holes , skeletonize , binary_erosion , disk
99from skimage .measure import label as label_cc # avoid namespace conflict
1010from skimage .filters import gaussian
1111
@@ -62,11 +62,12 @@ def edt_instance(label: np.ndarray,
6262 quantize : bool = True ,
6363 resolution : Tuple [float ] = (1.0 , 1.0 , 1.0 ),
6464 padding : bool = False ):
65+ erosion : int = 0 ):
6566 assert mode in ['2d' , '3d' ]
6667 if mode == '3d' :
6768 # calculate 3d distance transform for instances
6869 vol_distance , vol_semantic = distance_transform (
69- label , resolution = resolution , padding = padding )
70+ label , resolution = resolution , padding = padding , erosion = erosion )
7071 if quantize :
7172 vol_distance = energy_quantize (vol_distance )
7273 return vol_distance
@@ -75,7 +76,7 @@ def edt_instance(label: np.ndarray,
7576 vol_semantic = []
7677 for i in range (label .shape [0 ]):
7778 label_img = label [i ].copy ()
78- distance , semantic = distance_transform (label_img , padding = padding )
79+ distance , semantic = distance_transform (label_img , padding = padding , erosion = erosion )
7980 vol_distance .append (distance )
8081 vol_semantic .append (semantic )
8182
@@ -120,7 +121,8 @@ def distance_transform(label: np.ndarray,
120121 bg_value : float = - 1.0 ,
121122 relabel : bool = True ,
122123 padding : bool = False ,
123- resolution : Tuple [float ] = (1.0 , 1.0 )):
124+ resolution : Tuple [float ] = (1.0 , 1.0 ),
125+ erosion : int = 0 ):
124126 """Euclidean distance transform (DT or EDT) for instance masks.
125127 """
126128 eps = 1e-6
@@ -148,9 +150,12 @@ def distance_transform(label: np.ndarray,
148150 all_bg_sample = True
149151
150152 if not all_bg_sample :
153+ if erosion > 0 :
154+ erosion_disk = disk (erosion )
151155 for idx in indices :
152- temp1 = label .copy () == idx
153- temp2 = remove_small_holes (temp1 , 16 , connectivity = 1 )
156+ temp2 = remove_small_holes (label == idx , 16 , connectivity = 1 )
157+ if erosion > 0 :
158+ temp2 = binary_erosion (temp2 , erosion_disk )
154159
155160 semantic += temp2 .astype (np .uint8 )
156161 boundary_edt = distance_transform_edt (temp2 , resolution )
@@ -222,8 +227,7 @@ def skeleton_aware_distance_transform(
222227
223228 if not all_bg_sample :
224229 for idx in indices :
225- temp1 = label .copy () == idx
226- temp2 = remove_small_holes (temp1 , 16 , connectivity = 1 )
230+ temp2 = remove_small_holes (label == idx , 16 , connectivity = 1 )
227231 binary = temp2 .copy ()
228232
229233 if smooth :
0 commit comments