Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .latest_timestamp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
20251108_212916
3 changes: 3 additions & 0 deletions connectomics/config/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ class CheckpointConfig:
"""Model checkpointing configuration."""

monitor: str = "train_loss_total_epoch"
dirpath: Optional[str] = None
mode: str = "min"
save_top_k: int = 1
save_last: bool = True
Expand Down Expand Up @@ -831,6 +832,7 @@ class InferenceDataConfig:
test_transpose: List[int] = field(
default_factory=list
) # Axis permutation for test data (e.g., [2,1,0] for xyz->zyx)
output_path: Optional[str] = None # Optional explicit directory for inference outputs
output_name: str = (
"predictions.h5" # Output filename (auto-pathed to inference/{checkpoint}/{output_name})
)
Expand All @@ -845,6 +847,7 @@ class SlidingWindowConfig:

window_size: Optional[List[int]] = None
sw_batch_size: Optional[int] = None # If None, will use system.inference.batch_size
overlap: Optional[Any] = None # Overlap between window passes (float or sequence)
stride: Optional[List[int]] = None # Explicit stride for controlling window movement
blending: str = "gaussian" # 'gaussian' or 'constant' - blending mode for overlapping patches
sigma_scale: float = (
Expand Down
11 changes: 11 additions & 0 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,17 @@ def install_pytorch_connectomics(
print_success(f"Core packages installed: {', '.join(to_install)}")
else:
print_success("All core packages already installed")
print_info("Ensuring numpy and h5py are installed from conda-forge (force reinstall)...")
code, _, stderr = run_command(
f"conda install -n {env_name} -c conda-forge numpy h5py -y --force-reinstall",
check=False,
)
if code != 0:
print_warning("conda reinstall of numpy/h5py failed; please verify the environment manually")
if stderr.strip():
print_warning(stderr.strip())
else:
print_success("numpy and h5py verified via conda-forge")

# Group 2: Optional scientific packages (nice to have, but slow to install)
optional_packages = ["scipy", "scikit-learn", "scikit-image", "opencv"]
Expand Down
32 changes: 21 additions & 11 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,24 @@ def setup_config(args) -> Config:
config_name = config_path.stem # Get filename without extension
output_folder = f"outputs/{config_name}/"

# Update checkpoint dirpath to use the new output folder
cfg.monitor.checkpoint.dirpath = f"{output_folder}checkpoints/"
# Update checkpoint dirpath only if not provided by the user
if not getattr(cfg.monitor.checkpoint, "dirpath", None):
cfg.monitor.checkpoint.dirpath = str(Path(output_folder) / "checkpoints")
else:
cfg.monitor.checkpoint.dirpath = str(Path(cfg.monitor.checkpoint.dirpath))

# Update inference output path to use the new output folder
cfg.inference.data.output_path = f"{output_folder}results/"
# Update inference output path only if not provided by the user
if not getattr(cfg.inference.data, "output_path", None):
cfg.inference.data.output_path = str(Path(output_folder) / "results")
else:
cfg.inference.data.output_path = str(Path(cfg.inference.data.output_path))

# Note: We handle timestamping manually in main() to create run directories
# Set this to False to prevent PyTorch Lightning from adding its own timestamp
cfg.monitor.checkpoint.use_timestamp = False

print(f"📁 Output folder set to: {output_folder}")
print(f"📁 Checkpoints base directory: {cfg.monitor.checkpoint.dirpath}")
print(f"📂 Inference output directory: {cfg.inference.data.output_path}")

# Apply CLI overrides
if args.overrides:
Expand Down Expand Up @@ -1111,8 +1118,9 @@ def main():
# Subsequent invocations (with LOCAL_RANK set) reuse the existing timestamp.
if args.mode == "train":
# Extract output folder from checkpoint dirpath (remove /checkpoints suffix)
checkpoint_dirpath = cfg.monitor.checkpoint.dirpath
output_base = Path(checkpoint_dirpath).parent # This gives us outputs/experiment_name/
checkpoint_dir = Path(cfg.monitor.checkpoint.dirpath)
checkpoint_subdir = checkpoint_dir.name or "checkpoints"
output_base = checkpoint_dir.parent # Base directory containing timestamped runs

# Check if this is a DDP re-launch (LOCAL_RANK is set by PyTorch Lightning)
import os
Expand All @@ -1129,10 +1137,11 @@ def main():
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = output_base / timestamp

# Update checkpoint dirpath to use the timestamped directory
cfg.monitor.checkpoint.dirpath = str(run_dir / "checkpoints")
# Update checkpoint dirpath to use the timestamped directory (preserve leaf name)
checkpoint_path = run_dir / checkpoint_subdir
cfg.monitor.checkpoint.dirpath = str(checkpoint_path)

run_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path.mkdir(parents=True, exist_ok=True)
print(f"📁 Run directory: {run_dir}")

# Save config to run directory
Expand All @@ -1156,7 +1165,8 @@ def main():
if timestamp_file.exists():
timestamp = timestamp_file.read_text().strip()
run_dir = output_base / timestamp
cfg.monitor.checkpoint.dirpath = str(run_dir / "checkpoints")
checkpoint_path = run_dir / checkpoint_subdir
cfg.monitor.checkpoint.dirpath = str(checkpoint_path)
print(f"📁 [DDP Rank {local_rank}] Using run directory: {run_dir}")
else:
raise RuntimeError(
Expand Down
8 changes: 4 additions & 4 deletions tutorials/monai2d_worm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ data:
do_2d: true # Enable 2D data processing (extract 2D slices from 3D volumes)

# Volume configuration
train_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/*.tif
train_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/*.tif
train_image: /orcd/data/edboyden/002/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/*.tif
train_label: /orcd/data/edboyden/002/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/*.tif
train_resolution: [5, 5] # Lucchi EM: 5nm isotropic resolution
use_preloaded_cache: true # Load volumes into memory for fast training

Expand Down Expand Up @@ -158,7 +158,7 @@ monitor:
save_top_k: 1
save_last: true
save_every_n_epochs: 10
dirpath: checkpoints/ # Will be dynamically set to outputs/{yaml_filename}/YYYYMMDD_HHMMSS/checkpoints/
dirpath: outputs/monai2d_worm/checkpoints/ # Will be dynamically set to outputs/{yaml_filename}/YYYYMMDD_HHMMSS/checkpoints/
# checkpoint_filename: auto-generated from monitor metric (epoch={epoch:03d}-{monitor}={value:.4f})
use_timestamp: true # Enable timestamped subdirectories (YYYYMMDD_HHMMSS)

Expand All @@ -177,7 +177,7 @@ monitor:
inference:
data:
do_2d: true # Enable 2D data processing for inference
test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif
test_image: /orcd/data/edboyden/002/shenb/wormbehavior/13/*.tif
test_label:
test_resolution: [5, 5]
output_path: outputs/monai2d_worm/results/
Expand Down
Loading