@@ -377,7 +377,7 @@ def __init__(
377377 prediction_distance : int ,
378378 ):
379379 super ().__init__ (config , tensor_space , prediction_distance )
380- if config .transformer .diffusion is not None and config . transformer . diffusion == DiffusionStyle .masked :
380+ if config .transformer .diffusion == DiffusionStyle .masked :
381381 self ._loss_name = LanguageModelLossNames .mlm_loss
382382
383383 def _logits_cross_entropy_forward_backward (
@@ -402,44 +402,67 @@ def _logits_cross_entropy_forward_backward(
402402 sequence_parallel = self ._sequence_parallel and self ._parallel_embeddings ,
403403 )
404404
405- masked_indices = kwargs [LanguageModelKwargs .mask_indexes ]
406- p_mask = kwargs [LanguageModelKwargs .mask_probabilities ]
407- # index [0, 1, 2, 3, 4, 5] ->
408- # The labels are already left shifted x = [A, B, C, D, E, F] ->
409- # embd = [A, B, C, D, E]
410- # label = [B, C, D, E, F]
411-
412- # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
413- # can it just learn to copy 3? i.e copy the next token to the masked?
414- # Yes. We need to drop those position from loss if the next token is not masked
415- # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)
416-
417- last_weight = 0
418- B = logits .shape [0 ]
419-
420- loss_weight = torch .cat (
421- (
422- # ar_weight * in_context[:, 1:] + # not implement yet
423- masked_indices [:, 1 :] / p_mask [:, None ],
424- # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
425- (last_weight * torch .ones (B , device = logits .device )).unsqueeze (1 ),
426- # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
427- ),
428- dim = 1 ,
429- ).to (logits .dtype )
405+ if self .config .transformer .diffusion == DiffusionStyle .masked :
406+ masked_indices = kwargs [LanguageModelKwargs .mask_indexes ]
407+ p_mask = kwargs [LanguageModelKwargs .mask_probabilities ]
408+ # index [0, 1, 2, 3, 4, 5] ->
409+ # The labels are already left shifted x = [A, B, C, D, E, F] ->
410+ # embd = [A, B, C, D, E]
411+ # label = [B, C, D, E, F]
412+
413+ # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
414+ # can it just learn to copy 3? i.e copy the next token to the masked?
415+ # Yes. We need to drop those position from loss if the next token is not masked
416+ # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)
417+
418+ last_weight = 0
419+ B = logits .shape [0 ]
420+
421+ loss_weight = torch .cat (
422+ (
423+ # ar_weight * in_context[:, 1:] + # not implement yet
424+ masked_indices [:, 1 :] / p_mask [:, None ],
425+ # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
426+ (last_weight * torch .ones (B , device = logits .device )).unsqueeze (1 ),
427+ # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
428+ ),
429+ dim = 1 ,
430+ ).to (logits .dtype )
431+
432+ # print(f"Loss weight: {loss_weight}")
430433
431- # print(f"Loss weight: {loss_weight}")
434+ loss , grad = cross_entropy_forward_backward (
435+ logits = logits .flatten (0 , - 2 ),
436+ target = target ,
437+ loss_mask = None ,
438+ grad_output = grad_output ,
439+ group = self ._tensor_space .distributed .tensor_group if self ._parallel_embeddings else None ,
440+ implementation = self ._cross_entropy_impl ,
441+ logits_scale_factor = self ._logits_scale_factor ,
442+ loss_weight = loss_weight ,
443+ )
432444
433- loss , grad = cross_entropy_forward_backward (
434- logits = logits .flatten (0 , - 2 ),
435- target = target ,
436- loss_mask = None ,
437- grad_output = grad_output ,
438- group = self ._tensor_space .distributed .tensor_group if self ._parallel_embeddings else None ,
439- implementation = self ._cross_entropy_impl ,
440- logits_scale_factor = self ._logits_scale_factor ,
441- loss_weight = loss_weight ,
442- )
445+ elif self .confing .transformer .diffusion == DiffusionStyle .ar_masked :
446+
447+ loss_weights = kwargs [LanguageModelKwargs .loss_weights ]
448+ context_index = kwargs [LanguageModelKwargs .in_context ]
449+ masked_index = kwargs [LanguageModelKwargs .mask_indexes ]
450+ B = loss_weights .shape [0 ]
451+ masked_index = torch .cat ([masked_index [:, 1 :], torch .zeros (B , 1 , device = loss_weights .device )], dim = 1 )
452+ context_index = torch .cat ([context_index [:, 1 :], torch .zeros (B , 1 , device = loss_weights .device )], dim = 1 )
453+
454+ loss , grad , per_token_loss_b4_weight = cross_entropy_forward_backward (
455+ logits .flatten (0 , - 2 ),
456+ target = target ,
457+ group = self ._tensor_space .distributed .tensor_group if self ._parallel_embeddings else None ,
458+ grad_output = grad_output ,
459+ implementation = self ._cross_entropy_impl ,
460+ logits_scale_factor = self ._logits_scale_factor ,
461+ loss_weight = loss_weights ,
462+ )
463+
464+ losses ["loss_mask_tokens" ].append ((per_token_loss_b4_weight * masked_index ).mean ())
465+ losses ["loss_in_context_tokens" ].append ((per_token_loss_b4_weight * context_index ).mean ())
443466
444467 # This happens with the loss_weight.
445468 # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274
0 commit comments