33import argparse
44import os
55import torch
6- import torch .nn as nn
76from torch_mlir import fx
87from torch_mlir .fx import OutputType
98
9+ from utils import parse_shape_str , load_callable_symbol , generate_fake_tensor
10+
11+
1012# Parse arguments for selecting which model to load and which MLIR dialect to generate
1113def parse_args ():
1214 parser = argparse .ArgumentParser (description = "Generate MLIR for Torch-MLIR models." )
1315 parser .add_argument (
14- "--model" ,
16+ "--model-entrypoint " ,
1517 type = str ,
1618 required = True ,
17- help = "Path to the Torch model file." ,
19+ help = "Path to the model entrypoint, e.g. 'torchvision.models:resnet18' or '/path/to/model.py:build_model'." ,
20+ )
21+ parser .add_argument (
22+ "--model-state-path" ,
23+ type = str ,
24+ required = False ,
25+ help = "Path to the state file of the Torch model (usually has .pt or .pth extension)." ,
26+ )
27+ parser .add_argument (
28+ "--model-args" ,
29+ type = str ,
30+ required = False ,
31+ default = "[]" ,
32+ help = ""
33+ "Positional arguments to pass to the model-entry "
34+ "(note that this argument will be passed to an 'eval',"
35+ " so the string should contain a valid python code)." ,
36+ )
37+ parser .add_argument (
38+ "--model-kwargs" ,
39+ type = str ,
40+ required = False ,
41+ default = "{}" ,
42+ help = ""
43+ "Keyword arguments to pass to the model-entry "
44+ "(note that this argument will be passed to an 'eval',"
45+ " so the string should contain a valid python code)." ,
46+ )
47+ parser .add_argument (
48+ "--sample-shapes" ,
49+ type = str ,
50+ required = False ,
51+ help = "Tensor shapes/dtype that the 'forward' method of the model will be called with,"
52+ " e.g. '1,3,224,224,float32'. Must be specified if '--sample-fn' is not given." ,
53+ )
54+ parser .add_argument (
55+ "--sample-fn" ,
56+ type = str ,
57+ required = False ,
58+ help = "Path to a function that generates sample arguments for the model's 'forward' method."
59+ " The function should return a tuple of (args, kwargs). If this is given, '--sample-shapes' is ignored." ,
1860 )
1961 parser .add_argument (
2062 "--dialect" ,
@@ -23,50 +65,83 @@ def parse_args():
2365 default = "linalg" ,
2466 help = "MLIR dialect to generate." ,
2567 )
68+ parser .add_argument (
69+ "--out-mlir" ,
70+ type = str ,
71+ required = False ,
72+ help = "Path to save the generated MLIR module." ,
73+ )
2674 return parser .parse_args ()
2775
28- # Functin to load the Torch model
29- def load_torch_model (model_path ):
3076
77+ # Function to load the Torch model
78+ def load_torch_model (model_path ):
3179 if not os .path .exists (model_path ):
3280 raise FileNotFoundError (f"Model file { model_path } does not exist." )
33-
81+
3482 model = torch .load (model_path )
3583 return model
3684
37- # Function to generate MLIR from the Torch model
38- # See: https://github.com/MrSidims/PytorchExplorer/blob/main/backend/server.py#L237
39- def generate_mlir (model , dialect ):
4085
86+ def generate_sample_args (shape_str , sample_fn_path ) -> tuple [tuple , dict ]:
87+ """
88+ Generate sample arguments for the model's 'forward' method.
89+ (Required by torch_mlir.fx.export_and_import)
90+ """
91+ if sample_fn_path is None :
92+ shape , dtype = parse_shape_str (shape_str )
93+ return (generate_fake_tensor (shape , dtype ),), {}
94+
95+ return load_callable_symbol (sample_fn_path )()
96+
97+
98+ def generate_mlir (model , dialect , sample_args , sample_kwargs ):
4199 # Convert the Torch model to MLIR
42100 output_type = None
43101 if dialect == "torch" :
44102 output_type = OutputType .TORCH
45103 elif dialect == "linalg" :
46- output_type = OutputType .LINALG
104+ output_type = OutputType .LINALG_ON_TENSORS
47105 elif dialect == "stablehlo" :
48106 output_type = OutputType .STABLEHLO
49107 elif dialect == "tosa" :
50108 output_type = OutputType .TOSA
51109 else :
52110 raise ValueError (f"Unsupported dialect: { dialect } " )
53111
54- module = fx .export_and_import (model , "" , output_type = output_type )
112+ model .eval ()
113+ module = fx .export_and_import (
114+ model , * sample_args , output_type = output_type , ** sample_kwargs
115+ )
55116 return module
56117
118+
57119# Main function to execute the script
58120def main ():
59121 args = parse_args ()
60-
122+
61123 # Load the Torch model
62- model = load_torch_model (args .model )
63-
124+ entrypoint = load_callable_symbol (args .model_entrypoint )
125+
126+ model = entrypoint (* eval (args .model_args ), ** eval (args .model_kwargs ))
127+ if args .model_state_path is not None :
128+ state_dict = load_torch_model (args .model_state_path )
129+ model .load_state_dict (state_dict )
130+
131+ sample_args , sample_kwargs = generate_sample_args (
132+ args .sample_shapes , args .sample_fn
133+ )
64134 # Generate MLIR from the model
65- mlir_module = generate_mlir (model , args .dialect )
66-
135+ mlir_module = generate_mlir (model , args .dialect , sample_args , sample_kwargs )
136+
67137 # Print or save the MLIR module
68- print (mlir_module )
138+ if args .out_mlir :
139+ with open (args .out_mlir , "w" ) as f :
140+ f .write (str (mlir_module ))
141+ else :
142+ print (mlir_module )
143+
69144
70145# Entry point for the script
71146if __name__ == "__main__" :
72- main ()
147+ main ()
0 commit comments