Skip to content

Commit 9c6d27c

Browse files
author
Donglai Wei
committed
refactor training/inference
1 parent 2557c24 commit 9c6d27c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2382
-3267
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ build/
149149
.DS_Store
150150
Thumbs.db
151151

152+
# external libraries
153+
lib/
154+
152155
# Development logs and documentation
153156
.claude/
157+
.codex/
154158
.CLAUDE.md

CLAUDE.md

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,16 @@ connectomics/ # Main Python package (77 files, ~23K lines)
156156
│ ├── build.py # Optimizer/scheduler factory
157157
│ └── lr_scheduler.py # Custom LR schedulers
158158
159-
├── lightning/ # PyTorch Lightning integration (PRIMARY)
160-
│ ├── lit_data.py # LightningDataModule (Volume/Tile/Cloud datasets)
161-
│ ├── lit_model.py # LightningModule (1.8K lines - deep supervision, TTA)
162-
│ ├── lit_trainer.py # Trainer creation utilities
163-
│ └── callbacks.py # Custom Lightning callbacks
159+
├── training/ # Training utilities
160+
│ ├── lit/ # PyTorch Lightning integration (PRIMARY)
161+
│ │ ├── data.py # LightningDataModule (Volume/Tile/Cloud datasets)
162+
│ │ ├── model.py # LightningModule (deep supervision, TTA)
163+
│ │ ├── trainer.py # Trainer creation utilities
164+
│ │ ├── callbacks.py # Custom Lightning callbacks
165+
│ │ ├── config.py # Factory functions for training setup
166+
│ │ └── utils.py # CLI/config helpers
167+
│ ├── deep_supervision.py # Deep supervision utilities
168+
│ └── debugging.py # NaN detection and debugging utilities
164169
165170
├── data/ # Data loading and preprocessing
166171
│ ├── dataset/ # Dataset classes (HDF5, TIFF, Zarr, Cloud)
@@ -399,7 +404,7 @@ Multiple losses can be combined with weights in the config.
399404

400405
## PyTorch Lightning Integration
401406

402-
### LightningModule (`lightning/lit_model.py`)
407+
### LightningModule (`training/lit/model.py`)
403408
Wraps models with automatic training features:
404409
- Distributed training (DDP)
405410
- Mixed precision training (AMP)
@@ -410,7 +415,7 @@ Wraps models with automatic training features:
410415
- **Deep supervision**: Multi-scale loss computation with automatic target resizing
411416

412417
```python
413-
from connectomics.lightning import ConnectomicsModule
418+
from connectomics.training.lit import ConnectomicsModule
414419
415420
# Create Lightning module
416421
lit_model = ConnectomicsModule(cfg)
@@ -419,23 +424,23 @@ lit_model = ConnectomicsModule(cfg)
419424
lit_model = ConnectomicsModule(cfg, model=custom_model)
420425
```
421426

422-
### LightningDataModule (`lightning/lit_data.py`)
427+
### LightningDataModule (`training/lit/data.py`)
423428
Handles data loading with MONAI transforms:
424429
- Train/val/test splits
425430
- MONAI CacheDataset for fast loading
426431
- Automatic augmentation pipeline
427432
- Persistent workers for efficiency
428433

429434
```python
430-
from connectomics.lightning import ConnectomicsDataModule
435+
from connectomics.training.lit import ConnectomicsDataModule
431436
432437
datamodule = ConnectomicsDataModule(cfg)
433438
```
434439

