@@ -181,6 +181,12 @@ class ModelConfig:
181181 rsunet_act_negative_slope : float = 0.01 # For LeakyReLU
182182 rsunet_act_init : float = 0.25 # For PReLU
183183
184+ # nnUNet-specific parameters (for loading pretrained models)
185+ nnunet_checkpoint : Optional [str ] = None # Path to pretrained checkpoint (.pth file)
186+ nnunet_plans : Optional [str ] = None # Path to plans.json file
187+ nnunet_dataset : Optional [str ] = None # Path to dataset.json file
188+ nnunet_device : str = "cuda" # Device for model loading ('cuda' or 'cpu')
189+
184190 # Deep supervision (supported by MedNeXt, RSUNet, and some MONAI models)
185191 deep_supervision : bool = False
186192 deep_supervision_weights : Optional [List [float ]] = (
@@ -872,7 +878,10 @@ class SlidingWindowConfig:
872878
873879@dataclass
874880class TestTimeAugmentationConfig :
875- """Test-time augmentation configuration."""
881+ """Test-time augmentation configuration.
882+
883+ Note: Saving predictions is now handled by SavePredictionConfig.
884+ """
876885
877886 enabled : bool = False
878887 flip_axes : Any = (
@@ -886,12 +895,24 @@ class TestTimeAugmentationConfig:
886895 )
887896 ensemble_mode : str = "mean" # Ensemble mode for TTA: 'mean', 'min', 'max'
888897 apply_mask : bool = False # Multiply each channel by corresponding test_mask after ensemble
889- save_predictions : bool = (
890- False # Save intermediate TTA predictions (before decoding) to disk (default: False)
891- )
892- save_dtype : Optional [str ] = (
893- None # Data type for saving predictions: "float16", "float32", "uint8", "uint16", or None (keep original)
894- )
898+
899+
900+ @dataclass
901+ class SavePredictionConfig :
902+ """Configuration for saving intermediate predictions during inference.
903+
904+ Controls how raw model predictions are saved before any decoding or postprocessing.
905+ Useful for debugging, visualization, or running multiple decoding strategies.
906+
907+ Attributes:
908+ enabled: Enable saving intermediate predictions (default: True)
909+ intensity_scale: Scale factor for predictions (e.g., 255 for uint8 visualization)
910+ intensity_dtype: Data type for saved predictions (e.g., 'uint8', 'float32')
911+ """
912+
913+ enabled : bool = True # Enable saving intermediate predictions
914+ intensity_scale : float = 255.0 # Scale predictions to [0, 255] for saving
915+ intensity_dtype : str = "uint8" # Save as uint8 for visualization
895916
896917
897918@dataclass
@@ -971,28 +992,18 @@ class ConnectedComponentsConfig:
971992class PostprocessingConfig :
972993 """Postprocessing configuration for inference output.
973994
974- Controls how predictions are transformed before saving:
995+ Controls how predictions are transformed after saving:
975996 - Binary refinement: Morphological operations and connected components filtering
976- - Scaling: Multiply intensity values (e.g., 255 for uint8)
977- - Dtype conversion: Convert to target data type with proper clamping
978997 - Transpose: Reorder axes (e.g., [2,1,0] for zyx->xyz)
998+
999+ Note: Intensity scaling and dtype conversion are handled by SavePredictionConfig.
9791000 """
9801001
9811002 # Binary segmentation refinement (morphological ops, connected components)
9821003 binary : Optional [Dict [str , Any ]] = field (
9831004 default_factory = dict
9841005 ) # Binary postprocessing config (e.g., {'opening_iterations': 2})
9851006
986- # Intensity scaling
987- intensity_scale : Optional [float ] = (
988- None # Scale predictions for saving (e.g., 255.0 for uint8). None = no scaling
989- )
990-
991- # Data type conversion
992- intensity_dtype : Optional [str ] = (
993- None # Output data type: 'uint8', 'uint16', 'float32'. None = no conversion (keep as-is)
994- )
995-
9961007 # Axis permutation
9971008 output_transpose : List [int ] = field (
9981009 default_factory = list
@@ -1017,6 +1028,7 @@ class InferenceConfig:
10171028 Key Features:
10181029 - Sliding window inference for large volumes
10191030 - Test-time augmentation (TTA) support
1031+ - Saving intermediate predictions
10201032 - Multiple decoding strategies
10211033 - Postprocessing and evaluation
10221034 - System resource overrides for inference
@@ -1028,6 +1040,7 @@ class InferenceConfig:
10281040 test_time_augmentation : TestTimeAugmentationConfig = field (
10291041 default_factory = TestTimeAugmentationConfig
10301042 )
1043+ save_prediction : SavePredictionConfig = field (default_factory = SavePredictionConfig )
10311044 decoding : Optional [List [DecodeModeConfig ]] = None # List of decode modes to apply sequentially
10321045 postprocessing : PostprocessingConfig = field (default_factory = PostprocessingConfig )
10331046 evaluation : EvaluationConfig = field (default_factory = EvaluationConfig )
@@ -1255,6 +1268,7 @@ def configure_instance_segmentation(cfg: Config, boundary_thickness: int = 5) ->
12551268 "InferenceDataConfig" ,
12561269 "SlidingWindowConfig" ,
12571270 "TestTimeAugmentationConfig" ,
1271+ "SavePredictionConfig" ,
12581272 "PostprocessingConfig" ,
12591273 "BinaryPostprocessingConfig" ,
12601274 "ConnectedComponentsConfig" ,
0 commit comments