@@ -695,7 +695,9 @@ def get_model_config(self) -> dict:
695695 dict
696696 The configuration dictionary.
697697 """
698- return self .model .model .vision_model .config .__dict__
698+ if hasattr (self .model .model , "vision_model" ):
699+ return self .model .model .vision_model .config .__dict__
700+ return self .model .model .config .__dict__
699701
700702
701703class QEffCausalLMForTextImageToTextModel (QEFFBaseModel ):
@@ -835,7 +837,9 @@ def get_model_config(self) -> dict:
835837 dict
836838 The configuration dictionary.
837839 """
838- return self .model .language_model .config .__dict__
840+ if hasattr (self .model , "language_model" ):
841+ return self .model .language_model .config .__dict__
842+ return self .model .config .__dict__
839843
840844
841845class _QEffAutoModelForImageTextToTextDualQPC :
@@ -1086,7 +1090,11 @@ def compile(
10861090
10871091 custom_io_vision = {}
10881092 kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1093+ molmo = hasattr (self .model .config , "model_type" ) and self .model .config .model_type == "molmo"
1094+ if molmo :
1095+ custom_io_vision ["image_masks" ] = "float16"
10891096 custom_io_vision ["pixel_values" ] = "float16"
1097+
10901098 for output_name in output_names ["vision" ]:
10911099 if output_name .startswith ("past_" ):
10921100 custom_io_vision [output_name ] = kv_cache_dtype
@@ -1288,11 +1296,15 @@ def kv_offload_generate(
12881296 inputs [k ] = np .array (v )
12891297
12901298 vision_inputs = {
1291- k : v for k , v in inputs .items () if k in {"pixel_values" , "aspect_ratio_ids" , "aspect_ratio_mask" }
1299+ k : v
1300+ for k , v in inputs .items ()
1301+ if k
1302+ in {"pixel_values" , "image_masks" , "image_input_idx" , "valid_idx" , "aspect_ratio_ids" , "aspect_ratio_mask" }
12921303 }
12931304
1294- if vision_inputs :
1295- vision_inputs ["pixel_values" ] = vision_inputs ["pixel_values" ].astype ("float16" )
1305+ vision_inputs_fp16 = {"pixel_values" , "image_masks" }
1306+ vision_inputs .update ({k : vision_inputs [k ].astype ("float16" ) for k in vision_inputs_fp16 if k in vision_inputs })
1307+
12961308 vision_start = perf_counter ()
12971309
12981310 vision_outputs = {}
@@ -1429,7 +1441,10 @@ def __init__(
14291441 self .model .config .llm_config ._attn_implementation = "eager"
14301442 self .model .config .vision_config .use_flash_attn = "false"
14311443 else :
1432- self .model .config .text_config .use_cache = True
1444+ if hasattr (self .model .config , "text_config" ):
1445+ self .model .config .text_config .use_cache = True
1446+ else :
1447+ self .model .config .use_cache = True
14331448 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
14341449
14351450 @classmethod
@@ -1980,7 +1995,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
19801995 return cls (model , kv_offload = kv_offload , pretrained_model_name_or_path = pretrained_model_name_or_path , ** kwargs )
19811996
19821997
1983- MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel" : QEFFAutoModelForImageTextToText }
1998+ MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {
1999+ "InternVLChatModel" : QEFFAutoModelForImageTextToText ,
2000+ "MolmoForCausalLM" : QEFFAutoModelForImageTextToText ,
2001+ }
19842002
19852003
19862004class QEFFAutoModelForCausalLM (QEFFBaseModel ):
0 commit comments