Skip to content

Commit 14ee6e7

Browse files
authored
[gaudi] gemma3 text and vlm model intial support. need to add sliding window support later (#3270)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent bd1bdeb commit 14ee6e7

File tree

6 files changed

+796
-33
lines changed

6 files changed

+796
-33
lines changed

backends/gaudi/server/text_generation_server/models/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@
6767
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
6868
FlashGemma2ForCausalLM,
6969
)
70+
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
71+
Gemma3ForConditionalGeneration,
72+
FlashGemma3ForCausalLM,
73+
)
7074
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
7175
FlashDbrxForCausalLM,
7276
DbrxConfig,
@@ -220,6 +224,16 @@ class ModelType(enum.Enum):
220224
"name": "Gemma2",
221225
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
222226
}
227+
GEMMA3 = {
228+
"type": "gemma3",
229+
"name": "Gemma3",
230+
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
231+
}
232+
GEMMA3_TEXT = {
233+
"type": "gemma3_text",
234+
"name": "Gemma3 Text",
235+
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
236+
}
223237
COHERE = {
224238
"type": "cohere",
225239
"name": "Cohere",
@@ -630,6 +644,7 @@ def get_model(
630644
quantize=quantize,
631645
speculator=speculator,
632646
dtype=dtype,
647+
kv_cache_dtype=kv_cache_dtype,
633648
default_dtype=torch.bfloat16,
634649
trust_remote_code=trust_remote_code,
635650
lora_adapter_ids=lora_adapter_ids,
@@ -675,6 +690,34 @@ def get_model(
675690
trust_remote_code=trust_remote_code,
676691
lora_adapter_ids=lora_adapter_ids,
677692
)
693+
elif model_type == GEMMA3:
694+
return FlashVlmCausalLM(
695+
model_id=model_id,
696+
model_class=Gemma3ForConditionalGeneration,
697+
revision=revision,
698+
quantize=quantize,
699+
speculator=speculator,
700+
dtype=dtype,
701+
kv_cache_dtype=kv_cache_dtype,
702+
default_dtype=torch.bfloat16,
703+
trust_remote_code=trust_remote_code,
704+
lora_adapter_ids=lora_adapter_ids,
705+
support_chunking=False,
706+
)
707+
elif model_type == GEMMA3_TEXT:
708+
return FlashCausalLM(
709+
model_id=model_id,
710+
model_class=FlashGemma3ForCausalLM,
711+
revision=revision,
712+
quantize=quantize,
713+
speculator=speculator,
714+
dtype=dtype,
715+
kv_cache_dtype=kv_cache_dtype,
716+
# Works better for these models
717+
default_dtype=torch.bfloat16,
718+
trust_remote_code=trust_remote_code,
719+
lora_adapter_ids=lora_adapter_ids,
720+
)
678721
elif model_type == COHERE:
679722
return FlashCausalLM(
680723
model_id=model_id,
@@ -864,6 +907,7 @@ def get_model(
864907
quantize=quantize,
865908
speculator=speculator,
866909
dtype=dtype,
910+
kv_cache_dtype=kv_cache_dtype,
867911
default_dtype=torch.bfloat16,
868912
trust_remote_code=trust_remote_code,
869913
lora_adapter_ids=lora_adapter_ids,

0 commit comments

Comments
 (0)