Skip to content

Commit 900be8b

Browse files
Add some documentation and fix a bug in TFlowMetaTensor.inputs
1 parent b0944df commit 900be8b

File tree

2 files changed

+55
-15
lines changed

2 files changed

+55
-15
lines changed

symbolic_pymc/tensorflow/meta.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@
4343

4444

4545
class MetaOpDefLibrary(object):
46+
"""A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct.
47+
48+
It provides a map of `OpDef` names (lower-cased) to the Python API
49+
functions in `tensorflow.raw_ops`, as well as `inspect.Signature` objects
50+
for said functions so that default values and lists of arguments (keywords
51+
included) can be more easily used.
52+
53+
"""
4654

4755
lower_op_name_to_raw = {
4856
op_name.lower(): op_name
@@ -130,6 +138,12 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
130138

131139
@classmethod
132140
def get_op_info(cls, opdef):
141+
"""Return the TF Python API function signature for a given `OpDef`.
142+
143+
Parameter
144+
---------
145+
opdef: str or `OpDef` object (meta or base)
146+
"""
133147
if isinstance(opdef, str):
134148
opdef_name = opdef
135149
opdef = op_def_registry.get(opdef_name)
@@ -289,7 +303,14 @@ def input_args(self, *args, apply_defaults=True, **kwargs):
289303
return op_args.arguments
290304

291305
def __call__(self, *args, **kwargs):
292-
"""Create the meta object(s) resulting from an application of this `OpDef`'s implied `Operation`."""
306+
"""Create the meta object(s) using the TF Python API's operator functions.
307+
308+
Each meta `OpDef` is associated with a TF Python function
309+
(`self._apply_func`) that is used to construct its `Operation`s.
310+
311+
See `TFlowMetaTensor.operator` and `TFlowMetaTensor.operator`.
312+
313+
"""
293314

294315
apply_arguments = self.input_args(*args, **kwargs)
295316

@@ -385,8 +406,7 @@ def __eq__(self, other):
385406
if not (type(self) == type(other)):
386407
return False
387408

388-
if not (self.base == other.base):
389-
return False
409+
assert self.base == other.base
390410

391411
return self.obj.name == other.obj.name
392412

@@ -745,28 +765,44 @@ def name(self):
745765

746766
@property
747767
def operator(self):
768+
"""Return the meta OpDef for this tensor.
769+
770+
Since meta OpDefs are callable (and dispatch to the corresponding TF
771+
Python API function), this object called with arguments provided by
772+
`TFlowMetaTensor.inputs` recreates the underlying tensor using the TF
773+
Python interface. This approach has advantages over the purely
774+
graph-level approach to constructing meta objects, because--when all
775+
arguments are reifiable--it allows us to use purely TF means to
776+
construct a meta object (i.e. by first constructing the base object and
777+
then "metatizing" it).
778+
779+
Meta objects produced this way result in less unknown information
780+
(e.g. dtypes and shapes) and have the same default values as their base
781+
object counterparts (e.g. `Operator` names and `NodeDef.attr` values).
782+
"""
748783
if self.op is not None and not isvar(self.op):
749784
return self.op.op_def
750785

751786
@property
752787
def inputs(self):
753-
"""Return the tensor's inputs/rands.
788+
"""Return the inputs necessary to recreate this object using its TF Python API function.
789+
790+
These inputs differ from `self.op.inputs` primarily in that they
791+
contain the `node_def` parameters as keywords (e.g. to Python API
792+
functions like `tf.add`).
754793
755-
NOTE: These inputs differ from `self.op.inputs` in that they contain
756-
the `node_def` parameters, as well.
757-
In other words, these can be used to recreate this object (per
758-
the meta object spec).
794+
See `TFlowMetaTensor.operator` for more information.
759795
"""
760796
# TODO: In keeping with our desire to return logic variables in cases
761797
# where params aren't given/inferred, we could return something like
762798
# `cons(var(), var())` here (although that wouldn't be necessarily imply
763799
# that the result is a proper list/tuple).
764-
if self.op is not None and not isvar(self.op):
765-
input_args = self.op.op_def.input_args(
766-
*self.op.inputs,
767-
name=self.op.name if not isvar(self.op.name) else None,
768-
**self.op.node_def.attr,
769-
)
800+
if self.op is not None and not isvar(self.op) and not isvar(self.op.inputs):
801+
if not isvar(self.op.node_def) and not isvar(self.op.node_def.attr):
802+
attr = self.op.node_def.attr
803+
else:
804+
attr = {}
805+
input_args = self.op.op_def.input_args(*self.op.inputs, name=self.op.name, **attr)
770806
return tuple(input_args.values())
771807

772808
def reify(self):

tests/tensorflow/test_meta.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def test_meta_eager():
6464
@run_in_graph_mode
6565
def test_meta_basic():
6666

67+
assert mt.Add == mt.Add
68+
assert mt.Add != mt.Sub
69+
6770
var_mt = TFlowMetaTensor(var(), var(), var())
6871
# It should generate a logic variable for the name and use from here on.
6972
var_name = var_mt.name
@@ -248,9 +251,10 @@ def test_meta_lvars():
248251
assert all(isvar(getattr(tn_mt, s)) for s in tn_mt.__all_props__)
249252
assert isinstance(tn_mt.reify(), TFlowMetaTensor)
250253

251-
mo_mt = TFlowMetaOp(mt.Add, [tn_mt, tn_mt], var())
254+
mo_mt = TFlowMetaOp(mt.Add, var(), [tn_mt, var('a')])
252255
assert len(mo_mt.outputs) == 1
253256
assert isinstance(mo_mt.reify(), TFlowMetaOp)
257+
assert mo_mt.outputs[0].inputs == (tn_mt, var('a'), mo_mt.name)
254258

255259

256260
@pytest.mark.usefixtures("run_with_tensorflow")

0 commit comments

Comments
 (0)