Skip to content

Commit 03701d9

Browse files
authored
Merge pull request #163 from BoyuShen2004/master
Fix batch processing for inference and add auto-matching test images
2 parents 5080376 + b318579 commit 03701d9

File tree

4 files changed

+246
-38
lines changed

4 files changed

+246
-38
lines changed

connectomics/data/io/io.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,31 @@ def read_volume(
281281
if image_suffix in ["h5", "hdf5"]:
282282
data = read_hdf5(filename, dataset)
283283
elif "tif" in image_suffix:
284-
data = imageio.volread(filename).squeeze()
284+
# Check if filename contains glob patterns
285+
if "*" in filename or "?" in filename:
286+
# Expand glob pattern to get matching files
287+
file_list = sorted(glob.glob(filename))
288+
if len(file_list) == 0:
289+
raise FileNotFoundError(f"No TIFF files found matching pattern: {filename}")
290+
291+
# Read each file and stack along depth dimension
292+
volumes = []
293+
for filepath in file_list:
294+
vol = imageio.volread(filepath).squeeze()
295+
# imageio.volread can return multi-page TIFF as (D, H, W) or single page as (H, W)
296+
# Ensure all volumes have at least 3D (D, H, W)
297+
if vol.ndim == 2:
298+
vol = vol[np.newaxis, ...] # Add depth dimension: (H, W) -> (1, H, W)
299+
# vol.ndim == 3 means (D, H, W), which is what we want
300+
volumes.append(vol)
301+
302+
# Stack all volumes along depth dimension
303+
# Each volume is (D_i, H, W), result will be (sum(D_i), H, W)
304+
data = np.concatenate(volumes, axis=0) # Stack along depth (first dimension)
305+
else:
306+
# Single file or multi-page TIFF
307+
data = imageio.volread(filename).squeeze()
308+
285309
if data.ndim == 4:
286310
# Convert (D, C, H, W) to (C, D, H, W) order
287311
data = data.transpose(1, 0, 2, 3)

connectomics/lightning/lit_model.py

Lines changed: 160 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -536,38 +536,66 @@ def _apply_postprocessing(self, data: np.ndarray) -> np.ndarray:
536536
from connectomics.decoding.postprocess import apply_binary_postprocessing
537537

538538
# Process each sample in batch
539-
batch_size = data.shape[0] if data.ndim >= 4 else 1
540-
541-
# Handle different input shapes
542-
if data.ndim == 2: # (H, W) -> (1, 1, H, W)
543-
data = data[np.newaxis, np.newaxis, ...]
544-
elif data.ndim == 3: # (D, H, W) or (C, H, W) -> assume (D, H, W) and add batch dim
545-
data = data[np.newaxis, ...] # (1, D, H, W)
539+
# Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
540+
print(f" DEBUG: _apply_postprocessing - input data shape: {data.shape}, ndim: {data.ndim}")
541+
if data.ndim == 4:
542+
# 2D data: (B, C, H, W)
543+
batch_size = data.shape[0]
544+
print(f" DEBUG: _apply_postprocessing - detected 2D data, batch_size: {batch_size}")
545+
elif data.ndim == 5:
546+
# 3D data: (B, C, D, H, W)
547+
batch_size = data.shape[0]
548+
print(f" DEBUG: _apply_postprocessing - detected 3D data, batch_size: {batch_size}")
549+
elif data.ndim == 3:
550+
# Single 3D volume: (C, D, H, W) or (D, H, W) - add batch dimension
551+
batch_size = 1
552+
data = data[np.newaxis, ...] # (1, C, D, H, W) or (1, D, H, W)
553+
print(f" DEBUG: _apply_postprocessing - single 3D sample, added batch dimension")
554+
elif data.ndim == 2:
555+
# Single 2D image: (H, W) - add batch and channel dimensions
556+
batch_size = 1
557+
data = data[np.newaxis, np.newaxis, ...] # (1, 1, H, W)
558+
print(f" DEBUG: _apply_postprocessing - single 2D sample, added batch and channel dimensions")
559+
else:
560+
batch_size = 1
546561

547-
# Ensure we have at least 4D: (B, ...) where ... can be (D, H, W) or (C, D, H, W)
562+
# Ensure we have at least 4D: (B, ...) where ... can be (C, H, W) for 2D or (C, D, H, W) for 3D
548563
results = []
549564
for batch_idx in range(batch_size):
550-
sample = data[batch_idx] # (C, D, H, W) or (D, H, W)
551-
552-
# Extract foreground probability (handle both 3D and 4D)
553-
if sample.ndim == 4: # (C, D, H, W)
554-
foreground_prob = sample[0] # Use first channel
555-
else: # (D, H, W) - already single channel
565+
sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D
566+
print(f" DEBUG: _apply_postprocessing - processing batch_idx {batch_idx}, sample shape: {sample.shape}")
567+
568+
# Extract foreground probability (always use first channel if channel dimension exists)
569+
if sample.ndim == 4: # (C, D, H, W) - 3D with channel
570+
foreground_prob = sample[0] # Use first channel -> (D, H, W)
571+
elif sample.ndim == 3:
572+
# Could be (C, H, W) for 2D or (D, H, W) for 3D without channel
573+
# If first dim is small (<=4), assume it's channel (2D), otherwise depth (3D)
574+
if sample.shape[0] <= 4:
575+
foreground_prob = sample[0] # (C, H, W) -> use first channel -> (H, W)
576+
else:
577+
foreground_prob = sample # (D, H, W) - already single channel
578+
elif sample.ndim == 2: # (H, W) - 2D single channel
579+
foreground_prob = sample
580+
else:
556581
foreground_prob = sample
557582

558583
# Apply binary postprocessing
559584
processed = apply_binary_postprocessing(foreground_prob, binary_config)
560585

561-
# Expand dims to maintain shape consistency
562-
if sample.ndim == 4:
586+
# Expand dims to maintain shape consistency with original sample structure
587+
if sample.ndim == 4: # (C, D, H, W) -> processed is (D, H, W)
563588
processed = processed[np.newaxis, ...] # (1, D, H, W)
564-
else:
565-
processed = processed # Keep (D, H, W)
589+
elif sample.ndim == 3 and sample.shape[0] <= 4: # (C, H, W) -> processed is (H, W)
590+
processed = processed[np.newaxis, ...] # (1, H, W)
591+
# else: processed is already correct shape (D, H, W) or (H, W)
566592

567593
results.append(processed)
568594

569595
# Stack results back into batch
596+
print(f" DEBUG: _apply_postprocessing - stacking {len(results)} results, shapes: {[r.shape for r in results]}")
570597
data = np.stack(results, axis=0)
598+
print(f" DEBUG: _apply_postprocessing - after stacking, data shape: {data.shape}")
571599

572600
# Step 2: Apply scaling if configured (support both new and legacy names)
573601
intensity_scale = getattr(postprocessing, 'intensity_scale', None)
@@ -666,13 +694,29 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
666694
}
667695

