Skip to content

Commit 8b673bb

Browse files
Extend TF meta object creation options to metatize
These changes allow one to--for instance--create meta objects with logic variable names even when the objects are created through `metatize` (e.g. constant tensors). Existing meta objects are also preserved--instead of recreated--when they're used as inputs to a meta object that's derived from a "temporary" base object (per `TFlowMetaOpDef.__call__`).
1 parent 613e387 commit 8b673bb

File tree

2 files changed

+84
-27
lines changed

2 files changed

+84
-27
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from functools import partial
1414

15+
from cachetools import cachedmethod, Cache
16+
1517
from unification import Var, var, isvar
1618

1719
from google.protobuf.message import Message
@@ -37,6 +39,8 @@
3739

3840
from .. import meta
3941

42+
tf_metatize_cache = Cache(50)
43+
4044

4145
class MetaOpDefLibrary(object):
4246

@@ -147,26 +151,17 @@ def get_op_info(cls, opdef):
147151

148152
def _metatize_tf_object(obj):
149153
try:
150-
obj = tf.convert_to_tensor(obj)
154+
tf_obj = tf.convert_to_tensor(obj)
151155
except (TypeError, ValueError):
152156
raise ValueError("Could not find a TensorFlow MetaSymbol class for {obj}")
153157

154-
if isinstance(obj, tf.Tensor):
155-
try:
156-
obj.op
157-
except AttributeError:
158-
raise AttributeError(
159-
f"TensorFlow Operation not available; "
160-
"try recreating the object with eager-mode disabled"
161-
" (e.g. within `tensorflow.python.eager.context.graph_mode`)"
162-
)
163-
164-
return _metatize(obj)
158+
return _metatize(tf_obj)
165159

166160

167161
def load_dispatcher():
168162
"""Set/override dispatcher to default to TF objects."""
169163

164+
from tensorflow.python.framework.ops import EagerTensor
170165
from tensorflow.python.ops.gen_linalg_ops import _SvdOutput
171166

172167
def _metatize_tf_svd(obj):
@@ -175,6 +170,16 @@ def _metatize_tf_svd(obj):
175170

176171
_metatize.add((_SvdOutput,), _metatize_tf_svd)
177172

173+
def _metatize_tf_eager(obj):
174+
"""Catch eager tensor metatize issues early."""
175+
raise AttributeError(
176+
f"TensorFlow Operation not available; "
177+
"try recreating the object with eager-mode disabled"
178+
" (e.g. within `tensorflow.python.eager.context.graph_mode`)"
179+
)
180+
181+
_metatize.add((EagerTensor,), _metatize_tf_eager)
182+
178183
_metatize.add((object,), _metatize_tf_object)
179184

180185

@@ -269,16 +274,14 @@ def input_args(self, *args, apply_defaults=True, **kwargs):
269274
def __call__(self, *args, **kwargs):
270275
"""Create the meta object(s) resulting from an application of this `OpDef`'s implied `Operation`."""
271276

277+
apply_arguments = self.input_args(*args, **kwargs)
278+
272279
if not meta._auto_reification_disabled:
273-
op_args, op_args_unreified = meta_reify_iter(args)
274-
op_kwargs, op_kwargs_unreified = meta_reify_iter(kwargs)
280+
op_args, op_args_unreified = meta_reify_iter(apply_arguments)
275281
else:
276-
op_args, op_args_unreified = args, True
277-
op_kwargs, op_kwargs_unreified = kwargs, True
278-
279-
apply_arguments = self.input_args(*op_args, **op_kwargs)
282+
op_args, op_args_unreified = apply_arguments, True
280283

281-
if not (op_args_unreified or op_kwargs_unreified):
284+
if not op_args_unreified:
282285

283286
# them into meta objects. Doing so will yield information we
284287
# wouldn't be able to produce otherwise (e.g. shape info).
@@ -289,11 +292,22 @@ def __call__(self, *args, **kwargs):
289292
# the TF-`Operation` inferred values (e.g. shapes, dtypes, etc.)
290293

291294
# We have to use a primitive string or TF will complain.
292-
name = apply_arguments.get("name", None)
295+
name = op_args.get("name", None)
293296
if name is not None:
294-
apply_arguments["name"] = str(name)
297+
op_args["name"] = str(name)
298+
299+
tf_out = self._apply_func(**op_args)
300+
301+
# Ensure that the original meta objects will result
302+
# from the following `metatize`
303+
tf_metatize_cache.update(
304+
{
305+
k: v
306+
for k, v in zip(op_args.values(), apply_arguments.values())
307+
if isinstance(k, tf.Tensor)
308+
}
309+
)
295310

