Skip to content

Commit 79e84f9

Browse files
authored
Merge branch 'main' into better-init-2
2 parents 99961fc + 16c7afd commit 79e84f9

File tree

16 files changed

+295
-348
lines changed

16 files changed

+295
-348
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ If you're contributing a **vision-language model** (or any multimodal model that
125125
All new models should use the modular architecture pattern. Create a `modular_<model_name>.py` file using the modular model converter:
126126

127127
- Use the CLI, [`transformers add-new-model-like`](https://github.com/huggingface/transformers/blob/main/src/transformers/cli/add_new_model_like.py) to generate a modular skeleton and get started
128-
- All code should be in the modular file if possible. Modeling must be in it, it's better if configuration is in it as well.
128+
- All code should be in the modular file if possible. Modeling must be in it, it's better if configuration is in it as well. [Modular guide](./modular_transformers#implementing-a-modular-file) shows a quick way to set up a modular file.
129129
- Reuse existing patterns from similar models as much as possible
130+
- You can make the model compatible with inference engines such as vLLM or SGLang, and enable zero-effort integration. See specific requirements for model implementation in ["Transformers modeling backend"](./transformers_as_backend#multimodal-models)
130131

131132
To verify your modular file is correct, run:
132133

docker/transformers-pytorch-amd-gpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
1+
FROM rocm/pytorch:rocm7.1_ubuntu22.04_py3.10_pytorch_release_2.8.0
22
LABEL maintainer="Hugging Face"
33

44
ARG DEBIAN_FRONTEND=noninteractive

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
- local: tools
119119
title: Tools
120120
- local: transformers_as_backend
121-
title: Inference server backends
121+
title: Transformers as modeling backend
122122
- local: continuous_batching
123123
title: Continuous Batching
124124
title: Inference

docs/source/en/modular_transformers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Contributing a new model to Transformers
22

3-
Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance.
3+
Modular Transformers lowers the bar for contributing models and significantly reduces the code required to add a model by allowing imports and inheritance. We recommend to go through [general contribution guidelines for new models](./contributing#do-you-want-to-implement-a-new-model) before diving into the details here.
44

55
One of Transformers' core design feature is the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy) policy. Model components - such as attention layers - are repeated across many files and any independent implementations tend to diverge as fixes and changes are applied to specific parts of the code.
66

docs/source/en/transformers_as_backend.md

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ rendered properly in your Markdown viewer.
1414
1515
-->
1616

17-
# Inference server backends
17+
# Transformers as modeling backend
1818

19-
Transformers' models are compatible with different inference servers like vLLM and SGLang. Instead of implementing a model for each inference server, you only need one model, which can be plugged into any inference server. It simplifies maintenance and makes it easy for users to use different inference servers for different use cases.
19+
Transformers' models are compatible with different inference servers like vLLM and SGLang. Instead of implementing a new model architecture from scratch for each inference server, you only need a model definition in `transformers`, which can be plugged into any inference server. It simplifies maintenance and makes it easy for users to use different inference servers for different use cases.
2020

2121
With Transformers as a backend, you can also serve any model - including custom and Hub-hosted models - without waiting for native support.
2222

@@ -157,57 +157,13 @@ class MyConfig(PreTrainedConfig):
157157

158158
### Multimodal models
159159

160-
For multimodal models, you need to include a few more changes on top of the general recommendations. These rules ensure that your model integrates properly with multimodal data.
160+
For multimodal models, you need to include a few more changes on top of the general recommendations outlined in ["contribuiting a model"](./contributing#vision-language-model-contribution-checklist). These rules ensure that your model integrates properly and enables processing multimodal data.
161161

162-
1. A multimodal model requires a base `MyMultiModalModel` class to handle multimodal fusion without a language modeling head and a separate generative class that adds a head.
162+
1. A multimodal model's processing class must have the `self.image_token` and `self.image_token_ids` attributes. These are placeholder tokens used to indicate image positions in the input. This placeholder token is the same token used in the input prompt to denote images and used in model code to scatter image features.
163163

164-
The base model needs to implement the `get_image_features()` method to accept image pixel values and return encoded outputs. These are later merged with the language embeddings and don't require any postprocessing. The shape of the returned features must match the number of input images. If a vision encoder returns variable-length outputs (patch-based), return a list of 2D tensors of size `(image_seq_len, image_dim)` for each image.
164+
2. The processing class needs `self._get_num_multimodal_tokens` method to compute the number of placeholder tokens needed for multimodal inputs with given sizes and to return a [`MultiModalData`] object. The placeholders between `<image>` tokens such as row or column tokens don't count as image placeholders. Only tokens that are actually replaced by image features later in modeling should be counted!
165165

166-
Expand the code below for an example.
167-
168-
<details>
169-
<summary>modeling_my_multimodal_model.py</summary>
170-
171-
```python
172-
from transformers.generation import GenerationMixin
173-
174-
class MyMultimodalModel(MyMultimodalPreTrainedModel):
175-
def __init__(self, config):
176-
super().__init__(config)
177-
self.language_model = AutoModel.from_config(config.text_config)
178-
self.vision_tower = AutoModel.from_config(config.vision_config)
179-
self.multimodal_projection = nn.Linear(vision_dim, text_dim)
180-
181-
def get_image_features(self, pixel_values):
182-
return self.vision_tower(pixel_values).last_hidden_states
183-
184-
def forward(self, input_ids, pixel_values, **kwargs):
185-
# process your inputs
186-
return MyModelOutputWithPast(
187-
last_hidden_state=last_hidden_state,
188-
image_hidden_states=image_features,
189-
[...]
190-
)
191-
192-
class MyMultimodalModelForConditionalGeneration(MyMultimodalPreTrainedModel, GenerationMixin):
193-
def __init__(self, config):
194-
super().__init__(config)
195-
self.model = MyMultimodalModel(config)
196-
self.lm_head = nn.Linear(hidden_dim, vocab_size)
197-
```
198-
199-
</details>
200-
201-
2. A multimodal model config must be nested with the following fields.
202-
* text_config: decoder language model config
203-
* vision_config: vision encoder config
204-
* image_token_id: ID of the image placeholder token used in the input to indicate image position
205-
206-
3. A multimodal model's processing class must have the `self.image_token` and `self.image_token_ids` attributes. These are placeholder tokens used to indicate image positions in the input. The placeholder token is the same token used in the input prompt and to mask scatter image features.
207-
208-
The processing class also needs `self._get_num_multimodal_tokens` method to compute the number of placeholder tokens needed for multimodal inputs with given sizes and to return a [`MultiModalData`] object. The placeholder for row and column tokens don't count as image placeholders. Only the tokens that are actually replaced by image features are computed.
209-
210-
Finally, when `return_mm_token_type_ids=True`, the class has to return `mm_token_type_ids` to indicate whether each position is a text token (`0`) or image placeholder token (`1`). Each image's token type IDs must be contiguous with no breaks between consecutive ones.
166+
3. The processor needs to check the value of `return_mm_token_type_ids` and return `mm_token_type_ids` to indicate whether each position is a text token (`0`), image placeholder token (`1`) or video placeholder token (`2`). Each multimodal token type ID sequence must be contiguous without breaks between consecutive tokens, therefore special tokens for begin/end/row/column must be treated as placeholders.
211167

212168
Expand the code below for an example.
213169

@@ -246,5 +202,5 @@ class MyMultimodalProcessor(ProcessorMixin):
246202

247203
## Resources
248204

249-
* Read the [Transformers backend integration in vLLM](https://blog.vllm.ai/2025/04/11/transformers-backend.html) blog post for more details about the Transformers backend in vLLM.
250-
* Read the [Transformers backend integration in SGLang](https://huggingface.co/blog/transformers-backend-sglang) blog post for more details about the Transformers backend in SGLang.
205+
* Read the [Transformers modeling backend integration in vLLM](https://blog.vllm.ai/2025/04/11/transformers-backend.html) blog post for more details about the Transformers modeling backend in vLLM.
206+
* Read the [Transformers modeling backend integration in SGLang](https://huggingface.co/blog/transformers-backend-sglang) blog post for more details about the Transformers modeling backend in SGLang.

src/transformers/models/blt/modeling_blt.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch.nn.functional as F
2929

3030
from ...activations import ACT2FN
31-
from ...cache_utils import Cache, DynamicCache
31+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
3232
from ...generation import GenerationMixin
3333
from ...masking_utils import create_causal_mask
3434
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -321,7 +321,6 @@ def forward(
321321
hidden_states: torch.Tensor,
322322
attention_mask: torch.Tensor,
323323
position_embeddings: torch.Tensor,
324-
use_cache: bool = False,
325324
past_key_values=None,
326325
cache_position=None,
327326
**kwargs,
@@ -393,9 +392,7 @@ def forward(
393392
self,
394393
hidden_states: torch.Tensor,
395394
cross_attention_states: Optional[torch.Tensor] = None,
396-
past_key_values: Optional[Cache] = None,
397395
attention_mask: Optional[torch.Tensor] = None,
398-
cache_position: Optional[torch.LongTensor] = None,
399396
**kwargs: Unpack[TransformersKwargs],
400397
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
401398
"""Input shape: Batch x Time x Channel"""
@@ -404,27 +401,13 @@ def forward(
404401
query_states = self.q_proj(query_states)
405402
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
406403

407-
if cross_attention_states is not None:
408-
cross_attention_states = self.k_norm(cross_attention_states)
409-
key_states = self.k_proj(cross_attention_states)
410-
value_states = self.v_proj(cross_attention_states)
411-
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
412-
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
413-
if past_key_values is not None:
414-
key_states, value_states = past_key_values.update(
415-
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
416-
)
417-
elif cache_position[0] != 0:
418-
key_states, value_states = (
419-
past_key_values.layers[self.layer_idx].keys,
420-
past_key_values.layers[self.layer_idx].values,
421-
)
422-
else:
423-
raise ValueError(
424-
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
425-
)
426-
attention_interface: Callable = eager_attention_forward
404+
cross_attention_states = self.k_norm(cross_attention_states)
405+
key_states = self.k_proj(cross_attention_states)
406+
value_states = self.v_proj(cross_attention_states)
407+
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
408+
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
427409

410+
attention_interface: Callable = eager_attention_forward
428411
if self.config._attn_implementation != "eager":
429412
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
430413

@@ -1089,6 +1072,9 @@ def forward(
10891072
if (input_ids is None) ^ (inputs_embeds is not None):
10901073
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
10911074

1075+
if use_cache and past_key_values is None:
1076+
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
1077+
10921078
# Extract input embeddings as early as possible
10931079
if inputs_embeds is not None:
10941080
encoder_embeds = inputs_embeds
@@ -1137,7 +1123,7 @@ def forward(
11371123
input_embeds=encoder_embeds,
11381124
attention_mask=attention_mask,
11391125
cache_position=cache_position,
1140-
past_key_values=past_key_values,
1126+
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
11411127
position_ids=position_ids,
11421128
)
11431129

@@ -1157,6 +1143,7 @@ def forward(
11571143
encoder_attention_mask=cross_attn_mask_enc,
11581144
num_patches=patch_lengths.shape[1],
11591145
patch_ids=patch_ids,
1146+
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
11601147
**kwargs,
11611148
)
11621149
encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
@@ -1192,7 +1179,7 @@ def forward(
11921179
patch_embeds=global_hidden_states,
11931180
attention_mask=causal_mask,
11941181
position_ids=position_ids,
1195-
past_key_values=past_key_values,
1182+
past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None,
11961183
cache_position=cache_position,
11971184
encoder_attention_mask=cross_attn_mask_dec,
11981185
**kwargs,

src/transformers/models/blt/modular_blt.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch.nn as nn
2323
import torch.nn.functional as F
2424

25-
from ...cache_utils import Cache, DynamicCache
25+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2626
from ...masking_utils import create_causal_mask
2727
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
2828
from ...modeling_rope_utils import dynamic_rope_update
@@ -299,27 +299,6 @@ def __init__(self, config, layer_idx: int):
299299
class BltSelfAttention(MllamaTextSelfAttention):
300300
def __init__(self, config: BltConfig, layer_idx: int):
301301
super().__init__(config, layer_idx)
302-
self.is_causal = True
303-
304-
def forward(
305-
self,
306-
hidden_states: torch.Tensor,
307-
attention_mask: torch.Tensor,
308-
position_embeddings: torch.Tensor,
309-
use_cache: bool = False,
310-
past_key_values=None,
311-
cache_position=None,
312-
**kwargs,
313-
):
314-
return super().forward(
315-
hidden_states=hidden_states,
316-
attention_mask=attention_mask,
317-
position_embeddings=position_embeddings,
318-
use_cache=use_cache,
319-
past_key_values=past_key_values,
320-
cache_position=cache_position,
321-
**kwargs,
322-
)
323302

324303

325304
class BltCrossAttention(MllamaTextCrossAttention):
@@ -335,37 +314,21 @@ def forward(
335314
self,
336315
hidden_states: torch.Tensor,
337316
cross_attention_states: Optional[torch.Tensor] = None,
338-
past_key_values: Optional[Cache] = None,
339317
attention_mask: Optional[torch.Tensor] = None,
340-
cache_position: Optional[torch.LongTensor] = None,
341318
**kwargs: Unpack[TransformersKwargs],
342319
):
343320
bsz, q_len, _ = hidden_states.size()
344321
query_states = self.q_norm(hidden_states)
345322
query_states = self.q_proj(query_states)
346323
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
347324

348-
if cross_attention_states is not None:
349-
cross_attention_states = self.k_norm(cross_attention_states)
350-
key_states = self.k_proj(cross_attention_states)
351-
value_states = self.v_proj(cross_attention_states)
352-
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
353-
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
354-
if past_key_values is not None:
355-
key_states, value_states = past_key_values.update(
356-
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
357-
)
358-
elif cache_position[0] != 0:
359-
key_states, value_states = (
360-
past_key_values.layers[self.layer_idx].keys,
361-
past_key_values.layers[self.layer_idx].values,
362-
)
363-
else:
364-
raise ValueError(
365-
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
366-
)
367-
attention_interface: Callable = eager_attention_forward
325+
cross_attention_states = self.k_norm(cross_attention_states)
326+
key_states = self.k_proj(cross_attention_states)
327+
value_states = self.v_proj(cross_attention_states)
328+
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329+
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
368330

331+
attention_interface: Callable = eager_attention_forward
369332
if self.config._attn_implementation != "eager":
370333
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
371334

@@ -828,6 +791,9 @@ def forward(
828791
if (input_ids is None) ^ (inputs_embeds is not None):
829792
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
830793

794+
if use_cache and past_key_values is None:
795+
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
796+
831797
# Extract input embeddings as early as possible
832798
if inputs_embeds is not None:
833799
encoder_embeds = inputs_embeds
@@ -876,7 +842,7 @@ def forward(
876842
input_embeds=encoder_embeds,
877843
attention_mask=attention_mask,
878844
cache_position=cache_position,
879-
past_key_values=past_key_values,
845+
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
880846
position_ids=position_ids,
881847
)
882848

@@ -896,6 +862,7 @@ def forward(
896862
encoder_attention_mask=cross_attn_mask_enc,
897863
num_patches=patch_lengths.shape[1],
898864
patch_ids=patch_ids,
865+
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
899866
**kwargs,
900867
)
901868
encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
@@ -931,7 +898,7 @@ def forward(
931898
patch_embeds=global_hidden_states,
932899
attention_mask=causal_mask,
933900
position_ids=position_ids,
934-
past_key_values=past_key_values,
901+
past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None,
935902
cache_position=cache_position,
936903
encoder_attention_mask=cross_attn_mask_dec,
937904
**kwargs,

src/transformers/models/glm4v/modeling_glm4v.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,14 +1418,11 @@ def forward(
14181418
pixel_values_videos: Optional[torch.FloatTensor] = None,
14191419
image_grid_thw: Optional[torch.LongTensor] = None,
14201420
video_grid_thw: Optional[torch.LongTensor] = None,
1421-
rope_deltas: Optional[torch.LongTensor] = None,
14221421
cache_position: Optional[torch.LongTensor] = None,
14231422
logits_to_keep: Union[int, torch.Tensor] = 0,
14241423
**kwargs: Unpack[TransformersKwargs],
14251424
) -> Union[tuple, Glm4vCausalLMOutputWithPast]:
14261425
r"""
1427-
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1428-
The rope index difference between sequence length and multimodal rope.
14291426
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
14301427
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
14311428
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

0 commit comments

Comments
 (0)