Skip to content

Commit 3b905bb

Browse files
Merge pull request #80 from brandonwillard/add-meta-object-creation-options
Add ability to disable base object generation and use logic variable defaults
2 parents dff6ff4 + 7a56ed4 commit 3b905bb

File tree

4 files changed

+184
-41
lines changed

4 files changed

+184
-41
lines changed

.pylintrc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ enable=import-error,
5454
unused-wildcard-import,
5555
global-variable-not-assigned,
5656
undefined-loop-variable,
57-
global-statement,
5857
global-at-module-level,
5958
bad-open-mode,
6059
redundant-unittest-assert,
@@ -264,7 +263,7 @@ ignore-mixin-members=yes
264263
# (useful for modules/projects where namespaces are manipulated during runtime
265264
# and thus existing member attributes cannot be deduced by static analysis. It
266265
# supports qualified module names, as well as Unix pattern matching.
267-
ignored-modules=tensorflow.core.framework,tensorflow.python.framework
266+
ignored-modules=tensorflow.core.framework,tensorflow.python.framework,tensorflow.python.ops.gen_linalg_ops
268267

269268
# List of classes names for which member attributes should not be checked
270269
# (useful for classes with attributes dynamically set). This supports can work

symbolic_pymc/meta.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from itertools import chain
66
from functools import partial
77
from collections import OrderedDict
8+
from contextlib import contextmanager
89
from collections.abc import Iterator, Mapping
910

1011
from unification import isvar, Var
@@ -22,6 +23,46 @@
2223

2324
metatize_cache = {}
2425

26+
_auto_reification_disabled = False
27+
_lvar_defaults_enabled = set()
28+
29+
30+
@contextmanager
31+
def disable_auto_reification():
32+
"""Stop meta objects from automatically reifying themselves in order to determine unspecified properties."""
33+
global _auto_reification_disabled
34+
_current_value = _auto_reification_disabled
35+
_auto_reification_disabled = True
36+
try:
37+
yield
38+
finally:
39+
_auto_reification_disabled = _current_value
40+
41+
42+
@contextmanager
43+
def enable_lvar_defaults(*types):
44+
"""Use logic variables instead of guessed/inferred values during meta object creation.
45+
46+
This is useful for handling unexpected values--created by default or behind
47+
the scenes--in backend base objects (e.g. default names, TF NodeDef
48+
attributes, etc.). By using logic variables instead, it's much easier to
49+
create meta object "patterns" when certain types of exactness aren't
50+
necessary.
51+
52+
Parameters
53+
----------
54+
types: collection of str
55+
String names for the types we want to make default to logic variables.
56+
Currently allowed values are "names" and "node_attrs" (for TensorFlow).
57+
"""
58+
global _lvar_defaults_enabled
59+
_current_value = _lvar_defaults_enabled
60+
_lvar_defaults_enabled = set(types)
61+
try:
62+
yield
63+
finally:
64+
_lvar_defaults_enabled = _current_value
65+
2566

2667
def metatize(obj):
2768
"""Convert object to base type then meta object."""
@@ -107,6 +148,12 @@ def __new__(cls, name, bases, clsdict):
107148

108149
if clsdict["__volatile_slots__"]:
109150

151+
def reset(self):
152+
for s in self.__volatile_slots__:
153+
object.__setattr__(self, s, None)
154+
155+
clsdict["reset"] = reset
156+
110157
def __setattr__(self, attr, obj):
111158
"""If a slot value is changed, reset cached slots."""
112159

@@ -121,8 +168,7 @@ def __setattr__(self, attr, obj):
121168
# Are we setting it to a new value?
122169
and getattr(self, attr) is not obj
123170
):
124-
for s in self.__volatile_slots__:
125-
object.__setattr__(self, s, None)
171+
self.reset()
126172

127173
object.__setattr__(self, attr, obj)
128174

@@ -150,6 +196,8 @@ def _cached_hash(self):
150196
if getattr(self, "_hash", None) is not None:
151197
return self._hash
152198

