Skip to content

Commit a37e410

Browse files
author
Donglai Wei
committed
update demo
1 parent 5d7854a commit a37e410

File tree

9 files changed

+379
-318
lines changed

9 files changed

+379
-318
lines changed

README.md

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ PyTorch Connectomics (PyTC) helps neuroscientists:
3131
-**Train models** without deep ML expertise
3232
-**Process** large-scale connectomics datasets efficiently
3333

34-
**Built on:** [PyTorch Lightning](https://lightning.ai/) + [MONAI](https://monai.io/) for modern, scalable deep learning.
35-
36-
**Used by:** Harvard, MIT, Janelia Research Campus, and 100+ labs worldwide.
34+
**Built on:** [PyTorch Lightning](https://lightning.ai/) + [MONAI](https://monai.io/) + [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) for modern, scalable deep learning.
3735

3836
---
3937

@@ -105,22 +103,29 @@ python scripts/main.py --demo
105103
Train on real mitochondria segmentation data:
106104

107105
```bash
108-
# Download tutorial data (~100 MB)
109-
mkdir -p datasets/
110-
wget https://huggingface.co/datasets/pytc/tutorial/resolve/main/Lucchi%2B%2B.zip
111-
unzip Lucchi++.zip -d datasets/
112-
rm Lucchi++.zip
106+
# Download tutorial data (~50 MB)
107+
just download lucchi++
113108

114109
# Quick test (1 batch)
115-
python scripts/main.py --config tutorials/monai_lucchi++.yaml --fast-dev-run
110+
just train lucchi++ monai_unet --fast-dev-run
116111

117-
# Full training
118-
python scripts/main.py --config tutorials/monai_lucchi++.yaml
112+
# Full training on a single GPU (choose your architecture: monai_unet, rsunet, mednext)
113+
just train lucchi++ monai_unet -- system.training.num_gpus=1
119114
```
120115

121116
**Monitor progress:**
122117
```bash
123-
tensorboard --logdir outputs/lucchi++_monai_unet
118+
just tensorboard lucchi++_monai_unet
119+
```
120+
121+
**Resume training from checkpoint:**
122+
```bash
123+
just resume lucchi++ monai_unet outputs/lucchi++_monai_unet/*/checkpoints/last.ckpt
124+
```
125+
126+
**Run inference:**
127+
```bash
128+
just test lucchi++ monai_unet outputs/lucchi++_monai_unet/*/checkpoints/best.ckpt
124129
```
125130

126131
---

connectomics/models/arch/monai_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,10 @@ def build_basic_unet(cfg) -> ConnectomicsModel:
102102
)
103103

104104
# BasicUNet requires exactly 6 feature levels
105-
base_features = list(cfg.model.filters) if hasattr(cfg.model, 'filters') else [32, 64, 128, 256, 512]
105+
# Pad with last value repeated (not doubled) to keep memory usage low
106+
base_features = list(cfg.model.filters) if hasattr(cfg.model, 'filters') else [32, 64, 128, 256, 512, 1024]
106107
while len(base_features) < 6:
107-
base_features.append(base_features[-1] * 2)
108+
base_features.append(base_features[-1]) # Repeat last value instead of doubling
108109
features = tuple(base_features[:6])
109110

110111
model = BasicUNet(
@@ -157,8 +158,7 @@ def build_monai_unet(cfg) -> ConnectomicsModel:
157158
"Use [H, W] for 2D or [D, H, W] for 3D. "
158159
"spatial_dims will be automatically inferred from input_size length."
159160
)
160-
features = list(cfg.model.filters) if hasattr(cfg.model, 'filters') else [32, 64, 128, 256, 512]
161-
channels = features[:5] # Limit to 5 levels
161+
channels = list(cfg.model.filters) if hasattr(cfg.model, 'filters') else [32, 64, 128, 256, 512]
162162
strides = [2] * (len(channels) - 1) # 2x downsampling at each level
163163

164164
# Handle normalization type and parameters

connectomics/utils/demo.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def create_demo_config():
8585
ModelConfig,
8686
DataConfig,
8787
OptimizationConfig,
88+
OptimizerConfig,
89+
SchedulerConfig,
8890
MonitorConfig,
8991
CheckpointConfig,
9092
EarlyStoppingConfig,
@@ -111,14 +113,15 @@ def create_demo_config():
111113
),
112114
),
113115
model=ModelConfig(
114-
architecture="monai_basic_unet3d",
116+
architecture="monai_unet",
115117
in_channels=1,
116-
out_channels=2,
118+
out_channels=1, # Single channel binary segmentation
117119
spatial_dims=3,
118120
filters=[16, 32, 64, 128], # Smaller for demo
119121
dropout=0.1,
120122
loss_functions=["DiceLoss"],
121123
loss_weights=[1.0],
124+
loss_kwargs=[{"sigmoid": True, "smooth_nr": 1e-5, "smooth_dr": 1e-5}],
122125
),
123126
data=DataConfig(
124127
train_image=None, # Will be generated
@@ -142,9 +145,9 @@ def create_demo_config():
142145
log_every_n_steps=1,
143146
deterministic=False,
144147
benchmark=True,
148+
optimizer=OptimizerConfig(name="AdamW", lr=1e-3, weight_decay=1e-4),
149+
scheduler=SchedulerConfig(name="ConstantLR", warmup_epochs=0),
145150
),
146-
optimizer={"name": "AdamW", "lr": 1e-3, "weight_decay": 1e-4},
147-
scheduler={"name": "ConstantLR", "warmup_epochs": 0},
148151
monitor=MonitorConfig(
149152
checkpoint=CheckpointConfig(
150153
dirpath="outputs/demo/checkpoints",
@@ -323,13 +326,18 @@ def run_demo():
323326
except (ImportError, ModuleNotFoundError):
324327
pass
325328

329+
# Use TensorBoard logger to avoid "no logger" warnings
330+
from pytorch_lightning.loggers import TensorBoardLogger
331+
demo_logger = TensorBoardLogger(save_dir=str(temp_dir), name="demo_logs", version="")
332+
326333
trainer = pl.Trainer(
327334
max_epochs=cfg.optimization.max_epochs,
328335
accelerator="gpu" if cfg.system.training.num_gpus > 0 else "cpu",
329336
devices=cfg.system.training.num_gpus if cfg.system.training.num_gpus > 0 else 1,
330337
precision=cfg.optimization.precision,
331338
callbacks=callbacks,
332-
logger=False, # Disable logging for demo
339+
logger=demo_logger,
340+
log_every_n_steps=cfg.optimization.log_every_n_steps,
333341
enable_checkpointing=True,
334342
enable_progress_bar=True,
335343
enable_model_summary=True,
@@ -352,13 +360,18 @@ def run_demo():
352360
print("=" * 60)
353361
print("\nYour installation is working correctly! 🎉")
354362
print("\n📚 Next steps:")
355-
print(" 1. Try a tutorial:")
356-
print(" python scripts/main.py --config tutorials/lucchi.yaml --fast-dev-run")
357-
print("\n 2. Download tutorial data:")
358-
print(" just download-data lucchi # Or see README for manual download")
363+
print(" 1. Download tutorial data:")
364+
print(" just download lucchi++")
365+
print(" just download-list # See all available datasets")
366+
print("\n 2. Try a fast dev run:")
367+
print(" just train lucchi++ monai_unet -- --fast-dev-run")
359368
print("\n 3. Train on real data:")
360-
print(" python scripts/main.py --config tutorials/lucchi.yaml")
361-
print("\n 4. Read the documentation:")
369+
print(" just train lucchi++ monai_unet")
370+
print(" just train lucchi++ rsunet")
371+
print(" just train lucchi++ mednext")
372+
print("\n 4. Monitor training:")
373+
print(" just tensorboard lucchi++_monai_unet")
374+
print("\n 5. Read the documentation:")
362375
print(" https://connectomics.readthedocs.io")
363376
print("\n" + "=" * 60 + "\n")
364377

justfile

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,38 @@ default:
66
@just --list
77

88
# ============================================================================
9-
# SLURM/HPC Setup
9+
# Setup & Data
1010
# ============================================================================
1111

1212
# Setup SLURM environment: detect CUDA/cuDNN and install PyTorch with correct versions
1313
setup-slurm:
1414
bash connectomics/utils/setup_slurm.sh
1515

16+
# Download dataset(s) (e.g., just download lucchi++, just download all)
17+
# Available: lucchi++, snemi, mitoem, cremi
18+
download +datasets:
19+
python scripts/download_data.py {{datasets}}
20+
21+
# List available datasets
22+
download-list:
23+
python scripts/download_data.py --list
24+
1625
# ============================================================================
1726
# Training Commands
1827
# ============================================================================
1928

20-
# Train on Lucchi dataset (use '+' to pass extra args: just train monai lucchi++ -- data.batch_size=8, --fast-dev-run)
21-
train model dataset *ARGS='':
22-
python scripts/main.py --config tutorials/{{model}}_{{dataset}}.yaml {{ARGS}}
29+
# Train with unified config (e.g., just train lucchi++ rsunet -- data.batch_size=8)
30+
# Architecture options: monai_basic_unet3d, rsunet, mednext
31+
train dataset arch *ARGS='':
32+
python scripts/main.py --config tutorials/{{dataset}}.yaml model.architecture={{arch}} {{ARGS}}
2333

24-
# Continue training from checkpoint (use '+' for extra args: just resume monai lucchi++ ckpt.pt -- --reset-optimizer)
25-
resume model dataset checkpoint *ARGS='':
26-
python scripts/main.py --config tutorials/{{model}}_{{dataset}}.yaml --checkpoint {{checkpoint}} {{ARGS}}
34+
# Continue training from checkpoint (e.g., just resume lucchi++ rsunet ckpt.pt -- --reset-optimizer)
35+
resume dataset arch checkpoint *ARGS='':
36+
python scripts/main.py --config tutorials/{{dataset}}.yaml model.architecture={{arch}} --checkpoint {{checkpoint}} {{ARGS}}
2737

28-
# Test on Lucchi++ dataset (provide path to checkpoint)
29-
test model dataset checkpoint *ARGS='':
30-
python scripts/main.py --config tutorials/{{model}}_{{dataset}}.yaml --mode test --checkpoint {{checkpoint}} {{ARGS}}
38+
# Test model (e.g., just test lucchi++ rsunet ckpt.pt)
39+
test dataset arch checkpoint *ARGS='':
40+
python scripts/main.py --config tutorials/{{dataset}}.yaml model.architecture={{arch}} --mode test --checkpoint {{checkpoint}} {{ARGS}}
3141

3242
# ============================================================================
3343
# Monitoring Commands

0 commit comments

Comments
 (0)