Skip to content

Commit 7199074

Browse files
authored
[gaudi] Refine rope memory, do not need to keep sin/cos cache per layer (#3274)
1 parent 238fbd4 commit 7199074

26 files changed

+315
-3525
lines changed

backends/gaudi/server/text_generation_server/layers/rotary.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
3636
self._sin_k_cached = None
3737
self.scaling_factor = scaling_factor
3838
self.dynamic_args = None
39-
self.max_position_embeddings = max_position_embeddings
39+
self._update_cos_sin_cache(
40+
torch.float32, inv_freq.device, max_position_embeddings
41+
)
4042

4143
def forward(
4244
self,
@@ -268,9 +270,7 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
268270
self._sin_cached = torch.sin(freqs).to(dtype)
269271

270272
def get_cos_sin(self, position_ids: torch.Tensor):
271-
self._update_cos_sin_cache(
272-
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
273-
)
273+
274274
cos = torch.index_select(self._cos_cached, 0, position_ids)
275275
sin = torch.index_select(self._sin_cached, 0, position_ids)
276276

@@ -298,6 +298,9 @@ def __init__(
298298
self._cos_k_cached = None
299299
self._sin_k_cached = None
300300
self.dynamic_args = None
301+
self._update_cos_sin_cache(
302+
torch.float32, short_inv_freq.device, max_position_embeddings
303+
)
301304

302305
def _update_cos_sin_cache(self, dtype, device, seqlen):
303306
# Reset the tables if the sequence length has changed,
@@ -351,6 +354,9 @@ def __init__(
351354
self._cos_k_cached = None
352355
self._sin_k_cached = None
353356
self.dynamic_args = None
357+
self._update_cos_sin_cache(
358+
torch.float32, short_inv_freq.device, max_position_embeddings
359+
)
354360

355361
def _update_cos_sin_cache(self, dtype, device, seqlen):
356362
if (
@@ -592,9 +598,6 @@ def get_cos_sin(
592598
position_ids: torch.Tensor,
593599
):
594600
slen = position_ids.shape[0]
595-
self._update_cos_sin_cache(
596-
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
597-
)
598601

599602
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
600603
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,18 +160,14 @@ def __init__(
160160
prefix: str,
161161
config,
162162
weights,
163+
rotary_emb,
163164
):
164165
super().__init__()
165166
self.num_heads = config.num_attention_heads
166167
self.hidden_size = config.hidden_size
167168
self.head_size = self.hidden_size // self.num_heads
168169

169-
self.rotary_emb = CohereRotary.static(
170-
config=config,
171-
dim=self.head_size,
172-
base=config.rope_theta,
173-
device=weights.device,
174-
)
170+
self.rotary_emb = rotary_emb
175171

176172
self.softmax_scale = self.head_size**-0.5
177173

@@ -325,11 +321,14 @@ def forward(self, hidden_states):
325321

326322

327323
class FlashCohereLayer(nn.Module):
328-
def __init__(self, prefix: str, layer_id, config, weights):
324+
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
329325
super().__init__()
330326
prefix = f"{prefix}.layers.{layer_id}"
331327
self.self_attn = FlashCohereAttention(
332-
prefix=f"{prefix}.self_attn", config=config, weights=weights
328+
prefix=f"{prefix}.self_attn",
329+
config=config,
330+
weights=weights,
331+
rotary_emb=rotary_emb,
333332
)
334333
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
335334

@@ -385,13 +384,20 @@ def __init__(self, prefix: str, config, weights):
385384
self.embed_tokens = TensorParallelEmbedding(
386385
prefix=f"{prefix}.embed_tokens", weights=weights
387386
)
387+
rotary_emb = CohereRotary.static(
388+
config=config,
389+
dim=config.hidden_size // config.num_attention_heads,
390+
base=config.rope_theta,
391+
device=weights.device,
392+
)
388393
self.layers = nn.ModuleList(
389394
[
390395
FlashCohereLayer(
391396
prefix,
392397
layer_id,
393398
config,
394399
weights,
400+
rotary_emb,
395401
)
396402
for layer_id in range(config.num_hidden_layers)
397403
]

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,19 +263,15 @@ def __init__(
263263
prefix: str,
264264
config,
265265
weights,
266+
rotary_emb,
266267
):
267268
super().__init__()
268269
self.clip_qkv = config.attn_config.clip_qkv
269270
self.num_heads = config.n_heads
270271
self.hidden_size = config.d_model
271272
self.head_size = self.hidden_size // self.num_heads
272273

273-
self.rotary_emb = PositionRotaryEmbedding.static(
274-
config=config,
275-
dim=self.head_size,
276-
base=config.attn_config.rope_theta,
277-
device=weights.device,
278-
)
274+
self.rotary_emb = rotary_emb
279275

280276
self.softmax_scale = self.head_size**-0.5
281277

@@ -370,13 +366,17 @@ def __init__(
370366
prefix: str,
371367
config,
372368
weights,
369+
rotary_emb,
373370
):
374371
super().__init__()
375372
self.norm_1 = FastLayerNorm.load_no_bias(
376373
prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
377374
)
378375
self.self_attn = DbrxAttention(
379-
prefix=f"{prefix}.attn", config=config, weights=weights
376+
prefix=f"{prefix}.attn",
377+
config=config,
378+
weights=weights,
379+
rotary_emb=rotary_emb,
380380
)
381381
self.norm_2 = FastLayerNorm.load_no_bias(
382382
prefix=f"{prefix}.norm_2",
@@ -601,12 +601,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
601601

602602

603603
class DbrxLayer(nn.Module):
604-
def __init__(self, prefix: str, layer_id, config, weights):
604+
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
605605
super().__init__()
606606
prefix = f"{prefix}.blocks.{layer_id}"
607607

608608
self.attn = DbrxNormAttentionNorm(
609-
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
609+
prefix=f"{prefix}.norm_attn_norm",
610+
config=config,
611+
weights=weights,
612+
rotary_emb=rotary_emb,
610613
)
611614

612615
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
@@ -649,6 +652,12 @@ def __init__(self, prefix: str, config, weights):
649652
self.embed_tokens = TensorParallelEmbedding(
650653
prefix=f"{prefix}.wte", weights=weights
651654
)
655+
rotary_emb = PositionRotaryEmbedding.static(
656+
config=config,
657+
dim=config.d_model // config.n_heads,
658+
base=config.attn_config.rope_theta,
659+
device=weights.device,
660+
)
652661

653662
self.layers = nn.ModuleList(
654663
[
@@ -657,6 +666,7 @@ def __init__(self, prefix: str, config, weights):
657666
layer_id,
658667
config,
659668
weights,
669+
rotary_emb,
660670
)
661671
for layer_id in range(config.n_layers)
662672
]

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
prefix: str,
157157
config,
158158
weights: Weights,
159+
rotary_emb,
159160
):
160161
super().__init__()
161162
self.num_heads = config.num_attention_heads
@@ -167,13 +168,7 @@ def __init__(
167168
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
168169
self.value_head_size = config.v_head_dim
169170
self.head_pad_size = max(self.head_size, self.value_head_size)
170-
171-
self.rotary_emb = PositionRotaryEmbedding.static(
172-
config=config,
173-
dim=self.qk_rope_head_dim,
174-
base=config.rope_theta,
175-
device=weights.device,
176-
)
171+
self.rotary_emb = rotary_emb
177172

178173
mscale = get_mscale(
179174
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
@@ -459,14 +454,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
459454

460455

461456
class DeepseekV2Layer(nn.Module):
462-
def __init__(self, prefix, layer_id, config, weights):
457+
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
463458
super().__init__()
464459
prefix = f"{prefix}.layers.{layer_id}"
465460

466461
self.self_attn = DeepseekV2Attention(
467462
prefix=f"{prefix}.self_attn",
468463
config=config,
469464
weights=weights,
465+
rotary_emb=rotary_emb,
470466
)
471467

472468
if (
@@ -541,13 +537,20 @@ def __init__(self, prefix: str, config, weights: Weights):
541537
prefix=f"{prefix}.embed_tokens", weights=weights
542538
)
543539

540+
rotary_emb = PositionRotaryEmbedding.static(
541+
config=config,
542+
dim=config.qk_rope_head_dim,
543+
base=config.rope_theta,
544+
device=weights.device,
545+
)
544546
self.layers = nn.ModuleList(
545547
[
546548
DeepseekV2Layer(
547549
prefix,
548550
layer_id,
549551
config,
550552
weights,
553+
rotary_emb,
551554
)
552555
for layer_id in range(config.num_hidden_layers)
553556
]

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__(
169169
prefix: str,
170170
config,
171171
weights: Weights,
172+
rotary_emb,
172173
):
173174
super().__init__()
174175
self.num_heads = config.num_attention_heads
@@ -180,13 +181,7 @@ def __init__(
180181
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
181182
self.value_head_size = config.v_head_dim
182183
self.head_pad_size = max(self.head_size, self.value_head_size)
183-
184-
self.rotary_emb = PositionRotaryEmbedding.static(
185-
config=config,
186-
dim=self.qk_rope_head_dim,
187-
base=config.rope_theta,
188-
device=weights.device,
189-
)
184+
self.rotary_emb = rotary_emb
190185

191186
mscale = get_mscale(
192187
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
@@ -535,14 +530,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
535530

536531

537532
class DeepseekV3Layer(nn.Module):
538-
def __init__(self, prefix, layer_id, config, weights):
533+
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
539534
super().__init__()
540535
prefix = f"{prefix}.layers.{layer_id}"
541536

542537
self.self_attn = DeepseekV3Attention(
543538
prefix=f"{prefix}.self_attn",
544539
config=config,
545540
weights=weights,
541+
rotary_emb=rotary_emb,
546542
)
547543

548544
if (
@@ -616,6 +612,12 @@ def __init__(self, prefix: str, config, weights: Weights):
616612
self.embed_tokens = TensorParallelEmbedding(
617613
prefix=f"{prefix}.embed_tokens", weights=weights
618614
)
615+
rotary_emb = PositionRotaryEmbedding.static(
616+
config=config,
617+
dim=config.qk_rope_head_dim,
618+
base=config.rope_theta,
619+
device=weights.device,
620+
)
619621

620622
self.layers = nn.ModuleList(
621623
[
@@ -624,6 +626,7 @@ def __init__(self, prefix: str, config, weights: Weights):
624626
layer_id,
625627
config,
626628
weights,
629+
rotary_emb,
627630
)
628631
for layer_id in range(config.num_hidden_layers)
629632
]

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,14 @@ def _load_gqa(config, prefix: str, weights):
166166

167167
class FlashGemma2Attention(torch.nn.Module):
168168
def __init__(
169-
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
169+
self,
170+
prefix: str,
171+
config,
172+
weights,
173+
layer_id,
174+
causal: bool,
175+
is_sliding: bool,
176+
rotary_emb,
170177
):
171178
super().__init__()
172179
self.num_heads = config.num_attention_heads
@@ -176,13 +183,7 @@ def __init__(
176183
self.window_size = config.sliding_window
177184
else:
178185
self.window_size = -1
179-
180-
self.rotary_emb = PositionRotaryEmbedding.static(
181-
config=config,
182-
dim=self.head_size,
183-
base=config.rope_theta,
184-
device=weights.device,
185-
)
186+
self.rotary_emb = rotary_emb
186187

187188
# self.softmax_scale = self.head_size**-0.5
188189
self.softmax_scale = config.query_pre_attn_scalar**-0.5
@@ -354,7 +355,14 @@ def forward(self, hidden_states, adapter_data):
354355

355356
class FlashGemma2Layer(nn.Module):
356357
def __init__(
357-
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
358+
self,
359+
prefix: str,
360+
config,
361+
weights,
362+
layer_id,
363+
causal: bool,
364+
is_sliding: bool,
365+
rotary_emb,
358366
):
359367
super().__init__()
360368
self.self_attn = FlashGemma2Attention(
@@ -364,6 +372,7 @@ def __init__(
364372
layer_id=layer_id,
365373
causal=causal,
366374
is_sliding=is_sliding,
375+
rotary_emb=rotary_emb,
367376
)
368377
self.mlp = Gemma2MLP(
369378
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
@@ -435,6 +444,13 @@ def __init__(self, prefix: str, config, weights, causal: bool):
435444
process_group = weights.process_group
436445
self.tp_rank = process_group.rank()
437446
self.tp_world_size = process_group.size()
447+
rotary_emb = PositionRotaryEmbedding.static(
448+
config=config,
449+
dim=config.head_dim,
450+
base=config.rope_theta,
451+
device=weights.device,
452+
)
453+
438454
self.layers = nn.ModuleList(
439455
[
440456
FlashGemma2Layer(
@@ -444,6 +460,7 @@ def __init__(self, prefix: str, config, weights, causal: bool):
444460
layer_id=layer_id,
445461
causal=causal,
446462
is_sliding=layer_id % 2 == 0,
463+
rotary_emb=rotary_emb,
447464
)
448465
for layer_id in range(config.num_hidden_layers)
449466
]

0 commit comments

Comments
 (0)