@@ -1646,6 +1646,17 @@ def test_simple_not_flat(self, device, jacapi):
16461646 expected = expected .view (2 , 3 , 2 , 3 )
16471647 assert torch .allclose (y , expected )
16481648
1649+ @jacrev_and_jacfwd
1650+ def test_take (self , device , jacapi ):
1651+ x = torch .rand (5 )
1652+
1653+ def func (x ):
1654+ y = torch .ones (3 , dtype = torch .long )
1655+ z = torch .take (x , y )
1656+ return z
1657+
1658+ self .assertEqual (jacrev (func )(x ), torch .autograd .functional .jacobian (func , x ))
1659+
16491660 @FIXME_jacrev_only
16501661 def test_diff_numel (self , device , jacapi ):
16511662 x = torch .randn (2 , 4 , device = device )
@@ -2172,26 +2183,38 @@ def f(x):
21722183 def test_chunk_jacrev_chunksize_one (self , device , _preallocate_and_copy ):
21732184 # With chunk_size=1, we shouldn't `vmap` and hence not be limited
21742185 # by it's constraints.
2186+ x = torch .randn (3 , 3 , device = device )
21752187
2176- x = torch .randn (3 , device = device )
2177- idx_1 = torch .tensor ([0 , ], device = device )
2178- idx_2 = torch .tensor ([0 , 1 ], device = device )
2179- chunk_size = 1
2180-
2181- def f (x , idx ):
2182- # `take` doesn't work with vmap
2183- # as it returns an output with dynamic shape.
2184- return torch .take (x , idx )
2185-
2186- for fn , idx in ((f , idx_1 ), (f , idx_2 )):
2187- jacfn = jacrev (fn , chunk_size = chunk_size , _preallocate_and_copy = _preallocate_and_copy )
2188- actual = jacfn (x , idx )
2189- expected = torch .autograd .functional .jacobian (partial (fn , idx = idx ), x , vectorize = False )
2190- self .assertEqual (actual , expected )
2188+ # Function with Dynamic Op in Backward.
2189+ # This should cause jacrev/vmap(vjp) to fail.
2190+ class IdentityWithDynamicBackwardOp (torch .autograd .Function ):
2191+ @staticmethod
2192+ def forward (input ):
2193+ return input
21912194
2192- msg = r"vmap: .* is not possible because there exists a Tensor"
2193- with self .assertRaisesRegex (RuntimeError , msg ):
2194- jacrev (fn , chunk_size = 2 , _preallocate_and_copy = _preallocate_and_copy )(x , idx )
2195+ @staticmethod
2196+ def setup_context (ctx , inputs , output ):
2197+ pass
2198+
2199+ @staticmethod
2200+ def backward (ctx , grad_output ):
2201+ # dynamic op in backward pass.
2202+ grad_output .nonzero ()
2203+ return grad_output
2204+
2205+ def f (x ):
2206+ return IdentityWithDynamicBackwardOp .apply (x )
2207+
2208+ # With `chunk_size=1`, we don't use vmap. So the following should work.
2209+ jacfn = jacrev (f , chunk_size = 1 , _preallocate_and_copy = _preallocate_and_copy )
2210+ actual = jacfn (x )
2211+ expected = torch .autograd .functional .jacobian (f , x , vectorize = False )
2212+ self .assertEqual (actual , expected )
2213+
2214+ # Should fail with `chunk_size=2`.
2215+ msg = r"vmap: We do not support batching operators that can output dynamic shape."
2216+ with self .assertRaisesRegex (RuntimeError , msg ):
2217+ jacrev (f , chunk_size = 2 , _preallocate_and_copy = _preallocate_and_copy )(x )
21952218
21962219 def test_complex_error (self , device ):
21972220 # Verify complex input raises error
0 commit comments