Skip to content

Commit 066f1c2

Browse files
author
Donglai Wei
committed
add 3d axon seg
1 parent 8812a17 commit 066f1c2

File tree

13 files changed

+755
-329
lines changed

13 files changed

+755
-329
lines changed

CHECKERBOARD_FIX.md

Lines changed: 0 additions & 247 deletions
This file was deleted.

connectomics/config/hydra_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class ModelConfig:
147147
kernel_size: int = 3 # Convolution kernel size
148148
strides: Optional[List[int]] = None # Downsampling strides (e.g., [2, 2, 2, 2] for 4 levels)
149149
act: str = "relu" # Activation function: 'relu', 'prelu', 'elu', etc.
150+
upsample: str = "deconv" # Upsampling mode for MONAI BasicUNet: 'deconv' (transposed conv), 'nontrainable' (interpolation + conv), or 'pixelshuffle'
150151

151152
# Transformer-specific (UNETR, etc.)
152153
feature_size: int = 16
@@ -926,17 +927,15 @@ class BinaryPostprocessingConfig:
926927
"""Binary segmentation postprocessing configuration.
927928
928929
Applies morphological operations and connected components filtering to binary predictions.
930+
Input should already be binary (from binary_thresholding decoding).
929931
Pipeline order:
930-
1. Threshold predictions to binary mask (using threshold_range if provided)
932+
1. Ensure input is binary (auto-threshold if needed: 0.5 for [0,1], 0 for >1)
931933
2. Apply morphological opening (erosion + dilation)
932934
3. Extract connected components
933935
4. Keep top-k largest components
934936
"""
935937

936938
enabled: bool = False # Enable binary postprocessing pipeline
937-
threshold_range: Optional[List[float]] = (
938-
None # Threshold range [min, max] for binarization. If None, uses 0.5 for [0,1] or 0 for >1
939-
)
940939
median_filter_size: Optional[Tuple[int, ...]] = (
941940
None # Median filter kernel size (e.g., (3, 3) for 2D)
942941
)
@@ -960,16 +959,16 @@ class PostprocessingConfig:
960959
"""Postprocessing configuration for inference output.
961960
962961
Controls how predictions are transformed before saving:
963-
- Thresholding: Binarize predictions using threshold_range
962+
- Binary refinement: Morphological operations and connected components filtering
964963
- Scaling: Multiply intensity values (e.g., 255 for uint8)
965964
- Dtype conversion: Convert to target data type with proper clamping
966965
- Transpose: Reorder axes (e.g., [2,1,0] for zyx->xyz)
967966
"""
968967

969-
# Thresholding configuration
968+
# Binary segmentation refinement (morphological ops, connected components)
970969
binary: Optional[Dict[str, Any]] = field(
971970
default_factory=dict
972-
) # Binary thresholding config (e.g., {'threshold_range': [0.5, 1.0]})
971+
) # Binary postprocessing config (e.g., {'opening_iterations': 2})
973972

