|
| 1 | +# ------------------------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# Licensed under the MIT License. See License.txt in the project root for |
| 4 | +# license information. |
| 5 | +# -------------------------------------------------------------------------- |
| 6 | + |
| 7 | +import os |
| 8 | +import unittest |
| 9 | + |
| 10 | +import onnx |
| 11 | +import torch |
| 12 | +from parity_utilities import find_transformers_source |
| 13 | + |
| 14 | +if find_transformers_source(): |
| 15 | + from fusion_options import FusionOptions |
| 16 | + from onnx_model import OnnxModel |
| 17 | + from optimizer import optimize_model |
| 18 | +else: |
| 19 | + from onnxruntime.transformers.fusion_options import FusionOptions |
| 20 | + from onnxruntime.transformers.onnx_model import OnnxModel |
| 21 | + from onnxruntime.transformers.optimizer import optimize_model |
| 22 | + |
| 23 | + |
| 24 | +# From https://github.com/huggingface/transformers/blob/34f76bb62b915b43617aa88557aea97840e163f0/src/transformers/activations.py#L90 |
| 25 | +class PhiVCLIPQuickGelu(torch.nn.Module): |
| 26 | + def __init__(self): |
| 27 | + super().__init__() |
| 28 | + |
| 29 | + def forward(self, x): |
| 30 | + return x * torch.sigmoid(1.702 * x) |
| 31 | + |
| 32 | + |
| 33 | +# Line-by-line calculation of https://github.com/huggingface/transformers/blob/34f76bb62b915b43617aa88557aea97840e163f0/src/transformers/models/clip/modeling_clip.py#L613 |
| 34 | +class PhiVCLIPLayerNorm(torch.nn.Module): |
| 35 | + def __init__(self): |
| 36 | + super().__init__() |
| 37 | + self.weight = torch.nn.Parameter(torch.ones(20)).to(torch.float16).detach() |
| 38 | + self.bias = torch.nn.Parameter(torch.ones(20)).to(torch.float16).detach() |
| 39 | + self.eps = 1e-05 |
| 40 | + |
| 41 | + def forward(self, x): |
| 42 | + mean = x.mean(-1, keepdim=True) |
| 43 | + diff = (x - mean).to(torch.float64) |
| 44 | + variance = diff.pow(2).mean(-1, keepdim=True) |
| 45 | + x = diff / torch.sqrt(variance + self.eps) |
| 46 | + x = x.to(torch.float16) * self.weight + self.bias |
| 47 | + return x |
| 48 | + |
| 49 | + |
| 50 | +# From https://github.com/huggingface/transformers/blob/34f76bb62b915b43617aa88557aea97840e163f0/src/transformers/models/clip/modeling_clip.py#L300 |
| 51 | +class PhiVCLIPAttention(torch.nn.Module): |
| 52 | + def __init__(self): |
| 53 | + super().__init__() |
| 54 | + self.embed_dim = 20 |
| 55 | + self.num_heads = 2 |
| 56 | + self.head_dim = self.embed_dim // self.num_heads |
| 57 | + |
| 58 | + self.scale = self.head_dim**-0.5 |
| 59 | + |
| 60 | + self.k_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) |
| 61 | + self.v_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) |
| 62 | + self.q_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) |
| 63 | + self.out_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) |
| 64 | + |
| 65 | + self.k_proj.weight.data.fill_(1) |
| 66 | + self.k_proj.bias.data.fill_(1) |
| 67 | + self.v_proj.weight.data.fill_(1) |
| 68 | + self.v_proj.bias.data.fill_(1) |
| 69 | + self.q_proj.weight.data.fill_(1) |
| 70 | + self.q_proj.bias.data.fill_(1) |
| 71 | + self.out_proj.weight.data.fill_(1) |
| 72 | + self.out_proj.bias.data.fill_(1) |
| 73 | + |
| 74 | + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| 75 | + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
| 76 | + |
| 77 | + def forward( |
| 78 | + self, |
| 79 | + hidden_states, |
| 80 | + attention_mask=None, |
| 81 | + causal_attention_mask=None, |
| 82 | + output_attentions=False, |
| 83 | + ): |
| 84 | + """Input shape: Batch x Time x Channel""" |
| 85 | + |
| 86 | + bsz, tgt_len, embed_dim = hidden_states.size() |
| 87 | + |
| 88 | + # get query proj |
| 89 | + query_states = self.q_proj(hidden_states) * self.scale |
| 90 | + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| 91 | + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
| 92 | + |
| 93 | + proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
| 94 | + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
| 95 | + key_states = key_states.view(*proj_shape) |
| 96 | + value_states = value_states.view(*proj_shape) |
| 97 | + |
| 98 | + src_len = key_states.size(1) |
| 99 | + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
| 100 | + |
| 101 | + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
| 102 | + raise ValueError( |
| 103 | + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
| 104 | + f" {attn_weights.size()}" |
| 105 | + ) |
| 106 | + |
| 107 | + # apply the causal_attention_mask first |
| 108 | + if causal_attention_mask is not None: |
| 109 | + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| 110 | + raise ValueError( |
| 111 | + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" |
| 112 | + f" {causal_attention_mask.size()}" |
| 113 | + ) |
| 114 | + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask |
| 115 | + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
| 116 | + |
| 117 | + if attention_mask is not None: |
| 118 | + if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| 119 | + raise ValueError( |
| 120 | + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
| 121 | + ) |
| 122 | + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
| 123 | + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
| 124 | + |
| 125 | + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) |
| 126 | + |
| 127 | + attn_probs = torch.nn.functional.dropout(attn_weights, p=0, training=False) |
| 128 | + |
| 129 | + attn_output = torch.bmm(attn_probs, value_states) |
| 130 | + |
| 131 | + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
| 132 | + raise ValueError( |
| 133 | + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" |
| 134 | + f" {attn_output.size()}" |
| 135 | + ) |
| 136 | + |
| 137 | + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
| 138 | + attn_output = attn_output.transpose(1, 2) |
| 139 | + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) |
| 140 | + |
| 141 | + attn_output = self.out_proj(attn_output) |
| 142 | + |
| 143 | + return attn_output |
| 144 | + |
| 145 | + |
| 146 | +class PhiVCLIPAttentionAndLayerNorm(torch.nn.Module): |
| 147 | + def __init__(self): |
| 148 | + super().__init__() |
| 149 | + self.attn = PhiVCLIPAttention() |
| 150 | + self.ln = torch.nn.LayerNorm(20, eps=1e-05) |
| 151 | + |
| 152 | + def forward(self, x): |
| 153 | + # SkipLayerNorm ------+ |
| 154 | + # | | |
| 155 | + # Attention | |
| 156 | + # | | |
| 157 | + # MatMul | |
| 158 | + # | | |
| 159 | + # SkipLayerNorm ------+ |
| 160 | + |
| 161 | + # SkipLayerNorm |
| 162 | + x = x + x |
| 163 | + x = self.ln(x) |
| 164 | + residual = x |
| 165 | + |
| 166 | + # Attention + MatMul |
| 167 | + x = self.attn(x) |
| 168 | + |
| 169 | + # SkipLayerNorm |
| 170 | + x = residual + x |
| 171 | + x = self.ln(x) |
| 172 | + return x |
| 173 | + |
| 174 | + |
| 175 | +class TestFusion(unittest.TestCase): |
| 176 | + def verify_fusion(self, optimized_model, expected_model_filename): |
| 177 | + optimized_model.topological_sort(is_deterministic=True) |
| 178 | + |
| 179 | + expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename) |
| 180 | + expected_model = OnnxModel(onnx.load(expected_model_path)) |
| 181 | + expected_model.topological_sort(is_deterministic=True) |
| 182 | + |
| 183 | + nodes = optimized_model.model.graph.node |
| 184 | + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) |
| 185 | + |
| 186 | + for i in range(len(nodes)): |
| 187 | + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) |
| 188 | + |
| 189 | + for expected_initializer in expected_model.model.graph.initializer: |
| 190 | + self.assertTrue( |
| 191 | + OnnxModel.has_same_value( |
| 192 | + optimized_model.get_initializer(expected_initializer.name), expected_initializer |
| 193 | + ) |
| 194 | + ) |
| 195 | + |
| 196 | + def export(self, model, inputs): |
| 197 | + torch.onnx.export( |
| 198 | + model, |
| 199 | + args=inputs, |
| 200 | + f=os.path.join(os.path.dirname(__file__), "export.onnx"), |
| 201 | + export_params=True, |
| 202 | + opset_version=14, |
| 203 | + do_constant_folding=True, |
| 204 | + ) |
| 205 | + |
| 206 | + def tearDown(self): |
| 207 | + path = os.path.join(os.path.dirname(__file__), "export.onnx") |
| 208 | + if os.path.exists(path): |
| 209 | + os.remove(path) |
| 210 | + |
| 211 | + def test_phi_vision_layernorm(self): |
| 212 | + if not torch.cuda.is_available(): |
| 213 | + return |
| 214 | + model = PhiVCLIPLayerNorm() |
| 215 | + inputs = (torch.randn(1, 2, 20).to(torch.float16),) |
| 216 | + self.export(model, inputs) |
| 217 | + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) |
| 218 | + options = FusionOptions("clip") |
| 219 | + optimized_model = optimize_model( |
| 220 | + original_model, |
| 221 | + model_type="clip", |
| 222 | + num_heads=2, |
| 223 | + hidden_size=20, |
| 224 | + optimization_options=options, |
| 225 | + opt_level=0, |
| 226 | + use_gpu=True, |
| 227 | + ) |
| 228 | + self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-layernorm.onnx") |
| 229 | + |
| 230 | + def test_phi_vision_quickgelu(self): |
| 231 | + model = PhiVCLIPQuickGelu() |
| 232 | + inputs = (torch.randn(1, 2, 20),) |
| 233 | + self.export(model, inputs) |
| 234 | + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) |
| 235 | + options = FusionOptions("clip") |
| 236 | + optimized_model = optimize_model( |
| 237 | + original_model, model_type="clip", num_heads=2, hidden_size=20, optimization_options=options, opt_level=0 |
| 238 | + ) |
| 239 | + self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-quickgelu.onnx") |
| 240 | + |
| 241 | + def test_phi_vision_attention(self): |
| 242 | + model = PhiVCLIPAttentionAndLayerNorm() |
| 243 | + inputs = (torch.randn(1, 2, 20),) |
| 244 | + self.export(model, inputs) |
| 245 | + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) |
| 246 | + options = FusionOptions("clip") |
| 247 | + optimized_model = optimize_model( |
| 248 | + original_model, model_type="clip", num_heads=2, hidden_size=20, optimization_options=options, opt_level=0 |
| 249 | + ) |
| 250 | + self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-attention.onnx") |
| 251 | + |
| 252 | + |
| 253 | +if __name__ == "__main__": |
| 254 | + unittest.main() |
0 commit comments