@@ -257,14 +257,14 @@ def apply_binary_postprocessing(
257257 """Apply binary segmentation postprocessing pipeline.
258258
259259 Pipeline order:
260- 1. Threshold predictions to binary mask using threshold_range
260+ 1. Ensure input is binary (convert if needed)
261261 2. Apply median filter (optional)
262262 3. Apply morphological opening (erosion + dilation)
263263 4. Apply morphological closing (dilation + erosion)
264264 5. Extract connected components and filter by size/keep top-k
265265
266266 Args:
267- pred (numpy.ndarray): Predicted foreground probability in range [0, 1].
267+ pred (numpy.ndarray): Binary mask (values 0 or 1) or predicted probabilities in range [0, 1].
268268 Shape can be 2D (H, W) or 3D (D, H, W).
269269 config (BinaryPostprocessingConfig): Configuration for postprocessing pipeline.
270270
@@ -276,21 +276,30 @@ def apply_binary_postprocessing(
276276 >>> from connectomics.config import BinaryPostprocessingConfig, ConnectedComponentsConfig
277277 >>> config = BinaryPostprocessingConfig(
278278 ... enabled=True,
279- ... threshold_range=(0.8, 1.0),
280279 ... opening_iterations=2,
281280 ... connected_components=ConnectedComponentsConfig(top_k=1)
282281 ... )
283282 >>> pred = np.random.rand(128, 128) # Random probabilities
284283 >>> binary_mask = apply_binary_postprocessing(pred, config)
285284 """
286285 if not config or not config .enabled :
287- # Just threshold at 0.5 if postprocessing is disabled
288- return (pred > 0.5 ).astype (np .uint8 )
289-
290- # Step 1: Threshold to binary using threshold_range
291- # Use the minimum threshold from the range
292- threshold = config .threshold_range [0 ]
293- binary = (pred > threshold ).astype (np .uint8 )
286+ # If no postprocessing, ensure binary output
287+ if pred .max () <= 1 :
288+ return (pred > 0.5 ).astype (np .uint8 )
289+ else :
290+ return (pred > 0 ).astype (np .uint8 )
291+
292+ # Step 1: Ensure input is binary
293+ # Check if input is already binary (0/1) or needs thresholding
294+ if np .all ((pred == 0 ) | (pred == 1 )):
295+ # Already binary
296+ binary = pred .astype (np .uint8 )
297+ elif pred .max () <= 1.0 :
298+ # Probability values in [0, 1], threshold at 0.5
299+ binary = (pred > 0.5 ).astype (np .uint8 )
300+ else :
301+ # Assume already thresholded but not scaled to 0/1
302+ binary = (pred > 0 ).astype (np .uint8 )
294303
295304 # Step 2: Apply median filter (optional noise reduction)
296305 if config .median_filter_size is not None :
0 commit comments