@@ -159,6 +159,8 @@ class _ModelInfo:
159159 is_embedding_model : bool
160160 supports_multimodal : bool
161161 supports_pp : bool
162+ has_inner_state : bool
163+ is_attention_free : bool
162164
163165 @staticmethod
164166 def from_model_cls (model : Type [nn .Module ]) -> "_ModelInfo" :
@@ -167,6 +169,8 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
167169 is_embedding_model = is_embedding_model (model ),
168170 supports_multimodal = supports_multimodal (model ),
169171 supports_pp = supports_pp (model ),
172+ has_inner_state = has_inner_state (model ),
173+ is_attention_free = is_attention_free (model ),
170174 )
171175
172176
@@ -382,6 +386,12 @@ def is_pp_supported_model(
382386 ) -> bool :
383387 return self .inspect_model_cls (architectures ).supports_pp
384388
389+ def model_has_inner_state (self , architectures : Union [str , List [str ]]) -> bool :
390+ return self .inspect_model_cls (architectures ).has_inner_state
391+
392+ def is_attention_free_model (self , architectures : Union [str , List [str ]]) -> bool :
393+ return self .inspect_model_cls (architectures ).is_attention_free
394+
385395
386396ModelRegistry = _ModelRegistry ({
387397 model_arch : _LazyRegisteredModel (
@@ -430,32 +440,6 @@ def _run() -> None:
430440 with open (output_file , "wb" ) as f :
431441 f .write (pickle .dumps (result ))
432442
433- @staticmethod
434- def model_has_inner_state (architectures : Union [str , List [str ]]) -> bool :
435- if isinstance (architectures , str ):
436- architectures = [architectures ]
437- if not architectures :
438- logger .warning ("No model architectures are specified" )
439-
440- has_instate = partial (ModelRegistry ._check_stateless ,
441- has_inner_state ,
442- default = False )
443-
444- return any (has_instate (arch ) for arch in architectures )
445-
446- @staticmethod
447- def is_attention_free_model (architectures : Union [str , List [str ]]) -> bool :
448- if isinstance (architectures , str ):
449- architectures = [architectures ]
450- if not architectures :
451- logger .warning ("No model architectures are specified" )
452-
453- is_attn_free = partial (ModelRegistry ._check_stateless ,
454- is_attention_free ,
455- default = False )
456-
457- return any (is_attn_free (arch ) for arch in architectures )
458-
459443
460444if __name__ == "__main__" :
461445 _run ()
0 commit comments