Skip to content

Commit b318579

Browse files
committed
Fix batch processing for inference and add auto-matching test images
- Fix IndexError in batch processing by preserving batch dimensions in _apply_decode_mode and _apply_postprocessing - Fix glob pattern handling in read_volume for TIFF files - Add automatic test_image matching in neuroglancer visualization when prediction files are provided - Improve batch size detection in _write_outputs to handle 2D and 3D data correctly
1 parent e528dc8 commit b318579

File tree

4 files changed

+246
-38
lines changed

4 files changed

+246
-38
lines changed

connectomics/data/io/io.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,31 @@ def read_volume(
281281
if image_suffix in ["h5", "hdf5"]:
282282
data = read_hdf5(filename, dataset)
283283
elif "tif" in image_suffix:
284-
data = imageio.volread(filename).squeeze()
284+
# Check if filename contains glob patterns
285+
if "*" in filename or "?" in filename:
286+
# Expand glob pattern to get matching files
287+
file_list = sorted(glob.glob(filename))
288+
if len(file_list) == 0:
289+
raise FileNotFoundError(f"No TIFF files found matching pattern: {filename}")
290+
291+
# Read each file and stack along depth dimension
292+
volumes = []
293+
for filepath in file_list:
294+
vol = imageio.volread(filepath).squeeze()
295+
# imageio.volread can return multi-page TIFF as (D, H, W) or single page as (H, W)
296+
# Ensure all volumes have at least 3D (D, H, W)
297+
if vol.ndim == 2:
298+
vol = vol[np.newaxis, ...] # Add depth dimension: (H, W) -> (1, H, W)
299+
# vol.ndim == 3 means (D, H, W), which is what we want
300+
volumes.append(vol)
301+
302+
# Stack all volumes along depth dimension
303+
# Each volume is (D_i, H, W), result will be (sum(D_i), H, W)
304+
data = np.concatenate(volumes, axis=0) # Stack along depth (first dimension)
305+
else:
306+
# Single file or multi-page TIFF
307+
data = imageio.volread(filename).squeeze()
308+
285309
if data.ndim == 4:
286310
# Convert (D, C, H, W) to (C, D, H, W) order
287311
data = data.transpose(1, 0, 2, 3)

connectomics/lightning/lit_model.py

Lines changed: 160 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -521,38 +521,66 @@ def _apply_postprocessing(self, data: np.ndarray) -> np.ndarray:
521521
from connectomics.decoding.postprocess import apply_binary_postprocessing
522522

523523
# Process each sample in batch
524-
batch_size = data.shape[0] if data.ndim >= 4 else 1
525-
526-
# Handle different input shapes
527-
if data.ndim == 2: # (H, W) -> (1, 1, H, W)
528-
data = data[np.newaxis, np.newaxis, ...]
529-
elif data.ndim == 3: # (D, H, W) or (C, H, W) -> assume (D, H, W) and add batch dim
530-
data = data[np.newaxis, ...] # (1, D, H, W)
524+
# Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
525+
print(f" DEBUG: _apply_postprocessing - input data shape: {data.shape}, ndim: {data.ndim}")
526+
if data.ndim == 4:
527+
# 2D data: (B, C, H, W)
528+
batch_size = data.shape[0]
529+
print(f" DEBUG: _apply_postprocessing - detected 2D data, batch_size: {batch_size}")
530+
elif data.ndim == 5:
531+
# 3D data: (B, C, D, H, W)
532+
batch_size = data.shape[0]
533+
print(f" DEBUG: _apply_postprocessing - detected 3D data, batch_size: {batch_size}")
534+
elif data.ndim == 3:
535+
# Single 3D volume: (C, D, H, W) or (D, H, W) - add batch dimension
536+
batch_size = 1
537+
data = data[np.newaxis, ...] # (1, C, D, H, W) or (1, D, H, W)
538+
print(f" DEBUG: _apply_postprocessing - single 3D sample, added batch dimension")
539+
elif data.ndim == 2:
540+
# Single 2D image: (H, W) - add batch and channel dimensions
541+
batch_size = 1
542+
data = data[np.newaxis, np.newaxis, ...] # (1, 1, H, W)
543+
print(f" DEBUG: _apply_postprocessing - single 2D sample, added batch and channel dimensions")
544+
else:
545+
batch_size = 1
531546

532-
# Ensure we have at least 4D: (B, ...) where ... can be (D, H, W) or (C, D, H, W)
547+
# Ensure we have at least 4D: (B, ...) where ... can be (C, H, W) for 2D or (C, D, H, W) for 3D
533548
results = []
534549
for batch_idx in range(batch_size):
535-
sample = data[batch_idx] # (C, D, H, W) or (D, H, W)
536-
537-
# Extract foreground probability (handle both 3D and 4D)
538-
if sample.ndim == 4: # (C, D, H, W)
539-
foreground_prob = sample[0] # Use first channel
540-
else: # (D, H, W) - already single channel
550+
sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D
551+
print(f" DEBUG: _apply_postprocessing - processing batch_idx {batch_idx}, sample shape: {sample.shape}")
552+
553+
# Extract foreground probability (always use first channel if channel dimension exists)
554+
if sample.ndim == 4: # (C, D, H, W) - 3D with channel
555+
foreground_prob = sample[0] # Use first channel -> (D, H, W)
556+
elif sample.ndim == 3:
557+
# Could be (C, H, W) for 2D or (D, H, W) for 3D without channel
558+
# If first dim is small (<=4), assume it's channel (2D), otherwise depth (3D)
559+
if sample.shape[0] <= 4:
560+
foreground_prob = sample[0] # (C, H, W) -> use first channel -> (H, W)
561+
else:
562+
foreground_prob = sample # (D, H, W) - already single channel
563+
elif sample.ndim == 2: # (H, W) - 2D single channel
564+
foreground_prob = sample
565+
else:
541566
foreground_prob = sample
542567

543568
# Apply binary postprocessing
544569
processed = apply_binary_postprocessing(foreground_prob, binary_config)
545570

546-
# Expand dims to maintain shape consistency
547-
if sample.ndim == 4:
571+
# Expand dims to maintain shape consistency with original sample structure
572+
if sample.ndim == 4: # (C, D, H, W) -> processed is (D, H, W)
548573
processed = processed[np.newaxis, ...] # (1, D, H, W)
549-
else:
550-
processed = processed # Keep (D, H, W)
574+
elif sample.ndim == 3 and sample.shape[0] <= 4: # (C, H, W) -> processed is (H, W)
575+
processed = processed[np.newaxis, ...] # (1, H, W)
576+
# else: processed is already correct shape (D, H, W) or (H, W)
551577

552578
results.append(processed)
553579

554580
# Stack results back into batch
581+
print(f" DEBUG: _apply_postprocessing - stacking {len(results)} results, shapes: {[r.shape for r in results]}")
555582
data = np.stack(results, axis=0)
583+
print(f" DEBUG: _apply_postprocessing - after stacking, data shape: {data.shape}")
556584

557585
# Step 2: Apply scaling if configured (support both new and legacy names)
558586
intensity_scale = getattr(postprocessing, 'intensity_scale', None)
@@ -651,13 +679,29 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
651679
}
652680

