Skip to content

Commit 209b90a

Browse files
Updates for recent TF op_def_registry changes
1 parent 29de09b commit 209b90a

File tree

2 files changed

+35
-34
lines changed

2 files changed

+35
-34
lines changed

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
scipy>=1.2.0
22
Theano>=1.0.4
33
pymc4 @ git+https://github.com/pymc-devs/pymc4.git@master#egg=pymc4-0.0.1
4-
tfp-nightly==0.9.0.dev20190908
5-
tf-nightly-2.0-preview==2.0.0.dev20190908
6-
tensorflow-estimator-2.0-preview==1.14.0.dev2019090801
4+
tfp-nightly==0.9.0.dev20191003
5+
tf-nightly-2.0-preview==2.0.0.dev20191002
6+
tensorflow-estimator-2.0-preview>=1.14.0.dev2019090801
77
pymc3>=3.6
88
multipledispatch>=0.6.0
99
unification>=0.2.2

symbolic_pymc/tensorflow/meta.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,18 @@
3737
)
3838

3939

40-
class MetaOpDefLibrary(op_def_library.OpDefLibrary):
41-
def __init__(self, *args, **kwargs):
42-
# This is a lame way to fix the numerous naming inconsistencies between
43-
# TF `Operation`s, `OpDef`s, and the corresponding user-level functions.
44-
self.lower_op_name_to_raw = {
45-
op_name.lower(): op_name
46-
for op_name in dir(tf.raw_ops)
47-
if callable(getattr(tf.raw_ops, op_name))
48-
}
49-
super().__init__(*args, **kwargs)
40+
class MetaOpDefLibrary(object):
41+
42+
lower_op_name_to_raw = {
43+
op_name.lower(): op_name
44+
for op_name in dir(tf.raw_ops)
45+
if callable(getattr(tf.raw_ops, op_name))
46+
}
47+
opdef_signatures = {}
48+
49+
@classmethod
50+
def apply_op(cls, *args, **kwargs):
51+
return op_def_library.apply_op(*args, **kwargs)
5052

5153
@classmethod
5254
def make_opdef_sig(cls, opdef, opdef_py_func=None):
@@ -121,21 +123,22 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
121123
)
122124
return opdef_sig, opdef_py_func
123125

124-
def add_op(self, opdef):
125-
op_info = self._ops.get(opdef.name, None)
126-
if op_info is None:
127-
super().add_op(opdef)
128-
op_info = self._ops[opdef.name]
126+
@classmethod
127+
def get_op_info(cls, opdef):
128+
if isinstance(opdef, str):
129+
opdef_name = opdef
130+
opdef = op_def_registry.get(opdef_name)
131+
else:
132+
opdef_name = opdef.name
133+
134+
opdef_sig = cls.opdef_signatures.get(opdef_name, None)
135+
136+
if opdef_sig is None and opdef is not None:
129137
opdef_func = getattr(tf.raw_ops, opdef.name, None)
130-
opdef_sig, opdef_func = self.make_opdef_sig(op_info.op_def, opdef_func)
131-
op_info.opdef_sig = opdef_sig
132-
op_info.opdef_func = opdef_func
133-
return op_info
138+
opdef_sig = cls.make_opdef_sig(opdef, opdef_func)
139+
cls.opdef_signatures[opdef.name] = cls.make_opdef_sig(opdef, opdef_func)
134140

135-
def get_opinfo(self, opdef):
136-
if isinstance(opdef, str):
137-
opdef = op_def_registry.get_registered_ops()[opdef]
138-
return self.add_op(opdef)
141+
return opdef_sig
139142

140143

141144
op_def_lib = MetaOpDefLibrary()
@@ -251,7 +254,7 @@ class TFlowMetaOpDef(MetaOp, metaclass=OpDefFactoryType):
251254
>>> from google.protobuf import json_format
252255
>>> print(json_format.MessageToJson(opdef))
253256
- If you want to use an `OpDef` to construct a node, see
254-
`op_def_library.OpDefLibrary.apply_op`.
257+
`op_def_library.apply_op`.
255258
256259
"""
257260

@@ -260,9 +263,7 @@ class TFlowMetaOpDef(MetaOp, metaclass=OpDefFactoryType):
260263

261264
def __init__(self, obj=None):
262265
super().__init__(obj=obj)
263-
op_info = op_def_lib.add_op(obj)
264-
self._apply_func_sig = op_info.opdef_sig
265-
self._apply_func = op_info.opdef_func
266+
self._apply_func_sig, self._apply_func = op_def_lib.get_op_info(obj)
266267

267268
def out_meta_types(self, inputs=None):
268269
def _convert_outputs(o):
@@ -411,8 +412,8 @@ def __init__(self, op, name, attr, obj=None):
411412
# We want to limit the attributes we'll consider to those that show
412413
# up in an OpDef function's signature (e.g. ignore info about
413414
# permissible types).
414-
opinfo = op_def_lib.get_opinfo(self.op)
415-
op_param_names = opinfo.opdef_sig.parameters.keys()
415+
opdef_sig, _ = op_def_lib.get_op_info(self.op)
416+
op_param_names = opdef_sig.parameters.keys()
416417

417418
_attr = dict()
418419
for k, v in attr.items():
@@ -496,7 +497,7 @@ def __init__(self, op_def, node_def, inputs, outputs=None, obj=None):
496497
super().__init__(obj=obj)
497498

498499
if isinstance(op_def, str):
499-
op_def = op_def_registry.get_registered_ops()[op_def]
500+
op_def = op_def_registry.get(op_def)
500501

501502
self.op_def = metatize(op_def)
502503
self.node_def = metatize(node_def)
@@ -798,7 +799,7 @@ def __call__(self, x):
798799
def find_opdef(cls, name):
799800
"""Attempt to create a meta `OpDef` for a given TF function/`Operation` name."""
800801
raw_op_name = op_def_lib.lower_op_name_to_raw.get(name.lower(), name)
801-
op_def = op_def_registry.get_registered_ops()[raw_op_name]
802+
op_def = op_def_registry.get(raw_op_name)
802803

803804
if op_def is not None:
804805
meta_obj = TFlowMetaOpDef(obj=op_def)

0 commit comments

Comments
 (0)