Skip to content

Commit b31c68c

Browse files
author
donglaiw
committed
add skeeleton mag prediction
1 parent 5ea8479 commit b31c68c

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

connectomics/data/utils/data_segmentation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from skimage.morphology import erosion, dilation, disk
77
from skimage.measure import label as label_cc # avoid namespace conflict
88
from scipy.signal import convolve2d
9+
from scipy.ndimage import maximum_filter
910

1011
from .data_affinity import *
1112
from .data_transform import *
@@ -333,6 +334,7 @@ def seg_to_targets(
333334
out[tid] = seg_to_instance_bd(label, bd_sz, do_bg)[
334335
None, :].astype(np.float32)
335336
elif topt[0] == '5': # distance transform (instance)
337+
# 5-3d-1-0-2.0-5
336338
distance = seg2inst_edt(label, topt)
337339
out[tid] = distance[np.newaxis, :].astype(np.float32)
338340
elif topt[0] == '6': # distance transform (semantic)
@@ -351,6 +353,15 @@ def seg_to_targets(
351353
out[tid] = np.concatenate((diffgrads,bin_mask), axis=0)
352354
else:
353355
out[tid] = seg2diffgrads(label)
356+
elif topt[0] == '8': # skeleton prediction
357+
# 8-5-3d-1-0-1.0-1
358+
index = topt[2:].find('-')
359+
# dilate the skeleton
360+
dilation = int(topt[2:2+index])
361+
if int(dilation) > 0 :
362+
label = maximum_filter(label, dilation)
363+
distance = seg2inst_edt(label, topt[2+index+1:])
364+
out[tid] = distance[np.newaxis, :].astype(np.float32)
354365
elif topt[0] == '9': # generic semantic segmentation
355366
out[tid] = label.astype(np.int64)
356367
else:

connectomics/utils/visualizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ def visualize(self, volume, label, output, weight, iter_total, writer,
6262
output[idx] = self.get_semantic_map(output[idx], topt)
6363
label[idx] = self.get_semantic_map(label[idx], topt, argmax=False)
6464

65-
if topt[0] == '5': # distance transform
65+
if topt[0] in ['5','8']: # distance transform
66+
if topt[0] == '8':
67+
index = topt[2:].find('-')
68+
topt = topt[2+index+1:]
6669
if len(topt) == 1:
6770
topt = topt + '-2d-0-0-5.0-0' # default
6871
_, mode, padding, quant, z_res, erosion = topt.split('-')

0 commit comments

Comments
 (0)