Skip to content

Commit 0b469fb

Browse files
committed
include ar+diff option as a seperate style
1 parent 632dc7c commit 0b469fb

File tree

5 files changed

+124
-65
lines changed

5 files changed

+124
-65
lines changed

fast_llm/data/data/gpt/data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
139139
token_ids = torch.from_numpy(stacked_ids)
140140

141141
if sampling_parameters.diffusion.style == DiffusionStyle.masked:
142-
143142
diffusion_config = sampling_parameters.diffusion
144143

145144
batch_size, seq_len = token_ids.shape

fast_llm/layers/language_model/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class LanguageModelKwargs:
4242
mask_indexes = "mask_indexes"
4343
mask_probabilities = "mask_probabilities"
4444
mask_inputs = "mask_inputs"
45+
loss_weights = "loss_weights"
46+
in_context = "in_context"
4547

4648

4749
@config_class()

fast_llm/layers/language_model/head.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def __init__(
377377
prediction_distance: int,
378378
):
379379
super().__init__(config, tensor_space, prediction_distance)
380-
if config.transformer.diffusion is not None and config.transformer.diffusion == DiffusionStyle.masked:
380+
if config.transformer.diffusion == DiffusionStyle.masked:
381381
self._loss_name = LanguageModelLossNames.mlm_loss
382382