199+
# TODO: We could also descend into `__props__` and reset their
200+
# `_hash` values, as well.
153201
object.__setattr__(self, "_hash", _orig_hash(self))
154202

155203
return self._hash

symbolic_pymc/tensorflow/meta.py

Lines changed: 86 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from inspect import Parameter, Signature
1010

11-
from collections import OrderedDict
11+
from collections import OrderedDict, Sequence
1212

1313
from functools import partial
1414

@@ -35,6 +35,8 @@
3535
metatize,
3636
)
3737

38+
from .. import meta
39+
3840

3941
class MetaOpDefLibrary(object):
4042

@@ -164,6 +166,15 @@ def _metatize_tf_object(obj):
164166

165167
def load_dispatcher():
166168
"""Set/override dispatcher to default to TF objects."""
169+
170+
from tensorflow.python.ops.gen_linalg_ops import _SvdOutput
171+
172+
def _metatize_tf_svd(obj):
173+
"""Turn a TensorFlow `Svd` object/tuple into a standard tuple."""
174+
return _metatize(tuple(obj))
175+
176+
_metatize.add((_SvdOutput,), _metatize_tf_svd)
177+
167178
_metatize.add((object,), _metatize_tf_object)
168179

169180

@@ -207,6 +218,7 @@ class TFlowMetaOpDef(MetaOp, metaclass=OpDefFactoryType):
207218
208219
>>> from google.protobuf import json_format
209220
>>> print(json_format.MessageToJson(opdef))
221+
210222
- If you want to use an `OpDef` to construct a node, see
211223
`op_def_library.apply_op`.
212224
@@ -220,39 +232,53 @@ def __init__(self, obj=None):
220232
self._apply_func_sig, self._apply_func = op_def_lib.get_op_info(obj)
221233

222234
def out_meta_types(self, inputs=None, node_def=None):
235+
"""Return a list of tuples containing object types and corresponding dtypes for the outputs of this OpDef."""
236+
223237
def _convert_outputs(o):
224-
if o.type_attr == "T" and node_def:
238+
if o.type_attr == "T" and hasattr(node_def, "attr"):
225239
return (TFlowMetaTensor, node_def.attr.get("T", var()))
226240
elif o.type_attr == "dtype" and inputs:
227241
return (TFlowMetaTensor, inputs.get("dtype", var()))
228242
else:
229243
return (TFlowMetaTensor, var())
230244

245+
# TODO: We also have permissible dtype information from objects in the
246+
# array `self.obj.attr` under the field `allowed_values`.
247+
231248
out_meta_types = tuple(_convert_outputs(o) for o in self.obj.output_arg)
232-
# TODO: We also have permissible dtype information:
233-
# from objects in the array `self.obj.attr` under the field
234-
# `allowed_values`.
249+
235250
return out_meta_types
236251

237-
def input_args(self, *args, **kwargs):
252+
def input_args(self, *args, apply_defaults=True, **kwargs):
253+
"""Return a list of arguments for this OpDef's 'apply function'."""
238254
kwargs = OrderedDict(
239255
(k, v)
240256
for k, v in kwargs.items()
241257
# Filter out the optional keyword arguments so they we only pass
242258
# expected arguments to the `OpDef`'s apply function.
243259
if k in self._apply_func_sig.parameters
244260
)
261+
245262
op_args = self._apply_func_sig.bind(*args, **kwargs)
246-
op_args.apply_defaults()
263+
264+
if apply_defaults:
265+
op_args.apply_defaults()
266+
247267
return op_args.arguments
248268

249269
def __call__(self, *args, **kwargs):
250270
"""Create the meta object(s) resulting from an application of this `OpDef`'s implied `Operation`."""
251-
op_args, op_args_unreified = meta_reify_iter(args)
252-
op_kwargs, op_kwargs_unreified = meta_reify_iter(kwargs)
271+
272+
if not meta._auto_reification_disabled:
273+
op_args, op_args_unreified = meta_reify_iter(args)
274+
op_kwargs, op_kwargs_unreified = meta_reify_iter(kwargs)
275+
else:
276+
op_args, op_args_unreified = args, True
277+
op_kwargs, op_kwargs_unreified = kwargs, True
278+
253279
apply_arguments = self.input_args(*op_args, **op_kwargs)
254280

