Skip to content

Commit b131236

Browse files
osyukselricardoV94
authored andcommitted
Remove VarName
1 parent 41f33ad commit b131236

File tree

3 files changed

+22
-27
lines changed

3 files changed

+22
-27
lines changed

pymc/model/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
)
6868
from pymc.util import (
6969
UNSET,
70-
VarName,
7170
WithMemoization,
7271
_UnsetType,
7372
get_transformed_name,
@@ -1968,7 +1967,7 @@ def debug_parameters(rv):
19681967
def to_graphviz(
19691968
self,
19701969
*,
1971-
var_names: Iterable[VarName] | None = None,
1970+
var_names: Iterable[str] | None = None,
19721971
formatting: str = "plain",
19731972
save: str | None = None,
19741973
figsize: tuple[int, int] | None = None,
@@ -2172,7 +2171,7 @@ def compile_fn(
21722171
)
21732172

21742173

2175-
def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]:
2174+
def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]:
21762175
"""Build a point.
21772176
21782177
Uses same args as dict() does.

pymc/model_graph.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from pymc.model.core import modelcontext
2929
from pymc.pytensorf import _cheap_eval_mode
30-
from pymc.util import VarName, get_default_varnames, get_var_name
30+
from pymc.util import get_default_varnames, get_var_name
3131

3232
__all__ = (
3333
"ModelGraph",
@@ -173,7 +173,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs:
173173
}
174174

175175

176-
def get_node_type(var_name: VarName, model) -> NodeType:
176+
def get_node_type(var_name: str, model) -> NodeType:
177177
"""Return the node type of the variable in the model."""
178178
v = model[var_name]
179179

@@ -242,7 +242,7 @@ def __init__(self, model):
242242
self._all_vars = {model[var_name] for var_name in self._all_var_names}
243243
self.var_list = self.model.named_vars.values()
244244

245-
def get_parent_names(self, var: Variable) -> set[VarName]:
245+
def get_parent_names(self, var: Variable) -> set[str]:
246246
if var.owner is None:
247247
return set()
248248

@@ -261,12 +261,12 @@ def _expand(x):
261261
return x.owner.inputs
262262

263263
return {
264-
cast(VarName, ancestor.name) # type: ignore[union-attr]
264+
cast(str, ancestor.name) # type: ignore[union-attr]
265265
for ancestor in walk(nodes=var.owner.inputs, expand=_expand)
266266
if ancestor in named_vars
267267
}
268268

269-
def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
269+
def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]:
270270
if var_names is None:
271271
return self._all_var_names
272272

@@ -296,13 +296,11 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
296296
# ordering of self._all_var_names is important
297297
return [get_var_name(var) for var in selected_ancestors]
298298

299-
def make_compute_graph(
300-
self, var_names: Iterable[VarName] | None = None
301-
) -> dict[VarName, set[VarName]]:
299+
def make_compute_graph(self, var_names: Iterable[str] | None = None) -> dict[str, set[str]]:
302300
"""Get map of var_name -> set(input var names) for the model."""
303301
model = self.model
304302
named_vars = self._all_vars
305-
input_map: dict[VarName, set[VarName]] = defaultdict(set)
303+
input_map: dict[str, set[str]] = defaultdict(set)
306304

307305
var_names_to_plot = self.vars_to_plot(var_names)
308306
for var_name in var_names_to_plot:
@@ -319,15 +317,15 @@ def make_compute_graph(
319317
for ancestor in ancestors([obs_var]):
320318
if ancestor not in named_vars:
321319
continue
322-
obs_name = cast(VarName, ancestor.name)
320+
obs_name = cast(str, ancestor.name)
323321
input_map[var_name].discard(obs_name)
324322
input_map[obs_name].add(var_name)
325323

326324
return input_map
327325

328326
def get_plates(
329327
self,
330-
var_names: Iterable[VarName] | None = None,
328+
var_names: Iterable[str] | None = None,
331329
) -> list[Plate]:
332330
"""Rough but surprisingly accurate plate detection.
333331
@@ -337,7 +335,7 @@ def get_plates(
337335
Returns
338336
-------
339337
dict
340-
Maps plate labels to the set of ``VarName``s inside the plate.
338+
Maps plate labels to the set of strings inside the plate.
341339
"""
342340
plates = defaultdict(set)
343341

@@ -389,8 +387,8 @@ def get_plates(
389387

390388
def edges(
391389
self,
392-
var_names: Iterable[VarName] | None = None,
393-
) -> list[tuple[VarName, VarName]]:
390+
var_names: Iterable[str] | None = None,
391+
) -> list[tuple[str, str]]:
394392
"""Get edges between the variables in the model.
395393
396394
Parameters
@@ -405,7 +403,7 @@ def edges(
405403
406404
"""
407405
return [
408-
(VarName(child.replace(":", "&")), VarName(parent.replace(":", "&")))
406+
(str(child.replace(":", "&")), str(parent.replace(":", "&")))
409407
for child, parents in self.make_compute_graph(var_names=var_names).items()
410408
for parent in parents
411409
]
@@ -422,7 +420,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
422420
def make_graph(
423421
name: str,
424422
plates: list[Plate],
425-
edges: list[tuple[VarName, VarName]],
423+
edges: list[tuple[str, str]],
426424
formatting: str = "plain",
427425
save=None,
428426
figsize=None,
@@ -496,7 +494,7 @@ def make_graph(
496494
def make_networkx(
497495
name: str,
498496
plates: list[Plate],
499-
edges: list[tuple[VarName, VarName]],
497+
edges: list[tuple[str, str]],
500498
formatting: str = "plain",
501499
node_formatters: NodeTypeFormatterMapping | None = None,
502500
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
@@ -566,7 +564,7 @@ def make_networkx(
566564
def model_to_networkx(
567565
model=None,
568566
*,
569-
var_names: Iterable[VarName] | None = None,
567+
var_names: Iterable[str] | None = None,
570568
formatting: str = "plain",
571569
node_formatters: NodeTypeFormatterMapping | None = None,
572570
include_dim_lengths: bool = True,
@@ -660,7 +658,7 @@ def model_to_networkx(
660658
def model_to_graphviz(
661659
model=None,
662660
*,
663-
var_names: Iterable[VarName] | None = None,
661+
var_names: Iterable[str] | None = None,
664662
formatting: str = "plain",
665663
save: str | None = None,
666664
figsize: tuple[int, int] | None = None,

pymc/util.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections import namedtuple
1919
from collections.abc import Sequence
2020
from copy import deepcopy
21-
from typing import NewType, cast
21+
from typing import cast
2222

2323
import arviz
2424
import cloudpickle
@@ -31,8 +31,6 @@
3131

3232
from pymc.exceptions import BlockModelAccessError
3333

34-
VarName = NewType("VarName", str)
35-
3634

3735
class _UnsetType:
3836
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
@@ -214,9 +212,9 @@ def get_default_varnames(var_iterator, include_transformed):
214212
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]
215213

216214

217-
def get_var_name(var) -> VarName:
215+
def get_var_name(var) -> str:
218216
"""Get an appropriate, plain variable name for a variable."""
219-
return VarName(str(getattr(var, "name", var)))
217+
return var if isinstance(var, str) else str(var.name)
220218

221219

222220
def get_transformed(z):

0 commit comments

Comments
 (0)