1313import torch .nn as nn
1414from torch .distributed .device_mesh import DeviceMesh
1515from torch .distributed .pipelining import PipelineStage
16-
16+ from torch . _subclasses . fake_tensor import FakeTensorMode
1717from torch .distributed .pipelining .schedules import (
1818 _PipelineSchedule ,
1919 _PipelineScheduleRuntime ,
@@ -47,7 +47,8 @@ class MmSeparateWeightGrad(torch.autograd.Function):
4747 @staticmethod
4848 def forward (ctx , i , w ):
4949 ctx .save_for_backward (i )
50- return w
50+ grad_shape = (i .shape [0 ], w .shape [1 ])
51+ return torch .empty (grad_shape , dtype = w .dtype , device = w .device , requires_grad = True )
5152
5253 @staticmethod
5354 def backward (ctx , grad_output ):
@@ -59,56 +60,34 @@ class MmSeparateInputGrad(torch.autograd.Function):
5960 @staticmethod
6061 def forward (ctx , i , w ):
6162 ctx .save_for_backward (w )
62- return i
63+ grad_shape = (i .shape [0 ], w .shape [1 ])
64+ return torch .empty (grad_shape , dtype = w .dtype , device = w .device , requires_grad = True )
6365
6466 @staticmethod
6567 def backward (ctx , grad_output ):
6668 (w ,) = ctx .saved_tensors
67- """
68- A[m,k] @ B[k,n] -> O[m,n]
69- grad_o[m,n] @ B.t()[n,k] -> grad_a[m,k]
70- looks right..
71- getting
72- [rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 67, in backward
73- [rank4]:[rank4]: grad_input = grad_output.mm(w.t())
74- [rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 88, in split_mm
75- [rank4]:[rank4]: return MmPassThrough.apply(i1, w1)
76- [rank4]:[rank4]: File "/data/users/whc/pytorch/torch/autograd/function.py", line 583, in apply
77- [rank4]:[rank4]: return super().apply(*args, **kwargs) # type: ignore[misc]
78- [rank4]:[rank4]: File "/data/users/whc/torchtitan/torchtitan/distributed/pipeline_parallel.py", line 74, in forward
79- [rank4]:[rank4]: return torch.mm(x, y)
80- [rank4]:[rank4]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x2816 and 2048x2816)
81-
82- [rank4]:[rank4]: RuntimeError:
83- [rank4]:[rank4]: Failed to run stage backward:
84- [rank4]:[rank4]: Stage output: ('Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)',)
85- [rank4]:[rank4]: Output gradient: ('Tensor(torch.Size([1, 4096, 2048]), grad=False, dtype=torch.bfloat16)',)
86- [rank4]:[rank4]: Input: ['Tensor(torch.Size([1, 4096, 2048]), grad=True, dtype=torch.bfloat16)']
87- [rank4]:[rank4]:
88- """
89- logger .error (f"MmSeparateInputGrad backward: { grad_output .shape = } , { w .t ().shape = } " )
9069 grad_input = grad_output .mm (w .t ())
9170 return grad_input , None
9271
9372 class MmPassThrough (torch .autograd .Function ):
9473 @staticmethod
95- def forward (ctx , x , y ):
74+ def forward (ctx , x , y , fake_1 , fake_2 ):
9675 with torch ._C ._AutoDispatchBelowAutograd ():
9776 return torch .mm (x , y )
9877
9978 @staticmethod
10079 def backward (ctx , gO ):
101- # TODO(whc) - claude first wrote it this way and later tried to return None, None, i'm not sure which is correct
102- return gO , gO
80+ return None , None , gO , gO
10381
10482 def split_mm (i , w ):
10583 # Apply the pass-through node. y is passed to this node so that it can be
10684 # saved for backward, but detach because we don't want to actually build
10785 # this edge of the graph
108- logger .error (f"split_mm forward: { i .shape = } , { w .shape = } " )
109- w1 = MmSeparateWeightGrad .apply (i .detach (), w )
110- i1 = MmSeparateInputGrad .apply (i , w .detach ())
111- return MmPassThrough .apply (i1 , w1 )
86+ # logger.error(f"split_mm forward: {i.shape=}, {w.shape=}")
87+ fake_1 = MmSeparateWeightGrad .apply (i .detach (), w )
88+ fake_2 = MmSeparateInputGrad .apply (i , w .detach ())
89+
90+ return MmPassThrough .apply (i .detach (), w .detach (), fake_1 , fake_2 )
11291
11392 # addmm operator: out = beta * input + alpha * (mat1 @ mat2)
11493 class AddmmSeparateMat2Grad (torch .autograd .Function ):
@@ -237,7 +216,12 @@ class GroupedMmSeparateMat2Grad(torch.autograd.Function):
237216 def forward (ctx , input , mat2 , offs , bias , out_dtype ):
238217 ctx .save_for_backward (input )
239218 ctx .offs = offs
240- return mat2
219+ with FakeTensorMode (allow_non_fake_inputs = True ), torch ._C ._AutoDispatchBelowAutograd (), torch .no_grad ():
220+ fake_output = torch .ops .aten ._grouped_mm .default (
221+ input , mat2 , offs = offs , bias = bias , out_dtype = out_dtype
222+ )
223+
224+ return torch .empty ((1 ,), dtype = input .dtype , device = input .device , requires_grad = True ).expand_as (fake_output )
241225
242226 @staticmethod
243227 def backward (ctx , grad_output ):
@@ -253,7 +237,15 @@ class GroupedMmSeparateInputGrad(torch.autograd.Function):
253237 def forward (ctx , input , mat2 , offs , bias , out_dtype ):
254238 ctx .save_for_backward (mat2 )
255239 ctx .offs = offs
256- return input
240+ with FakeTensorMode (allow_non_fake_inputs = True ), torch ._C ._AutoDispatchBelowAutograd (), torch .no_grad ():
241+ fake_output = torch .ops .aten ._grouped_mm .default (
242+ input , mat2 , offs = offs , bias = bias , out_dtype = out_dtype
243+ )
244+
245+ # grad_shape = fake_output.shape
246+ # print(f"Shape: {fake_output.shape=}")
247+ return torch .empty ((1 ,), dtype = input .dtype , device = input .device , requires_grad = True ).expand_as (fake_output )
248+ # return torch.empty(grad_shape, dtype=input.dtype, device=input.device, requires_grad=True)
257249
258250 @staticmethod
259251 def backward (ctx , grad_output ):
@@ -266,24 +258,24 @@ def backward(ctx, grad_output):
266258
267259 class GroupedMmPassThrough (torch .autograd .Function ):
268260 @staticmethod
269- def forward (ctx , input , mat2 , offs , bias , out_dtype ):
261+ def forward (ctx , input , mat2 , offs , bias , out_dtype , fake_1 , fake_2 ):
270262 with torch ._C ._AutoDispatchBelowAutograd ():
271263 return torch .ops .aten ._grouped_mm .default (
272264 input , mat2 , offs = offs , bias = bias , out_dtype = out_dtype
273265 )
274266
275267 @staticmethod
276268 def backward (ctx , gO ):
277- return gO , gO , None , None , None
269+ return None , None , None , None , None , gO , gO
278270
279271 def split_grouped_mm (input , mat2 , offs = None , bias = None , out_dtype = None ):
280- mat2_1 = GroupedMmSeparateMat2Grad .apply (
272+ fake_1 = GroupedMmSeparateMat2Grad .apply (
281273 input .detach (), mat2 , offs , bias , out_dtype
282274 )
283- input_1 = GroupedMmSeparateInputGrad .apply (
275+ fake_2 = GroupedMmSeparateInputGrad .apply (
284276 input , mat2 .detach (), offs , bias , out_dtype
285277 )
286- return GroupedMmPassThrough .apply (input_1 , mat2_1 , offs , bias , out_dtype )
278+ return GroupedMmPassThrough .apply (input . detach (), mat2 . detach () , offs , bias , out_dtype , fake_1 , fake_2 )
287279
288280 lib .impl ("mm" , split_mm , "Autograd" )
289281 lib .impl ("addmm" , split_addmm , "Autograd" )
0 commit comments