Skip to content

Commit bea37c5

Browse files
Fix noise_shape validation in keras.layers.Dropout (#21819)
* Fix noise_shape validation in keras.layers.Dropout * updated dropout.py to validate str for noise_shape, added unit test for _vaidate_noise_shape * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix minor code format issue --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 68fb291 commit bea37c5

File tree

2 files changed

+92
-1
lines changed

2 files changed

+92
-1
lines changed

keras/src/layers/regularization/dropout.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,55 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
4848
)
4949
self.rate = rate
5050
self.seed = seed
51-
self.noise_shape = noise_shape
51+
self.noise_shape = self._validate_noise_shape(noise_shape)
5252
if rate > 0:
5353
self.seed_generator = backend.random.SeedGenerator(seed)
5454
self.supports_masking = True
5555

5656
self._build_at_init()
5757

58+
def _validate_noise_shape(self, noise_shape):
59+
if noise_shape is None:
60+
return None
61+
62+
if isinstance(noise_shape, str):
63+
raise ValueError(
64+
f"Invalid value received for argument `noise_shape`. "
65+
f"Expected a tuple or list of integers. "
66+
f"Received: noise_shape={noise_shape}"
67+
)
68+
69+
if not isinstance(noise_shape, tuple):
70+
try:
71+
noise_shape = tuple(noise_shape)
72+
except TypeError:
73+
raise ValueError(
74+
f"Invalid value received for argument `noise_shape`. "
75+
f"Expected an iterable of integers "
76+
f"(e.g., a tuple or list). "
77+
f"Received: noise_shape={noise_shape}"
78+
)
79+
80+
for i, dim in enumerate(noise_shape):
81+
if dim is not None:
82+
if not isinstance(dim, int):
83+
raise ValueError(
84+
f"Invalid value received for argument `noise_shape`. "
85+
f"Expected all elements to be integers or None. "
86+
f"Received element at index {i}: {dim} "
87+
f"(type: {type(dim).__name__})"
88+
)
89+
90+
if dim <= 0:
91+
raise ValueError(
92+
f"Invalid value received for argument `noise_shape`. "
93+
f"Expected all dimensions to be positive integers "
94+
f"or None. "
95+
f"Received negative or zero value at index {i}: {dim}"
96+
)
97+
98+
return noise_shape
99+
58100
def call(self, inputs, training=False):
59101
if training and self.rate > 0:
60102
return backend.random.dropout(

keras/src/layers/regularization/dropout_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,52 @@ def test_dropout_rate_greater_than_one(self):
6060
"Expected a float value between 0 and 1.",
6161
):
6262
_ = layers.Dropout(rate=1.5)
63+
64+
def test_validate_noise_shape_none(self):
65+
layer = layers.Dropout(0.5, noise_shape=None)
66+
self.assertIsNone(layer.noise_shape)
67+
68+
def test_validate_noise_shape_integer_tuple(self):
69+
layer = layers.Dropout(0.5, noise_shape=(20, 1, 10))
70+
self.assertEqual(layer.noise_shape, (20, 1, 10))
71+
72+
def test_validate_noise_shape_none_values(self):
73+
layer = layers.Dropout(0.5, noise_shape=(None, 1, None))
74+
self.assertEqual(layer.noise_shape, (None, 1, None))
75+
76+
def test_validate_noise_shape_cast_to_a_tuple(self):
77+
layer = layers.Dropout(0.5, noise_shape=[20, 1, 10])
78+
self.assertEqual(layer.noise_shape, (20, 1, 10))
79+
self.assertIsInstance(layer.noise_shape, tuple)
80+
81+
def test_validate_noise_shape_non_iterable(self):
82+
with self.assertRaisesRegex(
83+
ValueError,
84+
"Invalid value received for argument `noise_shape`. "
85+
"Expected a tuple or list of integers.",
86+
):
87+
layers.Dropout(0.5, noise_shape="Invalid")
88+
89+
def test_validate_noise_shape_invalid_type(self):
90+
with self.assertRaisesRegex(
91+
ValueError,
92+
"Invalid value received for argument `noise_shape`. "
93+
"Expected all elements to be integers or None.",
94+
):
95+
layers.Dropout(0.5, noise_shape=(20, 1.5, 10))
96+
97+
def test_validate_noise_shape_negative_value(self):
98+
with self.assertRaisesRegex(
99+
ValueError,
100+
"Invalid value received for argument `noise_shape`. "
101+
"Expected all dimensions to be positive integers or None.",
102+
):
103+
layers.Dropout(0.5, noise_shape=(20, -1, 10))
104+
105+
def test_validate_noise_shape_zero_value(self):
106+
with self.assertRaisesRegex(
107+
ValueError,
108+
"Invalid value received for argument `noise_shape`. "
109+
"Expected all dimensions to be positive integers or None.",
110+
):
111+
layers.Dropout(0.5, noise_shape=(20, 0, 10))

0 commit comments

Comments
 (0)