Skip to content

Commit 20ca817

Browse files
committed
fix typos
Signed-off-by: dchigarev <dmitry.chigarev@intel.com>
1 parent 225a4f4 commit 20ca817

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

ingress/Torch-MLIR/generate-mlir.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def parse_args():
2222
"--model-state-path",
2323
type=str,
2424
required=False,
25-
help="Path to the state file of the Torch model (usually has .pt or .pth extension).",
25+
help="Path to a state file of the Torch model (usually has .pt or .pth extension).",
2626
)
2727
parser.add_argument(
2828
"--model-args",
2929
type=str,
3030
required=False,
3131
default="[]",
3232
help=""
33-
"Positional arguments to pass to the model-entry "
33+
"Positional arguments to pass to the model's entrypoint "
3434
"(note that this argument will be passed to an 'eval',"
3535
" so the string should contain a valid python code).",
3636
)
@@ -40,7 +40,7 @@ def parse_args():
4040
required=False,
4141
default="{}",
4242
help=""
43-
"Keyword arguments to pass to the model-entry "
43+
"Keyword arguments to pass to the model's entrypoint "
4444
"(note that this argument will be passed to an 'eval',"
4545
" so the string should contain a valid python code).",
4646
)
@@ -95,7 +95,7 @@ def generate_sample_args(shape_str, sample_fn_path) -> tuple[tuple, dict]:
9595
return load_callable_symbol(sample_fn_path)()
9696

9797

98-
def generate_mlir(model, dialect, sample_args, sample_kwargs):
98+
def generate_mlir(model, sample_args, sample_kwargs=None, dialect="linalg"):
9999
# Convert the Torch model to MLIR
100100
output_type = None
101101
if dialect == "torch":
@@ -109,6 +109,9 @@ def generate_mlir(model, dialect, sample_args, sample_kwargs):
109109
else:
110110
raise ValueError(f"Unsupported dialect: {dialect}")
111111

112+
if sample_kwargs is None:
113+
sample_kwargs = {}
114+
112115
model.eval()
113116
module = fx.export_and_import(
114117
model, *sample_args, output_type=output_type, **sample_kwargs
@@ -132,7 +135,7 @@ def main():
132135
args.sample_shapes, args.sample_fn
133136
)
134137
# Generate MLIR from the model
135-
mlir_module = generate_mlir(model, args.dialect, sample_args, sample_kwargs)
138+
mlir_module = generate_mlir(model, sample_args, sample_kwargs, args.dialect)
136139

137140
# Print or save the MLIR module
138141
if args.out_mlir:

0 commit comments

Comments
 (0)