Skip to content

Commit 225a4f4

Browse files
committed
[mlir][ingress][RFC] Initial version of fx-importer script using torch-mlir
Signed-off-by: dchigarev <dmitry.chigarev@intel.com>
1 parent 20c7e72 commit 225a4f4

File tree

3 files changed

+214
-24
lines changed

3 files changed

+214
-24
lines changed

ingress/Torch-MLIR/generate-mlir.py

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,60 @@
33
import argparse
44
import os
55
import torch
6-
import torch.nn as nn
76
from torch_mlir import fx
87
from torch_mlir.fx import OutputType
98

9+
from utils import parse_shape_str, load_callable_symbol, generate_fake_tensor
10+
11+
1012
# Parse arguments for selecting which model to load and which MLIR dialect to generate
1113
def parse_args():
1214
parser = argparse.ArgumentParser(description="Generate MLIR for Torch-MLIR models.")
1315
parser.add_argument(
14-
"--model",
16+
"--model-entrypoint",
1517
type=str,
1618
required=True,
17-
help="Path to the Torch model file.",
19+
help="Path to the model entrypoint, e.g. 'torchvision.models:resnet18' or '/path/to/model.py:build_model'.",
20+
)
21+
parser.add_argument(
22+
"--model-state-path",
23+
type=str,
24+
required=False,
25+
help="Path to the state file of the Torch model (usually has .pt or .pth extension).",
26+
)
27+
parser.add_argument(
28+
"--model-args",
29+
type=str,
30+
required=False,
31+
default="[]",
32+
help=""
33+
"Positional arguments to pass to the model-entry "
34+
"(note that this argument will be passed to an 'eval',"
35+
" so the string should contain a valid python code).",
36+
)
37+
parser.add_argument(
38+
"--model-kwargs",
39+
type=str,
40+
required=False,
41+
default="{}",
42+
help=""
43+
"Keyword arguments to pass to the model-entry "
44+
"(note that this argument will be passed to an 'eval',"
45+
" so the string should contain a valid python code).",
46+
)
47+
parser.add_argument(
48+
"--sample-shapes",
49+
type=str,
50+
required=False,
51+
help="Tensor shapes/dtype that the 'forward' method of the model will be called with,"
52+
" e.g. '1,3,224,224,float32'. Must be specified if '--sample-fn' is not given.",
53+
)
54+
parser.add_argument(
55+
"--sample-fn",
56+
type=str,
57+
required=False,
58+
help="Path to a function that generates sample arguments for the model's 'forward' method."
59+
" The function should return a tuple of (args, kwargs). If this is given, '--sample-shapes' is ignored.",
1860
)
1961
parser.add_argument(
2062
"--dialect",
@@ -23,50 +65,83 @@ def parse_args():
2365
default="linalg",
2466
help="MLIR dialect to generate.",
2567
)
68+
parser.add_argument(
69+
"--out-mlir",
70+
type=str,
71+
required=False,
72+
help="Path to save the generated MLIR module.",
73+
)
2674
return parser.parse_args()
2775

28-
# Functin to load the Torch model
29-
def load_torch_model(model_path):
3076

77+
# Function to load the Torch model
78+
def load_torch_model(model_path):
3179
if not os.path.exists(model_path):
3280
raise FileNotFoundError(f"Model file {model_path} does not exist.")
33-
81+
3482
model = torch.load(model_path)
3583
return model
3684

37-
# Function to generate MLIR from the Torch model
38-
# See: https://github.com/MrSidims/PytorchExplorer/blob/main/backend/server.py#L237
39-
def generate_mlir(model, dialect):
4085

86+
def generate_sample_args(shape_str, sample_fn_path) -> tuple[tuple, dict]:
87+
"""
88+
Generate sample arguments for the model's 'forward' method.
89+
(Required by torch_mlir.fx.export_and_import)
90+
"""
91+
if sample_fn_path is None:
92+
shape, dtype = parse_shape_str(shape_str)
93+
return (generate_fake_tensor(shape, dtype),), {}
94+
95+
return load_callable_symbol(sample_fn_path)()
96+
97+
98+
def generate_mlir(model, dialect, sample_args, sample_kwargs):
4199
# Convert the Torch model to MLIR
42100
output_type = None
43101
if dialect == "torch":
44102
output_type = OutputType.TORCH
45103
elif dialect == "linalg":
46-
output_type = OutputType.LINALG
104+
output_type = OutputType.LINALG_ON_TENSORS
47105
elif dialect == "stablehlo":
48106
output_type = OutputType.STABLEHLO
49107
elif dialect == "tosa":
50108
output_type = OutputType.TOSA
51109
else:
52110
raise ValueError(f"Unsupported dialect: {dialect}")
53111

