@@ -836,94 +836,90 @@ def sample_fn(rng, size, dtype, *parameters):
836836 compare_jax_and_py ([], [out ], [])
837837
838838
839- def test_random_concrete_shape ():
840- """JAX should compile when a `RandomVariable` is passed a concrete shape.
841-
842- There are three quantities that JAX considers as concrete:
843- 1. Constants known at compile time;
844- 2. The shape of an array.
845- 3. `static_argnums` parameters
846- This test makes sure that graphs with `RandomVariable`s compile when the
847- `size` parameter satisfies either of these criteria.
848-
849- """
850- rng = shared (np .random .default_rng (123 ))
851- x_pt = pt .dmatrix ()
852- out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
853- jax_fn = compile_random_function ([x_pt ], out )
854- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
855-
856-
857- def test_random_concrete_shape_from_param ():
858- rng = shared (np .random .default_rng (123 ))
859- x_pt = pt .dmatrix ()
860- out = pt .random .normal (x_pt , 1 , rng = rng )
861- jax_fn = compile_random_function ([x_pt ], out )
862- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
863-
864-
865- def test_random_concrete_shape_subtensor ():
866- """JAX should compile when a concrete value is passed for the `size` parameter.
867-
868- This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
869- inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
870- inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
871- rewrite.
872-
873- JAX does not accept scalars as `size` or `shape` arguments, so this is a
874- slight improvement over their API.
875-
876- """
877- rng = shared (np .random .default_rng (123 ))
878- x_pt = pt .dmatrix ()
879- out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
880- jax_fn = compile_random_function ([x_pt ], out )
881- assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
882-
883-
884- def test_random_concrete_shape_subtensor_tuple ():
885- """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
886-
887- This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
888- inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
889- scalar inputs into tuples of concrete values using the
890- `jax_size_parameter_as_tuple` rewrite.
891-
892- """
893- rng = shared (np .random .default_rng (123 ))
894- x_pt = pt .dmatrix ()
895- out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
896- jax_fn = compile_random_function ([x_pt ], out )
897- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
898-
899-
900- @pytest .mark .xfail (
901- reason = "`size_pt` should be specified as a static argument" , strict = True
902- )
903- def test_random_concrete_shape_graph_input ():
904- rng = shared (np .random .default_rng (123 ))
905- size_pt = pt .scalar ()
906- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
907- jax_fn = compile_random_function ([size_pt ], out )
908- assert jax_fn (10 ).shape == (10 ,)
909-
910-
911- def test_constant_shape_after_graph_rewriting ():
912- size = pt .vector ("size" , shape = (2 ,), dtype = int )
913- x = pt .random .normal (size = size )
914- assert x .type .shape == (None , None )
915-
916- with pytest .raises (TypeError ):
917- compile_random_function ([size ], x )([2 , 5 ])
918-
919- # Rebuild with strict=False so output type is not updated
920- # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
921- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
922- assert new_x .type .shape == (None , None )
923- assert compile_random_function ([], new_x )().shape == (2 , 5 )
924-
925- # Rebuild with strict=True, so output type is updated
926- # This uses a different path in the dispatch implementation
927- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
928- assert new_x .type .shape == (2 , 5 )
929- assert compile_random_function ([], new_x )().shape == (2 , 5 )
839+ class TestRandomShapeInputs :
840+ def test_random_concrete_shape (self ):
841+ """JAX should compile when a `RandomVariable` is passed a concrete shape.
842+
843+ There are three quantities that JAX considers as concrete:
844+ 1. Constants known at compile time;
845+ 2. The shape of an array.
846+ 3. `static_argnums` parameters
847+ This test makes sure that graphs with `RandomVariable`s compile when the
848+ `size` parameter satisfies either of these criteria.
849+
850+ """
851+ rng = shared (np .random .default_rng (123 ))
852+ x_pt = pt .dmatrix ()
853+ out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
854+ jax_fn = compile_random_function ([x_pt ], out )
855+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
856+
857+ def test_random_concrete_shape_from_param (self ):
858+ rng = shared (np .random .default_rng (123 ))
859+ x_pt = pt .dmatrix ()
860+ out = pt .random .normal (x_pt , 1 , rng = rng )
861+ jax_fn = compile_random_function ([x_pt ], out )
862+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
863+
864+ def test_random_concrete_shape_subtensor (self ):
865+ """JAX should compile when a concrete value is passed for the `size` parameter.
866+
867+ This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
868+ inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
869+ inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
870+ rewrite.
871+
872+ JAX does not accept scalars as `size` or `shape` arguments, so this is a
873+ slight improvement over their API.
874+
875+ """
876+ rng = shared (np .random .default_rng (123 ))
877+ x_pt = pt .dmatrix ()
878+ out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
879+ jax_fn = compile_random_function ([x_pt ], out )
880+ assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
881+
882+ def test_random_concrete_shape_subtensor_tuple (self ):
883+ """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
884+
885+ This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
886+ inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
887+ scalar inputs into tuples of concrete values using the
888+ `jax_size_parameter_as_tuple` rewrite.
889+
890+ """
891+ rng = shared (np .random .default_rng (123 ))
892+ x_pt = pt .dmatrix ()
893+ out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
894+ jax_fn = compile_random_function ([x_pt ], out )
895+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
896+
897+ @pytest .mark .xfail (
898+ reason = "`size_pt` should be specified as a static argument" , strict = True
899+ )
900+ def test_random_concrete_shape_graph_input (self ):
901+ rng = shared (np .random .default_rng (123 ))
902+ size_pt = pt .scalar ()
903+ out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
904+ jax_fn = compile_random_function ([size_pt ], out )
905+ assert jax_fn (10 ).shape == (10 ,)
906+
907+ def test_constant_shape_after_graph_rewriting (self ):
908+ size = pt .vector ("size" , shape = (2 ,), dtype = int )
909+ x = pt .random .normal (size = size )
910+ assert x .type .shape == (None , None )
911+
912+ with pytest .raises (TypeError ):
913+ compile_random_function ([size ], x )([2 , 5 ])
914+
915+ # Rebuild with strict=False so output type is not updated
916+ # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
917+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
918+ assert new_x .type .shape == (None , None )
919+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
920+
921+ # Rebuild with strict=True, so output type is updated
922+ # This uses a different path in the dispatch implementation
923+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
924+ assert new_x .type .shape == (2 , 5 )
925+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
0 commit comments