|
| 1 | +#!/usr/bin/env python |
| 2 | +""" Checkpoint Averaging Script |
| 3 | +
|
| 4 | +This script averages all model weights for checkpoints in specified path that match |
| 5 | +the specified filter wildcard. All checkpoints must be from the exact same model. |
| 6 | +
|
| 7 | +For any hope of decent results, the checkpoints should be from the same or child |
| 8 | +(via resumes) training session. This can be viewed as similar to maintaining running |
| 9 | +EMA (exponential moving average) of the model weights or performing SWA (stochastic |
| 10 | +weight averaging), but post-training. |
| 11 | +
|
| 12 | +Hacked together by Ross Wightman (https://github.com/rwightman) |
| 13 | +""" |
| 14 | +import torch |
| 15 | +import argparse |
| 16 | +import os |
| 17 | +import glob |
| 18 | +import hashlib |
| 19 | +from timm.models.helpers import load_state_dict |
| 20 | + |
| 21 | +parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') |
| 22 | +parser.add_argument('--input', default='', type=str, metavar='PATH', |
| 23 | + help='path to base input folder containing checkpoints') |
| 24 | +parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', |
| 25 | + help='checkpoint filter (path wildcard)') |
| 26 | +parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH', |
| 27 | + help='output filename') |
| 28 | +parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', |
| 29 | + help='Force not using ema version of weights (if present)') |
| 30 | +parser.add_argument('--no-sort', dest='no_sort', action='store_true', |
| 31 | + help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') |
| 32 | +parser.add_argument('-n', type=int, default=10, metavar='N', |
| 33 | + help='Number of checkpoints to average') |
| 34 | + |
| 35 | + |
| 36 | +def checkpoint_metric(checkpoint_path): |
| 37 | + if not checkpoint_path or not os.path.isfile(checkpoint_path): |
| 38 | + return {} |
| 39 | + print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path)) |
| 40 | + checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| 41 | + metric = None |
| 42 | + if 'metric' in checkpoint: |
| 43 | + metric = checkpoint['metric'] |
| 44 | + return metric |
| 45 | + |
| 46 | + |
| 47 | +def main(): |
| 48 | + args = parser.parse_args() |
| 49 | + # by default use the EMA weights (if present) |
| 50 | + args.use_ema = not args.no_use_ema |
| 51 | + # by default sort by checkpoint metric (if present) and avg top n checkpoints |
| 52 | + args.sort = not args.no_sort |
| 53 | + |
| 54 | + if os.path.exists(args.output): |
| 55 | + print("Error: Output filename ({}) already exists.".format(args.output)) |
| 56 | + exit(1) |
| 57 | + |
| 58 | + pattern = args.input |
| 59 | + if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep): |
| 60 | + pattern += os.path.sep |
| 61 | + pattern += args.filter |
| 62 | + checkpoints = glob.glob(pattern, recursive=True) |
| 63 | + |
| 64 | + if args.sort: |
| 65 | + checkpoint_metrics = [] |
| 66 | + for c in checkpoints: |
| 67 | + metric = checkpoint_metric(c) |
| 68 | + if metric is not None: |
| 69 | + checkpoint_metrics.append((metric, c)) |
| 70 | + checkpoint_metrics = list(sorted(checkpoint_metrics)) |
| 71 | + checkpoint_metrics = checkpoint_metrics[-args.n:] |
| 72 | + print("Selected checkpoints:") |
| 73 | + [print(m, c) for m, c in checkpoint_metrics] |
| 74 | + avg_checkpoints = [c for m, c in checkpoint_metrics] |
| 75 | + else: |
| 76 | + avg_checkpoints = checkpoints |
| 77 | + print("Selected checkpoints:") |
| 78 | + [print(c) for c in checkpoints] |
| 79 | + |
| 80 | + avg_state_dict = {} |
| 81 | + avg_counts = {} |
| 82 | + for c in avg_checkpoints: |
| 83 | + new_state_dict = load_state_dict(c, args.use_ema) |
| 84 | + if not new_state_dict: |
| 85 | + print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) |
| 86 | + continue |
| 87 | + |
| 88 | + for k, v in new_state_dict.items(): |
| 89 | + if k not in avg_state_dict: |
| 90 | + avg_state_dict[k] = v.clone().to(dtype=torch.float64) |
| 91 | + avg_counts[k] = 1 |
| 92 | + else: |
| 93 | + avg_state_dict[k] += v.to(dtype=torch.float64) |
| 94 | + avg_counts[k] += 1 |
| 95 | + |
| 96 | + for k, v in avg_state_dict.items(): |
| 97 | + v.div_(avg_counts[k]) |
| 98 | + |
| 99 | + # float32 overflow seems unlikely based on weights seen to date, but who knows |
| 100 | + float32_info = torch.finfo(torch.float32) |
| 101 | + final_state_dict = {} |
| 102 | + for k, v in avg_state_dict.items(): |
| 103 | + v = v.clamp(float32_info.min, float32_info.max) |
| 104 | + final_state_dict[k] = v.to(dtype=torch.float32) |
| 105 | + |
| 106 | + torch.save(final_state_dict, args.output) |
| 107 | + with open(args.output, 'rb') as f: |
| 108 | + sha_hash = hashlib.sha256(f.read()).hexdigest() |
| 109 | + print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) |
| 110 | + |
| 111 | + |
| 112 | +if __name__ == '__main__': |
| 113 | + main() |
0 commit comments