Skip to content

Commit 499db16

Browse files
committed
Restructure ingress/torch-mlir
Signed-off-by: dchigarev <dmitry.chigarev@intel.com>
1 parent 20ca817 commit 499db16

File tree

18 files changed

+232
-137
lines changed

18 files changed

+232
-137
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
__pycache__
2+
ingress/Torch-MLIR/examples/**/dumps/*.mlir

ingress/Torch-MLIR/README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
Using scripts in this directory one can convert a Torch Model to a MLIR module.
2+
3+
The conversion script is written in python and is basically a wrapper around [`torch-mlir` library](https://github.com/llvm/torch-mlir). One need to setup a python virtual environment with torch and torch-mlir libraries
4+
(`./scripts/install-virtualenv.sh`) to use the script.
5+
6+
In order to convert a model the script has to recieve:
7+
1. An instance of `torch.nn.Model` with proper state (weights).
8+
2. Sample input arguments to the model (e.g. empty tensor with proper shape and dtype).
9+
10+
There are two options of how this info can be provided to the converter:
11+
12+
### 1. Instantiate a model in your own script and use a function from the `py_src/export_lib` (recomended)
13+
14+
In this scenario a user is responsible for instantiating a model with proper state in their
15+
own python script. Then they should import a `generate_mlir` function from `py_src.export_lib`
16+
and call it in order to get a MLIR module:
17+
18+
```python
19+
model : nn.Model = get_model()
20+
sample_args = (get_sample_tensor(),)
21+
22+
# PYTHONPATH=$(pwd)/py_src/
23+
from export_lib import generate_mlir
24+
25+
mlir_module = generate_mlir(model, sample_args, dialect="linalg")
26+
print(mlir_module)
27+
```
28+
29+
### 2. Use `py_src/main.py` or `scripts/generate-mlir.sh` and pass Torch Model parameters via CLI
30+
31+
In this scenario the `py_src/main.py` script is fully responsible for instantiating a torch model
32+
and converting it to MLIR. User has to pass a proper python entrypoint for model's factory,
33+
its parameters if needed (`--model-args & --model-kwargs`), and sample model arguments (either
34+
as `--sample-shapes` or as an entrypoint to a function returning args and kwargs `--sample-fn`).
35+
36+
```
37+
# note that 'my_module' has to be in $PYTHONPATH
38+
python py_src/main.py --model-entrypoint my_module:my_factory \
39+
--module-state-path path/to/state.pth \
40+
--sample-shapes '1,2,324,float32' \
41+
--out-mlir res.mlir
42+
43+
# note that 'my_module' has to be in $PYTHONPATH
44+
./scripts/generate-mlir.sh --model-entrypoint torchvision.models:resnet18 \
45+
--sample-fn my_module:generate_resnet18_sample_args \
46+
--out-mlir res.mlir
47+
```
48+
49+
Look into `examples/` folder for more info.
4.11 KB
Binary file not shown.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
import os
5+
6+
class DummyMLP(nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.net = nn.Sequential(
10+
nn.Linear(10, 32),
11+
nn.ReLU(),
12+
nn.Linear(32, 2)
13+
)
14+
15+
def forward(self, x):
16+
return self.net(x)
17+
18+
def make_dummy_mlp():
19+
return DummyMLP()
20+
21+
if __name__ == "__main__":
22+
script_dir = os.path.dirname(os.path.abspath(__file__))
23+
torch.save(make_dummy_mlp().state_dict(), os.path.join(script_dir, "dummy_mlp.pth"))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env bash
2+
3+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
4+
ROOT_DIR=$SCRIPT_DIR/../../scripts/
5+
6+
PYTHONPATH=$PYTHONPATH:$SCRIPT_DIR $ROOT_DIR/generate-mlir.sh --model-entrypoint dummy_mlp_factory:make_dummy_mlp \
7+
--model-state-path $SCRIPT_DIR/dummy_mlp.pth \
8+
--sample-shapes "1,10,float32" \
9+
--dialect linalg \
10+
--out-mlir $SCRIPT_DIR/dummy_mlp_sh.mlir
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env bash
2+
3+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
4+
ROOT_DIR=$SCRIPT_DIR/../../py_src/
5+
6+
PYTHONPATH=$PYTHONPATH:$ROOT_DIR:$SCRIPT_DIR python $ROOT_DIR/main.py --model-entrypoint dummy_mlp_factory:make_dummy_mlp \
7+
--model-state-path $SCRIPT_DIR/dummy_mlp.pth \
8+
--sample-shapes "1,10,float32" \
9+
--dialect linalg \
10+
--out-mlir $SCRIPT_DIR/dummy_mlp.mlir
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from export_lib.export import generate_mlir
5+
6+
class DummyMLP(nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.net = nn.Sequential(
10+
nn.Linear(10, 32),
11+
nn.ReLU(),
12+
nn.Linear(32, 2)
13+
)
14+
15+
def forward(self, x):
16+
return self.net(x)
17+
18+
def main():
19+
model = DummyMLP()
20+
dummy_input = torch.randn(1, 10)
21+
mlir_mod = generate_mlir(model, (dummy_input,), {})
22+
print(mlir_mod)
23+
24+
if __name__ == "__main__":
25+
main()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
2+
ROOT_DIR=$SCRIPT_DIR/../../py_src/
3+
4+
PYTHONPATH=$ROOT_DIR python $SCRIPT_DIR/export.py
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/usr/bin/env bash
2+
3+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
4+
ROOT_DIR=$SCRIPT_DIR/../../scripts/
5+
6+
$ROOT_DIR/generate-mlir.sh --model-entrypoint torchvision.models:resnet18 \
7+
--sample-shapes "1,3,224,224,float32" \
8+
--dialect linalg \
9+
--out-mlir $SCRIPT_DIR/resnet_18_sh.mlir
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/usr/bin/env bash
2+
3+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
4+
ROOT_DIR=$SCRIPT_DIR/../../py_src/
5+
6+
python $ROOT_DIR/main.py --model-entrypoint torchvision.models:resnet18 \
7+
--sample-shapes "1,3,224,224,float32" \
8+
--dialect linalg \
9+
--out-mlir $SCRIPT_DIR/resnet_18.mlir

0 commit comments

Comments
 (0)