653681
# Process each sample in batch
654-
batch_size = data.shape[0] if data.ndim == 5 else 1
682+
# Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
683+
print(f" DEBUG: _apply_decode_mode - input data shape: {data.shape}, ndim: {data.ndim}")
655684
if data.ndim == 4:
656-
data = data[np.newaxis, ...] # Add batch dimension
685+
# 2D data: (B, C, H, W)
686+
batch_size = data.shape[0]
687+
print(f" DEBUG: _apply_decode_mode - detected 2D data, batch_size: {batch_size}")
688+
elif data.ndim == 5:
689+
# 3D data: (B, C, D, H, W)
690+
batch_size = data.shape[0]
691+
print(f" DEBUG: _apply_decode_mode - detected 3D data, batch_size: {batch_size}")
692+
else:
693+
# Single sample: add batch dimension
694+
batch_size = 1
695+
print(f" DEBUG: _apply_decode_mode - single sample, adding batch dimension")
696+
if data.ndim == 3:
697+
data = data[np.newaxis, ...] # (C, H, W) -> (1, C, H, W)
698+
elif data.ndim == 2:
699+
data = data[np.newaxis, np.newaxis, ...] # (H, W) -> (1, 1, H, W)
657700

658701
results = []
659702
for batch_idx in range(batch_size):
660-
sample = data[batch_idx] # (C, D, H, W)
703+
sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D
704+
print(f" DEBUG: _apply_decode_mode - processing batch_idx {batch_idx}, sample shape: {sample.shape}")
661705

662706
# Apply each decode mode sequentially
663707
for decode_cfg in decode_modes:
@@ -718,8 +762,10 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
718762
results.append(sample)
719763

