|
| 1 | +Using scripts in this directory one can convert a Torch Model to a MLIR module. |
| 2 | + |
| 3 | +The conversion script is written in python and is basically a wrapper around [`torch-mlir` library](https://github.com/llvm/torch-mlir). One need to setup a python virtual environment with torch and torch-mlir libraries |
| 4 | +(`./scripts/install-virtualenv.sh`) to use the script. |
| 5 | + |
| 6 | +In order to convert a model the script has to recieve: |
| 7 | +1. An instance of `torch.nn.Model` with proper state (weights). |
| 8 | +2. Sample input arguments to the model (e.g. empty tensor with proper shape and dtype). |
| 9 | + |
| 10 | +There are two options of how this info can be provided to the converter: |
| 11 | + |
| 12 | +### 1. Instantiate a model in your own script and use a function from the `py_src/export_lib` (recomended) |
| 13 | + |
| 14 | +In this scenario a user is responsible for instantiating a model with proper state in their |
| 15 | +own python script. Then they should import a `generate_mlir` function from `py_src.export_lib` |
| 16 | +and call it in order to get a MLIR module: |
| 17 | + |
| 18 | +```python |
| 19 | +model : nn.Model = get_model() |
| 20 | +sample_args = (get_sample_tensor(),) |
| 21 | + |
| 22 | +# PYTHONPATH=$(pwd)/py_src/ |
| 23 | +from export_lib import generate_mlir |
| 24 | + |
| 25 | +mlir_module = generate_mlir(model, sample_args, dialect="linalg") |
| 26 | +print(mlir_module) |
| 27 | +``` |
| 28 | + |
| 29 | +### 2. Use `py_src/main.py` or `scripts/generate-mlir.sh` and pass Torch Model parameters via CLI |
| 30 | + |
| 31 | +In this scenario the `py_src/main.py` script is fully responsible for instantiating a torch model |
| 32 | +and converting it to MLIR. User has to pass a proper python entrypoint for model's factory, |
| 33 | +its parameters if needed (`--model-args & --model-kwargs`), and sample model arguments (either |
| 34 | +as `--sample-shapes` or as an entrypoint to a function returning args and kwargs `--sample-fn`). |
| 35 | + |
| 36 | +``` |
| 37 | +# note that 'my_module' has to be in $PYTHONPATH |
| 38 | +python py_src/main.py --model-entrypoint my_module:my_factory \ |
| 39 | + --module-state-path path/to/state.pth \ |
| 40 | + --sample-shapes '1,2,324,float32' \ |
| 41 | + --out-mlir res.mlir |
| 42 | +
|
| 43 | +# note that 'my_module' has to be in $PYTHONPATH |
| 44 | +./scripts/generate-mlir.sh --model-entrypoint torchvision.models:resnet18 \ |
| 45 | + --sample-fn my_module:generate_resnet18_sample_args \ |
| 46 | + --out-mlir res.mlir |
| 47 | +``` |
| 48 | + |
| 49 | +Look into `examples/` folder for more info. |
0 commit comments