Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 609e9fb

Browse files
committed
1 parent e80b82a commit 609e9fb

File tree

1 file changed

+10
-26
lines changed

1 file changed

+10
-26
lines changed

vllm/model_executor/models/registry.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ class _ModelInfo:
159159
is_embedding_model: bool
160160
supports_multimodal: bool
161161
supports_pp: bool
162+
has_inner_state: bool
163+
is_attention_free: bool
162164

163165
@staticmethod
164166
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
@@ -167,6 +169,8 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
167169
is_embedding_model=is_embedding_model(model),
168170
supports_multimodal=supports_multimodal(model),
169171
supports_pp=supports_pp(model),
172+
has_inner_state=has_inner_state(model),
173+
is_attention_free=is_attention_free(model),
170174
)
171175

172176

@@ -382,6 +386,12 @@ def is_pp_supported_model(
382386
) -> bool:
383387
return self.inspect_model_cls(architectures).supports_pp
384388

389+
def model_has_inner_state(self, architectures: Union[str, List[str]]) -> bool:
390+
return self.inspect_model_cls(architectures).has_inner_state
391+
392+
def is_attention_free_model(self, architectures: Union[str, List[str]]) -> bool:
393+
return self.inspect_model_cls(architectures).is_attention_free
394+
385395

386396
ModelRegistry = _ModelRegistry({
387397
model_arch: _LazyRegisteredModel(
@@ -430,32 +440,6 @@ def _run() -> None:
430440
with open(output_file, "wb") as f:
431441
f.write(pickle.dumps(result))
432442

433-
@staticmethod
434-
def model_has_inner_state(architectures: Union[str, List[str]]) -> bool:
435-
if isinstance(architectures, str):
436-
architectures = [architectures]
437-
if not architectures:
438-
logger.warning("No model architectures are specified")
439-
440-
has_instate = partial(ModelRegistry._check_stateless,
441-
has_inner_state,
442-
default=False)
443-
444-
return any(has_instate(arch) for arch in architectures)
445-
446-
@staticmethod
447-
def is_attention_free_model(architectures: Union[str, List[str]]) -> bool:
448-
if isinstance(architectures, str):
449-
architectures = [architectures]
450-
if not architectures:
451-
logger.warning("No model architectures are specified")
452-
453-
is_attn_free = partial(ModelRegistry._check_stateless,
454-
is_attention_free,
455-
default=False)
456-
457-
return any(is_attn_free(arch) for arch in architectures)
458-
459443

460444
if __name__ == "__main__":
461445
_run()

0 commit comments

Comments
 (0)