Skip to content

Commit 032528b

Browse files
Fix name_scope_stack AttributeError and IndexError in __exit__ (#21834)
* Fix name_scope_stack AttributeError and IndexError in __exit__ - Add guard clauses to handle None and empty stack cases in name_scope.__exit__ - Prevents AttributeError when name_scope_stack is None (e.g., thread-local state cleared) - Prevents IndexError when name_scope_stack is empty - Add comprehensive tests for None stack, empty stack, multithreading scenarios - Fixes #21831 * Address code review comments - Simplify check to use idiomatic 'if name_scope_stack:' instead of explicit None and length checks - Remove unnecessary try-except blocks in tests (tests fail automatically on unexpected exceptions) - Use self.assertEqual() directly in multithreaded test for clearer assertions * chore: Update generated API files * fix: Apply formatting changes --------- Co-authored-by: SamareshSingh <ssam3003@gmail.com>
1 parent 39c475b commit 032528b

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

keras/src/backend/common/name_scope.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def __exit__(self, *args, **kwargs):
5858
name_scope_stack = global_state.get_global_attribute(
5959
"name_scope_stack"
6060
)
61-
name_scope_stack.pop()
61+
if name_scope_stack:
62+
name_scope_stack.pop()
6263

6364

6465
def current_path():

keras/src/backend/common/name_scope_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import threading
2+
13
from keras.src import testing
4+
from keras.src.backend.common import global_state
25
from keras.src.backend.common.name_scope import current_path
36
from keras.src.backend.common.name_scope import name_scope
47

@@ -46,3 +49,85 @@ def test_override_parent(self):
4649
current_path(), "absolute/path/middle/inner"
4750
)
4851
self.assertEqual(current_path(), "outer")
52+
53+
def test_exit_with_none_stack(self):
54+
"""Test that __exit__ handles None name_scope_stack gracefully."""
55+
# Create a name_scope instance
56+
scope = name_scope("test")
57+
# Enter the scope normally
58+
scope.__enter__()
59+
60+
# Simulate the scenario where global state is cleared
61+
# (e.g., in a different thread)
62+
global_state.set_global_attribute("name_scope_stack", None)
63+
64+
# Exit should not raise an AttributeError
65+
scope.__exit__()
66+
67+
# Clean up: reset the stack
68+
global_state.set_global_attribute("name_scope_stack", [])
69+
70+
def test_exit_with_empty_stack(self):
71+
"""Test that __exit__ handles empty name_scope_stack gracefully."""
72+
# Create a name_scope instance
73+
scope = name_scope("test")
74+
# Enter the scope normally
75+
scope.__enter__()
76+
77+
# Simulate the scenario where the stack is cleared
78+
name_scope_stack = global_state.get_global_attribute("name_scope_stack")
79+
name_scope_stack.clear()
80+
81+
# Exit should not raise an IndexError
82+
scope.__exit__()
83+
84+
# Verify stack is still empty
85+
name_scope_stack = global_state.get_global_attribute(
86+
"name_scope_stack", default=[]
87+
)
88+
self.assertEqual(len(name_scope_stack), 0)
89+
90+
def test_multithreaded_name_scope(self):
91+
"""Test name_scope in multithreaded environment."""
92+
results = []
93+
94+
def thread_function(thread_id):
95+
# Each thread should have its own name_scope_stack
96+
with name_scope(f"thread_{thread_id}"):
97+
path = current_path()
98+
results.append(path)
99+
# Verify we get the expected path
100+
self.assertEqual(path, f"thread_{thread_id}")
101+
102+
# Create and start multiple threads
103+
threads = []
104+
for i in range(5):
105+
thread = threading.Thread(target=thread_function, args=(i,))
106+
threads.append(thread)
107+
thread.start()
108+
109+
# Wait for all threads to complete
110+
for thread in threads:
111+
thread.join()
112+
113+
# Verify all threads executed successfully
114+
self.assertEqual(len(results), 5)
115+
116+
def test_exit_without_pop_on_exit(self):
117+
"""Test that __exit__ respects _pop_on_exit flag."""
118+
# Create a name_scope but don't enter it
119+
scope = name_scope("test")
120+
# _pop_on_exit should be False
121+
self.assertFalse(scope._pop_on_exit)
122+
123+
# Set up a stack manually
124+
global_state.set_global_attribute("name_scope_stack", [scope])
125+
126+
scope.__exit__()
127+
128+
# Verify the stack still contains the scope
129+
name_scope_stack = global_state.get_global_attribute("name_scope_stack")
130+
self.assertEqual(len(name_scope_stack), 1)
131+
132+
# Clean up
133+
global_state.set_global_attribute("name_scope_stack", [])

0 commit comments

Comments
 (0)