66from skimage .morphology import erosion , dilation , disk
77from skimage .measure import label as label_cc # avoid namespace conflict
88from scipy .signal import convolve2d
9+ from scipy .ndimage import maximum_filter
910
1011from .data_affinity import *
1112from .data_transform import *
@@ -333,6 +334,7 @@ def seg_to_targets(
333334 out [tid ] = seg_to_instance_bd (label , bd_sz , do_bg )[
334335 None , :].astype (np .float32 )
335336 elif topt [0 ] == '5' : # distance transform (instance)
337+ # 5-3d-1-0-2.0-5
336338 distance = seg2inst_edt (label , topt )
337339 out [tid ] = distance [np .newaxis , :].astype (np .float32 )
338340 elif topt [0 ] == '6' : # distance transform (semantic)
@@ -351,6 +353,15 @@ def seg_to_targets(
351353 out [tid ] = np .concatenate ((diffgrads ,bin_mask ), axis = 0 )
352354 else :
353355 out [tid ] = seg2diffgrads (label )
356+ elif topt [0 ] == '8' : # skeleton prediction
357+ # 8-5-3d-1-0-1.0-1
358+ index = topt [2 :].find ('-' )
359+ # dilate the skeleton
360+ dilation = int (topt [2 :2 + index ])
361+ if int (dilation ) > 0 :
362+ label = maximum_filter (label , dilation )
363+ distance = seg2inst_edt (label , topt [2 + index + 1 :])
364+ out [tid ] = distance [np .newaxis , :].astype (np .float32 )
354365 elif topt [0 ] == '9' : # generic semantic segmentation
355366 out [tid ] = label .astype (np .int64 )
356367 else :
0 commit comments