@@ -21,6 +21,9 @@ def __init__(self, model):
2121
2222 def forward (self , pixel_values ):
2323 vision_embeds = self .model .extract_feature (pixel_values )
24+ # Reshape from [num_patches, 256, hidden_dim] -> [1, num_patches*256, head_dim]
25+ # To enable prefill chunking for num_patches > 1
26+ vision_embeds = vision_embeds .reshape (1 , - 1 , vision_embeds .shape [- 1 ])
2427 return vision_embeds
2528
2629
@@ -35,14 +38,22 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
3538 input_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
3639 B , N , C = input_embeds .shape
3740 image_input_embeds = input_embeds .reshape (B * N , C )
41+ input_embeds = input_embeds .reshape (B * N , C )
3842 image_input_ids = input_ids .reshape (B * N )
39- selected = image_input_ids == constants .INTERN_IMG_CONTEXT_TOKEN
43+ # TODO: Find a better way to decide which token value to use
44+ image_context_token = (
45+ constants .INTERN_3_5_IMG_CONTEXT_TOKEN
46+ if "Qwen3" in self .config .architectures [0 ]
47+ else constants .INTERN_IMG_CONTEXT_TOKEN
48+ )
49+ selected = image_input_ids == image_context_token
4050 indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
4151 indices1 = torch .where (indices1 != - 1 , indices1 + image_idx , indices1 )
4252 indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
4353 image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
4454 image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
4555 inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
56+ inputs_embeds = inputs_embeds .reshape (B , N , C )
4657 outputs = self .model .language_model (
4758 inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
4859 )
@@ -84,12 +95,13 @@ def get_specializations(
8495 raise NotImplementedError ("Image Size other than 448 is not supported for Intern models yet." )
8596
8697 per_patch_embed_size = (img_size // self .config .vision_config .patch_size * self .config .downsample_ratio ) ** 2
87- vision_size = int (num_patches * per_patch_embed_size )
98+ vision_size = int (batch_size * num_patches * per_patch_embed_size )
8899 vision = [
89100 {
90101 "batch_size" : batch_size ,
91102 "num_patches" : num_patches ,
92103 "img_size" : img_size ,
104+ "batched_num_patches" : batch_size * num_patches ,
93105 }
94106 ]
95107 lang = [
@@ -126,8 +138,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
126138 lang_dynamic_axes = {}
127139 lang_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
128140 lang_dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
129- lang_dynamic_axes ["vision_embeds" ] = {0 : "batch_size" , 1 : "vision_size" }
130- vision_dynamic_axes ["pixel_values" ] = {0 : "num_patches " , 2 : "img_size" , 3 : "img_size" }
141+ lang_dynamic_axes ["vision_embeds" ] = {1 : "vision_size" }
142+ vision_dynamic_axes ["pixel_values" ] = {0 : "batched_num_patches " , 2 : "img_size" , 3 : "img_size" }
131143
132144 pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
133145 for i in range (self .language_model .config .num_hidden_layers ):
@@ -182,16 +194,16 @@ def get_dummy_inputs(self, kv_offload: bool = False):
182194 inputs_shapes = {}
183195 inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
184196 inputs_shapes ["vision_embeds" ] = (
185- constants . ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
186- computed_feature_size ,
197+ 1 ,
198+ computed_feature_size * constants . ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
187199 self .language_model .config .hidden_size ,
188200 )
189201 inputs_shapes ["position_ids" ] = (
190202 constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
191203 constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
192204 )
193205 inputs_shapes ["pixel_values" ] = (
194- constants .INTERN_NUM_PATCHES ,
206+ constants .INTERN_NUM_PATCHES * constants . ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
195207 constants .INTERN_NUM_CHANNELS ,
196208 img_size ,
197209 img_size ,
@@ -237,14 +249,22 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val
237249 vision_embeds = self .extract_feature (pixel_values )
238250 B , N , C = input_embeds .shape
239251 image_input_embeds = input_embeds .reshape (B * N , C )
252+ input_embeds = input_embeds .reshape (B * N , C )
240253 image_input_ids = input_ids .reshape (B * N )
241- selected = image_input_ids == constants .INTERN_IMG_CONTEXT_TOKEN
254+ # TODO: Find a better way to decide which token value to use
255+ image_context_token = (
256+ constants .INTERN_3_5_IMG_CONTEXT_TOKEN
257+ if "Qwen3" in self .config .architectures [0 ]
258+ else constants .INTERN_IMG_CONTEXT_TOKEN
259+ )
260+ selected = image_input_ids == image_context_token
242261 indices1 = selected .unsqueeze (0 ).to (torch .int64 ).cumsum (1 ) - 1
243262 indices1 = torch .where (indices1 != - 1 , indices1 + image_idx , indices1 )
244263 indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
245264 image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
246265 image_input_embeds = torch .where (selected .unsqueeze (0 ).unsqueeze (- 1 ), image_features_expanded , input_embeds )
247266 inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), input_embeds , image_input_embeds )
267+ inputs_embeds = inputs_embeds .reshape (B , N , C )
248268 outputs = self .language_model (
249269 inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
250270 )
0 commit comments