|
16 | 16 |
|
17 | 17 | from text_generation_server.utils.speculate import get_speculate, set_speculate |
18 | 18 | from text_generation_server.models.model import Model |
19 | | -from text_generation_server.models.causal_lm import CausalLM |
20 | | -from text_generation_server.models.bloom import BLOOM |
21 | | -from text_generation_server.models.starcoder import StarCoder |
22 | 19 | from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import ( |
23 | 20 | PhiMoEConfig, |
24 | 21 | ) |
|
32 | 29 | from text_generation_server.adapters.lora import LoraWeights |
33 | 30 |
|
34 | 31 | from text_generation_server.utils.log import log_master |
35 | | -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi |
36 | 32 |
|
37 | 33 | __all__ = [ |
38 | 34 | "Model", |
|
42 | 38 | ] |
43 | 39 | from text_generation_server.models.globals import ATTENTION |
44 | 40 |
|
| 41 | +VLM_BATCH_TYPES = set() |
45 | 42 | FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." |
46 | 43 |
|
47 | 44 | FLASH_ATTENTION = False |
|
63 | 60 | from text_generation_server.models.custom_modeling.flash_llama_modeling import ( |
64 | 61 | FlashLlamaForCausalLM, |
65 | 62 | ) |
| 63 | + from text_generation_server.models.custom_modeling.flash_llama4_modeling import ( |
| 64 | + Llama4ForConditionalGeneration, |
| 65 | + ) |
66 | 66 | from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( |
67 | 67 | FlashCohereForCausalLM, |
68 | 68 | ) |
|
140 | 140 | log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") |
141 | 141 | SUPPORTS_WINDOWING = False |
142 | 142 | FLASH_ATTENTION = False |
| 143 | + VLM_BATCH_TYPES = set() |
143 | 144 |
|
144 | 145 | if FLASH_ATTENTION: |
145 | 146 | __all__.append(FlashCausalLM) |
146 | 147 |
|
| 148 | + from text_generation_server.models.flash_vlm_causal_lm import ( |
| 149 | + FlashVlmCausalLMBatch, |
| 150 | + ) |
| 151 | + |
| 152 | + VLM_BATCH_TYPES = { |
| 153 | + PaliGemmaBatch, |
| 154 | + FlashVlmCausalLMBatch, |
| 155 | + FlashMllamaCausalLMBatch, |
| 156 | + } |
| 157 | + |
| 158 | + |
| 159 | +__all__.append(VLM_BATCH_TYPES) |
| 160 | + |
147 | 161 |
|
148 | 162 | class ModelType(enum.Enum): |
149 | 163 | DEEPSEEK_V2 = { |
@@ -179,6 +193,11 @@ class ModelType(enum.Enum): |
179 | 193 | "name": "Llama", |
180 | 194 | "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", |
181 | 195 | } |
| 196 | + LLAMA4 = { |
| 197 | + "type": "llama4", |
| 198 | + "name": "Llama4", |
| 199 | + "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", |
| 200 | + } |
182 | 201 | PHI3 = { |
183 | 202 | "type": "phi3", |
184 | 203 | "name": "Phi 3", |
@@ -589,6 +608,19 @@ def get_model( |
589 | 608 | trust_remote_code=trust_remote_code, |
590 | 609 | lora_adapter_ids=lora_adapter_ids, |
591 | 610 | ) |
| 611 | + elif model_type == LLAMA4: |
| 612 | + print(f"Llama4 model detected: {model_id}") |
| 613 | + return FlashVlmCausalLM( |
| 614 | + model_id=model_id, |
| 615 | + model_class=Llama4ForConditionalGeneration, |
| 616 | + revision=revision, |
| 617 | + quantize=quantize, |
| 618 | + speculator=speculator, |
| 619 | + dtype=dtype, |
| 620 | + default_dtype=torch.bfloat16, |
| 621 | + trust_remote_code=trust_remote_code, |
| 622 | + lora_adapter_ids=lora_adapter_ids, |
| 623 | + ) |
592 | 624 | elif model_type == BAICHUAN: |
593 | 625 | return FlashCausalLM( |
594 | 626 | model_id=model_id, |
@@ -823,20 +855,32 @@ def get_model( |
823 | 855 | trust_remote_code=trust_remote_code, |
824 | 856 | ) |
825 | 857 |
|
| 858 | + from text_generation_server.models.causal_lm import CausalLM |
826 | 859 | from text_generation_server.models.vlm_causal_lm import VlmCausalLM |
827 | 860 | from text_generation_server.models.custom_modeling.mllama import ( |
828 | 861 | MllamaForConditionalGeneration, |
829 | 862 | ) |
830 | 863 | from text_generation_server.models.custom_modeling.llava_next import ( |
831 | 864 | LlavaNextForConditionalGeneration, |
832 | 865 | ) |
| 866 | + from text_generation_server.models.vlm_causal_lm import ( |
| 867 | + VlmCausalLMBatch, |
| 868 | + ) |
| 869 | + |
| 870 | + VLM_BATCH_TYPES.add(VlmCausalLMBatch) |
| 871 | + |
| 872 | + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi |
833 | 873 |
|
834 | 874 | adapt_transformers_to_gaudi() |
835 | 875 | if SDP_ON_BF16 == 1: |
836 | 876 | torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) |
837 | 877 | if model_type == "gpt_bigcode": |
| 878 | + from text_generation_server.models.starcoder import StarCoder |
| 879 | + |
838 | 880 | return StarCoder(model_id=model_id, revision=revision, dtype=dtype) |
839 | 881 | if model_type == "bloom": |
| 882 | + from text_generation_server.models.bloom import BLOOM |
| 883 | + |
840 | 884 | return BLOOM( |
841 | 885 | model_id=model_id, |
842 | 886 | revision=revision, |
|
0 commit comments