435-
### Trainer (`lightning/lit_trainer.py`)
440+
### Trainer (`training/lit/trainer.py`)
436441
Convenience function for creating Lightning Trainer:
437442
```python
438-
from connectomics.lightning import create_trainer
443+
from connectomics.training.lit import create_trainer
439444
440445
trainer = create_trainer(cfg)
441446
trainer.fit(lit_model, datamodule=datamodule)
@@ -475,15 +480,15 @@ from pytorch_lightning import seed_everything
475480
seed_everything(cfg.system.seed)
476481
477482
# 3. Create data module
478-
from connectomics.lightning import ConnectomicsDataModule
483+
from connectomics.training.lit import ConnectomicsDataModule
479484
datamodule = ConnectomicsDataModule(cfg)
480485
481486
# 4. Create model
482-
from connectomics.lightning import ConnectomicsModule
487+
from connectomics.training.lit import ConnectomicsModule
483488
model = ConnectomicsModule(cfg)
484489
485490
# 5. Create trainer
486-
from connectomics.lightning import create_trainer
491+
from connectomics.training.lit import create_trainer
487492
trainer = create_trainer(cfg)
488493
489494
# 6. Train
@@ -555,9 +560,9 @@ scheduler:
555560
- `connectomics/models/solver/build.py`: Optimizer/scheduler factory
556561

557562
### Lightning
558-
- `connectomics/lightning/lit_model.py`: Lightning module wrapper
559-
- `connectomics/lightning/lit_data.py`: Data module
560-
- `connectomics/lightning/lit_trainer.py`: Trainer utilities
563+
- `connectomics/training/lit/model.py`: Lightning module wrapper
564+
- `connectomics/training/lit/data.py`: Data module
565+
- `connectomics/training/lit/trainer.py`: Trainer utilities
561566

562567
### Entry Points
563568
- `scripts/main.py`: Primary training script (Lightning + Hydra)
@@ -635,7 +640,7 @@ The codebase has **fully migrated** from legacy systems:
635640

636641
**Current development stack:**
637642
- Hydra/OmegaConf configs (`tutorials/*.yaml`)
638-
- PyTorch Lightning modules (`connectomics/lightning/`)
643+
- PyTorch Lightning modules (`connectomics/training/lit/`)
639644
- `scripts/main.py` entry point
640645
- MONAI models and transforms
641646
- Type-safe dataclass configurations
@@ -687,7 +692,7 @@ Install via `pip install -e .[extra_name]` where `extra_name` is:
687692

688693
#### `[wandb]` - Weights & Biases Integration
689694
- **wandb** (>=0.13.0): Experiment tracking and monitoring
690-
- Used in: `connectomics.lightning.lit_trainer` (optional logger)
695+
- Used in: `connectomics.training.lit.trainer` (optional logger)
691696

692697
#### `[tiff]` - TIFF File Support
693698
- **tifffile** (>=2021.11.2): Advanced TIFF reading/writing
@@ -817,4 +822,4 @@ python -c "from connectomics.models.arch import list_architectures; print(list_a
817822
- [MONAI Docs](https://docs.monai.io/en/stable/) - Medical imaging toolkit
818823
- [Hydra Docs](https://hydra.cc/) - Configuration management
819824
- [Project Documentation](https://zudi-lin.github.io/pytorch_connectomics/build/html/index.html) - Full docs
820-
- [Slack Community](https://join.slack.com/t/pytorchconnectomics/shared_invite/zt-obufj5d1-v5_NndNS5yog8vhxy4L12w) - Get help
825+
- [Slack Community](https://join.slack.com/t/pytorchconnectomics/shared_invite/zt-obufj5d1-v5_NndNS5yog8vhxy4L12w) - Get help

conda-recipe/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ test:
5555
- connectomics
5656
- connectomics.config
5757
- connectomics.data
58-
- connectomics.lightning
58+
- connectomics.training.lit
5959
- connectomics.models
6060
- connectomics.utils
6161
commands:

connectomics/data/augment/build.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,19 +376,34 @@ def _build_eval_transforms_impl(
376376
if getattr(cfg.data, "normalize_labels", False):
377377
transforms.append(NormalizeLabelsd(keys=["label"]))
378378

379-
# Check if we should skip label transforms (test mode with evaluation metrics)
379+
# Check if we should skip label transforms (test/tune mode)
380+
# Skip label transforms if test.data or tune.data has evaluation.enabled=True
381+
# This preserves original instance labels for metric computation
380382
skip_label_transform = False
381383
if mode == "test":
382-
if hasattr(cfg, "inference") and hasattr(cfg.inference, "evaluation"):
383-
evaluation_enabled = getattr(cfg.inference.evaluation, "enabled", False)
384-
metrics = getattr(cfg.inference.evaluation, "metrics", [])
384+
# Check if test.evaluation or tune.evaluation is enabled (for adapted_rand, etc.)
385+
evaluation_config = None
386+
if hasattr(cfg, "test") and hasattr(cfg.test, "evaluation"):
387+
evaluation_config = cfg.test.evaluation
388+
elif hasattr(cfg, "tune") and cfg.tune and hasattr(cfg.tune, "optimization"):
389+
# For tune mode, check if we're optimizing metrics that need instance labels
390+
if hasattr(cfg.tune.optimization, "single_objective"):
391+
metric = getattr(cfg.tune.optimization.single_objective, "metric", None)
392+
if metric == "adapted_rand":
393+
skip_label_transform = True
394+
print(f" ⚠️ Skipping label transforms for Optuna tuning (keeping original labels for {metric})")
395+
396+
if evaluation_config:
397+
evaluation_enabled = getattr(evaluation_config, "enabled", False)
398+
metrics = getattr(evaluation_config, "metrics", [])
385399
if evaluation_enabled and metrics:
386400
skip_label_transform = True
387401
print(
388402
f" ⚠️ Skipping label transforms for metric evaluation (keeping original labels for {metrics})"
389403
)
390404

391405
# Label transformations (affinity, distance transform, etc.)
406+
# Only apply if not skipped AND label_transform is configured
392407
if hasattr(cfg.data, "label_transform") and not skip_label_transform:
393408
from ..process.build import create_label_transform_pipeline
394409
from ..process.monai_transforms import SegErosionInstanced

connectomics/decoding/optuna_tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
620620
6. Save best parameters to YAML file
621621
622622
Example:
623-
>>> from connectomics.lightning import ConnectomicsModule, create_trainer
623+
>>> from connectomics.training.lit import ConnectomicsModule, create_trainer
624624
>>> from connectomics.decoding import run_tuning
625625
>>> model = ConnectomicsModule(cfg)
626626
>>> trainer = create_trainer(cfg)
@@ -655,7 +655,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
655655
print(f"Output directory: {output_dir}")
656656

657657
# Step 1: Run inference on tune dataset
658-
from connectomics.lightning import create_datamodule
658+
from connectomics.training.lit import create_datamodule
659659
from connectomics.data.io import read_volume
660660
import glob
661661

connectomics/decoding/segmentation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,6 @@ def decode_instance_binary_contour_distance(
202202
seed = remove_small_objects(seed, min_seed_size)
203203

204204
# step 3: compute the segmentation mask
205-
import pdb
206-
207-
pdb.set_trace()
208205
distance[distance < 0] = 0
209206
segmentation = mahotas.cwatershed(-distance.astype(np.float64), seed)
210207
segmentation[~foreground] = (

connectomics/inference/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Inference utilities package."""
2+
3+
from .manager import InferenceManager
4+
from .io import apply_postprocessing, apply_decode_mode, resolve_output_filenames, write_outputs
5+
from .sliding import build_sliding_inferer, resolve_inferer_roi_size, resolve_inferer_overlap
6+
from .tta import TTAPredictor
7+
8+
__all__ = [
9+
"InferenceManager",
10+
"apply_postprocessing",
11+
"apply_decode_mode",
12+
"resolve_output_filenames",
13+
"write_outputs",
14+
"build_sliding_inferer",
15+
"resolve_inferer_roi_size",
16+
"resolve_inferer_overlap",
17+
"TTAPredictor",
18+
]

0 commit comments

Comments
 (0)