Skip to content

Commit 1014d8d

Browse files
authored
Merge pull request #167 from BoyuShen2004/master
fix inference sliding_window bug
2 parents 03701d9 + 580a5cf commit 1014d8d

File tree

5 files changed

+40
-15
lines changed

5 files changed

+40
-15
lines changed

.latest_timestamp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
20251108_212916

connectomics/config/hydra_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ class CheckpointConfig:
514514
"""Model checkpointing configuration."""
515515

516516
monitor: str = "train_loss_total_epoch"
517+
dirpath: Optional[str] = None
517518
mode: str = "min"
518519
save_top_k: int = 1
519520
save_last: bool = True
@@ -831,6 +832,7 @@ class InferenceDataConfig:
831832
test_transpose: List[int] = field(
832833
default_factory=list
833834
) # Axis permutation for test data (e.g., [2,1,0] for xyz->zyx)
835+
output_path: Optional[str] = None # Optional explicit directory for inference outputs
834836
output_name: str = (
835837
"predictions.h5" # Output filename (auto-pathed to inference/{checkpoint}/{output_name})
836838
)
@@ -845,6 +847,7 @@ class SlidingWindowConfig:
845847

846848
window_size: Optional[List[int]] = None
847849
sw_batch_size: Optional[int] = None # If None, will use system.inference.batch_size
850+
overlap: Optional[Any] = None # Overlap between window passes (float or sequence)
848851
stride: Optional[List[int]] = None # Explicit stride for controlling window movement
849852
blending: str = "gaussian" # 'gaussian' or 'constant' - blending mode for overlapping patches
850853
sigma_scale: float = (

install.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,17 @@ def install_pytorch_connectomics(
431431
print_success(f"Core packages installed: {', '.join(to_install)}")
432432
else:
433433
print_success("All core packages already installed")
434+
print_info("Ensuring numpy and h5py are installed from conda-forge (force reinstall)...")
435+
code, _, stderr = run_command(
436+
f"conda install -n {env_name} -c conda-forge numpy h5py -y --force-reinstall",
437+
check=False,
438+
)
439+
if code != 0:
440+
print_warning("conda reinstall of numpy/h5py failed; please verify the environment manually")
441+
if stderr.strip():
442+
print_warning(stderr.strip())
443+
else:
444+
print_success("numpy and h5py verified via conda-forge")
434445

435446
# Group 2: Optional scientific packages (nice to have, but slow to install)
436447
optional_packages = ["scipy", "scikit-learn", "scikit-image", "opencv"]

scripts/main.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -182,17 +182,24 @@ def setup_config(args) -> Config:
182182
config_name = config_path.stem # Get filename without extension
183183
output_folder = f"outputs/{config_name}/"
184184

185-
# Update checkpoint dirpath to use the new output folder
186-
cfg.monitor.checkpoint.dirpath = f"{output_folder}checkpoints/"
185+
# Update checkpoint dirpath only if not provided by the user
186+
if not getattr(cfg.monitor.checkpoint, "dirpath", None):
187+
cfg.monitor.checkpoint.dirpath = str(Path(output_folder) / "checkpoints")
188+
else:
189+
cfg.monitor.checkpoint.dirpath = str(Path(cfg.monitor.checkpoint.dirpath))
187190

188-
# Update inference output path to use the new output folder
189-
cfg.inference.data.output_path = f"{output_folder}results/"
191+
# Update inference output path only if not provided by the user
192+
if not getattr(cfg.inference.data, "output_path", None):
193+
cfg.inference.data.output_path = str(Path(output_folder) / "results")
194+
else:
195+
cfg.inference.data.output_path = str(Path(cfg.inference.data.output_path))
190196

191197
# Note: We handle timestamping manually in main() to create run directories
192198
# Set this to False to prevent PyTorch Lightning from adding its own timestamp
193199
cfg.monitor.checkpoint.use_timestamp = False
194200

195-
print(f"📁 Output folder set to: {output_folder}")
201+
print(f"📁 Checkpoints base directory: {cfg.monitor.checkpoint.dirpath}")
202+
print(f"📂 Inference output directory: {cfg.inference.data.output_path}")
196203

197204
# Apply CLI overrides
198205
if args.overrides:
@@ -1111,8 +1118,9 @@ def main():
11111118
# Subsequent invocations (with LOCAL_RANK set) reuse the existing timestamp.
11121119
if args.mode == "train":
11131120
# Extract output folder from checkpoint dirpath (remove /checkpoints suffix)
1114-
checkpoint_dirpath = cfg.monitor.checkpoint.dirpath
1115-
output_base = Path(checkpoint_dirpath).parent # This gives us outputs/experiment_name/
1121+
checkpoint_dir = Path(cfg.monitor.checkpoint.dirpath)
1122+
checkpoint_subdir = checkpoint_dir.name or "checkpoints"
1123+
output_base = checkpoint_dir.parent # Base directory containing timestamped runs
11161124

11171125
# Check if this is a DDP re-launch (LOCAL_RANK is set by PyTorch Lightning)
11181126
import os
@@ -1129,10 +1137,11 @@ def main():
11291137
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
11301138
run_dir = output_base / timestamp
11311139

1132-
# Update checkpoint dirpath to use the timestamped directory
1133-
cfg.monitor.checkpoint.dirpath = str(run_dir / "checkpoints")
1140+
# Update checkpoint dirpath to use the timestamped directory (preserve leaf name)
1141+
checkpoint_path = run_dir / checkpoint_subdir
1142+
cfg.monitor.checkpoint.dirpath = str(checkpoint_path)
11341143

1135-
run_dir.mkdir(parents=True, exist_ok=True)
1144+
checkpoint_path.mkdir(parents=True, exist_ok=True)
11361145
print(f"📁 Run directory: {run_dir}")
11371146

11381147
# Save config to run directory
@@ -1156,7 +1165,8 @@ def main():
11561165
if timestamp_file.exists():
11571166
timestamp = timestamp_file.read_text().strip()
11581167
run_dir = output_base / timestamp
1159-
cfg.monitor.checkpoint.dirpath = str(run_dir / "checkpoints")
1168+
checkpoint_path = run_dir / checkpoint_subdir
1169+
cfg.monitor.checkpoint.dirpath = str(checkpoint_path)
11601170
print(f"📁 [DDP Rank {local_rank}] Using run directory: {run_dir}")
11611171
else:
11621172
raise RuntimeError(

tutorials/monai2d_worm.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ data:
6262
do_2d: true # Enable 2D data processing (extract 2D slices from 3D volumes)
6363

6464
# Volume configuration
65-
train_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/*.tif
66-
train_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/*.tif
65+
train_image: /orcd/data/edboyden/002/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/*.tif
66+
train_label: /orcd/data/edboyden/002/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/*.tif
6767
train_resolution: [5, 5] # Lucchi EM: 5nm isotropic resolution
6868
use_preloaded_cache: true # Load volumes into memory for fast training
6969

@@ -158,7 +158,7 @@ monitor:
158158
save_top_k: 1
159159
save_last: true
160160
save_every_n_epochs: 10
161-
dirpath: checkpoints/ # Will be dynamically set to outputs/{yaml_filename}/YYYYMMDD_HHMMSS/checkpoints/
161+
dirpath: outputs/monai2d_worm/checkpoints/ # Will be dynamically set to outputs/{yaml_filename}/YYYYMMDD_HHMMSS/checkpoints/
162162
# checkpoint_filename: auto-generated from monitor metric (epoch={epoch:03d}-{monitor}={value:.4f})
163163
use_timestamp: true # Enable timestamped subdirectories (YYYYMMDD_HHMMSS)
164164

@@ -177,7 +177,7 @@ monitor:
177177
inference:
178178
data:
179179
do_2d: true # Enable 2D data processing for inference
180-
test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif
180+
test_image: /orcd/data/edboyden/002/shenb/wormbehavior/13/*.tif
181181
test_label:
182182
test_resolution: [5, 5]
183183
output_path: outputs/monai2d_worm/results/

0 commit comments

Comments
 (0)