255-
if not op_args_unreified and not op_kwargs_unreified:
281+
if not (op_args_unreified or op_kwargs_unreified):
256282

257283
# them into meta objects. Doing so will yield information we
258284
# wouldn't be able to produce otherwise (e.g. shape info).
@@ -269,29 +295,43 @@ def __call__(self, *args, **kwargs):
269295

270296
tf_out = self._apply_func(**apply_arguments)
271297
res_var = metatize(tf_out)
272-
return res_var
273-
274-
#
275-
# If we're here, that means we have to create the meta objects
276-
# manually.
277-
#
278-
# TODO: `tf.placeholder`s are pretty flexible, we could probably use
279-
# one as a stand-in for any un-reified tensor arguments and at least
280-
# get some partial `dtype`, `shape` and `name` info.
281-
282-
op_input_args = tuple(
283-
apply_arguments.get(i.name) for i in self.obj.input_arg if i.name in apply_arguments
284-
)
285298

286-
node_attr = {a.name: apply_arguments.get(a.name, a) for a in self.obj.attr}
299+
if "names" in meta._lvar_defaults_enabled:
300+
# This should also reset the NodeDef's `obj`
301+
res_var.op.node_def.name = var()
302+
res_var.op.reset()
303+
res_var.reset()
287304

288-
op_name = op_kwargs.get("name", self.obj.name)
305+
if "node_attrs" in meta._lvar_defaults_enabled:
306+
# This should also reset the NodeDef's `obj`
307+
res_var.op.node_def.attr = var()
308+
res_var.op.reset()
309+
res_var.reset()
289310

290-
node_def = TFlowMetaNodeDef(self.obj.name, op_name, node_attr)
311+
else:
312+
#
313+
# If we're here, that means we have to create the meta objects
314+
# manually.
315+
#
291316

292-
op_mt = TFlowMetaOp(self, node_def, op_input_args)
317+
op_input_args = tuple(
318+
apply_arguments.get(i.name) for i in self.obj.input_arg if i.name in apply_arguments
319+
)
320+
321+
if "node_attrs" not in meta._lvar_defaults_enabled:
322+
node_attr = {a.name: apply_arguments.get(a.name, a) for a in self.obj.attr}
323+
else:
324+
node_attr = var()
325+
326+
op_name = op_kwargs.get(
327+
"name", self.obj.name if "names" not in meta._lvar_defaults_enabled else var()
328+
)
329+
330+
node_def = TFlowMetaNodeDef(self.obj.name, op_name, node_attr)
293331

294-
res_var = op_mt.default_output
332+
op_mt = TFlowMetaOp(self, node_def, op_input_args)
333+
334+
res_var = op_mt.default_output
295335

296336
return res_var
297337

@@ -517,16 +557,26 @@ def outputs(self):
517557
if getattr(self, "_outputs", None) is not None:
518558
return self._outputs
519559

520-
if (
521-
isvar(self.op_def)
522-
or isvar(self.inputs)
523-
or isvar(self.node_def)
524-
or isvar(self.node_def.attr)
525-
):
560+
if isvar(self.op_def):
526561
self._outputs = var()
527562
else:
528563

529-
apply_arguments = self.op_def.input_args(*self.inputs, **self.node_def.attr)
564+
if isvar(self.node_def) or isvar(getattr(self.node_def, "attr")):
565+
node_attr = {}
566+
else:
567+
node_attr = self.node_def.attr
568+
569+
if isvar(self.inputs):
570+
inputs = (None,) * len(self.op_def._apply_func_sig.parameters)
571+
apply_defaults = False
572+
else:
573+
inputs = self.inputs
574+
apply_defaults = True
575+
576+
apply_arguments = self.op_def.input_args(
577+
*inputs, apply_defaults=apply_defaults, **node_attr
578+
)
579+
530580
out_types_mt = self.op_def.out_meta_types(
531581
inputs=apply_arguments, node_def=self.node_def
532582
)
@@ -551,7 +601,7 @@ def default_output(self):
551601

