2121from typing import Any , cast
2222
2323from pytensor import function
24- from pytensor .graph .basic import ancestors , walk
24+ from pytensor .graph .basic import Variable , ancestors , walk
2525from pytensor .tensor .shape import Shape
26- from pytensor .tensor .variable import TensorVariable
2726
2827from pymc .model .core import modelcontext
28+ from pymc .pytensorf import _cheap_eval_mode
2929from pymc .util import VarName , get_default_varnames , get_var_name
3030
3131__all__ = (
@@ -73,7 +73,7 @@ def create_plate_label_with_dim_length(
7373
7474
7575def fast_eval (var ):
76- return function ([], var , mode = "FAST_COMPILE" )()
76+ return function ([], var , mode = _cheap_eval_mode )()
7777
7878
7979class NodeType (str , Enum ):
@@ -88,7 +88,7 @@ class NodeType(str, Enum):
8888
8989@dataclass
9090class NodeInfo :
91- var : TensorVariable
91+ var : Variable
9292 node_type : NodeType
9393
9494 def __hash__ (self ):
@@ -108,10 +108,10 @@ def __eq__(self, other) -> bool:
108108
109109
110110GraphvizNodeKwargs = dict [str , Any ]
111- NodeFormatter = Callable [[TensorVariable ], GraphvizNodeKwargs ]
111+ NodeFormatter = Callable [[Variable ], GraphvizNodeKwargs ]
112112
113113
114- def default_potential (var : TensorVariable ) -> GraphvizNodeKwargs :
114+ def default_potential (var : Variable ) -> GraphvizNodeKwargs :
115115 """Return default data for potential in the graph."""
116116 return {
117117 "shape" : "octagon" ,
@@ -120,17 +120,19 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs:
120120 }
121121
122122
123- def random_variable_symbol (var : TensorVariable ) -> str :
123+ def random_variable_symbol (var : Variable ) -> str :
124124 """Get the symbol of the random variable."""
125- symbol = var .owner .op . __class__ . __name__
125+ op = var .owner .op
126126
127- if symbol .endswith ("RV" ):
128- symbol = symbol [:- 2 ]
127+ if name := getattr (op , "name" , None ):
128+ symbol = name [0 ].upper () + name [1 :]
129+ else :
130+ symbol = op .__class__ .__name__ .removesuffix ("RV" )
129131
130132 return symbol
131133
132134
133- def default_free_rv (var : TensorVariable ) -> GraphvizNodeKwargs :
135+ def default_free_rv (var : Variable ) -> GraphvizNodeKwargs :
134136 """Return default data for free RV in the graph."""
135137 symbol = random_variable_symbol (var )
136138
@@ -141,7 +143,7 @@ def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs:
141143 }
142144
143145
144- def default_observed_rv (var : TensorVariable ) -> GraphvizNodeKwargs :
146+ def default_observed_rv (var : Variable ) -> GraphvizNodeKwargs :
145147 """Return default data for observed RV in the graph."""
146148 symbol = random_variable_symbol (var )
147149
@@ -152,7 +154,7 @@ def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs:
152154 }
153155
154156
155- def default_deterministic (var : TensorVariable ) -> GraphvizNodeKwargs :
157+ def default_deterministic (var : Variable ) -> GraphvizNodeKwargs :
156158 """Return default data for the deterministic in the graph."""
157159 return {
158160 "shape" : "box" ,
@@ -161,7 +163,7 @@ def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs:
161163 }
162164
163165
164- def default_data (var : TensorVariable ) -> GraphvizNodeKwargs :
166+ def default_data (var : Variable ) -> GraphvizNodeKwargs :
165167 """Return default data for the data in the graph."""
166168 return {
167169 "shape" : "box" ,
@@ -239,7 +241,7 @@ def __init__(self, model):
239241 self ._all_vars = {model [var_name ] for var_name in self ._all_var_names }
240242 self .var_list = self .model .named_vars .values ()
241243
242- def get_parent_names (self , var : TensorVariable ) -> set [VarName ]:
244+ def get_parent_names (self , var : Variable ) -> set [VarName ]:
243245 if var .owner is None :
244246 return set ()
245247
@@ -345,7 +347,7 @@ def get_plates(
345347 dim_name : fast_eval (value ).item () for dim_name , value in self .model .dim_lengths .items ()
346348 }
347349 var_shapes : dict [str , tuple [int , ...]] = {
348- var_name : tuple (fast_eval (self .model [var_name ].shape ))
350+ var_name : tuple (map ( int , fast_eval (self .model [var_name ].shape ) ))
349351 for var_name in self .vars_to_plot (var_names )
350352 }
351353
0 commit comments