33import copy
44from contextlib import contextmanager
55from functools import singledispatch
6- from typing import List , Optional
6+ from typing import TYPE_CHECKING , List , Optional
77
88from pytensor .graph .basic import Variable
99from pytensor .graph .utils import add_tag_trace
1010from pytensor .link .basic import Container
1111from pytensor .link .c .type import generic
1212
1313
14+ if TYPE_CHECKING :
15+ from pytensor .graph .type import Type
16+
17+
1418__SHARED_CONTEXT__ : Optional [List [Variable ]] = None
1519
1620
@@ -30,14 +34,39 @@ def collect_new_shareds():
3034class SharedVariable (Variable ):
3135 """Variable that is shared between compiled functions."""
3236
33- container : Optional [Container ] = None
34- """
35- A container to use for this SharedVariable when it is an implicit
36- function parameter.
37- """
37+ def __init__ (
38+ self ,
39+ type : "Type" ,
40+ value ,
41+ strict : bool ,
42+ allow_downcast = None ,
43+ container : Optional [Container ] = None ,
44+ name : Optional [str ] = None ,
45+ ):
46+ r"""
47+ Parameters
48+ ----------
49+ type
50+ The `Type` for this variable (see `Variable`).
51+ value
52+ A value to associate with this variable (a new container will be
53+ created).
54+ strict
55+ ``True`` means that values assigned to this variable will not be
56+ cast or copied, so they must have the correct `Type`\s.
57+ allow_downcast
58+ Only applies if `strict` is ``False``.
59+ ``True`` means that the assigned value can lose precision when cast
60+ during assignment. ``None`` means that only down-casting of a Python
61+ float to a scalar ``floatX`` is allowed.
62+ container
63+ The container to use for this variable. Illegal to pass this as well as
64+ a value.
65+ name
66+ The name for this variable (see `Variable`).
3867
39- def __init__ ( self , name , type , value , strict , allow_downcast = None , container = None ):
40- super ().__init__ (type = type , name = name , owner = None , index = None )
68+ """
69+ super ().__init__ (type = type , owner = None , index = None , name = name )
4170
4271 if container is not None :
4372 self .container = container
@@ -107,26 +136,6 @@ def set_value(self, new_value, borrow=False):
107136 def get_test_value (self ):
108137 return self .get_value (borrow = True , return_internal_type = True )
109138
110- def zero (self , borrow = False ):
111- """
112- Set the values of a shared variable to 0.
113-
114- Parameters
115- ----------
116- borrow : bbol
117- True to modify the value of a shared variable directly by using
118- its previous value. Potentially this can cause problems
119- regarding to the aliased memory.
120-
121- Changes done with this function will be visible to all functions using
122- this SharedVariable.
123-
124- """
125- if borrow :
126- self .container .value [...] = 0
127- else :
128- self .container .value = 0 * self .container .value
129-
130139 def clone (self , ** kwargs ):
131140 name = kwargs .get ("name" , self .name )
132141 cp = self .__class__ (
@@ -209,7 +218,7 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw
209218 return SharedVariable (
210219 type = generic ,
211220 value = value ,
212- name = name ,
213221 strict = strict ,
214222 allow_downcast = allow_downcast ,
223+ name = name ,
215224 )
0 commit comments