720764
# Stack results back into batch
721-
decoded = np.stack(results, axis=0) if len(results) > 1 else results[0]
722-
765+
# Always preserve batch dimension, even for batch_size=1
766+
print(f" DEBUG: _apply_decode_mode - stacking {len(results)} results, shapes: {[r.shape for r in results]}")
767+
decoded = np.stack(results, axis=0)
768+
print(f" DEBUG: _apply_decode_mode - final decoded shape: {decoded.shape}")
723769
return decoded
724770

725771
def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]:
@@ -742,26 +788,59 @@ def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]:
742788

743789
meta = batch.get('image_meta_dict')
744790
filenames: List[Optional[str]] = []
791+
792+
print(f" DEBUG: _resolve_output_filenames - meta type: {type(meta)}, batch_size: {batch_size}")
745793

746-
if isinstance(meta, dict):
794+
# Handle different metadata structures
795+
if isinstance(meta, list):
796+
# Multiple metadata dicts (one per sample in batch)
797+
print(f" DEBUG: _resolve_output_filenames - meta is list with {len(meta)} items")
798+
for idx, meta_item in enumerate(meta):
799+
if isinstance(meta_item, dict):
800+
filename = meta_item.get('filename_or_obj')
801+
if filename is not None:
802+
filenames.append(filename)
803+
else:
804+
print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] has no filename_or_obj")
805+
else:
806+
print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] is not a dict: {type(meta_item)}")
807+
# Update batch_size from metadata if we have a list
808+
batch_size = max(batch_size, len(filenames))
809+
print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from list")
810+
elif isinstance(meta, dict):
811+
# Single metadata dict
812+
print(f" DEBUG: _resolve_output_filenames - meta is dict")
747813
meta_filenames = meta.get('filename_or_obj')
748814
if isinstance(meta_filenames, (list, tuple)):
749-
filenames = list(meta_filenames)
815+
filenames = [f for f in meta_filenames if f is not None]
750816
elif meta_filenames is not None:
751817
filenames = [meta_filenames]
752-
elif isinstance(meta, list):
753-
for meta_item in meta:
754-
if isinstance(meta_item, dict):
755-
filenames.append(meta_item.get('filename_or_obj'))
756-
# Update batch_size from metadata if we have a list
757-
batch_size = max(batch_size, len(filenames))
818+
# Update batch_size from metadata
819+
if len(filenames) > 0:
820+
batch_size = max(batch_size, len(filenames))
821+
print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from dict")
822+
else:
823+
# Handle case where meta might be None or other types
824+
# This can happen if metadata wasn't preserved through transforms
825+
# We'll use fallback filenames based on batch_size
826+
print(f" DEBUG: _resolve_output_filenames - meta is None or unexpected type: {type(meta)}")
827+
pass
758828

759829
resolved_names: List[str] = []
760830
for idx in range(batch_size):
761831
if idx < len(filenames) and filenames[idx]:
762832
resolved_names.append(Path(str(filenames[idx])).stem)
763833
else:
834+
# Generate fallback filename - this shouldn't happen if metadata is preserved correctly
764835
resolved_names.append(f"volume_{self.global_step}_{idx}")
836+
837+
print(f" DEBUG: _resolve_output_filenames - returning {len(resolved_names)} resolved names: {resolved_names[:3]}...")
838+
839+
# Always return exactly batch_size filenames
840+
if len(resolved_names) < batch_size:
841+
print(f" WARNING: _resolve_output_filenames - Only {len(resolved_names)} filenames but batch_size is {batch_size}, padding with fallback names")
842+
while len(resolved_names) < batch_size:
843+
resolved_names.append(f"volume_{self.global_step}_{len(resolved_names)}")
765844

766845
return resolved_names
767846

