|
| 1 | +import importlib |
| 2 | +import importlib.util |
| 3 | +from pathlib import Path |
| 4 | +from typing import Iterable, Mapping |
| 5 | + |
| 6 | +try: |
| 7 | + import torch |
| 8 | + import torch.nn as nn |
| 9 | +except ImportError as e: |
| 10 | + raise ImportError( |
| 11 | + "PyTorch is required to use the torch import functionality. " |
| 12 | + "Please run 'uv pip install .[torch-mlir]'" |
| 13 | + ) from e |
| 14 | + |
| 15 | +try: |
| 16 | + from torch_mlir import fx |
| 17 | + from torch_mlir.fx import OutputType |
| 18 | +except ImportError as e: |
| 19 | + raise ImportError( |
| 20 | + "torch-mlir is required to use the torch import functionality. " |
| 21 | + "Please run 'uv pip install .[torch-mlir]'" |
| 22 | + ) from e |
| 23 | + |
| 24 | +from mlir import ir |
| 25 | + |
| 26 | +def import_from_model( |
| 27 | + model: nn.Module, |
| 28 | + sample_args : Iterable, |
| 29 | + sample_kwargs : Mapping = None, |
| 30 | + dialect : OutputType | str = OutputType.LINALG_ON_TENSORS, |
| 31 | + ir_context : ir.Context | None = None, |
| 32 | + **kwargs, |
| 33 | +) -> str | ir.Module: |
| 34 | + """ |
| 35 | + Import a PyTorch nn.Module into MLIR. |
| 36 | +
|
| 37 | + The function uses torch-mlir's FX importer to convert the given PyTorch model |
| 38 | + into an MLIR module in the specified dialect. The user has to provide sample |
| 39 | + input arguments (e.g. a torch.Tensor with the correct shape). |
| 40 | +
|
| 41 | + Parameters |
| 42 | + ---------- |
| 43 | + model : nn.Module |
| 44 | + The PyTorch model to import. |
| 45 | + sample_args : Iterable |
| 46 | + Sample input arguments to the model. |
| 47 | + sample_kwargs : Mapping, optional |
| 48 | + Sample keyword arguments to the model. |
| 49 | + dialect : torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}, default: OutputType.LINALG_ON_TENSORS |
| 50 | + The target dialect for the imported MLIR module. |
| 51 | + ir_context : ir.Context, optional |
| 52 | + An optional MLIR context to use for parsing the module. |
| 53 | + If not provided, the module is returned as a string. |
| 54 | + **kwargs |
| 55 | + Additional keyword arguments passed to the fx.export_and_import function. |
| 56 | + |
| 57 | + Returns |
| 58 | + ------- |
| 59 | + str | ir.Module |
| 60 | + The imported MLIR module as a string or an ir.Module if `ir_context` is provided. |
| 61 | + |
| 62 | + Examples |
| 63 | + -------- |
| 64 | + >>> import torch |
| 65 | + >>> import torch.nn as nn |
| 66 | + >>> from lighthouse.ingress.torch_import import import_from_model |
| 67 | + >>> class SimpleModel(nn.Module): |
| 68 | + ... def __init__(self): |
| 69 | + ... super().__init__() |
| 70 | + ... self.fc = nn.Linear(10, 5) |
| 71 | + ... def forward(self, x): |
| 72 | + ... return self.fc(x) |
| 73 | + >>> model = SimpleModel() |
| 74 | + >>> sample_input = (torch.randn(1, 10),) |
| 75 | + >>> # |
| 76 | + >>> # option 1: get MLIR module as a string |
| 77 | + >>> mlir_module : str = import_from_model(model, sample_input, dialect="linalg-on-tensors") |
| 78 | + >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect |
| 79 | + >>> # option 2: get MLIR module as an ir.Module |
| 80 | + >>> ir_context = ir.Context() |
| 81 | + >>> mlir_module_ir : ir.Module = import_from_model(model, sample_input, dialect="tosa", ir_context=ir_context) |
| 82 | + >>> # ... run pm.Pipeline on the ir.Module ... |
| 83 | + """ |
| 84 | + if dialect == "linalg": |
| 85 | + raise ValueError( |
| 86 | + "Dialect 'linalg' is not supported. Did you mean 'linalg-on-tensors'?" |
| 87 | + ) |
| 88 | + |
| 89 | + if sample_kwargs is None: |
| 90 | + sample_kwargs = {} |
| 91 | + |
| 92 | + model.eval() |
| 93 | + module = fx.export_and_import( |
| 94 | + model, *sample_args, output_type=dialect, **sample_kwargs, **kwargs |
| 95 | + ) |
| 96 | + |
| 97 | + text_module = str(module) |
| 98 | + if ir_context is None: |
| 99 | + return text_module |
| 100 | + return ir.Module.parse(text_module, context=ir_context) |
| 101 | + |
| 102 | + |
| 103 | +def import_from_file( |
| 104 | + filepath: str | Path, |
| 105 | + model_class_name: str = "Model", |
| 106 | + init_args_fn_name: str | None = "get_init_inputs", |
| 107 | + inputs_args_fn_name: str = "get_inputs", |
| 108 | + state_path : str | Path | None = None, |
| 109 | + dialect : OutputType | str = OutputType.LINALG_ON_TENSORS, |
| 110 | + ir_context : ir.Context | None = None, |
| 111 | + **kwargs, |
| 112 | +) -> str | ir.Module: |
| 113 | + """ |
| 114 | + Load a PyTorch nn.Module from a file and import it into MLIR. |
| 115 | +
|
| 116 | + The function takes a `filepath` to a Python file containing the model definition, |
| 117 | + along with the names of functions to get model init arguments and sample inputs. |
| 118 | + The function imports the model class on its own, instantiates it, and passes |
| 119 | + it ``torch_mlir`` to get a MLIR module in the specified `dialect`. |
| 120 | +
|
| 121 | + Parameters |
| 122 | + ---------- |
| 123 | + filepath : str | Path |
| 124 | + Path to the Python file containing the model definition. |
| 125 | + model_class_name : str, default: "Model" |
| 126 | + The name of the model class in the file. |
| 127 | + init_args_fn_name : str | None, default: "get_init_inputs" |
| 128 | + The name of the function in the file that returns the arguments for |
| 129 | + initializing the model. If None, the model is initialized without arguments. |
| 130 | + inputs_args_fn_name : str, default: "get_inputs" |
| 131 | + The name of the function in the file that returns the sample input arguments |
| 132 | + for the model. |
| 133 | + state_path : str | Path | None, default: None |
| 134 | + Optional path to a file containing the model's ``state_dict``. |
| 135 | + dialect: torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}, default: OutputType.LINALG_ON_TENSORS |
| 136 | + The target dialect for the imported MLIR module. |
| 137 | + ir_context : ir.Context, optional |
| 138 | + An optional MLIR context to use for parsing the module. |
| 139 | + If not provided, the module is returned as a string. |
| 140 | + **kwargs |
| 141 | + Additional keyword arguments passed to the fx.export_and_import function. |
| 142 | + |
| 143 | + Returns |
| 144 | + ------- |
| 145 | + str | ir.Module |
| 146 | + The imported MLIR module as a string or an ir.Module if `ir_context` is provided. |
| 147 | + |
| 148 | + Examples |
| 149 | + -------- |
| 150 | + Given a file `path/to/model_file.py` with the following content: |
| 151 | + ```python |
| 152 | + import torch |
| 153 | + import torch.nn as nn |
| 154 | +
|
| 155 | + class MyModel(nn.Module): |
| 156 | + def __init__(self): |
| 157 | + super().__init__() |
| 158 | + self.fc = nn.Linear(10, 5) |
| 159 | + def forward(self, x): |
| 160 | + return self.fc(x) |
| 161 | + |
| 162 | + def get_inputs(): |
| 163 | + return (torch.randn(1, 10),) |
| 164 | + ``` |
| 165 | +
|
| 166 | + The import script would look line: |
| 167 | + >>> from lighthouse.ingress.torch_import import import_from_file |
| 168 | + >>> # option 1: get MLIR module as a string |
| 169 | + >>> mlir_module : str = import_from_file( |
| 170 | + ... "path/to/model_file.py", |
| 171 | + ... model_class_name="MyModel", |
| 172 | + ... init_args_fn_name=None, |
| 173 | + ... dialect="linalg-on-tensors" |
| 174 | + ... ) |
| 175 | + >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect |
| 176 | + >>> # option 2: get MLIR module as an ir.Module |
| 177 | + >>> ir_context = ir.Context() |
| 178 | + >>> mlir_module_ir : ir.Module = import_from_file( |
| 179 | + ... "path/to/model_file.py", |
| 180 | + ... model_class_name="MyModel", |
| 181 | + ... init_args_fn_name=None, |
| 182 | + ... dialect="linalg-on-tensors", |
| 183 | + ... ir_context=ir_context |
| 184 | + ... ) |
| 185 | + >>> # ... run pm.Pipeline on the ir.Module ... |
| 186 | + """ |
| 187 | + if isinstance(filepath, str): |
| 188 | + filepath = Path(filepath) |
| 189 | + module_name = filepath.stem |
| 190 | + |
| 191 | + spec = importlib.util.spec_from_file_location(module_name, filepath) |
| 192 | + module = importlib.util.module_from_spec(spec) |
| 193 | + spec.loader.exec_module(module) |
| 194 | + |
| 195 | + model = getattr(module, model_class_name, None) |
| 196 | + if model is None: |
| 197 | + raise ValueError(f"Model class '{model_class_name}' not found in {filepath}") |
| 198 | + |
| 199 | + if init_args_fn_name is None: |
| 200 | + init_args_fn = lambda *args, **kwargs: () |
| 201 | + else: |
| 202 | + init_args_fn = getattr(module, init_args_fn_name, None) |
| 203 | + if init_args_fn is None: |
| 204 | + raise ValueError(f"Init args function '{init_args_fn_name}' not found in {filepath}") |
| 205 | + |
| 206 | + inputs_args_fn = getattr(module, inputs_args_fn_name, None) |
| 207 | + if inputs_args_fn is None: |
| 208 | + raise ValueError(f"Inputs args function '{inputs_args_fn_name}' not found in {filepath}") |
| 209 | + |
| 210 | + nn_model : nn.Module = model(*init_args_fn()) |
| 211 | + if state_path is not None: |
| 212 | + state_dict = torch.load(state_path) |
| 213 | + nn_model.load_state_dict(state_dict) |
| 214 | + |
| 215 | + return import_from_model(nn_model, *inputs_args_fn(), dialect=dialect, ir_context=ir_context, **kwargs) |
0 commit comments