@@ -894,15 +894,55 @@ def test_random_concrete_shape_subtensor_tuple(self):
894894 jax_fn = compile_random_function ([x_pt ], out )
895895 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
896896
897+ def test_random_scalar_shape_input (self ):
898+ dim0 = pt .scalar ("dim0" , dtype = int )
899+ dim1 = pt .scalar ("dim1" , dtype = int )
900+
901+ out = pt .random .normal (0 , 1 , size = dim0 )
902+ jax_fn = compile_random_function ([dim0 ], out )
903+ assert jax_fn (np .array (2 )).shape == (2 ,)
904+ assert jax_fn (np .array (3 )).shape == (3 ,)
905+
906+ out = pt .random .normal (0 , 1 , size = [dim0 , dim1 ])
907+ jax_fn = compile_random_function ([dim0 , dim1 ], out )
908+ assert jax_fn (np .array (2 ), np .array (3 )).shape == (2 , 3 )
909+ assert jax_fn (np .array (4 ), np .array (5 )).shape == (4 , 5 )
910+
897911 @pytest .mark .xfail (
898- reason = "`size_pt` should be specified as a static argument" , strict = True
912+ raises = TypeError , reason = "Cannot convert scalar input to integer"
899913 )
900- def test_random_concrete_shape_graph_input (self ):
901- rng = shared (np .random .default_rng (123 ))
902- size_pt = pt .scalar ()
903- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
904- jax_fn = compile_random_function ([size_pt ], out )
905- assert jax_fn (10 ).shape == (10 ,)
914+ def test_random_scalar_shape_input_not_supported (self ):
915+ dim = pt .scalar ("dim" , dtype = int )
916+ out1 = pt .random .normal (0 , 1 , size = dim )
917+ # An operation that wouldn't work if we replaced 0d array by integer
918+ out2 = dim [...].set (1 )
919+ jax_fn = compile_random_function ([dim ], [out1 , out2 ])
920+
921+ res1 , res2 = jax_fn (np .array (2 ))
922+ assert res1 .shape == (2 ,)
923+ assert res2 == 1
924+
925+ @pytest .mark .xfail (
926+ raises = TypeError , reason = "Cannot convert scalar input to integer"
927+ )
928+ def test_random_scalar_shape_input_not_supported2 (self ):
929+ dim = pt .scalar ("dim" , dtype = int )
930+ # This could theoretically be supported
931+ # but would require knowing that * 2 is a safe operation for a python integer
932+ out = pt .random .normal (0 , 1 , size = dim * 2 )
933+ jax_fn = compile_random_function ([dim ], out )
934+ assert jax_fn (np .array (2 )).shape == (4 ,)
935+
936+ @pytest .mark .xfail (
937+ raises = TypeError , reason = "Cannot convert tensor input to shape tuple"
938+ )
939+ def test_random_vector_shape_graph_input (self ):
940+ shape = pt .vector ("shape" , shape = (2 ,), dtype = int )
941+ out = pt .random .normal (0 , 1 , size = shape )
942+
943+ jax_fn = compile_random_function ([shape ], out )
944+ assert jax_fn (np .array ([2 , 3 ])).shape == (2 , 3 )
945+ assert jax_fn (np .array ([4 , 5 ])).shape == (4 , 5 )
906946
907947 def test_constant_shape_after_graph_rewriting (self ):
908948 size = pt .vector ("size" , shape = (2 ,), dtype = int )
@@ -912,13 +952,13 @@ def test_constant_shape_after_graph_rewriting(self):
912952 with pytest .raises (TypeError ):
913953 compile_random_function ([size ], x )([2 , 5 ])
914954
915- # Rebuild with strict=False so output type is not updated
955+ # Rebuild with strict=True so output type is not updated
916956 # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
917957 new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
918958 assert new_x .type .shape == (None , None )
919959 assert compile_random_function ([], new_x )().shape == (2 , 5 )
920960
921- # Rebuild with strict=True , so output type is updated
961+ # Rebuild with strict=False , so output type is updated
922962 # This uses a different path in the dispatch implementation
923963 new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
924964 assert new_x .type .shape == (2 , 5 )
0 commit comments