Skip to content

Commit 1a4b755

Browse files
committed
[ingress][torch-mlir] Add utility functions to import models using torch-mlir
Signed-off-by: dchigarev <dmitry.chigarev@intel.com>
1 parent 20c7e72 commit 1a4b755

File tree

7 files changed

+358
-0
lines changed

7 files changed

+358
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""
2+
Example demonstrating how to load a PyTorch model to MLIR using Lighthouse,
3+
without instantiating the model on our side.
4+
5+
The script uses 'lighthouse.ingress.torch.import_from_file' function that
6+
takes a filepath to a Python file containing the model definition, along with
7+
the names of functions to get model init arguments and sample inputs. The function
8+
imports the model class on its own, instantiates it, and passes it torch_mlir
9+
to get a MLIR module in the specified dialect.
10+
11+
The script uses model from 'DummyMLP/model.py' as an example.
12+
"""
13+
14+
import os
15+
from pyparsing import Path
16+
17+
# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module)
18+
import mlir.dialects.func as func
19+
from mlir import ir, passmanager
20+
21+
# Lighthouse imports
22+
from lighthouse.ingress.torch import import_from_file
23+
24+
# Step 1: Setup paths to locate the model definition file
25+
script_dir = Path(os.path.dirname(os.path.abspath(__file__)))
26+
model_path = script_dir / "DummyMLP" / "model.py"
27+
28+
ir_context = ir.Context()
29+
30+
# Step 2: Convert PyTorch model to MLIR
31+
# Conversion step where Lighthouse:
32+
# - Loads the DummyMLP class and instantiates with arguments obtained from 'get_init_inputs()'
33+
# - Calls get_sample_inputs() to get sample input tensors for shape inference
34+
# - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir
35+
mlir_module_ir : ir.Module = import_from_file(
36+
model_path, # Path to the Python file containing the model
37+
model_class_name="DummyMLP", # Name of the PyTorch nn.Module class to convert
38+
init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__()
39+
inputs_args_fn_name="get_sample_inputs", # Function that returns sample inputs for tracing
40+
dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types)
41+
ir_context=ir_context # MLIR context for the conversion
42+
)
43+
44+
# Step 3: Extract the main function operation from the MLIR module and print its metadata
45+
func_op : func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0]
46+
print(f"entry-point name: {func_op.name}")
47+
print(f"entry-point type: {func_op.type}")
48+
49+
# Step 4: Apply some MLIR passes using a PassManager
50+
pm = passmanager.PassManager(context=ir_context)
51+
52+
pm.add("one-shot-bufferize")
53+
pm.add("linalg-specialize-generic-ops")
54+
pm.run(mlir_module_ir.operation)
55+
56+
# Step 5: Output the final MLIR
57+
print("\n\nModule dump after running pm.Pipeline:")
58+
mlir_module_ir.dump()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Example demonstrating how to load an already instantiated PyTorch model
3+
to MLIR using Lighthouse.
4+
5+
The script uses 'lighthouse.ingress.torch.import_from_model' function that
6+
takes an already instantiated PyTorch model, along with its sample inputs.
7+
The function passes the model to torch_mlir to get a MLIR module in the
8+
specified dialect.
9+
10+
The script uses model from 'DummyMLP/model.py' as an example.
11+
"""
12+
13+
import torch
14+
15+
# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module)
16+
import mlir.dialects.func as func
17+
from mlir import ir, passmanager
18+
19+
# Lighthouse imports
20+
from lighthouse.ingress.torch import import_from_model
21+
22+
# Import a sample model definition
23+
from .DummyMLP.model import DummyMLP
24+
25+
# Step 1: Instantiate a model and prepare sample input
26+
model = DummyMLP()
27+
sample_input = torch.randn(1, 10)
28+
29+
ir_context = ir.Context()
30+
# Step 2: Convert PyTorch model to MLIR
31+
mlir_module_ir : ir.Module = import_from_model(
32+
model,
33+
sample_args=(sample_input,),
34+
ir_context=ir_context
35+
)
36+
37+
# Step 3: Extract the main function operation from the MLIR module and print its metadata
38+
func_op : func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0]
39+
print(f"entry-point name: {func_op.name}")
40+
print(f"entry-point type: {func_op.type}")
41+
42+
# Step 4: Apply some MLIR passes using a PassManager (optional)
43+
pm = passmanager.PassManager(context=ir_context)
44+
pm.add("one-shot-bufferize")
45+
pm.add("linalg-specialize-generic-ops")
46+
pm.run(mlir_module_ir.operation)
47+
48+
# Step 5: Output the final MLIR
49+
print("\n\nModule dump after running pm.Pipeline:")
50+
mlir_module_ir.dump()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Defines a simple PyTorch model to be used in lighthouse's ingress examples."""
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
import os
7+
8+
class DummyMLP(nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
self.net = nn.Sequential(
12+
nn.Linear(10, 32),
13+
nn.ReLU(),
14+
nn.Linear(32, 2)
15+
)
16+
17+
def forward(self, x):
18+
return self.net(x)
19+
20+
21+
def get_init_inputs():
22+
"""Function to return args to pass to DummyMLP.__init__()"""
23+
return ()
24+
25+
26+
def get_sample_inputs():
27+
"""Arguments to pass to DummyMLP.forward()"""
28+
return (torch.randn(1, 10),)
29+
30+
31+
if __name__ == "__main__":
32+
script_dir = os.path.dirname(os.path.abspath(__file__))
33+
torch.save(DummyMLP().state_dict(), os.path.join(script_dir, "dummy_mlp.pth"))

python/lighthouse/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.1.0a1"

python/lighthouse/ingress/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .torch_import import import_from_file, import_from_model
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import importlib
2+
import importlib.util
3+
from pathlib import Path
4+
from typing import Iterable, Mapping
5+
6+
try:
7+
import torch
8+
import torch.nn as nn
9+
except ImportError as e:
10+
raise ImportError(
11+
"PyTorch is required to use the torch import functionality. "
12+
"Please run 'uv pip install .[torch-mlir]'"
13+
) from e
14+
15+
try:
16+
from torch_mlir import fx
17+
from torch_mlir.fx import OutputType
18+
except ImportError as e:
19+
raise ImportError(
20+
"torch-mlir is required to use the torch import functionality. "
21+
"Please run 'uv pip install .[torch-mlir]'"
22+
) from e
23+
24+
from mlir import ir
25+
26+
def import_from_model(
27+
model: nn.Module,
28+
sample_args : Iterable,
29+
sample_kwargs : Mapping = None,
30+
dialect : OutputType | str = OutputType.LINALG_ON_TENSORS,
31+
ir_context : ir.Context | None = None,
32+
**kwargs,
33+
) -> str | ir.Module:
34+
"""
35+
Import a PyTorch nn.Module into MLIR.
36+
37+
The function uses torch-mlir's FX importer to convert the given PyTorch model
38+
into an MLIR module in the specified dialect. The user has to provide sample
39+
input arguments (e.g. a torch.Tensor with the correct shape).
40+
41+
Parameters
42+
----------
43+
model : nn.Module
44+
The PyTorch model to import.
45+
sample_args : Iterable
46+
Sample input arguments to the model.
47+
sample_kwargs : Mapping, optional
48+
Sample keyword arguments to the model.
49+
dialect : torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}, default: OutputType.LINALG_ON_TENSORS
50+
The target dialect for the imported MLIR module.
51+
ir_context : ir.Context, optional
52+
An optional MLIR context to use for parsing the module.
53+
If not provided, the module is returned as a string.
54+
**kwargs
55+
Additional keyword arguments passed to the fx.export_and_import function.
56+
57+
Returns
58+
-------
59+
str | ir.Module
60+
The imported MLIR module as a string or an ir.Module if `ir_context` is provided.
61+
62+
Examples
63+
--------
64+
>>> import torch
65+
>>> import torch.nn as nn
66+
>>> from lighthouse.ingress.torch_import import import_from_model
67+
>>> class SimpleModel(nn.Module):
68+
... def __init__(self):
69+
... super().__init__()
70+
... self.fc = nn.Linear(10, 5)
71+
... def forward(self, x):
72+
... return self.fc(x)
73+
>>> model = SimpleModel()
74+
>>> sample_input = (torch.randn(1, 10),)
75+
>>> #
76+
>>> # option 1: get MLIR module as a string
77+
>>> mlir_module : str = import_from_model(model, sample_input, dialect="linalg-on-tensors")
78+
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
79+
>>> # option 2: get MLIR module as an ir.Module
80+
>>> ir_context = ir.Context()
81+
>>> mlir_module_ir : ir.Module = import_from_model(model, sample_input, dialect="tosa", ir_context=ir_context)
82+
>>> # ... run pm.Pipeline on the ir.Module ...
83+
"""
84+
if dialect == "linalg":
85+
raise ValueError(
86+
"Dialect 'linalg' is not supported. Did you mean 'linalg-on-tensors'?"
87+
)
88+
89+
if sample_kwargs is None:
90+
sample_kwargs = {}
91+
92+
model.eval()
93+
module = fx.export_and_import(
94+
model, *sample_args, output_type=dialect, **sample_kwargs, **kwargs
95+
)
96+
97+
text_module = str(module)
98+
if ir_context is None:
99+
return text_module
100+
return ir.Module.parse(text_module, context=ir_context)
101+
102+
103+
def import_from_file(
104+
filepath: str | Path,
105+
model_class_name: str = "Model",
106+
init_args_fn_name: str | None = "get_init_inputs",
107+
inputs_args_fn_name: str = "get_inputs",
108+
state_path : str | Path | None = None,
109+
dialect : OutputType | str = OutputType.LINALG_ON_TENSORS,
110+
ir_context : ir.Context | None = None,
111+
**kwargs,
112+
) -> str | ir.Module:
113+
"""
114+
Load a PyTorch nn.Module from a file and import it into MLIR.
115+
116+
The function takes a `filepath` to a Python file containing the model definition,
117+
along with the names of functions to get model init arguments and sample inputs.
118+
The function imports the model class on its own, instantiates it, and passes
119+
it ``torch_mlir`` to get a MLIR module in the specified `dialect`.
120+
121+
Parameters
122+
----------
123+
filepath : str | Path
124+
Path to the Python file containing the model definition.
125+
model_class_name : str, default: "Model"
126+
The name of the model class in the file.
127+
init_args_fn_name : str | None, default: "get_init_inputs"
128+
The name of the function in the file that returns the arguments for
129+
initializing the model. If None, the model is initialized without arguments.
130+
inputs_args_fn_name : str, default: "get_inputs"
131+
The name of the function in the file that returns the sample input arguments
132+
for the model.
133+
state_path : str | Path | None, default: None
134+
Optional path to a file containing the model's ``state_dict``.
135+
dialect: torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}, default: OutputType.LINALG_ON_TENSORS
136+
The target dialect for the imported MLIR module.
137+
ir_context : ir.Context, optional
138+
An optional MLIR context to use for parsing the module.
139+
If not provided, the module is returned as a string.
140+
**kwargs
141+
Additional keyword arguments passed to the fx.export_and_import function.
142+
143+
Returns
144+
-------
145+
str | ir.Module
146+
The imported MLIR module as a string or an ir.Module if `ir_context` is provided.
147+
148+
Examples
149+
--------
150+
Given a file `path/to/model_file.py` with the following content:
151+
```python
152+
import torch
153+
import torch.nn as nn
154+
155+
class MyModel(nn.Module):
156+
def __init__(self):
157+
super().__init__()
158+
self.fc = nn.Linear(10, 5)
159+
def forward(self, x):
160+
return self.fc(x)
161+
162+
def get_inputs():
163+
return (torch.randn(1, 10),)
164+
```
165+
166+
The import script would look line:
167+
>>> from lighthouse.ingress.torch_import import import_from_file
168+
>>> # option 1: get MLIR module as a string
169+
>>> mlir_module : str = import_from_file(
170+
... "path/to/model_file.py",
171+
... model_class_name="MyModel",
172+
... init_args_fn_name=None,
173+
... dialect="linalg-on-tensors"
174+
... )
175+
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
176+
>>> # option 2: get MLIR module as an ir.Module
177+
>>> ir_context = ir.Context()
178+
>>> mlir_module_ir : ir.Module = import_from_file(
179+
... "path/to/model_file.py",
180+
... model_class_name="MyModel",
181+
... init_args_fn_name=None,
182+
... dialect="linalg-on-tensors",
183+
... ir_context=ir_context
184+
... )
185+
>>> # ... run pm.Pipeline on the ir.Module ...
186+
"""
187+
if isinstance(filepath, str):
188+
filepath = Path(filepath)
189+
module_name = filepath.stem
190+
191+
spec = importlib.util.spec_from_file_location(module_name, filepath)
192+
module = importlib.util.module_from_spec(spec)
193+
spec.loader.exec_module(module)
194+
195+
model = getattr(module, model_class_name, None)
196+
if model is None:
197+
raise ValueError(f"Model class '{model_class_name}' not found in {filepath}")
198+
199+
if init_args_fn_name is None:
200+
init_args_fn = lambda *args, **kwargs: ()
201+
else:
202+
init_args_fn = getattr(module, init_args_fn_name, None)
203+
if init_args_fn is None:
204+
raise ValueError(f"Init args function '{init_args_fn_name}' not found in {filepath}")
205+
206+
inputs_args_fn = getattr(module, inputs_args_fn_name, None)
207+
if inputs_args_fn is None:
208+
raise ValueError(f"Inputs args function '{inputs_args_fn_name}' not found in {filepath}")
209+
210+
nn_model : nn.Module = model(*init_args_fn())
211+
if state_path is not None:
212+
state_dict = torch.load(state_path)
213+
nn_model.load_state_dict(state_dict)
214+
215+
return import_from_model(nn_model, *inputs_args_fn(), dialect=dialect, ir_context=ir_context, **kwargs)

0 commit comments

Comments
 (0)