|
5 | 5 | import scipy |
6 | 6 | import numpy as np |
7 | 7 | from scipy.ndimage import distance_transform_edt |
8 | | -from skimage.morphology import remove_small_holes, skeletonize, binary_erosion, disk |
| 8 | +from skimage.morphology import remove_small_holes, skeletonize, binary_erosion, disk, ball |
9 | 9 | from skimage.measure import label as label_cc # avoid namespace conflict |
10 | 10 | from skimage.filters import gaussian |
11 | 11 |
|
@@ -61,7 +61,7 @@ def edt_instance(label: np.ndarray, |
61 | 61 | mode: str = '2d', |
62 | 62 | quantize: bool = True, |
63 | 63 | resolution: Tuple[float] = (1.0, 1.0, 1.0), |
64 | | - padding: bool = False), |
| 64 | + padding: bool = False, |
65 | 65 | erosion: int = 0): |
66 | 66 | assert mode in ['2d', '3d'] |
67 | 67 | if mode == '3d': |
@@ -151,11 +151,14 @@ def distance_transform(label: np.ndarray, |
151 | 151 |
|
152 | 152 | if not all_bg_sample: |
153 | 153 | if erosion > 0: |
154 | | - erosion_disk = disk(erosion) |
| 154 | + if label.ndim == 2: |
| 155 | + footprint = disk(erosion) |
| 156 | + elif label.ndim == 3: |
| 157 | + footprint = ball(erosion) |
155 | 158 | for idx in indices: |
156 | 159 | temp2 = remove_small_holes(label == idx, 16, connectivity=1) |
157 | 160 | if erosion > 0: |
158 | | - temp2 = binary_erosion(temp2, erosion_disk) |
| 161 | + temp2 = binary_erosion(temp2, footprint) |
159 | 162 |
|
160 | 163 | semantic += temp2.astype(np.uint8) |
161 | 164 | boundary_edt = distance_transform_edt(temp2, resolution) |
|
0 commit comments