diff --git a/keras/src/backend/common/name_scope.py b/keras/src/backend/common/name_scope.py index 71a8408767b..9eb1dd8985f 100644 --- a/keras/src/backend/common/name_scope.py +++ b/keras/src/backend/common/name_scope.py @@ -58,7 +58,8 @@ def __exit__(self, *args, **kwargs): name_scope_stack = global_state.get_global_attribute( "name_scope_stack" ) - name_scope_stack.pop() + if name_scope_stack: + name_scope_stack.pop() def current_path(): diff --git a/keras/src/backend/common/name_scope_test.py b/keras/src/backend/common/name_scope_test.py index 2e79f214695..79d9472638e 100644 --- a/keras/src/backend/common/name_scope_test.py +++ b/keras/src/backend/common/name_scope_test.py @@ -1,4 +1,5 @@ from keras.src import testing +from keras.src.backend.common import global_state from keras.src.backend.common.name_scope import current_path from keras.src.backend.common.name_scope import name_scope @@ -46,3 +47,53 @@ def test_override_parent(self): current_path(), "absolute/path/middle/inner" ) self.assertEqual(current_path(), "outer") + + def test_exit_with_empty_stack(self): + global_state.set_global_attribute("name_scope_stack", []) + + scope = name_scope("test") + scope._pop_on_exit = True + + try: + scope.__exit__() + success = True + except (AttributeError, IndexError): + success = False + + self.assertTrue(success) + + def test_exit_with_none_stack(self): + global_state.set_global_attribute("name_scope_stack", None) + + scope = name_scope("test") + scope._pop_on_exit = True + + try: + scope.__exit__() + success = True + except (AttributeError, IndexError): + success = False + + self.assertTrue(success) + + def test_exit_without_pop_on_exit(self): + global_state.set_global_attribute("name_scope_stack", ["dummy"]) + + scope = name_scope("test") + scope._pop_on_exit = False + + scope.__exit__() + + name_scope_stack = global_state.get_global_attribute("name_scope_stack") + self.assertEqual(len(name_scope_stack), 1) + + def test_normal_exit_still_works(self): + self.assertEqual(current_path(), "") + + with name_scope("test1"): + self.assertEqual(current_path(), "test1") + with name_scope("test2"): + self.assertEqual(current_path(), "test1/test2") + self.assertEqual(current_path(), "test1") + + self.assertEqual(current_path(), "")