296-
tf_out = self._apply_func(**apply_arguments)
297311
res_var = metatize(tf_out)
298312

299313
if "names" in meta._lvar_defaults_enabled:
@@ -324,7 +338,7 @@ def __call__(self, *args, **kwargs):
324338
node_attr = var()
325339

326340
if "names" not in meta._lvar_defaults_enabled:
327-
op_name = op_kwargs.get("name", self.obj.name)
341+
op_name = kwargs.get("name", self.obj.name)
328342
else:
329343
op_name = var()
330344

@@ -372,6 +386,18 @@ class TFlowMetaNodeDef(TFlowMetaSymbol):
372386
base = NodeDef
373387
__slots__ = ["op", "name", "attr", "_frozen_attr"]
374388

389+
@classmethod
390+
def _metatize(cls, obj):
391+
res = super()._metatize(obj)
392+
393+
if "node_attrs" in meta._lvar_defaults_enabled:
394+
res.attr = var()
395+
396+
if "names" in meta._lvar_defaults_enabled:
397+
res.name = var()
398+
399+
return res
400+
375401
@classmethod
376402
def _protobuf_convert(cls, k, v):
377403
"""Convert a small subset of protobuf objects.
@@ -473,7 +499,12 @@ def _metatize(cls, obj):
473499
new_args = [
474500
getattr(obj, s) if s != "inputs" else new_input for s in getattr(cls, "__props__", [])
475501
]
476-
return cls(*new_args, obj=obj)
502+
res = cls(*new_args, obj=obj)
503+
504+
if meta._lvar_defaults_enabled.issuperset(["node_attrs", "names"]):
505+
res.reset()
506+
507+
return res
477508

478509
def __init__(self, op_def, node_def, inputs, outputs=None, obj=None):
479510
"""Create a TensorFlow meta `Operation`.
@@ -654,6 +685,17 @@ class TFlowMetaTensor(TFlowMetaSymbol, MetaVariable):
654685
base = tf.Tensor
655686
__slots__ = ("op", "value_index", "dtype", "_shape", "_name")
656687

688+
@classmethod
689+
@cachedmethod(lambda cls: tf_metatize_cache)
690+
def _metatize(cls, obj):
691+
692+
res = super()._metatize(obj)
693+
694+
if meta._lvar_defaults_enabled.issuperset(["node_attrs", "names"]):
695+
res.reset()
696+
697+
return res
698+
657699
def __init__(self, op, value_index, dtype, obj=None):
658700
self.op = metatize(op)
659701
# TODO: Sync this value with `op.node_def.attr['dtype']` and/or
@@ -679,13 +721,14 @@ def name(self):
679721
if getattr(self, "_name", None):
680722
return self._name
681723

682-
if self.obj is not None and not isinstance(self.obj, Var):
683-
name = self.obj.name
684-
elif isinstance(getattr(self.op, "name", None), str) and not isvar(self.value_index):
724+
if isinstance(getattr(self.op, "name", None), str) and not isvar(self.value_index):
685725
name = f"{self.op.name}:{self.value_index}"
686726
else:
687727
name = var()
688728

729+
if self.obj is not None and not isinstance(self.obj, Var):
730+
assert name == self.obj.name
731+
689732
self._name = name
690733
return self._name
691734

tests/tensorflow/test_meta.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,3 +579,17 @@ def test_global_options():
579579
assert isvar(y_mt.op.node_def.attr)
580580
assert isvar(y_mt.op.inputs[0].op.node_def.attr)
581581
assert isvar(y_mt.op.inputs[1].op.node_def.attr)
582+
583+
with tf.Graph().as_default() as test_graph:
584+
a_mt = mt(2.0)
585+
assert a_mt.obj is not None
586+
587+
with test_graph.as_default(), enable_lvar_defaults('names', 'node_attrs'):
588+
a_new_mt = mt(a_mt)
589+
assert a_new_mt is a_mt
590+
591+
b_mt = 1.0 * a_mt
592+
assert a_mt.obj is not None
593+
assert isvar(b_mt.name)
594+
assert isvar(b_mt.op.node_def.attr)
595+
assert b_mt.op.inputs[1] is a_mt

0 commit comments

Comments
 (0)