Skip to content

Commit efc4af2

Browse files
committed
Add resnet 50
1 parent a394262 commit efc4af2

File tree

4 files changed

+102
-10
lines changed

4 files changed

+102
-10
lines changed

optimum/exporters/executorch/integrations.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Dict
16+
from typing import Dict, Optional
1717

1818
import torch
1919
from packaging.version import parse
@@ -173,25 +173,35 @@ class VisionEncoderExportableModule(torch.nn.Module):
173173
This module ensures that the exported model is compatible with ExecuTorch.
174174
"""
175175

176-
def __init__(self, model):
176+
def __init__(self, model, model_id: Optional[str] = None):
177177
super().__init__()
178178
self.model = model
179179
self.config = model.config
180180
# Metadata to be recorded in the pte model file
181181
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
182-
182+
183+
self.model_id = model_id
184+
183185
def forward(self, pixel_values):
184186
print(f"DEBUG: pixel_values: {pixel_values.shape}")
185187
print(f"DEBUG: forward: {self.model.method_meta('forward')}")
186188
return self.model(pixel_values=pixel_values)
187189

188190
def export(self, pixel_values=None) -> Dict[str, ExportedProgram]:
189191
if pixel_values is None:
190-
batch_size = 1
191-
num_channels = self.config.num_channels
192-
height = self.config.image_size
193-
width = self.config.image_size
194-
pixel_values = torch.rand(batch_size, num_channels, height, width)
192+
model_to_pixel_values_size = {
193+
"microsoft/resnet-50": [1, 3, 224, 224],
194+
}
195+
if self.model_id in model_to_pixel_values_size:
196+
# If an explicit shape is provided for this model, use it
197+
pixel_values = torch.rand(*model_to_pixel_values_size[self.model_id])
198+
else:
199+
# If no explicit shape is provided for this model, infer a shape from config
200+
batch_size = 1
201+
num_channels = self.config.num_channels
202+
height = self.config.image_size
203+
width = self.config.image_size
204+
pixel_values = torch.rand(batch_size, num_channels, height, width)
195205

196206
with torch.no_grad():
197207
return {

optimum/exporters/executorch/recipes/coreml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _lower_to_executorch(
9696
],
9797
compile_config=EdgeCompileConfig(
9898
_check_ir_validity=False,
99-
_skip_dim_order=False,
99+
_skip_dim_order=True,
100100
),
101101
constant_methods=metadata,
102102
).to_executorch(

optimum/exporters/executorch/tasks/image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ def load_image_classification_model(model_name_or_path: str, **kwargs) -> Vision
3939
"""
4040

4141
eager_model = AutoModelForImageClassification.from_pretrained(model_name_or_path, **kwargs).to("cpu").eval()
42-
return VisionEncoderExportableModule(eager_model)
42+
return VisionEncoderExportableModule(eager_model, model_name_or_path)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import subprocess
18+
import sys
19+
import tempfile
20+
import unittest
21+
22+
import pytest
23+
import torch
24+
from transformers.testing_utils import slow
25+
26+
from optimum.executorch import ExecuTorchModelForImageClassification
27+
28+
from ..utils import check_close_recursively
29+
30+
31+
is_not_macos = sys.platform != "darwin"
32+
33+
34+
class ExecuTorchModelIntegrationTest(unittest.TestCase):
35+
def __init__(self, *args, **kwargs):
36+
super().__init__(*args, **kwargs)
37+
38+
@slow
39+
@pytest.mark.run_slow
40+
def test_vit_export_to_executorch(self):
41+
model_id = "microsoft/resnet-50"
42+
task = "image-classification"
43+
recipe = "xnnpack"
44+
with tempfile.TemporaryDirectory() as tempdir:
45+
subprocess.run(
46+
f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch",
47+
shell=True,
48+
check=True,
49+
)
50+
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))
51+
52+
@slow
53+
@pytest.mark.run_slow
54+
@pytest.mark.skipif(is_not_macos, reason="Only runs on MacOS")
55+
def test_vit_image_classification_coreml_fp32_cpu(self):
56+
model_id = "microsoft/resnet-50"
57+
58+
batch_size = 1
59+
num_channels = 3
60+
height = 224
61+
width = 224
62+
pixel_values = torch.rand(batch_size, num_channels, height, width)
63+
64+
# Test fetching and lowering the model to ExecuTorch
65+
import coremltools as ct
66+
67+
et_model = ExecuTorchModelForImageClassification.from_pretrained(
68+
model_id=model_id,
69+
recipe="coreml",
70+
recipe_kwargs={"compute_precision": ct.precision.FLOAT32, "compute_units": ct.ComputeUnit.CPU_ONLY},
71+
)
72+
et_output = et_model.forward(pixel_values)
73+
74+
# Reference (using XNNPACK as reference because eager model currently segfaults in a PyTorch kernel)
75+
et_xnnpack = ExecuTorchModelForImageClassification.from_pretrained(
76+
model_id=model_id,
77+
recipe="xnnpack",
78+
)
79+
et_xnnpack_output = et_xnnpack.forward(pixel_values)
80+
81+
# Compare with reference
82+
self.assertTrue(check_close_recursively(et_output, et_xnnpack_output))

0 commit comments

Comments
 (0)