Skip to content

Commit 95e6774

Browse files
authored
[ingress][torch-mlir] Initial version of wrapper around torch-mlir's fx-importer (#4)
The PR adds two utility functions as part of lighthouse's python package to convert torch models to a mlir module using `torch_mlir`. A user can import one of the functions (`import_from_model` or `import_from_file`) and get a `mlir.ir.Module` that they can use to run passes on or simply write its content into a file. Some use cases: 1. Import from an instance of a model: ```python from lighthouse.ingress.torch import import_from_model from mlir import ir ctx = ir.Context() module : ir.Module = import_from_model(torch_model_instance, sample_args=(torch.rand(1, 10),), ir_context=ctx) # can now run some passes on the module ``` 2. Import from a file where a torch model is defined: Imagine we want to import a [model from `KernelBench`](https://github.com/ScalingIntelligence/KernelBench/blob/main/KernelBench/level1/10_3D_tensor_matrix_multiplication.py). They ship models as python files where models and their arguments are uniformly defined. ```python from lighthouse.ingress.torch import import_from_file from mlir import ir ctx = ir.Context() kernel_bench_root = Path(...) module : ir.Module = import_from_model( filepath=kernel_bench_root / "level1" / "10_3D_tensor_matrix_multiplication.py", ir_context=ctx ) # can now run some passes on the module ```
1 parent 2b209d5 commit 95e6774

File tree

10 files changed

+433
-114
lines changed

10 files changed

+433
-114
lines changed

ingress/Torch-MLIR/generate-mlir.py

Lines changed: 0 additions & 72 deletions
This file was deleted.

ingress/Torch-MLIR/generate-mlir.sh

Lines changed: 0 additions & 42 deletions
This file was deleted.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Defines a simple PyTorch model to be used in lighthouse's ingress examples."""
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
import os
7+
8+
class MLPModel(nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
self.net = nn.Sequential(
12+
nn.Linear(10, 32),
13+
nn.ReLU(),
14+
nn.Linear(32, 2)
15+
)
16+
17+
def forward(self, x):
18+
return self.net(x)
19+
20+
21+
def get_init_inputs():
22+
"""Function to return args to pass to MLPModel.__init__()"""
23+
return ()
24+
25+
26+
def get_sample_inputs():
27+
"""Arguments to pass to MLPModel.forward()"""
28+
return (torch.randn(1, 10),)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Example demonstrating how to load a PyTorch model to MLIR using Lighthouse
3+
without initializing the model class on the user's side.
4+
5+
The script uses 'lighthouse.ingress.torch.import_from_file' function that
6+
takes a path to a Python file containing the model definition (a Python class derived from 'nn.Module'),
7+
along with the names of functions to get model init arguments and sample inputs. The function
8+
imports the model class on its own, initializes it, and passes it to torch_mlir
9+
to get a MLIR module in the specified dialect.
10+
11+
The script uses the model from 'MLPModel/model.py' as an example.
12+
"""
13+
14+
import os
15+
from pathlib import Path
16+
17+
# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module)
18+
import mlir.dialects.func as func
19+
from mlir import ir
20+
21+
# Lighthouse imports
22+
from lighthouse.ingress.torch import import_from_file
23+
24+
# Step 1: Set up paths to locate the model definition file
25+
script_dir = Path(os.path.dirname(os.path.abspath(__file__)))
26+
model_path = script_dir / "MLPModel" / "model.py"
27+
28+
ir_context = ir.Context()
29+
30+
# Step 2: Convert PyTorch model to MLIR
31+
# Conversion step where Lighthouse:
32+
# - Loads the MLPModel class and instantiates it with arguments obtained from 'get_init_inputs()'
33+
# - Calls get_sample_inputs() to get sample input tensors for shape inference
34+
# - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir
35+
mlir_module_ir: ir.Module = import_from_file(
36+
model_path, # Path to the Python file containing the model
37+
model_class_name="MLPModel", # Name of the PyTorch nn.Module class to convert
38+
init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__()
39+
sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)'
40+
dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types)
41+
ir_context=ir_context # MLIR context for the conversion
42+
)
43+
44+
# The PyTorch model is now converted to MLIR at this point. You can now convert
45+
# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file.
46+
#
47+
# The following optional MLIR-processing steps are to give you an idea of what can
48+
# also be done with the MLIR module.
49+
50+
# Step 3: Extract the main function operation from the MLIR module and print its metadata
51+
func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0]
52+
print(f"entry-point name: {func_op.name}")
53+
print(f"entry-point type: {func_op.type}")
54+
55+
# Step 4: output the imported MLIR module
56+
print("\n\nModule dump:")
57+
mlir_module_ir.dump()
58+
59+
# You can alternatively write the MLIR module to a file:
60+
# with open("output.mlir", "w") as f:
61+
# f.write(str(mlir_module_ir))
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
Example demonstrating how to load an already initialized PyTorch model
3+
to MLIR using Lighthouse.
4+
5+
The script uses the 'lighthouse.ingress.torch.import_from_model' function that
6+
takes an initialized PyTorch model (an instance of a Python class derived from 'nn.Module'),
7+
along with its sample inputs. The function passes the model to torch_mlir
8+
to get a MLIR module in the specified dialect.
9+
10+
The script uses a model from 'MLPModel/model.py' as an example.
11+
"""
12+
13+
import torch
14+
15+
# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module)
16+
import mlir.dialects.func as func
17+
from mlir import ir
18+
19+
# Lighthouse imports
20+
from lighthouse.ingress.torch import import_from_model
21+
22+
# Import a sample model definition
23+
from MLPModel.model import MLPModel
24+
25+
# Step 1: Instantiate a model class and prepare sample input
26+
model = MLPModel()
27+
sample_input = torch.randn(1, 10)
28+
29+
ir_context = ir.Context()
30+
# Step 2: Convert the PyTorch model to MLIR
31+
mlir_module_ir: ir.Module = import_from_model(
32+
model,
33+
sample_args=(sample_input,),
34+
ir_context=ir_context
35+
)
36+
37+
# The PyTorch model is now converted to MLIR at this point. You can now convert
38+
# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file.
39+
#
40+
# The following optional MLIR-processing steps are to give you an idea of what can
41+
# also be done with the MLIR module.
42+
43+
# Step 3: Extract the main function operation from the MLIR module and print its metadata
44+
func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0]
45+
print(f"entry-point name: {func_op.name}")
46+
print(f"entry-point type: {func_op.type}")
47+
48+
# Step 4: output the imported MLIR module
49+
print("\n\nModule dump:")
50+
mlir_module_ir.dump()
51+
52+
# You can alternatively write the MLIR module to a file:
53+
# with open("output.mlir", "w") as f:
54+
# f.write(str(mlir_module_ir))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Lighthouse Ingress
2+
3+
The `lighthouse.ingress` module converts various input formats to MLIR modules.
4+
5+
## Supported Formats
6+
7+
#### Torch
8+
Converts PyTorch models to MLIR using `lighthouse.ingress.torch`.
9+
10+
**Examples:** [torch examples](https://github.com/llvm/lighthouse/tree/main/python/examples/ingress/torch)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Provides functions to convert source objects (code, models, designs) into MLIR files that the MLIR project can consume"""
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Provides functions to convert PyTorch models to MLIR."""
2+
3+
from .importer import import_from_file, import_from_model

0 commit comments

Comments
 (0)