|
| 1 | +import threading |
| 2 | + |
1 | 3 | from keras.src import testing |
| 4 | +from keras.src.backend.common import global_state |
2 | 5 | from keras.src.backend.common.name_scope import current_path |
3 | 6 | from keras.src.backend.common.name_scope import name_scope |
4 | 7 |
|
@@ -46,3 +49,109 @@ def test_override_parent(self): |
46 | 49 | current_path(), "absolute/path/middle/inner" |
47 | 50 | ) |
48 | 51 | self.assertEqual(current_path(), "outer") |
| 52 | + |
| 53 | + def test_exit_with_none_stack(self): |
| 54 | + """Test that __exit__ handles None name_scope_stack gracefully.""" |
| 55 | + # Create a name_scope instance |
| 56 | + scope = name_scope("test") |
| 57 | + # Enter the scope normally |
| 58 | + scope.__enter__() |
| 59 | + |
| 60 | + # Simulate the scenario where global state is cleared |
| 61 | + # (e.g., in a different thread) |
| 62 | + global_state.set_global_attribute("name_scope_stack", None) |
| 63 | + |
| 64 | + # Exit should not raise an AttributeError |
| 65 | + try: |
| 66 | + scope.__exit__() |
| 67 | + except AttributeError as e: |
| 68 | + self.fail(f"__exit__ raised AttributeError: {e}") |
| 69 | + |
| 70 | + # Clean up: reset the stack |
| 71 | + global_state.set_global_attribute("name_scope_stack", []) |
| 72 | + |
| 73 | + def test_exit_with_empty_stack(self): |
| 74 | + """Test that __exit__ handles empty name_scope_stack gracefully.""" |
| 75 | + # Create a name_scope instance |
| 76 | + scope = name_scope("test") |
| 77 | + # Enter the scope normally |
| 78 | + scope.__enter__() |
| 79 | + |
| 80 | + # Simulate the scenario where the stack is cleared |
| 81 | + name_scope_stack = global_state.get_global_attribute( |
| 82 | + "name_scope_stack" |
| 83 | + ) |
| 84 | + name_scope_stack.clear() |
| 85 | + |
| 86 | + # Exit should not raise an IndexError |
| 87 | + try: |
| 88 | + scope.__exit__() |
| 89 | + except IndexError as e: |
| 90 | + self.fail(f"__exit__ raised IndexError: {e}") |
| 91 | + |
| 92 | + # Verify stack is still empty |
| 93 | + name_scope_stack = global_state.get_global_attribute( |
| 94 | + "name_scope_stack", default=[] |
| 95 | + ) |
| 96 | + self.assertEqual(len(name_scope_stack), 0) |
| 97 | + |
| 98 | + def test_multithreaded_name_scope(self): |
| 99 | + """Test name_scope in multithreaded environment.""" |
| 100 | + results = [] |
| 101 | + errors = [] |
| 102 | + |
| 103 | + def thread_function(thread_id): |
| 104 | + try: |
| 105 | + # Each thread should have its own name_scope_stack |
| 106 | + with name_scope(f"thread_{thread_id}"): |
| 107 | + path = current_path() |
| 108 | + results.append(path) |
| 109 | + # Verify we get the expected path |
| 110 | + if path != f"thread_{thread_id}": |
| 111 | + errors.append( |
| 112 | + f"Thread {thread_id}: Expected " |
| 113 | + f"'thread_{thread_id}', got '{path}'" |
| 114 | + ) |
| 115 | + except Exception as e: |
| 116 | + errors.append(f"Thread {thread_id}: {type(e).__name__}: {e}") |
| 117 | + |
| 118 | + # Create and start multiple threads |
| 119 | + threads = [] |
| 120 | + for i in range(5): |
| 121 | + thread = threading.Thread(target=thread_function, args=(i,)) |
| 122 | + threads.append(thread) |
| 123 | + thread.start() |
| 124 | + |
| 125 | + # Wait for all threads to complete |
| 126 | + for thread in threads: |
| 127 | + thread.join() |
| 128 | + |
| 129 | + # Check for any errors |
| 130 | + if errors: |
| 131 | + self.fail(f"Errors in threads: {errors}") |
| 132 | + |
| 133 | + # Verify all threads executed successfully |
| 134 | + self.assertEqual(len(results), 5) |
| 135 | + |
| 136 | + def test_exit_without_pop_on_exit(self): |
| 137 | + """Test that __exit__ respects _pop_on_exit flag.""" |
| 138 | + # Create a name_scope but don't enter it |
| 139 | + scope = name_scope("test") |
| 140 | + # _pop_on_exit should be False |
| 141 | + self.assertFalse(scope._pop_on_exit) |
| 142 | + |
| 143 | + # Set up a stack manually |
| 144 | + global_state.set_global_attribute( |
| 145 | + "name_scope_stack", [scope], set_to_default=False |
| 146 | + ) |
| 147 | + |
| 148 | + scope.__exit__() |
| 149 | + |
| 150 | + # Verify the stack still contains the scope |
| 151 | + name_scope_stack = global_state.get_global_attribute( |
| 152 | + "name_scope_stack" |
| 153 | + ) |
| 154 | + self.assertEqual(len(name_scope_stack), 1) |
| 155 | + |
| 156 | + # Clean up |
| 157 | + global_state.set_global_attribute("name_scope_stack", []) |
0 commit comments