Skip to content

Commit ffd76d1

Browse files
kshitij12345pytorchmergebot
authored andcommitted
[fix] take : backward batching rule (pytorch#95772)
Fixes pytorch#95738 Pull Request resolved: pytorch#95772 Approved by: https://github.com/zou3519
1 parent 7d5d5be commit ffd76d1

File tree

3 files changed

+43
-22
lines changed

3 files changed

+43
-22
lines changed

test/functorch/test_eager_transforms.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,17 @@ def test_simple_not_flat(self, device, jacapi):
16461646
expected = expected.view(2, 3, 2, 3)
16471647
assert torch.allclose(y, expected)
16481648

1649+
@jacrev_and_jacfwd
1650+
def test_take(self, device, jacapi):
1651+
x = torch.rand(5)
1652+
1653+
def func(x):
1654+
y = torch.ones(3, dtype=torch.long)
1655+
z = torch.take(x, y)
1656+
return z
1657+
1658+
self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x))
1659+
16491660
@FIXME_jacrev_only
16501661
def test_diff_numel(self, device, jacapi):
16511662
x = torch.randn(2, 4, device=device)
@@ -2172,26 +2183,38 @@ def f(x):
21722183
def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
21732184
# With chunk_size=1, we shouldn't `vmap` and hence not be limited
21742185
# by it's constraints.
2186+
x = torch.randn(3, 3, device=device)
21752187

2176-
x = torch.randn(3, device=device)
2177-
idx_1 = torch.tensor([0, ], device=device)
2178-
idx_2 = torch.tensor([0, 1], device=device)
2179-
chunk_size = 1
2180-
2181-
def f(x, idx):
2182-
# `take` doesn't work with vmap
2183-
# as it returns an output with dynamic shape.
2184-
return torch.take(x, idx)
2185-
2186-
for fn, idx in ((f, idx_1), (f, idx_2)):
2187-
jacfn = jacrev(fn, chunk_size=chunk_size, _preallocate_and_copy=_preallocate_and_copy)
2188-
actual = jacfn(x, idx)
2189-
expected = torch.autograd.functional.jacobian(partial(fn, idx=idx), x, vectorize=False)
2190-
self.assertEqual(actual, expected)
2188+
# Function with Dynamic Op in Backward.
2189+
# This should cause jacrev/vmap(vjp) to fail.
2190+
class IdentityWithDynamicBackwardOp(torch.autograd.Function):
2191+
@staticmethod
2192+
def forward(input):
2193+
return input
21912194

2192-
msg = r"vmap: .* is not possible because there exists a Tensor"
2193-
with self.assertRaisesRegex(RuntimeError, msg):
2194-
jacrev(fn, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x, idx)
2195+
@staticmethod
2196+
def setup_context(ctx, inputs, output):
2197+
pass
2198+
2199+
@staticmethod
2200+
def backward(ctx, grad_output):
2201+
# dynamic op in backward pass.
2202+
grad_output.nonzero()
2203+
return grad_output
2204+
2205+
def f(x):
2206+
return IdentityWithDynamicBackwardOp.apply(x)
2207+
2208+
# With `chunk_size=1`, we don't use vmap. So the following should work.
2209+
jacfn = jacrev(f, chunk_size=1, _preallocate_and_copy=_preallocate_and_copy)
2210+
actual = jacfn(x)
2211+
expected = torch.autograd.functional.jacobian(f, x, vectorize=False)
2212+
self.assertEqual(actual, expected)
2213+
2214+
# Should fail with `chunk_size=2`.
2215+
msg = r"vmap: We do not support batching operators that can output dynamic shape."
2216+
with self.assertRaisesRegex(RuntimeError, msg):
2217+
jacrev(f, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x)
21952218

21962219
def test_complex_error(self, device):
21972220
# Verify complex input raises error

test/functorch/test_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,6 @@ def fn(inp, *args, **kwargs):
788788
xfail("normal"), # calls random op
789789
xfail("normal", "number_mean"), # calls random op
790790
xfail("pca_lowrank"), # calls random op
791-
xfail("put"), # vmap: inplace into a regular tensor
792791
# https://github.com/pytorch/pytorch/issues/96560
793792
decorate('linalg.pinv', 'hermitian', decorator=skipIfRocm),
794793
xfail("quantile", device_type='cpu'), # Batching rule not implemented for `at::equal`
@@ -882,7 +881,6 @@ def vjp_of_vjp(*args_and_cotangents):
882881
xfail('masked_scatter'), # dynamic
883882
xfail('nn.functional.fractional_max_pool2d'), # random
884883
xfail('nn.functional.fractional_max_pool3d'), # random
885-
xfail('take'), # dynamic
886884
xfail('pca_lowrank', ''), # randomness
887885
xfail('svd_lowrank', ''), # randomness
888886
xfail('to_sparse', ''), # non-dense output

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6702,9 +6702,9 @@ Tensor take_backward(
67026702
const Tensor& indices) {
67036703
Tensor grad_self = at::zeros_like(self);
67046704
// For Composite Compliance,
6705-
// if `grad` and `indices` are CCT but `self` is not
6705+
// if `grad` and `indices` are CCT but `grad_self` is not
67066706
// then we use the out-of-place variant of `put`.
6707-
if (!isTensorSubclassLike(self) &&
6707+
if (!isTensorSubclassLike(grad_self) &&
67086708
areAnyTensorSubclassLike({grad, indices})) {
67096709
return grad_self.put(indices, grad, true);
67106710
}

0 commit comments

Comments
 (0)