diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 934835c0f..8ecc7df49 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -39,6 +39,8 @@ "pipeline_module_split", ] +lib = torch.library.Library("aten", "IMPL") + def _override_torch_ops_for_zero_bubble(): class MmSeparateWeightGrad(torch.autograd.Function): @@ -142,10 +144,10 @@ def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1): bias_1 = AddmmSeparateBiasGrad.apply(bias, beta) return AddmmPassThrough.apply(bias_1, mat1_1, mat2_1, beta, alpha) - # _fused_rms_norm operator: RMS normalization - class FusedRmsNormSeparateWeightGrad(torch.autograd.Function): + # rms_norm operator: RMS normalization + class RmsNormSeparateWeightGrad(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): + def forward(ctx, input, normalized_shape, weight, eps): ctx.save_for_backward(input) ctx.normalized_shape = normalized_shape ctx.eps = eps @@ -155,6 +157,8 @@ def forward(ctx, input, weight, normalized_shape, eps): def backward(ctx, grad_output): (input,) = ctx.saved_tensors # Compute normalized input for weight gradient + if grad_output is None: + return None, None, None, None variance = input.pow(2).mean(-1, keepdim=True) rstd = torch.rsqrt(variance + ctx.eps) normalized = input * rstd @@ -162,66 +166,71 @@ def backward(ctx, grad_output): grad_weight = (grad_output * normalized).sum( dim=tuple(range(grad_output.ndim - 1)) ) - return None, grad_weight, None, None + return None, None, grad_weight, None - class FusedRmsNormSeparateInputGrad(torch.autograd.Function): + class RmsNormSeparateInputGrad(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): - ctx.save_for_backward(weight) + def forward(ctx, input, normalized_shape, weight, eps): + ( + ctx.save_for_backward(weight) + if weight is not None + else ctx.save_for_backward() + ) ctx.normalized_shape = normalized_shape ctx.eps = eps return input @staticmethod def backward(ctx, grad_output): - (weight,) = ctx.saved_tensors # This is a placeholder - the actual gradient computation happens in PassThrough # Here we just pass through the grad_output weighted by weight return grad_output, None, None, None - class FusedRmsNormPassThrough(torch.autograd.Function): + class RmsNormPassThrough(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): + def forward(ctx, input, normalized_shape, weight, eps): with torch._C._AutoDispatchBelowAutograd(): - return torch.ops.aten._fused_rms_norm( - input, weight, normalized_shape, eps - ) + return torch.rms_norm(input, normalized_shape, weight, eps) @staticmethod def backward(ctx, gO): - return gO, gO, None, None + return gO, None, gO, None - def split_fused_rms_norm(input, weight, normalized_shape, eps): - print("split fused_rms_norm") - weight_1 = FusedRmsNormSeparateWeightGrad.apply( - input.detach(), weight, normalized_shape, eps + def split_rms_norm(input, normalized_shape, weight=None, eps=None): + print("split rms_norm") + weight_1 = RmsNormSeparateWeightGrad.apply( + input.detach(), normalized_shape, weight, eps ) - input_1 = FusedRmsNormSeparateInputGrad.apply( - input, weight.detach(), normalized_shape, eps + input_1 = RmsNormSeparateInputGrad.apply( + input, + normalized_shape, + weight.detach() if weight is not None else None, + eps, ) - return FusedRmsNormPassThrough.apply(input_1, weight_1, normalized_shape, eps) + return RmsNormPassThrough.apply(input_1, normalized_shape, weight_1, eps) # _grouped_mm operator: Grouped matrix multiplication for MoE class GroupedMmSeparateMat2Grad(torch.autograd.Function): @staticmethod - def forward(ctx, input, mat2): + def forward(ctx, input, mat2, offs, bias, out_dtype): ctx.save_for_backward(input) + ctx.offs = offs return mat2 @staticmethod def backward(ctx, grad_output): (input,) = ctx.saved_tensors # Gradient w.r.t. mat2 for grouped mm - # This is simplified - actual implementation may need group-wise computation grad_mat2 = torch.ops.aten._grouped_mm.default( - input.transpose(-1, -2), grad_output, reduce="sum" + input.transpose(-1, -2), grad_output, offs=ctx.offs ) - return None, grad_mat2 + return None, grad_mat2, None, None, None class GroupedMmSeparateInputGrad(torch.autograd.Function): @staticmethod - def forward(ctx, input, mat2): + def forward(ctx, input, mat2, offs, bias, out_dtype): ctx.save_for_backward(mat2) + ctx.offs = offs return input @staticmethod @@ -229,30 +238,35 @@ def backward(ctx, grad_output): (mat2,) = ctx.saved_tensors # Gradient w.r.t. input for grouped mm grad_input = torch.ops.aten._grouped_mm.default( - grad_output, mat2.transpose(-1, -2), reduce="sum" + grad_output, mat2.transpose(-1, -2), offs=ctx.offs ) - return grad_input, None + return grad_input, None, None, None, None class GroupedMmPassThrough(torch.autograd.Function): @staticmethod - def forward(ctx, input, mat2, reduce="sum"): + def forward(ctx, input, mat2, offs, bias, out_dtype): with torch._C._AutoDispatchBelowAutograd(): - return torch.ops.aten._grouped_mm.default(input, mat2, reduce=reduce) + return torch.ops.aten._grouped_mm.default( + input, mat2, offs=offs, bias=bias, out_dtype=out_dtype + ) @staticmethod def backward(ctx, gO): - return gO, gO, None + return gO, gO, None, None, None - def split_grouped_mm(input, mat2, reduce="sum"): + def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None): print("split grouped_mm") - mat2_1 = GroupedMmSeparateMat2Grad.apply(input.detach(), mat2) - input_1 = GroupedMmSeparateInputGrad.apply(input, mat2.detach()) - return GroupedMmPassThrough.apply(input_1, mat2_1, reduce) + mat2_1 = GroupedMmSeparateMat2Grad.apply( + input.detach(), mat2, offs, bias, out_dtype + ) + input_1 = GroupedMmSeparateInputGrad.apply( + input, mat2.detach(), offs, bias, out_dtype + ) + return GroupedMmPassThrough.apply(input_1, mat2_1, offs, bias, out_dtype) - lib = torch.library.Library("aten", "IMPL") lib.impl("mm", split_mm, "Autograd") lib.impl("addmm", split_addmm, "Autograd") - lib.impl("_fused_rms_norm", split_fused_rms_norm, "Autograd") + lib.impl("rms_norm", split_rms_norm, "Autograd") lib.impl("_grouped_mm", split_grouped_mm, "Autograd")