54-
module = fx.export_and_import(model, "", output_type=output_type)
112+
model.eval()
113+
module = fx.export_and_import(
114+
model, *sample_args, output_type=output_type, **sample_kwargs
115+
)
55116
return module
56117

118+
57119
# Main function to execute the script
58120
def main():
59121
args = parse_args()
60-
122+
61123
# Load the Torch model
62-
model = load_torch_model(args.model)
63-
124+
entrypoint = load_callable_symbol(args.model_entrypoint)
125+
126+
model = entrypoint(*eval(args.model_args), **eval(args.model_kwargs))
127+
if args.model_state_path is not None:
128+
state_dict = load_torch_model(args.model_state_path)
129+
model.load_state_dict(state_dict)
130+
131+
sample_args, sample_kwargs = generate_sample_args(
132+
args.sample_shapes, args.sample_fn
133+
)
64134
# Generate MLIR from the model
65-
mlir_module = generate_mlir(model, args.dialect)
66-
135+
mlir_module = generate_mlir(model, args.dialect, sample_args, sample_kwargs)
136+
67137
# Print or save the MLIR module
68-
print(mlir_module)
138+
if args.out_mlir:
139+
with open(args.out_mlir, "w") as f:
140+
f.write(str(mlir_module))
141+
else:
142+
print(mlir_module)
143+
69144

70145
# Entry point for the script
71146
if __name__ == "__main__":
72-
main()
147+
main()
Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,54 @@
11
#!/usr/bin/env bash
22

33
# Command line argument for model to load and MLIR dialect to generate
4-
while getopts "m:d:" opt; do
4+
while getopts "m:d:s:a:k:S:f:o:" opt; do
55
case $opt in
66
m)
77
MODEL=$OPTARG
88
;;
99
d)
1010
DIALECT=$OPTARG
1111
;;
12+
s)
13+
STATE_PATH=$OPTARG
14+
;;
15+
a)
16+
MODEL_ARGS=$OPTARG
17+
;;
18+
k)
19+
MODEL_KWARGS=$OPTARG
20+
;;
21+
S)
22+
SAMPLE_SHAPES=$OPTARG
23+
;;
24+
f)
25+
SAMPLE_FN=$OPTARG
26+
;;
27+
o)
28+
OUT_MLIR=$OPTARG
29+
;;
1230
*)
13-
echo "Usage: $0 [-m model] [-d dialect]"
31+
echo "Usage: $0 [-m model-entrypoint] [-d dialect] [-s state_path] [-a model_args] [-k model_kwargs] [-S sample_shapes] [-f sample_fn] [-o out_mlir]"
1432
exit 1
1533
;;
1634
esac
1735
done
36+
1837
if [ -z "$MODEL" ]; then
19-
echo "Model not specified. Please provide a model using -m option."
38+
echo "Model not specified. Please provide a model entrypoint using -m option (e.g. torchvision.models:resnet18)."
2039
exit 1
2140
fi
2241
if [ -z "$DIALECT" ]; then
2342
DIALECT="linalg"
2443
fi
2544

45+
# If neither sample shapes nor sample fn provided, the Python will error.
46+
# Give a friendly check here to fail early.
47+
if [ -z "$SAMPLE_SHAPES" ] && [ -z "$SAMPLE_FN" ]; then
48+
echo "Either -S sample_shapes or -f sample_fn must be provided."
49+
exit 1
50+
fi
51+
2652
# Enable local virtualenv created by install-virtualenv.sh
2753
if [ ! -d "torch-mlir-venv" ]; then
2854
echo "Virtual environment not found. Please run install-virtualenv.sh first."
@@ -33,10 +59,19 @@ source torch-mlir-venv/bin/activate
3359
# Find script directory
3460
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
3561

