@@ -605,7 +605,7 @@ def test_change_specify_shape_size_multivariate():
605605
606606
607607@pytest .mark .parametrize (
608- "steps , shape, step_shape_offset, expected_steps , consistent" ,
608+ "support_shape , shape, support_shape_offset, expected_support_shape , consistent" ,
609609 [
610610 (10 , None , 0 , 10 , True ),
611611 (10 , None , 1 , 10 , True ),
@@ -621,44 +621,46 @@ def test_change_specify_shape_size_multivariate():
621621)
622622@pytest .mark .parametrize ("info_source" , ("shape" , "dims" , "observed" ))
623623def test_get_support_shape_1d (
624- info_source , steps , shape , step_shape_offset , expected_steps , consistent
624+ info_source , support_shape , shape , support_shape_offset , expected_support_shape , consistent
625625):
626626 if info_source == "shape" :
627- inferred_steps = get_support_shape_1d (
628- support_shape = steps , shape = shape , support_shape_offset = step_shape_offset
627+ inferred_support_shape = get_support_shape_1d (
628+ support_shape = support_shape , shape = shape , support_shape_offset = support_shape_offset
629629 )
630630
631631 elif info_source == "dims" :
632632 if shape is None :
633633 dims = None
634634 coords = {}
635635 else :
636- dims = tuple (str (i ) for i , shape in enumerate (shape ))
636+ dims = tuple (str (i ) for i , _ in enumerate (shape ))
637637 coords = {str (i ): range (shape ) for i , shape in enumerate (shape )}
638638 with Model (coords = coords ):
639- inferred_steps = get_support_shape_1d (
640- support_shape = steps , dims = dims , support_shape_offset = step_shape_offset
639+ inferred_support_shape = get_support_shape_1d (
640+ support_shape = support_shape , dims = dims , support_shape_offset = support_shape_offset
641641 )
642642
643643 elif info_source == "observed" :
644644 if shape is None :
645645 observed = None
646646 else :
647647 observed = np .zeros (shape )
648- inferred_steps = get_support_shape_1d (
649- support_shape = steps , observed = observed , support_shape_offset = step_shape_offset
648+ inferred_support_shape = get_support_shape_1d (
649+ support_shape = support_shape ,
650+ observed = observed ,
651+ support_shape_offset = support_shape_offset ,
650652 )
651653
652- if not isinstance (inferred_steps , TensorVariable ):
653- assert inferred_steps == expected_steps
654+ if not isinstance (inferred_support_shape , TensorVariable ):
655+ assert inferred_support_shape == expected_support_shape
654656 else :
655657 if consistent :
656- assert inferred_steps .eval () == expected_steps
658+ assert inferred_support_shape .eval () == expected_support_shape
657659 else :
658660 # check that inferred steps is still correct by ignoring the assert
659661 f = aesara .function (
660- [], inferred_steps , mode = Mode ().including ("local_remove_all_assert" )
662+ [], inferred_support_shape , mode = Mode ().including ("local_remove_all_assert" )
661663 )
662- assert f () == expected_steps
663- with pytest .raises (AssertionError , match = "Steps do not match" ):
664- inferred_steps .eval ()
664+ assert f () == expected_support_shape
665+ with pytest .raises (AssertionError , match = "support_shape does not match" ):
666+ inferred_support_shape .eval ()
0 commit comments