2727
2828from pymc .model .core import modelcontext
2929from pymc .pytensorf import _cheap_eval_mode
30- from pymc .util import VarName , get_default_varnames , get_var_name
30+ from pymc .util import get_default_varnames , get_var_name
3131
3232__all__ = (
3333 "ModelGraph" ,
@@ -173,7 +173,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs:
173173 }
174174
175175
176- def get_node_type (var_name : VarName , model ) -> NodeType :
176+ def get_node_type (var_name : str , model ) -> NodeType :
177177 """Return the node type of the variable in the model."""
178178 v = model [var_name ]
179179
@@ -242,7 +242,7 @@ def __init__(self, model):
242242 self ._all_vars = {model [var_name ] for var_name in self ._all_var_names }
243243 self .var_list = self .model .named_vars .values ()
244244
245- def get_parent_names (self , var : Variable ) -> set [VarName ]:
245+ def get_parent_names (self , var : Variable ) -> set [str ]:
246246 if var .owner is None :
247247 return set ()
248248
@@ -261,12 +261,12 @@ def _expand(x):
261261 return x .owner .inputs
262262
263263 return {
264- cast (VarName , ancestor .name ) # type: ignore[union-attr]
264+ cast (str , ancestor .name ) # type: ignore[union-attr]
265265 for ancestor in walk (nodes = var .owner .inputs , expand = _expand )
266266 if ancestor in named_vars
267267 }
268268
269- def vars_to_plot (self , var_names : Iterable [VarName ] | None = None ) -> list [VarName ]:
269+ def vars_to_plot (self , var_names : Iterable [str ] | None = None ) -> list [str ]:
270270 if var_names is None :
271271 return self ._all_var_names
272272
@@ -296,13 +296,11 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
296296 # ordering of self._all_var_names is important
297297 return [get_var_name (var ) for var in selected_ancestors ]
298298
299- def make_compute_graph (
300- self , var_names : Iterable [VarName ] | None = None
301- ) -> dict [VarName , set [VarName ]]:
299+ def make_compute_graph (self , var_names : Iterable [str ] | None = None ) -> dict [str , set [str ]]:
302300 """Get map of var_name -> set(input var names) for the model."""
303301 model = self .model
304302 named_vars = self ._all_vars
305- input_map : dict [VarName , set [VarName ]] = defaultdict (set )
303+ input_map : dict [str , set [str ]] = defaultdict (set )
306304
307305 var_names_to_plot = self .vars_to_plot (var_names )
308306 for var_name in var_names_to_plot :
@@ -319,15 +317,15 @@ def make_compute_graph(
319317 for ancestor in ancestors ([obs_var ]):
320318 if ancestor not in named_vars :
321319 continue
322- obs_name = cast (VarName , ancestor .name )
320+ obs_name = cast (str , ancestor .name )
323321 input_map [var_name ].discard (obs_name )
324322 input_map [obs_name ].add (var_name )
325323
326324 return input_map
327325
328326 def get_plates (
329327 self ,
330- var_names : Iterable [VarName ] | None = None ,
328+ var_names : Iterable [str ] | None = None ,
331329 ) -> list [Plate ]:
332330 """Rough but surprisingly accurate plate detection.
333331
@@ -337,7 +335,7 @@ def get_plates(
337335 Returns
338336 -------
339337 dict
340- Maps plate labels to the set of ``VarName``s inside the plate.
338+ Maps plate labels to the set of strings inside the plate.
341339 """
342340 plates = defaultdict (set )
343341
@@ -389,8 +387,8 @@ def get_plates(
389387
390388 def edges (
391389 self ,
392- var_names : Iterable [VarName ] | None = None ,
393- ) -> list [tuple [VarName , VarName ]]:
390+ var_names : Iterable [str ] | None = None ,
391+ ) -> list [tuple [str , str ]]:
394392 """Get edges between the variables in the model.
395393
396394 Parameters
@@ -405,7 +403,7 @@ def edges(
405403
406404 """
407405 return [
408- (VarName (child .replace (":" , "&" )), VarName (parent .replace (":" , "&" )))
406+ (str (child .replace (":" , "&" )), str (parent .replace (":" , "&" )))
409407 for child , parents in self .make_compute_graph (var_names = var_names ).items ()
410408 for parent in parents
411409 ]
@@ -422,7 +420,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
422420def make_graph (
423421 name : str ,
424422 plates : list [Plate ],
425- edges : list [tuple [VarName , VarName ]],
423+ edges : list [tuple [str , str ]],
426424 formatting : str = "plain" ,
427425 save = None ,
428426 figsize = None ,
@@ -496,7 +494,7 @@ def make_graph(
496494def make_networkx (
497495 name : str ,
498496 plates : list [Plate ],
499- edges : list [tuple [VarName , VarName ]],
497+ edges : list [tuple [str , str ]],
500498 formatting : str = "plain" ,
501499 node_formatters : NodeTypeFormatterMapping | None = None ,
502500 create_plate_label : PlateLabelFunc = create_plate_label_with_dim_length ,
@@ -566,7 +564,7 @@ def make_networkx(
566564def model_to_networkx (
567565 model = None ,
568566 * ,
569- var_names : Iterable [VarName ] | None = None ,
567+ var_names : Iterable [str ] | None = None ,
570568 formatting : str = "plain" ,
571569 node_formatters : NodeTypeFormatterMapping | None = None ,
572570 include_dim_lengths : bool = True ,
@@ -660,7 +658,7 @@ def model_to_networkx(
660658def model_to_graphviz (
661659 model = None ,
662660 * ,
663- var_names : Iterable [VarName ] | None = None ,
661+ var_names : Iterable [str ] | None = None ,
664662 formatting : str = "plain" ,
665663 save : str | None = None ,
666664 figsize : tuple [int , int ] | None = None ,
0 commit comments