diff --git a/training/training_loop.py b/training/training_loop.py index 14836ad2e..f83f04a6e 100755 --- a/training/training_loop.py +++ b/training/training_loop.py @@ -7,6 +7,7 @@ # license agreement from NVIDIA CORPORATION is strictly prohibited. import os +import re import time import copy import json @@ -152,6 +153,7 @@ def training_loop( G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. + cur_nimg = 0 if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: @@ -159,6 +161,12 @@ def training_loop( for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) + # resume from a snapshot file(`network-snapshot-.pkl`) continues from + # where it stopped.. + match = re.match(r"^.*(network-snapshot-)(\d+)(.pkl)$", resume_pkl, re.IGNORECASE) + if match: + cur_nimg = int(match.group(2)) * 1000 + # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) @@ -245,14 +253,14 @@ def training_loop( if rank == 0: print(f'Training for {total_kimg} kimg...') print() - cur_nimg = 0 + cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: - progress_fn(0, total_kimg) + progress_fn(cur_nimg // 1000, total_kimg) while True: # Fetch training data.