Skip to content

Commit 18cbecf

Browse files
authored
Enable Llama4 for Gaudi backend (#3223)
Signed-off-by: yuanwu <yuan.wu@intel.com>
1 parent 7e531f4 commit 18cbecf

File tree

7 files changed

+1575
-33
lines changed

7 files changed

+1575
-33
lines changed

backends/gaudi/server/text_generation_server/models/__init__.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616

1717
from text_generation_server.utils.speculate import get_speculate, set_speculate
1818
from text_generation_server.models.model import Model
19-
from text_generation_server.models.causal_lm import CausalLM
20-
from text_generation_server.models.bloom import BLOOM
21-
from text_generation_server.models.starcoder import StarCoder
2219
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
2320
PhiMoEConfig,
2421
)
@@ -32,7 +29,6 @@
3229
from text_generation_server.adapters.lora import LoraWeights
3330

3431
from text_generation_server.utils.log import log_master
35-
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
3632

3733
__all__ = [
3834
"Model",
@@ -42,6 +38,7 @@
4238
]
4339
from text_generation_server.models.globals import ATTENTION
4440

41+
VLM_BATCH_TYPES = set()
4542
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
4643

4744
FLASH_ATTENTION = False
@@ -63,6 +60,9 @@
6360
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
6461
FlashLlamaForCausalLM,
6562
)
63+
from text_generation_server.models.custom_modeling.flash_llama4_modeling import (
64+
Llama4ForConditionalGeneration,
65+
)
6666
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
6767
FlashCohereForCausalLM,
6868
)
@@ -140,10 +140,24 @@
140140
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
141141
SUPPORTS_WINDOWING = False
142142
FLASH_ATTENTION = False
143+
VLM_BATCH_TYPES = set()
143144

144145
if FLASH_ATTENTION:
145146
__all__.append(FlashCausalLM)
146147

148+
from text_generation_server.models.flash_vlm_causal_lm import (
149+
FlashVlmCausalLMBatch,
150+
)
151+
152+
VLM_BATCH_TYPES = {
153+
PaliGemmaBatch,
154+
FlashVlmCausalLMBatch,
155+
FlashMllamaCausalLMBatch,
156+
}
157+
158+
159+
__all__.append(VLM_BATCH_TYPES)
160+
147161

148162
class ModelType(enum.Enum):
149163
DEEPSEEK_V2 = {
@@ -179,6 +193,11 @@ class ModelType(enum.Enum):
179193
"name": "Llama",
180194
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
181195
}
196+
LLAMA4 = {
197+
"type": "llama4",
198+
"name": "Llama4",
199+
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
200+
}
182201
PHI3 = {
183202
"type": "phi3",
184203
"name": "Phi 3",
@@ -589,6 +608,19 @@ def get_model(
589608
trust_remote_code=trust_remote_code,
590609
lora_adapter_ids=lora_adapter_ids,
591610
)
611+
elif model_type == LLAMA4:
612+
print(f"Llama4 model detected: {model_id}")
613+
return FlashVlmCausalLM(
614+
model_id=model_id,
615+
model_class=Llama4ForConditionalGeneration,
616+
revision=revision,
617+
quantize=quantize,
618+
speculator=speculator,
619+
dtype=dtype,
620+
default_dtype=torch.bfloat16,
621+
trust_remote_code=trust_remote_code,
622+
lora_adapter_ids=lora_adapter_ids,
623+
)
592624
elif model_type == BAICHUAN:
593625
return FlashCausalLM(
594626
model_id=model_id,
@@ -823,20 +855,32 @@ def get_model(
823855
trust_remote_code=trust_remote_code,
824856
)
825857

858+
from text_generation_server.models.causal_lm import CausalLM
826859
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
827860
from text_generation_server.models.custom_modeling.mllama import (
828861
MllamaForConditionalGeneration,
829862
)
830863
from text_generation_server.models.custom_modeling.llava_next import (
831864
LlavaNextForConditionalGeneration,
832865
)
866+
from text_generation_server.models.vlm_causal_lm import (
867+
VlmCausalLMBatch,
868+
)
869+
870+
VLM_BATCH_TYPES.add(VlmCausalLMBatch)
871+
872+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
833873

834874
adapt_transformers_to_gaudi()
835875
if SDP_ON_BF16 == 1:
836876
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
837877
if model_type == "gpt_bigcode":
878+
from text_generation_server.models.starcoder import StarCoder
879+
838880
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
839881
if model_type == "bloom":
882+
from text_generation_server.models.bloom import BLOOM
883+
840884
return BLOOM(
841885
model_id=model_id,
842886
revision=revision,

0 commit comments

Comments
 (0)