Skip to content

Commit 74d3223

Browse files
committed
fix grad_out passing
and apply pytorch/pytorch#167002 to custom overlap_f_b for dualpipe ghstack-source-id: 4d524cc Pull Request resolved: #2036
1 parent 2f2ee95 commit 74d3223

File tree

2 files changed

+37
-40
lines changed

2 files changed

+37
-40
lines changed

torchtitan/distributed/dual_pipe_v.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,11 @@ def run_forward():
271271
arg_mbs[forward_mb_index],
272272
kwarg_mbs[forward_mb_index],
273273
)
274+
# TODO its error prone to have this logic scattered inside and outside the runtime file..
275+
# this goes along with the patch to pytorch: https://github.com/pytorch/pytorch/pull/167002/
276+
key = f"{forward_stage.stage_index}_{forward_mb_index}"
277+
assert key not in schedule.ownership_tokens
278+
schedule.ownership_tokens[key] = output.view_as(output).grad_fn
274279
schedule._maybe_compute_loss(
275280
forward_stage, output, ctx.target_mbs, forward_mb_index
276281
)

torchtitan/distributed/pipeline_parallel.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.nn as nn
1414
from torch.distributed.device_mesh import DeviceMesh
1515
from torch.distributed.pipelining import PipelineStage
16-
16+
from torch._subclasses.fake_tensor import FakeTensorMode
1717
from torch.distributed.pipelining.schedules import (
1818
_PipelineSchedule,
1919
_PipelineScheduleRuntime,
@@ -47,7 +47,8 @@ class MmSeparateWeightGrad(torch.autograd.Function):
4747
@staticmethod
4848
def forward(ctx, i, w):
4949
ctx.save_for_backward(i)
50-
return w
50+
grad_shape = (i.shape[0], w.shape[1])
51+
return torch.empty(grad_shape, dtype=w.dtype, device=w.device, requires_grad=True)
5152

5253
@staticmethod
5354
def backward(ctx, grad_output):
@@ -59,56 +60,34 @@ class MmSeparateInputGrad(torch.autograd.Function):
5960
@staticmethod
6061
def forward(ctx, i, w):
6162
ctx.save_for_backward(w)
62-
return i
63+
grad_shape = (i.shape[0], w.shape[1])
64+
return torch.empty(grad_shape, dtype=w.dtype, device=w.device, requires_grad=True)
6365

6466
@staticmethod
6567
def backward(ctx, grad_output):
6668
(w,) = ctx.saved_tensors
67-
"""
68-
A[m,k] @ B[k,n] -> O[m,n]
69-
grad_o[m,n] @ B.t()[n,k] -> grad_a[m,k]
70-
looks right..
71-
getting
72-
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 67, in backward
73-
[rank4]:[rank4]: grad_input = grad_output.mm(w.t())
74-
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 88, in split_mm
75-
[rank4]:[rank4]: return MmPassThrough.apply(i1, w1)
76-
[rank4]:[rank4]: File "/data/users/whc/pytorch/torch/autograd/function.py", line 583, in apply
77-
[rank4]:[rank4]: return super().apply(*args, **kwargs) # type: ignore[misc]
78-
[rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 74, in forward
79-
[rank4]:[rank4]: return torch.mm(x, y)
80-
[rank4]:[rank4]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x2816 and 2048x2816)
81-
82-
[rank4]:[rank4]: RuntimeError:
83-
[rank4]:[rank4]: Failed to run stage backward:
84-
[rank4]:[rank4]: Stage output: ('Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)',)
85-
[rank4]:[rank4]: Output gradient: ('Tensor(torch.Size([1, 4096, 2048]), grad=False, dtype=torch.bfloat16)',)
86-
[rank4]:[rank4]: Input: ['Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)']
87-
[rank4]:[rank4]:
88-
"""
89-
logger.error(f"MmSeparateInputGrad backward: {grad_output.shape=}, {w.t().shape=}")
9069
grad_input = grad_output.mm(w.t())
9170
return grad_input, None
9271

9372
class MmPassThrough(torch.autograd.Function):
9473
@staticmethod
95-
def forward(ctx, x, y):
74+
def forward(ctx, x, y, fake_1, fake_2):
9675
with torch._C._AutoDispatchBelowAutograd():
9776
return torch.mm(x, y)
9877

9978
@staticmethod
10079
def backward(ctx, gO):
101-
# TODO(whc) - claude first wrote it this way and later tried to return None, None, i'm not sure which is correct
102-
return gO, gO
80+
return None, None, gO, gO
10381

10482
def split_mm(i, w):
10583
# Apply the pass-through node. y is passed to this node so that it can be
10684
# saved for backward, but detach because we don't want to actually build
10785
# this edge of the graph
108-
logger.error(f"split_mm forward: {i.shape=}, {w.shape=}")
109-
w1 = MmSeparateWeightGrad.apply(i.detach(), w)
110-
i1 = MmSeparateInputGrad.apply(i, w.detach())
111-
return MmPassThrough.apply(i1, w1)
86+
# logger.error(f"split_mm forward: {i.shape=}, {w.shape=}")
87+
fake_1 = MmSeparateWeightGrad.apply(i.detach(), w)
88+
fake_2 = MmSeparateInputGrad.apply(i, w.detach())
89+
90+
return MmPassThrough.apply(i.detach(), w.detach(), fake_1, fake_2)
11291

11392
# addmm operator: out = beta * input + alpha * (mat1 @ mat2)
11493
class AddmmSeparateMat2Grad(torch.autograd.Function):
@@ -237,7 +216,12 @@ class GroupedMmSeparateMat2Grad(torch.autograd.Function):
237216
def forward(ctx, input, mat2, offs, bias, out_dtype):
238217
ctx.save_for_backward(input)
239218
ctx.offs = offs
240-
return mat2
219+
with FakeTensorMode(allow_non_fake_inputs=True), torch._C._AutoDispatchBelowAutograd(), torch.no_grad():
220+
fake_output = torch.ops.aten._grouped_mm.default(
221+
input, mat2, offs=offs, bias=bias, out_dtype=out_dtype
222+
)
223+
224+
return torch.empty((1,), dtype=input.dtype, device=input.device, requires_grad=True).expand_as(fake_output)
241225

242226
@staticmethod
243227
def backward(ctx, grad_output):
@@ -253,7 +237,15 @@ class GroupedMmSeparateInputGrad(torch.autograd.Function):
253237
def forward(ctx, input, mat2, offs, bias, out_dtype):
254238
ctx.save_for_backward(mat2)
255239
ctx.offs = offs
256-
return input
240+
with FakeTensorMode(allow_non_fake_inputs=True), torch._C._AutoDispatchBelowAutograd(), torch.no_grad():
241+
fake_output = torch.ops.aten._grouped_mm.default(
242+
input, mat2, offs=offs, bias=bias, out_dtype=out_dtype
243+
)
244+
245+
# grad_shape = fake_output.shape
246+
# print(f"Shape: {fake_output.shape=}")
247+
return torch.empty((1,), dtype=input.dtype, device=input.device, requires_grad=True).expand_as(fake_output)
248+
# return torch.empty(grad_shape, dtype=input.dtype, device=input.device, requires_grad=True)
257249

258250
@staticmethod
259251
def backward(ctx, grad_output):
@@ -266,24 +258,24 @@ def backward(ctx, grad_output):
266258

267259
class GroupedMmPassThrough(torch.autograd.Function):
268260
@staticmethod
269-
def forward(ctx, input, mat2, offs, bias, out_dtype):
261+
def forward(ctx, input, mat2, offs, bias, out_dtype, fake_1, fake_2):
270262
with torch._C._AutoDispatchBelowAutograd():
271263
return torch.ops.aten._grouped_mm.default(
272264
input, mat2, offs=offs, bias=bias, out_dtype=out_dtype
273265
)
274266

275267
@staticmethod
276268
def backward(ctx, gO):
277-
return gO, gO, None, None, None
269+
return None, None, None, None, None, gO, gO
278270

279271
def split_grouped_mm(input, mat2, offs=None, bias=None, out_dtype=None):
280-
mat2_1 = GroupedMmSeparateMat2Grad.apply(
272+
fake_1 = GroupedMmSeparateMat2Grad.apply(
281273
input.detach(), mat2, offs, bias, out_dtype
282274
)
283-
input_1 = GroupedMmSeparateInputGrad.apply(
275+
fake_2 = GroupedMmSeparateInputGrad.apply(
284276
input, mat2.detach(), offs, bias, out_dtype
285277
)
286-
return GroupedMmPassThrough.apply(input_1, mat2_1, offs, bias, out_dtype)
278+
return GroupedMmPassThrough.apply(input.detach(), mat2.detach(), offs, bias, out_dtype, fake_1, fake_2)
287279

288280
lib.impl("mm", split_mm, "Autograd")
289281
lib.impl("addmm", split_addmm, "Autograd")

0 commit comments

Comments
 (0)