Skip to content

Commit 435de70

Browse files
authored
Merge pull request #172 from PytorchConnectomics/claude/remove-backward-compatibility-01YXAHvH6ebTvvYLmfr18rZY
Claude/remove backward compatibility 01 yxa hv h6eb tvv y lmfr18r zy
2 parents 96e2aca + 15f3c8a commit 435de70

File tree

8 files changed

+778
-73
lines changed

8 files changed

+778
-73
lines changed

REFACTORING_PLAN.md

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,18 @@ connectomics/lightning/
259259
└── lit_data.py # LightningDataModule (684 lines, existing)
260260
```
261261

262+
**Migration Steps:**
263+
1. Create new module files
264+
2. Move functionality in logical chunks
265+
3. Update imports in `lit_model.py`
266+
4. Add integration tests for each module
267+
5. Update documentation
268+
269+
**Success Criteria:**
270+
- [ ] Each file < 500 lines
271+
- [ ] Clear separation of concerns
272+
- [ ] All existing tests pass
273+
- [ ] Documentation updated
262274
**Note:** Multi-task learning was integrated into `deep_supervision.py` (not a separate module) since the logic is tightly coupled with deep supervision.
263275

264276
**Completed Actions:**
@@ -408,7 +420,6 @@ class DataConfig:
408420
- [x] `deep_supervision_clamp_min: float` (default: -20.0)
409421
- [x] `deep_supervision_clamp_max: float` (default: 20.0)
410422
- [x] Validation logic with warning for insufficient weights
411-
- [x] Backward compatible (defaults match old behavior)
412423
- [ ] Other hardcoded values (target interpolation, rejection sampling) - Future work
413424

414425
**Status:** ✅ Phase 2.3 (Deep Supervision) completed. Users can now customize deep supervision weights and clamping ranges via config.
@@ -500,14 +511,12 @@ def build_test_transforms(cfg: Config, keys: list[str] = None) -> Compose:
500511
- File size reduced from 791 to 727 lines (-64 lines, ~8% reduction)
501512
- Eliminated ~80% code duplication
502513
- Single source of truth for shared transform logic
503-
- Backward compatible (same public API)
504514

505515
**Action Items:**
506516
- [x] Extract shared logic into `_build_eval_transforms_impl()`
507517
- [x] Identify val/test-specific differences (4 key differences)
508518
- [x] Create mode-specific branching with clear comments
509519
- [x] Keep wrapper functions for API compatibility
510-
- [x] Backward compatible (public API unchanged)
511520

512521
**Status:** ✅ Phase 2.5 complete. Code duplication eliminated while preserving all functionality.
513522

@@ -996,10 +1005,8 @@ See Priority 1.3 above for full details.
9961005

9971006
### Mitigation Strategies
9981007
1. **Comprehensive testing** before and after each change
999-
2. **Feature flags** for backward compatibility
1000-
3. **Deprecation warnings** before removal
1001-
4. **Rollback plan** for each phase
1002-
5. **User communication** via release notes
1008+
2. **Rollback plan** for each phase
1009+
3. **User communication** via release notes
10031010

10041011
---
10051012

connectomics/config/hydra_config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,6 @@ class DataConfig:
392392
default_factory=list
393393
) # Axis permutation for training data (e.g., [2,1,0] for xyz->zyx)
394394
val_transpose: List[int] = field(default_factory=list) # Axis permutation for validation data
395-
test_transpose: List[int] = field(
396-
default_factory=list
397-
) # Axis permutation for test data (deprecated, use inference.data.test_transpose)
398395

399396
# Dataset statistics (for auto-planning)
400397
target_spacing: Optional[List[float]] = None # Target voxel spacing [z, y, x] in mm
@@ -868,9 +865,6 @@ class TestTimeAugmentationConfig:
868865
flip_axes: Any = (
869866
None # TTA flip strategy: "all" (8 flips), null (no aug), or list like [[0], [1], [2]]
870867
)
871-
act: Optional[str] = (
872-
None # Single activation for all channels: 'softmax', 'sigmoid', 'tanh', None (deprecated, use channel_activations)
873-
)
874868
channel_activations: Optional[List[Any]] = (
875869
None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...] e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']]
876870
)

connectomics/config/hydra_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ def resolve_data_paths(cfg: Config) -> Config:
244244
Supported paths:
245245
- Training: cfg.data.train_path + cfg.data.train_image/train_label/train_mask
246246
- Validation: cfg.data.val_path + cfg.data.val_image/val_label/val_mask
247-
- Testing (legacy): cfg.data.test_path + cfg.data.test_image/test_label/test_mask
248-
- Inference (primary): cfg.inference.data.test_path + cfg.inference.data.test_image/test_label/test_mask
247+
- Testing: cfg.data.test_path + cfg.data.test_image/test_label/test_mask
248+
- Inference: cfg.inference.data.test_path + cfg.inference.data.test_image/test_label/test_mask
249249
250250
Args:
251251
cfg: Config object to resolve paths for
@@ -316,7 +316,7 @@ def _combine_path(base_path: str, file_path: Optional[Union[str, List[str]]]) ->
316316
cfg.data.val_mask = _combine_path(cfg.data.val_path, cfg.data.val_mask)
317317
cfg.data.val_json = _combine_path(cfg.data.val_path, cfg.data.val_json)
318318

319-
# Resolve test paths (legacy support for cfg.data.test_path)
319+
# Resolve test paths
320320
if cfg.data.test_path:
321321
cfg.data.test_image = _combine_path(cfg.data.test_path, cfg.data.test_image)
322322
cfg.data.test_label = _combine_path(cfg.data.test_path, cfg.data.test_label)

connectomics/data/augment/build.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def build_train_transforms(
7373
# Load images first (unless using pre-cached dataset)
7474
if not skip_loading:
7575
# Use appropriate loader based on dataset type
76-
dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility
76+
dataset_type = getattr(cfg.data, "dataset_type", "volume")
7777

7878
if dataset_type == "filename":
7979
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
@@ -94,12 +94,9 @@ def build_train_transforms(
9494
transforms.append(ApplyVolumetricSplitd(keys=keys))
9595

9696
# Apply resize if configured (before cropping)
97-
# Check data_transform first (new), then fall back to image_transform.resize (legacy)
9897
resize_size = None
9998
if hasattr(cfg.data, "data_transform") and hasattr(cfg.data.data_transform, "resize") and cfg.data.data_transform.resize is not None:
10099
resize_size = cfg.data.data_transform.resize
101-
elif hasattr(cfg.data.image_transform, "resize") and cfg.data.image_transform.resize is not None:
102-
resize_size = cfg.data.image_transform.resize
103100

104101
if resize_size:
105102
# Use bilinear for images, nearest for labels/masks
@@ -247,7 +244,7 @@ def _build_eval_transforms_impl(
247244
transforms = []
248245

249246
# Load images first - use appropriate loader based on dataset type
250-
dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility
247+
dataset_type = getattr(cfg.data, "dataset_type", "volume")
251248

252249
if dataset_type == "filename":
253250
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
@@ -260,17 +257,15 @@ def _build_eval_transforms_impl(
260257
if mode == "val":
261258
transpose_axes = cfg.data.val_transpose if cfg.data.val_transpose else []
262259
else: # mode == "test"
263-
# Check both data.test_transpose and inference.data.test_transpose
260+
# Use inference.data.test_transpose
264261
transpose_axes = []
265-
if cfg.data.test_transpose:
266-
transpose_axes = cfg.data.test_transpose
267262
if (
268263
hasattr(cfg, "inference")
269264
and hasattr(cfg.inference, "data")
270265
and hasattr(cfg.inference.data, "test_transpose")
271266
and cfg.inference.data.test_transpose
272267
):
273-
transpose_axes = cfg.inference.data.test_transpose # inference takes precedence
268+
transpose_axes = cfg.inference.data.test_transpose
274269

275270
transforms.append(
276271
LoadVolumed(keys=keys, transpose_axes=transpose_axes if transpose_axes else None)
@@ -455,8 +450,8 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str], do_2d: bo
455450
List of MONAI transforms
456451
"""
457452
transforms = []
458-
459-
# Get preset mode (default to "some" for backward compatibility)
453+
454+
# Get preset mode
460455
preset = getattr(aug_cfg, "preset", "some")
461456

462457
# Helper function to check if augmentation should be applied

connectomics/lightning/callbacks.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -428,25 +428,7 @@ def create_callbacks(cfg) -> list:
428428
callbacks.append(vis_callback)
429429

430430
# Model checkpoint callback
431-
# Support both new unified config (training.checkpoint_*) and old separate config (checkpoint.*)
432-
if hasattr(cfg, 'checkpoint') and cfg.checkpoint is not None:
433-
# Old config style (backward compatibility)
434-
monitor = getattr(cfg.checkpoint, 'monitor', 'val/loss')
435-
default_filename = f'epoch={{epoch:03d}}-{monitor}={{{monitor}:.4f}}'
436-
filename = getattr(cfg.checkpoint, 'filename', default_filename)
437-
438-
checkpoint_callback = ModelCheckpoint(
439-
monitor=monitor,
440-
mode=getattr(cfg.checkpoint, 'mode', 'min'),
441-
save_top_k=getattr(cfg.checkpoint, 'save_top_k', 3),
442-
save_last=getattr(cfg.checkpoint, 'save_last', True),
443-
dirpath=getattr(cfg.checkpoint, 'dirpath', 'checkpoints'),
444-
filename=filename,
445-
verbose=True
446-
)
447-
callbacks.append(checkpoint_callback)
448-
elif hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'checkpoint'):
449-
# New unified config style (monitor.checkpoint.*)
431+
if hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'checkpoint'):
450432
monitor = getattr(cfg.monitor.checkpoint, 'monitor', 'val/loss')
451433
filename = getattr(cfg.monitor.checkpoint, 'filename', None)
452434
if filename is None:
@@ -465,19 +447,7 @@ def create_callbacks(cfg) -> list:
465447
callbacks.append(checkpoint_callback)
466448

467449
# Early stopping callback
468-
# Support both new unified config (training.early_stopping_*) and old separate config (early_stopping.*)
469-
if hasattr(cfg, 'early_stopping') and cfg.early_stopping is not None and cfg.early_stopping.enabled:
470-
# Old config style (backward compatibility)
471-
early_stop_callback = EarlyStopping(
472-
monitor=getattr(cfg.early_stopping, 'monitor', 'val/loss'),
473-
patience=getattr(cfg.early_stopping, 'patience', 10),
474-
mode=getattr(cfg.early_stopping, 'mode', 'min'),
475-
min_delta=getattr(cfg.early_stopping, 'min_delta', 0.0),
476-
verbose=True
477-
)
478-
callbacks.append(early_stop_callback)
479-
elif hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'early_stopping') and getattr(cfg.monitor.early_stopping, 'enabled', False):
480-
# New unified config style (monitor.early_stopping.*)
450+
if hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'early_stopping') and getattr(cfg.monitor.early_stopping, 'enabled', False):
481451
early_stop_callback = EarlyStopping(
482452
monitor=getattr(cfg.monitor.early_stopping, 'monitor', 'val/loss'),
483453
patience=getattr(cfg.monitor.early_stopping, 'patience', 10),

0 commit comments

Comments
 (0)