Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ def forward(ctx, i, w):
@staticmethod
def backward(ctx, grad_output):
(w,) = ctx.saved_tensors
"""
A[m,k] @ B[k,n] -> O[m,n]
grad_o[m,n] @ B.t()[n,k] -> grad_a[m,k]
looks right..
getting
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 67, in backward
[rank4]:[rank4]: grad_input = grad_output.mm(w.t())
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 88, in split_mm
[rank4]:[rank4]: return MmPassThrough.apply(i1, w1)
[rank4]:[rank4]: File "/data/users/whc/pytorch/torch/autograd/function.py", line 583, in apply
[rank4]:[rank4]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 74, in forward
[rank4]:[rank4]: return torch.mm(x, y)
[rank4]:[rank4]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x2816 and 2048x2816)

[rank4]:[rank4]: RuntimeError:
[rank4]:[rank4]: Failed to run stage backward:
[rank4]:[rank4]: Stage output: ('Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)',)
[rank4]:[rank4]: Output gradient: ('Tensor(torch.Size([1, 4096, 2048]), grad=False, dtype=torch.bfloat16)',)
[rank4]:[rank4]: Input: ['Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)']
[rank4]:[rank4]:
"""
logger.error(f"MmSeparateInputGrad backward: {grad_output.shape=}, {w.t().shape=}")
grad_input = grad_output.mm(w.t())
return grad_input, None

Expand All @@ -75,13 +98,14 @@ def forward(ctx, x, y):

@staticmethod
def backward(ctx, gO):
# TODO(whc) - claude first wrote it this way and later tried to return None, None, i'm not sure which is correct
return gO, gO

def split_mm(i, w):
print("split mul")
# Apply the pass-through node. y is passed to this node so that it can be
# saved for backward, but detach because we don't want to actually build
# this edge of the graph
logger.error(f"split_mm forward: {i.shape=}, {w.shape=}")
w1 = MmSeparateWeightGrad.apply(i.detach(), w)
i1 = MmSeparateInputGrad.apply(i, w.detach())
return MmPassThrough.apply(i1, w1)
Expand Down Expand Up @@ -138,7 +162,6 @@ def backward(ctx, gO):
return gO, gO, gO, None, None

def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1):
print("split addmm")
mat2_1 = AddmmSeparateMat2Grad.apply(mat1.detach(), mat2, alpha)
mat1_1 = AddmmSeparateMat1Grad.apply(mat1, mat2.detach(), alpha)
bias_1 = AddmmSeparateBiasGrad.apply(bias, beta)
Expand Down Expand Up @@ -197,7 +220,6 @@ def backward(ctx, gO):
return gO, None, gO, None

def split_rms_norm(input, normalized_shape, weight=None, eps=None):
print("split rms_norm")
weight_1 = RmsNormSeparateWeightGrad.apply(
input.detach(), normalized_shape, weight, eps
)
Expand Down Expand Up @@ -255,7 +277,6 @@ def backward(ctx, gO):
return gO, gO, None, None, None

def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None):
print("split grouped_mm")
mat2_1 = GroupedMmSeparateMat2Grad.apply(
input.detach(), mat2, offs, bias, out_dtype
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]"

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
mode = "none" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
Expand Down
Loading