Skip to content

Commit 124f603

Browse files
author
Donglai Wei
committed
fix zebrafish_neurons.yaml: need multiple loss functions for each channel
1 parent 65f7a12 commit 124f603

File tree

6 files changed

+587
-84
lines changed

6 files changed

+587
-84
lines changed

connectomics/config/hydra_config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,8 @@ class SavePredictionConfig:
911911
"""
912912

913913
enabled: bool = True # Enable saving intermediate predictions
914-
intensity_scale: float = 255.0 # Scale predictions to [0, 255] for saving
915-
intensity_dtype: str = "uint8" # Save as uint8 for visualization
914+
intensity_scale: float = -1.0 # If < 0, keep raw predictions (no normalization/scaling). If > 0, normalize to [0,1] then scale.
915+
intensity_dtype: str = "uint8" # Save as uint8 for visualization (ignored if intensity_scale < 0)
916916

917917

918918
@dataclass
@@ -999,6 +999,8 @@ class PostprocessingConfig:
999999
Note: Intensity scaling and dtype conversion are handled by SavePredictionConfig.
10001000
"""
10011001

1002+
enabled: bool = False # Enable postprocessing pipeline
1003+
10021004
# Binary segmentation refinement (morphological ops, connected components)
10031005
binary: Optional[Dict[str, Any]] = field(
10041006
default_factory=dict
@@ -1059,7 +1061,7 @@ class TestDataConfig:
10591061
test_image: Optional[str] = None
10601062
test_label: Optional[str] = None
10611063
test_mask: Optional[str] = None
1062-
test_resolution: Optional[List[int]] = None
1064+
test_resolution: Optional[List[float]] = None
10631065
test_transpose: Optional[List[int]] = None
10641066
output_path: Optional[str] = None
10651067
cache_suffix: str = "_prediction.h5"

connectomics/decoding/segmentation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from skimage.morphology import dilation, remove_small_objects
2020
from scipy.ndimage import zoom
2121
import mahotas
22+
from connectomics.data.process.target import seg_to_semantic_edt
2223

2324
try:
2425
from numba import jit
@@ -175,10 +176,7 @@ def decode_instance_binary_contour_distance(
175176
elif mode == "watershed":
176177
# Watershed mode requires distance channel
177178
if distance is None:
178-
raise ValueError(
179-
"Watershed mode requires distance channel. "
180-
"Please specify distance_channels in your decode configuration."
181-
)
179+
distance = seg_to_semantic_edt(foreground, mode="3d")
182180
# step 2: compute the instance seeds
183181
if precomputed_seed is not None:
184182
seed = precomputed_seed

connectomics/inference/io.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,59 @@ def apply_save_prediction_transform(cfg: Config | DictConfig, data: np.ndarray)
2222
2323
This is used when saving intermediate predictions (before decoding).
2424
25+
Default behavior (no config):
26+
- Normalizes predictions to [0, 1] using min-max normalization
27+
- Keeps dtype as float32
28+
29+
Config options:
30+
- intensity_scale: If < 0, disables normalization (raw values)
31+
If > 0, normalize to [0, 1] then multiply by scale
32+
- intensity_dtype: Target dtype for conversion (uint8, float32, etc.)
33+
2534
Args:
2635
cfg: Configuration object
2736
data: Predictions array to transform
2837
2938
Returns:
3039
Transformed predictions with applied scaling and dtype conversion
3140
"""
32-
if not hasattr(cfg, "inference") or not hasattr(cfg.inference, "save_prediction"):
33-
return data
41+
# Default: keep raw predictions if no config
42+
intensity_scale = -1.0 # Default: keep raw predictions
43+
44+
if hasattr(cfg, "inference") and hasattr(cfg.inference, "save_prediction"):
45+
save_pred_cfg = cfg.inference.save_prediction
46+
intensity_scale = getattr(save_pred_cfg, "intensity_scale", -1.0)
47+
48+
# Apply intensity scaling (if intensity_scale >= 0, normalize to [0, 1] then scale)
49+
if intensity_scale >= 0:
50+
# Convert to float32 for normalization
51+
data = data.astype(np.float32)
52+
53+
# Min-max normalization to [0, 1]
54+
data_min = data.min()
55+
data_max = data.max()
3456

35-
save_pred_cfg = cfg.inference.save_prediction
57+
if data_max > data_min:
58+
data = (data - data_min) / (data_max - data_min)
59+
print(f" Normalized predictions to [0, 1] (min={data_min:.4f}, max={data_max:.4f})")
60+
else:
61+
print(f" Warning: data_min == data_max ({data_min:.4f}), skipping normalization")
3662

37-
# Apply intensity scaling
38-
intensity_scale = getattr(save_pred_cfg, "intensity_scale", None)
39-
if intensity_scale is not None and intensity_scale != 1.0:
40-
data = data * float(intensity_scale)
63+
# Apply scaling factor
64+
if intensity_scale != 1.0:
65+
data = data * float(intensity_scale)
66+
print(f" Scaled predictions by {intensity_scale} -> range [{data.min():.4f}, {data.max():.4f}]")
67+
else:
68+
print(f" Intensity scaling disabled (scale={intensity_scale} < 0), keeping raw predictions")
69+
# Skip dtype conversion when intensity_scale < 0 to preserve raw predictions
70+
return data
4171

4272
# Apply dtype conversion
43-
target_dtype_str = getattr(save_pred_cfg, "intensity_dtype", None)
73+
target_dtype_str = None
74+
if hasattr(cfg, "inference") and hasattr(cfg.inference, "save_prediction"):
75+
save_pred_cfg = cfg.inference.save_prediction
76+
target_dtype_str = getattr(save_pred_cfg, "intensity_dtype", None)
77+
4478
if target_dtype_str is not None:
4579
dtype_map = {
4680
"uint8": np.uint8,
@@ -68,6 +102,7 @@ def apply_save_prediction_transform(cfg: Config | DictConfig, data: np.ndarray)
68102
if np.issubdtype(target_dtype, np.integer):
69103
info = np.iinfo(target_dtype)
70104
data = np.clip(data, info.min, info.max)
105+
print(f" Converting to {target_dtype_str} (clipped to [{info.min}, {info.max}])")
71106

72107
data = data.astype(target_dtype)
73108

@@ -89,6 +124,10 @@ def apply_postprocessing(cfg: Config | DictConfig, data: np.ndarray) -> np.ndarr
89124

90125
postprocessing = cfg.inference.postprocessing
91126

127+
# Check if postprocessing is enabled
128+
if not getattr(postprocessing, "enabled", False):
129+
return data
130+
92131
binary_config = getattr(postprocessing, "binary", None)
93132
if binary_config is not None and getattr(binary_config, "enabled", False):
94133
from connectomics.decoding.postprocess import apply_binary_postprocessing

scripts/main.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -144,41 +144,58 @@ def main():
144144
# Setup run directory (handles DDP coordination and config saving)
145145
# Determine output base directory from checkpoint for test/tune modes
146146
if args.mode in ["test", "tune", "tune-test"] and args.checkpoint:
147-
# Extract base directory from checkpoint path
148-
output_base = get_output_base_from_checkpoint(args.checkpoint)
149-
output_base.mkdir(parents=True, exist_ok=True)
147+
# Check if test.data.output_path is already set (skip checkpoint-based path extraction)
148+
test_output_path_set = (
149+
hasattr(cfg, "test")
150+
and hasattr(cfg.test, "data")
151+
and getattr(cfg.test.data, "output_path", None)
152+
)
150153

151-
# Create mode-specific subdirectories
152-
if args.mode in ["tune", "tune-test"]:
153-
dirpath = str(output_base / "tuning")
154-
results_path = str(output_base / "results")
155-
# Override tune output directories in config
156-
if cfg.tune is not None:
157-
cfg.tune.output.output_dir = dirpath
158-
cfg.tune.output.output_pred = results_path
159-
# For tune-test, also set test output directory and cache suffix
160-
if args.mode == "tune-test":
161-
print(f"🔍 Setting test config for tune-test mode")
162-
print(f"🔍 cfg.test is None: {cfg.test is None}")
163-
if cfg.test is not None:
164-
print(f"🔍 cfg.test.data is None: {cfg.test.data is None}")
165-
if cfg.test.data is not None:
166-
cfg.test.data.output_path = results_path
167-
cfg.test.data.cache_suffix = cfg.tune.output.cache_suffix
168-
print(f"📋 Test output: {cfg.test.data.output_path}")
169-
print(f"📋 Test cache suffix: {cfg.test.data.cache_suffix}")
154+
if args.mode == "test" and test_output_path_set:
155+
# Use the config value directly, skip checkpoint-based directory creation
156+
dirpath = str(cfg.test.data.output_path)
157+
output_base = None # Not needed when using config path
158+
else:
159+
# Extract base directory from checkpoint path
160+
output_base = get_output_base_from_checkpoint(args.checkpoint)
161+
output_base.mkdir(parents=True, exist_ok=True)
162+
163+
# Create mode-specific subdirectories
164+
if args.mode in ["tune", "tune-test"]:
165+
dirpath = str(output_base / "tuning")
166+
results_path = str(output_base / "results")
167+
# Override tune output directories in config
168+
if cfg.tune is not None:
169+
cfg.tune.output.output_dir = dirpath
170+
cfg.tune.output.output_pred = results_path
171+
# For tune-test, also set test output directory and cache suffix
172+
if args.mode == "tune-test":
173+
print(f"🔍 Setting test config for tune-test mode")
174+
print(f"🔍 cfg.test is None: {cfg.test is None}")
175+
if cfg.test is not None:
176+
print(f"🔍 cfg.test.data is None: {cfg.test.data is None}")
177+
if cfg.test.data is not None:
178+
cfg.test.data.output_path = results_path
179+
cfg.test.data.cache_suffix = cfg.tune.output.cache_suffix
180+
print(f"📋 Test output: {cfg.test.data.output_path}")
181+
print(f"📋 Test cache suffix: {cfg.test.data.cache_suffix}")
182+
else:
183+
print(f"❌ cfg.test.data is None, cannot set cache_suffix!")
184+
else:
185+
print(f"❌ cfg.test is None, cannot set cache_suffix!")
186+
else: # test mode
187+
dirpath = str(output_base / "results")
188+
# Override test output directory in config only if not already set
189+
if hasattr(cfg, "test") and hasattr(cfg.test, "data"):
190+
if not getattr(cfg.test.data, "output_path", None):
191+
cfg.test.data.output_path = dirpath
170192
else:
171-
print(f"❌ cfg.test.data is None, cannot set cache_suffix!")
172-
else:
173-
print(f"❌ cfg.test is None, cannot set cache_suffix!")
174-
else: # test mode
175-
dirpath = str(output_base / "results")
176-
# Override test output directory in config
177-
if hasattr(cfg, "test") and hasattr(cfg.test, "data"):
178-
cfg.test.data.output_path = dirpath
193+
# Use the config value, but ensure it's a string path
194+
dirpath = str(cfg.test.data.output_path)
179195

180196
run_dir = setup_run_directory(args.mode, cfg, dirpath)
181-
print(f"📂 Output base: {output_base}")
197+
if output_base is not None:
198+
print(f"📂 Output base: {output_base}")
182199
else:
183200
# Train mode or no checkpoint - use default config paths
184201
dirpath = cfg.monitor.checkpoint.dirpath

0 commit comments

Comments
 (0)