@@ -22,15 +22,15 @@ def parse_args():
2222 "--model-state-path" ,
2323 type = str ,
2424 required = False ,
25- help = "Path to the state file of the Torch model (usually has .pt or .pth extension)." ,
25+ help = "Path to a state file of the Torch model (usually has .pt or .pth extension)." ,
2626 )
2727 parser .add_argument (
2828 "--model-args" ,
2929 type = str ,
3030 required = False ,
3131 default = "[]" ,
3232 help = ""
33- "Positional arguments to pass to the model-entry "
33+ "Positional arguments to pass to the model's entrypoint "
3434 "(note that this argument will be passed to an 'eval',"
3535 " so the string should contain a valid python code)." ,
3636 )
@@ -40,7 +40,7 @@ def parse_args():
4040 required = False ,
4141 default = "{}" ,
4242 help = ""
43- "Keyword arguments to pass to the model-entry "
43+ "Keyword arguments to pass to the model's entrypoint "
4444 "(note that this argument will be passed to an 'eval',"
4545 " so the string should contain a valid python code)." ,
4646 )
@@ -95,7 +95,7 @@ def generate_sample_args(shape_str, sample_fn_path) -> tuple[tuple, dict]:
9595 return load_callable_symbol (sample_fn_path )()
9696
9797
98- def generate_mlir (model , dialect , sample_args , sample_kwargs ):
98+ def generate_mlir (model , sample_args , sample_kwargs = None , dialect = "linalg" ):
9999 # Convert the Torch model to MLIR
100100 output_type = None
101101 if dialect == "torch" :
@@ -109,6 +109,9 @@ def generate_mlir(model, dialect, sample_args, sample_kwargs):
109109 else :
110110 raise ValueError (f"Unsupported dialect: { dialect } " )
111111
112+ if sample_kwargs is None :
113+ sample_kwargs = {}
114+
112115 model .eval ()
113116 module = fx .export_and_import (
114117 model , * sample_args , output_type = output_type , ** sample_kwargs
@@ -132,7 +135,7 @@ def main():
132135 args .sample_shapes , args .sample_fn
133136 )
134137 # Generate MLIR from the model
135- mlir_module = generate_mlir (model , args . dialect , sample_args , sample_kwargs )
138+ mlir_module = generate_mlir (model , sample_args , sample_kwargs , args . dialect )
136139
137140 # Print or save the MLIR module
138141 if args .out_mlir :
0 commit comments