Skip to content

Commit fd79188

Browse files
authored
Refactor: eagle3 training loop & loss mask (#548)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** A few refactors to eagle3 training code: - Put first eagle step into TTT loop; - Support TTT mask beyond 4; - Remove answer-only loss mask in data loading. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 1a89a88 commit fd79188

File tree

2 files changed

+60
-115
lines changed

2 files changed

+60
-115
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -78,31 +78,11 @@ def get_role_content(item):
7878
return_tensors="pt",
7979
add_special_tokens=False,
8080
truncation=True,
81-
return_offsets_mapping=True,
8281
)
8382
input_ids = output.input_ids[0]
8483
attention_mask = output.attention_mask[0]
85-
offset_mapping = output.offset_mapping[0]
86-
loss_mask = torch.zeros_like(input_ids)
87-
labels = torch.full_like(input_ids, IGNORE_TOKEN_ID)
88-
89-
for turn in messages:
90-
if turn["role"] == "assistant":
91-
content = turn["content"]
92-
# Unfortunate strip() necessary because chat templates are doing the same.
93-
start = conversation.index(content.strip())
94-
stop = start + len(content)
95-
indices = []
96-
for tok_index, (tok_start, tok_stop) in enumerate(offset_mapping):
97-
if tok_start >= start and tok_stop <= stop:
98-
indices.append(tok_index)
99-
labels[indices] = input_ids[indices]
100-
loss_mask[indices] = 1
101-
102-
# Shift loss_mask and labels to the left by 1 token
103-
loss_mask = torch.cat([loss_mask[1:], torch.zeros(1, dtype=loss_mask.dtype)])
104-
labels = torch.cat([labels[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=labels.dtype)])
105-
84+
loss_mask = torch.ones_like(input_ids)
85+
labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)])
10686
new_examples["input_ids"].append(input_ids)
10787
new_examples["attention_mask"].append(attention_mask)
10888
new_examples["loss_mask"].append(loss_mask)
@@ -158,7 +138,7 @@ def convert_role(role):
158138
input_ids = output.input_ids[0]
159139
attention_mask = output.attention_mask[0]
160140
loss_mask = torch.ones_like(input_ids)
161-
labels = torch.full_like(input_ids, IGNORE_TOKEN_ID)
141+
labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)])
162142
# TODO: add labels and answer-only loss masking?
163143

