diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index fe676f1781..934835c0fc 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -68,7 +68,8 @@ def backward(ctx, grad_output): class MmPassThrough(torch.autograd.Function): @staticmethod def forward(ctx, x, y): - return torch.mm(x, y) + with torch._C._AutoDispatchBelowAutograd(): + return torch.mm(x, y) @staticmethod def backward(ctx, gO): @@ -83,8 +84,176 @@ def split_mm(i, w): i1 = MmSeparateInputGrad.apply(i, w.detach()) return MmPassThrough.apply(i1, w1) + # addmm operator: out = beta * input + alpha * (mat1 @ mat2) + class AddmmSeparateMat2Grad(torch.autograd.Function): + @staticmethod + def forward(ctx, mat1, mat2, alpha): + ctx.save_for_backward(mat1) + ctx.alpha = alpha + return mat2 + + @staticmethod + def backward(ctx, grad_output): + (mat1,) = ctx.saved_tensors + # Gradient w.r.t. mat2: alpha * mat1.T @ grad_output + grad_mat2 = mat1.t().mm(grad_output) * ctx.alpha + return None, grad_mat2, None + + class AddmmSeparateMat1Grad(torch.autograd.Function): + @staticmethod + def forward(ctx, mat1, mat2, alpha): + ctx.save_for_backward(mat2) + ctx.alpha = alpha + return mat1 + + @staticmethod + def backward(ctx, grad_output): + (mat2,) = ctx.saved_tensors + # Gradient w.r.t. mat1: alpha * grad_output @ mat2.T + grad_mat1 = grad_output.mm(mat2.t()) * ctx.alpha + return grad_mat1, None, None + + class AddmmSeparateBiasGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, bias, beta): + ctx.beta = beta + return bias + + @staticmethod + def backward(ctx, grad_output): + # Gradient w.r.t. bias: beta * sum(grad_output, dim=0) + grad_bias = grad_output.sum(dim=0) * ctx.beta + return grad_bias, None + + class AddmmPassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, bias, mat1, mat2, beta, alpha): + with torch._C._AutoDispatchBelowAutograd(): + return torch.addmm(bias, mat1, mat2, beta=beta, alpha=alpha) + + @staticmethod + def backward(ctx, gO): + return gO, gO, gO, None, None + + def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1): + print("split addmm") + mat2_1 = AddmmSeparateMat2Grad.apply(mat1.detach(), mat2, alpha) + mat1_1 = AddmmSeparateMat1Grad.apply(mat1, mat2.detach(), alpha) + bias_1 = AddmmSeparateBiasGrad.apply(bias, beta) + return AddmmPassThrough.apply(bias_1, mat1_1, mat2_1, beta, alpha) + + # _fused_rms_norm operator: RMS normalization + class FusedRmsNormSeparateWeightGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps): + ctx.save_for_backward(input) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + return weight + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + # Compute normalized input for weight gradient + variance = input.pow(2).mean(-1, keepdim=True) + rstd = torch.rsqrt(variance + ctx.eps) + normalized = input * rstd + # Gradient w.r.t. weight: sum over batch dimension + grad_weight = (grad_output * normalized).sum( + dim=tuple(range(grad_output.ndim - 1)) + ) + return None, grad_weight, None, None + + class FusedRmsNormSeparateInputGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps): + ctx.save_for_backward(weight) + ctx.normalized_shape = normalized_shape + ctx.eps = eps + return input + + @staticmethod + def backward(ctx, grad_output): + (weight,) = ctx.saved_tensors + # This is a placeholder - the actual gradient computation happens in PassThrough + # Here we just pass through the grad_output weighted by weight + return grad_output, None, None, None + + class FusedRmsNormPassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps): + with torch._C._AutoDispatchBelowAutograd(): + return torch.ops.aten._fused_rms_norm( + input, weight, normalized_shape, eps + ) + + @staticmethod + def backward(ctx, gO): + return gO, gO, None, None + + def split_fused_rms_norm(input, weight, normalized_shape, eps): + print("split fused_rms_norm") + weight_1 = FusedRmsNormSeparateWeightGrad.apply( + input.detach(), weight, normalized_shape, eps + ) + input_1 = FusedRmsNormSeparateInputGrad.apply( + input, weight.detach(), normalized_shape, eps + ) + return FusedRmsNormPassThrough.apply(input_1, weight_1, normalized_shape, eps) + + # _grouped_mm operator: Grouped matrix multiplication for MoE + class GroupedMmSeparateMat2Grad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mat2): + ctx.save_for_backward(input) + return mat2 + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + # Gradient w.r.t. mat2 for grouped mm + # This is simplified - actual implementation may need group-wise computation + grad_mat2 = torch.ops.aten._grouped_mm.default( + input.transpose(-1, -2), grad_output, reduce="sum" + ) + return None, grad_mat2 + + class GroupedMmSeparateInputGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mat2): + ctx.save_for_backward(mat2) + return input + + @staticmethod + def backward(ctx, grad_output): + (mat2,) = ctx.saved_tensors + # Gradient w.r.t. input for grouped mm + grad_input = torch.ops.aten._grouped_mm.default( + grad_output, mat2.transpose(-1, -2), reduce="sum" + ) + return grad_input, None + + class GroupedMmPassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mat2, reduce="sum"): + with torch._C._AutoDispatchBelowAutograd(): + return torch.ops.aten._grouped_mm.default(input, mat2, reduce=reduce) + + @staticmethod + def backward(ctx, gO): + return gO, gO, None + + def split_grouped_mm(input, mat2, reduce="sum"): + print("split grouped_mm") + mat2_1 = GroupedMmSeparateMat2Grad.apply(input.detach(), mat2) + input_1 = GroupedMmSeparateInputGrad.apply(input, mat2.detach()) + return GroupedMmPassThrough.apply(input_1, mat2_1, reduce) + lib = torch.library.Library("aten", "IMPL") lib.impl("mm", split_mm, "Autograd") + lib.impl("addmm", split_addmm, "Autograd") + lib.impl("_fused_rms_norm", split_fused_rms_norm, "Autograd") + lib.impl("_grouped_mm", split_grouped_mm, "Autograd") def pipeline_llm(