552602
mt_outs = self.outputs
553603

554-
if len(mt_outs) == 1:
604+
if isinstance(mt_outs, Sequence) and len(mt_outs) == 1:
555605
out_var = mt_outs[0]
556606
else:
557607
out_var = mt_outs

tests/tensorflow/test_meta.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from unification import var, isvar
1616

17-
from symbolic_pymc.meta import MetaSymbol
17+
from symbolic_pymc.meta import MetaSymbol, disable_auto_reification, enable_lvar_defaults
1818
from symbolic_pymc.tensorflow.meta import (TFlowMetaTensor,
1919
TFlowMetaTensorShape,
2020
TFlowMetaOp,
@@ -228,6 +228,13 @@ def test_meta_lvars():
228228
mo_mt = TFlowMetaOp(var(), var(), var(), var())
229229
assert all(isvar(getattr(mo_mt, s)) for s in mo_mt.__all_props__)
230230

231+
mo_mt = TFlowMetaOp(var(), var(), var())
232+
assert isvar(mo_mt.op_def)
233+
assert isvar(mo_mt.outputs)
234+
235+
mo_mt = TFlowMetaOp(mt.Add, var(), var())
236+
assert len(mo_mt.outputs) == 1
237+
231238
ts_mt = TFlowMetaTensorShape(var())
232239
assert all(isvar(getattr(ts_mt, s)) for s in ts_mt.__all_props__)
233240

@@ -286,6 +293,13 @@ def test_meta_multi_output():
286293
assert d.op.outputs == (d, U, V)
287294
assert d.op.default_output is d.op.outputs
288295

296+
tf.compat.v1.disable_eager_execution()
297+
298+
X_mt = mt(np.eye(2))
299+
d, U, V = mt.linalg.svd(X_mt)
300+
d.value_index = var()
301+
assert isinstance(d.reify(), TFlowMetaTensor)
302+
289303

290304
@pytest.mark.usefixtures("run_with_tensorflow")
291305
@run_in_graph_mode
@@ -515,3 +529,35 @@ def test_tensor_ops():
515529
abs_mt = abs(x_mt)
516530
assert abs_mt.name == abs_tf.name
517531
assert abs_mt.op.type == abs_tf.op.type
532+
533+
534+
@pytest.mark.usefixtures("run_with_tensorflow")
535+
@run_in_graph_mode
536+
def test_global_options():
537+
538+
with tf.Graph().as_default():
539+
x_mt = mt.Placeholder('float')
540+
assert isinstance(x_mt.obj, tf.Tensor)
541+
assert x_mt.name == 'Placeholder:0'
542+
543+
with tf.Graph().as_default(), disable_auto_reification():
544+
y_mt = mt.Placeholder('float')
545+
assert y_mt.obj is None
546+
assert y_mt.name == 'Placeholder:0'
547+
assert isinstance(y_mt.op.node_def.attr, dict)
548+
549+
with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'):
550+
# This *will* auto-reify and have base versions of `names` and `attrs`;
551+
# however, it will replace those with lvars.
552+
z_mt = mt.Placeholder('float')
553+
assert z_mt.obj is None
554+
assert isvar(z_mt.name)
555+
assert isvar(z_mt.op.node_def.attr)
556+
557+
with tf.Graph().as_default(), disable_auto_reification(), enable_lvar_defaults('names', 'node_attrs'):
558+
# This will *not* auto-reify and simply create the object from scratch with meta types
559+
# and the appropriate/desired logic variables.
560+
z_mt = mt.Placeholder('float')
561+
assert z_mt.obj is None
562+
assert isvar(z_mt.name)
563+
assert isvar(z_mt.op.node_def.attr)

0 commit comments

Comments
 (0)