2121# Downloads all model files again if manifest file is not present
2222MANIFEST_FILE = "model_manifest.json"
2323
24+ # Valid paths for model-saving specification
25+ VALID_PATHS = ("script" , "trace" , "torchscript" , "pytorch" , "all" )
26+
27+ # Key models selected for benchmarking with their respective paths
2428BENCHMARK_MODELS = {
25- "vgg16" : {"model" : models .vgg16 (weights = None ), "path" : "script" },
26- "resnet50" : {"model" : models .resnet50 (weights = None ), "path" : "script" },
29+ "vgg16" : {"model" : models .vgg16 (pretrained = True ), "path" : ["script" , "pytorch" ]},
30+ "resnet50" : {
31+ "model" : models .resnet50 (weights = None ),
32+ "path" : ["script" , "pytorch" ],
33+ },
2734 "efficientnet_b0" : {
2835 "model" : timm .create_model ("efficientnet_b0" , pretrained = True ),
29- "path" : "script" ,
36+ "path" : [ "script" , "pytorch" ] ,
3037 },
3138 "vit" : {
3239 "model" : timm .create_model ("vit_base_patch16_224" , pretrained = True ),
@@ -40,18 +47,26 @@ def get(n, m, manifest):
4047 print ("Downloading {}" .format (n ))
4148 traced_filename = "models/" + n + "_traced.jit.pt"
4249 script_filename = "models/" + n + "_scripted.jit.pt"
50+ pytorch_filename = "models/" + n + "_pytorch.pt"
4351 x = torch .ones ((1 , 3 , 300 , 300 )).cuda ()
44- if n == "bert-base-uncased " :
52+ if n == "bert_base_uncased " :
4553 traced_model = m ["model" ]
4654 torch .jit .save (traced_model , traced_filename )
4755 manifest .update ({n : [traced_filename ]})
4856 else :
4957 m ["model" ] = m ["model" ].eval ().cuda ()
50- if m ["path" ] == "both" or m ["path" ] == "trace" :
58+
59+ # Get all desired model save specifications as list
60+ paths = [m ["path" ]] if isinstance (m ["path" ], str ) else m ["path" ]
61+
62+ # Depending on specified model save specifications, save desired model formats
63+ if any (path in ("all" , "torchscript" , "trace" ) for path in paths ):
64+ # (TorchScript) Traced model
5165 trace_model = torch .jit .trace (m ["model" ], [x ])
5266 torch .jit .save (trace_model , traced_filename )
5367 manifest .update ({n : [traced_filename ]})
54- if m ["path" ] == "both" or m ["path" ] == "script" :
68+ if any (path in ("all" , "torchscript" , "script" ) for path in paths ):
69+ # (TorchScript) Scripted model
5570 script_model = torch .jit .script (m ["model" ])
5671 torch .jit .save (script_model , script_filename )
5772 if n in manifest .keys ():
@@ -60,6 +75,15 @@ def get(n, m, manifest):
6075 manifest .update ({n : files })
6176 else :
6277 manifest .update ({n : [script_filename ]})
78+ if any (path in ("all" , "pytorch" ) for path in paths ):
79+ # (PyTorch Module) model
80+ torch .save (m ["model" ], pytorch_filename )
81+ if n in manifest .keys ():
82+ files = list (manifest [n ]) if type (manifest [n ]) != list else manifest [n ]
83+ files .append (script_filename )
84+ manifest .update ({n : files })
85+ else :
86+ manifest .update ({n : [script_filename ]})
6387 return manifest
6488
6589
@@ -72,15 +96,35 @@ def download_models(version_matches, manifest):
7296 for n , m in BENCHMARK_MODELS .items ():
7397 scripted_filename = "models/" + n + "_scripted.jit.pt"
7498 traced_filename = "models/" + n + "_traced.jit.pt"
99+ pytorch_filename = "models/" + n + "_pytorch.pt"
75100 # Check if model file exists on disk
101+
102+ # Extract model specifications as list and ensure all desired formats exist
103+ paths = [m ["path" ]] if isinstance (m ["path" ], str ) else m ["path" ]
76104 if (
77105 (
78- m ["path" ] == "both"
106+ any (path == "all" for path in paths )
107+ and os .path .exists (scripted_filename )
108+ and os .path .exists (traced_filename )
109+ and os .path .exists (pytorch_filename )
110+ )
111+ or (
112+ any (path == "torchscript" for path in paths )
79113 and os .path .exists (scripted_filename )
80114 and os .path .exists (traced_filename )
81115 )
82- or (m ["path" ] == "script" and os .path .exists (scripted_filename ))
83- or (m ["path" ] == "trace" and os .path .exists (traced_filename ))
116+ or (
117+ any (path == "script" for path in paths )
118+ and os .path .exists (scripted_filename )
119+ )
120+ or (
121+ any (path == "trace" for path in paths )
122+ and os .path .exists (traced_filename )
123+ )
124+ or (
125+ any (path == "pytorch" for path in paths )
126+ and os .path .exists (pytorch_filename )
127+ )
84128 ):
85129 print ("Skipping {} " .format (n ))
86130 continue
@@ -90,7 +134,6 @@ def download_models(version_matches, manifest):
90134def main ():
91135 manifest = None
92136 version_matches = False
93- manifest_exists = False
94137
95138 # Check if Manifest file exists or is empty
96139 if not os .path .exists (MANIFEST_FILE ) or os .stat (MANIFEST_FILE ).st_size == 0 :
@@ -99,7 +142,6 @@ def main():
99142 # Creating an empty manifest file for overwriting post setup
100143 os .system ("touch {}" .format (MANIFEST_FILE ))
101144 else :
102- manifest_exists = True
103145
104146 # Load manifest if already exists
105147 with open (MANIFEST_FILE , "r" ) as f :
@@ -129,4 +171,13 @@ def main():
129171 f .truncate ()
130172
131173
132- main ()
174+ if __name__ == "__main__" :
175+ # Ensure all specified desired model formats exist and are valid
176+ paths = [
177+ [m ["path" ]] if isinstance (m ["path" ], str ) else m ["path" ]
178+ for m in BENCHMARK_MODELS .values ()
179+ ]
180+ assert all (
181+ (path in VALID_PATHS ) for path_list in paths for path in path_list
182+ ), "Not all 'path' attributes in BENCHMARK_MODELS are valid"
183+ main ()
0 commit comments