668696
# Process each sample in batch
669-
batch_size = data.shape[0] if data.ndim == 5 else 1
697+
# Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
698+
print(f" DEBUG: _apply_decode_mode - input data shape: {data.shape}, ndim: {data.ndim}")
670699
if data.ndim == 4:
671-
data = data[np.newaxis, ...] # Add batch dimension
700+
# 2D data: (B, C, H, W)
701+
batch_size = data.shape[0]
702+
print(f" DEBUG: _apply_decode_mode - detected 2D data, batch_size: {batch_size}")
703+
elif data.ndim == 5:
704+
# 3D data: (B, C, D, H, W)
705+
batch_size = data.shape[0]
706+
print(f" DEBUG: _apply_decode_mode - detected 3D data, batch_size: {batch_size}")
707+
else:
708+
# Single sample: add batch dimension
709+
batch_size = 1
710+
print(f" DEBUG: _apply_decode_mode - single sample, adding batch dimension")
711+
if data.ndim == 3:
712+
data = data[np.newaxis, ...] # (C, H, W) -> (1, C, H, W)
713+
elif data.ndim == 2:
714+
data = data[np.newaxis, np.newaxis, ...] # (H, W) -> (1, 1, H, W)
672715

673716
results = []
674717
for batch_idx in range(batch_size):
675-
sample = data[batch_idx] # (C, D, H, W)
718+
sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D
719+
print(f" DEBUG: _apply_decode_mode - processing batch_idx {batch_idx}, sample shape: {sample.shape}")
676720

677721
# Apply each decode mode sequentially
678722
for decode_cfg in decode_modes:
@@ -733,8 +777,10 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
733777
results.append(sample)
734778

