103103 help = 'Name of model to train (default: "resnet50")' )
104104group .add_argument ('--pretrained' , action = 'store_true' , default = False ,
105105 help = 'Start with pretrained version of specified network (if avail)' )
106+ group .add_argument ('--pretrained-path' , default = None , type = str ,
107+ help = 'Load this checkpoint as if they were the pretrained weights (with adaptation).' )
106108group .add_argument ('--initial-checkpoint' , default = '' , type = str , metavar = 'PATH' ,
107- help = 'Initialize model from this checkpoint (default: none)' )
109+ help = 'Load this checkpoint into model after initialization (default: none)' )
108110group .add_argument ('--resume' , default = '' , type = str , metavar = 'PATH' ,
109111 help = 'Resume full model and optimizer state from checkpoint (default: none)' )
110112group .add_argument ('--no-resume-opt' , action = 'store_true' , default = False ,
@@ -420,6 +422,11 @@ def main():
420422 elif args .input_size is not None :
421423 in_chans = args .input_size [0 ]
422424
425+ factory_kwargs = {}
426+ if args .pretrained_path :
427+ # merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
428+ factory_kwargs ['pretrained_cfg_overlay' ] = dict (file = args .pretrained_path )
429+
423430 model = create_model (
424431 args .model ,
425432 pretrained = args .pretrained ,
@@ -433,6 +440,7 @@ def main():
433440 bn_eps = args .bn_eps ,
434441 scriptable = args .torchscript ,
435442 checkpoint_path = args .initial_checkpoint ,
443+ ** factory_kwargs ,
436444 ** args .model_kwargs ,
437445 )
438446 if args .head_init_scale is not None :
0 commit comments