@@ -536,38 +536,66 @@ def _apply_postprocessing(self, data: np.ndarray) -> np.ndarray:
536536 from connectomics .decoding .postprocess import apply_binary_postprocessing
537537
538538 # Process each sample in batch
539- batch_size = data .shape [0 ] if data .ndim >= 4 else 1
540-
541- # Handle different input shapes
542- if data .ndim == 2 : # (H, W) -> (1, 1, H, W)
543- data = data [np .newaxis , np .newaxis , ...]
544- elif data .ndim == 3 : # (D, H, W) or (C, H, W) -> assume (D, H, W) and add batch dim
545- data = data [np .newaxis , ...] # (1, D, H, W)
539+ # Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
540+ print (f" DEBUG: _apply_postprocessing - input data shape: { data .shape } , ndim: { data .ndim } " )
541+ if data .ndim == 4 :
542+ # 2D data: (B, C, H, W)
543+ batch_size = data .shape [0 ]
544+ print (f" DEBUG: _apply_postprocessing - detected 2D data, batch_size: { batch_size } " )
545+ elif data .ndim == 5 :
546+ # 3D data: (B, C, D, H, W)
547+ batch_size = data .shape [0 ]
548+ print (f" DEBUG: _apply_postprocessing - detected 3D data, batch_size: { batch_size } " )
549+ elif data .ndim == 3 :
550+ # Single 3D volume: (C, D, H, W) or (D, H, W) - add batch dimension
551+ batch_size = 1
552+ data = data [np .newaxis , ...] # (1, C, D, H, W) or (1, D, H, W)
553+ print (f" DEBUG: _apply_postprocessing - single 3D sample, added batch dimension" )
554+ elif data .ndim == 2 :
555+ # Single 2D image: (H, W) - add batch and channel dimensions
556+ batch_size = 1
557+ data = data [np .newaxis , np .newaxis , ...] # (1, 1, H, W)
558+ print (f" DEBUG: _apply_postprocessing - single 2D sample, added batch and channel dimensions" )
559+ else :
560+ batch_size = 1
546561
547- # Ensure we have at least 4D: (B, ...) where ... can be (D , H, W) or (C, D, H, W)
562+ # Ensure we have at least 4D: (B, ...) where ... can be (C , H, W) for 2D or (C, D, H, W) for 3D
548563 results = []
549564 for batch_idx in range (batch_size ):
550- sample = data [batch_idx ] # (C, D, H, W) or (D, H, W)
551-
552- # Extract foreground probability (handle both 3D and 4D)
553- if sample .ndim == 4 : # (C, D, H, W)
554- foreground_prob = sample [0 ] # Use first channel
555- else : # (D, H, W) - already single channel
565+ sample = data [batch_idx ] # (C, H, W) for 2D or (C, D, H, W) for 3D
566+ print (f" DEBUG: _apply_postprocessing - processing batch_idx { batch_idx } , sample shape: { sample .shape } " )
567+
568+ # Extract foreground probability (always use first channel if channel dimension exists)
569+ if sample .ndim == 4 : # (C, D, H, W) - 3D with channel
570+ foreground_prob = sample [0 ] # Use first channel -> (D, H, W)
571+ elif sample .ndim == 3 :
572+ # Could be (C, H, W) for 2D or (D, H, W) for 3D without channel
573+ # If first dim is small (<=4), assume it's channel (2D), otherwise depth (3D)
574+ if sample .shape [0 ] <= 4 :
575+ foreground_prob = sample [0 ] # (C, H, W) -> use first channel -> (H, W)
576+ else :
577+ foreground_prob = sample # (D, H, W) - already single channel
578+ elif sample .ndim == 2 : # (H, W) - 2D single channel
579+ foreground_prob = sample
580+ else :
556581 foreground_prob = sample
557582
558583 # Apply binary postprocessing
559584 processed = apply_binary_postprocessing (foreground_prob , binary_config )
560585
561- # Expand dims to maintain shape consistency
562- if sample .ndim == 4 :
586+ # Expand dims to maintain shape consistency with original sample structure
587+ if sample .ndim == 4 : # (C, D, H, W) -> processed is (D, H, W)
563588 processed = processed [np .newaxis , ...] # (1, D, H, W)
564- else :
565- processed = processed # Keep (D, H, W)
589+ elif sample .ndim == 3 and sample .shape [0 ] <= 4 : # (C, H, W) -> processed is (H, W)
590+ processed = processed [np .newaxis , ...] # (1, H, W)
591+ # else: processed is already correct shape (D, H, W) or (H, W)
566592
567593 results .append (processed )
568594
569595 # Stack results back into batch
596+ print (f" DEBUG: _apply_postprocessing - stacking { len (results )} results, shapes: { [r .shape for r in results ]} " )
570597 data = np .stack (results , axis = 0 )
598+ print (f" DEBUG: _apply_postprocessing - after stacking, data shape: { data .shape } " )
571599
572600 # Step 2: Apply scaling if configured (support both new and legacy names)
573601 intensity_scale = getattr (postprocessing , 'intensity_scale' , None )
@@ -666,13 +694,29 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
666694 }
667695
668696 # Process each sample in batch
669- batch_size = data .shape [0 ] if data .ndim == 5 else 1
697+ # Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
698+ print (f" DEBUG: _apply_decode_mode - input data shape: { data .shape } , ndim: { data .ndim } " )
670699 if data .ndim == 4 :
671- data = data [np .newaxis , ...] # Add batch dimension
700+ # 2D data: (B, C, H, W)
701+ batch_size = data .shape [0 ]
702+ print (f" DEBUG: _apply_decode_mode - detected 2D data, batch_size: { batch_size } " )
703+ elif data .ndim == 5 :
704+ # 3D data: (B, C, D, H, W)
705+ batch_size = data .shape [0 ]
706+ print (f" DEBUG: _apply_decode_mode - detected 3D data, batch_size: { batch_size } " )
707+ else :
708+ # Single sample: add batch dimension
709+ batch_size = 1
710+ print (f" DEBUG: _apply_decode_mode - single sample, adding batch dimension" )
711+ if data .ndim == 3 :
712+ data = data [np .newaxis , ...] # (C, H, W) -> (1, C, H, W)
713+ elif data .ndim == 2 :
714+ data = data [np .newaxis , np .newaxis , ...] # (H, W) -> (1, 1, H, W)
672715
673716 results = []
674717 for batch_idx in range (batch_size ):
675- sample = data [batch_idx ] # (C, D, H, W)
718+ sample = data [batch_idx ] # (C, H, W) for 2D or (C, D, H, W) for 3D
719+ print (f" DEBUG: _apply_decode_mode - processing batch_idx { batch_idx } , sample shape: { sample .shape } " )
676720
677721 # Apply each decode mode sequentially
678722 for decode_cfg in decode_modes :
@@ -733,8 +777,10 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
733777 results .append (sample )
734778
735779 # Stack results back into batch
736- decoded = np .stack (results , axis = 0 ) if len (results ) > 1 else results [0 ]
737-
780+ # Always preserve batch dimension, even for batch_size=1
781+ print (f" DEBUG: _apply_decode_mode - stacking { len (results )} results, shapes: { [r .shape for r in results ]} " )
782+ decoded = np .stack (results , axis = 0 )
783+ print (f" DEBUG: _apply_decode_mode - final decoded shape: { decoded .shape } " )
738784 return decoded
739785
740786 def _resolve_output_filenames (self , batch : Dict [str , Any ]) -> List [str ]:
@@ -757,26 +803,59 @@ def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]:
757803
758804 meta = batch .get ('image_meta_dict' )
759805 filenames : List [Optional [str ]] = []
806+
807+ print (f" DEBUG: _resolve_output_filenames - meta type: { type (meta )} , batch_size: { batch_size } " )
760808
761- if isinstance (meta , dict ):
809+ # Handle different metadata structures
810+ if isinstance (meta , list ):
811+ # Multiple metadata dicts (one per sample in batch)
812+ print (f" DEBUG: _resolve_output_filenames - meta is list with { len (meta )} items" )
813+ for idx , meta_item in enumerate (meta ):
814+ if isinstance (meta_item , dict ):
815+ filename = meta_item .get ('filename_or_obj' )
816+ if filename is not None :
817+ filenames .append (filename )
818+ else :
819+ print (f" DEBUG: _resolve_output_filenames - meta_item[{ idx } ] has no filename_or_obj" )
820+ else :
821+ print (f" DEBUG: _resolve_output_filenames - meta_item[{ idx } ] is not a dict: { type (meta_item )} " )
822+ # Update batch_size from metadata if we have a list
823+ batch_size = max (batch_size , len (filenames ))
824+ print (f" DEBUG: _resolve_output_filenames - extracted { len (filenames )} filenames from list" )
825+ elif isinstance (meta , dict ):
826+ # Single metadata dict
827+ print (f" DEBUG: _resolve_output_filenames - meta is dict" )
762828 meta_filenames = meta .get ('filename_or_obj' )
763829 if isinstance (meta_filenames , (list , tuple )):
764- filenames = list ( meta_filenames )
830+ filenames = [ f for f in meta_filenames if f is not None ]
765831 elif meta_filenames is not None :
766832 filenames = [meta_filenames ]
767- elif isinstance (meta , list ):
768- for meta_item in meta :
769- if isinstance (meta_item , dict ):
770- filenames .append (meta_item .get ('filename_or_obj' ))
771- # Update batch_size from metadata if we have a list
772- batch_size = max (batch_size , len (filenames ))
833+ # Update batch_size from metadata
834+ if len (filenames ) > 0 :
835+ batch_size = max (batch_size , len (filenames ))
836+ print (f" DEBUG: _resolve_output_filenames - extracted { len (filenames )} filenames from dict" )
837+ else :
838+ # Handle case where meta might be None or other types
839+ # This can happen if metadata wasn't preserved through transforms
840+ # We'll use fallback filenames based on batch_size
841+ print (f" DEBUG: _resolve_output_filenames - meta is None or unexpected type: { type (meta )} " )
842+ pass
773843
774844 resolved_names : List [str ] = []
775845 for idx in range (batch_size ):
776846 if idx < len (filenames ) and filenames [idx ]:
777847 resolved_names .append (Path (str (filenames [idx ])).stem )
778848 else :
849+ # Generate fallback filename - this shouldn't happen if metadata is preserved correctly
779850 resolved_names .append (f"volume_{ self .global_step } _{ idx } " )
851+
852+ print (f" DEBUG: _resolve_output_filenames - returning { len (resolved_names )} resolved names: { resolved_names [:3 ]} ..." )
853+
854+ # Always return exactly batch_size filenames
855+ if len (resolved_names ) < batch_size :
856+ print (f" WARNING: _resolve_output_filenames - Only { len (resolved_names )} filenames but batch_size is { batch_size } , padding with fallback names" )
857+ while len (resolved_names ) < batch_size :
858+ resolved_names .append (f"volume_{ self .global_step } _{ len (resolved_names )} " )
780859
781860 return resolved_names
782861
@@ -814,8 +893,42 @@ def _write_outputs(
814893 if hasattr (self .cfg .inference , 'postprocessing' ):
815894 output_transpose = getattr (self .cfg .inference .postprocessing , 'output_transpose' , [])
816895
896+ # Determine actual batch size from predictions
897+ # Handle both batched (B, ...) and unbatched (...) predictions
898+ print (f" DEBUG: _write_outputs - predictions shape: { predictions .shape } , ndim: { predictions .ndim } , filenames count: { len (filenames )} " )
899+
900+ if predictions .ndim >= 4 :
901+ # Has batch dimension: (B, C, D, H, W) or (B, C, H, W) or (B, D, H, W)
902+ actual_batch_size = predictions .shape [0 ]
903+ elif predictions .ndim == 3 :
904+ # Could be batched 2D data (B, H, W) or single 3D volume (D, H, W)
905+ # Check if first dimension matches number of filenames -> it's batched 2D data
906+ if len (filenames ) > 0 and predictions .shape [0 ] == len (filenames ):
907+ # Batched 2D data: (B, H, W) where B matches number of filenames
908+ actual_batch_size = predictions .shape [0 ]
909+ print (f" DEBUG: _write_outputs - detected batched 2D data (B, H, W) with batch_size={ actual_batch_size } " )
910+ else :
911+ # Single 3D volume: (D, H, W) - treat as batch_size=1
912+ actual_batch_size = 1
913+ predictions = predictions [np .newaxis , ...] # Add batch dimension
914+ print (f" DEBUG: _write_outputs - detected single 3D volume, added batch dimension" )
915+ elif predictions .ndim == 2 :
916+ # Single 2D image: (H, W) - treat as batch_size=1
917+ actual_batch_size = 1
918+ predictions = predictions [np .newaxis , ...] # Add batch dimension
919+ else :
920+ # Unexpected shape, default to batch_size=1
921+ actual_batch_size = 1
922+ if predictions .ndim < 2 :
923+ predictions = predictions [np .newaxis , ...] # Add batch dimension
924+
925+ # Ensure we don't exceed the actual batch size
926+ batch_size = min (actual_batch_size , len (filenames ))
927+ print (f" DEBUG: _write_outputs - actual_batch_size: { actual_batch_size } , batch_size: { batch_size } , will save { batch_size } predictions" )
928+
817929 # Save predictions
818- for idx , name in enumerate (filenames ):
930+ for idx in range (batch_size ):
931+ name = filenames [idx ]
819932 prediction = predictions [idx ]
820933
821934 # Apply output transpose if specified
@@ -1327,16 +1440,28 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
13271440 labels = batch .get ('label' )
13281441 mask = batch .get ('mask' ) # Get test mask if available
13291442
1443+ # Get batch size from images
1444+ actual_batch_size = images .shape [0 ]
1445+ print (f" DEBUG: test_step - images shape: { images .shape } , batch_size: { actual_batch_size } " )
1446+
13301447 # Always use TTA (handles no-transform case) + sliding window
13311448 # TTA preprocessing (activation, masking) is applied regardless of flip augmentation
13321449 # Note: TTA always returns a simple tensor, not a dict (deep supervision not supported in test mode)
13331450 predictions = self ._predict_with_tta (images , mask = mask )
13341451
13351452 # Convert predictions to numpy for saving/decoding
13361453 predictions_np = predictions .detach ().cpu ().float ().numpy ()
1454+ print (f" DEBUG: test_step - predictions_np shape: { predictions_np .shape } " )
13371455
13381456 # Resolve filenames once for all saving operations
13391457 filenames = self ._resolve_output_filenames (batch )
1458+ print (f" DEBUG: test_step - filenames count: { len (filenames )} , filenames: { filenames [:5 ]} ..." )
1459+
1460+ # Ensure filenames list matches actual batch size
1461+ # If we don't have enough filenames, generate default ones
1462+ while len (filenames ) < actual_batch_size :
1463+ filenames .append (f"volume_{ self .global_step } _{ len (filenames )} " )
1464+ print (f" DEBUG: test_step - after padding, filenames count: { len (filenames )} " )
13401465
13411466 # Check if we should save intermediate predictions (before decoding)
13421467 save_intermediate = False
@@ -1348,10 +1473,13 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
13481473 self ._write_outputs (predictions_np , filenames , suffix = "tta_prediction" )
13491474
13501475 # Apply decode mode (instance segmentation decoding)
1476+ print (f" DEBUG: test_step - before decode, predictions_np shape: { predictions_np .shape } " )
13511477 decoded_predictions = self ._apply_decode_mode (predictions_np )
1478+ print (f" DEBUG: test_step - after decode, decoded_predictions shape: { decoded_predictions .shape } " )
13521479
13531480 # Apply postprocessing (scaling and dtype conversion) if configured
13541481 postprocessed_predictions = self ._apply_postprocessing (decoded_predictions )
1482+ print (f" DEBUG: test_step - after postprocess, postprocessed_predictions shape: { postprocessed_predictions .shape } " )
13551483
13561484 # Save final decoded and postprocessed predictions
13571485 self ._write_outputs (postprocessed_predictions , filenames , suffix = "prediction" )
0 commit comments