Skip to content

Commit 35feabc

Browse files
committed
Merge: [nnUNet/PyT] Fix case with checkpoint path set to None
2 parents 0915477 + 20bda77 commit 35feabc

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

PyTorch/Segmentation/nnUNet/main.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
import os
1616

1717
import torch
18+
from data_loading.data_module import DataModule
19+
from nnunet.nn_unet import NNUnet
1820
from pytorch_lightning import Trainer, seed_everything
1921
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary, RichProgressBar
2022
from pytorch_lightning.plugins.io import AsyncCheckpointIO
2123
from pytorch_lightning.strategies import DDPStrategy
22-
23-
from data_loading.data_module import DataModule
24-
from nnunet.nn_unet import NNUnet
2524
from utils.args import get_main_args
2625
from utils.logger import LoggingCallback
2726
from utils.utils import make_empty_dir, set_cuda_devices, set_granularity, verify_ckpt_path

PyTorch/Segmentation/nnUNet/nnunet/nn_unet.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,15 @@
2222
from data_loading.data_module import get_data_path, get_test_fnames
2323
from monai.inferers import sliding_window_inference
2424
from monai.networks.nets import DynUNet
25+
from nnunet.brats22_model import UNet3D
26+
from nnunet.loss import Loss, LossBraTS
27+
from nnunet.metrics import Dice
2528
from pytorch_lightning.utilities import rank_zero_only
2629
from scipy.special import expit, softmax
2730
from skimage.transform import resize
2831
from utils.logger import DLLogger
2932
from utils.utils import get_config_file, print0
3033

31-
from nnunet.brats22_model import UNet3D
32-
from nnunet.loss import Loss, LossBraTS
33-
from nnunet.metrics import Dice
34-
3534

3635
class NNUnet(pl.LightningModule):
3736
def __init__(self, args, triton=False, data_dir=None):
@@ -279,7 +278,7 @@ def test_epoch_end(self, outputs):
279278

280279
@rank_zero_only
281280
def on_fit_end(self):
282-
if not self.args.benchmark and self.args.skip_first_n_eval == 0:
281+
if not self.args.benchmark:
283282
metrics = {}
284283
metrics["dice_score"] = round(self.best_mean.item(), 2)
285284
metrics["train_loss"] = round(sum(self.train_loss) / len(self.train_loss), 4)

PyTorch/Segmentation/nnUNet/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def verify_ckpt_path(args):
5858
return resume_path_results
5959
print("[Warning] Checkpoint not found. Starting training from scratch.")
6060
return None
61-
if not os.path.isfile(args.ckpt_path):
61+
if args.ckpt_path is None or not os.path.isfile(args.ckpt_path):
6262
print(f"Provided checkpoint {args.ckpt_path} is not a file. Starting training from scratch.")
6363
return None
6464
return args.ckpt_path

0 commit comments

Comments
 (0)