Skip to content

Commit 8f434a8

Browse files
committed
Phase 2.2-2.3: Remove dummy validation dataset and make deep supervision configurable
Changes: 1. Phase 2.2: Remove dummy validation dataset workaround (lit_data.py) - Replaced DummyDataset with proper warning when validation fails - More honest error handling instead of masking configuration issues 2. Phase 2.3: Make hardcoded deep supervision values configurable (hydra_config.py, lit_model.py) - Added ModelConfig fields: * deep_supervision_weights (default: [1.0, 0.5, 0.25, 0.125, 0.0625]) * deep_supervision_clamp_min (default: -20.0) * deep_supervision_clamp_max (default: 20.0) - Updated lit_model.py to use configurable values in 3 locations: * Multi-task output clamping * Deep supervision output clamping * Deep supervision scale weights with validation Benefits: - Users can now customize deep supervision behavior without code changes - Removes technical debt items #2 and #4 from REFACTORING_PLAN.md - Better error handling for missing validation data
1 parent 0946eaa commit 8f434a8

File tree

3 files changed

+32
-25
lines changed

3 files changed

+32
-25
lines changed

connectomics/config/hydra_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ class ModelConfig:
182182

183183
# Deep supervision (supported by MedNeXt, RSUNet, and some MONAI models)
184184
deep_supervision: bool = False
185+
deep_supervision_weights: Optional[List[float]] = None # None = auto: [1.0, 0.5, 0.25, 0.125, 0.0625]
186+
deep_supervision_clamp_min: float = -20.0 # Clamp logits to prevent numerical instability
187+
deep_supervision_clamp_max: float = 20.0 # Especially important at coarser scales
185188

186189
# Loss configuration
187190
loss_functions: List[str] = field(default_factory=lambda: ["DiceLoss", "BCEWithLogitsLoss"])

connectomics/lightning/lit_data.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from __future__ import annotations
99
from typing import Dict, List, Any, Optional, Union, Tuple
10-
import numpy as np
10+
import warnings
1111

12+
import numpy as np
1213
import torch
1314
import pytorch_lightning as pl
1415
from torch.utils.data import DataLoader
@@ -179,29 +180,16 @@ def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
179180
return []
180181

181182
dataloader = self._create_dataloader(self.val_dataset, shuffle=False)
182-
if dataloader is None:
183-
from torch.utils.data import Dataset
184-
185-
class DummyDataset(Dataset):
186-
def __len__(self):
187-
return 1
188183

189-
def __getitem__(self, idx):
190-
zero = torch.zeros(1, dtype=torch.float32)
191-
return {
192-
"image": zero,
193-
"label": zero,
194-
}
195-
196-
return DataLoader(
197-
dataset=DummyDataset(),
198-
batch_size=1,
199-
shuffle=False,
200-
num_workers=0,
201-
pin_memory=False,
202-
persistent_workers=False,
203-
collate_fn=self._collate_fn,
184+
# If validation dataset exists but dataloader creation failed,
185+
# skip validation rather than using dummy data
186+
if dataloader is None:
187+
warnings.warn(
188+
"Validation dataloader creation failed despite validation dataset being provided. "
189+
"Skipping validation. Check your data configuration.",
190+
UserWarning
204191
)
192+
return []
205193

206194
return dataloader
207195

connectomics/lightning/lit_model.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,9 @@ def _compute_loss_for_scale(
11141114
# At coarser scales (especially with mixed precision), logits can explode
11151115
# BCEWithLogitsLoss: clamp to [-20, 20] (sigmoid maps to [2e-9, 1-2e-9])
11161116
# MSELoss with tanh: clamp to [-10, 10] (tanh maps to [-0.9999, 0.9999])
1117-
task_output = torch.clamp(task_output, min=-20.0, max=20.0)
1117+
clamp_min = getattr(self.cfg.model, 'deep_supervision_clamp_min', -20.0)
1118+
clamp_max = getattr(self.cfg.model, 'deep_supervision_clamp_max', 20.0)
1119+
task_output = torch.clamp(task_output, min=clamp_min, max=clamp_max)
11181120

11191121
# Apply specified losses for this task
11201122
for loss_idx in loss_indices:
@@ -1142,7 +1144,9 @@ def _compute_loss_for_scale(
11421144
else:
11431145
# Standard deep supervision: apply all losses to all outputs
11441146
# Clamp outputs to prevent numerical instability at coarser scales
1145-
output_clamped = torch.clamp(output, min=-20.0, max=20.0)
1147+
clamp_min = getattr(self.cfg.model, 'deep_supervision_clamp_min', -20.0)
1148+
clamp_max = getattr(self.cfg.model, 'deep_supervision_clamp_max', 20.0)
1149+
output_clamped = torch.clamp(output, min=clamp_min, max=clamp_max)
11461150

11471151
for loss_fn, weight in zip(self.loss_functions, self.loss_weights):
11481152
loss = loss_fn(output_clamped, target)
@@ -1191,7 +1195,19 @@ def _compute_deep_supervision_loss(
11911195
main_output = outputs['output']
11921196
ds_outputs = [outputs[f'ds_{i}'] for i in range(1, 5) if f'ds_{i}' in outputs]
11931197

1194-
ds_weights = [1.0] + [0.5 ** i for i in range(1, len(ds_outputs) + 1)]
1198+
# Use configured weights or default exponential decay
1199+
if hasattr(self.cfg.model, 'deep_supervision_weights') and self.cfg.model.deep_supervision_weights is not None:
1200+
ds_weights = self.cfg.model.deep_supervision_weights
1201+
# Ensure we have enough weights for all outputs
1202+
if len(ds_weights) < len(ds_outputs) + 1:
1203+
warnings.warn(
1204+
f"deep_supervision_weights has {len(ds_weights)} weights but "
1205+
f"{len(ds_outputs) + 1} outputs. Using exponential decay for missing weights."
1206+
)
1207+
ds_weights = [1.0] + [0.5 ** i for i in range(1, len(ds_outputs) + 1)]
1208+
else:
1209+
ds_weights = [1.0] + [0.5 ** i for i in range(1, len(ds_outputs) + 1)]
1210+
11951211
all_outputs = [main_output] + ds_outputs
11961212

11971213
total_loss = 0.0

0 commit comments

Comments
 (0)