Skip to content

Commit 16f5399

Browse files
author
Donglai Wei
committed
add optuna parameter search
1 parent 07657e2 commit 16f5399

29 files changed

+4146
-168
lines changed

OPTUNA_INTEGRATION.md

Lines changed: 466 additions & 0 deletions
Large diffs are not rendered by default.

QUICKSTART.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ The Lucchi++ dataset contains mitochondria segmentation data from EM images.
9999
```bash
100100
# Download from HuggingFace (recommended)
101101
mkdir -p datasets/
102-
wget https://huggingface.co/datasets/pytc/tutorial/resolve/main/Lucchi%2B%2B.zip
103-
unzip Lucchi++.zip -d datasets/
104-
rm Lucchi++.zip
102+
wget https://huggingface.co/datasets/pytc/tutorial/resolve/main/lucchi%2B%2B.zip
103+
unzip lucchi++.zip -d datasets/
104+
rm lucchi++.zip
105105
```
106106

107107
**Size:** ~100 MB

TROUBLESHOOTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,8 @@ data:
321321
**Solution - Manual download:**
322322
```bash
323323
# Download from HuggingFace
324-
wget https://huggingface.co/datasets/pytc/tutorial/resolve/main/Lucchi%2B%2B.zip
325-
unzip Lucchi++.zip -d datasets/
324+
wget https://huggingface.co/datasets/pytc/tutorial/resolve/main/lucchi%2B%2B.zip
325+
unzip lucchi++.zip -d datasets/
326326

327327
# Or use git-lfs
328328
git lfs install

connectomics/decoding/segmentation.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,20 @@ def decode_binary_thresholding(
140140

141141
# Apply morphological opening (erosion + dilation) - removes small objects
142142
if opening_iterations > 0:
143-
binary_mask = ndimage.binary_opening(
144-
binary_mask, iterations=opening_iterations
145-
).astype(np.uint8)
143+
binary_mask = ndimage.binary_opening(binary_mask, iterations=opening_iterations).astype(
144+
np.uint8
145+
)
146146

147147
# Apply morphological closing (dilation + erosion) - fills small holes
148148
if closing_iterations > 0:
149-
binary_mask = ndimage.binary_closing(
150-
binary_mask, iterations=closing_iterations
151-
).astype(np.uint8)
149+
binary_mask = ndimage.binary_closing(binary_mask, iterations=closing_iterations).astype(
150+
np.uint8
151+
)
152152

153153
# Apply connected components filtering
154-
if connected_components and connected_components.get('enabled', False):
155-
remove_small = connected_components.get('remove_small', 0)
156-
connectivity = connected_components.get('connectivity', 26)
154+
if connected_components and connected_components.get("enabled", False):
155+
remove_small = connected_components.get("remove_small", 0)
156+
connectivity = connected_components.get("connectivity", 26)
157157

158158
if remove_small > 0:
159159
# Label connected components
@@ -359,7 +359,7 @@ def decode_binary_contour_distance_watershed(
359359
min_seed_size: int = 32,
360360
return_seed: bool = False,
361361
precomputed_seed: Optional[np.ndarray] = None,
362-
prediction_scale: int = 255,
362+
prediction_scale: int = 1,
363363
binary_channels: Optional[List[int]] = None,
364364
contour_channels: Optional[List[int]] = None,
365365
distance_channels: Optional[List[int]] = None,
@@ -456,15 +456,19 @@ def decode_binary_contour_distance_watershed(
456456
else:
457457
# Position-based fallback (legacy behavior)
458458
if use_contour:
459-
assert predictions.shape[0] >= 3, f"Expected at least 3 channels (binary, contour, distance), got {predictions.shape[0]}"
459+
assert (
460+
predictions.shape[0] >= 3
461+
), f"Expected at least 3 channels (binary, contour, distance), got {predictions.shape[0]}"
460462
# If more than 3 channels, first N-2 channels are binary (average them)
461463
if predictions.shape[0] > 3:
462464
binary = predictions[:-2].mean(axis=0)
463465
contour, distance = predictions[-2], predictions[-1]
464466
else:
465467
binary, contour, distance = predictions[0], predictions[1], predictions[2]
466468
else:
467-
assert predictions.shape[0] >= 2, f"Expected at least 2 channels (binary, distance) when contour disabled, got {predictions.shape[0]}"
469+
assert (
470+
predictions.shape[0] >= 2
471+
), f"Expected at least 2 channels (binary, distance) when contour disabled, got {predictions.shape[0]}"
468472
# If more than 2 channels, first N-1 channels are binary (average them)
469473
if predictions.shape[0] > 2:
470474
binary = predictions[:-1].mean(axis=0)
@@ -476,10 +480,19 @@ def decode_binary_contour_distance_watershed(
476480
# Convert thresholds based on prediction scale
477481
if prediction_scale == 255:
478482
distance = (distance / prediction_scale) * 2.0 - 1.0
479-
binary_threshold = (binary_threshold[0] * prediction_scale, binary_threshold[1] * prediction_scale)
483+
binary_threshold = (
484+
binary_threshold[0] * prediction_scale,
485+
binary_threshold[1] * prediction_scale,
486+
)
480487
if use_contour:
481-
contour_threshold = (contour_threshold[0] * prediction_scale, contour_threshold[1] * prediction_scale)
482-
distance_threshold = (distance_threshold[0] * prediction_scale, distance_threshold[1] * prediction_scale)
488+
contour_threshold = (
489+
contour_threshold[0] * prediction_scale,
490+
contour_threshold[1] * prediction_scale,
491+
)
492+
distance_threshold = (
493+
distance_threshold[0] * prediction_scale,
494+
distance_threshold[1] * prediction_scale,
495+
)
483496

484497
if precomputed_seed is not None:
485498
seed = precomputed_seed
@@ -492,10 +505,7 @@ def decode_binary_contour_distance_watershed(
492505
)
493506
else:
494507
# No contour constraint - only binary and distance
495-
seed_map = (
496-
(binary > binary_threshold[0])
497-
* (distance > distance_threshold[0])
498-
)
508+
seed_map = (binary > binary_threshold[0]) * (distance > distance_threshold[0])
499509
seed = cc3d.connected_components(seed_map)
500510
seed = remove_small_objects(seed, min_seed_size)
501511

@@ -507,10 +517,7 @@ def decode_binary_contour_distance_watershed(
507517
)
508518
else:
509519
# No contour constraint - only binary and distance
510-
foreground = (
511-
(binary > binary_threshold[1])
512-
* (distance > distance_threshold[1])
513-
)
520+
foreground = (binary > binary_threshold[1]) * (distance > distance_threshold[1])
514521

515522
segmentation = mahotas.cwatershed(-distance.astype(np.float64), seed)
516523
segmentation[~foreground] = (

connectomics/models/loss/build.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,51 +64,32 @@ def create_loss(
6464
loss_registry = {
6565
# MONAI Dice variants
6666
'DiceLoss': DiceLoss,
67-
'Dice': DiceLoss, # Alias
6867
'DiceCELoss': DiceCELoss,
69-
'DiceCE': DiceCELoss, # Alias
7068
'DiceFocalLoss': DiceFocalLoss,
71-
'DiceFocal': DiceFocalLoss, # Alias
7269
'GeneralizedDiceLoss': GeneralizedDiceLoss,
73-
'GDiceLoss': GeneralizedDiceLoss, # Alias
7470

7571
# MONAI other losses
7672
'FocalLoss': FocalLoss,
77-
'Focal': FocalLoss, # Alias
7873
'TverskyLoss': TverskyLoss,
79-
'Tversky': TverskyLoss, # Alias
8074

8175
# PyTorch standard losses (for convenience)
8276
'BCEWithLogitsLoss': nn.BCEWithLogitsLoss,
83-
'BCE': nn.BCEWithLogitsLoss, # Alias
8477
'CrossEntropyLoss': CrossEntropyLossWrapper, # Use wrapper for shape handling
85-
'CE': CrossEntropyLossWrapper, # Alias
8678
'MSELoss': nn.MSELoss,
87-
'MSE': nn.MSELoss, # Alias
8879
'L1Loss': nn.L1Loss,
89-
'L1': nn.L1Loss, # Alias
9080

9181
# Custom connectomics losses
9282
'WeightedBCEWithLogitsLoss': WeightedBCEWithLogitsLoss,
93-
'WeightedBCE': WeightedBCEWithLogitsLoss, # Alias
9483
'WeightedMSELoss': WeightedMSELoss,
95-
'WeightedMSE': WeightedMSELoss, # Alias
9684
'WeightedMAELoss': WeightedMAELoss,
97-
'WeightedMAE': WeightedMAELoss, # Alias
9885
'GANLoss': GANLoss,
99-
'GAN': GANLoss, # Alias
10086

10187
# Regularization losses
10288
'BinaryRegularization': BinaryRegularization,
103-
'BinaryReg': BinaryRegularization, # Alias
10489
'ForegroundDistanceConsistency': ForegroundDistanceConsistency,
105-
'FgDTConsistency': ForegroundDistanceConsistency, # Alias
10690
'ContourDistanceConsistency': ContourDistanceConsistency,
107-
'ContourDTConsistency': ContourDistanceConsistency, # Alias
10891
'ForegroundContourConsistency': ForegroundContourConsistency,
109-
'FgContourConsistency': ForegroundContourConsistency, # Alias
11092
'NonOverlapRegularization': NonOverlapRegularization,
111-
'NonoverlapReg': NonOverlapRegularization, # Alias
11293
}
11394

11495
if loss_name not in loss_registry:

connectomics/utils/download.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,26 @@
1515
DATASETS = {
1616
"lucchi": {
1717
"name": "Lucchi++ Mitochondria Segmentation",
18-
"url": "https://huggingface.co/datasets/pytc/tutorial/resolve/main/Lucchi%2B%2B.zip",
18+
"url": "https://huggingface.co/datasets/pytc/tutorial/resolve/main/lucchi%2B%2B.zip",
1919
"size_mb": 100,
2020
"description": "EM images with mitochondria annotations from Lucchi et al.",
2121
"files": [
22-
"datasets/Lucchi++/train_image.h5",
23-
"datasets/Lucchi++/train_label.h5",
24-
"datasets/Lucchi++/test_image.h5",
25-
"datasets/Lucchi++/test_label.h5",
22+
"datasets/lucchi++/train_image.h5",
23+
"datasets/lucchi++/train_label.h5",
24+
"datasets/lucchi++/test_image.h5",
25+
"datasets/lucchi++/test_label.h5",
2626
],
2727
},
2828
"lucchi++": { # Alias
2929
"name": "Lucchi++ Mitochondria Segmentation",
30-
"url": "https://huggingface.co/datasets/pytc/tutorial/resolve/main/Lucchi%2B%2B.zip",
30+
"url": "https://huggingface.co/datasets/pytc/tutorial/resolve/main/lucchi%2B%2B.zip",
3131
"size_mb": 100,
3232
"description": "EM images with mitochondria annotations from Lucchi et al.",
3333
"files": [
34-
"datasets/Lucchi++/train_image.h5",
35-
"datasets/Lucchi++/train_label.h5",
36-
"datasets/Lucchi++/test_image.h5",
37-
"datasets/Lucchi++/test_label.h5",
34+
"datasets/lucchi++/train_image.h5",
35+
"datasets/lucchi++/train_label.h5",
36+
"datasets/lucchi++/test_image.h5",
37+
"datasets/lucchi++/test_label.h5",
3838
],
3939
},
4040
}

justfile

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,26 @@ test dataset arch_or_ckpt ckpt_or_args='' *ARGS='':
5656
python scripts/main.py --config tutorials/{{dataset}}.yaml model.architecture={{arch_or_ckpt}} --mode test --checkpoint {{ckpt_or_args}} {{ARGS}}
5757
fi
5858

59+
# Tune decoding parameters on validation set (e.g., just tune hydra-lv ckpt.pt)
60+
tune dataset ckpt *ARGS='':
61+
python scripts/main.py --config tutorials/{{dataset}}.yaml --mode tune --checkpoint {{ckpt}} {{ARGS}}
62+
63+
# Tune parameters then test (recommended for optimal results)
64+
tune-test dataset ckpt *ARGS='':
65+
python scripts/main.py --config tutorials/{{dataset}}.yaml --mode tune+test --checkpoint {{ckpt}} {{ARGS}}
66+
67+
# Quick tuning with 20 trials (for testing)
68+
tune-quick dataset ckpt *ARGS='':
69+
python scripts/main.py --config tutorials/{{dataset}}.yaml --mode tune --checkpoint {{ckpt}} --tune-trials 20 {{ARGS}}
70+
71+
# Test with specific parameter file
72+
test-with-params dataset ckpt params *ARGS='':
73+
python scripts/main.py --config tutorials/{{dataset}}.yaml --mode test --checkpoint {{ckpt}} --params {{params}} {{ARGS}}
74+
75+
# Inference (alias for test, clearer naming)
76+
infer dataset ckpt *ARGS='':
77+
python scripts/main.py --config tutorials/{{dataset}}.yaml --mode infer --checkpoint {{ckpt}} {{ARGS}}
78+
5979
# ============================================================================
6080
# Monitoring Commands
6181
# ============================================================================

0 commit comments

Comments
 (0)