@@ -521,38 +521,66 @@ def _apply_postprocessing(self, data: np.ndarray) -> np.ndarray:
521521 from connectomics .decoding .postprocess import apply_binary_postprocessing
522522
523523 # Process each sample in batch
524- batch_size = data .shape [0 ] if data .ndim >= 4 else 1
525-
526- # Handle different input shapes
527- if data .ndim == 2 : # (H, W) -> (1, 1, H, W)
528- data = data [np .newaxis , np .newaxis , ...]
529- elif data .ndim == 3 : # (D, H, W) or (C, H, W) -> assume (D, H, W) and add batch dim
530- data = data [np .newaxis , ...] # (1, D, H, W)
524+ # Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
525+ print (f" DEBUG: _apply_postprocessing - input data shape: { data .shape } , ndim: { data .ndim } " )
526+ if data .ndim == 4 :
527+ # 2D data: (B, C, H, W)
528+ batch_size = data .shape [0 ]
529+ print (f" DEBUG: _apply_postprocessing - detected 2D data, batch_size: { batch_size } " )
530+ elif data .ndim == 5 :
531+ # 3D data: (B, C, D, H, W)
532+ batch_size = data .shape [0 ]
533+ print (f" DEBUG: _apply_postprocessing - detected 3D data, batch_size: { batch_size } " )
534+ elif data .ndim == 3 :
535+ # Single 3D volume: (C, D, H, W) or (D, H, W) - add batch dimension
536+ batch_size = 1
537+ data = data [np .newaxis , ...] # (1, C, D, H, W) or (1, D, H, W)
538+ print (f" DEBUG: _apply_postprocessing - single 3D sample, added batch dimension" )
539+ elif data .ndim == 2 :
540+ # Single 2D image: (H, W) - add batch and channel dimensions
541+ batch_size = 1
542+ data = data [np .newaxis , np .newaxis , ...] # (1, 1, H, W)
543+ print (f" DEBUG: _apply_postprocessing - single 2D sample, added batch and channel dimensions" )
544+ else :
545+ batch_size = 1
531546
532- # Ensure we have at least 4D: (B, ...) where ... can be (D , H, W) or (C, D, H, W)
547+ # Ensure we have at least 4D: (B, ...) where ... can be (C , H, W) for 2D or (C, D, H, W) for 3D
533548 results = []
534549 for batch_idx in range (batch_size ):
535- sample = data [batch_idx ] # (C, D, H, W) or (D, H, W)
536-
537- # Extract foreground probability (handle both 3D and 4D)
538- if sample .ndim == 4 : # (C, D, H, W)
539- foreground_prob = sample [0 ] # Use first channel
540- else : # (D, H, W) - already single channel
550+ sample = data [batch_idx ] # (C, H, W) for 2D or (C, D, H, W) for 3D
551+ print (f" DEBUG: _apply_postprocessing - processing batch_idx { batch_idx } , sample shape: { sample .shape } " )
552+
553+ # Extract foreground probability (always use first channel if channel dimension exists)
554+ if sample .ndim == 4 : # (C, D, H, W) - 3D with channel
555+ foreground_prob = sample [0 ] # Use first channel -> (D, H, W)
556+ elif sample .ndim == 3 :
557+ # Could be (C, H, W) for 2D or (D, H, W) for 3D without channel
558+ # If first dim is small (<=4), assume it's channel (2D), otherwise depth (3D)
559+ if sample .shape [0 ] <= 4 :
560+ foreground_prob = sample [0 ] # (C, H, W) -> use first channel -> (H, W)
561+ else :
562+ foreground_prob = sample # (D, H, W) - already single channel
563+ elif sample .ndim == 2 : # (H, W) - 2D single channel
564+ foreground_prob = sample
565+ else :
541566 foreground_prob = sample
542567
543568 # Apply binary postprocessing
544569 processed = apply_binary_postprocessing (foreground_prob , binary_config )
545570
546- # Expand dims to maintain shape consistency
547- if sample .ndim == 4 :
571+ # Expand dims to maintain shape consistency with original sample structure
572+ if sample .ndim == 4 : # (C, D, H, W) -> processed is (D, H, W)
548573 processed = processed [np .newaxis , ...] # (1, D, H, W)
549- else :
550- processed = processed # Keep (D, H, W)
574+ elif sample .ndim == 3 and sample .shape [0 ] <= 4 : # (C, H, W) -> processed is (H, W)
575+ processed = processed [np .newaxis , ...] # (1, H, W)
576+ # else: processed is already correct shape (D, H, W) or (H, W)
551577
552578 results .append (processed )
553579
554580 # Stack results back into batch
581+ print (f" DEBUG: _apply_postprocessing - stacking { len (results )} results, shapes: { [r .shape for r in results ]} " )
555582 data = np .stack (results , axis = 0 )
583+ print (f" DEBUG: _apply_postprocessing - after stacking, data shape: { data .shape } " )
556584
557585 # Step 2: Apply scaling if configured (support both new and legacy names)
558586 intensity_scale = getattr (postprocessing , 'intensity_scale' , None )
@@ -651,13 +679,29 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
651679 }
652680
653681 # Process each sample in batch
654- batch_size = data .shape [0 ] if data .ndim == 5 else 1
682+ # Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
683+ print (f" DEBUG: _apply_decode_mode - input data shape: { data .shape } , ndim: { data .ndim } " )
655684 if data .ndim == 4 :
656- data = data [np .newaxis , ...] # Add batch dimension
685+ # 2D data: (B, C, H, W)
686+ batch_size = data .shape [0 ]
687+ print (f" DEBUG: _apply_decode_mode - detected 2D data, batch_size: { batch_size } " )
688+ elif data .ndim == 5 :
689+ # 3D data: (B, C, D, H, W)
690+ batch_size = data .shape [0 ]
691+ print (f" DEBUG: _apply_decode_mode - detected 3D data, batch_size: { batch_size } " )
692+ else :
693+ # Single sample: add batch dimension
694+ batch_size = 1
695+ print (f" DEBUG: _apply_decode_mode - single sample, adding batch dimension" )
696+ if data .ndim == 3 :
697+ data = data [np .newaxis , ...] # (C, H, W) -> (1, C, H, W)
698+ elif data .ndim == 2 :
699+ data = data [np .newaxis , np .newaxis , ...] # (H, W) -> (1, 1, H, W)
657700
658701 results = []
659702 for batch_idx in range (batch_size ):
660- sample = data [batch_idx ] # (C, D, H, W)
703+ sample = data [batch_idx ] # (C, H, W) for 2D or (C, D, H, W) for 3D
704+ print (f" DEBUG: _apply_decode_mode - processing batch_idx { batch_idx } , sample shape: { sample .shape } " )
661705
662706 # Apply each decode mode sequentially
663707 for decode_cfg in decode_modes :
@@ -718,8 +762,10 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
718762 results .append (sample )
719763
720764 # Stack results back into batch
721- decoded = np .stack (results , axis = 0 ) if len (results ) > 1 else results [0 ]
722-
765+ # Always preserve batch dimension, even for batch_size=1
766+ print (f" DEBUG: _apply_decode_mode - stacking { len (results )} results, shapes: { [r .shape for r in results ]} " )
767+ decoded = np .stack (results , axis = 0 )
768+ print (f" DEBUG: _apply_decode_mode - final decoded shape: { decoded .shape } " )
723769 return decoded
724770
725771 def _resolve_output_filenames (self , batch : Dict [str , Any ]) -> List [str ]:
@@ -742,26 +788,59 @@ def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]:
742788
743789 meta = batch .get ('image_meta_dict' )
744790 filenames : List [Optional [str ]] = []
791+
792+ print (f" DEBUG: _resolve_output_filenames - meta type: { type (meta )} , batch_size: { batch_size } " )
745793
746- if isinstance (meta , dict ):
794+ # Handle different metadata structures
795+ if isinstance (meta , list ):
796+ # Multiple metadata dicts (one per sample in batch)
797+ print (f" DEBUG: _resolve_output_filenames - meta is list with { len (meta )} items" )
798+ for idx , meta_item in enumerate (meta ):
799+ if isinstance (meta_item , dict ):
800+ filename = meta_item .get ('filename_or_obj' )
801+ if filename is not None :
802+ filenames .append (filename )
803+ else :
804+ print (f" DEBUG: _resolve_output_filenames - meta_item[{ idx } ] has no filename_or_obj" )
805+ else :
806+ print (f" DEBUG: _resolve_output_filenames - meta_item[{ idx } ] is not a dict: { type (meta_item )} " )
807+ # Update batch_size from metadata if we have a list
808+ batch_size = max (batch_size , len (filenames ))
809+ print (f" DEBUG: _resolve_output_filenames - extracted { len (filenames )} filenames from list" )
810+ elif isinstance (meta , dict ):
811+ # Single metadata dict
812+ print (f" DEBUG: _resolve_output_filenames - meta is dict" )
747813 meta_filenames = meta .get ('filename_or_obj' )
748814 if isinstance (meta_filenames , (list , tuple )):
749- filenames = list ( meta_filenames )
815+ filenames = [ f for f in meta_filenames if f is not None ]
750816 elif meta_filenames is not None :
751817 filenames = [meta_filenames ]
752- elif isinstance (meta , list ):
753- for meta_item in meta :
754- if isinstance (meta_item , dict ):
755- filenames .append (meta_item .get ('filename_or_obj' ))
756- # Update batch_size from metadata if we have a list
757- batch_size = max (batch_size , len (filenames ))
818+ # Update batch_size from metadata
819+ if len (filenames ) > 0 :
820+ batch_size = max (batch_size , len (filenames ))
821+ print (f" DEBUG: _resolve_output_filenames - extracted { len (filenames )} filenames from dict" )
822+ else :
823+ # Handle case where meta might be None or other types
824+ # This can happen if metadata wasn't preserved through transforms
825+ # We'll use fallback filenames based on batch_size
826+ print (f" DEBUG: _resolve_output_filenames - meta is None or unexpected type: { type (meta )} " )
827+ pass
758828
759829 resolved_names : List [str ] = []
760830 for idx in range (batch_size ):
761831 if idx < len (filenames ) and filenames [idx ]:
762832 resolved_names .append (Path (str (filenames [idx ])).stem )
763833 else :
834+ # Generate fallback filename - this shouldn't happen if metadata is preserved correctly
764835 resolved_names .append (f"volume_{ self .global_step } _{ idx } " )
836+
837+ print (f" DEBUG: _resolve_output_filenames - returning { len (resolved_names )} resolved names: { resolved_names [:3 ]} ..." )
838+
839+ # Always return exactly batch_size filenames
840+ if len (resolved_names ) < batch_size :
841+ print (f" WARNING: _resolve_output_filenames - Only { len (resolved_names )} filenames but batch_size is { batch_size } , padding with fallback names" )
842+ while len (resolved_names ) < batch_size :
843+ resolved_names .append (f"volume_{ self .global_step } _{ len (resolved_names )} " )
765844
766845 return resolved_names
767846
@@ -799,8 +878,42 @@ def _write_outputs(
799878 if hasattr (self .cfg .inference , 'postprocessing' ):
800879 output_transpose = getattr (self .cfg .inference .postprocessing , 'output_transpose' , [])
801880
881+ # Determine actual batch size from predictions
882+ # Handle both batched (B, ...) and unbatched (...) predictions
883+ print (f" DEBUG: _write_outputs - predictions shape: { predictions .shape } , ndim: { predictions .ndim } , filenames count: { len (filenames )} " )
884+
885+ if predictions .ndim >= 4 :
886+ # Has batch dimension: (B, C, D, H, W) or (B, C, H, W) or (B, D, H, W)
887+ actual_batch_size = predictions .shape [0 ]
888+ elif predictions .ndim == 3 :
889+ # Could be batched 2D data (B, H, W) or single 3D volume (D, H, W)
890+ # Check if first dimension matches number of filenames -> it's batched 2D data
891+ if len (filenames ) > 0 and predictions .shape [0 ] == len (filenames ):
892+ # Batched 2D data: (B, H, W) where B matches number of filenames
893+ actual_batch_size = predictions .shape [0 ]
894+ print (f" DEBUG: _write_outputs - detected batched 2D data (B, H, W) with batch_size={ actual_batch_size } " )
895+ else :
896+ # Single 3D volume: (D, H, W) - treat as batch_size=1
897+ actual_batch_size = 1
898+ predictions = predictions [np .newaxis , ...] # Add batch dimension
899+ print (f" DEBUG: _write_outputs - detected single 3D volume, added batch dimension" )
900+ elif predictions .ndim == 2 :
901+ # Single 2D image: (H, W) - treat as batch_size=1
902+ actual_batch_size = 1
903+ predictions = predictions [np .newaxis , ...] # Add batch dimension
904+ else :
905+ # Unexpected shape, default to batch_size=1
906+ actual_batch_size = 1
907+ if predictions .ndim < 2 :
908+ predictions = predictions [np .newaxis , ...] # Add batch dimension
909+
910+ # Ensure we don't exceed the actual batch size
911+ batch_size = min (actual_batch_size , len (filenames ))
912+ print (f" DEBUG: _write_outputs - actual_batch_size: { actual_batch_size } , batch_size: { batch_size } , will save { batch_size } predictions" )
913+
802914 # Save predictions
803- for idx , name in enumerate (filenames ):
915+ for idx in range (batch_size ):
916+ name = filenames [idx ]
804917 prediction = predictions [idx ]
805918
806919 # Apply output transpose if specified
@@ -1303,16 +1416,28 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
13031416 labels = batch .get ('label' )
13041417 mask = batch .get ('mask' ) # Get test mask if available
13051418
1419+ # Get batch size from images
1420+ actual_batch_size = images .shape [0 ]
1421+ print (f" DEBUG: test_step - images shape: { images .shape } , batch_size: { actual_batch_size } " )
1422+
13061423 # Always use TTA (handles no-transform case) + sliding window
13071424 # TTA preprocessing (activation, masking) is applied regardless of flip augmentation
13081425 # Note: TTA always returns a simple tensor, not a dict (deep supervision not supported in test mode)
13091426 predictions = self ._predict_with_tta (images , mask = mask )
13101427
13111428 # Convert predictions to numpy for saving/decoding
13121429 predictions_np = predictions .detach ().cpu ().float ().numpy ()
1430+ print (f" DEBUG: test_step - predictions_np shape: { predictions_np .shape } " )
13131431
13141432 # Resolve filenames once for all saving operations
13151433 filenames = self ._resolve_output_filenames (batch )
1434+ print (f" DEBUG: test_step - filenames count: { len (filenames )} , filenames: { filenames [:5 ]} ..." )
1435+
1436+ # Ensure filenames list matches actual batch size
1437+ # If we don't have enough filenames, generate default ones
1438+ while len (filenames ) < actual_batch_size :
1439+ filenames .append (f"volume_{ self .global_step } _{ len (filenames )} " )
1440+ print (f" DEBUG: test_step - after padding, filenames count: { len (filenames )} " )
13161441
13171442 # Check if we should save intermediate predictions (before decoding)
13181443 save_intermediate = False
@@ -1324,10 +1449,13 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
13241449 self ._write_outputs (predictions_np , filenames , suffix = "tta_prediction" )
13251450
13261451 # Apply decode mode (instance segmentation decoding)
1452+ print (f" DEBUG: test_step - before decode, predictions_np shape: { predictions_np .shape } " )
13271453 decoded_predictions = self ._apply_decode_mode (predictions_np )
1454+ print (f" DEBUG: test_step - after decode, decoded_predictions shape: { decoded_predictions .shape } " )
13281455
13291456 # Apply postprocessing (scaling and dtype conversion) if configured
13301457 postprocessed_predictions = self ._apply_postprocessing (decoded_predictions )
1458+ print (f" DEBUG: test_step - after postprocess, postprocessed_predictions shape: { postprocessed_predictions .shape } " )
13311459
13321460 # Save final decoded and postprocessed predictions
13331461 self ._write_outputs (postprocessed_predictions , filenames , suffix = "prediction" )
0 commit comments