@@ -479,27 +479,26 @@ def test_linspace(num, dtype, endpoint, data):
479479 ah .assert_exactly_equal (out , expected )
480480
481481
482- @given (
482+ @given (dtype = xps .numeric_dtypes (), data = st .data ())
483+ def test_meshgrid (dtype , data ):
483484 # The number and size of generated arrays is arbitrarily limited to prevent
484485 # meshgrid() running out of memory.
485- dtypes = hh .mutually_promotable_dtypes (5 , dtypes = dh .numeric_dtypes ),
486- data = st .data (),
487- )
488- def test_meshgrid (dtypes , data ):
489- arrays = []
490486 shapes = data .draw (
491- hh .mutually_broadcastable_shapes (
492- len (dtypes ), min_dims = 1 , max_dims = 1 , max_side = 5
487+ st .integers (1 , 5 ).flatmap (
488+ lambda n : hh .mutually_broadcastable_shapes (
489+ n , min_dims = 1 , max_dims = 1 , max_side = 5
490+ )
493491 ),
494492 label = "shapes" ,
495493 )
496- for i , (dtype , shape ) in enumerate (zip (dtypes , shapes ), 1 ):
494+ arrays = []
495+ for i , shape in enumerate (shapes , 1 ):
497496 x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f"x{ i } " )
498497 arrays .append (x )
499498 assert math .prod (x .size for x in arrays ) <= hh .MAX_ARRAY_SIZE # sanity check
500499 out = xp .meshgrid (* arrays )
501500 for i , x in enumerate (out ):
502- ph .assert_dtype ("meshgrid" , dtypes , x .dtype , repr_name = f"out[{ i } ].dtype" )
501+ ph .assert_dtype ("meshgrid" , dtype , x .dtype , repr_name = f"out[{ i } ].dtype" )
503502
504503
505504def make_one (dtype : DataType ) -> Scalar :
0 commit comments