Skip to content

Commit 65f7a12

Browse files
author
Donglai Wei
committed
update mito documentation
1 parent a18b392 commit 65f7a12

24 files changed

+1063
-187
lines changed

connectomics/config/hydra_config.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ class ModelConfig:
181181
rsunet_act_negative_slope: float = 0.01 # For LeakyReLU
182182
rsunet_act_init: float = 0.25 # For PReLU
183183

184+
# nnUNet-specific parameters (for loading pretrained models)
185+
nnunet_checkpoint: Optional[str] = None # Path to pretrained checkpoint (.pth file)
186+
nnunet_plans: Optional[str] = None # Path to plans.json file
187+
nnunet_dataset: Optional[str] = None # Path to dataset.json file
188+
nnunet_device: str = "cuda" # Device for model loading ('cuda' or 'cpu')
189+
184190
# Deep supervision (supported by MedNeXt, RSUNet, and some MONAI models)
185191
deep_supervision: bool = False
186192
deep_supervision_weights: Optional[List[float]] = (
@@ -872,7 +878,10 @@ class SlidingWindowConfig:
872878

873879
@dataclass
874880
class TestTimeAugmentationConfig:
875-
"""Test-time augmentation configuration."""
881+
"""Test-time augmentation configuration.
882+
883+
Note: Saving predictions is now handled by SavePredictionConfig.
884+
"""
876885

877886
enabled: bool = False
878887
flip_axes: Any = (
@@ -886,12 +895,24 @@ class TestTimeAugmentationConfig:
886895
)
887896
ensemble_mode: str = "mean" # Ensemble mode for TTA: 'mean', 'min', 'max'
888897
apply_mask: bool = False # Multiply each channel by corresponding test_mask after ensemble
889-
save_predictions: bool = (
890-
False # Save intermediate TTA predictions (before decoding) to disk (default: False)
891-
)
892-
save_dtype: Optional[str] = (
893-
None # Data type for saving predictions: "float16", "float32", "uint8", "uint16", or None (keep original)
894-
)
898+
899+
900+
@dataclass
901+
class SavePredictionConfig:
902+
"""Configuration for saving intermediate predictions during inference.
903+
904+
Controls how raw model predictions are saved before any decoding or postprocessing.
905+
Useful for debugging, visualization, or running multiple decoding strategies.
906+
907+
Attributes:
908+
enabled: Enable saving intermediate predictions (default: True)
909+
intensity_scale: Scale factor for predictions (e.g., 255 for uint8 visualization)
910+
intensity_dtype: Data type for saved predictions (e.g., 'uint8', 'float32')
911+
"""
912+
913+
enabled: bool = True # Enable saving intermediate predictions
914+
intensity_scale: float = 255.0 # Scale predictions to [0, 255] for saving
915+
intensity_dtype: str = "uint8" # Save as uint8 for visualization
895916

896917

897918
@dataclass
@@ -971,28 +992,18 @@ class ConnectedComponentsConfig:
971992
class PostprocessingConfig:
972993
"""Postprocessing configuration for inference output.
973994
974-
Controls how predictions are transformed before saving:
995+
Controls how predictions are transformed after saving:
975996
- Binary refinement: Morphological operations and connected components filtering
976-
- Scaling: Multiply intensity values (e.g., 255 for uint8)
977-
- Dtype conversion: Convert to target data type with proper clamping
978997
- Transpose: Reorder axes (e.g., [2,1,0] for zyx->xyz)
998+
999+
Note: Intensity scaling and dtype conversion are handled by SavePredictionConfig.
9791000
"""
9801001

9811002
# Binary segmentation refinement (morphological ops, connected components)
9821003
binary: Optional[Dict[str, Any]] = field(
9831004
default_factory=dict
9841005
) # Binary postprocessing config (e.g., {'opening_iterations': 2})
9851006

986-
# Intensity scaling
987-
intensity_scale: Optional[float] = (
988-
None # Scale predictions for saving (e.g., 255.0 for uint8). None = no scaling
989-
)
990-
991-
# Data type conversion
992-
intensity_dtype: Optional[str] = (
993-
None # Output data type: 'uint8', 'uint16', 'float32'. None = no conversion (keep as-is)
994-
)
995-
9961007
# Axis permutation
9971008
output_transpose: List[int] = field(
9981009
default_factory=list
@@ -1017,6 +1028,7 @@ class InferenceConfig:
10171028
Key Features:
10181029
- Sliding window inference for large volumes
10191030
- Test-time augmentation (TTA) support
1031+
- Saving intermediate predictions
10201032
- Multiple decoding strategies
10211033
- Postprocessing and evaluation
10221034
- System resource overrides for inference
@@ -1028,6 +1040,7 @@ class InferenceConfig:
10281040
test_time_augmentation: TestTimeAugmentationConfig = field(
10291041
default_factory=TestTimeAugmentationConfig
10301042
)
1043+
save_prediction: SavePredictionConfig = field(default_factory=SavePredictionConfig)
10311044
decoding: Optional[List[DecodeModeConfig]] = None # List of decode modes to apply sequentially
10321045
postprocessing: PostprocessingConfig = field(default_factory=PostprocessingConfig)
10331046
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
@@ -1255,6 +1268,7 @@ def configure_instance_segmentation(cfg: Config, boundary_thickness: int = 5) ->
12551268
"InferenceDataConfig",
12561269
"SlidingWindowConfig",
12571270
"TestTimeAugmentationConfig",
1271+
"SavePredictionConfig",
12581272
"PostprocessingConfig",
12591273
"BinaryPostprocessingConfig",
12601274
"ConnectedComponentsConfig",

connectomics/inference/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
"""Inference utilities package."""
22

33
from .manager import InferenceManager
4-
from .io import apply_postprocessing, apply_decode_mode, resolve_output_filenames, write_outputs
4+
from .io import (
5+
apply_save_prediction_transform,
6+
apply_postprocessing,
7+
apply_decode_mode,
8+
resolve_output_filenames,
9+
write_outputs,
10+
)
511
from .sliding import build_sliding_inferer, resolve_inferer_roi_size, resolve_inferer_overlap
612
from .tta import TTAPredictor
713

814
__all__ = [
915
"InferenceManager",
16+
"apply_save_prediction_transform",
1017
"apply_postprocessing",
1118
"apply_decode_mode",
1219
"resolve_output_filenames",

connectomics/inference/io.py

Lines changed: 70 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,73 @@
1616
from ..config import Config
1717

1818

19+
def apply_save_prediction_transform(cfg: Config | DictConfig, data: np.ndarray) -> np.ndarray:
20+
"""
21+
Apply intensity scaling and dtype conversion from save_prediction config.
22+
23+
This is used when saving intermediate predictions (before decoding).
24+
25+
Args:
26+
cfg: Configuration object
27+
data: Predictions array to transform
28+
29+
Returns:
30+
Transformed predictions with applied scaling and dtype conversion
31+
"""
32+
if not hasattr(cfg, "inference") or not hasattr(cfg.inference, "save_prediction"):
33+
return data
34+
35+
save_pred_cfg = cfg.inference.save_prediction
36+
37+
# Apply intensity scaling
38+
intensity_scale = getattr(save_pred_cfg, "intensity_scale", None)
39+
if intensity_scale is not None and intensity_scale != 1.0:
40+
data = data * float(intensity_scale)
41+
42+
# Apply dtype conversion
43+
target_dtype_str = getattr(save_pred_cfg, "intensity_dtype", None)
44+
if target_dtype_str is not None:
45+
dtype_map = {
46+
"uint8": np.uint8,
47+
"int8": np.int8,
48+
"uint16": np.uint16,
49+
"int16": np.int16,
50+
"uint32": np.uint32,
51+
"int32": np.int32,
52+
"float16": np.float16,
53+
"float32": np.float32,
54+
"float64": np.float64,
55+
}
56+
57+
if target_dtype_str not in dtype_map:
58+
warnings.warn(
59+
f"Unknown dtype '{target_dtype_str}' in save_prediction config. "
60+
f"Supported: {list(dtype_map.keys())}. Keeping current dtype.",
61+
UserWarning,
62+
)
63+
return data
64+
65+
target_dtype = dtype_map[target_dtype_str]
66+
67+
# Get dtype info for proper clamping
68+
if np.issubdtype(target_dtype, np.integer):
69+
info = np.iinfo(target_dtype)
70+
data = np.clip(data, info.min, info.max)
71+
72+
data = data.astype(target_dtype)
73+
74+
return data
75+
76+
1977
def apply_postprocessing(cfg: Config | DictConfig, data: np.ndarray) -> np.ndarray:
2078
"""
2179
Apply postprocessing transformations to predictions.
2280
23-
This method applies (in order):
81+
This method applies:
2482
1. Binary postprocessing (morphological operations, connected components filtering)
25-
2. Scaling (intensity_scale or output_scale): Multiply predictions by scale factor
26-
3. Dtype conversion (intensity_dtype or output_dtype): Convert to target dtype with clamping
83+
2. Axis transposition (output_transpose)
84+
85+
Note: Intensity scaling and dtype conversion are handled by apply_save_prediction_transform()
2786
"""
2887
if not hasattr(cfg, "inference") or not hasattr(cfg.inference, "postprocessing"):
2988
return data
@@ -82,53 +141,16 @@ def apply_postprocessing(cfg: Config | DictConfig, data: np.ndarray) -> np.ndarr
82141

83142
data = np.stack(results, axis=0)
84143

85-
intensity_scale = getattr(postprocessing, "intensity_scale", None)
86-
output_scale = getattr(postprocessing, "output_scale", None)
87-
scale = intensity_scale if intensity_scale is not None else output_scale
88-
if scale is not None:
89-
data = data * float(scale)
90-
91-
target_dtype_str = getattr(postprocessing, "intensity_dtype", None)
92-
if target_dtype_str is None:
93-
target_dtype_str = getattr(postprocessing, "output_dtype", None)
94-
95-
if target_dtype_str is not None:
96-
dtype_map = {
97-
"uint8": np.uint8,
98-
"int8": np.int8,
99-
"uint16": np.uint16,
100-
"int16": np.int16,
101-
"uint32": np.uint32,
102-
"int32": np.int32,
103-
"float16": np.float16,
104-
"float32": np.float32,
105-
"float64": np.float64,
106-
}
107-
108-
if target_dtype_str not in dtype_map:
144+
# Apply axis transposition if configured
145+
output_transpose = getattr(postprocessing, "output_transpose", [])
146+
if output_transpose and len(output_transpose) > 0:
147+
try:
148+
data = np.transpose(data, axes=output_transpose)
149+
except Exception as e:
109150
warnings.warn(
110-
f"Unknown dtype '{target_dtype_str}'. Supported: {list(dtype_map.keys())}. "
111-
f"Keeping float32.",
151+
f"Transpose failed with axes {output_transpose}: {e}. Keeping original shape.",
112152
UserWarning,
113153
)
114-
return data
115-
116-
target_dtype = dtype_map[target_dtype_str]
117-
118-
if target_dtype_str == "uint8":
119-
data = np.clip(data, 0, 255)
120-
elif target_dtype_str == "int8":
121-
data = np.clip(data, -128, 127)
122-
elif target_dtype_str == "uint16":
123-
data = np.clip(data, 0, 65535)
124-
elif target_dtype_str == "int16":
125-
data = np.clip(data, -32768, 32767)
126-
elif target_dtype_str == "uint32":
127-
data = np.clip(data, 0, 4294967295)
128-
elif target_dtype_str == "int32":
129-
data = np.clip(data, -2147483648, 2147483647)
130-
131-
data = data.astype(target_dtype)
132154

133155
return data
134156

@@ -342,8 +364,8 @@ def write_outputs(
342364
sample = np.squeeze(sample)
343365
write_hdf5(
344366
output_path,
345-
"main",
346367
sample.astype(np.float32) if not np.issubdtype(sample.dtype, np.integer) else sample,
368+
dataset="main",
347369
)
348370

349371
print(f" ✓ Saved prediction: {output_path}")

connectomics/models/arch/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def build_my_model(cfg):
5757
except ImportError:
5858
_RSUNET_AVAILABLE = False
5959

60+
# Import nnUNet models to trigger registration
61+
try:
62+
from . import nnunet_models
63+
_NNUNET_AVAILABLE = True
64+
except ImportError:
65+
_NNUNET_AVAILABLE = False
66+
6067
# Check what's available
6168
def get_available_architectures() -> dict:
6269
"""
@@ -75,6 +82,7 @@ def get_available_architectures() -> dict:
7582
'monai': [a for a in all_archs if a.startswith('monai_')] if _MONAI_AVAILABLE else [],
7683
'mednext': [a for a in all_archs if a.startswith('mednext')] if _MEDNEXT_AVAILABLE else [],
7784
'rsunet': [a for a in all_archs if a.startswith('rsunet')] if _RSUNET_AVAILABLE else [],
85+
'nnunet': [a for a in all_archs if a.startswith('nnunet')] if _NNUNET_AVAILABLE else [],
7886
}
7987

8088
return info
@@ -107,6 +115,13 @@ def print_available_architectures():
107115
for arch in info['rsunet']:
108116
print(f" - {arch}")
109117

118+
if info['nnunet']:
119+
print(f"\nnnUNet Models ({len(info['nnunet'])}):")
120+
for arch in info['nnunet']:
121+
print(f" - {arch}")
122+
else:
123+
print("\nnnUNet Models: Not available (install with: pip install nnunetv2)")
124+
110125
print(f"\nTotal: {len(info['all'])} architectures")
111126
print("="*60 + "\n")
112127

0 commit comments

Comments
 (0)