|
67 | 67 | from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( |
68 | 68 | FlashGemma2ForCausalLM, |
69 | 69 | ) |
| 70 | + from text_generation_server.models.custom_modeling.flash_gemma3_modeling import ( |
| 71 | + Gemma3ForConditionalGeneration, |
| 72 | + FlashGemma3ForCausalLM, |
| 73 | + ) |
70 | 74 | from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( |
71 | 75 | FlashDbrxForCausalLM, |
72 | 76 | DbrxConfig, |
@@ -220,6 +224,16 @@ class ModelType(enum.Enum): |
220 | 224 | "name": "Gemma2", |
221 | 225 | "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315", |
222 | 226 | } |
| 227 | + GEMMA3 = { |
| 228 | + "type": "gemma3", |
| 229 | + "name": "Gemma3", |
| 230 | + "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d", |
| 231 | + } |
| 232 | + GEMMA3_TEXT = { |
| 233 | + "type": "gemma3_text", |
| 234 | + "name": "Gemma3 Text", |
| 235 | + "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d", |
| 236 | + } |
223 | 237 | COHERE = { |
224 | 238 | "type": "cohere", |
225 | 239 | "name": "Cohere", |
@@ -630,6 +644,7 @@ def get_model( |
630 | 644 | quantize=quantize, |
631 | 645 | speculator=speculator, |
632 | 646 | dtype=dtype, |
| 647 | + kv_cache_dtype=kv_cache_dtype, |
633 | 648 | default_dtype=torch.bfloat16, |
634 | 649 | trust_remote_code=trust_remote_code, |
635 | 650 | lora_adapter_ids=lora_adapter_ids, |
@@ -675,6 +690,34 @@ def get_model( |
675 | 690 | trust_remote_code=trust_remote_code, |
676 | 691 | lora_adapter_ids=lora_adapter_ids, |
677 | 692 | ) |
| 693 | + elif model_type == GEMMA3: |
| 694 | + return FlashVlmCausalLM( |
| 695 | + model_id=model_id, |
| 696 | + model_class=Gemma3ForConditionalGeneration, |
| 697 | + revision=revision, |
| 698 | + quantize=quantize, |
| 699 | + speculator=speculator, |
| 700 | + dtype=dtype, |
| 701 | + kv_cache_dtype=kv_cache_dtype, |
| 702 | + default_dtype=torch.bfloat16, |
| 703 | + trust_remote_code=trust_remote_code, |
| 704 | + lora_adapter_ids=lora_adapter_ids, |
| 705 | + support_chunking=False, |
| 706 | + ) |
| 707 | + elif model_type == GEMMA3_TEXT: |
| 708 | + return FlashCausalLM( |
| 709 | + model_id=model_id, |
| 710 | + model_class=FlashGemma3ForCausalLM, |
| 711 | + revision=revision, |
| 712 | + quantize=quantize, |
| 713 | + speculator=speculator, |
| 714 | + dtype=dtype, |
| 715 | + kv_cache_dtype=kv_cache_dtype, |
| 716 | + # Works better for these models |
| 717 | + default_dtype=torch.bfloat16, |
| 718 | + trust_remote_code=trust_remote_code, |
| 719 | + lora_adapter_ids=lora_adapter_ids, |
| 720 | + ) |
678 | 721 | elif model_type == COHERE: |
679 | 722 | return FlashCausalLM( |
680 | 723 | model_id=model_id, |
@@ -864,6 +907,7 @@ def get_model( |
864 | 907 | quantize=quantize, |
865 | 908 | speculator=speculator, |
866 | 909 | dtype=dtype, |
| 910 | + kv_cache_dtype=kv_cache_dtype, |
867 | 911 | default_dtype=torch.bfloat16, |
868 | 912 | trust_remote_code=trust_remote_code, |
869 | 913 | lora_adapter_ids=lora_adapter_ids, |
|
0 commit comments