@@ -779,8 +779,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl
779779 seq_len = constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
780780 )
781781
782- lang_inputs ["past_key_values" ] = [[] for _ in range (self .model .config .num_hidden_layers )]
783- for i in range (self .model .config .num_hidden_layers ):
782+ lang_inputs ["past_key_values" ] = [[] for _ in range (self .model .config .text_config . num_hidden_layers )]
783+ for i in range (self .model .config .text_config . num_hidden_layers ):
784784 for kv in ["key" , "value" ]:
785785 lang_inputs ["past_key_values" ][i ].append (torch .zeros (kv_cache_shape , dtype = torch .float32 ))
786786
@@ -811,10 +811,10 @@ def get_specializations(
811811 ** compiler_options ,
812812 ):
813813 if height is None or width is None :
814- height = 1365
815- width = 2048
814+ height = constants . QWEN2_5_VL_HEIGHT
815+ width = constants . QWEN2_5_VL_WIDTH
816816 logger .warning (
817- "Setting height and width to be 1365 and 2048 respectively, as it was neither passed nor found in vision_config"
817+ f "Setting height and width to be { height } and { width } respectively, as it was neither passed nor found in vision_config"
818818 )
819819 prefill_seq_len = prefill_seq_len if prefill_seq_len else 128
820820 ctx_len = ctx_len if ctx_len else constants .INTERN_CTX_LEN
@@ -940,7 +940,7 @@ def smart_resize(
940940
941941 def get_onnx_dynamic_axes (self , comp_ctx_lengths : Optional [List [int ]] = None , kv_offload : bool = False ):
942942 # Define dynamic axes
943- num_layers = self .config .num_hidden_layers
943+ num_layers = self .config .text_config . num_hidden_layers
944944
945945 vision_dynamic_axes = {
946946 "pixel_values" : {0 : "grid_height" , 1 : "grid_width" },
@@ -961,6 +961,7 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv
961961 lang_dynamic_axes ["comp_ctx_lengths" ] = {0 : "comp_ctx_lengths" }
962962
963963 dynamic_axes = {}
964+
964965 if kv_offload :
965966 dynamic_axes ["vision" ] = vision_dynamic_axes
966967 dynamic_axes ["lang" ] = lang_dynamic_axes
@@ -972,7 +973,7 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv
972973 def get_output_names (self , kv_offload : bool = False ):
973974 vision_output_names = ["vision_embeds" ]
974975 lang_output_names = ["logits" ]
975- for i in range (self .model .config .num_hidden_layers ):
976+ for i in range (self .model .config .text_config . num_hidden_layers ):
976977 for kv in ["key" , "value" ]:
977978 lang_output_names .append (f"past_{ kv } .{ i } _RetainedState" )
978979
@@ -988,6 +989,32 @@ def get_output_names(self, kv_offload: bool = False):
988989 return lang_output_names
989990 return output_names
990991
992+ def prepare_inputs_for_generation (self , inputs , prefill_seq_len = 128 , batch_size = 1 ):
993+ input_ids_length = inputs ["input_ids" ].shape [1 ]
994+
995+ inputs ["position_ids" ] = torch .arange (input_ids_length ).view (1 , 1 , input_ids_length ).expand (- 1 , batch_size , - 1 )
996+
997+ pos_ids , rope_deltas = self .model .get_rope_index (
998+ inputs ["input_ids" ],
999+ None if "image_grid_thw" not in inputs else inputs ["image_grid_thw" ],
1000+ video_grid_thw = None ,
1001+ second_per_grid_ts = None ,
1002+ attention_mask = inputs ["attention_mask" ],
1003+ )
1004+
1005+ inputs ["position_ids" ] = torch .cat ((inputs ["position_ids" ], pos_ids ), dim = 0 )
1006+
1007+ num_chunks = - (input_ids_length // - prefill_seq_len ) # ceil divide without float
1008+ padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
1009+
1010+ inputs ["position_ids" ] = F .pad (
1011+ inputs ["position_ids" ], pad = (0 , padded_len - input_ids_length ), mode = "constant" , value = - 1
1012+ )
1013+
1014+ inputs .pop ("image_grid_thw" , None )
1015+
1016+ return inputs
1017+
9911018 def get_inputs_info (self ):
9921019 return [
9931020 IOInfo (name = "input_ids" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
0 commit comments