Skip to content

Commit 71e2de0

Browse files
committed
Fix name_scope_stack AttributeError and IndexError in __exit__
- Add guard clauses to handle None and empty stack cases in name_scope.__exit__ - Prevents AttributeError when name_scope_stack is None (e.g., thread-local state cleared) - Prevents IndexError when name_scope_stack is empty - Add comprehensive tests for None stack, empty stack, multithreading scenarios - Fixes #21831
1 parent 19ca9c1 commit 71e2de0

File tree

2 files changed

+111
-1
lines changed

2 files changed

+111
-1
lines changed

keras/src/backend/common/name_scope.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def __exit__(self, *args, **kwargs):
5858
name_scope_stack = global_state.get_global_attribute(
5959
"name_scope_stack"
6060
)
61-
name_scope_stack.pop()
61+
if name_scope_stack is not None and len(name_scope_stack) > 0:
62+
name_scope_stack.pop()
6263

6364

6465
def current_path():

keras/src/backend/common/name_scope_test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import threading
2+
13
from keras.src import testing
4+
from keras.src.backend.common import global_state
25
from keras.src.backend.common.name_scope import current_path
36
from keras.src.backend.common.name_scope import name_scope
47

@@ -46,3 +49,109 @@ def test_override_parent(self):
4649
current_path(), "absolute/path/middle/inner"
4750
)
4851
self.assertEqual(current_path(), "outer")
52+
53+
def test_exit_with_none_stack(self):
54+
"""Test that __exit__ handles None name_scope_stack gracefully."""
55+
# Create a name_scope instance
56+
scope = name_scope("test")
57+
# Enter the scope normally
58+
scope.__enter__()
59+
60+
# Simulate the scenario where global state is cleared
61+
# (e.g., in a different thread)
62+
global_state.set_global_attribute("name_scope_stack", None)
63+
64+
# Exit should not raise an AttributeError
65+
try:
66+
scope.__exit__()
67+
except AttributeError as e:
68+
self.fail(f"__exit__ raised AttributeError: {e}")
69+
70+
# Clean up: reset the stack
71+
global_state.set_global_attribute("name_scope_stack", [])
72+
73+
def test_exit_with_empty_stack(self):
74+
"""Test that __exit__ handles empty name_scope_stack gracefully."""
75+
# Create a name_scope instance
76+
scope = name_scope("test")
77+
# Enter the scope normally
78+
scope.__enter__()
79+
80+
# Simulate the scenario where the stack is cleared
81+
name_scope_stack = global_state.get_global_attribute(
82+
"name_scope_stack"
83+
)
84+
name_scope_stack.clear()
85+
86+
# Exit should not raise an IndexError
87+
try:
88+
scope.__exit__()
89+
except IndexError as e:
90+
self.fail(f"__exit__ raised IndexError: {e}")
91+
92+
# Verify stack is still empty
93+
name_scope_stack = global_state.get_global_attribute(
94+
"name_scope_stack", default=[]
95+
)
96+
self.assertEqual(len(name_scope_stack), 0)
97+
98+
def test_multithreaded_name_scope(self):
99+
"""Test name_scope in multithreaded environment."""
100+
results = []
101+
errors = []
102+
103+
def thread_function(thread_id):
104+
try:
105+
# Each thread should have its own name_scope_stack
106+
with name_scope(f"thread_{thread_id}"):
107+
path = current_path()
108+
results.append(path)
109+
# Verify we get the expected path
110+
if path != f"thread_{thread_id}":
111+
errors.append(
112+
f"Thread {thread_id}: Expected "
113+
f"'thread_{thread_id}', got '{path}'"
114+
)
115+
except Exception as e:
116+
errors.append(f"Thread {thread_id}: {type(e).__name__}: {e}")
117+
118+
# Create and start multiple threads
119+
threads = []
120+
for i in range(5):
121+
thread = threading.Thread(target=thread_function, args=(i,))
122+
threads.append(thread)
123+
thread.start()
124+
125+
# Wait for all threads to complete
126+
for thread in threads:
127+
thread.join()
128+
129+
# Check for any errors
130+
if errors:
131+
self.fail(f"Errors in threads: {errors}")
132+
133+
# Verify all threads executed successfully
134+
self.assertEqual(len(results), 5)
135+
136+
def test_exit_without_pop_on_exit(self):
137+
"""Test that __exit__ respects _pop_on_exit flag."""
138+
# Create a name_scope but don't enter it
139+
scope = name_scope("test")
140+
# _pop_on_exit should be False
141+
self.assertFalse(scope._pop_on_exit)
142+
143+
# Set up a stack manually
144+
global_state.set_global_attribute(
145+
"name_scope_stack", [scope], set_to_default=False
146+
)
147+
148+
scope.__exit__()
149+
150+
# Verify the stack still contains the scope
151+
name_scope_stack = global_state.get_global_attribute(
152+
"name_scope_stack"
153+
)
154+
self.assertEqual(len(name_scope_stack), 1)
155+
156+
# Clean up
157+
global_state.set_global_attribute("name_scope_stack", [])

0 commit comments

Comments
 (0)