974973
# Intensity scaling
975974
intensity_scale: Optional[float] = (

connectomics/config/hydra_utils.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,43 @@ def _combine_path(base_path: str, file_path: Optional[Union[str, List[str]]]) ->
291291
if base_path and not os.path.isabs(file_path):
292292
file_path = os.path.join(base_path, file_path)
293293

294-
# Expand glob patterns
295-
if "*" in file_path or "?" in file_path:
294+
# Expand glob patterns with optional selector support
295+
# Format: path/*.tiff[0] or path/*.tiff[filename]
296+
import re
297+
selector_match = re.match(r'^(.+)\[(.+)\]$', file_path)
298+
299+
if selector_match:
300+
# Has selector - extract glob pattern and selector
301+
glob_pattern = selector_match.group(1)
302+
selector = selector_match.group(2)
303+
304+
expanded = sorted(glob(glob_pattern))
305+
if not expanded:
306+
return file_path # No matches - return original
307+
308+
# Select file based on selector
309+
try:
310+
# Try numeric index
311+
index = int(selector)
312+
if index < -len(expanded) or index >= len(expanded):
313+
print(f"Warning: Index {index} out of range for {len(expanded)} files, using first")
314+
return expanded[0]
315+
return expanded[index]
316+
except ValueError:
317+
# Not a number, try filename match
318+
from pathlib import Path
319+
matching = [f for f in expanded if Path(f).name == selector or Path(f).stem == selector]
320+
if not matching:
321+
# Try partial match
322+
matching = [f for f in expanded if selector in Path(f).name]
323+
if matching:
324+
return matching[0]
325+
else:
326+
print(f"Warning: No file matches selector '{selector}', using first of {len(expanded)} files")
327+
return expanded[0]
328+
329+
elif "*" in file_path or "?" in file_path:
330+
# Standard glob without selector
296331
expanded = sorted(glob(file_path))
297332
if expanded:
298333
return expanded
@@ -302,32 +337,34 @@ def _combine_path(base_path: str, file_path: Optional[Union[str, List[str]]]) ->
302337

303338
return file_path
304339

305-
# Resolve training paths
306-
if cfg.data.train_path:
307-
cfg.data.train_image = _combine_path(cfg.data.train_path, cfg.data.train_image)
308-
cfg.data.train_label = _combine_path(cfg.data.train_path, cfg.data.train_label)
309-
cfg.data.train_mask = _combine_path(cfg.data.train_path, cfg.data.train_mask)
310-
cfg.data.train_json = _combine_path(cfg.data.train_path, cfg.data.train_json)
311-
312-
# Resolve validation paths
313-
if cfg.data.val_path:
314-
cfg.data.val_image = _combine_path(cfg.data.val_path, cfg.data.val_image)
315-
cfg.data.val_label = _combine_path(cfg.data.val_path, cfg.data.val_label)
316-
cfg.data.val_mask = _combine_path(cfg.data.val_path, cfg.data.val_mask)
317-
cfg.data.val_json = _combine_path(cfg.data.val_path, cfg.data.val_json)
318-
319-
# Resolve test paths
320-
if cfg.data.test_path:
321-
cfg.data.test_image = _combine_path(cfg.data.test_path, cfg.data.test_image)
322-
cfg.data.test_label = _combine_path(cfg.data.test_path, cfg.data.test_label)
323-
cfg.data.test_mask = _combine_path(cfg.data.test_path, cfg.data.test_mask)
324-
cfg.data.test_json = _combine_path(cfg.data.test_path, cfg.data.test_json)
340+
# Resolve training paths (always expand globs, use train_path as base if available)
341+
train_base = cfg.data.train_path if cfg.data.train_path else ""
342+
cfg.data.train_image = _combine_path(train_base, cfg.data.train_image)
343+
cfg.data.train_label = _combine_path(train_base, cfg.data.train_label)
344+
cfg.data.train_mask = _combine_path(train_base, cfg.data.train_mask)
345+
cfg.data.train_json = _combine_path(train_base, cfg.data.train_json)
346+
347+
# Resolve validation paths (always expand globs, use val_path as base if available)
348+
val_base = cfg.data.val_path if cfg.data.val_path else ""
349+
cfg.data.val_image = _combine_path(val_base, cfg.data.val_image)
350+
cfg.data.val_label = _combine_path(val_base, cfg.data.val_label)
351+
cfg.data.val_mask = _combine_path(val_base, cfg.data.val_mask)
352+
cfg.data.val_json = _combine_path(val_base, cfg.data.val_json)
353+
354+
# Resolve test paths (always expand globs, use test_path as base if available)
355+
test_base = cfg.data.test_path if cfg.data.test_path else ""
356+
cfg.data.test_image = _combine_path(test_base, cfg.data.test_image)
357+
cfg.data.test_label = _combine_path(test_base, cfg.data.test_label)
358+
cfg.data.test_mask = _combine_path(test_base, cfg.data.test_mask)
359+
cfg.data.test_json = _combine_path(test_base, cfg.data.test_json)
325360

326361
# Resolve inference data paths (primary location for test_path)
362+
inference_test_base = ""
327363
if hasattr(cfg.inference.data, 'test_path') and cfg.inference.data.test_path:
328-
cfg.inference.data.test_image = _combine_path(cfg.inference.data.test_path, cfg.inference.data.test_image)
329-
cfg.inference.data.test_label = _combine_path(cfg.inference.data.test_path, cfg.inference.data.test_label)
330-
cfg.inference.data.test_mask = _combine_path(cfg.inference.data.test_path, cfg.inference.data.test_mask)
364+
inference_test_base = cfg.inference.data.test_path
365+
cfg.inference.data.test_image = _combine_path(inference_test_base, cfg.inference.data.test_image)
366+
cfg.inference.data.test_label = _combine_path(inference_test_base, cfg.inference.data.test_label)
367+
cfg.inference.data.test_mask = _combine_path(inference_test_base, cfg.inference.data.test_mask)
331368

332369
return cfg
333370

0 commit comments

Comments
 (0)