@@ -197,25 +197,52 @@ def build_train_transforms(
197197 return Compose (transforms )
198198
199199
200- def build_val_transforms (cfg : Config , keys : list [str ] = None ) -> Compose :
200+ def _build_eval_transforms_impl (
201+ cfg : Config , mode : str = "val" , keys : list [str ] = None
202+ ) -> Compose :
201203 """
202- Build validation transforms from Hydra config.
204+ Internal implementation for building evaluation transforms (validation or test).
205+
206+ This function contains the shared logic between validation and test transforms,
207+ with mode-specific branching for key differences.
203208
204209 Args:
205210 cfg: Hydra Config object
206- keys: Keys to transform (default: ['image', 'label'] or ['image', 'label', 'mask'] if masks are used)
211+ mode: 'val' or 'test' mode
212+ keys: Keys to transform (default: auto-detected based on mode)
207213
208214 Returns:
209215 Composed MONAI transforms (no augmentation)
210216 """
211217 if keys is None :
212- # Auto-detect keys based on config
213- keys = ["image" , "label" ]
214- # Add mask to keys if it's specified in the config (check both train and val masks)
215- if (hasattr (cfg .data , "val_mask" ) and cfg .data .val_mask is not None ) or (
216- hasattr (cfg .data , "train_mask" ) and cfg .data .train_mask is not None
217- ):
218- keys .append ("mask" )
218+ # Auto-detect keys based on mode
219+ if mode == "val" :
220+ # Validation: default to image+label
221+ keys = ["image" , "label" ]
222+ # Add mask if val_mask or train_mask exists
223+ if (hasattr (cfg .data , "val_mask" ) and cfg .data .val_mask is not None ) or (
224+ hasattr (cfg .data , "train_mask" ) and cfg .data .train_mask is not None
225+ ):
226+ keys .append ("mask" )
227+ else : # mode == "test"
228+ # Test/inference: default to image only
229+ keys = ["image" ]
230+ # Only add label if test_label is explicitly specified
231+ if (
232+ hasattr (cfg , "inference" )
233+ and hasattr (cfg .inference , "data" )
234+ and hasattr (cfg .inference .data , "test_label" )
235+ and cfg .inference .data .test_label is not None
236+ ):
237+ keys .append ("label" )
238+ # Add mask if test_mask is explicitly specified
239+ if (
240+ hasattr (cfg , "inference" )
241+ and hasattr (cfg .inference , "data" )
242+ and hasattr (cfg .inference .data , "test_mask" )
243+ and cfg .inference .data .test_mask is not None
244+ ):
245+ keys .append ("mask" )
219246
220247 transforms = []
221248
@@ -229,9 +256,24 @@ def build_val_transforms(cfg: Config, keys: list[str] = None) -> Compose:
229256 transforms .append (EnsureChannelFirstd (keys = keys ))
230257 else :
231258 # For volume-based datasets (HDF5, TIFF volumes), use custom LoadVolumed
232- val_transpose = cfg .data .val_transpose if cfg .data .val_transpose else []
259+ # Get transpose axes based on mode
260+ if mode == "val" :
261+ transpose_axes = cfg .data .val_transpose if cfg .data .val_transpose else []
262+ else : # mode == "test"
263+ # Check both data.test_transpose and inference.data.test_transpose
264+ transpose_axes = []
265+ if cfg .data .test_transpose :
266+ transpose_axes = cfg .data .test_transpose
267+ if (
268+ hasattr (cfg , "inference" )
269+ and hasattr (cfg .inference , "data" )
270+ and hasattr (cfg .inference .data , "test_transpose" )
271+ and cfg .inference .data .test_transpose
272+ ):
273+ transpose_axes = cfg .inference .data .test_transpose # inference takes precedence
274+
233275 transforms .append (
234- LoadVolumed (keys = keys , transpose_axes = val_transpose if val_transpose else None )
276+ LoadVolumed (keys = keys , transpose_axes = transpose_axes if transpose_axes else None )
235277 )
236278
237279 # Apply volumetric split if enabled
@@ -270,155 +312,18 @@ def build_val_transforms(cfg: Config, keys: list[str] = None) -> Compose:
270312 )
271313 )
272314
273- # Add spatial cropping to prevent loading full volumes (OOM fix)
274- # NOTE: If split is enabled with padding, this crop will be applied AFTER padding
275- if patch_size and all (size > 0 for size in patch_size ):
276- transforms .append (
277- CenterSpatialCropd (
278- keys = keys ,
279- roi_size = patch_size ,
280- )
281- )
282-
283- # Normalization - use smart normalization
284- if cfg .data .image_transform .normalize != "none" :
285- transforms .append (
286- SmartNormalizeIntensityd (
287- keys = ["image" ],
288- mode = cfg .data .image_transform .normalize ,
289- clip_percentile_low = cfg .data .image_transform .clip_percentile_low ,
290- clip_percentile_high = cfg .data .image_transform .clip_percentile_high ,
291- )
292- )
293-
294- # Normalize labels to 0-1 range if enabled
295- if getattr (cfg .data , "normalize_labels" , False ):
296- transforms .append (NormalizeLabelsd (keys = ["label" ]))
297-
298- # Label transformations (affinity, distance transform, etc.)
299- if hasattr (cfg .data , "label_transform" ):
300- from ..process .build import create_label_transform_pipeline
301- from ..process .monai_transforms import SegErosionInstanced
302-
303- label_cfg = cfg .data .label_transform
304-
305- # Apply instance erosion first if specified
306- if hasattr (label_cfg , "erosion" ) and label_cfg .erosion > 0 :
307- transforms .append (SegErosionInstanced (keys = ["label" ], tsz_h = label_cfg .erosion ))
308-
309- # Build label transform pipeline directly from label_transform config
310- label_pipeline = create_label_transform_pipeline (label_cfg )
311- transforms .extend (label_pipeline .transforms )
312-
313- # NOTE: Do NOT squeeze labels here!
314- # - DiceLoss needs (B, 1, H, W) with to_onehot_y=True
315- # - CrossEntropyLoss needs (B, H, W)
316- # Squeezing is handled in the loss wrapper instead
317-
318- # Final conversion to tensor with float32 dtype
319- transforms .append (ToTensord (keys = keys , dtype = torch .float32 ))
320-
321- return Compose (transforms )
322-
323-
324- def build_test_transforms (cfg : Config , keys : list [str ] = None ) -> Compose :
325- """
326- Build test/inference transforms from Hydra config.
327-
328- Similar to validation transforms but WITHOUT cropping to enable
329- sliding window inference on full volumes.
330-
331- Args:
332- cfg: Hydra Config object
333- keys: Keys to transform (default: ['image', 'label'] or ['image', 'label', 'mask'] if masks are used)
334-
335- Returns:
336- Composed MONAI transforms (no augmentation, no cropping)
337- """
338- if keys is None :
339- # Auto-detect keys based on config
340- keys = ["image" ]
341- # Only add label if test_label is specified in the config
342- if (
343- hasattr (cfg , "inference" )
344- and hasattr (cfg .inference , "data" )
345- and hasattr (cfg .inference .data , "test_label" )
346- and cfg .inference .data .test_label is not None
347- ):
348- keys .append ("label" )
349- # Add mask to keys if it's specified in the config (check test mask)
350- if (
351- hasattr (cfg , "inference" )
352- and hasattr (cfg .inference , "data" )
353- and hasattr (cfg .inference .data , "test_mask" )
354- and cfg .inference .data .test_mask is not None
355- ):
356- keys .append ("mask" )
357-
358- transforms = []
359-
360- # Load images first - use appropriate loader based on dataset type
361- dataset_type = getattr (cfg .data , "dataset_type" , "volume" ) # Default to volume for backward compatibility
362-
363- if dataset_type == "filename" :
364- # For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
365- transforms .append (LoadImaged (keys = keys , image_only = False ))
366- # Ensure channel-first format [C, H, W] or [C, D, H, W]
367- transforms .append (EnsureChannelFirstd (keys = keys ))
368- else :
369- # For volume-based datasets (HDF5, TIFF volumes), use custom LoadVolumed
370- # Get transpose axes for test data (check both data.test_transpose and inference.data.test_transpose)
371- test_transpose = []
372- if cfg .data .test_transpose :
373- test_transpose = cfg .data .test_transpose
374- if (
375- hasattr (cfg , "inference" )
376- and hasattr (cfg .inference , "data" )
377- and hasattr (cfg .inference .data , "test_transpose" )
378- and cfg .inference .data .test_transpose
379- ):
380- test_transpose = cfg .inference .data .test_transpose # inference takes precedence
381- transforms .append (
382- LoadVolumed (keys = keys , transpose_axes = test_transpose if test_transpose else None )
383- )
384-
385- # Apply volumetric split if enabled (though typically not used for test)
386- if cfg .data .split_enabled :
387- from connectomics .data .utils import ApplyVolumetricSplitd
388-
389- transforms .append (ApplyVolumetricSplitd (keys = keys ))
390-
391- # Apply resize if configured (before padding)
392- if hasattr (cfg .data .image_transform , "resize" ) and cfg .data .image_transform .resize is not None :
393- resize_factors = cfg .data .image_transform .resize
394- if resize_factors :
395- # Use bilinear for images, nearest for labels/masks
315+ # Add spatial cropping - MODE-SPECIFIC
316+ # Validation: Apply center crop for patch-based validation
317+ # Test: Skip cropping to enable sliding window inference on full volumes
318+ if mode == "val" :
319+ if patch_size and all (size > 0 for size in patch_size ):
396320 transforms .append (
397- Resized (keys = ["image" ], scale = resize_factors , mode = "bilinear" , align_corners = True )
398- )
399- # Resize labels and masks with nearest-neighbor
400- label_mask_keys = [k for k in keys if k in ["label" , "mask" ]]
401- if label_mask_keys :
402- transforms .append (
403- Resized (
404- keys = label_mask_keys ,
405- scale = resize_factors ,
406- mode = "nearest" ,
407- align_corners = None ,
408- )
321+ CenterSpatialCropd (
322+ keys = keys ,
323+ roi_size = patch_size ,
409324 )
410-
411- patch_size = tuple (cfg .data .patch_size ) if hasattr (cfg .data , "patch_size" ) else None
412- if patch_size and all (size > 0 for size in patch_size ):
413- transforms .append (
414- SpatialPadd (
415- keys = keys ,
416- spatial_size = patch_size ,
417- constant_values = 0.0 ,
418325 )
419- )
420-
421- # NOTE: No CenterSpatialCropd here - we want full volumes for sliding window inference!
326+ # else: mode == "test" -> no cropping for sliding window inference
422327
423328 # Normalization - use smart normalization
424329 if cfg .data .image_transform .normalize != "none" :
@@ -431,25 +336,25 @@ def build_test_transforms(cfg: Config, keys: list[str] = None) -> Compose:
431336 )
432337 )
433338
434- # Only apply label transforms if 'label' is in keys
339+ # Only process labels if 'label' is in keys
435340 if "label" in keys :
436341 # Normalize labels to 0-1 range if enabled
437342 if getattr (cfg .data , "normalize_labels" , False ):
438343 transforms .append (NormalizeLabelsd (keys = ["label" ]))
439344
440- # Check if any evaluation metric is enabled (requires original instance labels )
345+ # Check if we should skip label transforms (test mode with evaluation metrics )
441346 skip_label_transform = False
442- if hasattr (cfg , "inference" ) and hasattr (cfg .inference , "evaluation" ):
443- evaluation_enabled = getattr (cfg .inference .evaluation , "enabled" , False )
444- metrics = getattr (cfg .inference .evaluation , "metrics" , [])
445- if evaluation_enabled and metrics :
446- skip_label_transform = True
447- print (
448- f" ⚠️ Skipping label transforms for metric evaluation (keeping original labels for { metrics } )"
449- )
347+ if mode == "test" :
348+ if hasattr (cfg , "inference" ) and hasattr (cfg .inference , "evaluation" ):
349+ evaluation_enabled = getattr (cfg .inference .evaluation , "enabled" , False )
350+ metrics = getattr (cfg .inference .evaluation , "metrics" , [])
351+ if evaluation_enabled and metrics :
352+ skip_label_transform = True
353+ print (
354+ f" ⚠️ Skipping label transforms for metric evaluation (keeping original labels for { metrics } )"
355+ )
450356
451357 # Label transformations (affinity, distance transform, etc.)
452- # Skip if evaluation metrics are enabled (need original labels for metric computation)
453358 if hasattr (cfg .data , "label_transform" ) and not skip_label_transform :
454359 from ..process .build import create_label_transform_pipeline
455360 from ..process .monai_transforms import SegErosionInstanced
@@ -475,6 +380,37 @@ def build_test_transforms(cfg: Config, keys: list[str] = None) -> Compose:
475380 return Compose (transforms )
476381
477382
383+ def build_val_transforms (cfg : Config , keys : list [str ] = None ) -> Compose :
384+ """
385+ Build validation transforms from Hydra config.
386+
387+ Args:
388+ cfg: Hydra Config object
389+ keys: Keys to transform (default: auto-detected as ['image', 'label'])
390+
391+ Returns:
392+ Composed MONAI transforms (no augmentation, center cropping)
393+ """
394+ return _build_eval_transforms_impl (cfg , mode = "val" , keys = keys )
395+
396+
397+ def build_test_transforms (cfg : Config , keys : list [str ] = None ) -> Compose :
398+ """
399+ Build test/inference transforms from Hydra config.
400+
401+ Similar to validation transforms but WITHOUT cropping to enable
402+ sliding window inference on full volumes.
403+
404+ Args:
405+ cfg: Hydra Config object
406+ keys: Keys to transform (default: auto-detected as ['image'] only)
407+
408+ Returns:
409+ Composed MONAI transforms (no augmentation, no cropping)
410+ """
411+ return _build_eval_transforms_impl (cfg , mode = "test" , keys = keys )
412+
413+
478414def build_inference_transforms (cfg : Config ) -> Compose :
479415 """
480416 Build inference transforms from Hydra config.
0 commit comments