164144
new_examples["input_ids"].append(input_ids)

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 57 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -508,17 +508,15 @@ def modify(
508508
# https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
509509
self.is_quantized = False
510510

511-
self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later.
512-
self._cached_attn_blk_masks = []
511+
self.num_ttt_steps = 4 # NOTE: (hg) hardcoded for now. Might add to config later.
512+
self._cached_attn_blk_masks = {}
513513

514514
def _get_ttt_attention_mask(self, seq_length, ttt_step):
515515
# compile and cached flex attention masks in first call
516-
if ttt_step >= len(self._cached_attn_blk_masks):
517-
self._cached_attn_blk_masks.append(
518-
self._compute_ttt_attention_mask(seq_length, ttt_step)
516+
if ttt_step not in self._cached_attn_blk_masks:
517+
self._cached_attn_blk_masks.update(
518+
{ttt_step: self._compute_ttt_attention_mask(seq_length, ttt_step)}
519519
)
520-
521-
# return cached flex attention mask
522520
return self._cached_attn_blk_masks[ttt_step]
523521

524522
def _prepare_decoder_attention_mask(
@@ -600,44 +598,26 @@ def _get_eagle_module_inputs(
600598

601599
def _compute_ttt_attention_mask(self, seq_length, ttt_step) -> BlockMask | torch.Tensor:
602600
"""Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""
603-
if ttt_step == 0:
604-
605-
def msk(b, h, q_idx, kv_idx):
606-
# symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
607-
return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length)
608601

609-
elif ttt_step == 1:
610-
611-
def msk(b, h, q_idx, kv_idx):
612-
# attention mask of shape [seq_len, 3* seq_len] for TTT step 1
613-
return (
614-
(kv_idx <= (q_idx - 2))
615-
| ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length))
616-
| ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2))
617-
)
618-
elif ttt_step == 2:
619-
620-
def msk(b, h, q_idx, kv_idx):
621-
# attention mask of shape [seq_len, 4* seq_len] for TTT step 2
622-
return (
623-
(kv_idx <= (q_idx - 3))
624-
| ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length))
625-
| ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2))
626-
| ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3))
602+
def msk_func(b, h, q_idx, kv_idx):
603+
mask = kv_idx <= (q_idx - ttt_step)
604+
for i in range(1, ttt_step + 1):
605+
mask_block_i = (kv_idx == q_idx + i * seq_length - (ttt_step - i)) & (
606+
kv_idx >= seq_length * i
627607
)
628-
else:
629-
raise ValueError(f"EAGLE TTT step {ttt_step} is not supported")
608+
mask = mask | mask_block_i
609+
return mask
630610

631611
dtypemin = torch.finfo(self._base_llm_config.dtype).min
632612
q_len = seq_length
633-
kv_len = seq_length * (2 + ttt_step)
613+
kv_len = seq_length * (1 + ttt_step)
634614
if self.eagle_module.config._attn_implementation == "flex_attention":
635615
# Return block mask for flex attention
636-
block_mask = create_block_mask(msk, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len)
616+
block_mask = create_block_mask(msk_func, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len)
637617
return block_mask
638618
else:
639619
# Return tensor mask for non-flex attention
640-
tensor_mask = msk(
620+
tensor_mask = msk_func(
641621
None,
642622
None,
643623
torch.arange(q_len).view(1, 1, q_len, 1),
@@ -847,69 +827,54 @@ def forward(
847827
inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs)
848828

849829
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids)
850-
851-
# Then, we run eagle forward
852-
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
853-
eagle_input_hidden_states,
854-
inputs_embeds,
855-
attention_mask_0,
856-
position_ids,
857-
position_embeddings,
858-
eagle_cache,
859-
)
860-
861830
past_key_values.eagle_cache = eagle_cache
862831

863-
# Compute loss on the eagle modules
864-
classification_loss, acc = self._eagle_loss(
865-
base_model_logits[:, 1:],
866-
eagle_logits[:, :-1],
867-
loss_mask[:, 1:],
868-
)
869-
eagle_loss = classification_loss
870-
train_accs.append(acc)
871-
872832
# ====Perform training-time-testing with 3 extra eagle forward passes====
873-
if self.training:
874-
for ttt_step in range(self.num_ttt_steps):
875-
eagle_input_hidden_states = torch.cat(
833+
for ttt_step in range(self.num_ttt_steps):
834+
attention_mask = (
835+
attention_mask_0
836+
if ttt_step == 0
837+
else self._get_ttt_attention_mask(seq_length, ttt_step)
838+
)
839+
_, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward(
840+
eagle_input_hidden_states,
841+
inputs_embeds,
842+
attention_mask,
843+
position_ids,
844+
position_embeddings,
845+
eagle_cache,
846+
)
847+
eagle_input_hidden_states = torch.cat(
848+
(
849+
torch.zeros(
850+
(b, 1, h),
851+
dtype=eagle_input_hidden_states.dtype,
852+
device=eagle_input_hidden_states.device,
853+
),
854+
eagle_input_hidden_states[:, :-1, :],
855+
),
856+
dim=1,
857+
)
858+
classification_loss, acc = self._eagle_loss(
859+
# base model predict +1 tok, while eagle predict +2
860+
# so we shift base model outputs compared to eagle outputs
861+
base_model_logits[:, 1:],
862+
eagle_logits[:, :-1],
863+
# additionally, we mask the first n tok of eagle outputs at nth TTT step
864+
torch.cat(
876865
(
877-
torch.zeros(
878-
(b, 1, h),
879-
dtype=eagle_input_hidden_states.dtype,
880-
device=eagle_input_hidden_states.device,
881-
),
882-
eagle_prenorm_h[:, :-1, :],
866+
torch.zeros(b, ttt_step, dtype=loss_mask.dtype, device=loss_mask.device),
867+
loss_mask[:, 1 + ttt_step :],
883868
),
884869
dim=1,
885-
)
886-
attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step)
887-
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
888-
eagle_input_hidden_states,
889-
inputs_embeds,
890-
attention_mask,
891-
position_ids,
892-
position_embeddings,
893-
eagle_cache,
894-
)
895-
classification_loss, acc = self._eagle_loss(
896-
# base model predict +1 tok, while eagle predict +2
897-
# so we shift base model outputs compared to eagle outputs
898-
base_model_logits[:, 1:],
899-
eagle_logits[:, :-1],
900-
# additionally, we mask the first n tok of eagle outputs at nth TTT step
901-
torch.cat(
902-
(
903-
torch.zeros(
904-
b, 1 + ttt_step, dtype=loss_mask.dtype, device=loss_mask.device
905-
),
906-
loss_mask[:, 2 + ttt_step :],
907-
),
908-
dim=1,
909-
),
910-
)
911-
eagle_loss += classification_loss
912-
train_accs.append(acc)
870+
),
871+
)
872+
eagle_loss = (
873+
classification_loss if eagle_loss is None else eagle_loss + classification_loss
874+
)
875+
train_accs.append(acc)
876+
if not self.training:
877+
break
913878
# Finally, we merge base model loss and eagle loss, raise error if both are None
914879
if base_model_loss is not None and eagle_loss is not None:
915880
loss = base_model_loss + eagle_loss

0 commit comments

Comments
 (0)