@@ -799,8 +878,42 @@ def _write_outputs(
799878
if hasattr(self.cfg.inference, 'postprocessing'):
800879
output_transpose = getattr(self.cfg.inference.postprocessing, 'output_transpose', [])
801880

881+
# Determine actual batch size from predictions
882+
# Handle both batched (B, ...) and unbatched (...) predictions
883+
print(f" DEBUG: _write_outputs - predictions shape: {predictions.shape}, ndim: {predictions.ndim}, filenames count: {len(filenames)}")
884+
885+
if predictions.ndim >= 4:
886+
# Has batch dimension: (B, C, D, H, W) or (B, C, H, W) or (B, D, H, W)
887+
actual_batch_size = predictions.shape[0]
888+
elif predictions.ndim == 3:
889+
# Could be batched 2D data (B, H, W) or single 3D volume (D, H, W)
890+
# Check if first dimension matches number of filenames -> it's batched 2D data
891+
if len(filenames) > 0 and predictions.shape[0] == len(filenames):
892+
# Batched 2D data: (B, H, W) where B matches number of filenames
893+
actual_batch_size = predictions.shape[0]
894+
print(f" DEBUG: _write_outputs - detected batched 2D data (B, H, W) with batch_size={actual_batch_size}")
895+
else:
896+
# Single 3D volume: (D, H, W) - treat as batch_size=1
897+
actual_batch_size = 1
898+
predictions = predictions[np.newaxis, ...] # Add batch dimension
899+
print(f" DEBUG: _write_outputs - detected single 3D volume, added batch dimension")
900+
elif predictions.ndim == 2:
901+
# Single 2D image: (H, W) - treat as batch_size=1
902+
actual_batch_size = 1
903+
predictions = predictions[np.newaxis, ...] # Add batch dimension
904+
else:
905+
# Unexpected shape, default to batch_size=1
906+
actual_batch_size = 1
907+
if predictions.ndim < 2:
908+
predictions = predictions[np.newaxis, ...] # Add batch dimension
909+
910+
# Ensure we don't exceed the actual batch size
911+
batch_size = min(actual_batch_size, len(filenames))
912+
print(f" DEBUG: _write_outputs - actual_batch_size: {actual_batch_size}, batch_size: {batch_size}, will save {batch_size} predictions")
913+
802914
# Save predictions
803-
for idx, name in enumerate(filenames):
915+
for idx in range(batch_size):
916+
name = filenames[idx]
804917
prediction = predictions[idx]
805918

806919
# Apply output transpose if specified
@@ -1303,16 +1416,28 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
13031416
labels = batch.get('label')
13041417
mask = batch.get('mask') # Get test mask if available
13051418

1419+
# Get batch size from images
1420+
actual_batch_size = images.shape[0]
1421+
print(f" DEBUG: test_step - images shape: {images.shape}, batch_size: {actual_batch_size}")
1422+
13061423
# Always use TTA (handles no-transform case) + sliding window
13071424
# TTA preprocessing (activation, masking) is applied regardless of flip augmentation
13081425
# Note: TTA always returns a simple tensor, not a dict (deep supervision not supported in test mode)
13091426
predictions = self._predict_with_tta(images, mask=mask)
13101427

13111428
# Convert predictions to numpy for saving/decoding
13121429
predictions_np = predictions.detach().cpu().float().numpy()
1430+
print(f" DEBUG: test_step - predictions_np shape: {predictions_np.shape}")
13131431

13141432
# Resolve filenames once for all saving operations
13151433
filenames = self._resolve_output_filenames(batch)
1434+
print(f" DEBUG: test_step - filenames count: {len(filenames)}, filenames: {filenames[:5]}...")
1435+
1436+
# Ensure filenames list matches actual batch size
1437+
# If we don't have enough filenames, generate default ones
1438+
while len(filenames) < actual_batch_size:
1439+
filenames.append(f"volume_{self.global_step}_{len(filenames)}")
1440+
print(f" DEBUG: test_step - after padding, filenames count: {len(filenames)}")
13161441

13171442
# Check if we should save intermediate predictions (before decoding)
13181443
save_intermediate = False
@@ -1324,10 +1449,13 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
13241449
self._write_outputs(predictions_np, filenames, suffix="tta_prediction")
13251450

13261451
# Apply decode mode (instance segmentation decoding)
1452+
print(f" DEBUG: test_step - before decode, predictions_np shape: {predictions_np.shape}")
13271453
decoded_predictions = self._apply_decode_mode(predictions_np)
1454+
print(f" DEBUG: test_step - after decode, decoded_predictions shape: {decoded_predictions.shape}")
13281455

13291456
# Apply postprocessing (scaling and dtype conversion) if configured
13301457
postprocessed_predictions = self._apply_postprocessing(decoded_predictions)
1458+
print(f" DEBUG: test_step - after postprocess, postprocessed_predictions shape: {postprocessed_predictions.shape}")
13311459

13321460
# Save final decoded and postprocessed predictions
13331461
self._write_outputs(postprocessed_predictions, filenames, suffix="prediction")

0 commit comments

Comments
 (0)