|
1 | | -import types |
2 | | -import weakref |
3 | | -from collections.abc import MutableSet |
4 | | - |
5 | | - |
6 | | -def check_deterministic(iterable): |
7 | | - # Most places where OrderedSet is used, pytensor interprets any exception |
8 | | - # whatsoever as a problem that an optimization introduced into the graph. |
9 | | - # If I raise a TypeError when the DestroyHandler tries to do something |
10 | | - # non-deterministic, it will just result in optimizations getting ignored. |
11 | | - # So I must use an assert here. In the long term we should fix the rest of |
12 | | - # pytensor to use exceptions correctly, so that this can be a TypeError. |
13 | | - if iterable is not None: |
14 | | - if not isinstance( |
15 | | - iterable, list | tuple | OrderedSet | types.GeneratorType | str | dict |
16 | | - ): |
17 | | - if len(iterable) > 1: |
18 | | - # We need to accept length 1 size to allow unpickle in tests. |
19 | | - raise AssertionError( |
20 | | - "Get an not ordered iterable when one was expected" |
21 | | - ) |
22 | | - |
23 | | - |
24 | | -# Copyright (C) 2009 Raymond Hettinger |
25 | | -# Permission is hereby granted, free of charge, to any person obtaining a |
26 | | -# copy of this software and associated documentation files (the |
27 | | -# "Software"), to deal in the Software without restriction, including |
28 | | -# without limitation the rights to use, copy, modify, merge, publish, |
29 | | -# distribute, sublicense, and/or sell copies of the Software, and to permit |
30 | | -# persons to whom the Software is furnished to do so, subject to the |
31 | | -# following conditions: |
32 | | - |
33 | | -# The above copyright notice and this permission notice shall be included |
34 | | -# in all copies or substantial portions of the Software. |
35 | | - |
36 | | -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
37 | | -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
38 | | -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
39 | | -# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
40 | | -# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
41 | | -# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
42 | | -# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
43 | | -# {{{ http://code.activestate.com/recipes/576696/ (r5) |
44 | | - |
45 | | - |
46 | | -class Link: |
47 | | - # This make that we need to use a different pickle protocol |
48 | | - # then the default. Otherwise, there is pickling errors |
49 | | - __slots__ = "prev", "next", "key", "__weakref__" |
50 | | - |
51 | | - def __getstate__(self): |
52 | | - # weakref.proxy don't pickle well, so we use weakref.ref |
53 | | - # manually and don't pickle the weakref. |
54 | | - # We restore the weakref when we unpickle. |
55 | | - ret = [self.prev(), self.next()] |
56 | | - try: |
57 | | - ret.append(self.key) |
58 | | - except AttributeError: |
59 | | - pass |
60 | | - return ret |
61 | | - |
62 | | - def __setstate__(self, state): |
63 | | - self.prev = weakref.ref(state[0]) |
64 | | - self.next = weakref.ref(state[1]) |
65 | | - if len(state) == 3: |
66 | | - self.key = state[2] |
| 1 | +from collections.abc import Iterable, Iterator, MutableSet |
| 2 | +from typing import Any |
67 | 3 |
|
68 | 4 |
|
69 | 5 | class OrderedSet(MutableSet): |
70 | | - "Set the remembers the order elements were added" |
71 | | - |
72 | | - # Big-O running times for all methods are the same as for regular sets. |
73 | | - # The internal self.__map dictionary maps keys to links in a doubly linked list. |
74 | | - # The circular doubly linked list starts and ends with a sentinel element. |
75 | | - # The sentinel element never gets deleted (this simplifies the algorithm). |
76 | | - # The prev/next links are weakref proxies (to prevent circular references). |
77 | | - # Individual links are kept alive by the hard reference in self.__map. |
78 | | - # Those hard references disappear when a key is deleted from an OrderedSet. |
| 6 | + values: dict[Any, None] |
79 | 7 |
|
80 | | - # Added by IG-- pre-existing pytensor code expected sets |
81 | | - # to have this method |
82 | | - def update(self, iterable): |
83 | | - check_deterministic(iterable) |
84 | | - self |= iterable |
85 | | - |
86 | | - def __init__(self, iterable=None): |
87 | | - # Checks added by IG |
88 | | - check_deterministic(iterable) |
89 | | - self.__root = root = Link() # sentinel node for doubly linked list |
90 | | - root.prev = root.next = weakref.ref(root) |
91 | | - self.__map = {} # key --> link |
92 | | - if iterable is not None: |
93 | | - self |= iterable |
94 | | - |
95 | | - def __len__(self): |
96 | | - return len(self.__map) |
97 | | - |
98 | | - def __contains__(self, key): |
99 | | - return key in self.__map |
100 | | - |
101 | | - def add(self, key): |
102 | | - # Store new key in a new link at the end of the linked list |
103 | | - if key not in self.__map: |
104 | | - self.__map[key] = link = Link() |
105 | | - root = self.__root |
106 | | - last = root.prev |
107 | | - link.prev, link.next, link.key = last, weakref.ref(root), key |
108 | | - last().next = root.prev = weakref.ref(link) |
109 | | - |
110 | | - def union(self, s): |
111 | | - check_deterministic(s) |
112 | | - n = self.copy() |
113 | | - for elem in s: |
114 | | - if elem not in n: |
115 | | - n.add(elem) |
116 | | - return n |
117 | | - |
118 | | - def intersection_update(self, s): |
119 | | - l = [] |
120 | | - for elem in self: |
121 | | - if elem not in s: |
122 | | - l.append(elem) |
123 | | - for elem in l: |
124 | | - self.remove(elem) |
125 | | - return self |
| 8 | + def __init__(self, iterable: Iterable | None = None) -> None: |
| 9 | + if iterable is None: |
| 10 | + self.values = {} |
| 11 | + else: |
| 12 | + self.values = {value: None for value in iterable} |
126 | 13 |
|
127 | | - def difference_update(self, s): |
128 | | - check_deterministic(s) |
129 | | - for elem in s: |
130 | | - if elem in self: |
131 | | - self.remove(elem) |
132 | | - return self |
| 14 | + def __contains__(self, value) -> bool: |
| 15 | + return value in self.values |
133 | 16 |
|
134 | | - def copy(self): |
135 | | - n = OrderedSet() |
136 | | - n.update(self) |
137 | | - return n |
| 17 | + def __iter__(self) -> Iterator: |
| 18 | + yield from self.values |
138 | 19 |
|
139 | | - def discard(self, key): |
140 | | - # Remove an existing item using self.__map to find the link which is |
141 | | - # then removed by updating the links in the predecessor and successors. |
142 | | - if key in self.__map: |
143 | | - link = self.__map.pop(key) |
144 | | - link.prev().next = link.next |
145 | | - link.next().prev = link.prev |
| 20 | + def __len__(self) -> int: |
| 21 | + return len(self.values) |
146 | 22 |
|
147 | | - def __iter__(self): |
148 | | - # Traverse the linked list in order. |
149 | | - root = self.__root |
150 | | - curr = root.next() |
151 | | - while curr is not root: |
152 | | - yield curr.key |
153 | | - curr = curr.next() |
| 23 | + def add(self, value) -> None: |
| 24 | + self.values[value] = None |
154 | 25 |
|
155 | | - def __reversed__(self): |
156 | | - # Traverse the linked list in reverse order. |
157 | | - root = self.__root |
158 | | - curr = root.prev() |
159 | | - while curr is not root: |
160 | | - yield curr.key |
161 | | - curr = curr.prev() |
| 26 | + def discard(self, value) -> None: |
| 27 | + if value in self.values: |
| 28 | + del self.values[value] |
162 | 29 |
|
163 | | - def pop(self, last=True): |
164 | | - if not self: |
165 | | - raise KeyError("set is empty") |
166 | | - if last: |
167 | | - key = next(reversed(self)) |
168 | | - else: |
169 | | - key = next(iter(self)) |
170 | | - self.discard(key) |
171 | | - return key |
| 30 | + def copy(self) -> "OrderedSet": |
| 31 | + return OrderedSet(self) |
172 | 32 |
|
173 | | - def __repr__(self): |
174 | | - if not self: |
175 | | - return f"{self.__class__.__name__}()" |
176 | | - return f"{self.__class__.__name__}({list(self)!r})" |
177 | | - |
178 | | - def __eq__(self, other): |
179 | | - # Note that we implement only the comparison to another |
180 | | - # `OrderedSet`, and not to a regular `set`, because otherwise we |
181 | | - # could have a non-symmetric equality relation like: |
182 | | - # my_ordered_set == my_set and my_set != my_ordered_set |
183 | | - if isinstance(other, OrderedSet): |
184 | | - return len(self) == len(other) and list(self) == list(other) |
185 | | - elif isinstance(other, set): |
186 | | - # Raise exception to avoid confusion. |
187 | | - raise TypeError( |
188 | | - "Cannot compare an `OrderedSet` to a `set` because " |
189 | | - "this comparison cannot be made symmetric: please " |
190 | | - "manually cast your `OrderedSet` into `set` before " |
191 | | - "performing this comparison." |
192 | | - ) |
193 | | - else: |
194 | | - return NotImplemented |
| 33 | + def update(self, other: Iterable) -> None: |
| 34 | + for value in other: |
| 35 | + self.add(value) |
195 | 36 |
|
| 37 | + def union(self, other: Iterable) -> "OrderedSet": |
| 38 | + new_set = self.copy() |
| 39 | + new_set.update(other) |
| 40 | + return new_set |
196 | 41 |
|
197 | | -# end of http://code.activestate.com/recipes/576696/ }}} |
| 42 | + def difference_update(self, other: Iterable) -> None: |
| 43 | + for value in other: |
| 44 | + self.discard(value) |
0 commit comments