62+
# Build python arg list
63+
args=( "--model-entrypoint" "$MODEL" "--dialect" "$DIALECT" )
64+
[ -n "$STATE_PATH" ] && args+=( "--model-state-path" "$STATE_PATH" )
65+
[ -n "$MODEL_ARGS" ] && args+=( "--model-args" "$MODEL_ARGS" )
66+
[ -n "$MODEL_KWARGS" ] && args+=( "--model-kwargs" "$MODEL_KWARGS" )
67+
[ -n "$SAMPLE_SHAPES" ]&& args+=( "--sample-shapes" "$SAMPLE_SHAPES" )
68+
[ -n "$SAMPLE_FN" ] && args+=( "--sample-fn" "$SAMPLE_FN" )
69+
[ -n "$OUT_MLIR" ] && args+=( "--out-mlir" "$OUT_MLIR" )
70+
3671
# Use the Python script to generate MLIR
37-
echo "Generating MLIR for model '$MODEL' with dialect '$DIALECT'..."
38-
python $SCRIPT_DIR/generate-mlir.py --model "$MODEL" --dialect "$DIALECT"
72+
echo "Generating MLIR for model entrypoint '$MODEL' with dialect '$DIALECT'..."
73+
python "$SCRIPT_DIR/generate-mlir.py" "${args[@]}"
3974
if [ $? -ne 0 ]; then
4075
echo "Failed to generate MLIR for model '$MODEL'."
4176
exit 1
42-
fi
77+
fi

ingress/Torch-MLIR/utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import importlib
2+
import importlib.util
3+
import sys
4+
import os
5+
6+
import torch
7+
from torch._subclasses.fake_tensor import FakeTensorMode
8+
9+
from typing import Callable
10+
11+
12+
def load_callable_symbol(entry: str) -> Callable:
13+
"""
14+
Load a callable python symbol from a module or a file.
15+
16+
Parameters
17+
----------
18+
entry : str
19+
A string specifying the module or file and the attribute path,
20+
in the format 'module_or_path:attr', e.g.
21+
'torchvision.models:resnet18' or '/path/to/model.py:build_model'.
22+
23+
Returns
24+
-------
25+
Callable
26+
"""
27+
if ":" not in entry:
28+
raise ValueError("Entry must be like 'module_or_path:attr'")
29+
30+
left, right = entry.split(":", 1)
31+
attr_path = right.split(".")
32+
33+
if os.path.exists(left) and left.endswith(".py"):
34+
mod_dir = os.path.abspath(os.path.dirname(left))
35+
mod_name = os.path.splitext(os.path.basename(left))[0]
36+
sys_path_was = list(sys.path)
37+
try:
38+
if mod_dir not in sys.path:
39+
sys.path.insert(0, mod_dir)
40+
spec = importlib.util.spec_from_file_location(mod_name, left)
41+
if spec is None or spec.loader is None:
42+
raise ImportError(f"Cannot load spec from {left}")
43+
module = importlib.util.module_from_spec(spec)
44+
spec.loader.exec_module(module)
45+
finally:
46+
sys.path = sys_path_was
47+
else:
48+
module = importlib.import_module(left)
49+
50+
obj = module
51+
for name in attr_path:
52+
obj = getattr(obj, name)
53+
54+
return obj
55+
56+
57+
def parse_shape_str(shape: str) -> tuple[tuple[int], torch.dtype]:
58+
"""
59+
Parse a shape string into a shape tuple and a torch dtype.
60+
61+
Parameters
62+
----------
63+
shape : str
64+
A string representing the shape and dtype, e.g. '1,3,224,224,float32'.
65+
"""
66+
components = shape.split(",")
67+
shapes = components[:-1]
68+
dtype = components[-1]
69+
tdtype = getattr(torch, dtype)
70+
if tdtype is None:
71+
raise ValueError(f"Unsupported dtype: {dtype}")
72+
if any(dim == "?" for dim in shapes):
73+
raise ValueError(f"Dynamic shapes are not supported yet: {shape}")
74+
return (tuple(int(dim) for dim in shapes if dim), tdtype)
75+
76+
77+
def generate_fake_tensor(shape: tuple[int], dtype: torch.dtype) -> torch.Tensor:
78+
"""Generate a fake tensor (has no actual buffer) with the given shape and dtype."""
79+
with FakeTensorMode():
80+
return torch.empty(shape, dtype=dtype)

0 commit comments

Comments
 (0)