|
| 1 | +import threading |
| 2 | + |
1 | 3 | from keras.src import testing |
| 4 | +from keras.src.backend.common import global_state |
2 | 5 | from keras.src.backend.common.name_scope import current_path |
3 | 6 | from keras.src.backend.common.name_scope import name_scope |
4 | 7 |
|
@@ -46,3 +49,85 @@ def test_override_parent(self): |
46 | 49 | current_path(), "absolute/path/middle/inner" |
47 | 50 | ) |
48 | 51 | self.assertEqual(current_path(), "outer") |
| 52 | + |
| 53 | + def test_exit_with_none_stack(self): |
| 54 | + """Test that __exit__ handles None name_scope_stack gracefully.""" |
| 55 | + # Create a name_scope instance |
| 56 | + scope = name_scope("test") |
| 57 | + # Enter the scope normally |
| 58 | + scope.__enter__() |
| 59 | + |
| 60 | + # Simulate the scenario where global state is cleared |
| 61 | + # (e.g., in a different thread) |
| 62 | + global_state.set_global_attribute("name_scope_stack", None) |
| 63 | + |
| 64 | + # Exit should not raise an AttributeError |
| 65 | + scope.__exit__() |
| 66 | + |
| 67 | + # Clean up: reset the stack |
| 68 | + global_state.set_global_attribute("name_scope_stack", []) |
| 69 | + |
| 70 | + def test_exit_with_empty_stack(self): |
| 71 | + """Test that __exit__ handles empty name_scope_stack gracefully.""" |
| 72 | + # Create a name_scope instance |
| 73 | + scope = name_scope("test") |
| 74 | + # Enter the scope normally |
| 75 | + scope.__enter__() |
| 76 | + |
| 77 | + # Simulate the scenario where the stack is cleared |
| 78 | + name_scope_stack = global_state.get_global_attribute("name_scope_stack") |
| 79 | + name_scope_stack.clear() |
| 80 | + |
| 81 | + # Exit should not raise an IndexError |
| 82 | + scope.__exit__() |
| 83 | + |
| 84 | + # Verify stack is still empty |
| 85 | + name_scope_stack = global_state.get_global_attribute( |
| 86 | + "name_scope_stack", default=[] |
| 87 | + ) |
| 88 | + self.assertEqual(len(name_scope_stack), 0) |
| 89 | + |
| 90 | + def test_multithreaded_name_scope(self): |
| 91 | + """Test name_scope in multithreaded environment.""" |
| 92 | + results = [] |
| 93 | + |
| 94 | + def thread_function(thread_id): |
| 95 | + # Each thread should have its own name_scope_stack |
| 96 | + with name_scope(f"thread_{thread_id}"): |
| 97 | + path = current_path() |
| 98 | + results.append(path) |
| 99 | + # Verify we get the expected path |
| 100 | + self.assertEqual(path, f"thread_{thread_id}") |
| 101 | + |
| 102 | + # Create and start multiple threads |
| 103 | + threads = [] |
| 104 | + for i in range(5): |
| 105 | + thread = threading.Thread(target=thread_function, args=(i,)) |
| 106 | + threads.append(thread) |
| 107 | + thread.start() |
| 108 | + |
| 109 | + # Wait for all threads to complete |
| 110 | + for thread in threads: |
| 111 | + thread.join() |
| 112 | + |
| 113 | + # Verify all threads executed successfully |
| 114 | + self.assertEqual(len(results), 5) |
| 115 | + |
| 116 | + def test_exit_without_pop_on_exit(self): |
| 117 | + """Test that __exit__ respects _pop_on_exit flag.""" |
| 118 | + # Create a name_scope but don't enter it |
| 119 | + scope = name_scope("test") |
| 120 | + # _pop_on_exit should be False |
| 121 | + self.assertFalse(scope._pop_on_exit) |
| 122 | + |
| 123 | + # Set up a stack manually |
| 124 | + global_state.set_global_attribute("name_scope_stack", [scope]) |
| 125 | + |
| 126 | + scope.__exit__() |
| 127 | + |
| 128 | + # Verify the stack still contains the scope |
| 129 | + name_scope_stack = global_state.get_global_attribute("name_scope_stack") |
| 130 | + self.assertEqual(len(name_scope_stack), 1) |
| 131 | + |
| 132 | + # Clean up |
| 133 | + global_state.set_global_attribute("name_scope_stack", []) |
0 commit comments