383383
def _logits_cross_entropy_forward_backward(
@@ -402,44 +402,67 @@ def _logits_cross_entropy_forward_backward(
402402
sequence_parallel=self._sequence_parallel and self._parallel_embeddings,
403403
)
404404

405-
masked_indices = kwargs[LanguageModelKwargs.mask_indexes]
406-
p_mask = kwargs[LanguageModelKwargs.mask_probabilities]
407-
# index [0, 1, 2, 3, 4, 5] ->
408-
# The labels are already left shifted x = [A, B, C, D, E, F] ->
409-
# embd = [A, B, C, D, E]
410-
# label = [B, C, D, E, F]
411-
412-
# Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
413-
# can it just learn to copy 3? i.e copy the next token to the masked?
414-
# Yes. We need to drop those position from loss if the next token is not masked
415-
# We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)
416-
417-
last_weight = 0
418-
B = logits.shape[0]
419-
420-
loss_weight = torch.cat(
421-
(
422-
# ar_weight * in_context[:, 1:] + # not implement yet
423-
masked_indices[:, 1:] / p_mask[:, None],
424-
# + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
425-
(last_weight * torch.ones(B, device=logits.device)).unsqueeze(1),
426-
# This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
427-
),
428-
dim=1,
429-
).to(logits.dtype)
405+
if self.config.transformer.diffusion == DiffusionStyle.masked:
406+
masked_indices = kwargs[LanguageModelKwargs.mask_indexes]
407+
p_mask = kwargs[LanguageModelKwargs.mask_probabilities]
408+
# index [0, 1, 2, 3, 4, 5] ->
409+
# The labels are already left shifted x = [A, B, C, D, E, F] ->
410+
# embd = [A, B, C, D, E]
411+
# label = [B, C, D, E, F]
412+
413+
# Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
414+
# can it just learn to copy 3? i.e copy the next token to the masked?
415+
# Yes. We need to drop those position from loss if the next token is not masked
416+
# We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)
417+
418+
last_weight = 0
419+
B = logits.shape[0]
420+
421+
loss_weight = torch.cat(
422+
(
423+
# ar_weight * in_context[:, 1:] + # not implement yet
424+
masked_indices[:, 1:] / p_mask[:, None],
425+
# + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
426+
(last_weight * torch.ones(B, device=logits.device)).unsqueeze(1),
427+
# This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
428+
),
429+
dim=1,
430+
).to(logits.dtype)
431+
432+
# print(f"Loss weight: {loss_weight}")
430433

431-
# print(f"Loss weight: {loss_weight}")
434+
loss, grad = cross_entropy_forward_backward(
435+
logits=logits.flatten(0, -2),
436+
target=target,
437+
loss_mask=None,
438+
grad_output=grad_output,
439+
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
440+
implementation=self._cross_entropy_impl,
441+
logits_scale_factor=self._logits_scale_factor,
442+
loss_weight=loss_weight,
443+
)
432444

433-
loss, grad = cross_entropy_forward_backward(
434-
logits=logits.flatten(0, -2),
435-
target=target,
436-
loss_mask=None,
437-
grad_output=grad_output,
438-
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
439-
implementation=self._cross_entropy_impl,
440-
logits_scale_factor=self._logits_scale_factor,
441-
loss_weight=loss_weight,
442-
)
445+
elif self.confing.transformer.diffusion == DiffusionStyle.ar_masked:
446+
447+
loss_weights = kwargs[LanguageModelKwargs.loss_weights]
448+
context_index = kwargs[LanguageModelKwargs.in_context]
449+
masked_index = kwargs[LanguageModelKwargs.mask_indexes]
450+
B = loss_weights.shape[0]
451+
masked_index = torch.cat([masked_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1)
452+
context_index = torch.cat([context_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1)
453+
454+
loss, grad, per_token_loss_b4_weight = cross_entropy_forward_backward(
455+
logits.flatten(0, -2),
456+
target=target,
457+
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
458+
grad_output=grad_output,
459+
implementation=self._cross_entropy_impl,
460+
logits_scale_factor=self._logits_scale_factor,
461+
loss_weight=loss_weights,
462+
)
463+
464+
losses["loss_mask_tokens"].append((per_token_loss_b4_weight * masked_index).mean())
465+
losses["loss_in_context_tokens"].append((per_token_loss_b4_weight * context_index).mean())
443466

444467
# This happens with the loss_weight.
445468
# MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274

fast_llm/layers/transformer/attention.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,15 +389,41 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[
389389
softmax_scale=self._softmax_scale,
390390
)
391391
input_ = input_.flatten(-2)
392+
392393
else:
393394
# TODO: Avoid the flattens.
395+
394396
input_ = self._attn_fused(
395397
query.flatten(-2),
396398
key.flatten(-2),
397399
value.flatten(-2),
398400
kwargs[TransformerKwargs.attention_mask],
399401
kwargs[TransformerKwargs.attention_mask_value],
400402
)
403+
# print(f"Fused: Attention: {input_.shape} {input_} ")
404+
405+
flash_input_ = _flash_attn_func(
406+
query,
407+
key,
408+
value,
409+
window_size=(-1, -1) if window_size is None else (window_size - 1, 0),
410+
dropout_p=self._config.attention_dropout if self.training else 0.0,
411+
causal=False,
412+
softmax_scale=self._softmax_scale,
413+
)
414+
# print(f"1: Flash : Attention: {flash_input_.shape} {flash_input_} ")
415+
flash_input_ = flash_input_.flatten(-2)
416+
# print(f"2: Flash: Attention: {flash_input_.shape} {flash_input_} ")
417+
diff = input_ - flash_input_
418+
# print(f"Element-wise difference: {diff.shape} {diff}")
419+
max_diff = diff.abs().max()
420+
min_diff = diff.abs().min()
421+
print(f"Min element-wise difference: {min_diff.item()}")
422+
print(f"Max element-wise difference: {max_diff.item()}")
423+
# if max_diff > 1e-3:
424+
# print("Warning: Max difference exceeds 1e-3")
425+
# import sys
426+
# sys.exit(1)
401427

402428
if self._debug_transformer:
403429
self._debug_log(query, "query", self._QUERY_DIMS, kwargs)

fast_llm/models/gpt/model.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ def preprocess(
351351
)
352352
# Setup bidirection attention for masked diffusion
353353
# It uses _flash_attn_func so no need to set attention_mask and attention_mask_value.
354-
# kwargs[TransformerKwargs.causal] = False
354+
kwargs[TransformerKwargs.causal] = False
355+
355356
batch_size, seq_len = batch.token_ids.shape
356357
seq_len -= 1 # last token is dropped inputs
357358
attention_mask = torch.ones(
@@ -395,40 +396,48 @@ def preprocess(
395396
# seq_len -= 1 # last token is drop from the input
396397
# # Compute attention mask for diffusion
397398
C = batch.in_context_length.to(device=self._tensor_space.distributed.device)
398-
# row_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view(1, seq_len, 1)
399-
# col_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view(1, 1, seq_len)
400-
# C_exp = C.view(batch_size, 1, 1)
399+
row_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view(
400+
1, seq_len, 1
401+
)
402+
col_idx = torch.arange(seq_len, device=self._tensor_space.distributed.device).view(
403+
1, 1, seq_len
404+
)
405+
C_exp = C.view(batch_size, 1, 1)
401406

402-
# causal_mask = col_idx <= row_idx
403-
# row_idx < C_exp
404-
# col_idx < C_exp
407+
causal_mask = col_idx <= row_idx
408+
row_idx < C_exp
409+
col_idx < C_exp
405410

406-
# attn_mask = torch.zeros(
407-
# batch_size, seq_len, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device
408-
# )
411+
attn_mask = torch.zeros(
412+
batch_size,
413+
seq_len,
414+
seq_len,
415+
dtype=torch.bool,
416+
device=self._tensor_space.distributed.device,
417+
)
409418

410-
# for b in range(batch_size):
411-
# C_val = C[b].item()
419+
for b in range(batch_size):
420+
C_val = C[b].item()
412421

413-
# if C_val > 0:
414-
# context_causal = causal_mask[0, :C_val, :C_val]
415-
# attn_mask[b, :C_val, :C_val] = context_causal
422+
if C_val > 0:
423+
context_causal = causal_mask[0, :C_val, :C_val]
424+
attn_mask[b, :C_val, :C_val] = context_causal
416425

417-
# if C_val > 0 and C_val < seq_len:
418-
# attn_mask[b, C_val:, :C_val] = True
426+
if C_val > 0 and C_val < seq_len:
427+
attn_mask[b, C_val:, :C_val] = True
419428

420-
# if C_val < seq_len:
421-
# attn_mask[b, C_val:, C_val:] = True
429+
if C_val < seq_len:
430+
attn_mask[b, C_val:, C_val:] = True
422431

423432
# Handle padding if needed
424-
# if batch.sequence_lengths is not None:
425-
# padded = torch.zeros(
426-
# batch_size, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device
427-
# )
428-
# for b in range(batch_size):
429-
# padded[b, batch.sequence_lengths[b] :] = True
430-
# not_padded = ~padded[:, 1:]
431-
# attn_mask = attn_mask & not_padded.unsqueeze(1) & not_padded.unsqueeze(2)
433+
if batch.sequence_lengths is not None:
434+
padded = torch.zeros(
435+
batch_size, seq_len, dtype=torch.bool, device=self._tensor_space.distributed.device
436+
)
437+
for b in range(batch_size):
438+
padded[b, batch.sequence_lengths[b] :] = True
439+
not_padded = ~padded[:, 1:]
440+
attn_mask = attn_mask & not_padded.unsqueeze(1) & not_padded.unsqueeze(2)
432441

433442
# Reshape to match expected attention mask format
434443
attention_mask = attn_mask.unsqueeze(1).unsqueeze(1) # Add additional dimension

0 commit comments

Comments
 (0)