|
43 | 43 |
|
44 | 44 |
|
45 | 45 | class MetaOpDefLibrary(object): |
| 46 | + """A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct. |
| 47 | +
|
| 48 | + It provides a map of `OpDef` names (lower-cased) to the Python API |
| 49 | + functions in `tensorflow.raw_ops`, as well as `inspect.Signature` objects |
| 50 | + for said functions so that default values and lists of arguments (keywords |
| 51 | + included) can be more easily used. |
| 52 | +
|
| 53 | + """ |
46 | 54 |
|
47 | 55 | lower_op_name_to_raw = { |
48 | 56 | op_name.lower(): op_name |
@@ -130,6 +138,12 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None): |
130 | 138 |
|
131 | 139 | @classmethod |
132 | 140 | def get_op_info(cls, opdef): |
| 141 | + """Return the TF Python API function signature for a given `OpDef`. |
| 142 | +
|
| 143 | + Parameter |
| 144 | + --------- |
| 145 | + opdef: str or `OpDef` object (meta or base) |
| 146 | + """ |
133 | 147 | if isinstance(opdef, str): |
134 | 148 | opdef_name = opdef |
135 | 149 | opdef = op_def_registry.get(opdef_name) |
@@ -289,7 +303,14 @@ def input_args(self, *args, apply_defaults=True, **kwargs): |
289 | 303 | return op_args.arguments |
290 | 304 |
|
291 | 305 | def __call__(self, *args, **kwargs): |
292 | | - """Create the meta object(s) resulting from an application of this `OpDef`'s implied `Operation`.""" |
| 306 | + """Create the meta object(s) using the TF Python API's operator functions. |
| 307 | +
|
| 308 | + Each meta `OpDef` is associated with a TF Python function |
| 309 | + (`self._apply_func`) that is used to construct its `Operation`s. |
| 310 | +
|
| 311 | + See `TFlowMetaTensor.operator` and `TFlowMetaTensor.operator`. |
| 312 | +
|
| 313 | + """ |
293 | 314 |
|
294 | 315 | apply_arguments = self.input_args(*args, **kwargs) |
295 | 316 |
|
@@ -385,8 +406,7 @@ def __eq__(self, other): |
385 | 406 | if not (type(self) == type(other)): |
386 | 407 | return False |
387 | 408 |
|
388 | | - if not (self.base == other.base): |
389 | | - return False |
| 409 | + assert self.base == other.base |
390 | 410 |
|
391 | 411 | return self.obj.name == other.obj.name |
392 | 412 |
|
@@ -745,28 +765,44 @@ def name(self): |
745 | 765 |
|
746 | 766 | @property |
747 | 767 | def operator(self): |
| 768 | + """Return the meta OpDef for this tensor. |
| 769 | +
|
| 770 | + Since meta OpDefs are callable (and dispatch to the corresponding TF |
| 771 | + Python API function), this object called with arguments provided by |
| 772 | + `TFlowMetaTensor.inputs` recreates the underlying tensor using the TF |
| 773 | + Python interface. This approach has advantages over the purely |
| 774 | + graph-level approach to constructing meta objects, because--when all |
| 775 | + arguments are reifiable--it allows us to use purely TF means to |
| 776 | + construct a meta object (i.e. by first constructing the base object and |
| 777 | + then "metatizing" it). |
| 778 | +
|
| 779 | + Meta objects produced this way result in less unknown information |
| 780 | + (e.g. dtypes and shapes) and have the same default values as their base |
| 781 | + object counterparts (e.g. `Operator` names and `NodeDef.attr` values). |
| 782 | + """ |
748 | 783 | if self.op is not None and not isvar(self.op): |
749 | 784 | return self.op.op_def |
750 | 785 |
|
751 | 786 | @property |
752 | 787 | def inputs(self): |
753 | | - """Return the tensor's inputs/rands. |
| 788 | + """Return the inputs necessary to recreate this object using its TF Python API function. |
| 789 | +
|
| 790 | + These inputs differ from `self.op.inputs` primarily in that they |
| 791 | + contain the `node_def` parameters as keywords (e.g. to Python API |
| 792 | + functions like `tf.add`). |
754 | 793 |
|
755 | | - NOTE: These inputs differ from `self.op.inputs` in that they contain |
756 | | - the `node_def` parameters, as well. |
757 | | - In other words, these can be used to recreate this object (per |
758 | | - the meta object spec). |
| 794 | + See `TFlowMetaTensor.operator` for more information. |
759 | 795 | """ |
760 | 796 | # TODO: In keeping with our desire to return logic variables in cases |
761 | 797 | # where params aren't given/inferred, we could return something like |
762 | 798 | # `cons(var(), var())` here (although that wouldn't be necessarily imply |
763 | 799 | # that the result is a proper list/tuple). |
764 | | - if self.op is not None and not isvar(self.op): |
765 | | - input_args = self.op.op_def.input_args( |
766 | | - *self.op.inputs, |
767 | | - name=self.op.name if not isvar(self.op.name) else None, |
768 | | - **self.op.node_def.attr, |
769 | | - ) |
| 800 | + if self.op is not None and not isvar(self.op) and not isvar(self.op.inputs): |
| 801 | + if not isvar(self.op.node_def) and not isvar(self.op.node_def.attr): |
| 802 | + attr = self.op.node_def.attr |
| 803 | + else: |
| 804 | + attr = {} |
| 805 | + input_args = self.op.op_def.input_args(*self.op.inputs, name=self.op.name, **attr) |
770 | 806 | return tuple(input_args.values()) |
771 | 807 |
|
772 | 808 | def reify(self): |
|
0 commit comments