Skip to content

Commit 91d250e

Browse files
authored
Reinstate self.scaling in Gemma3nTextAttention (#41751)
maintenance: make Gemma3nTextAttention more amenable to modular inheritance
1 parent 7cb4280 commit 91d250e

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/transformers/models/gemma3n/modeling_gemma3n.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
12371237
self.layer_idx = layer_idx
12381238
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
12391239
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
1240+
self.scaling = 1.0
12401241
self.attention_dropout = self.config.attention_dropout
12411242
self.is_causal = True
12421243

@@ -1335,7 +1336,7 @@ def forward(
13351336
value_states,
13361337
attention_mask,
13371338
dropout=self.attention_dropout if self.training else 0.0,
1338-
scaling=1.0,
1339+
scaling=self.scaling,
13391340
sliding_window=self.sliding_window,
13401341
**kwargs,
13411342
)

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,7 +1703,7 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
17031703
super().__init__(config, layer_idx)
17041704
self.is_causal = True
17051705
del self.attn_logit_softcapping
1706-
del self.scaling
1706+
self.scaling = 1.0
17071707
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
17081708

17091709
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
@@ -1782,7 +1782,7 @@ def forward(
17821782
value_states,
17831783
attention_mask,
17841784
dropout=self.attention_dropout if self.training else 0.0,
1785-
scaling=1.0,
1785+
scaling=self.scaling,
17861786
sliding_window=self.sliding_window,
17871787
**kwargs,
17881788
)

0 commit comments

Comments
 (0)