Skip to content

Commit 01c1ab3

Browse files
authored
Fix export for pt 2.8 (#2288)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 75e1be0 commit 01c1ab3

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

neural_compressor/torch/export/pt2e_export.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@
1313
# limitations under the License.
1414
"""Export model for quantization."""
1515

16+
from functools import partial
1617
from typing import Any, Dict, Optional, Tuple, Union
1718

1819
import torch
1920
from torch.fx.graph_module import GraphModule
2021

2122
from neural_compressor.common.utils import logger
22-
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, TORCH_VERSION_2_7_0, get_torch_version, is_ipex_imported
23+
from neural_compressor.torch.utils import (
24+
TORCH_VERSION_2_2_2,
25+
TORCH_VERSION_2_7_0,
26+
TORCH_VERSION_2_8_0,
27+
get_torch_version,
28+
is_ipex_imported,
29+
)
2330

2431
__all__ = ["export", "export_model_for_pt2e_quant"]
2532

@@ -52,7 +59,10 @@ def export_model_for_pt2e_quant(
5259
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
5360
# updated to use the official `torch.export` API when that is ready.
5461
cur_version = get_torch_version()
55-
if cur_version >= TORCH_VERSION_2_7_0:
62+
if cur_version >= TORCH_VERSION_2_8_0:
63+
export_func = torch.export.export
64+
export_func = partial(export_func, strict=True)
65+
elif cur_version >= TORCH_VERSION_2_7_0:
5666
export_func = torch.export.export_for_training
5767
else:
5868
export_func = torch._export.capture_pre_autograd_graph

neural_compressor/torch/utils/environ.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def get_ipex_version():
152152

153153
TORCH_VERSION_2_2_2 = Version("2.2.2")
154154
TORCH_VERSION_2_7_0 = Version("2.7.0")
155+
TORCH_VERSION_2_8_0 = Version("2.8.0")
155156

156157

157158
def get_torch_version():

0 commit comments

Comments
 (0)