diff --git a/keras/src/backend/common/name_scope.py b/keras/src/backend/common/name_scope.py index 71a8408767b..9eb1dd8985f 100644 --- a/keras/src/backend/common/name_scope.py +++ b/keras/src/backend/common/name_scope.py @@ -58,7 +58,8 @@ def __exit__(self, *args, **kwargs): name_scope_stack = global_state.get_global_attribute( "name_scope_stack" ) - name_scope_stack.pop() + if name_scope_stack: + name_scope_stack.pop() def current_path(): diff --git a/keras/src/backend/common/name_scope_test.py b/keras/src/backend/common/name_scope_test.py index 2e79f214695..a9873ca95ee 100644 --- a/keras/src/backend/common/name_scope_test.py +++ b/keras/src/backend/common/name_scope_test.py @@ -1,4 +1,7 @@ +import threading + from keras.src import testing +from keras.src.backend.common import global_state from keras.src.backend.common.name_scope import current_path from keras.src.backend.common.name_scope import name_scope @@ -46,3 +49,85 @@ def test_override_parent(self): current_path(), "absolute/path/middle/inner" ) self.assertEqual(current_path(), "outer") + + def test_exit_with_none_stack(self): + """Test that __exit__ handles None name_scope_stack gracefully.""" + # Create a name_scope instance + scope = name_scope("test") + # Enter the scope normally + scope.__enter__() + + # Simulate the scenario where global state is cleared + # (e.g., in a different thread) + global_state.set_global_attribute("name_scope_stack", None) + + # Exit should not raise an AttributeError + scope.__exit__() + + # Clean up: reset the stack + global_state.set_global_attribute("name_scope_stack", []) + + def test_exit_with_empty_stack(self): + """Test that __exit__ handles empty name_scope_stack gracefully.""" + # Create a name_scope instance + scope = name_scope("test") + # Enter the scope normally + scope.__enter__() + + # Simulate the scenario where the stack is cleared + name_scope_stack = global_state.get_global_attribute("name_scope_stack") + name_scope_stack.clear() + + # Exit should not raise an IndexError + scope.__exit__() + + # Verify stack is still empty + name_scope_stack = global_state.get_global_attribute( + "name_scope_stack", default=[] + ) + self.assertEqual(len(name_scope_stack), 0) + + def test_multithreaded_name_scope(self): + """Test name_scope in multithreaded environment.""" + results = [] + + def thread_function(thread_id): + # Each thread should have its own name_scope_stack + with name_scope(f"thread_{thread_id}"): + path = current_path() + results.append(path) + # Verify we get the expected path + self.assertEqual(path, f"thread_{thread_id}") + + # Create and start multiple threads + threads = [] + for i in range(5): + thread = threading.Thread(target=thread_function, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all threads executed successfully + self.assertEqual(len(results), 5) + + def test_exit_without_pop_on_exit(self): + """Test that __exit__ respects _pop_on_exit flag.""" + # Create a name_scope but don't enter it + scope = name_scope("test") + # _pop_on_exit should be False + self.assertFalse(scope._pop_on_exit) + + # Set up a stack manually + global_state.set_global_attribute("name_scope_stack", [scope]) + + scope.__exit__() + + # Verify the stack still contains the scope + name_scope_stack = global_state.get_global_attribute("name_scope_stack") + self.assertEqual(len(name_scope_stack), 1) + + # Clean up + global_state.set_global_attribute("name_scope_stack", [])