@@ -60,3 +60,52 @@ def test_dropout_rate_greater_than_one(self):
6060 "Expected a float value between 0 and 1." ,
6161 ):
6262 _ = layers .Dropout (rate = 1.5 )
63+
64+ def test_validate_noise_shape_none (self ):
65+ layer = layers .Dropout (0.5 , noise_shape = None )
66+ self .assertIsNone (layer .noise_shape )
67+
68+ def test_validate_noise_shape_integer_tuple (self ):
69+ layer = layers .Dropout (0.5 , noise_shape = (20 , 1 , 10 ))
70+ self .assertEqual (layer .noise_shape , (20 , 1 , 10 ))
71+
72+ def test_validate_noise_shape_none_values (self ):
73+ layer = layers .Dropout (0.5 , noise_shape = (None , 1 , None ))
74+ self .assertEqual (layer .noise_shape , (None , 1 , None ))
75+
76+ def test_validate_noise_shape_cast_to_a_tuple (self ):
77+ layer = layers .Dropout (0.5 , noise_shape = [20 , 1 , 10 ])
78+ self .assertEqual (layer .noise_shape , (20 , 1 , 10 ))
79+ self .assertIsInstance (layer .noise_shape , tuple )
80+
81+ def test_validate_noise_shape_non_iterable (self ):
82+ with self .assertRaisesRegex (
83+ ValueError ,
84+ "Invalid value received for argument `noise_shape`. "
85+ "Expected a tuple or list of integers." ,
86+ ):
87+ layers .Dropout (0.5 , noise_shape = "Invalid" )
88+
89+ def test_validate_noise_shape_invalid_type (self ):
90+ with self .assertRaisesRegex (
91+ ValueError ,
92+ "Invalid value received for argument `noise_shape`. "
93+ "Expected all elements to be integers or None." ,
94+ ):
95+ layers .Dropout (0.5 , noise_shape = (20 , 1.5 , 10 ))
96+
97+ def test_validate_noise_shape_negative_value (self ):
98+ with self .assertRaisesRegex (
99+ ValueError ,
100+ "Invalid value received for argument `noise_shape`. "
101+ "Expected all dimensions to be positive integers or None." ,
102+ ):
103+ layers .Dropout (0.5 , noise_shape = (20 , - 1 , 10 ))
104+
105+ def test_validate_noise_shape_zero_value (self ):
106+ with self .assertRaisesRegex (
107+ ValueError ,
108+ "Invalid value received for argument `noise_shape`. "
109+ "Expected all dimensions to be positive integers or None." ,
110+ ):
111+ layers .Dropout (0.5 , noise_shape = (20 , 0 , 10 ))
0 commit comments