735779
# Stack results back into batch
736-
decoded = np.stack(results, axis=0) if len(results) > 1 else results[0]
737-
780+
# Always preserve batch dimension, even for batch_size=1
781+
print(f" DEBUG: _apply_decode_mode - stacking {len(results)} results, shapes: {[r.shape for r in results]}")
782+
decoded = np.stack(results, axis=0)
783+
print(f" DEBUG: _apply_decode_mode - final decoded shape: {decoded.shape}")
738784
return decoded
739785

740786
def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]:
@@ -757,26 +803,59 @@ def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]:
757803

758804
meta = batch.get('image_meta_dict')
759805
filenames: List[Optional[str]] = []
806+
807+
print(f" DEBUG: _resolve_output_filenames - meta type: {type(meta)}, batch_size: {batch_size}")
760808

761-
if isinstance(meta, dict):
809+
# Handle different metadata structures
810+
if isinstance(meta, list):
811+
# Multiple metadata dicts (one per sample in batch)
812+
print(f" DEBUG: _resolve_output_filenames - meta is list with {len(meta)} items")
813+
for idx, meta_item in enumerate(meta):
814+
if isinstance(meta_item, dict):
815+
filename = meta_item.get('filename_or_obj')
816+
if filename is not None:
817+
filenames.append(filename)
818+
else:
819+
print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] has no filename_or_obj")
820+
else:
821+
print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] is not a dict: {type(meta_item)}")
822+
# Update batch_size from metadata if we have a list
823+
batch_size = max(batch_size, len(filenames))
824+
print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from list")
825+
elif isinstance(meta, dict):
826+
# Single metadata dict
827+
print(f" DEBUG: _resolve_output_filenames - meta is dict")
762828
meta_filenames = meta.get('filename_or_obj')
763829
if isinstance(meta_filenames, (list, tuple)):
764-
filenames = list(meta_filenames)
830+
filenames = [f for f in meta_filenames if f is not None]
765831
elif meta_filenames is not None:
766832
filenames = [meta_filenames]
767-
elif isinstance(meta, list):
768-
for meta_item in meta:
769-
if isinstance(meta_item, dict):
770-
filenames.append(meta_item.get('filename_or_obj'))
771-
# Update batch_size from metadata if we have a list
772-
batch_size = max(batch_size, len(filenames))
833+
# Update batch_size from metadata
834+
if len(filenames) > 0:
835+
batch_size = max(batch_size, len(filenames))
836+
print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from dict")
837+
else:
838+
# Handle case where meta might be None or other types
839+
# This can happen if metadata wasn't preserved through transforms
840+
# We'll use fallback filenames based on batch_size
841+
print(f" DEBUG: _resolve_output_filenames - meta is None or unexpected type: {type(meta)}")
842+
pass
773843

774844
resolved_names: List[str] = []
775845
for idx in range(batch_size):
776846
if idx < len(filenames) and filenames[idx]:
777847
resolved_names.append(Path(str(filenames[idx])).stem)
778848
else:
849+
# Generate fallback filename - this shouldn't happen if metadata is preserved correctly
779850
resolved_names.append(f"volume_{self.global_step}_{idx}")
851+
852+
print(f" DEBUG: _resolve_output_filenames - returning {len(resolved_names)} resolved names: {resolved_names[:3]}...")
853+
854+
# Always return exactly batch_size filenames
855+
if len(resolved_names) < batch_size:
856+
print(f" WARNING: _resolve_output_filenames - Only {len(resolved_names)} filenames but batch_size is {batch_size}, padding with fallback names")
857+
while len(resolved_names) < batch_size:
858+
resolved_names.append(f"volume_{self.global_step}_{len(resolved_names)}")
780859

781860
return resolved_names
782861

