Skip to content

Commit ab59e9e

Browse files
Add unit tests for Phi vision (microsoft#23357)
### Description This PR adds unit tests for [fusing the vision components](microsoft#20721) of Phi-3 vision and Phi-3.5 vision. ### Motivation and Context Many multi-modal models use a CLIP encoder or a variant of CLIP as part of their encoders. These fusion unit tests will ensure that the vision components of Phi-3 vision and Phi-3.5 vision can still be fused when existing fusions are modified to support more models.
1 parent b67983c commit ab59e9e

File tree

4 files changed

+270
-0
lines changed

4 files changed

+270
-0
lines changed
Binary file not shown.
Binary file not shown.

onnxruntime/test/python/transformers/test_data/models/phi-3.5-v-instruct-vision-quickgelu.onnx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
onnxruntime.transformers1.20.1:�
2+
H
3+
onnx::Mul_04 QuickGelu_0" QuickGelu*
4+
alpha#��?�:com.microsoft
5+
main_graphZ!
6+
onnx::Mul_0
7+

8+

9+

10+
b
11+
4
12+

13+

14+

15+
BB
16+
com.microsoft
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
import os
8+
import unittest
9+
10+
import onnx
11+
import torch
12+
from parity_utilities import find_transformers_source
13+
14+
if find_transformers_source():
15+
from fusion_options import FusionOptions
16+
from onnx_model import OnnxModel
17+
from optimizer import optimize_model
18+
else:
19+
from onnxruntime.transformers.fusion_options import FusionOptions
20+
from onnxruntime.transformers.onnx_model import OnnxModel
21+
from onnxruntime.transformers.optimizer import optimize_model
22+
23+
24+
# From https://github.com/huggingface/transformers/blob/34f76bb62b915b43617aa88557aea97840e163f0/src/transformers/activations.py#L90
25+
class PhiVCLIPQuickGelu(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
29+
def forward(self, x):
30+
return x * torch.sigmoid(1.702 * x)
31+
32+
33+
# Line-by-line calculation of https://github.com/huggingface/transformers/blob/34f76bb62b915b43617aa88557aea97840e163f0/src/transformers/models/clip/modeling_clip.py#L613
34+
class PhiVCLIPLayerNorm(torch.nn.Module):
35+
def __init__(self):
36+
super().__init__()
37+
self.weight = torch.nn.Parameter(torch.ones(20)).to(torch.float16).detach()
38+
self.bias = torch.nn.Parameter(torch.ones(20)).to(torch.float16).detach()
39+
self.eps = 1e-05
40+
41+
def forward(self, x):
42+
mean = x.mean(-1, keepdim=True)
43+
diff = (x - mean).to(torch.float64)
44+
variance = diff.pow(2).mean(-1, keepdim=True)
45+
x = diff / torch.sqrt(variance + self.eps)
46+
x = x.to(torch.float16) * self.weight + self.bias
47+
return x
48+
49+
50+
# From https://github.com/huggingface/transformers/blob/34f76bb62b915b43617aa88557aea97840e163f0/src/transformers/models/clip/modeling_clip.py#L300
51+
class PhiVCLIPAttention(torch.nn.Module):
52+
def __init__(self):
53+
super().__init__()
54+
self.embed_dim = 20
55+
self.num_heads = 2
56+
self.head_dim = self.embed_dim // self.num_heads
57+
58+
self.scale = self.head_dim**-0.5
59+
60+
self.k_proj = torch.nn.Linear(self.embed_dim, self.embed_dim)
61+
self.v_proj = torch.nn.Linear(self.embed_dim, self.embed_dim)
62+
self.q_proj = torch.nn.Linear(self.embed_dim, self.embed_dim)
63+
self.out_proj = torch.nn.Linear(self.embed_dim, self.embed_dim)
64+
65+
self.k_proj.weight.data.fill_(1)
66+
self.k_proj.bias.data.fill_(1)
67+
self.v_proj.weight.data.fill_(1)
68+
self.v_proj.bias.data.fill_(1)
69+
self.q_proj.weight.data.fill_(1)
70+
self.q_proj.bias.data.fill_(1)
71+
self.out_proj.weight.data.fill_(1)
72+
self.out_proj.bias.data.fill_(1)
73+
74+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
75+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
76+
77+
def forward(
78+
self,
79+
hidden_states,
80+
attention_mask=None,
81+
causal_attention_mask=None,
82+
output_attentions=False,
83+
):
84+
"""Input shape: Batch x Time x Channel"""
85+
86+
bsz, tgt_len, embed_dim = hidden_states.size()
87+
88+
# get query proj
89+
query_states = self.q_proj(hidden_states) * self.scale
90+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
91+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
92+
93+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
94+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
95+
key_states = key_states.view(*proj_shape)
96+
value_states = value_states.view(*proj_shape)
97+
98+
src_len = key_states.size(1)
99+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
100+
101+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
102+
raise ValueError(
103+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
104+
f" {attn_weights.size()}"
105+
)
106+
107+
# apply the causal_attention_mask first
108+
if causal_attention_mask is not None:
109+
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
110+
raise ValueError(
111+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
112+
f" {causal_attention_mask.size()}"
113+
)
114+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
115+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
116+
117+
if attention_mask is not None:
118+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
119+
raise ValueError(
120+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
121+
)
122+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
123+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
124+
125+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
126+
127+
attn_probs = torch.nn.functional.dropout(attn_weights, p=0, training=False)
128+
129+
attn_output = torch.bmm(attn_probs, value_states)
130+
131+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
132+
raise ValueError(
133+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
134+
f" {attn_output.size()}"
135+
)
136+
137+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
138+
attn_output = attn_output.transpose(1, 2)
139+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
140+
141+
attn_output = self.out_proj(attn_output)
142+
143+
return attn_output
144+
145+
146+
class PhiVCLIPAttentionAndLayerNorm(torch.nn.Module):
147+
def __init__(self):
148+
super().__init__()
149+
self.attn = PhiVCLIPAttention()
150+
self.ln = torch.nn.LayerNorm(20, eps=1e-05)
151+
152+
def forward(self, x):
153+
# SkipLayerNorm ------+
154+
# | |
155+
# Attention |
156+
# | |
157+
# MatMul |
158+
# | |
159+
# SkipLayerNorm ------+
160+
161+
# SkipLayerNorm
162+
x = x + x
163+
x = self.ln(x)
164+
residual = x
165+
166+
# Attention + MatMul
167+
x = self.attn(x)
168+
169+
# SkipLayerNorm
170+
x = residual + x
171+
x = self.ln(x)
172+
return x
173+
174+
175+
class TestFusion(unittest.TestCase):
176+
def verify_fusion(self, optimized_model, expected_model_filename):
177+
optimized_model.topological_sort(is_deterministic=True)
178+
179+
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
180+
expected_model = OnnxModel(onnx.load(expected_model_path))
181+
expected_model.topological_sort(is_deterministic=True)
182+
183+
nodes = optimized_model.model.graph.node
184+
self.assertEqual(len(nodes), len(expected_model.model.graph.node))
185+
186+
for i in range(len(nodes)):
187+
self.assertEqual(nodes[i], expected_model.model.graph.node[i])
188+
189+
for expected_initializer in expected_model.model.graph.initializer:
190+
self.assertTrue(
191+
OnnxModel.has_same_value(
192+
optimized_model.get_initializer(expected_initializer.name), expected_initializer
193+
)
194+
)
195+
196+
def export(self, model, inputs):
197+
torch.onnx.export(
198+
model,
199+
args=inputs,
200+
f=os.path.join(os.path.dirname(__file__), "export.onnx"),
201+
export_params=True,
202+
opset_version=14,
203+
do_constant_folding=True,
204+
)
205+
206+
def tearDown(self):
207+
path = os.path.join(os.path.dirname(__file__), "export.onnx")
208+
if os.path.exists(path):
209+
os.remove(path)
210+
211+
def test_phi_vision_layernorm(self):
212+
if not torch.cuda.is_available():
213+
return
214+
model = PhiVCLIPLayerNorm()
215+
inputs = (torch.randn(1, 2, 20).to(torch.float16),)
216+
self.export(model, inputs)
217+
original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx"))
218+
options = FusionOptions("clip")
219+
optimized_model = optimize_model(
220+
original_model,
221+
model_type="clip",
222+
num_heads=2,
223+
hidden_size=20,
224+
optimization_options=options,
225+
opt_level=0,
226+
use_gpu=True,
227+
)
228+
self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-layernorm.onnx")
229+
230+
def test_phi_vision_quickgelu(self):
231+
model = PhiVCLIPQuickGelu()
232+
inputs = (torch.randn(1, 2, 20),)
233+
self.export(model, inputs)
234+
original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx"))
235+
options = FusionOptions("clip")
236+
optimized_model = optimize_model(
237+
original_model, model_type="clip", num_heads=2, hidden_size=20, optimization_options=options, opt_level=0
238+
)
239+
self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-quickgelu.onnx")
240+
241+
def test_phi_vision_attention(self):
242+
model = PhiVCLIPAttentionAndLayerNorm()
243+
inputs = (torch.randn(1, 2, 20),)
244+
self.export(model, inputs)
245+
original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx"))
246+
options = FusionOptions("clip")
247+
optimized_model = optimize_model(
248+
original_model, model_type="clip", num_heads=2, hidden_size=20, optimization_options=options, opt_level=0
249+
)
250+
self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-attention.onnx")
251+
252+
253+
if __name__ == "__main__":
254+
unittest.main()

0 commit comments

Comments
 (0)