|
7 | 7 | from contextlib import contextmanager |
8 | 8 | from typing import List, Optional |
9 | 9 |
|
10 | | -import numpy as np |
11 | | - |
12 | 10 | from pytensor.graph.basic import Variable |
13 | 11 | from pytensor.graph.utils import add_tag_trace |
14 | 12 | from pytensor.link.basic import Container |
@@ -103,6 +101,8 @@ def __init__(self, name, type, value, strict, allow_downcast=None, container=Non |
103 | 101 | if isinstance(__SHARED_CONTEXT__, list): |
104 | 102 | __SHARED_CONTEXT__.append(self) |
105 | 103 |
|
| 104 | + self._default_update: Optional[Variable] = None |
| 105 | + |
106 | 106 | def get_value(self, borrow=False, return_internal_type=False): |
107 | 107 | """ |
108 | 108 | Get the non-symbolic value associated with this SharedVariable. |
@@ -179,47 +179,23 @@ def clone(self, **kwargs): |
179 | 179 | cp.tag = copy.copy(self.tag) |
180 | 180 | return cp |
181 | 181 |
|
182 | | - def __getitem__(self, *args): |
183 | | - # __getitem__ is not available for generic SharedVariable objects. |
184 | | - # We raise a TypeError like Python would do if __getitem__ was not |
185 | | - # implemented at all, but with a more explicit error message to help |
186 | | - # PyTensor users figure out the root of the problem more easily. |
187 | | - value = self.get_value(borrow=True) |
188 | | - if isinstance(value, np.ndarray): |
189 | | - # Array probably had an unknown dtype. |
190 | | - msg = ( |
191 | | - f"a Numpy array with dtype: '{value.dtype}'. This data type is not " |
192 | | - "currently recognized by PyTensor tensors: please cast " |
193 | | - "your data into a supported numeric type if you need " |
194 | | - "PyTensor tensor functionalities." |
195 | | - ) |
196 | | - else: |
197 | | - msg = ( |
198 | | - f"an object of type: {type(value)}. Did you forget to cast it into " |
199 | | - "a Numpy array before calling pytensor.shared()?" |
200 | | - ) |
201 | | - |
202 | | - raise TypeError( |
203 | | - "The generic 'SharedVariable' object is not subscriptable. " |
204 | | - f"This shared variable contains {msg}" |
205 | | - ) |
206 | | - |
207 | | - def _value_get(self): |
208 | | - raise Exception( |
209 | | - "sharedvar.value does not exist anymore. Use " |
210 | | - "sharedvar.get_value() or sharedvar.set_value()" |
211 | | - " instead." |
212 | | - ) |
| 182 | + @property |
| 183 | + def default_update(self) -> Optional[Variable]: |
| 184 | + """A default update expression for this `Variable`. |
213 | 185 |
|
214 | | - def _value_set(self, new_value): |
215 | | - raise Exception( |
216 | | - "sharedvar.value does not exist anymore. Use " |
217 | | - "sharedvar.get_value() or sharedvar.set_value()" |
218 | | - " instead." |
219 | | - ) |
| 186 | + If this value is non-``None``, its value will be used as the `update` |
| 187 | + (see `pytensor.function`) for this `Variable` when no updates are |
| 188 | + provided through `pytensor.function` and `no_default_updates` isn't |
| 189 | + enabled. |
| 190 | + """ |
| 191 | + return self._default_update |
220 | 192 |
|
221 | | - # We keep this just to raise an error |
222 | | - value = property(_value_get, _value_set) |
| 193 | + @default_update.setter |
| 194 | + def default_update(self, value): |
| 195 | + if value is not None: |
| 196 | + self._default_update = self.type.filter_variable(value, allow_convert=True) |
| 197 | + else: |
| 198 | + self._default_update = value |
223 | 199 |
|
224 | 200 |
|
225 | 201 | def shared_constructor(ctor, remove=False): |
|
0 commit comments