@@ -814,8 +893,42 @@ def _write_outputs(
814893
if hasattr(self.cfg.inference, 'postprocessing'):
815894
output_transpose = getattr(self.cfg.inference.postprocessing, 'output_transpose', [])
816895

896+
# Determine actual batch size from predictions
897+
# Handle both batched (B, ...) and unbatched (...) predictions
898+
print(f" DEBUG: _write_outputs - predictions shape: {predictions.shape}, ndim: {predictions.ndim}, filenames count: {len(filenames)}")
899+
900+
if predictions.ndim >= 4:
901+
# Has batch dimension: (B, C, D, H, W) or (B, C, H, W) or (B, D, H, W)
902+
actual_batch_size = predictions.shape[0]
903+
elif predictions.ndim == 3:
904+
# Could be batched 2D data (B, H, W) or single 3D volume (D, H, W)
905+
# Check if first dimension matches number of filenames -> it's batched 2D data
906+
if len(filenames) > 0 and predictions.shape[0] == len(filenames):
907+
# Batched 2D data: (B, H, W) where B matches number of filenames
908+
actual_batch_size = predictions.shape[0]
909+
print(f" DEBUG: _write_outputs - detected batched 2D data (B, H, W) with batch_size={actual_batch_size}")
910+
else:
911+
# Single 3D volume: (D, H, W) - treat as batch_size=1
912+
actual_batch_size = 1
913+
predictions = predictions[np.newaxis, ...] # Add batch dimension
914+
print(f" DEBUG: _write_outputs - detected single 3D volume, added batch dimension")
915+
elif predictions.ndim == 2:
916+
# Single 2D image: (H, W) - treat as batch_size=1
917+
actual_batch_size = 1
918+
predictions = predictions[np.newaxis, ...] # Add batch dimension
919+
else:
920+
# Unexpected shape, default to batch_size=1
921+
actual_batch_size = 1
922+
if predictions.ndim < 2:
923+
predictions = predictions[np.newaxis, ...] # Add batch dimension
924+
925+
# Ensure we don't exceed the actual batch size
926+
batch_size = min(actual_batch_size, len(filenames))
927+
print(f" DEBUG: _write_outputs - actual_batch_size: {actual_batch_size}, batch_size: {batch_size}, will save {batch_size} predictions")
928+
817929
# Save predictions
818-
for idx, name in enumerate(filenames):
930+
for idx in range(batch_size):
931+
name = filenames[idx]
819932
prediction = predictions[idx]
820933

821934
# Apply output transpose if specified
@@ -1327,16 +1440,28 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
13271440
labels = batch.get('label')
13281441
mask = batch.get('mask') # Get test mask if available
13291442

1443+
# Get batch size from images
1444+
actual_batch_size = images.shape[0]
1445+
print(f" DEBUG: test_step - images shape: {images.shape}, batch_size: {actual_batch_size}")
1446+
13301447
# Always use TTA (handles no-transform case) + sliding window
13311448
# TTA preprocessing (activation, masking) is applied regardless of flip augmentation
13321449
# Note: TTA always returns a simple tensor, not a dict (deep supervision not supported in test mode)
13331450
predictions = self._predict_with_tta(images, mask=mask)
13341451

13351452
# Convert predictions to numpy for saving/decoding
13361453
predictions_np = predictions.detach().cpu().float().numpy()
1454+
print(f" DEBUG: test_step - predictions_np shape: {predictions_np.shape}")
13371455

13381456
# Resolve filenames once for all saving operations
13391457
filenames = self._resolve_output_filenames(batch)
1458+
print(f" DEBUG: test_step - filenames count: {len(filenames)}, filenames: {filenames[:5]}...")
1459+
1460+
# Ensure filenames list matches actual batch size
1461+
# If we don't have enough filenames, generate default ones
1462+
while len(filenames) < actual_batch_size:
1463+
filenames.append(f"volume_{self.global_step}_{len(filenames)}")
1464+
print(f" DEBUG: test_step - after padding, filenames count: {len(filenames)}")
13401465

13411466
# Check if we should save intermediate predictions (before decoding)
13421467
save_intermediate = False
@@ -1348,10 +1473,13 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
13481473
self._write_outputs(predictions_np, filenames, suffix="tta_prediction")
13491474

13501475
# Apply decode mode (instance segmentation decoding)
1476+
print(f" DEBUG: test_step - before decode, predictions_np shape: {predictions_np.shape}")
13511477
decoded_predictions = self._apply_decode_mode(predictions_np)
1478+
print(f" DEBUG: test_step - after decode, decoded_predictions shape: {decoded_predictions.shape}")
13521479

13531480
# Apply postprocessing (scaling and dtype conversion) if configured
13541481
postprocessed_predictions = self._apply_postprocessing(decoded_predictions)
1482+
print(f" DEBUG: test_step - after postprocess, postprocessed_predictions shape: {postprocessed_predictions.shape}")
13551483

13561484
# Save final decoded and postprocessed predictions
13571485
self._write_outputs(postprocessed_predictions, filenames, suffix="prediction")

0 commit comments

Comments
 (0)